diff --git a/rclpy/rclpy/executors.py b/rclpy/rclpy/executors.py index 48ecb4fcf..79422a032 100644 --- a/rclpy/rclpy/executors.py +++ b/rclpy/rclpy/executors.py @@ -13,11 +13,13 @@ # limitations under the License. from concurrent.futures import ThreadPoolExecutor +import inspect import multiprocessing from threading import Condition from threading import Lock from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy +from rclpy.task import Task from rclpy.timer import WallTimer from rclpy.utilities import ok from rclpy.utilities import timeout_sec_to_nsec @@ -71,6 +73,16 @@ def wait(self, timeout_sec=None): return True +async def await_or_execute(callback, *args): + """Await a callback if it is a coroutine, else execute it.""" + if inspect.iscoroutinefunction(callback): + # Await a coroutine + return await callback(*args) + else: + # Call a normal function + return callback(*args) + + class TimeoutException(Exception): """Signal that a timeout occurred.""" @@ -93,6 +105,9 @@ def __init__(self): super().__init__() self._nodes = set() self._nodes_lock = Lock() + # Tasks to be executed (oldest first) 3-tuple Task, Entity, Node + self._tasks = [] + self._tasks_lock = Lock() # This is triggered when wait_for_ready_callbacks should rebuild the wait list gc, gc_handle = _rclpy.rclpy_create_guard_condition() self._guard_condition = gc @@ -105,6 +120,22 @@ def __init__(self): self._last_args = None self._last_kwargs = None + def create_task(self, callback, *args, **kwargs): + """ + Add a callback or coroutine to be executed during :meth:`spin` and return a Future. + + Arguments to this function are passed to the callback. + + :param callback: A callback to be run in the executor + :type callback: callable or coroutine function + :rtype: :class:`rclpy.task.Future` instance + """ + task = Task(callback, args, kwargs, executor=self) + with self._tasks_lock: + self._tasks.append((task, None, None)) + # Task inherits from Future + return task + def shutdown(self, timeout_sec=None): """ Stop executing callbacks and wait for their completion. @@ -175,23 +206,23 @@ def spin_once(self, timeout_sec=None): def _take_timer(self, tmr): _rclpy.rclpy_call_timer(tmr.timer_handle) - def _execute_timer(self, tmr, _): - tmr.callback() + async def _execute_timer(self, tmr, _): + await await_or_execute(tmr.callback) def _take_subscription(self, sub): msg = _rclpy.rclpy_take(sub.subscription_handle, sub.msg_type) return msg - def _execute_subscription(self, sub, msg): + async def _execute_subscription(self, sub, msg): if msg: - sub.callback(msg) + await await_or_execute(sub.callback, msg) def _take_client(self, client): response = _rclpy.rclpy_take_response( client.client_handle, client.srv_type.Response, client.sequence_number) return response - def _execute_client(self, client, response): + async def _execute_client(self, client, response): if response: # clients spawn their own thread to wait for a response in the # wait_for_future function. Users can either use this mechanism or monitor @@ -203,42 +234,35 @@ def _take_service(self, srv): srv.service_handle, srv.srv_type.Request) return request_and_header - def _execute_service(self, srv, request_and_header): + async def _execute_service(self, srv, request_and_header): if request_and_header is None: return (request, header) = request_and_header if request: - response = srv.callback(request, srv.srv_type.Response()) + response = await await_or_execute(srv.callback, request, srv.srv_type.Response()) srv.send_response(response, header) def _take_guard_condition(self, gc): gc._executor_triggered = False - def _execute_guard_condition(self, gc, _): - gc.callback() + async def _execute_guard_condition(self, gc, _): + await await_or_execute(gc.callback) - def _make_handler(self, entity, take_from_wait_list, call_callback): + def _make_handler(self, entity, node, take_from_wait_list, call_coroutine): """ Make a handler that performs work on an entity. :param entity: An entity to wait on :param take_from_wait_list: Makes the entity to stop appearing in the wait list :type take_from_wait_list: callable - :param call_callback: Does the work the entity is ready for - :type call_callback: callable + :param call_coroutine: Does the work the entity is ready for + :type call_coroutine: coroutine function :rtype: callable """ - gc = self._guard_condition - work_tracker = self._work_tracker - is_shutdown = self._is_shutdown # Mark this so it doesn't get added back to the wait list entity._executor_event = True - def handler(): - nonlocal entity - nonlocal gc - nonlocal is_shutdown - nonlocal work_tracker + async def handler(entity, gc, is_shutdown, work_tracker): if is_shutdown or not entity.callback_group.beginning_execution(entity): # Didn't get the callback, or the executor has been ordered to stop entity._executor_event = False @@ -252,13 +276,18 @@ def handler(): _rclpy.rclpy_trigger_guard_condition(gc) try: - call_callback(entity, arg) + await call_coroutine(entity, arg) finally: entity.callback_group.ending_execution(entity) # Signal that work has been done so the next callback in a mutually exclusive # callback group can get executed _rclpy.rclpy_trigger_guard_condition(gc) - return handler + task = Task( + handler, (entity, self._guard_condition, self._is_shutdown, self._work_tracker), + executor=self) + with self._tasks_lock: + self._tasks.append((task, entity, node)) + return task def can_execute(self, entity): """ @@ -290,6 +319,18 @@ def _wait_for_ready_callbacks(self, timeout_sec=None, nodes=None): if nodes is None: nodes = self.get_nodes() + # Yield tasks in-progress before waiting for new work + tasks = None + with self._tasks_lock: + tasks = list(self._tasks) + if tasks: + for task, entity, node in reversed(tasks): + if not task.executing() and not task.done() and (node is None or node in nodes): + yield task, entity, node + with self._tasks_lock: + # Get rid of any tasks that are done + self._tasks = list(filter(lambda t_e_n: not t_e_n[0].done(), self._tasks)) + yielded_work = False while not yielded_work and not self._is_shutdown: # Gather entities that can be waited on @@ -368,7 +409,7 @@ def _wait_for_ready_callbacks(self, timeout_sec=None, nodes=None): if _rclpy.rclpy_is_timer_ready(tmr.timer_handle): if tmr.callback_group.can_execute(tmr): handler = self._make_handler( - tmr, self._take_timer, self._execute_timer) + tmr, node, self._take_timer, self._execute_timer) yielded_work = True yield handler, tmr, node @@ -376,7 +417,7 @@ def _wait_for_ready_callbacks(self, timeout_sec=None, nodes=None): if sub.subscription_pointer in subs_ready: if sub.callback_group.can_execute(sub): handler = self._make_handler( - sub, self._take_subscription, self._execute_subscription) + sub, node, self._take_subscription, self._execute_subscription) yielded_work = True yield handler, sub, node @@ -384,7 +425,8 @@ def _wait_for_ready_callbacks(self, timeout_sec=None, nodes=None): if gc._executor_triggered: if gc.callback_group.can_execute(gc): handler = self._make_handler( - gc, self._take_guard_condition, self._execute_guard_condition) + gc, node, self._take_guard_condition, + self._execute_guard_condition) yielded_work = True yield handler, gc, node @@ -392,7 +434,7 @@ def _wait_for_ready_callbacks(self, timeout_sec=None, nodes=None): if client.client_pointer in clients_ready: if client.callback_group.can_execute(client): handler = self._make_handler( - client, self._take_client, self._execute_client) + client, node, self._take_client, self._execute_client) yielded_work = True yield handler, client, node @@ -400,7 +442,7 @@ def _wait_for_ready_callbacks(self, timeout_sec=None, nodes=None): if srv.service_pointer in services_ready: if srv.callback_group.can_execute(srv): handler = self._make_handler( - srv, self._take_service, self._execute_service) + srv, node, self._take_service, self._execute_service) yielded_work = True yield handler, srv, node diff --git a/rclpy/rclpy/task.py b/rclpy/rclpy/task.py new file mode 100644 index 000000000..16269f195 --- /dev/null +++ b/rclpy/rclpy/task.py @@ -0,0 +1,230 @@ +# Copyright 2018 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import threading +import weakref + + +def _fake_weakref(): + """Return None when called to simulate a weak reference that has been garbage collected.""" + return None + + +class Future: + """Represent the outcome of a task in the future.""" + + def __init__(self, *, executor=None): + # true if the task is done or cancelled + self._done = False + # true if the task is cancelled + self._cancelled = False + # the final return value of the handler + self._result = None + # An exception raised by the handler when called + self._exception = None + # callbacks to be scheduled after this task completes + self._callbacks = [] + # Lock for threadsafety + self._lock = threading.Lock() + # An executor to use when scheduling done callbacks + self._executor = None + self._set_executor(executor) + + def __await__(self): + # Yield if the task is not finished + while not self._done: + yield + return self._result + + def cancel(self): + """Request cancellation of the running task if it is not done already.""" + with self._lock: + if not self._done: + self._cancelled = True + self._schedule_done_callbacks() + + def cancelled(self): + """ + Indicate if the task has been cancelled. + + :return: True if the task was cancelled + :rtype: bool + """ + return self._cancelled + + def done(self): + """ + Indicate if the task has finished executing. + + :return: True if the task is finished or raised while it was executing + :rtype: bool + """ + return self._done + + def result(self): + """ + Get the result of a done task. + + :return: The result set by the task + """ + return self._result + + def exception(self): + """ + Get an exception raised by a done task. + + :return: The exception raised by the task + """ + return self._exception + + def set_result(self, result): + """ + Set the result returned by a task. + + :param result: The output of a long running task. + """ + with self._lock: + self._result = result + self._done = True + self._cancelled = False + self._schedule_done_callbacks() + + def set_exception(self, exception): + """ + Set the exception raised by the task. + + :param result: The output of a long running task. + """ + with self._lock: + self._exception = exception + self._done = True + self._cancelled = False + self._schedule_done_callbacks() + + def _schedule_done_callbacks(self): + """Schedule done callbacks on the executor if possible.""" + executor = self._executor() + if executor is not None: + for callback in self._callbacks: + executor.create_task(callback, self) + self._callbacks = [] + + def _set_executor(self, executor): + """Set the executor this future is associated with.""" + with self._lock: + if executor is None: + self._executor = _fake_weakref + else: + self._executor = weakref.ref(executor) + + def add_done_callback(self, callback): + """ + Add a callback to be executed when the task is done. + + :param callback: a callback taking the future as an agrument to be run when completed + """ + with self._lock: + if self._done: + executor = self._executor() + if executor is not None: + executor.create_task(callback, self) + else: + self._callbacks.append(callback) + + +class Task(Future): + """ + Execute a function or coroutine. + + This executes either a normal function or a coroutine to completion. On completion it creates + tasks for any 'done' callbacks. + + This class should only be instantiated by :class:`rclpy.executors.Executor`. + """ + + def __init__(self, handler, args=None, kwargs=None, executor=None): + super().__init__(executor=executor) + # _handler is either a normal function or a coroutine + self._handler = handler + # Arguments passed into the function + if args is None: + args = [] + self._args = args + if kwargs is None: + kwargs = {} + self._kwargs = kwargs + if inspect.iscoroutinefunction(handler): + self._handler = handler(*args, **kwargs) + self._args = None + self._kwargs = None + # True while the task is being executed + self._executing = False + # Lock acquired to prevent task from executing in parallel with itself + self._task_lock = threading.Lock() + + def __call__(self): + """ + Run or resume a task. + + This attempts to execute a handler. If the handler is a coroutine it will attempt to + await it. If there are done callbacks it will schedule them with the executor. + + The return value of the handler is stored as the task result. + """ + if self._done or self._executing or not self._task_lock.acquire(blocking=False): + return + try: + if self._done: + return + self._executing = True + + if inspect.iscoroutine(self._handler): + # Execute a coroutine + try: + self._handler.send(None) + except StopIteration as e: + # The coroutine finished; store the result + self._handler.close() + self.set_result(e.value) + self._complete_task() + except Exception as e: + self.set_exception(e) + self._complete_task() + else: + # Execute a normal function + try: + self.set_result(self._handler(*self._args, **self._kwargs)) + except Exception as e: + self.set_exception(e) + self._complete_task() + + self._executing = False + finally: + self._task_lock.release() + + def _complete_task(self): + """Cleanup after task finished.""" + self._handler = None + self._args = None + self._kwargs = None + + def executing(self): + """ + Check if the task is currently being executed. + + :return: True if the task is currently executing. + :rtype: bool + """ + return self._executing diff --git a/rclpy/test/test_executor.py b/rclpy/test/test_executor.py index cd7d322b6..15ea6aa93 100644 --- a/rclpy/test/test_executor.py +++ b/rclpy/test/test_executor.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import time import unittest @@ -81,6 +82,123 @@ def test_executor_spin_non_blocking(self): end = time.monotonic() self.assertLess(start - end, 0.001) + def test_execute_coroutine_timer(self): + self.assertIsNotNone(self.node.handle) + executor = SingleThreadedExecutor() + executor.add_node(self.node) + + called1 = False + called2 = False + + async def coroutine(): + nonlocal called1 + nonlocal called2 + called1 = True + await asyncio.sleep(0) + called2 = True + + tmr = self.node.create_timer(0.1, coroutine) + try: + executor.spin_once(timeout_sec=1.23) + self.assertTrue(called1) + self.assertFalse(called2) + + called1 = False + executor.spin_once(timeout_sec=0) + self.assertFalse(called1) + self.assertTrue(called2) + finally: + self.node.destroy_timer(tmr) + + def test_execute_coroutine_guard_condition(self): + self.assertIsNotNone(self.node.handle) + executor = SingleThreadedExecutor() + executor.add_node(self.node) + + called1 = False + called2 = False + + async def coroutine(): + nonlocal called1 + nonlocal called2 + called1 = True + await asyncio.sleep(0) + called2 = True + + gc = self.node.create_guard_condition(coroutine) + try: + gc.trigger() + executor.spin_once(timeout_sec=0) + self.assertTrue(called1) + self.assertFalse(called2) + + called1 = False + executor.spin_once(timeout_sec=1) + self.assertFalse(called1) + self.assertTrue(called2) + finally: + self.node.destroy_guard_condition(gc) + + def test_create_task_coroutine(self): + self.assertIsNotNone(self.node.handle) + executor = SingleThreadedExecutor() + executor.add_node(self.node) + + async def coroutine(): + return 'Sentinel Result' + + future = executor.create_task(coroutine) + self.assertFalse(future.done()) + + executor.spin_once(timeout_sec=0) + self.assertTrue(future.done()) + self.assertEqual('Sentinel Result', future.result()) + + def test_create_task_normal_function(self): + self.assertIsNotNone(self.node.handle) + executor = SingleThreadedExecutor() + executor.add_node(self.node) + + def func(): + return 'Sentinel Result' + + future = executor.create_task(func) + self.assertFalse(future.done()) + + executor.spin_once(timeout_sec=0) + self.assertTrue(future.done()) + self.assertEqual('Sentinel Result', future.result()) + + def test_create_task_dependent_coroutines(self): + self.assertIsNotNone(self.node.handle) + executor = SingleThreadedExecutor() + executor.add_node(self.node) + + async def coro1(): + return 'Sentinel Result 1' + + future1 = executor.create_task(coro1) + + async def coro2(): + nonlocal future1 + await future1 + return 'Sentinel Result 2' + + future2 = executor.create_task(coro2) + + # Coro2 is newest task, so it gets to await future1 in this spin + executor.spin_once(timeout_sec=0) + # Coro1 execs in this spin + executor.spin_once(timeout_sec=0) + self.assertTrue(future1.done()) + self.assertEqual('Sentinel Result 1', future1.result()) + self.assertFalse(future2.done()) + + # Coro2 passes the await step here (timeout change forces new generator) + executor.spin_once(timeout_sec=1) + self.assertTrue(future2.done()) + self.assertEqual('Sentinel Result 2', future2.result()) + if __name__ == '__main__': unittest.main() diff --git a/rclpy/test/test_task.py b/rclpy/test/test_task.py new file mode 100644 index 000000000..9bba03c93 --- /dev/null +++ b/rclpy/test/test_task.py @@ -0,0 +1,264 @@ +# Copyright 2018 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import unittest + +from rclpy.task import Future +from rclpy.task import Task + + +class DummyExecutor: + + def __init__(self): + self.done_callbacks = [] + + def create_task(self, cb, *args): + self.done_callbacks.append((cb, args)) + + +class TestTask(unittest.TestCase): + + def test_task_normal_callable(self): + + def func(): + return 'Sentinel Result' + + t = Task(func) + t() + self.assertTrue(t.done()) + self.assertEqual('Sentinel Result', t.result()) + + def test_task_lambda(self): + + def func(): + return 'Sentinel Result' + + t = Task(lambda: func()) + t() + self.assertTrue(t.done()) + self.assertEqual('Sentinel Result', t.result()) + + def test_coroutine(self): + called1 = False + called2 = False + + async def coro(): + nonlocal called1 + nonlocal called2 + called1 = True + await asyncio.sleep(0) + called2 = True + return 'Sentinel Result' + + t = Task(coro) + t() + self.assertTrue(called1) + self.assertFalse(called2) + + called1 = False + t() + self.assertFalse(called1) + self.assertTrue(called2) + self.assertTrue(t.done()) + self.assertEqual('Sentinel Result', t.result()) + + def test_done_callback_scheduled(self): + executor = DummyExecutor() + + t = Task(lambda: None, executor=executor) + t.add_done_callback('Sentinel Value') + t() + self.assertTrue(t.done()) + self.assertEqual(1, len(executor.done_callbacks)) + self.assertEqual('Sentinel Value', executor.done_callbacks[0][0]) + args = executor.done_callbacks[0][1] + self.assertEqual(1, len(args)) + self.assertEqual(t, args[0]) + + def test_done_task_done_callback_scheduled(self): + executor = DummyExecutor() + + t = Task(lambda: None, executor=executor) + t() + self.assertTrue(t.done()) + t.add_done_callback('Sentinel Value') + self.assertEqual(1, len(executor.done_callbacks)) + self.assertEqual('Sentinel Value', executor.done_callbacks[0][0]) + args = executor.done_callbacks[0][1] + self.assertEqual(1, len(args)) + self.assertEqual(t, args[0]) + + def test_done_task_called(self): + called = False + + def func(): + nonlocal called + called = True + + t = Task(func) + t() + self.assertTrue(called) + self.assertTrue(t.done()) + called = False + t() + self.assertFalse(called) + self.assertTrue(t.done()) + + def test_cancelled(self): + t = Task(lambda: None) + t.cancel() + self.assertTrue(t.cancelled()) + + def test_done_task_cancelled(self): + t = Task(lambda: None) + t() + t.cancel() + self.assertFalse(t.cancelled()) + + def test_exception(self): + + def func(): + e = Exception() + e.sentinel_value = 'Sentinel Exception' + raise e + + t = Task(func) + t() + self.assertTrue(t.done()) + self.assertEqual('Sentinel Exception', t.exception().sentinel_value) + self.assertEqual(None, t.result()) + + def test_coroutine_exception(self): + + async def coro(): + e = Exception() + e.sentinel_value = 'Sentinel Exception' + raise e + + t = Task(coro) + t() + self.assertTrue(t.done()) + self.assertEqual('Sentinel Exception', t.exception().sentinel_value) + self.assertEqual(None, t.result()) + + def test_task_normal_callable_args(self): + arg_in = 'Sentinel Arg' + + def func(arg): + return arg + + t = Task(func, args=(arg_in,)) + t() + self.assertEqual('Sentinel Arg', t.result()) + + def test_coroutine_args(self): + arg_in = 'Sentinel Arg' + + async def coro(arg): + return arg + + t = Task(coro, args=(arg_in,)) + t() + self.assertEqual('Sentinel Arg', t.result()) + + def test_task_normal_callable_kwargs(self): + arg_in = 'Sentinel Arg' + + def func(kwarg=None): + return kwarg + + t = Task(func, kwargs={'kwarg': arg_in}) + t() + self.assertEqual('Sentinel Arg', t.result()) + + def test_coroutine_kwargs(self): + arg_in = 'Sentinel Arg' + + async def coro(kwarg=None): + return kwarg + + t = Task(coro, kwargs={'kwarg': arg_in}) + t() + self.assertEqual('Sentinel Arg', t.result()) + + def test_executing(self): + t = Task(lambda: None) + self.assertFalse(t.executing()) + + +class TestFuture(unittest.TestCase): + + def test_cancelled(self): + f = Future() + f.cancel() + self.assertTrue(f.cancelled()) + + def test_done(self): + f = Future() + self.assertFalse(f.done()) + f.set_result(None) + self.assertTrue(f.done()) + + def test_set_result(self): + f = Future() + f.set_result('Sentinel Result') + self.assertEqual('Sentinel Result', f.result()) + self.assertTrue(f.done()) + + def test_set_exception(self): + f = Future() + f.set_exception('Sentinel Exception') + self.assertEqual('Sentinel Exception', f.exception()) + self.assertTrue(f.done()) + + def test_await(self): + f = Future() + + async def coro(): + nonlocal f + return await f + + c = coro() + c.send(None) + f.set_result('Sentinel Result') + try: + c.send(None) + except StopIteration as e: + self.assertEqual('Sentinel Result', e.value) + + def test_cancel_schedules_callbacks(self): + executor = DummyExecutor() + f = Future(executor=executor) + f.add_done_callback(lambda f: None) + f.cancel() + self.assertTrue(executor.done_callbacks) + + def test_set_result_schedules_callbacks(self): + executor = DummyExecutor() + f = Future(executor=executor) + f.add_done_callback(lambda f: None) + f.set_result('Anything') + self.assertTrue(executor.done_callbacks) + + def test_set_exception_schedules_callbacks(self): + executor = DummyExecutor() + f = Future(executor=executor) + f.add_done_callback(lambda f: None) + f.set_exception('Anything') + self.assertTrue(executor.done_callbacks) + + +if __name__ == '__main__': + unittest.main()