diff --git a/include/envoy/network/drain_decision.h b/include/envoy/network/drain_decision.h index 5c5b60a2aac81..454bef1254cd1 100644 --- a/include/envoy/network/drain_decision.h +++ b/include/envoy/network/drain_decision.h @@ -13,7 +13,7 @@ class DrainDecision { * @return TRUE if a connection should be drained and closed. It is up to individual network * filters to determine when this should be called for the least impact possible. */ - virtual bool drainClose() PURE; + virtual bool drainClose() const PURE; }; } // namespace Network diff --git a/include/envoy/server/BUILD b/include/envoy/server/BUILD index 0188d9a4daf0f..750940fd09401 100644 --- a/include/envoy/server/BUILD +++ b/include/envoy/server/BUILD @@ -100,11 +100,11 @@ envoy_cc_library( name = "filter_config_interface", hdrs = ["filter_config.h"], deps = [ - ":drain_manager_interface", "//include/envoy/access_log:access_log_interface", "//include/envoy/http:filter_interface", "//include/envoy/init:init_interface", "//include/envoy/local_info:local_info_interface", + "//include/envoy/network:drain_decision_interface", "//include/envoy/ratelimit:ratelimit_interface", "//include/envoy/runtime:runtime_interface", "//include/envoy/thread_local:thread_local_interface", @@ -117,6 +117,7 @@ envoy_cc_library( name = "listener_manager_interface", hdrs = ["listener_manager.h"], deps = [ + ":drain_manager_interface", ":filter_config_interface", ":guarddog_interface", "//include/envoy/json:json_object_interface", diff --git a/include/envoy/server/drain_manager.h b/include/envoy/server/drain_manager.h index c28402347acb3..d9398ee00dfd3 100644 --- a/include/envoy/server/drain_manager.h +++ b/include/envoy/server/drain_manager.h @@ -8,20 +8,17 @@ namespace Envoy { namespace Server { /** - * Handles connection draining. An instance is generally shared across the entire server. + * Handles connection draining. This concept is used globally during hot restart / server draining + * as well as on individual listeners when they are being dynamically removed. */ class DrainManager : public Network::DrainDecision { public: /** - * @return TRUE if the manager is currently draining connections. + * Invoked to begin the drain procedure. (Making drain close operations more likely). + * @param completion supplies the completion that will be called when the drain sequence is + * finished. The parameter is optional and can be an unassigned function. */ - virtual bool draining() PURE; - - /** - * Invoked in the secondary process to begin the drain procedure. (Making drain close operations - * more likely). - */ - virtual void startDrainSequence() PURE; + virtual void startDrainSequence(std::function completion) PURE; /** * Invoked in the newly launched primary process to begin the parent shutdown sequence. At the end diff --git a/include/envoy/server/filter_config.h b/include/envoy/server/filter_config.h index 1f6ea104cec6b..41b8d20454ce5 100644 --- a/include/envoy/server/filter_config.h +++ b/include/envoy/server/filter_config.h @@ -6,10 +6,10 @@ #include "envoy/http/filter.h" #include "envoy/init/init.h" #include "envoy/json/json_object.h" +#include "envoy/network/drain_decision.h" #include "envoy/network/filter.h" #include "envoy/ratelimit/ratelimit.h" #include "envoy/runtime/runtime.h" -#include "envoy/server/drain_manager.h" #include "envoy/tracing/http_tracer.h" #include "envoy/upstream/cluster_manager.h" @@ -47,9 +47,10 @@ class FactoryContext { virtual Event::Dispatcher& dispatcher() PURE; /** - * @return DrainManager& singleton for use by the entire server. + * @return const Network::DrainDecision& a drain decision that filters can use to determine if + * they should be doing graceful closes on connections when possible. */ - virtual DrainManager& drainManager() PURE; + virtual const Network::DrainDecision& drainDecision() PURE; /** * @return whether external healthchecks are currently failed or not. diff --git a/include/envoy/server/instance.h b/include/envoy/server/instance.h index b922a6f2c757b..a22f4ea4ed660 100644 --- a/include/envoy/server/instance.h +++ b/include/envoy/server/instance.h @@ -62,21 +62,15 @@ class Instance { */ virtual Network::DnsResolverSharedPtr dnsResolver() PURE; - /** - * @return TRUE if the server is currently draining. No new connections will be received and - * filters should shed connections where possible. - */ - virtual bool draining() PURE; - /** * Close the server's listening sockets and begin draining the listeners. */ virtual void drainListeners() PURE; /** - * @return DrainManager& singleton for use by the entire server. + * @return const DrainManager& singleton for use by the entire server. */ - virtual DrainManager& drainManager() PURE; + virtual const DrainManager& drainManager() PURE; /** * @return AccessLogManager for use by the entire server. diff --git a/include/envoy/server/listener_manager.h b/include/envoy/server/listener_manager.h index e3f6060ed9db1..1f9310859a7c7 100644 --- a/include/envoy/server/listener_manager.h +++ b/include/envoy/server/listener_manager.h @@ -3,6 +3,7 @@ #include "envoy/json/json_object.h" #include "envoy/network/filter.h" #include "envoy/network/listen_socket.h" +#include "envoy/server/drain_manager.h" #include "envoy/server/filter_config.h" #include "envoy/server/guarddog.h" #include "envoy/ssl/context.h" @@ -36,6 +37,11 @@ class ListenerComponentFactory { createFilterFactoryList(const std::vector& filters, Configuration::FactoryContext& context) PURE; + /** + * @return DrainManagerPtr a new drain manager. + */ + virtual DrainManagerPtr createDrainManager() PURE; + /** * @return uint64_t a listener tag usable for connection handler tracking. */ diff --git a/source/common/http/conn_manager_impl.cc b/source/common/http/conn_manager_impl.cc index 5892b514c67c1..21851650122e1 100644 --- a/source/common/http/conn_manager_impl.cc +++ b/source/common/http/conn_manager_impl.cc @@ -48,7 +48,7 @@ ConnectionManagerTracingStats ConnectionManagerImpl::generateTracingStats(const } ConnectionManagerImpl::ConnectionManagerImpl(ConnectionManagerConfig& config, - Network::DrainDecision& drain_close, + const Network::DrainDecision& drain_close, Runtime::RandomGenerator& random_generator, Tracing::HttpTracer& tracer, Runtime::Loader& runtime, const LocalInfo::LocalInfo& local_info) diff --git a/source/common/http/conn_manager_impl.h b/source/common/http/conn_manager_impl.h index fd370780478a2..f12ef0dee4e74 100644 --- a/source/common/http/conn_manager_impl.h +++ b/source/common/http/conn_manager_impl.h @@ -255,7 +255,7 @@ class ConnectionManagerImpl : Logger::Loggable, public ServerConnectionCallbacks, public Network::ConnectionCallbacks { public: - ConnectionManagerImpl(ConnectionManagerConfig& config, Network::DrainDecision& drain_close, + ConnectionManagerImpl(ConnectionManagerConfig& config, const Network::DrainDecision& drain_close, Runtime::RandomGenerator& random_generator, Tracing::HttpTracer& tracer, Runtime::Loader& runtime, const LocalInfo::LocalInfo& local_info); ~ConnectionManagerImpl(); @@ -524,7 +524,7 @@ class ConnectionManagerImpl : Logger::Loggable, ServerConnectionPtr codec_; std::list streams_; Stats::TimespanPtr conn_length_; - Network::DrainDecision& drain_close_; + const Network::DrainDecision& drain_close_; DrainState drain_state_{DrainState::NotDraining}; UserAgent user_agent_; Event::TimerPtr idle_timer_; diff --git a/source/server/BUILD b/source/server/BUILD index b792962ede5b0..9eec2dd92940f 100644 --- a/source/server/BUILD +++ b/source/server/BUILD @@ -122,6 +122,7 @@ envoy_cc_library( hdrs = ["listener_manager_impl.h"], deps = [ ":configuration_lib", + ":drain_manager_lib", ":init_manager_lib", "//include/envoy/registry", "//include/envoy/server:filter_config_interface", diff --git a/source/server/config/network/http_connection_manager.cc b/source/server/config/network/http_connection_manager.cc index 63fa6f3b9508b..88f8326a3ce8b 100644 --- a/source/server/config/network/http_connection_manager.cc +++ b/source/server/config/network/http_connection_manager.cc @@ -33,7 +33,7 @@ HttpConnectionManagerFilterConfigFactory::createFilterFactory(const Json::Object new HttpConnectionManagerConfig(config, context)); return [http_config, &context](Network::FilterManager& filter_manager) mutable -> void { filter_manager.addReadFilter(Network::ReadFilterSharedPtr{new Http::ConnectionManagerImpl( - *http_config, context.drainManager(), context.random(), context.httpTracer(), + *http_config, context.drainDecision(), context.random(), context.httpTracer(), context.runtime(), context.localInfo())}); }; } diff --git a/source/server/config_validation/server.h b/source/server/config_validation/server.h index e184cda35cabc..8b6052cb4bea1 100644 --- a/source/server/config_validation/server.h +++ b/source/server/config_validation/server.h @@ -60,7 +60,6 @@ class ValidationInstance : Logger::Loggable, Ssl::ContextManager& sslContextManager() override { return *ssl_context_manager_; } Event::Dispatcher& dispatcher() override { return *dispatcher_; } Network::DnsResolverSharedPtr dnsResolver() override { return dns_resolver_; } - bool draining() override { NOT_IMPLEMENTED; } void drainListeners() override { NOT_IMPLEMENTED; } DrainManager& drainManager() override { NOT_IMPLEMENTED; } AccessLog::AccessLogManager& accessLogManager() override { return access_log_manager_; } @@ -98,6 +97,7 @@ class ValidationInstance : Logger::Loggable, // validation mock. return nullptr; } + DrainManagerPtr createDrainManager() override { return nullptr; } uint64_t nextListenerTag() override { return 0; } // Server::WorkerFactory diff --git a/source/server/drain_manager_impl.cc b/source/server/drain_manager_impl.cc index 25a44e3c744e3..f5fe2506e92b5 100644 --- a/source/server/drain_manager_impl.cc +++ b/source/server/drain_manager_impl.cc @@ -15,7 +15,7 @@ namespace Server { DrainManagerImpl::DrainManagerImpl(Instance& server) : server_(server) {} -bool DrainManagerImpl::drainClose() { +bool DrainManagerImpl::drainClose() const { // If we are actively HC failed, always drain close. if (server_.healthCheckFailed()) { return true; @@ -37,10 +37,13 @@ void DrainManagerImpl::drainSequenceTick() { if (drain_time_completed_ < server_.options().drainTime()) { drain_tick_timer_->enableTimer(std::chrono::milliseconds(1000)); + } else if (drain_sequence_completion_) { + drain_sequence_completion_(); } } -void DrainManagerImpl::startDrainSequence() { +void DrainManagerImpl::startDrainSequence(std::function completion) { + drain_sequence_completion_ = completion; ASSERT(!drain_tick_timer_); drain_tick_timer_ = server_.dispatcher().createTimer([this]() -> void { drainSequenceTick(); }); drainSequenceTick(); diff --git a/source/server/drain_manager_impl.h b/source/server/drain_manager_impl.h index c39eb7465c446..fc5ac1f1a1013 100644 --- a/source/server/drain_manager_impl.h +++ b/source/server/drain_manager_impl.h @@ -21,18 +21,19 @@ class DrainManagerImpl : Logger::Loggable, public DrainManager DrainManagerImpl(Instance& server); // Server::DrainManager - bool draining() override { return drain_tick_timer_ != nullptr; } - bool drainClose() override; - void startDrainSequence() override; + bool drainClose() const override; + void startDrainSequence(std::function completion) override; void startParentShutdownSequence() override; private: + bool draining() const { return drain_tick_timer_ != nullptr; } void drainSequenceTick(); Instance& server_; Event::TimerPtr drain_tick_timer_; std::chrono::seconds drain_time_completed_{}; Event::TimerPtr parent_shutdown_timer_; + std::function drain_sequence_completion_; }; } // namespace Server diff --git a/source/server/listener_manager_impl.cc b/source/server/listener_manager_impl.cc index 2acd7eba6d5a1..f02a80abcfe5f 100644 --- a/source/server/listener_manager_impl.cc +++ b/source/server/listener_manager_impl.cc @@ -9,6 +9,7 @@ #include "common/ssl/context_config_impl.h" #include "server/configuration_impl.h" // TODO(mattklein123): Remove post 1.4.0 +#include "server/drain_manager_impl.h" namespace Envoy { namespace Server { @@ -87,6 +88,10 @@ ProdListenerComponentFactory::createListenSocket(Network::Address::InstanceConst } } +DrainManagerPtr ProdListenerComponentFactory::createDrainManager() { + return DrainManagerPtr{new DrainManagerImpl(server_)}; +} + ListenerImpl::ListenerImpl(const Json::Object& json, ListenerManagerImpl& parent, const std::string& name, bool workers_started, uint64_t hash) : Json::Validator(json, Json::Schema::LISTENER_SCHEMA), parent_(parent), @@ -98,7 +103,8 @@ ListenerImpl::ListenerImpl(const Json::Object& json, ListenerManagerImpl& parent per_connection_buffer_limit_bytes_( json.getInteger("per_connection_buffer_limit_bytes", 1024 * 1024)), listener_tag_(parent_.factory_.nextListenerTag()), name_(name), - workers_started_(workers_started), hash_(hash) { + workers_started_(workers_started), hash_(hash), + local_drain_manager_(parent.factory_.createDrainManager()) { // ':' is a reserved char in statsd. Do the translation here to avoid costly inline translations // later. @@ -130,6 +136,13 @@ bool ListenerImpl::createFilterChain(Network::Connection& connection) { return Configuration::FilterChainUtility::buildFilterChain(connection, filter_factories_); } +bool ListenerImpl::drainClose() const { + // When a listener is draining, the "drain close" decision is the union of the per-listener drain + // manager and the server wide drain manager. This allows individual listeners to be drained and + // removed independently of a server-wide drain event (e.g., /healthcheck/fail or hot restart). + return local_drain_manager_->drainClose() || parent_.server_.drainManager().drainClose(); +} + void ListenerImpl::infoLog(const std::string& message) { ENVOY_LOG(info, "{}: name={}, hash={}, address={}", message, name_, hash_, address_->asString()); } @@ -261,8 +274,8 @@ bool ListenerManagerImpl::addOrUpdateListener(const Json::Object& json) { } void ListenerManagerImpl::drainListener(ListenerImplPtr&& listener) { - // TODO(mattklein123): Actually implement timed draining in a follow up. Currently we just - // correctly synchronize removal across all workers. + // First add the listener to the draining list. This must be done under the lock since it + // can race with remove completions. std::list::iterator draining_it; { std::lock_guard guard(draining_listeners_lock_); @@ -270,16 +283,30 @@ void ListenerManagerImpl::drainListener(ListenerImplPtr&& listener) { workers_.size()); } - draining_it->listener_->infoLog("removing listener"); + // Tell all workers to stop accepting new connections on this listener. + draining_it->listener_->infoLog("draining listener"); for (const auto& worker : workers_) { - worker->removeListener(*draining_it->listener_, [this, draining_it]() -> void { - std::lock_guard guard(draining_listeners_lock_); - if (--draining_it->workers_pending_removal_ == 0) { - draining_it->listener_->infoLog("listener removal complete"); - draining_listeners_.erase(draining_it); - } - }); + worker->stopListener(*draining_it->listener_); } + + // The following sets up 2 level lambda. The first completes when the listener's drain manager + // has completed draining at whatever the server configured drain times are. Once the drain time + // has completed via the drain manager's timer, we tell the workers to remove the listener. The + // 2nd lambda acquires the lock and determines when we can remove the listener from the draining + // list. This makes sure that we don't destroy the listener while filters might still be using its + // context (stats, etc.). + draining_it->listener_->localDrainManager().startDrainSequence([this, draining_it]() -> void { + draining_it->listener_->infoLog("removing listener"); + for (const auto& worker : workers_) { + worker->removeListener(*draining_it->listener_, [this, draining_it]() -> void { + std::lock_guard guard(draining_listeners_lock_); + if (--draining_it->workers_pending_removal_ == 0) { + draining_it->listener_->infoLog("listener removal complete"); + draining_listeners_.erase(draining_it); + } + }); + } + }); } ListenerManagerImpl::ListenerList::iterator @@ -306,6 +333,12 @@ std::vector> ListenerManagerImpl::listeners() { } void ListenerManagerImpl::onListenerWarmed(ListenerImpl& listener) { + // The warmed listener should be added first so that the worker will accept new connections + // when it stops listening on the old listener. + for (const auto& worker : workers_) { + worker->addListener(listener); + } + auto existing_active_listener = getListenerByName(active_listeners_, listener.name()); auto existing_warming_listener = getListenerByName(warming_listeners_, listener.name()); (*existing_warming_listener)->infoLog("warm complete. updating active listener"); @@ -317,10 +350,6 @@ void ListenerManagerImpl::onListenerWarmed(ListenerImpl& listener) { } warming_listeners_.erase(existing_warming_listener); - - for (const auto& worker : workers_) { - worker->addListener(listener); - } } uint64_t ListenerManagerImpl::numConnections() { diff --git a/source/server/listener_manager_impl.h b/source/server/listener_manager_impl.h index a19cb4ce4f19d..94d5497bfadda 100644 --- a/source/server/listener_manager_impl.h +++ b/source/server/listener_manager_impl.h @@ -38,6 +38,7 @@ class ProdListenerComponentFactory : public ListenerComponentFactory, } Network::ListenSocketSharedPtr createListenSocket(Network::Address::InstanceConstSharedPtr address, bool bind_to_port) override; + DrainManagerPtr createDrainManager() override; uint64_t nextListenerTag() override { return next_listener_tag_++; } private: @@ -96,8 +97,16 @@ class ListenerManagerImpl : public ListenerManager, Logger::Loggable draining_listeners_; std::list workers_; bool workers_started_{}; @@ -106,14 +115,18 @@ class ListenerManagerImpl : public ListenerManager, Logger::Loggable { @@ -131,12 +144,13 @@ class ListenerImpl : public Listener, bool workers_started, uint64_t hash); ~ListenerImpl(); - Network::Address::InstanceConstSharedPtr address() { return address_; } - const Network::ListenSocketSharedPtr& getSocket() { return socket_; } - uint64_t hash() { return hash_; } + Network::Address::InstanceConstSharedPtr address() const { return address_; } + const Network::ListenSocketSharedPtr& getSocket() const { return socket_; } + uint64_t hash() const { return hash_; } void infoLog(const std::string& message); void initialize(); - const std::string& name() { return name_; } + DrainManager& localDrainManager() const { return *local_drain_manager_; } + const std::string& name() const { return name_; } void setSocket(const Network::ListenSocketSharedPtr& socket); // Server::Listener @@ -156,7 +170,7 @@ class ListenerImpl : public Listener, } Upstream::ClusterManager& clusterManager() override { return parent_.server_.clusterManager(); } Event::Dispatcher& dispatcher() override { return parent_.server_.dispatcher(); } - DrainManager& drainManager() override { return parent_.server_.drainManager(); } + Network::DrainDecision& drainDecision() override { return *this; } bool healthCheckFailed() override { return parent_.server_.healthCheckFailed(); } Tracing::HttpTracer& httpTracer() override { return parent_.server_.httpTracer(); } Init::Manager& initManager() override; @@ -171,6 +185,9 @@ class ListenerImpl : public Listener, Stats::Scope& scope() override { return *global_scope_; } ThreadLocal::Instance& threadLocal() override { return parent_.server_.threadLocal(); } + // Network::DrainDecision + bool drainClose() const override; + // Network::FilterChainFactory bool createFilterChain(Network::Connection& connection) override; @@ -192,6 +209,7 @@ class ListenerImpl : public Listener, InitManagerImpl dynamic_init_manager_; bool initialize_canceled_{}; std::vector filter_factories_; + DrainManagerPtr local_drain_manager_; }; } // namespace Server diff --git a/source/server/server.cc b/source/server/server.cc index 7203680d7776e..cb129726f40ed 100644 --- a/source/server/server.cc +++ b/source/server/server.cc @@ -77,7 +77,7 @@ Tracing::HttpTracer& InstanceImpl::httpTracer() { return config_->httpTracer(); void InstanceImpl::drainListeners() { ENVOY_LOG(warn, "closing and draining listeners"); listener_manager_->stopListeners(); - drain_manager_->startDrainSequence(); + drain_manager_->startDrainSequence(nullptr); } void InstanceImpl::failHealthcheck(bool fail) { diff --git a/source/server/server.h b/source/server/server.h index 29f6b040c616c..d5100574aee69 100644 --- a/source/server/server.h +++ b/source/server/server.h @@ -99,7 +99,6 @@ class InstanceImpl : Logger::Loggable, public Instance { Ssl::ContextManager& sslContextManager() override { return *ssl_context_manager_; } Event::Dispatcher& dispatcher() override { return *dispatcher_; } Network::DnsResolverSharedPtr dnsResolver() override { return dns_resolver_; } - bool draining() override { return drain_manager_->draining(); } void drainListeners() override; DrainManager& drainManager() override { return *drain_manager_; } AccessLog::AccessLogManager& accessLogManager() override { return access_log_manager_; } diff --git a/test/integration/echo_integration_test.cc b/test/integration/echo_integration_test.cc index 8b3881aa3901b..e37d5ad08343d 100644 --- a/test/integration/echo_integration_test.cc +++ b/test/integration/echo_integration_test.cc @@ -56,14 +56,17 @@ TEST_P(EchoIntegrationTest, AddRemoveListener) { )EOF"; // Add the listener. - ConditionalInitializer listener_added; + ConditionalInitializer listener_added_by_worker; + ConditionalInitializer listener_added_by_manager; test_server_->setOnWorkerListenerAddedCb( - [&listener_added]() -> void { listener_added.setReady(); }); + [&listener_added_by_worker]() -> void { listener_added_by_worker.setReady(); }); Json::ObjectSharedPtr loader = TestEnvironment::jsonLoadFromString(json, GetParam()); - test_server_->server().dispatcher().post([this, loader]() -> void { + test_server_->server().dispatcher().post([this, loader, &listener_added_by_manager]() -> void { EXPECT_TRUE(test_server_->server().listenerManager().addOrUpdateListener(*loader)); + listener_added_by_manager.setReady(); }); - listener_added.waitReady(); + listener_added_by_worker.waitReady(); + listener_added_by_manager.waitReady(); EXPECT_EQ(2UL, test_server_->server().listenerManager().listeners().size()); uint32_t new_listener_port = test_server_->server() diff --git a/test/integration/server.h b/test/integration/server.h index e0f32313bfc6f..28af8529ebb32 100644 --- a/test/integration/server.h +++ b/test/integration/server.h @@ -37,9 +37,9 @@ class TestOptionsImpl : public Options { const std::string& configPath() override { return config_path_; } const std::string& adminAddressPath() override { return admin_address_path_; } Network::Address::IpVersion localAddressIpVersion() override { return local_address_ip_version_; } - std::chrono::seconds drainTime() override { return std::chrono::seconds(0); } + std::chrono::seconds drainTime() override { return std::chrono::seconds(1); } spdlog::level::level_enum logLevel() override { NOT_IMPLEMENTED; } - std::chrono::seconds parentShutdownTime() override { return std::chrono::seconds(0); } + std::chrono::seconds parentShutdownTime() override { return std::chrono::seconds(2); } uint64_t restartEpoch() override { return 0; } std::chrono::milliseconds fileFlushIntervalMsec() override { return std::chrono::milliseconds(10000); @@ -55,9 +55,8 @@ class TestOptionsImpl : public Options { class TestDrainManager : public DrainManager { public: // Server::DrainManager - bool drainClose() override { return draining_; } - bool draining() override { return draining_; } - void startDrainSequence() override {} + bool drainClose() const override { return draining_; } + void startDrainSequence(std::function) override {} void startParentShutdownSequence() override {} bool draining_{}; diff --git a/test/mocks/network/mocks.h b/test/mocks/network/mocks.h index b34887a30fc07..9fe0e119f4077 100644 --- a/test/mocks/network/mocks.h +++ b/test/mocks/network/mocks.h @@ -185,7 +185,7 @@ class MockDrainDecision : public DrainDecision { MockDrainDecision(); ~MockDrainDecision(); - MOCK_METHOD0(drainClose, bool()); + MOCK_CONST_METHOD0(drainClose, bool()); }; class MockFilterChainFactory : public FilterChainFactory { diff --git a/test/mocks/server/mocks.cc b/test/mocks/server/mocks.cc index 7a083a233bd9a..335cc2ad429c6 100644 --- a/test/mocks/server/mocks.cc +++ b/test/mocks/server/mocks.cc @@ -25,7 +25,9 @@ MockOptions::~MockOptions() {} MockAdmin::MockAdmin() {} MockAdmin::~MockAdmin() {} -MockDrainManager::MockDrainManager() {} +MockDrainManager::MockDrainManager() { + ON_CALL(*this, startDrainSequence(_)).WillByDefault(SaveArg<0>(&drain_sequence_completion_)); +} MockDrainManager::~MockDrainManager() {} MockWatchDog::MockWatchDog() {} @@ -104,7 +106,7 @@ MockFactoryContext::MockFactoryContext() { ON_CALL(*this, accessLogManager()).WillByDefault(ReturnRef(access_log_manager_)); ON_CALL(*this, clusterManager()).WillByDefault(ReturnRef(cluster_manager_)); ON_CALL(*this, dispatcher()).WillByDefault(ReturnRef(dispatcher_)); - ON_CALL(*this, drainManager()).WillByDefault(ReturnRef(drain_manager_)); + ON_CALL(*this, drainDecision()).WillByDefault(ReturnRef(drain_manager_)); ON_CALL(*this, httpTracer()).WillByDefault(ReturnRef(http_tracer_)); ON_CALL(*this, initManager()).WillByDefault(ReturnRef(init_manager_)); ON_CALL(*this, localInfo()).WillByDefault(ReturnRef(local_info_)); diff --git a/test/mocks/server/mocks.h b/test/mocks/server/mocks.h index fabd9bda97fc0..ec376c49011eb 100644 --- a/test/mocks/server/mocks.h +++ b/test/mocks/server/mocks.h @@ -75,10 +75,11 @@ class MockDrainManager : public DrainManager { ~MockDrainManager(); // Server::DrainManager - MOCK_METHOD0(drainClose, bool()); - MOCK_METHOD0(draining, bool()); - MOCK_METHOD0(startDrainSequence, void()); + MOCK_CONST_METHOD0(drainClose, bool()); + MOCK_METHOD1(startDrainSequence, void(std::function completion)); MOCK_METHOD0(startParentShutdownSequence, void()); + + std::function drain_sequence_completion_; }; class MockWatchDog : public WatchDog { @@ -126,12 +127,15 @@ class MockListenerComponentFactory : public ListenerComponentFactory { MockListenerComponentFactory(); ~MockListenerComponentFactory(); + DrainManagerPtr createDrainManager() override { return DrainManagerPtr{createDrainManager_()}; } + MOCK_METHOD2(createFilterFactoryList, std::vector( const std::vector& filters, Configuration::FactoryContext& context)); MOCK_METHOD2(createListenSocket, Network::ListenSocketSharedPtr(Network::Address::InstanceConstSharedPtr address, bool bind_to_port)); + MOCK_METHOD0(createDrainManager_, DrainManager*()); MOCK_METHOD0(nextListenerTag, uint64_t()); std::shared_ptr socket_; @@ -221,7 +225,6 @@ class MockInstance : public Instance { MOCK_METHOD0(sslContextManager, Ssl::ContextManager&()); MOCK_METHOD0(dispatcher, Event::Dispatcher&()); MOCK_METHOD0(dnsResolver, Network::DnsResolverSharedPtr()); - MOCK_METHOD0(draining, bool()); MOCK_METHOD0(drainListeners, void()); MOCK_METHOD0(drainManager, DrainManager&()); MOCK_METHOD0(accessLogManager, AccessLog::AccessLogManager&()); @@ -303,7 +306,7 @@ class MockFactoryContext : public FactoryContext { MOCK_METHOD0(accessLogManager, AccessLog::AccessLogManager&()); MOCK_METHOD0(clusterManager, Upstream::ClusterManager&()); MOCK_METHOD0(dispatcher, Event::Dispatcher&()); - MOCK_METHOD0(drainManager, DrainManager&()); + MOCK_METHOD0(drainDecision, const Network::DrainDecision&()); MOCK_METHOD0(healthCheckFailed, bool()); MOCK_METHOD0(httpTracer, Tracing::HttpTracer&()); MOCK_METHOD0(initManager, Init::Manager&()); diff --git a/test/server/drain_manager_impl_test.cc b/test/server/drain_manager_impl_test.cc index d4b4d7ed79cc1..284aff67025e1 100644 --- a/test/server/drain_manager_impl_test.cc +++ b/test/server/drain_manager_impl_test.cc @@ -7,14 +7,17 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -namespace Envoy { +using testing::InSequence; using testing::Return; using testing::SaveArg; using testing::_; +namespace Envoy { namespace Server { TEST(DrainManagerImplTest, All) { + InSequence s; + NiceMock server; ON_CALL(server.options_, drainTime()).WillByDefault(Return(std::chrono::seconds(600))); ON_CALL(server.options_, parentShutdownTime()).WillByDefault(Return(std::chrono::seconds(900))); @@ -37,12 +40,15 @@ TEST(DrainManagerImplTest, All) { // Test drain sequence. Event::MockTimer* drain_timer = new Event::MockTimer(&server.dispatcher_); EXPECT_CALL(*drain_timer, enableTimer(_)); - drain_manager.startDrainSequence(); + ReadyWatcher drain_complete; + drain_manager.startDrainSequence([&drain_complete]() -> void { drain_complete.ready(); }); // 600s which is the default drain time. for (size_t i = 0; i < 599; i++) { if (i < 598) { EXPECT_CALL(*drain_timer, enableTimer(_)); + } else { + EXPECT_CALL(drain_complete, ready()); } drain_timer->callback_(); } diff --git a/test/server/listener_manager_impl_test.cc b/test/server/listener_manager_impl_test.cc index eaecbc80bdbc7..f2fac3f2291cb 100644 --- a/test/server/listener_manager_impl_test.cc +++ b/test/server/listener_manager_impl_test.cc @@ -23,11 +23,14 @@ namespace Server { class ListenerHandle { public: + ListenerHandle() { EXPECT_CALL(*drain_manager_, startParentShutdownSequence()).Times(0); } ~ListenerHandle() { onDestroy(); } MOCK_METHOD0(onDestroy, void()); Init::MockTarget target_; + MockDrainManager* drain_manager_ = new MockDrainManager(); + Configuration::FactoryContext* context_{}; }; class ListenerManagerImplTest : public testing::Test { @@ -38,17 +41,22 @@ class ListenerManagerImplTest : public testing::Test { } /** - * This routing sets up an expectation that does two things: + * This routing sets up an expectation that does various things: * 1) Allows us to track listener destruction via filter factory destruction. * 2) Allows us to register for init manager handling much like RDS, etc. would do. + * 3) Stores the factory context for later use. + * 4) Creates a mock local drain manager for the listener. */ - ListenerHandle* expectFilterFactoryCreate(bool need_init) { + ListenerHandle* expectListenerCreate(bool need_init) { ListenerHandle* raw_listener = new ListenerHandle(); + EXPECT_CALL(listener_factory_, createDrainManager_()) + .WillOnce(Return(raw_listener->drain_manager_)); EXPECT_CALL(listener_factory_, createFilterFactoryList(_, _)) .WillOnce(Invoke([raw_listener, need_init](const std::vector&, - Server::Configuration::FactoryContext& context) - -> std::vector { + Configuration::FactoryContext& context) + -> std::vector { std::shared_ptr notifier(raw_listener); + raw_listener->context_ = &context; if (need_init) { context.initManager().registerTarget(notifier->target_); } @@ -72,10 +80,9 @@ class ListenerManagerImplWithRealFiltersTest : public ListenerManagerImplTest { // Use real filter loading by default. ON_CALL(listener_factory_, createFilterFactoryList(_, _)) .WillByDefault(Invoke([this](const std::vector& filters, - Server::Configuration::FactoryContext& context) - -> std::vector { - return Server::ProdListenerComponentFactory::createFilterFactoryList_(filters, server_, - context); + Configuration::FactoryContext& context) + -> std::vector { + return ProdListenerComponentFactory::createFilterFactoryList_(filters, server_, context); })); } }; @@ -217,7 +224,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, BadFilterType) { class TestStatsConfigFactory : public Configuration::NamedNetworkFilterConfigFactory { public: - // Server::Configuration::NamedNetworkFilterConfigFactory + // Configuration::NamedNetworkFilterConfigFactory Configuration::NetworkFilterFactoryCb createFilterFactory(const Json::Object&, Configuration::FactoryContext& context) override { context.scope().counter("bar").inc(); @@ -264,7 +271,7 @@ class TestDeprecatedEchoConfigFactory : public Configuration::NetworkFilterConfi // NetworkFilterConfigFactory Configuration::NetworkFilterFactoryCb tryCreateFilterFactory(Configuration::NetworkFilterType type, const std::string& name, - const Json::Object&, Server::Instance&) override { + const Json::Object&, Instance&) override { if (type != Configuration::NetworkFilterType::Read || name != "echo_deprecated") { return nullptr; } @@ -310,7 +317,7 @@ TEST_F(ListenerManagerImplTest, AddListenerAddressNotMatching) { )EOF"; Json::ObjectSharedPtr loader = Json::Factory::loadFromString(listener_foo_json); - ListenerHandle* listener_foo = expectFilterFactoryCreate(false); + ListenerHandle* listener_foo = expectListenerCreate(false); EXPECT_CALL(listener_factory_, createListenSocket(_, true)); EXPECT_TRUE(manager_->addOrUpdateListener(*loader)); @@ -324,7 +331,7 @@ TEST_F(ListenerManagerImplTest, AddListenerAddressNotMatching) { )EOF"; loader = Json::Factory::loadFromString(listener_foo_different_address_json); - ListenerHandle* listener_foo_different_address = expectFilterFactoryCreate(false); + ListenerHandle* listener_foo_different_address = expectListenerCreate(false); EXPECT_CALL(*listener_foo_different_address, onDestroy()); EXPECT_THROW_WITH_MESSAGE(manager_->addOrUpdateListener(*loader), EnvoyException, "error updating listener: 'foo' has a different address " @@ -346,7 +353,7 @@ TEST_F(ListenerManagerImplTest, AddOrUpdateListener) { )EOF"; Json::ObjectSharedPtr loader = Json::Factory::loadFromString(listener_foo_json); - ListenerHandle* listener_foo = expectFilterFactoryCreate(false); + ListenerHandle* listener_foo = expectListenerCreate(false); EXPECT_CALL(listener_factory_, createListenSocket(_, true)); EXPECT_TRUE(manager_->addOrUpdateListener(*loader)); @@ -365,7 +372,7 @@ TEST_F(ListenerManagerImplTest, AddOrUpdateListener) { )EOF"; loader = Json::Factory::loadFromString(listener_foo_update1_json); - ListenerHandle* listener_foo_update1 = expectFilterFactoryCreate(false); + ListenerHandle* listener_foo_update1 = expectListenerCreate(false); EXPECT_CALL(*listener_foo, onDestroy()); EXPECT_TRUE(manager_->addOrUpdateListener(*loader)); @@ -380,10 +387,13 @@ TEST_F(ListenerManagerImplTest, AddOrUpdateListener) { // Update foo. Should go into warming, have an immediate warming callback, and start immediate // removal. loader = Json::Factory::loadFromString(listener_foo_json); - ListenerHandle* listener_foo_update2 = expectFilterFactoryCreate(false); - EXPECT_CALL(*worker_, removeListener(_, _)); + ListenerHandle* listener_foo_update2 = expectListenerCreate(false); EXPECT_CALL(*worker_, addListener(_)); + EXPECT_CALL(*worker_, stopListener(_)); + EXPECT_CALL(*listener_foo_update1->drain_manager_, startDrainSequence(_)); EXPECT_TRUE(manager_->addOrUpdateListener(*loader)); + EXPECT_CALL(*worker_, removeListener(_, _)); + listener_foo_update1->drain_manager_->drain_sequence_completion_(); EXPECT_CALL(*listener_foo_update1, onDestroy()); worker_->callRemovalCompletion(); @@ -397,7 +407,7 @@ TEST_F(ListenerManagerImplTest, AddOrUpdateListener) { )EOF"; loader = Json::Factory::loadFromString(listener_bar_json); - ListenerHandle* listener_bar = expectFilterFactoryCreate(false); + ListenerHandle* listener_bar = expectListenerCreate(false); EXPECT_CALL(listener_factory_, createListenSocket(_, true)); EXPECT_CALL(*worker_, addListener(_)); EXPECT_TRUE(manager_->addOrUpdateListener(*loader)); @@ -413,7 +423,7 @@ TEST_F(ListenerManagerImplTest, AddOrUpdateListener) { )EOF"; loader = Json::Factory::loadFromString(listener_baz_json); - ListenerHandle* listener_baz = expectFilterFactoryCreate(true); + ListenerHandle* listener_baz = expectListenerCreate(true); EXPECT_CALL(listener_factory_, createListenSocket(_, true)); EXPECT_CALL(listener_baz->target_, initialize(_)); EXPECT_TRUE(manager_->addOrUpdateListener(*loader)); @@ -434,7 +444,7 @@ TEST_F(ListenerManagerImplTest, AddOrUpdateListener) { )EOF"; loader = Json::Factory::loadFromString(listener_baz_update1_json); - ListenerHandle* listener_baz_update1 = expectFilterFactoryCreate(true); + ListenerHandle* listener_baz_update1 = expectListenerCreate(true); EXPECT_CALL(*listener_baz, onDestroy()).WillOnce(Invoke([listener_baz]() -> void { // Call the initialize callback during destruction like RDS will. listener_baz->target_.callback_(); @@ -473,18 +483,21 @@ TEST_F(ListenerManagerImplTest, AddDrainingListener) { ON_CALL(*listener_factory_.socket_, localAddress()).WillByDefault(Return(local_address)); Json::ObjectSharedPtr loader = Json::Factory::loadFromString(listener_foo_json); - ListenerHandle* listener_foo = expectFilterFactoryCreate(false); + ListenerHandle* listener_foo = expectListenerCreate(false); EXPECT_CALL(listener_factory_, createListenSocket(_, true)); EXPECT_CALL(*worker_, addListener(_)); EXPECT_TRUE(manager_->addOrUpdateListener(*loader)); // Remove foo into draining. - EXPECT_CALL(*worker_, removeListener(_, _)); + EXPECT_CALL(*worker_, stopListener(_)); + EXPECT_CALL(*listener_foo->drain_manager_, startDrainSequence(_)); EXPECT_TRUE(manager_->removeListener("foo")); + EXPECT_CALL(*worker_, removeListener(_, _)); + listener_foo->drain_manager_->drain_sequence_completion_(); // Add foo again. We should use the socket from draining. loader = Json::Factory::loadFromString(listener_foo_json); - ListenerHandle* listener_foo2 = expectFilterFactoryCreate(false); + ListenerHandle* listener_foo2 = expectListenerCreate(false); EXPECT_CALL(*worker_, addListener(_)); EXPECT_TRUE(manager_->addOrUpdateListener(*loader)); @@ -509,13 +522,57 @@ TEST_F(ListenerManagerImplTest, CantBindSocket) { )EOF"; Json::ObjectSharedPtr loader = Json::Factory::loadFromString(listener_foo_json); - ListenerHandle* listener_foo = expectFilterFactoryCreate(true); + ListenerHandle* listener_foo = expectListenerCreate(true); EXPECT_CALL(listener_factory_, createListenSocket(_, true)) .WillOnce(Throw(EnvoyException("can't bind"))); EXPECT_CALL(*listener_foo, onDestroy()); EXPECT_THROW(manager_->addOrUpdateListener(*loader), EnvoyException); } +TEST_F(ListenerManagerImplTest, ListenerDraining) { + InSequence s; + + EXPECT_CALL(*worker_, start(_)); + manager_->startWorkers(guard_dog_); + + std::string listener_foo_json = R"EOF( + { + "name": "foo", + "address": "tcp://127.0.0.1:1234", + "filters": [] + } + )EOF"; + + Json::ObjectSharedPtr loader = Json::Factory::loadFromString(listener_foo_json); + ListenerHandle* listener_foo = expectListenerCreate(false); + EXPECT_CALL(listener_factory_, createListenSocket(_, true)); + EXPECT_CALL(*worker_, addListener(_)); + EXPECT_TRUE(manager_->addOrUpdateListener(*loader)); + + EXPECT_CALL(*listener_foo->drain_manager_, drainClose()).WillOnce(Return(false)); + EXPECT_CALL(server_.drain_manager_, drainClose()).WillOnce(Return(false)); + EXPECT_FALSE(listener_foo->context_->drainDecision().drainClose()); + + EXPECT_CALL(*worker_, stopListener(_)); + EXPECT_CALL(*listener_foo->drain_manager_, startDrainSequence(_)); + EXPECT_TRUE(manager_->removeListener("foo")); + + // NOTE: || short circuit here prevents the server drain manager from getting called. + EXPECT_CALL(*listener_foo->drain_manager_, drainClose()).WillOnce(Return(true)); + EXPECT_TRUE(listener_foo->context_->drainDecision().drainClose()); + + EXPECT_CALL(*worker_, removeListener(_, _)); + listener_foo->drain_manager_->drain_sequence_completion_(); + + EXPECT_CALL(*listener_foo->drain_manager_, drainClose()).WillOnce(Return(false)); + EXPECT_CALL(server_.drain_manager_, drainClose()).WillOnce(Return(true)); + EXPECT_TRUE(listener_foo->context_->drainDecision().drainClose()); + + EXPECT_CALL(*listener_foo, onDestroy()); + worker_->callRemovalCompletion(); + EXPECT_EQ(0UL, manager_->listeners().size()); +} + TEST_F(ListenerManagerImplTest, RemoveListener) { InSequence s; @@ -535,7 +592,7 @@ TEST_F(ListenerManagerImplTest, RemoveListener) { )EOF"; Json::ObjectSharedPtr loader = Json::Factory::loadFromString(listener_foo_json); - ListenerHandle* listener_foo = expectFilterFactoryCreate(true); + ListenerHandle* listener_foo = expectListenerCreate(true); EXPECT_CALL(listener_factory_, createListenSocket(_, true)); EXPECT_CALL(listener_foo->target_, initialize(_)); EXPECT_TRUE(manager_->addOrUpdateListener(*loader)); @@ -547,7 +604,7 @@ TEST_F(ListenerManagerImplTest, RemoveListener) { EXPECT_EQ(0UL, manager_->listeners().size()); // Add foo again and initialize it. - listener_foo = expectFilterFactoryCreate(true); + listener_foo = expectListenerCreate(true); EXPECT_CALL(listener_factory_, createListenSocket(_, true)); EXPECT_CALL(listener_foo->target_, initialize(_)); EXPECT_TRUE(manager_->addOrUpdateListener(*loader)); @@ -567,15 +624,18 @@ TEST_F(ListenerManagerImplTest, RemoveListener) { )EOF"; loader = Json::Factory::loadFromString(listener_foo_update1_json); - ListenerHandle* listener_foo_update1 = expectFilterFactoryCreate(true); + ListenerHandle* listener_foo_update1 = expectListenerCreate(true); EXPECT_CALL(listener_foo_update1->target_, initialize(_)); EXPECT_TRUE(manager_->addOrUpdateListener(*loader)); EXPECT_EQ(1UL, manager_->listeners().size()); // Remove foo which should remove both warming and active. EXPECT_CALL(*listener_foo_update1, onDestroy()); - EXPECT_CALL(*worker_, removeListener(_, _)); + EXPECT_CALL(*worker_, stopListener(_)); + EXPECT_CALL(*listener_foo->drain_manager_, startDrainSequence(_)); EXPECT_TRUE(manager_->removeListener("foo")); + EXPECT_CALL(*worker_, removeListener(_, _)); + listener_foo->drain_manager_->drain_sequence_completion_(); EXPECT_CALL(*listener_foo, onDestroy()); worker_->callRemovalCompletion(); EXPECT_EQ(0UL, manager_->listeners().size());