diff --git a/CHANGELOG.md b/CHANGELOG.md index f53e9ffa2e96e24c38775504255edc95e6c7be90..5fb91b19853bdad2dd95c2596aede1064397f2a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,5 @@ # v2.9.7 * add binary shift operations to `AtomicNumber` - +* add `default` to `read_in_file` diff --git a/satella/__init__.py b/satella/__init__.py index acf6f4bb456885882e773dc49f4c8e4707925a8e..8a93c15ad8b46cd0b27de093cc0b10aeb6fe8148 100644 --- a/satella/__init__.py +++ b/satella/__init__.py @@ -1 +1 @@ -__version__ = '2.9.7_a2' +__version__ = '2.9.7' diff --git a/satella/files.py b/satella/files.py index 7a1d73defa08ea044f097f180fec4d6977808f14..2b7f5ef2bfdd13579838285e63cccaf502173d64 100644 --- a/satella/files.py +++ b/satella/files.py @@ -81,19 +81,32 @@ def write_to_file(path: str, data: tp.Union[bytes, str], file.close() -def read_in_file(path: str, encoding: tp.Optional[str] = None) -> tp.Union[bytes, str]: +def read_in_file(path: str, encoding: tp.Optional[str] = None, + default: tp.Optional[tp.Union[bytes, str]] = None) -> tp.Union[bytes, str]: """ Opens a file for reading, reads it in, converts to given encoding (or returns as bytes if not given), and closes it. :param path: path of file to read :param encoding: optional encoding. If default (None given), this will be returned as bytes + :param default: value to return when the file does not exist. Default (None) will raise a + FileNotFoundError :return: file content, either decoded as a str, or not as bytes """ - if encoding is None: - file = open(path, 'rb') - else: - file = codecs.open(path, 'rb', encoding) + if os.path.isdir(path): + if default: + return default + raise FileNotFoundError('%s found and is a directory' % (path, )) + + try: + if encoding is None: + file = open(path, 'rb') + else: + file = codecs.open(path, 'rb', encoding) + except FileNotFoundError: + if default: + return default + raise try: return file.read() diff --git a/tests/test_files.py b/tests/test_files.py index 4b21d8703f100d76ccccba63bc25e4c16ab832d9..64b2ebc7dc1a09008ae2352c64fbf846331752b0 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -14,6 +14,14 @@ def putfile(path: str) -> None: class TestFiles(unittest.TestCase): + def try_directory(self): + os.system('mkdir test') + self.assertRaises(FileNotFoundError, lambda: read_in_file('test')) + self.assertEqual(b'test', read_in_file('test', default=b'test')) + os.system('rm -rf test') + self.assertRaises(FileNotFoundError, lambda: read_in_file('test')) + self.assertEqual(b'test', read_in_file('test', default=b'test')) + def test_make_noncolliding_name(self): with open('test.txt', 'w') as f_out: f_out.write('test')