Skip to content
5 changes: 5 additions & 0 deletions envoy/grpc/async_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ class RawAsyncClient {
absl::string_view method_name,
RawAsyncStreamCallbacks& callbacks,
const Http::AsyncClient::StreamOptions& options) PURE;

protected:
// The lifetime of RawAsyncClient must be in the same thread.
bool isThreadSafe() { return thread_id_ == std::this_thread::get_id(); }
std::thread::id thread_id_{std::this_thread::get_id()};
};

using RawAsyncClientPtr = std::unique_ptr<RawAsyncClient>;
Expand Down
3 changes: 3 additions & 0 deletions source/common/grpc/async_client_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ AsyncClientImpl::AsyncClientImpl(Upstream::ClusterManager& cm,
Router::HeaderParser::configure(config.initial_metadata(), /*append=*/false)) {}

AsyncClientImpl::~AsyncClientImpl() {
ASSERT(isThreadSafe());
while (!active_streams_.empty()) {
active_streams_.front()->resetStream();
}
Expand All @@ -31,6 +32,7 @@ AsyncRequest* AsyncClientImpl::sendRaw(absl::string_view service_full_name,
RawAsyncRequestCallbacks& callbacks,
Tracing::Span& parent_span,
const Http::AsyncClient::RequestOptions& options) {
ASSERT(isThreadSafe());
auto* const async_request = new AsyncRequestImpl(
*this, service_full_name, method_name, std::move(request), callbacks, parent_span, options);
AsyncStreamImplPtr grpc_stream{async_request};
Expand All @@ -48,6 +50,7 @@ RawAsyncStream* AsyncClientImpl::startRaw(absl::string_view service_full_name,
absl::string_view method_name,
RawAsyncStreamCallbacks& callbacks,
const Http::AsyncClient::StreamOptions& options) {
ASSERT(isThreadSafe());
auto grpc_stream =
std::make_unique<AsyncStreamImpl>(*this, service_full_name, method_name, callbacks, options);

Expand Down
3 changes: 3 additions & 0 deletions source/common/grpc/google_async_client_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ GoogleAsyncClientImpl::GoogleAsyncClientImpl(Event::Dispatcher& dispatcher,
}

GoogleAsyncClientImpl::~GoogleAsyncClientImpl() {
ASSERT(isThreadSafe());
ENVOY_LOG(debug, "Client teardown, resetting streams");
while (!active_streams_.empty()) {
active_streams_.front()->resetStream();
Expand All @@ -120,6 +121,7 @@ AsyncRequest* GoogleAsyncClientImpl::sendRaw(absl::string_view service_full_name
RawAsyncRequestCallbacks& callbacks,
Tracing::Span& parent_span,
const Http::AsyncClient::RequestOptions& options) {
ASSERT(isThreadSafe());
auto* const async_request = new GoogleAsyncRequestImpl(
*this, service_full_name, method_name, std::move(request), callbacks, parent_span, options);
GoogleAsyncStreamImplPtr grpc_stream{async_request};
Expand All @@ -137,6 +139,7 @@ RawAsyncStream* GoogleAsyncClientImpl::startRaw(absl::string_view service_full_n
absl::string_view method_name,
RawAsyncStreamCallbacks& callbacks,
const Http::AsyncClient::StreamOptions& options) {
ASSERT(isThreadSafe());
auto grpc_stream = std::make_unique<GoogleAsyncStreamImpl>(*this, service_full_name, method_name,
callbacks, options);

Expand Down
12 changes: 6 additions & 6 deletions source/extensions/filters/http/ext_authz/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Http::FilterFactoryCb ExtAuthzFilterConfig::createFilterFactoryFromProtoTyped(
const auto filter_config = std::make_shared<FilterConfig>(
proto_config, context.scope(), context.runtime(), context.httpContext(), stats_prefix,
context.getServerFactoryContext().bootstrap());
// The callback is created in main thread and executed in worker thread, variables except factory
// context must be captured by value into the callback.
Http::FilterFactoryCb callback;

if (proto_config.has_http_service()) {
Expand All @@ -42,7 +44,6 @@ Http::FilterFactoryCb ExtAuthzFilterConfig::createFilterFactoryFromProtoTyped(
};
} else if (proto_config.grpc_service().has_google_grpc()) {
// Google gRPC client.

const uint32_t timeout_ms =
PROTOBUF_GET_MS_OR_DEFAULT(proto_config.grpc_service(), timeout, DefaultTimeout);

Expand All @@ -57,15 +58,14 @@ Http::FilterFactoryCb ExtAuthzFilterConfig::createFilterFactoryFromProtoTyped(
};
} else {
// Envoy gRPC client.

Grpc::RawAsyncClientSharedPtr raw_client =
context.clusterManager().grpcAsyncClientManager().getOrCreateRawAsyncClient(
proto_config.grpc_service(), context.scope(), true, Grpc::CacheOption::AlwaysCache);
const uint32_t timeout_ms =
PROTOBUF_GET_MS_OR_DEFAULT(proto_config.grpc_service(), timeout, DefaultTimeout);
callback = [raw_client, filter_config, timeout_ms,
callback = [grpc_service = proto_config.grpc_service(), &context, filter_config, timeout_ms,
transport_api_version = Config::Utility::getAndCheckTransportVersion(proto_config)](
Http::FilterChainFactoryCallbacks& callbacks) {
Grpc::RawAsyncClientSharedPtr raw_client =
context.clusterManager().grpcAsyncClientManager().getOrCreateRawAsyncClient(
grpc_service, context.scope(), true, Grpc::CacheOption::AlwaysCache);
auto client = std::make_unique<Filters::Common::ExtAuthz::GrpcClientImpl>(
raw_client, std::chrono::milliseconds(timeout_ms), transport_api_version);
callbacks.addStreamFilter(std::make_shared<Filter>(filter_config, std::move(client)));
Expand Down
15 changes: 14 additions & 1 deletion test/common/grpc/async_client_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class EnvoyAsyncClientImplTest : public testing::Test {
public:
EnvoyAsyncClientImplTest()
: method_descriptor_(helloworld::Greeter::descriptor()->FindMethodByName("SayHello")) {
envoy::config::core::v3::GrpcService config;

config.mutable_envoy_grpc()->set_cluster_name("test_cluster");

auto& initial_metadata_entry = *config.mutable_initial_metadata()->Add();
Expand All @@ -39,13 +39,26 @@ class EnvoyAsyncClientImplTest : public testing::Test {
ON_CALL(cm_.thread_local_cluster_, httpAsyncClient()).WillByDefault(ReturnRef(http_client_));
}

envoy::config::core::v3::GrpcService config;
const Protobuf::MethodDescriptor* method_descriptor_;
NiceMock<Http::MockAsyncClient> http_client_;
NiceMock<Upstream::MockClusterManager> cm_;
AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> grpc_client_;
DangerousDeprecatedTestTime test_time_;
};

TEST_F(EnvoyAsyncClientImplTest, ThreadSafe) {
NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> grpc_callbacks;

Thread::ThreadPtr thread = Thread::threadFactoryForTest().createThread([&]() {
// Verify that using the grpc client in a different thread cause assertion failure.
EXPECT_DEBUG_DEATH(grpc_client_->start(*method_descriptor_, grpc_callbacks,
Http::AsyncClient::StreamOptions()),
"isThreadSafe");
});
thread->join();
}

// Validate that the host header is the cluster name in grpc config.
TEST_F(EnvoyAsyncClientImplTest, HostIsClusterNameByDefault) {
NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> grpc_callbacks;
Expand Down
16 changes: 15 additions & 1 deletion test/common/grpc/google_async_client_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class MockStubFactory : public GoogleStubFactory {
return shared_stub_;
}

MockGenericStub* stub_ = new MockGenericStub();
NiceMock<MockGenericStub>* stub_ = new NiceMock<MockGenericStub>();
GoogleStubSharedPtr shared_stub_{stub_};
};

Expand Down Expand Up @@ -86,6 +86,20 @@ class EnvoyGoogleAsyncClientImplTest : public testing::Test {
AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> grpc_client_;
};

// Verify that grpc client check for thread consistency.
TEST_F(EnvoyGoogleAsyncClientImplTest, ThreadSafe) {
initialize();
ON_CALL(*stub_factory_.stub_, PrepareCall_(_, _, _)).WillByDefault(Return(nullptr));
Thread::ThreadPtr thread = Thread::threadFactoryForTest().createThread([&]() {
NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> grpc_callbacks;
// Verify that using the grpc client in a different thread cause assertion failure.
EXPECT_DEBUG_DEATH(grpc_client_->start(*method_descriptor_, grpc_callbacks,
Http::AsyncClient::StreamOptions()),
"isThreadSafe");
});
thread->join();
}

// Validate that a failure in gRPC stub call creation returns immediately with
// status UNAVAILABLE.
TEST_F(EnvoyGoogleAsyncClientImplTest, StreamHttpStartFail) {
Expand Down
77 changes: 56 additions & 21 deletions test/extensions/filters/http/ext_authz/config_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,54 +21,89 @@ namespace HttpFilters {
namespace ExtAuthz {
namespace {

void expectCorrectProtoGrpc(envoy::config::core::v3::ApiVersion api_version) {
void expectCorrectProtoGrpc(envoy::config::core::v3::ApiVersion api_version,
std::string const& grpc_service_yaml) {
std::unique_ptr<TestDeprecatedV2Api> _deprecated_v2_api;
if (api_version != envoy::config::core::v3::ApiVersion::V3) {
_deprecated_v2_api = std::make_unique<TestDeprecatedV2Api>();
}
std::string yaml = R"EOF(
transport_api_version: V3
grpc_service:
google_grpc:
target_uri: ext_authz_server
stat_prefix: google
failure_mode_allow: false
transport_api_version: {}
)EOF";

ExtAuthzFilterConfig factory;
ProtobufTypes::MessagePtr proto_config = factory.createEmptyConfigProto();
TestUtility::loadFromYaml(
fmt::format(yaml, TestUtility::getVersionStringFromApiVersion(api_version)), *proto_config);
fmt::format(grpc_service_yaml, TestUtility::getVersionStringFromApiVersion(api_version)),
*proto_config);

testing::StrictMock<Server::Configuration::MockFactoryContext> context;
testing::StrictMock<Server::Configuration::MockServerFactoryContext> server_context;
EXPECT_CALL(context, getServerFactoryContext())
.WillRepeatedly(testing::ReturnRef(server_context));
EXPECT_CALL(context, messageValidationVisitor());
EXPECT_CALL(context, clusterManager());
EXPECT_CALL(context, clusterManager()).Times(2);
EXPECT_CALL(context, runtime());
EXPECT_CALL(context, scope()).Times(2);
EXPECT_CALL(context, scope()).Times(3);

Http::FilterFactoryCb cb = factory.createFilterFactoryFromProto(*proto_config, "stats", context);
Http::MockFilterChainFactoryCallbacks filter_callback;
EXPECT_CALL(filter_callback, addStreamFilter(_));
// Expect the raw async client to be created inside the callback.
// The creation of the filter callback is in main thread while the execution of callback is in
// worker thread. Because of the thread local cache of async client, it must be created in worker
// thread inside the callback.
EXPECT_CALL(context.cluster_manager_.async_client_manager_, getOrCreateRawAsyncClient(_, _, _, _))
.WillOnce(Invoke(
[](const envoy::config::core::v3::GrpcService&, Stats::Scope&, bool, Grpc::CacheOption) {
return std::make_unique<NiceMock<Grpc::MockAsyncClient>>();
}));

Http::FilterFactoryCb cb = factory.createFilterFactoryFromProto(*proto_config, "stats", context);
Http::MockFilterChainFactoryCallbacks filter_callback;
EXPECT_CALL(filter_callback, addStreamFilter(_));
cb(filter_callback);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it may be useful to add test that calls cb on a separate thread and verifies that getOrCreateRawAsyncClient is called on that separate thread.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, test added.


Thread::ThreadPtr thread = Thread::threadFactoryForTest().createThread([&context, cb]() {
Http::MockFilterChainFactoryCallbacks filter_callback;
EXPECT_CALL(filter_callback, addStreamFilter(_));
// Execute the filter factory callback in another thread.
EXPECT_CALL(context.cluster_manager_.async_client_manager_,
getOrCreateRawAsyncClient(_, _, _, _))
.WillOnce(Invoke(
[](const envoy::config::core::v3::GrpcService&, Stats::Scope&, bool,
Grpc::CacheOption) { return std::make_unique<NiceMock<Grpc::MockAsyncClient>>(); }));
cb(filter_callback);
});
thread->join();
}

} // namespace

TEST(HttpExtAuthzConfigTest, CorrectProtoGrpc) {
TEST(HttpExtAuthzConfigTest, CorrectProtoGoogleGrpc) {
std::string google_grpc_service_yaml = R"EOF(
transport_api_version: V3
grpc_service:
google_grpc:
target_uri: ext_authz_server
stat_prefix: google
failure_mode_allow: false
transport_api_version: {}
)EOF";
#ifndef ENVOY_DISABLE_DEPRECATED_FEATURES
// TODO(chaoqin-li1123): clean this up when we move AUTO to V3 by default.
expectCorrectProtoGrpc(envoy::config::core::v3::ApiVersion::AUTO, google_grpc_service_yaml);
#endif
expectCorrectProtoGrpc(envoy::config::core::v3::ApiVersion::V3, google_grpc_service_yaml);
}

TEST(HttpExtAuthzConfigTest, CorrectProtoEnvoyGrpc) {
std::string envoy_grpc_service_yaml = R"EOF(
transport_api_version: V3
grpc_service:
envoy_grpc:
cluster_name: ext_authz_server
failure_mode_allow: false
transport_api_version: {}
)EOF";
#ifndef ENVOY_DISABLE_DEPRECATED_FEATURES
expectCorrectProtoGrpc(envoy::config::core::v3::ApiVersion::AUTO);
expectCorrectProtoGrpc(envoy::config::core::v3::ApiVersion::V2);
// TODO(chaoqin-li1123): clean this up when we move AUTO to V3 by default.
expectCorrectProtoGrpc(envoy::config::core::v3::ApiVersion::AUTO, envoy_grpc_service_yaml);
#endif
expectCorrectProtoGrpc(envoy::config::core::v3::ApiVersion::V3);
expectCorrectProtoGrpc(envoy::config::core::v3::ApiVersion::V3, envoy_grpc_service_yaml);
}

TEST(HttpExtAuthzConfigTest, CorrectProtoHttp) {
Expand Down