diff --git a/docs/processes.rst b/docs/processes.rst index 19ec90fa6223ec092672a38d3ef7d6cc3ae6dd24..2b1ed0167337ec95c814e6129b58a583742c0ce4 100644 --- a/docs/processes.rst +++ b/docs/processes.rst @@ -1,4 +1,8 @@ processes ========= +Note that this function will consume stdout as soon as it's +available, so that you don't need to worry about +the buffer overflowing and such. + .. autofunction:: satella.processes.call_and_return_stdout diff --git a/satella/processes.py b/satella/processes.py index cb8da0550a104dc92bf496bd2e3517cd352301f2..0f94032953b59233955eb2e067738c4d3bfd1dcd 100644 --- a/satella/processes.py +++ b/satella/processes.py @@ -1,15 +1,30 @@ import subprocess import typing as tp +import threading from .exceptions import ProcessFailed +def read_nowait(process: subprocess.Popen, output_list: tp.List[str]): + try: + while process.stdout.readable(): + line = process.stdout.readline() + output_list.append(line) + except (IOError, OSError): + pass + + def call_and_return_stdout(args: tp.Union[str, tp.List[str]], + timeout: tp.Optional[int] = None, expected_return_code: int = 0, **kwargs) -> tp.Union[bytes, str]: """ Call a process and return it's stdout. + Everything in kwargs will be passed to subprocess.Popen + :param args: arguments to run the program with. If passed a string, it will be split on space. + :param timeout: amount of seconds to wait for the process result. If process does not complete + within this time, it will be sent a SIGKILL :param expected_return_code: an expected return code of this process. 0 is the default. If process returns anything else, ProcessFailed will be raise :param ProcessFailed: process' result code was different from the requested @@ -18,12 +33,26 @@ def call_and_return_stdout(args: tp.Union[str, tp.List[str]], args = args.split(' ') kwargs['capture_output'] = True + kwargs['stdout'] = subprocess.PIPE + + stdout_list = [] + + proc = subprocess.Popen(args, **kwargs) + reader_thread = threading.Thread(target=read_nowait, args=(proc, stdout_list), daemon=True) + reader_thread.start() - proc = subprocess.run(args, **kwargs) + try: + proc.wait(timeout=timeout) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait() if proc.returncode != expected_return_code: raise ProcessFailed(proc.returncode) else: - return proc.stdout + if kwargs.get('encoding', None) is None: + return b''.join(stdout_list) + else: + return ''.join(stdout_list) diff --git a/tests/test_processes.py b/tests/test_processes.py index 2a1e2510e6fc2ee7a0a7ab98559916944b8db318..f4bfc327a8b0fbd495312dbc00f09b8b18c52e10 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -8,3 +8,6 @@ class TestProcesses(unittest.TestCase): def test_return_stdout(self): output = call_and_return_stdout('cat /proc/meminfo', shell=True, encoding='utf8') self.assertIn('MemTotal', output) + + output = call_and_return_stdout('cat /proc/meminfo', shell=True) + self.assertIn(b'MemTotal', output)