diff --git a/include/ur_client_library/control/reverse_interface.h b/include/ur_client_library/control/reverse_interface.h index dbfb407e..82202079 100644 --- a/include/ur_client_library/control/reverse_interface.h +++ b/include/ur_client_library/control/reverse_interface.h @@ -38,6 +38,7 @@ #include #include #include +#include namespace urcl { @@ -156,9 +157,31 @@ class ReverseInterface "commands.")]] virtual void setKeepaliveCount(const uint32_t count); - void registerDisconnectionCallback(std::function disconnection_fun) + /*! + * \brief Register a callback for the robot-based disconnection. + * + * The callback will be called when the robot disconnects from the reverse interface. + * + * \param disconnection_fun The function to be called on disconnection. + * + * \returns A unique handler ID for the registered callback. This can be used to unregister the + * callback later. + */ + uint32_t registerDisconnectionCallback(std::function disconnection_fun) + { + disconnect_callbacks_.push_back({ next_disconnect_callback_id_, disconnection_fun }); + return next_disconnect_callback_id_++; + } + + /*! \brief Unregisters a disconnection callback. + * + * \param handler_id The ID of the handler to be unregistered as obtained from + * registerDisconnectionCallback. + */ + void unregisterDisconnectionCallback(const uint32_t handler_id) { - disconnection_callback_ = disconnection_fun; + disconnect_callbacks_.remove_if( + [handler_id](const HandlerFunction& h) { return h.id == handler_id; }); } /*! @@ -178,7 +201,8 @@ class ReverseInterface virtual void messageCallback(const socket_t filedescriptor, char* buffer, int nbytesrecv); - std::function disconnection_callback_ = nullptr; + std::list> disconnect_callbacks_; + uint32_t next_disconnect_callback_id_ = 0; socket_t client_fd_; comm::TCPServer server_; diff --git a/include/ur_client_library/control/trajectory_point_interface.h b/include/ur_client_library/control/trajectory_point_interface.h index 1a131fcd..22d3442f 100644 --- a/include/ur_client_library/control/trajectory_point_interface.h +++ b/include/ur_client_library/control/trajectory_point_interface.h @@ -29,6 +29,7 @@ #ifndef UR_CLIENT_LIBRARY_TRAJECTORY_INTERFACE_H_INCLUDED #define UR_CLIENT_LIBRARY_TRAJECTORY_INTERFACE_H_INCLUDED +#include #include #include "ur_client_library/control/motion_primitives.h" @@ -133,10 +134,11 @@ class TrajectoryPointInterface : public ReverseInterface */ bool writeMotionPrimitive(const std::shared_ptr primitive); - void setTrajectoryEndCallback(std::function callback) - { - handle_trajectory_end_ = callback; - } + void setTrajectoryEndCallback(std::function callback); + + uint32_t addTrajectoryEndCallback(const std::function& callback); + + void removeTrajectoryEndCallback(const uint32_t callback_id); protected: virtual void connectionCallback(const socket_t filedescriptor) override; @@ -146,7 +148,8 @@ class TrajectoryPointInterface : public ReverseInterface virtual void messageCallback(const socket_t filedescriptor, char* buffer, int nbytesrecv) override; private: - std::function handle_trajectory_end_; + std::list> trajectory_end_callbacks_; + uint32_t next_done_callback_id_ = 0; }; } // namespace control diff --git a/include/ur_client_library/types.h b/include/ur_client_library/types.h index 9022daeb..86192577 100644 --- a/include/ur_client_library/types.h +++ b/include/ur_client_library/types.h @@ -22,7 +22,9 @@ #include #include +#include #include +#include "ur_client_library/log.h" namespace urcl { @@ -81,4 +83,20 @@ constexpr typename std::underlying_type::type toUnderlying(const E e) noexcep { return static_cast::type>(e); } -} // namespace urcl \ No newline at end of file + +template +struct HandlerFunction +{ + uint32_t id; + std::function function; + + HandlerFunction(uint32_t id, std::function function) : id(id), function(function) + { + } + + bool operator==(const HandlerFunction& other) const + { + return id == other.id; + } +}; +} // namespace urcl diff --git a/include/ur_client_library/ur/instruction_executor.h b/include/ur_client_library/ur/instruction_executor.h index 664b4e1c..ecec111d 100644 --- a/include/ur_client_library/ur/instruction_executor.h +++ b/include/ur_client_library/ur/instruction_executor.h @@ -42,16 +42,16 @@ class InstructionExecutor InstructionExecutor() = delete; InstructionExecutor(std::shared_ptr driver) : driver_(driver) { - driver_->registerTrajectoryDoneCallback( + traj_done_callback_handler_id_ = driver_->registerTrajectoryDoneCallback( std::bind(&InstructionExecutor::trajDoneCallback, this, std::placeholders::_1)); - driver_->registerTrajectoryInterfaceDisconnectedCallback( + disconnected_handler_id_ = driver_->registerTrajectoryInterfaceDisconnectedCallback( std::bind(&InstructionExecutor::trajDisconnectCallback, this, std::placeholders::_1)); } ~InstructionExecutor() { - driver_->registerTrajectoryDoneCallback(nullptr); - driver_->registerTrajectoryInterfaceDisconnectedCallback(nullptr); + driver_->unregisterTrajectoryDoneCallback(traj_done_callback_handler_id_); + driver_->unregisterTrajectoryInterfaceDisconnectedCallback(disconnected_handler_id_); } /** @@ -187,10 +187,13 @@ class InstructionExecutor return trajectory_running_; } -private: +protected: void trajDoneCallback(const urcl::control::TrajectoryResult& result); void trajDisconnectCallback(const int filedescriptor); + uint32_t traj_done_callback_handler_id_; + uint32_t disconnected_handler_id_; + std::shared_ptr driver_; std::atomic trajectory_running_ = false; std::atomic cancel_requested_ = false; diff --git a/include/ur_client_library/ur/ur_driver.h b/include/ur_client_library/ur/ur_driver.h index 49951940..2ba9efbc 100644 --- a/include/ur_client_library/ur/ur_driver.h +++ b/include/ur_client_library/ur/ur_driver.h @@ -823,10 +823,17 @@ class UrDriver * * \param trajectory_done_cb Callback function that will be triggered in the event of finishing * a trajectory execution + * + * \returns The ID of the callback that can be used to unregister the callback later. */ - void registerTrajectoryDoneCallback(std::function trajectory_done_cb) + uint32_t registerTrajectoryDoneCallback(std::function trajectory_done_cb) + { + return trajectory_interface_->addTrajectoryEndCallback(trajectory_done_cb); + } + + void unregisterTrajectoryDoneCallback(const uint32_t handler_id) { - trajectory_interface_->setTrajectoryEndCallback(trajectory_done_cb); + trajectory_interface_->removeTrajectoryEndCallback(handler_id); } /*! @@ -887,9 +894,32 @@ class UrDriver primary_client_->stop(); } - void registerTrajectoryInterfaceDisconnectedCallback(std::function fun) + /*! + * \brief Register a callback for the trajectory interface disconnection. + * + * This callback will be called when the trajectory interface is disconnected. + * + * \param fun Callback function that will be triggered in the event of disconnection + * + * \returns The ID of the callback that can be used to unregister the callback later. + */ + uint32_t registerTrajectoryInterfaceDisconnectedCallback(std::function fun) + { + return trajectory_interface_->registerDisconnectionCallback(fun); + } + + /*! + * \brief Unregister a callback for the trajectory interface disconnection. + * + * This will remove the callback that was registered with + * registerTrajectoryInterfaceDisconnectedCallback. + * + * \param handler_id The ID of the callback to be removed as obtained from + * registerTrajectoryInterfaceDisconnectedCallback. + */ + void unregisterTrajectoryInterfaceDisconnectedCallback(const uint32_t handler_id) { - trajectory_interface_->registerDisconnectionCallback(fun); + trajectory_interface_->unregisterDisconnectionCallback(handler_id); } /*! diff --git a/src/control/reverse_interface.cpp b/src/control/reverse_interface.cpp index ae65ff94..d9fec1c7 100644 --- a/src/control/reverse_interface.cpp +++ b/src/control/reverse_interface.cpp @@ -238,6 +238,10 @@ void ReverseInterface::disconnectionCallback(const socket_t filedescriptor) URCL_LOG_INFO("Connection to reverse interface dropped.", filedescriptor); client_fd_ = INVALID_SOCKET; handle_program_state_(false); + for (auto handler : disconnect_callbacks_) + { + handler.function(filedescriptor); + } } void ReverseInterface::messageCallback(const socket_t filedescriptor, char* buffer, int nbytesrecv) diff --git a/src/control/trajectory_point_interface.cpp b/src/control/trajectory_point_interface.cpp index cf61506f..88e90098 100644 --- a/src/control/trajectory_point_interface.cpp +++ b/src/control/trajectory_point_interface.cpp @@ -275,9 +275,9 @@ void TrajectoryPointInterface::connectionCallback(const socket_t filedescriptor) void TrajectoryPointInterface::disconnectionCallback(const socket_t filedescriptor) { URCL_LOG_DEBUG("Connection to trajectory interface dropped."); - if (disconnection_callback_ != nullptr) + for (auto handler : disconnect_callbacks_) { - disconnection_callback_(filedescriptor); + handler.function(filedescriptor); } client_fd_ = INVALID_SOCKET; } @@ -289,9 +289,12 @@ void TrajectoryPointInterface::messageCallback(const socket_t filedescriptor, ch int32_t* status = reinterpret_cast(buffer); URCL_LOG_DEBUG("Received message %d on TrajectoryPointInterface", be32toh(*status)); - if (handle_trajectory_end_) + if (!trajectory_end_callbacks_.empty()) { - handle_trajectory_end_(static_cast(be32toh(*status))); + for (auto handler : trajectory_end_callbacks_) + { + handler.function(static_cast(be32toh(*status))); + } } else { @@ -304,5 +307,23 @@ void TrajectoryPointInterface::messageCallback(const socket_t filedescriptor, ch nbytesrecv); } } + +void TrajectoryPointInterface::setTrajectoryEndCallback(std::function callback) +{ + addTrajectoryEndCallback(callback); +} + +uint32_t TrajectoryPointInterface::addTrajectoryEndCallback(const std::function& callback) +{ + trajectory_end_callbacks_.push_back({ next_done_callback_id_, callback }); + return next_done_callback_id_++; +} + +void TrajectoryPointInterface::removeTrajectoryEndCallback(const uint32_t handler_id) +{ + trajectory_end_callbacks_.remove_if( + [handler_id](const HandlerFunction& h) { return h.id == handler_id; }); +} + } // namespace control } // namespace urcl diff --git a/src/ur/instruction_executor.cpp b/src/ur/instruction_executor.cpp index 5828c830..48d8ffab 100644 --- a/src/ur/instruction_executor.cpp +++ b/src/ur/instruction_executor.cpp @@ -36,10 +36,13 @@ void urcl::InstructionExecutor::trajDoneCallback(const urcl::control::TrajectoryResult& result) { URCL_LOG_DEBUG("Trajectory result received: %s", control::trajectoryResultToString(result).c_str()); - std::unique_lock lock(trajectory_result_mutex_); - trajectory_done_cv_.notify_all(); - trajectory_result_ = result; - trajectory_running_ = false; + if (trajectory_running_) + { + std::unique_lock lock(trajectory_result_mutex_); + trajectory_done_cv_.notify_all(); + trajectory_result_ = result; + trajectory_running_ = false; + } } void urcl::InstructionExecutor::trajDisconnectCallback(const int filedescriptor) { diff --git a/tests/test_reverse_interface.cpp b/tests/test_reverse_interface.cpp index 3674b755..9950b111 100644 --- a/tests/test_reverse_interface.cpp +++ b/tests/test_reverse_interface.cpp @@ -32,8 +32,38 @@ #include #include #include +#include "ur_client_library/log.h" using namespace urcl; +std::mutex g_connection_mutex; +std::condition_variable g_connection_condition; + +class TestableReverseInterface : public control::ReverseInterface +{ +public: + TestableReverseInterface(const control::ReverseInterfaceConfig& config) : control::ReverseInterface(config) + { + } + + virtual void connectionCallback(const socket_t filedescriptor) + { + control::ReverseInterface::connectionCallback(filedescriptor); + connected = true; + std::lock_guard lk(g_connection_mutex); + g_connection_condition.notify_one(); + } + + virtual void disconnectionCallback(const socket_t filedescriptor) + { + URCL_LOG_DEBUG("There are %zu disconnection callbacks registered.", disconnect_callbacks_.size()); + control::ReverseInterface::disconnectionCallback(filedescriptor); + connected = false; + std::lock_guard lk(g_connection_mutex); + g_connection_condition.notify_one(); + } + + std::atomic connected = false; +}; class ReverseIntefaceTest : public ::testing::Test { @@ -153,8 +183,11 @@ class ReverseIntefaceTest : public ::testing::Test control::ReverseInterfaceConfig config; config.port = 50001; config.handle_program_state = std::bind(&ReverseIntefaceTest::handleProgramState, this, std::placeholders::_1); - reverse_interface_.reset(new control::ReverseInterface(config)); + reverse_interface_.reset(new TestableReverseInterface(config)); client_.reset(new Client(50001)); + std::unique_lock lk(g_connection_mutex); + g_connection_condition.wait_for(lk, std::chrono::seconds(1), + [&]() { return reverse_interface_->connected.load(); }); } void TearDown() @@ -187,7 +220,7 @@ class ReverseIntefaceTest : public ::testing::Test return false; } - std::unique_ptr reverse_interface_; + std::unique_ptr reverse_interface_; std::unique_ptr client_; private: @@ -450,9 +483,63 @@ TEST_F(ReverseIntefaceTest, deprecated_set_keep_alive_count) EXPECT_EQ(expected_read_timeout, received_read_timeout); } +TEST_F(ReverseIntefaceTest, disconnected_callbacks_are_called) +{ + // Wait for the client to connect to the server + EXPECT_TRUE(waitForProgramState(1000, true)); + + std::atomic disconnect_called_1 = false; + std::atomic disconnect_called_2 = false; + + // Register disconnection callbacks + int disconnection_callback_id_1 = + reverse_interface_->registerDisconnectionCallback([&disconnect_called_1](const int fd) { + std::cout << "Disconnection 1 callback called with fd: " << fd << std::endl; + disconnect_called_1 = true; + }); + int disconnection_callback_id_2 = + reverse_interface_->registerDisconnectionCallback([&disconnect_called_2](const int fd) { + std::cout << "Disconnection 2 callback called with fd: " << fd << std::endl; + disconnect_called_2 = true; + }); + + // Close the client connection + client_->close(); + EXPECT_TRUE(waitForProgramState(1000, false)); + std::unique_lock lk(g_connection_mutex); + g_connection_condition.wait_for(lk, std::chrono::seconds(1), [&]() { return !reverse_interface_->connected.load(); }); + EXPECT_TRUE(disconnect_called_1); + EXPECT_TRUE(disconnect_called_2); + + // Unregister 1. 2 should still be called + disconnect_called_1 = false; + disconnect_called_2 = false; + client_.reset(new Client(50001)); + EXPECT_TRUE(waitForProgramState(1000, true)); + reverse_interface_->unregisterDisconnectionCallback(disconnection_callback_id_1); + client_->close(); + g_connection_condition.wait_for(lk, std::chrono::seconds(1), [&]() { return !reverse_interface_->connected.load(); }); + EXPECT_TRUE(waitForProgramState(1000, false)); + EXPECT_FALSE(disconnect_called_1); + EXPECT_TRUE(disconnect_called_2); + + // Unregister both. None should be called + disconnect_called_1 = false; + disconnect_called_2 = false; + client_.reset(new Client(50001)); + EXPECT_TRUE(waitForProgramState(1000, true)); + reverse_interface_->unregisterDisconnectionCallback(disconnection_callback_id_2); + client_->close(); + g_connection_condition.wait_for(lk, std::chrono::seconds(1), [&]() { return !reverse_interface_->connected.load(); }); + EXPECT_TRUE(waitForProgramState(1000, false)); + EXPECT_FALSE(disconnect_called_1); + EXPECT_FALSE(disconnect_called_2); +} + int main(int argc, char* argv[]) { ::testing::InitGoogleTest(&argc, argv); + urcl::setLogLevel(LogLevel::INFO); return RUN_ALL_TESTS(); } diff --git a/tests/test_trajectory_point_interface.cpp b/tests/test_trajectory_point_interface.cpp index 57607b0e..bd843e43 100644 --- a/tests/test_trajectory_point_interface.cpp +++ b/tests/test_trajectory_point_interface.cpp @@ -37,6 +37,36 @@ using namespace urcl; +std::mutex g_connection_mutex; +std::condition_variable g_connection_condition; + +class TestableTrajectoryPointInterface : public control::TrajectoryPointInterface +{ +public: + TestableTrajectoryPointInterface(uint32_t port) : control::TrajectoryPointInterface(port) + { + } + + virtual void connectionCallback(const socket_t filedescriptor) override + { + control::TrajectoryPointInterface::connectionCallback(filedescriptor); + connected = true; + std::lock_guard lk(g_connection_mutex); + g_connection_condition.notify_one(); + } + + virtual void disconnectionCallback(const socket_t filedescriptor) override + { + URCL_LOG_DEBUG("There are %zu disconnection callbacks registered.", disconnect_callbacks_.size()); + control::TrajectoryPointInterface::disconnectionCallback(filedescriptor); + connected = false; + std::lock_guard lk(g_connection_mutex); + g_connection_condition.notify_one(); + } + + std::atomic connected = false; //!< True, if the interface is connected to the robot. +}; + class TrajectoryPointInterfaceTest : public ::testing::Test { protected: @@ -257,10 +287,12 @@ class TrajectoryPointInterfaceTest : public ::testing::Test void SetUp() { - traj_point_interface_.reset(new control::TrajectoryPointInterface(50003)); + traj_point_interface_.reset(new TestableTrajectoryPointInterface(50003)); client_.reset(new Client(50003)); // Need to be sure that the client has connected to the server - std::this_thread::sleep_for(std::chrono::seconds(1)); + std::unique_lock lk(g_connection_mutex); + g_connection_condition.wait_for(lk, std::chrono::seconds(1), + [&]() { return traj_point_interface_->connected.load(); }); } void TearDown() @@ -271,7 +303,7 @@ class TrajectoryPointInterfaceTest : public ::testing::Test } } - std::unique_ptr traj_point_interface_; + std::unique_ptr traj_point_interface_; std::unique_ptr client_; public: @@ -599,6 +631,59 @@ TEST_F(TrajectoryPointInterfaceTest, unsupported_motion_type_throws) EXPECT_THROW(traj_point_interface_->writeMotionPrimitive(primitive), urcl::UnsupportedMotionType); } +TEST_F(TrajectoryPointInterfaceTest, disconnected_callbacks_are_called_correctly) +{ + std::atomic disconnect_called_1 = false; + std::atomic disconnect_called_2 = false; + + // Register disconnection callbacks + int disconnection_callback_id_1 = + traj_point_interface_->registerDisconnectionCallback([&disconnect_called_1](const int fd) { + std::cout << "Disconnection 1 callback called with fd: " << fd << std::endl; + disconnect_called_1 = true; + }); + int disconnection_callback_id_2 = + traj_point_interface_->registerDisconnectionCallback([&disconnect_called_2](const int fd) { + std::cout << "Disconnection 2 callback called with fd: " << fd << std::endl; + disconnect_called_2 = true; + }); + + // Close the client connection + client_->close(); + std::unique_lock lk(g_connection_mutex); + g_connection_condition.wait_for(lk, std::chrono::seconds(1), + [&]() { return !traj_point_interface_->connected.load(); }); + EXPECT_TRUE(disconnect_called_1); + EXPECT_TRUE(disconnect_called_2); + + // Unregister 1. 2 should still be called + disconnect_called_1 = false; + disconnect_called_2 = false; + client_.reset(new Client(50003)); + g_connection_condition.wait_for(lk, std::chrono::seconds(1), + [&]() { return traj_point_interface_->connected.load(); }); + + traj_point_interface_->unregisterDisconnectionCallback(disconnection_callback_id_1); + client_->close(); + g_connection_condition.wait_for(lk, std::chrono::seconds(1), + [&]() { return !traj_point_interface_->connected.load(); }); + EXPECT_FALSE(disconnect_called_1); + EXPECT_TRUE(disconnect_called_2); + + // Unregister both. None should be called + disconnect_called_1 = false; + disconnect_called_2 = false; + client_.reset(new Client(50003)); + g_connection_condition.wait_for(lk, std::chrono::seconds(1), + [&]() { return traj_point_interface_->connected.load(); }); + traj_point_interface_->unregisterDisconnectionCallback(disconnection_callback_id_2); + client_->close(); + g_connection_condition.wait_for(lk, std::chrono::seconds(1), + [&]() { return !traj_point_interface_->connected.load(); }); + EXPECT_FALSE(disconnect_called_1); + EXPECT_FALSE(disconnect_called_2); +} + int main(int argc, char* argv[]) { ::testing::InitGoogleTest(&argc, argv);