diff --git a/test/common/stats/BUILD b/test/common/stats/BUILD index 3427be0dcce23..41f27fac08eac 100644 --- a/test/common/stats/BUILD +++ b/test/common/stats/BUILD @@ -273,6 +273,7 @@ envoy_cc_test( "//test/mocks/stats:stats_mocks", "//test/mocks/thread_local:thread_local_mocks", "//test/test_common:logging_lib", + "//test/test_common:real_threads_test_helper_lib", "//test/test_common:test_time_lib", "//test/test_common:utility_lib", "@envoy_api//envoy/config/metrics/v3:pkg_cc_proto", diff --git a/test/common/stats/thread_local_store_test.cc b/test/common/stats/thread_local_store_test.cc index 5bf5e67395437..0847d9605df29 100644 --- a/test/common/stats/thread_local_store_test.cc +++ b/test/common/stats/thread_local_store_test.cc @@ -20,11 +20,10 @@ #include "test/mocks/stats/mocks.h" #include "test/mocks/thread_local/mocks.h" #include "test/test_common/logging.h" +#include "test/test_common/real_threads_test_helper.h" #include "test/test_common/utility.h" #include "absl/strings/str_split.h" -#include "absl/synchronization/blocking_counter.h" -#include "absl/synchronization/notification.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -1565,75 +1564,36 @@ TEST_F(HistogramTest, ParentHistogramBucketSummary) { "B3.6e+06(1,1)", parent_histogram->bucketSummary()); } - -class ThreadLocalRealThreadsTestBase : public ThreadLocalStoreNoMocksTestBase { +class ThreadLocalRealThreadsTestBase : public Thread::RealThreadsTestHelper, + public ThreadLocalStoreNoMocksTestBase { protected: static constexpr uint32_t NumScopes = 1000; static constexpr uint32_t NumIters = 35; - // Helper class to block on a number of multi-threaded operations occurring. - class BlockingBarrier { - public: - explicit BlockingBarrier(uint32_t count) : blocking_counter_(count) {} - ~BlockingBarrier() { blocking_counter_.Wait(); } - - /** - * Returns a function that first executes 'f', and then decrements the count - * toward unblocking the scope. This is intended to be used as a post() callback. - * - * @param f the function to run prior to decrementing the count. - */ - std::function run(std::function f) { - return [this, f]() { - f(); - decrementCount(); - }; - } - - /** - * @return a function that, when run, decrements the count, intended for passing to post(). - */ - std::function decrementCountFn() { - return [this] { decrementCount(); }; - } - - void decrementCount() { blocking_counter_.DecrementCount(); } - - private: - absl::BlockingCounter blocking_counter_; - }; - +public: ThreadLocalRealThreadsTestBase(uint32_t num_threads) - : num_threads_(num_threads), api_(Api::createApiForTest()), - thread_factory_(api_->threadFactory()), pool_(store_->symbolTable()) { - // This is the same order as InstanceImpl::initialize in source/server/server.cc. - thread_dispatchers_.resize(num_threads_); - { - BlockingBarrier blocking_barrier(num_threads_ + 1); - main_thread_ = thread_factory_.createThread( - [this, &blocking_barrier]() { mainThreadFn(blocking_barrier); }); - for (uint32_t i = 0; i < num_threads_; ++i) { - threads_.emplace_back(thread_factory_.createThread( - [this, i, &blocking_barrier]() { workerThreadFn(i, blocking_barrier); })); - } - } - - { - BlockingBarrier blocking_barrier(1); - main_dispatcher_->post(blocking_barrier.run([this]() { - tls_ = std::make_unique(); - tls_->registerThread(*main_dispatcher_, true); - for (Event::DispatcherPtr& dispatcher : thread_dispatchers_) { - // Worker threads must be registered from the main thread, per assert in registerThread(). - tls_->registerThread(*dispatcher, false); - } - store_->initializeThreading(*main_dispatcher_, *tls_); - })); - } + : RealThreadsTestHelper(num_threads), pool_(store_->symbolTable()) { + runOnMainBlocking([this]() { store_->initializeThreading(*main_dispatcher_, *tls_); }); } ~ThreadLocalRealThreadsTestBase() override { + // TODO(chaoqin-li1123): clean this up when we figure out how to free the threading resources in + // RealThreadsTestHelper. shutdownThreading(); + exitThreads(); + } + + void shutdownThreading() { + runOnMainBlocking([this]() { + if (!tls_->isShutdown()) { + tls_->shutdownGlobalThreading(); + } + store_->shutdownThreading(); + tls_->shutdownThread(); + }); + } + + void exitThreads() { for (Event::DispatcherPtr& dispatcher : thread_dispatchers_) { dispatcher->post([&dispatcher]() { dispatcher->exit(); }); } @@ -1650,52 +1610,6 @@ class ThreadLocalRealThreadsTestBase : public ThreadLocalStoreNoMocksTestBase { main_thread_->join(); } - void shutdownThreading() { - BlockingBarrier blocking_barrier(1); - main_dispatcher_->post(blocking_barrier.run([this]() { - if (!tls_->isShutdown()) { - tls_->shutdownGlobalThreading(); - } - store_->shutdownThreading(); - tls_->shutdownThread(); - })); - } - - void workerThreadFn(uint32_t thread_index, BlockingBarrier& blocking_barrier) { - thread_dispatchers_[thread_index] = - api_->allocateDispatcher(absl::StrCat("test_worker_", thread_index)); - blocking_barrier.decrementCount(); - thread_dispatchers_[thread_index]->run(Event::Dispatcher::RunType::RunUntilExit); - } - - void mainThreadFn(BlockingBarrier& blocking_barrier) { - main_dispatcher_ = api_->allocateDispatcher("test_main_thread"); - blocking_barrier.decrementCount(); - main_dispatcher_->run(Event::Dispatcher::RunType::RunUntilExit); - } - - void mainDispatchBlock() { - // To ensure all stats are freed we have to wait for a few posts() to clear. - // First, wait for the main-dispatcher to initiate the cross-thread TLS cleanup. - BlockingBarrier blocking_barrier(1); - main_dispatcher_->post(blocking_barrier.run([]() {})); - } - - void tlsBlock() { - BlockingBarrier blocking_barrier(num_threads_); - for (Event::DispatcherPtr& thread_dispatcher : thread_dispatchers_) { - thread_dispatcher->post(blocking_barrier.run([]() {})); - } - } - - const uint32_t num_threads_; - Api::ApiPtr api_; - Event::DispatcherPtr main_dispatcher_; - std::vector thread_dispatchers_; - Thread::ThreadFactory& thread_factory_; - ThreadLocal::InstanceImplPtr tls_; - Thread::ThreadPtr main_thread_; - std::vector threads_; StatNamePool pool_; }; @@ -1717,11 +1631,8 @@ class ClusterShutdownCleanupStarvationTest : public ThreadLocalRealThreadsTestBa } void createScopesIncCountersAndCleanupAllThreads() { - BlockingBarrier blocking_barrier(NumThreads); - for (Event::DispatcherPtr& thread_dispatcher : thread_dispatchers_) { - thread_dispatcher->post( - blocking_barrier.run([this]() { createScopesIncCountersAndCleanup(); })); - } + + runOnAllWorkersBlocking([this]() { createScopesIncCountersAndCleanup(); }); } std::chrono::seconds elapsedTime() { @@ -1795,7 +1706,7 @@ class HistogramThreadTest : public ThreadLocalRealThreadsTestBase { void mergeHistograms() { BlockingBarrier blocking_barrier(1); - main_dispatcher_->post([this, &blocking_barrier]() { + runOnMainBlocking([this, &blocking_barrier]() { store_->mergeHistograms(blocking_barrier.decrementCountFn()); }); } @@ -1804,7 +1715,7 @@ class HistogramThreadTest : public ThreadLocalRealThreadsTestBase { uint32_t num; { BlockingBarrier blocking_barrier(1); - main_dispatcher_->post([this, &num, &blocking_barrier]() { + runOnMainBlocking([this, &num, &blocking_barrier]() { ThreadLocalStoreTestingPeer::numTlsHistograms(*store_, [&num, &blocking_barrier](uint32_t num_hist) { num = num_hist; @@ -1817,10 +1728,7 @@ class HistogramThreadTest : public ThreadLocalRealThreadsTestBase { // Executes a function on every worker thread dispatcher. void foreachThread(const std::function& fn) { - BlockingBarrier blocking_barrier(NumThreads); - for (Event::DispatcherPtr& thread_dispatcher : thread_dispatchers_) { - thread_dispatcher->post(blocking_barrier.run(fn)); - } + runOnAllWorkersBlocking([&fn]() { fn(); }); } }; diff --git a/test/extensions/filters/http/ext_authz/BUILD b/test/extensions/filters/http/ext_authz/BUILD index 6076c5d7b3cfe..4bdae57469fe1 100644 --- a/test/extensions/filters/http/ext_authz/BUILD +++ b/test/extensions/filters/http/ext_authz/BUILD @@ -46,8 +46,12 @@ envoy_extension_cc_test( srcs = ["config_test.cc"], extension_names = ["envoy.filters.http.ext_authz"], deps = [ + "//source/common/grpc:async_client_manager_lib", + "//source/common/network:address_lib", + "//source/common/thread_local:thread_local_lib", "//source/extensions/filters/http/ext_authz:config", "//test/mocks/server:factory_context_mocks", + "//test/test_common:real_threads_test_helper_lib", "//test/test_common:test_runtime_lib", "@envoy_api//envoy/config/core/v3:pkg_cc_proto", "@envoy_api//envoy/extensions/filters/http/ext_authz/v3:pkg_cc_proto", diff --git a/test/extensions/filters/http/ext_authz/config_test.cc b/test/extensions/filters/http/ext_authz/config_test.cc index 47c6053064b9b..da59c41c5125e 100644 --- a/test/extensions/filters/http/ext_authz/config_test.cc +++ b/test/extensions/filters/http/ext_authz/config_test.cc @@ -3,10 +3,13 @@ #include "envoy/extensions/filters/http/ext_authz/v3/ext_authz.pb.validate.h" #include "envoy/stats/scope.h" +#include "source/common/grpc/async_client_manager_impl.h" +#include "source/common/network/address_impl.h" +#include "source/common/thread_local/thread_local_impl.h" #include "source/extensions/filters/http/ext_authz/config.h" #include "test/mocks/server/factory_context.h" -#include "test/test_common/test_runtime.h" +#include "test/test_common/real_threads_test_helper.h" #include "test/test_common/utility.h" #include "gmock/gmock.h" @@ -14,84 +17,104 @@ using testing::_; using testing::Invoke; +using testing::NiceMock; +using testing::StrictMock; namespace Envoy { namespace Extensions { namespace HttpFilters { namespace ExtAuthz { -namespace { - -void expectCorrectProtoGrpc(std::string const& grpc_service_yaml) { - ExtAuthzFilterConfig factory; - ProtobufTypes::MessagePtr proto_config = factory.createEmptyConfigProto(); - TestUtility::loadFromYaml(grpc_service_yaml, *proto_config); - - testing::StrictMock context; - testing::StrictMock server_context; - EXPECT_CALL(context, getServerFactoryContext()) - .WillRepeatedly(testing::ReturnRef(server_context)); - EXPECT_CALL(context, messageValidationVisitor()); - EXPECT_CALL(context, clusterManager()).Times(2); - EXPECT_CALL(context, runtime()); - EXPECT_CALL(context, scope()).Times(3); - - Http::FilterFactoryCb cb = factory.createFilterFactoryFromProto(*proto_config, "stats", context); - Http::MockFilterChainFactoryCallbacks filter_callback; - EXPECT_CALL(filter_callback, addStreamFilter(_)); - // Expect the raw async client to be created inside the callback. - // The creation of the filter callback is in main thread while the execution of callback is in - // worker thread. Because of the thread local cache of async client, it must be created in worker - // thread inside the callback. - EXPECT_CALL(context.cluster_manager_.async_client_manager_, getOrCreateRawAsyncClient(_, _, _, _)) - .WillOnce(Invoke( - [](const envoy::config::core::v3::GrpcService&, Stats::Scope&, bool, Grpc::CacheOption) { - return std::make_unique>(); - })); - cb(filter_callback); - - Thread::ThreadPtr thread = Thread::threadFactoryForTest().createThread([&context, cb]() { - Http::MockFilterChainFactoryCallbacks filter_callback; - EXPECT_CALL(filter_callback, addStreamFilter(_)); - // Execute the filter factory callback in another thread. - EXPECT_CALL(context.cluster_manager_.async_client_manager_, - getOrCreateRawAsyncClient(_, _, _, _)) - .WillOnce(Invoke( - [](const envoy::config::core::v3::GrpcService&, Stats::Scope&, bool, - Grpc::CacheOption) { return std::make_unique>(); })); - cb(filter_callback); - }); - thread->join(); -} -} // namespace +class TestAsyncClientManagerImpl : public Grpc::AsyncClientManagerImpl { +public: + TestAsyncClientManagerImpl(Upstream::ClusterManager& cm, ThreadLocal::Instance& tls, + TimeSource& time_source, Api::Api& api, + const Grpc::StatNames& stat_names) + : Grpc::AsyncClientManagerImpl(cm, tls, time_source, api, stat_names) {} + Grpc::AsyncClientFactoryPtr factoryForGrpcService(const envoy::config::core::v3::GrpcService&, + Stats::Scope&, bool) override { + return std::make_unique>(); + } +}; -TEST(HttpExtAuthzConfigTest, CorrectProtoGoogleGrpc) { - std::string google_grpc_service_yaml = R"EOF( - transport_api_version: V3 - grpc_service: - google_grpc: - target_uri: ext_authz_server - stat_prefix: google - failure_mode_allow: false - transport_api_version: V3 - )EOF"; - expectCorrectProtoGrpc(google_grpc_service_yaml); -} +class ExtAuthzFilterTest : public Event::TestUsingSimulatedTime, + public Thread::RealThreadsTestHelper, + public testing::Test { +public: + ExtAuthzFilterTest() : RealThreadsTestHelper(5), stat_names_(symbol_table_) { + runOnMainBlocking([&]() { + async_client_manager_ = std::make_unique( + context_.cluster_manager_, tls(), api().timeSource(), api(), stat_names_); + }); + } -TEST(HttpExtAuthzConfigTest, CorrectProtoEnvoyGrpc) { - std::string envoy_grpc_service_yaml = R"EOF( - transport_api_version: V3 - grpc_service: - envoy_grpc: - cluster_name: ext_authz_server - failure_mode_allow: false - transport_api_version: V3 - )EOF"; - expectCorrectProtoGrpc(envoy_grpc_service_yaml); -} + ~ExtAuthzFilterTest() override { + // Reset the async client manager before shutdown threading. + // Because its dtor will try to post to event loop to clear thread local slot. + runOnMainBlocking([&]() { async_client_manager_.reset(); }); + // TODO(chaoqin-li1123): clean this up when we figure out how to free the threading resources in + // RealThreadsTestHelper. + shutdownThreading(); + exitThreads(); + } + + Http::FilterFactoryCb createFilterFactory( + const envoy::extensions::filters::http::ext_authz::v3::ExtAuthz& ext_authz_config) { + // Delegate call to mock async client manager to real async client manager. + ON_CALL(context_, getServerFactoryContext()).WillByDefault(testing::ReturnRef(server_context_)); + ON_CALL(context_.cluster_manager_.async_client_manager_, getOrCreateRawAsyncClient(_, _, _, _)) + .WillByDefault(Invoke([&](const envoy::config::core::v3::GrpcService& config, + Stats::Scope& scope, bool skip_cluster_check, + Grpc::CacheOption cache_option) { + return async_client_manager_->getOrCreateRawAsyncClient(config, scope, skip_cluster_check, + cache_option); + })); + ExtAuthzFilterConfig factory; + return factory.createFilterFactoryFromProto(ext_authz_config, "stats", context_); + } + + Http::StreamFilterSharedPtr createFilterFromFilterFactory(Http::FilterFactoryCb filter_factory) { + StrictMock filter_callbacks; + + Http::StreamFilterSharedPtr filter; + EXPECT_CALL(filter_callbacks, addStreamFilter(_)).WillOnce(::testing::SaveArg<0>(&filter)); + filter_factory(filter_callbacks); + return filter; + } + +private: + NiceMock server_context_; + Stats::SymbolTableImpl symbol_table_; + Grpc::StatNames stat_names_; -TEST(HttpExtAuthzConfigTest, CorrectProtoHttp) { - std::string yaml = R"EOF( +protected: + NiceMock context_; + std::unique_ptr async_client_manager_; +}; + +class ExtAuthzFilterHttpTest : public ExtAuthzFilterTest { +public: + void testFilterFactory(const std::string& ext_authz_config_yaml) { + envoy::extensions::filters::http::ext_authz::v3::ExtAuthz ext_authz_config; + Http::FilterFactoryCb filter_factory; + // Load config and create filter factory in main thread. + runOnMainBlocking([&]() { + TestUtility::loadFromYaml(ext_authz_config_yaml, ext_authz_config); + filter_factory = createFilterFactory(ext_authz_config); + }); + + // Create filter from filter factory per thread. + for (int i = 0; i < 5; i++) { + runOnAllWorkersBlocking([&, filter_factory]() { + Http::StreamFilterSharedPtr filter = createFilterFromFilterFactory(filter_factory); + EXPECT_NE(filter, nullptr); + }); + } + } +}; + +TEST_F(ExtAuthzFilterHttpTest, ExtAuthzFilterFactoryTestHttp) { + const std::string ext_authz_config_yaml = R"EOF( stat_prefix: "wall" transport_api_version: V3 http_service: @@ -132,22 +155,90 @@ TEST(HttpExtAuthzConfigTest, CorrectProtoHttp) { max_request_bytes: 100 pack_as_bytes: true )EOF"; + testFilterFactory(ext_authz_config_yaml); +} + +class ExtAuthzFilterGrpcTest : public ExtAuthzFilterTest { +public: + void testFilterFactoryAndFilterWithGrpcClient(const std::string& ext_authz_config_yaml) { + envoy::extensions::filters::http::ext_authz::v3::ExtAuthz ext_authz_config; + Http::FilterFactoryCb filter_factory; + runOnMainBlocking([&]() { + TestUtility::loadFromYaml(ext_authz_config_yaml, ext_authz_config); + filter_factory = createFilterFactory(ext_authz_config); + }); + + int request_sent_per_thread = 5; + // Initialize address instance to prepare for grpc traffic. + initAddress(); + // Create filter from filter factory per thread and send grpc request. + for (int i = 0; i < request_sent_per_thread; i++) { + runOnAllWorkersBlocking([&, filter_factory]() { + Http::StreamFilterSharedPtr filter = createFilterFromFilterFactory(filter_factory); + testExtAuthzFilter(filter); + }); + } + runOnAllWorkersBlocking( + [&]() { expectGrpcClientSentRequest(ext_authz_config, request_sent_per_thread); }); + } + +private: + void initAddress() { + addr_ = std::make_shared("1.2.3.4", 1111); + connection_.stream_info_.downstream_connection_info_provider_->setRemoteAddress(addr_); + connection_.stream_info_.downstream_connection_info_provider_->setLocalAddress(addr_); + } + + void testExtAuthzFilter(Http::StreamFilterSharedPtr filter) { + EXPECT_NE(filter, nullptr); + Http::TestRequestHeaderMapImpl request_headers; + NiceMock decoder_callbacks; + ON_CALL(decoder_callbacks, connection()).WillByDefault(Return(&connection_)); + filter->setDecoderFilterCallbacks(decoder_callbacks); + EXPECT_EQ(Http::FilterHeadersStatus::StopAllIterationAndWatermark, + filter->decodeHeaders(request_headers, false)); + std::shared_ptr decoder_filter = filter; + decoder_filter->onDestroy(); + } + + void expectGrpcClientSentRequest( + const envoy::extensions::filters::http::ext_authz::v3::ExtAuthz& ext_authz_config, + int requests_sent_per_thread) { + Grpc::RawAsyncClientSharedPtr async_client = async_client_manager_->getOrCreateRawAsyncClient( + ext_authz_config.grpc_service(), context_.scope(), false, Grpc::CacheOption::AlwaysCache); + Grpc::MockAsyncClient* mock_async_client = + dynamic_cast(async_client.get()); + EXPECT_NE(mock_async_client, nullptr); + // All the request in this thread should be sent through the same async client because the async + // client is cached. + EXPECT_EQ(mock_async_client->send_count_, requests_sent_per_thread); + } + + Network::Address::InstanceConstSharedPtr addr_; + NiceMock connection_; +}; - ExtAuthzFilterConfig factory; - ProtobufTypes::MessagePtr proto_config = factory.createEmptyConfigProto(); - TestUtility::loadFromYaml(yaml, *proto_config); - testing::StrictMock context; - testing::StrictMock server_context; - EXPECT_CALL(context, getServerFactoryContext()) - .WillRepeatedly(testing::ReturnRef(server_context)); - EXPECT_CALL(context, messageValidationVisitor()); - EXPECT_CALL(context, clusterManager()); - EXPECT_CALL(context, runtime()); - EXPECT_CALL(context, scope()); - Http::FilterFactoryCb cb = factory.createFilterFactoryFromProto(*proto_config, "stats", context); - testing::StrictMock filter_callback; - EXPECT_CALL(filter_callback, addStreamFilter(_)); - cb(filter_callback); +TEST_F(ExtAuthzFilterGrpcTest, EnvoyGrpc) { + const std::string ext_authz_config_yaml = R"EOF( + transport_api_version: V3 + grpc_service: + envoy_grpc: + cluster_name: test_cluster + failure_mode_allow: false + )EOF"; + testFilterFactoryAndFilterWithGrpcClient(ext_authz_config_yaml); +} + +TEST_F(ExtAuthzFilterGrpcTest, GoogleGrpc) { + const std::string ext_authz_config_yaml = R"EOF( + transport_api_version: V3 + grpc_service: + google_grpc: + target_uri: ext_authz_server + stat_prefix: google + failure_mode_allow: false + )EOF"; + testFilterFactoryAndFilterWithGrpcClient(ext_authz_config_yaml); } // Test that the deprecated extension name is disabled by default. diff --git a/test/mocks/grpc/mocks.cc b/test/mocks/grpc/mocks.cc index 20605edb277e8..6b747a66611d4 100644 --- a/test/mocks/grpc/mocks.cc +++ b/test/mocks/grpc/mocks.cc @@ -7,7 +7,13 @@ namespace Grpc { MockAsyncClient::MockAsyncClient() { async_request_ = std::make_unique>(); - ON_CALL(*this, sendRaw(_, _, _, _, _, _)).WillByDefault(Return(async_request_.get())); + ON_CALL(*this, sendRaw(_, _, _, _, _, _)) + .WillByDefault(Invoke([this](absl::string_view, absl::string_view, Buffer::InstancePtr&&, + RawAsyncRequestCallbacks&, Tracing::Span&, + const Http::AsyncClient::RequestOptions&) { + send_count_++; + return async_request_.get(); + })); } MockAsyncClient::~MockAsyncClient() = default; diff --git a/test/mocks/grpc/mocks.h b/test/mocks/grpc/mocks.h index 76de0db3f2b92..cf4f244b9e5c4 100644 --- a/test/mocks/grpc/mocks.h +++ b/test/mocks/grpc/mocks.h @@ -90,6 +90,8 @@ class MockAsyncClient : public RawAsyncClient { const Http::AsyncClient::StreamOptions& options)); std::unique_ptr> async_request_; + // Keep track of the number of requests to detect potential race condition. + int send_count_{}; }; class MockAsyncClientFactory : public AsyncClientFactory { diff --git a/test/test_common/BUILD b/test/test_common/BUILD index 9e4ecbdeee869..a6df26363f633 100644 --- a/test/test_common/BUILD +++ b/test/test_common/BUILD @@ -55,6 +55,18 @@ envoy_cc_test_library( ], ) +envoy_cc_test_library( + name = "real_threads_test_helper_lib", + srcs = ["real_threads_test_helper.cc"], + hdrs = ["real_threads_test_helper.h"], + deps = [ + "utility_lib", + "//source/common/common:thread_lib", + "//source/common/event:dispatcher_lib", + "//source/common/thread_local:thread_local_lib", + ], +) + envoy_cc_test( name = "network_utility_test", srcs = ["network_utility_test.cc"], diff --git a/test/test_common/real_threads_test_helper.cc b/test/test_common/real_threads_test_helper.cc new file mode 100644 index 0000000000000..3408b0c173b6e --- /dev/null +++ b/test/test_common/real_threads_test_helper.cc @@ -0,0 +1,110 @@ +#include "real_threads_test_helper.h" + +#include "absl/synchronization/barrier.h" +#include "utility.h" + +namespace Envoy { +namespace Thread { + +RealThreadsTestHelper::RealThreadsTestHelper(uint32_t num_threads) + : api_(Api::createApiForTest()), num_threads_(num_threads), + thread_factory_(api_->threadFactory()) { + // This is the same order as InstanceImpl::initialize in source/server/server.cc. + thread_dispatchers_.resize(num_threads_); + { + BlockingBarrier blocking_barrier(num_threads_ + 1); + main_thread_ = thread_factory_.createThread( + [this, &blocking_barrier]() { mainThreadFn(blocking_barrier); }); + for (uint32_t i = 0; i < num_threads_; ++i) { + threads_.emplace_back(thread_factory_.createThread( + [this, i, &blocking_barrier]() { workerThreadFn(i, blocking_barrier); })); + } + } + runOnMainBlocking([this]() { + tls_ = std::make_unique(); + tls_->registerThread(*main_dispatcher_, true); + for (Event::DispatcherPtr& dispatcher : thread_dispatchers_) { + // Worker threads must be registered from the main thread, per assert in registerThread(). + tls_->registerThread(*dispatcher, false); + } + }); +} + +std::function RealThreadsTestHelper::BlockingBarrier::run(std::function f) { + return [this, f]() { + f(); + decrementCount(); + }; +} + +std::function RealThreadsTestHelper::BlockingBarrier::decrementCountFn() { + return [this] { decrementCount(); }; +} + +void RealThreadsTestHelper::shutdownThreading() { + runOnMainBlocking([this]() { + if (!tls_->isShutdown()) { + tls_->shutdownGlobalThreading(); + } + tls_->shutdownThread(); + }); +} + +void RealThreadsTestHelper::exitThreads() { + for (Event::DispatcherPtr& dispatcher : thread_dispatchers_) { + dispatcher->post([&dispatcher]() { dispatcher->exit(); }); + } + + for (ThreadPtr& thread : threads_) { + thread->join(); + } + + main_dispatcher_->post([this]() { + tls_.reset(); + main_dispatcher_->exit(); + }); + main_thread_->join(); +} + +void RealThreadsTestHelper::runOnAllWorkersBlocking(std::function work) { + absl::Barrier start_barrier(num_threads_); + BlockingBarrier blocking_barrier(num_threads_); + for (Event::DispatcherPtr& thread_dispatcher : thread_dispatchers_) { + thread_dispatcher->post(blocking_barrier.run([work, &start_barrier]() { + start_barrier.Block(); + work(); + })); + } +} + +void RealThreadsTestHelper::runOnMainBlocking(std::function work) { + BlockingBarrier blocking_barrier(1); + main_dispatcher_->post(blocking_barrier.run([work]() { work(); })); +} + +void RealThreadsTestHelper::mainDispatchBlock() { + // To ensure all stats are freed we have to wait for a few posts() to clear. + // First, wait for the main-dispatcher to initiate the cross-thread TLS cleanup. + runOnMainBlocking([]() {}); +} + +void RealThreadsTestHelper::tlsBlock() { + runOnAllWorkersBlocking([]() {}); +} + +void RealThreadsTestHelper::workerThreadFn(uint32_t thread_index, + BlockingBarrier& blocking_barrier) { + thread_dispatchers_[thread_index] = + api_->allocateDispatcher(absl::StrCat("test_worker_", thread_index)); + blocking_barrier.decrementCount(); + thread_dispatchers_[thread_index]->run(Event::Dispatcher::RunType::RunUntilExit); +} + +void RealThreadsTestHelper::mainThreadFn(BlockingBarrier& blocking_barrier) { + main_dispatcher_ = api_->allocateDispatcher("test_main_thread"); + blocking_barrier.decrementCount(); + main_dispatcher_->run(Event::Dispatcher::RunType::RunUntilExit); +} + +} // namespace Thread +} // namespace Envoy diff --git a/test/test_common/real_threads_test_helper.h b/test/test_common/real_threads_test_helper.h new file mode 100644 index 0000000000000..12444777efbd4 --- /dev/null +++ b/test/test_common/real_threads_test_helper.h @@ -0,0 +1,78 @@ +#include "source/common/event/dispatcher_impl.h" +#include "source/common/thread_local/thread_local_impl.h" + +#include "absl/synchronization/blocking_counter.h" + +namespace Envoy { +namespace Thread { + +class RealThreadsTestHelper { +protected: + // Helper class to block on a number of multi-threaded operations occurring. + class BlockingBarrier { + public: + explicit BlockingBarrier(uint32_t count) : blocking_counter_(count) {} + ~BlockingBarrier() { blocking_counter_.Wait(); } + + /** + * Returns a function that first executes 'f', and then decrements the count + * toward unblocking the scope. This is intended to be used as a post() callback. + * + * @param f the function to run prior to decrementing the count. + */ + std::function run(std::function f); + + /** + * @return a function that, when run, decrements the count, intended for passing to post(). + */ + std::function decrementCountFn(); + + void decrementCount() { blocking_counter_.DecrementCount(); } + + private: + absl::BlockingCounter blocking_counter_; + }; + + explicit RealThreadsTestHelper(uint32_t num_threads); + // TODO(chaoqin-li1123): Clean up threading resources from the destructor when we figure out how + // to handle different destruction orders of thread local object. + ~RealThreadsTestHelper() = default; + // Shutdown thread local instance. + void shutdownThreading(); + // Post exit signal and wait for main thread and worker threads to join. + void exitThreads(); + // Run the callback in all the workers, block until the callback has finished in all threads. + void runOnAllWorkersBlocking(std::function work); + // Run the callback in main thread, block until the callback has been executed in main thread. + void runOnMainBlocking(std::function work); + // Post an empty callback to main thread and block until all the previous callbacks have been + // executed. + void mainDispatchBlock(); + // Post an empty callback to worker threads and block until all the previous callbacks have been + // executed. + void tlsBlock(); + + ThreadLocal::Instance& tls() { return *tls_; } + + Api::Api& api() { return *api_; } + + // TODO(chaoqin-li1123): make these variables private when we figure out how to clean up the + // threading resources inside the helper class. + Api::ApiPtr api_; + Event::DispatcherPtr main_dispatcher_; + std::vector thread_dispatchers_; + ThreadLocal::InstanceImplPtr tls_; + ThreadPtr main_thread_; + std::vector threads_; + +private: + void workerThreadFn(uint32_t thread_index, BlockingBarrier& blocking_barrier); + + void mainThreadFn(BlockingBarrier& blocking_barrier); + + const uint32_t num_threads_; + ThreadFactory& thread_factory_; +}; + +} // namespace Thread +} // namespace Envoy