diff --git a/include/envoy/server/lifecycle_notifier.h b/include/envoy/server/lifecycle_notifier.h index 32a5ae4c5da58..cd829f61b991f 100644 --- a/include/envoy/server/lifecycle_notifier.h +++ b/include/envoy/server/lifecycle_notifier.h @@ -30,6 +30,13 @@ class ServerLifecycleNotifier { ShutdownExit }; + // A handle to a callback registration. Deleting this handle will unregister the callback. + class Handle { + public: + virtual ~Handle() = default; + }; + using HandlePtr = std::unique_ptr; + /** * Callback invoked when the server reaches a certain lifecycle stage. * @@ -49,8 +56,8 @@ class ServerLifecycleNotifier { * The second version which takes a completion back is currently only supported * for the ShutdownExit stage. */ - virtual void registerCallback(Stage stage, StageCallback callback) PURE; - virtual void registerCallback(Stage stage, StageCallbackWithCompletion callback) PURE; + virtual HandlePtr registerCallback(Stage stage, StageCallback callback) PURE; + virtual HandlePtr registerCallback(Stage stage, StageCallbackWithCompletion callback) PURE; }; } // namespace Server diff --git a/source/server/config_validation/server.h b/source/server/config_validation/server.h index b492102c86d71..a63cbfd0dea6d 100644 --- a/source/server/config_validation/server.h +++ b/source/server/config_validation/server.h @@ -139,8 +139,12 @@ class ValidationInstance : Logger::Loggable, } // ServerLifecycleNotifier - void registerCallback(Stage, StageCallback) override {} - void registerCallback(Stage, StageCallbackWithCompletion) override {} + ServerLifecycleNotifier::HandlePtr registerCallback(Stage, StageCallback) override { + return nullptr; + } + ServerLifecycleNotifier::HandlePtr registerCallback(Stage, StageCallbackWithCompletion) override { + return nullptr; + } private: void initialize(const Options& options, Network::Address::InstanceConstSharedPtr local_address, diff --git a/source/server/server.cc b/source/server/server.cc index 0a039883d0f0f..55e569bdb4908 100644 --- a/source/server/server.cc +++ b/source/server/server.cc @@ -547,13 +547,19 @@ void InstanceImpl::shutdownAdmin() { restarter_.sendParentTerminateRequest(); } -void InstanceImpl::registerCallback(Stage stage, StageCallback callback) { - stage_callbacks_[stage].push_back(callback); +ServerLifecycleNotifier::HandlePtr InstanceImpl::registerCallback(Stage stage, + StageCallback callback) { + auto& callbacks = stage_callbacks_[stage]; + return absl::make_unique>( + callbacks, callbacks.insert(callbacks.end(), callback)); } -void InstanceImpl::registerCallback(Stage stage, StageCallbackWithCompletion callback) { +ServerLifecycleNotifier::HandlePtr +InstanceImpl::registerCallback(Stage stage, StageCallbackWithCompletion callback) { ASSERT(stage == Stage::ShutdownExit); - stage_completable_callbacks_[stage].push_back(callback); + auto& callbacks = stage_completable_callbacks_[stage]; + return absl::make_unique>( + callbacks, callbacks.insert(callbacks.end(), callback)); } void InstanceImpl::notifyCallbacksForStage(Stage stage, Event::PostCb completion_cb) { diff --git a/source/server/server.h b/source/server/server.h index 6b2fdf0c0cd3d..822cbd7f8a8bf 100644 --- a/source/server/server.h +++ b/source/server/server.h @@ -194,8 +194,9 @@ class InstanceImpl : Logger::Loggable, } // ServerLifecycleNotifier - void registerCallback(Stage stage, StageCallback callback) override; - void registerCallback(Stage stage, StageCallbackWithCompletion callback) override; + ServerLifecycleNotifier::HandlePtr registerCallback(Stage stage, StageCallback callback) override; + ServerLifecycleNotifier::HandlePtr + registerCallback(Stage stage, StageCallbackWithCompletion callback) override; private: ProtobufTypes::MessagePtr dumpBootstrapConfig(); @@ -260,8 +261,23 @@ class InstanceImpl : Logger::Loggable, Http::ContextImpl http_context_; std::unique_ptr heap_shrinker_; const std::thread::id main_thread_id_; - absl::flat_hash_map> stage_callbacks_; - absl::flat_hash_map> stage_completable_callbacks_; + + using LifecycleNotifierCallbacks = std::list; + using LifecycleNotifierCompletionCallbacks = std::list; + + template class LifecycleCallbackHandle : public ServerLifecycleNotifier::Handle { + public: + LifecycleCallbackHandle(T& callbacks, typename T::iterator it) + : callbacks_(callbacks), it_(it) {} + ~LifecycleCallbackHandle() override { callbacks_.erase(it_); } + + private: + T& callbacks_; + typename T::iterator it_; + }; + + absl::flat_hash_map stage_callbacks_; + absl::flat_hash_map stage_completable_callbacks_; }; } // namespace Server diff --git a/test/mocks/server/mocks.h b/test/mocks/server/mocks.h index 93ccbc954df52..26d0d62f3b80c 100644 --- a/test/mocks/server/mocks.h +++ b/test/mocks/server/mocks.h @@ -265,8 +265,9 @@ class MockServerLifecycleNotifier : public ServerLifecycleNotifier { MockServerLifecycleNotifier(); ~MockServerLifecycleNotifier(); - MOCK_METHOD2(registerCallback, void(Stage, StageCallback)); - MOCK_METHOD2(registerCallback, void(Stage, StageCallbackWithCompletion)); + MOCK_METHOD2(registerCallback, ServerLifecycleNotifier::HandlePtr(Stage, StageCallback)); + MOCK_METHOD2(registerCallback, + ServerLifecycleNotifier::HandlePtr(Stage, StageCallbackWithCompletion)); }; class MockWorkerFactory : public WorkerFactory { diff --git a/test/server/server_test.cc b/test/server/server_test.cc index c9d68e3f5a724..9ebd5790ecbe7 100644 --- a/test/server/server_test.cc +++ b/test/server/server_test.cc @@ -182,23 +182,32 @@ TEST_P(ServerInstanceImplTest, LifecycleNotifications) { // Run the server in a separate thread so we can test different lifecycle stages. auto server_thread = Thread::threadFactoryForTest().createThread([&] { initialize("test/server/node_bootstrap.yaml"); - server_->registerCallback(ServerLifecycleNotifier::Stage::Startup, [&] { + auto handle1 = server_->registerCallback(ServerLifecycleNotifier::Stage::Startup, [&] { startup = true; started.Notify(); }); - server_->registerCallback(ServerLifecycleNotifier::Stage::ShutdownExit, [&] { + auto handle2 = server_->registerCallback(ServerLifecycleNotifier::Stage::ShutdownExit, [&] { shutdown = true; shutdown_begin.Notify(); }); - server_->registerCallback(ServerLifecycleNotifier::Stage::ShutdownExit, - [&](Event::PostCb completion_cb) { - // Block till we're told to complete - completion_block.WaitForNotification(); - shutdown_with_completion = true; - server_->dispatcher().post(completion_cb); - completion_done.Notify(); - }); + auto handle3 = server_->registerCallback(ServerLifecycleNotifier::Stage::ShutdownExit, + [&](Event::PostCb completion_cb) { + // Block till we're told to complete + completion_block.WaitForNotification(); + shutdown_with_completion = true; + server_->dispatcher().post(completion_cb); + completion_done.Notify(); + }); + auto handle4 = + server_->registerCallback(ServerLifecycleNotifier::Stage::Startup, [&] { FAIL(); }); + handle4 = server_->registerCallback(ServerLifecycleNotifier::Stage::ShutdownExit, + [&](Event::PostCb) { FAIL(); }); + handle4 = nullptr; + server_->run(); + handle1 = nullptr; + handle2 = nullptr; + handle3 = nullptr; server_ = nullptr; thread_local_ = nullptr; });