diff --git a/rclpy/rclpy/impl/_rclpy_pybind11.pyi b/rclpy/rclpy/impl/_rclpy_pybind11.pyi index c8c306f76..878e87f53 100644 --- a/rclpy/rclpy/impl/_rclpy_pybind11.pyi +++ b/rclpy/rclpy/impl/_rclpy_pybind11.pyi @@ -195,6 +195,12 @@ class Client(Destroyable, Generic[SrvRequestT, SrvResponseT]): def get_logger_name(self) -> str: """Get the name of the logger associated with the node of the client.""" + def set_on_new_response_callback(self, callback: Callable[[int], None]) -> None: + """Set the on new response callback function for the client.""" + + def clear_on_new_response_callback(self) -> None: + """Clear the on new response callback function for the client.""" + class Context(Destroyable): @@ -286,6 +292,12 @@ class Service(Destroyable, Generic[SrvRequestT, SrvResponseT]): def get_logger_name(self) -> str: """Get the name of the logger associated with the node of the service.""" + def set_on_new_request_callback(self, callback: Callable[[int], None]) -> None: + """Set the on new request callback function for the service.""" + + def clear_on_new_request_callback(self) -> None: + """Clear the on new request callback function for the service.""" + class TypeDescriptionService(Destroyable): @@ -578,6 +590,12 @@ class Timer(Destroyable): def is_timer_canceled(self) -> bool: """Check if a timer is canceled.""" + def set_on_reset_callback(self, callback: Callable[[int], None]) -> None: + """Set the on reset callback function for the timer.""" + + def clear_on_reset_callback(self) -> None: + """Clear the on reset callback function for the timer.""" + class Subscription(Destroyable, Generic[MsgT]): @@ -600,6 +618,12 @@ class Subscription(Destroyable, Generic[MsgT]): def get_publisher_count(self) -> int: """Count the publishers from a subscription.""" + def set_on_new_message_callback(self, callback: Callable[[int], None]) -> None: + """Set the on new message callback function for the subscription.""" + + def clear_on_new_message_callback(self) -> None: + """Clear the on new message callback function for the subscription.""" + class rcl_time_point_t: diff --git a/rclpy/src/rclpy/client.cpp b/rclpy/src/rclpy/client.cpp index ccbddfd6a..0c12108da 100644 --- a/rclpy/src/rclpy/client.cpp +++ b/rclpy/src/rclpy/client.cpp @@ -23,6 +23,7 @@ #include #include +#include #include "client.hpp" #include "clock.hpp" @@ -30,13 +31,19 @@ #include "node.hpp" #include "python_allocator.hpp" #include "utils.hpp" +#include "events_executor/rcl_support.hpp" namespace rclpy { +using events_executor::RclEventCallbackTrampoline; void Client::destroy() { + try { + clear_on_new_response_callback(); + } catch (RCLError) { + } rcl_client_.reset(); node_.destroy(); } @@ -181,6 +188,41 @@ Client::get_logger_name() const return node_logger_name; } +void +Client::set_callback( + rcl_event_callback_t callback, + const void * user_data) +{ + rcl_ret_t ret = rcl_client_set_on_new_response_callback( + rcl_client_.get(), + callback, + user_data); + + if (RCL_RET_OK != ret) { + throw RCLError(std::string("Failed to set the on new response callback for client: ") + + rcl_get_error_string().str); + } +} + +void +Client::set_on_new_response_callback(std::function callback) +{ + clear_on_new_response_callback(); + on_new_response_callback_ = std::move(callback); + set_callback( + RclEventCallbackTrampoline, + static_cast(&on_new_response_callback_)); +} + +void +Client::clear_on_new_response_callback() +{ + if (on_new_response_callback_) { + set_callback(nullptr, nullptr); + on_new_response_callback_ = nullptr; + } +} + void define_client(py::object module) { @@ -208,6 +250,10 @@ define_client(py::object module) "Configure whether introspection is enabled") .def( "get_logger_name", &Client::get_logger_name, - "Get the name of the logger associated with the node of the client."); + "Get the name of the logger associated with the node of the client.") + .def( + "set_on_new_response_callback", &Client::set_on_new_response_callback, + py::arg("callback")) + .def("clear_on_new_response_callback", &Client::clear_on_new_response_callback); } } // namespace rclpy diff --git a/rclpy/src/rclpy/client.hpp b/rclpy/src/rclpy/client.hpp index 4428dfcc4..8fa8a3f07 100644 --- a/rclpy/src/rclpy/client.hpp +++ b/rclpy/src/rclpy/client.hpp @@ -16,6 +16,7 @@ #define RCLPY__CLIENT_HPP_ #include +#include #include #include @@ -120,10 +121,20 @@ class Client : public Destroyable, public std::enable_shared_from_this const char * get_logger_name() const; + void + set_on_new_response_callback(std::function callback); + + void + clear_on_new_response_callback(); + private: Node node_; + std::function on_new_response_callback_{nullptr}; std::shared_ptr rcl_client_; rosidl_service_type_support_t * srv_type_; + + void + set_callback(rcl_event_callback_t callback, const void * user_data); }; /// Define a pybind11 wrapper for an rclpy::Client diff --git a/rclpy/src/rclpy/events_executor/rcl_support.cpp b/rclpy/src/rclpy/events_executor/rcl_support.cpp index 3a8826f70..3a8e54c68 100644 --- a/rclpy/src/rclpy/events_executor/rcl_support.cpp +++ b/rclpy/src/rclpy/events_executor/rcl_support.cpp @@ -26,7 +26,13 @@ namespace events_executor extern "C" void RclEventCallbackTrampoline(const void * user_data, size_t number_of_events) { const auto cb = reinterpret_cast *>(user_data); - (*cb)(number_of_events); + try { + (*cb)(number_of_events); + } catch (const std::exception & e) { + // Catch and print any exception to avoid propagation to c code + std::fprintf(stderr, "%s\n", e.what()); + std::terminate(); + } } RclCallbackManager::RclCallbackManager(EventsQueue * events_queue) diff --git a/rclpy/src/rclpy/service.cpp b/rclpy/src/rclpy/service.cpp index e24d07b97..f7714d499 100644 --- a/rclpy/src/rclpy/service.cpp +++ b/rclpy/src/rclpy/service.cpp @@ -22,19 +22,26 @@ #include #include +#include #include "clock.hpp" #include "exceptions.hpp" #include "node.hpp" #include "service.hpp" #include "utils.hpp" +#include "events_executor/rcl_support.hpp" namespace rclpy { +using events_executor::RclEventCallbackTrampoline; void Service::destroy() { + try { + clear_on_new_request_callback(); + } catch (RCLError) { + } rcl_service_.reset(); node_.destroy(); } @@ -184,6 +191,41 @@ Service::configure_introspection( } } +void +Service::set_callback( + rcl_event_callback_t callback, + const void * user_data) +{ + rcl_ret_t ret = rcl_service_set_on_new_request_callback( + rcl_service_.get(), + callback, + user_data); + + if (RCL_RET_OK != ret) { + throw RCLError(std::string("Failed to set the on new request callback for service: ") + + rcl_get_error_string().str); + } +} + +void +Service::set_on_new_request_callback(std::function callback) +{ + clear_on_new_request_callback(); + on_new_request_callback_ = std::move(callback); + set_callback( + RclEventCallbackTrampoline, + static_cast(&on_new_request_callback_)); +} + +void +Service::clear_on_new_request_callback() +{ + if (on_new_request_callback_) { + set_callback(nullptr, nullptr); + on_new_request_callback_ = nullptr; + } +} + void define_service(py::object module) { @@ -211,6 +253,10 @@ define_service(py::object module) "Configure whether introspection is enabled") .def( "get_logger_name", &Service::get_logger_name, - "Get the name of the logger associated with the node of the service."); + "Get the name of the logger associated with the node of the service.") + .def( + "set_on_new_request_callback", &Service::set_on_new_request_callback, + py::arg("callback")) + .def("clear_on_new_request_callback", &Service::clear_on_new_request_callback); } } // namespace rclpy diff --git a/rclpy/src/rclpy/service.hpp b/rclpy/src/rclpy/service.hpp index 6b178e18b..6f63ebb04 100644 --- a/rclpy/src/rclpy/service.hpp +++ b/rclpy/src/rclpy/service.hpp @@ -16,6 +16,7 @@ #define RCLPY__SERVICE_HPP_ #include +#include #include #include @@ -125,10 +126,20 @@ class Service : public Destroyable, public std::enable_shared_from_this void destroy() override; + void + set_on_new_request_callback(std::function callback); + + void + clear_on_new_request_callback(); + private: Node node_; + std::function on_new_request_callback_{nullptr}; std::shared_ptr rcl_service_; rosidl_service_type_support_t * srv_type_; + + void + set_callback(rcl_event_callback_t callback, const void * user_data); }; /// Define a pybind11 wrapper for an rclpy::Service diff --git a/rclpy/src/rclpy/subscription.cpp b/rclpy/src/rclpy/subscription.cpp index eab4de9a6..e9700a198 100644 --- a/rclpy/src/rclpy/subscription.cpp +++ b/rclpy/src/rclpy/subscription.cpp @@ -24,17 +24,21 @@ #include #include #include +#include #include "exceptions.hpp" #include "node.hpp" #include "serialization.hpp" #include "subscription.hpp" #include "utils.hpp" +#include "events_executor/rcl_support.hpp" using pybind11::literals::operator""_a; namespace rclpy { +using events_executor::RclEventCallbackTrampoline; + Subscription::Subscription( Node & node, py::object pymsg_type, std::string topic, py::object pyqos_profile) @@ -87,6 +91,10 @@ Subscription::Subscription( void Subscription::destroy() { + try { + clear_on_new_message_callback(); + } catch (RCLError) { + } rcl_subscription_.reset(); node_.destroy(); } @@ -194,6 +202,41 @@ Subscription::get_publisher_count() const return count; } +void +Subscription::set_callback( + rcl_event_callback_t callback, + const void * user_data) +{ + rcl_ret_t ret = rcl_subscription_set_on_new_message_callback( + rcl_subscription_.get(), + callback, + user_data); + + if (RCL_RET_OK != ret) { + throw RCLError(std::string("Failed to set the on new message callback for subscription: ") + + rcl_get_error_string().str); + } +} + +void +Subscription::set_on_new_message_callback(std::function callback) +{ + clear_on_new_message_callback(); + on_new_message_callback_ = std::move(callback); + set_callback( + RclEventCallbackTrampoline, + static_cast(&on_new_message_callback_)); +} + +void +Subscription::clear_on_new_message_callback() +{ + if (on_new_message_callback_) { + set_callback(nullptr, nullptr); + on_new_message_callback_ = nullptr; + } +} + void define_subscription(py::object module) { @@ -215,6 +258,10 @@ define_subscription(py::object module) "Return the resolved topic name of a subscription.") .def( "get_publisher_count", &Subscription::get_publisher_count, - "Count the publishers from a subscription."); + "Count the publishers from a subscription.") + .def( + "set_on_new_message_callback", &Subscription::set_on_new_message_callback, + py::arg("callback")) + .def("clear_on_new_message_callback", &Subscription::clear_on_new_message_callback); } } // namespace rclpy diff --git a/rclpy/src/rclpy/subscription.hpp b/rclpy/src/rclpy/subscription.hpp index f46401ea2..7563de01b 100644 --- a/rclpy/src/rclpy/subscription.hpp +++ b/rclpy/src/rclpy/subscription.hpp @@ -16,6 +16,7 @@ #define RCLPY__SUBSCRIPTION_HPP_ #include +#include #include @@ -105,9 +106,19 @@ class Subscription : public Destroyable, public std::enable_shared_from_this callback); + + void + clear_on_new_message_callback(); + private: Node node_; + std::function on_new_message_callback_{nullptr}; std::shared_ptr rcl_subscription_; + + void + set_callback(rcl_event_callback_t callback, const void * user_data); }; /// Define a pybind11 wrapper for an rclpy::Subscription void define_subscription(py::object module); diff --git a/rclpy/src/rclpy/timer.cpp b/rclpy/src/rclpy/timer.cpp index 239e7b48d..4561eec78 100644 --- a/rclpy/src/rclpy/timer.cpp +++ b/rclpy/src/rclpy/timer.cpp @@ -19,17 +19,26 @@ #include #include +#include +#include #include "clock.hpp" #include "context.hpp" #include "exceptions.hpp" #include "timer.hpp" +#include "events_executor/rcl_support.hpp" namespace rclpy { +using events_executor::RclEventCallbackTrampoline; + void Timer::destroy() { + try { + clear_on_reset_callback(); + } catch (RCLError) { + } rcl_timer_.reset(); clock_.destroy(); context_.destroy(); @@ -175,6 +184,41 @@ bool Timer::is_timer_canceled() return is_canceled; } +void +Timer::set_callback( + rcl_event_callback_t callback, + const void * user_data) +{ + rcl_ret_t ret = rcl_timer_set_on_reset_callback( + rcl_timer_.get(), + callback, + user_data); + + if (RCL_RET_OK != ret) { + throw RCLError(std::string("Failed to set the on reset callback for timer: ") + + rcl_get_error_string().str); + } +} + +void +Timer::set_on_reset_callback(std::function callback) +{ + clear_on_reset_callback(); + on_reset_callback_ = std::move(callback); + set_callback( + RclEventCallbackTrampoline, + static_cast(&on_reset_callback_)); +} + +void +Timer::clear_on_reset_callback() +{ + if (on_reset_callback_) { + set_callback(nullptr, nullptr); + on_reset_callback_ = nullptr; + } +} + void define_timer(py::object module) { @@ -212,7 +256,11 @@ define_timer(py::object module) "Cancel a timer.") .def( "is_timer_canceled", &Timer::is_timer_canceled, - "Check if a timer is canceled."); + "Check if a timer is canceled.") + .def( + "set_on_reset_callback", &Timer::set_on_reset_callback, + py::arg("callback")) + .def("clear_on_reset_callback", &Timer::clear_on_reset_callback); } } // namespace rclpy diff --git a/rclpy/src/rclpy/timer.hpp b/rclpy/src/rclpy/timer.hpp index 5cbab817f..be0fa17c2 100644 --- a/rclpy/src/rclpy/timer.hpp +++ b/rclpy/src/rclpy/timer.hpp @@ -17,6 +17,7 @@ #include #include +#include #include @@ -145,10 +146,20 @@ class Timer : public Destroyable, public std::enable_shared_from_this /// Force an early destruction of this object void destroy() override; + void + set_on_reset_callback(std::function callback); + + void + clear_on_reset_callback(); + private: Context context_; Clock clock_; + std::function on_reset_callback_{nullptr}; std::shared_ptr rcl_timer_; + + void + set_callback(rcl_event_callback_t callback, const void * user_data); }; /// Define a pybind11 wrapper for an rcl_timer_t diff --git a/rclpy/test/test_client.py b/rclpy/test/test_client.py index 66d43317b..605cdec17 100644 --- a/rclpy/test/test_client.py +++ b/rclpy/test/test_client.py @@ -21,6 +21,7 @@ from typing import Tuple from typing import TYPE_CHECKING import unittest +from unittest.mock import Mock from rcl_interfaces.srv import GetParameters import rclpy @@ -295,6 +296,28 @@ def test_logger_name_is_equal_to_node_name(self) -> None: with self.node.create_client(GetParameters, 'get/parameters') as cli: self.assertEqual(cli.logger_name, 'TestClient') + def test_on_new_response_callback(self) -> None: + def _service(request, response): + return response + with self.node.create_client(Empty, '/service') as cli: + with self.node.create_service(Empty, '/service', _service): + executor = rclpy.executors.SingleThreadedExecutor(context=self.context) + try: + self.assertTrue(cli.wait_for_service(timeout_sec=20)) + executor.add_node(self.node) + cb = Mock() + cli.handle.set_on_new_response_callback(cb) + cb.assert_not_called() + cli.call_async(Empty.Request()) + executor.spin_once(0) + cb.assert_called_once_with(1) + cli.handle.clear_on_new_response_callback() + cli.call_async(Empty.Request()) + executor.spin_once(0) + cb.assert_called_once() + finally: + executor.shutdown() + if __name__ == '__main__': unittest.main() diff --git a/rclpy/test/test_service.py b/rclpy/test/test_service.py index ef8fea102..dc07254a4 100644 --- a/rclpy/test/test_service.py +++ b/rclpy/test/test_service.py @@ -15,6 +15,7 @@ from typing import Generator from typing import List from typing import Optional +from unittest.mock import Mock import pytest @@ -116,3 +117,16 @@ def test_service_context_manager() -> None: with node.create_service( srv_type=Empty, srv_name='empty_service', callback=lambda _, _1: None) as srv: assert srv.service_name == '/empty_service' + + +def test_set_on_new_request_callback(test_node) -> None: + cli = test_node.create_client(Empty, '/service') + srv = test_node.create_service(Empty, '/service', lambda req, res: res) + cb = Mock() + srv.handle.set_on_new_request_callback(cb) + cb.assert_not_called() + cli.call_async(Empty.Request()) + cb.assert_called_once_with(1) + srv.handle.clear_on_new_request_callback() + cli.call_async(Empty.Request()) + cb.assert_called_once() diff --git a/rclpy/test/test_subscription.py b/rclpy/test/test_subscription.py index 12dd78eb3..421eebcae 100644 --- a/rclpy/test/test_subscription.py +++ b/rclpy/test/test_subscription.py @@ -15,6 +15,7 @@ import time from typing import List from typing import Optional +from unittest.mock import Mock import pytest @@ -176,3 +177,21 @@ def test_subscription_publisher_count() -> None: sub.destroy() node.destroy_node() + + +def test_on_new_message_callback(test_node) -> None: + topic_name = '/topic' + cb = Mock() + sub = test_node.create_subscription( + msg_type=Empty, + topic=topic_name, + qos_profile=10, + callback=cb) + pub = test_node.create_publisher(Empty, topic_name, 10) + sub.handle.set_on_new_message_callback(cb) + cb.assert_not_called() + pub.publish(Empty()) + cb.assert_called_once_with(1) + sub.handle.clear_on_new_message_callback() + pub.publish(Empty()) + cb.assert_called_once() diff --git a/rclpy/test/test_timer.py b/rclpy/test/test_timer.py index 561dbfcad..f50692657 100644 --- a/rclpy/test/test_timer.py +++ b/rclpy/test/test_timer.py @@ -18,6 +18,7 @@ import time from typing import List from typing import Optional +from unittest.mock import Mock import pytest import rclpy @@ -27,6 +28,25 @@ from rclpy.timer import TimerInfo +@pytest.fixture +def context() -> None: + return rclpy.context.Context() + + +@pytest.fixture +def setup_ros(context) -> None: + rclpy.init(context=context) + yield + rclpy.shutdown(context=context) + + +@pytest.fixture +def test_node(context, setup_ros): + node = rclpy.create_node('test_node', context=context) + yield node + node.destroy_node() + + TEST_PERIODS = ( 0.1, pytest.param( @@ -328,3 +348,15 @@ def timer_callback(info: TimerInfo) -> None: node.destroy_timer(timer) node.destroy_node() rclpy.shutdown(context=context) + + +def test_on_reset_callback(test_node): + tmr = test_node.create_timer(1, lambda: None) + cb = Mock() + tmr.handle.set_on_reset_callback(cb) + cb.assert_not_called() + tmr.reset() + cb.assert_called_once_with(1) + tmr.handle.clear_on_reset_callback() + tmr.reset() + cb.assert_called_once()