diff --git a/cliar/cliar.py b/cliar/cliar.py index 68b09b4..9c80c8b 100644 --- a/cliar/cliar.py +++ b/cliar/cliar.py @@ -1,5 +1,6 @@ from argparse import ArgumentParser, RawTextHelpFormatter -from inspect import signature, getmembers, ismethod, isclass +from asyncio import get_event_loop +from inspect import signature, getmembers, ismethod, isclass, iscoroutine from collections import OrderedDict from typing import List, Iterable, Callable, Set, Type, get_type_hints @@ -291,7 +292,12 @@ def parse(self): for subcli in self._subclis: subcli.global_args = self.global_args - if command.handler(**handler_args) == NotImplemented: + result = command.handler(**handler_args) + + if iscoroutine(result): + result = get_event_loop().run_until_complete(result) + + if result == NotImplemented: command.handler.__self__._parser.print_help() def _root(self): diff --git a/tests/test_async_fns.py b/tests/test_async_fns.py new file mode 100644 index 0000000..a55f0d6 --- /dev/null +++ b/tests/test_async_fns.py @@ -0,0 +1,17 @@ +from subprocess import run + + +def test_help(capfd, datadir): + run(f'python {datadir/"async_fns.py"} wait -h', shell=True) + + output = capfd.readouterr().out + + assert '-s SECONDS-TO-WAIT, --seconds-to-wait SECONDS-TO-WAIT' in output + +def test_wait(capfd, datadir): + seconds_to_wait = 1.0 + + run(f'python {datadir/"async_fns.py"} wait -s "{seconds_to_wait}"', + shell=True) + seconds_awaited = float(capfd.readouterr().out.strip()) + assert round(seconds_awaited, 1) == round(seconds_to_wait, 1) diff --git a/tests/test_async_fns/async_fns.py b/tests/test_async_fns/async_fns.py new file mode 100644 index 0000000..e509086 --- /dev/null +++ b/tests/test_async_fns/async_fns.py @@ -0,0 +1,15 @@ +import asyncio +from time import perf_counter + +from cliar import Cliar + + +class AsyncFunctions(Cliar): + async def wait(self, seconds_to_wait: float = 1.0): + t1 = perf_counter() + await asyncio.sleep(seconds_to_wait) + elapsed = perf_counter() - t1 + print(elapsed) + +if __name__ == "__main__": + AsyncFunctions().parse()