diff --git a/include/envoy/http/async_client.h b/include/envoy/http/async_client.h index a38555acb7ffb..9e95df1cc2f76 100644 --- a/include/envoy/http/async_client.h +++ b/include/envoy/http/async_client.h @@ -20,6 +20,19 @@ namespace Http { */ class AsyncClient { public: + /** + * An in-flight HTTP request. + */ + class Request { + public: + virtual ~Request() = default; + + /** + * Signals that the request should be cancelled. + */ + virtual void cancel() PURE; + }; + /** * Async Client failure reasons. */ @@ -30,6 +43,9 @@ class AsyncClient { /** * Notifies caller of async HTTP request status. + * + * To support a use case where a caller makes multiple requests in parallel, + * individual callback methods provide request context corresponding to that response. */ class Callbacks { public: @@ -37,14 +53,23 @@ class AsyncClient { /** * Called when the async HTTP request succeeds. + * @param request request handle. + * NOTE: request handle is passed for correlation purposes only, e.g. + * for client code to be able to exclude that handle from a list of + * requests in progress. * @param response the HTTP response */ - virtual void onSuccess(ResponseMessagePtr&& response) PURE; + virtual void onSuccess(const Request& request, ResponseMessagePtr&& response) PURE; /** * Called when the async HTTP request fails. + * @param request request handle. + * NOTE: request handle is passed for correlation purposes only, e.g. + * for client code to be able to exclude that handle from a list of + * requests in progress. + * @param reason failure reason */ - virtual void onFailure(FailureReason reason) PURE; + virtual void onFailure(const Request& request, FailureReason reason) PURE; }; /** @@ -92,19 +117,6 @@ class AsyncClient { virtual void onReset() PURE; }; - /** - * An in-flight HTTP request. - */ - class Request { - public: - virtual ~Request() = default; - - /** - * Signals that the request should be cancelled. - */ - virtual void cancel() PURE; - }; - /** * An in-flight HTTP stream. */ diff --git a/source/common/config/remote_data_fetcher.cc b/source/common/config/remote_data_fetcher.cc index 1123581a533e1..2572e00913893 100644 --- a/source/common/config/remote_data_fetcher.cc +++ b/source/common/config/remote_data_fetcher.cc @@ -39,7 +39,8 @@ void RemoteDataFetcher::fetch() { DurationUtil::durationToMilliseconds(uri_.timeout())))); } -void RemoteDataFetcher::onSuccess(Http::ResponseMessagePtr&& response) { +void RemoteDataFetcher::onSuccess(const Http::AsyncClient::Request&, + Http::ResponseMessagePtr&& response) { const uint64_t status_code = Http::Utility::getResponseStatus(response->headers()); if (status_code == enumToInt(Http::Code::OK)) { ENVOY_LOG(debug, "fetch remote data [uri = {}]: success", uri_.uri()); @@ -66,7 +67,8 @@ void RemoteDataFetcher::onSuccess(Http::ResponseMessagePtr&& response) { request_ = nullptr; } -void RemoteDataFetcher::onFailure(Http::AsyncClient::FailureReason reason) { +void RemoteDataFetcher::onFailure(const Http::AsyncClient::Request&, + Http::AsyncClient::FailureReason reason) { ENVOY_LOG(debug, "fetch remote data [uri = {}]: network error {}", uri_.uri(), enumToInt(reason)); request_ = nullptr; callback_.onFailure(FailureReason::Network); diff --git a/source/common/config/remote_data_fetcher.h b/source/common/config/remote_data_fetcher.h index ced327a9fa370..34a7863ff2f0e 100644 --- a/source/common/config/remote_data_fetcher.h +++ b/source/common/config/remote_data_fetcher.h @@ -50,8 +50,9 @@ class RemoteDataFetcher : public Logger::Loggable, ~RemoteDataFetcher() override; // Http::AsyncClient::Callbacks - void onSuccess(Http::ResponseMessagePtr&& response) override; - void onFailure(Http::AsyncClient::FailureReason reason) override; + void onSuccess(const Http::AsyncClient::Request&, Http::ResponseMessagePtr&& response) override; + void onFailure(const Http::AsyncClient::Request&, + Http::AsyncClient::FailureReason reason) override; /** * Fetch data from remote. diff --git a/source/common/http/async_client_impl.cc b/source/common/http/async_client_impl.cc index 42a755c1d1258..cc5659da7885a 100644 --- a/source/common/http/async_client_impl.cc +++ b/source/common/http/async_client_impl.cc @@ -267,7 +267,7 @@ void AsyncRequestImpl::onComplete() { response_->trailers(), streamInfo(), Tracing::EgressConfig::get()); - callbacks_.onSuccess(std::move(response_)); + callbacks_.onSuccess(*this, std::move(response_)); } void AsyncRequestImpl::onHeaders(ResponseHeaderMapPtr&& headers, bool) { @@ -302,7 +302,7 @@ void AsyncRequestImpl::onReset() { if (!cancelled_) { // In this case we don't have a valid response so we do need to raise a failure. - callbacks_.onFailure(AsyncClient::FailureReason::Reset); + callbacks_.onFailure(*this, AsyncClient::FailureReason::Reset); } } diff --git a/source/common/http/rest_api_fetcher.cc b/source/common/http/rest_api_fetcher.cc index 15444aeb81dc8..612fff3708a33 100644 --- a/source/common/http/rest_api_fetcher.cc +++ b/source/common/http/rest_api_fetcher.cc @@ -28,13 +28,14 @@ RestApiFetcher::~RestApiFetcher() { void RestApiFetcher::initialize() { refresh(); } -void RestApiFetcher::onSuccess(Http::ResponseMessagePtr&& response) { +void RestApiFetcher::onSuccess(const Http::AsyncClient::Request& request, + Http::ResponseMessagePtr&& response) { uint64_t response_code = Http::Utility::getResponseStatus(response->headers()); if (response_code == enumToInt(Http::Code::NotModified)) { requestComplete(); return; } else if (response_code != enumToInt(Http::Code::OK)) { - onFailure(Http::AsyncClient::FailureReason::Reset); + onFailure(request, Http::AsyncClient::FailureReason::Reset); return; } @@ -47,7 +48,8 @@ void RestApiFetcher::onSuccess(Http::ResponseMessagePtr&& response) { requestComplete(); } -void RestApiFetcher::onFailure(Http::AsyncClient::FailureReason reason) { +void RestApiFetcher::onFailure(const Http::AsyncClient::Request&, + Http::AsyncClient::FailureReason reason) { // Currently Http::AsyncClient::FailureReason only has one value: "Reset". ASSERT(reason == Http::AsyncClient::FailureReason::Reset); onFetchFailure(Config::ConfigUpdateFailureReason::ConnectionFailure, nullptr); diff --git a/source/common/http/rest_api_fetcher.h b/source/common/http/rest_api_fetcher.h index 0be0f53a27920..f7dfa76dcde33 100644 --- a/source/common/http/rest_api_fetcher.h +++ b/source/common/http/rest_api_fetcher.h @@ -62,8 +62,9 @@ class RestApiFetcher : public Http::AsyncClient::Callbacks { void requestComplete(); // Http::AsyncClient::Callbacks - void onSuccess(Http::ResponseMessagePtr&& response) override; - void onFailure(Http::AsyncClient::FailureReason reason) override; + void onSuccess(const Http::AsyncClient::Request&, Http::ResponseMessagePtr&& response) override; + void onFailure(const Http::AsyncClient::Request&, + Http::AsyncClient::FailureReason reason) override; Runtime::RandomGenerator& random_; const std::chrono::milliseconds refresh_interval_; diff --git a/source/common/router/shadow_writer_impl.h b/source/common/router/shadow_writer_impl.h index c59fd00daa585..2224912e88560 100644 --- a/source/common/router/shadow_writer_impl.h +++ b/source/common/router/shadow_writer_impl.h @@ -24,8 +24,8 @@ class ShadowWriterImpl : Logger::Loggable, const Http::AsyncClient::RequestOptions& options) override; // Http::AsyncClient::Callbacks - void onSuccess(Http::ResponseMessagePtr&&) override {} - void onFailure(Http::AsyncClient::FailureReason) override {} + void onSuccess(const Http::AsyncClient::Request&, Http::ResponseMessagePtr&&) override {} + void onFailure(const Http::AsyncClient::Request&, Http::AsyncClient::FailureReason) override {} private: Upstream::ClusterManager& cm_; diff --git a/source/extensions/filters/common/ext_authz/ext_authz_http_impl.cc b/source/extensions/filters/common/ext_authz/ext_authz_http_impl.cc index fa207cc411ea3..09b59bf0c97a3 100644 --- a/source/extensions/filters/common/ext_authz/ext_authz_http_impl.cc +++ b/source/extensions/filters/common/ext_authz/ext_authz_http_impl.cc @@ -286,14 +286,16 @@ void RawHttpClientImpl::check(RequestCallbacks& callbacks, } } -void RawHttpClientImpl::onSuccess(Http::ResponseMessagePtr&& message) { +void RawHttpClientImpl::onSuccess(const Http::AsyncClient::Request&, + Http::ResponseMessagePtr&& message) { callbacks_->onComplete(toResponse(std::move(message))); span_->finishSpan(); callbacks_ = nullptr; span_ = nullptr; } -void RawHttpClientImpl::onFailure(Http::AsyncClient::FailureReason reason) { +void RawHttpClientImpl::onFailure(const Http::AsyncClient::Request&, + Http::AsyncClient::FailureReason reason) { ASSERT(reason == Http::AsyncClient::FailureReason::Reset); callbacks_->onComplete(std::make_unique(errorResponse())); span_->setTag(Tracing::Tags::get().Error, Tracing::Tags::get().True); diff --git a/source/extensions/filters/common/ext_authz/ext_authz_http_impl.h b/source/extensions/filters/common/ext_authz/ext_authz_http_impl.h index 23a1e69aab764..7cd2732090083 100644 --- a/source/extensions/filters/common/ext_authz/ext_authz_http_impl.h +++ b/source/extensions/filters/common/ext_authz/ext_authz_http_impl.h @@ -155,8 +155,9 @@ class RawHttpClientImpl : public Client, Tracing::Span&) override; // Http::AsyncClient::Callbacks - void onSuccess(Http::ResponseMessagePtr&& message) override; - void onFailure(Http::AsyncClient::FailureReason reason) override; + void onSuccess(const Http::AsyncClient::Request&, Http::ResponseMessagePtr&& message) override; + void onFailure(const Http::AsyncClient::Request&, + Http::AsyncClient::FailureReason reason) override; private: ResponsePtr toResponse(Http::ResponseMessagePtr message); diff --git a/source/extensions/filters/http/common/jwks_fetcher.cc b/source/extensions/filters/http/common/jwks_fetcher.cc index 90bf7cf61733a..3406879727c75 100644 --- a/source/extensions/filters/http/common/jwks_fetcher.cc +++ b/source/extensions/filters/http/common/jwks_fetcher.cc @@ -63,7 +63,7 @@ class JwksFetcherImpl : public JwksFetcher, } // HTTP async receive methods - void onSuccess(Http::ResponseMessagePtr&& response) override { + void onSuccess(const Http::AsyncClient::Request&, Http::ResponseMessagePtr&& response) override { ENVOY_LOG(trace, "{}", __func__); complete_ = true; const uint64_t status_code = Http::Utility::getResponseStatus(response->headers()); @@ -93,7 +93,8 @@ class JwksFetcherImpl : public JwksFetcher, reset(); } - void onFailure(Http::AsyncClient::FailureReason reason) override { + void onFailure(const Http::AsyncClient::Request&, + Http::AsyncClient::FailureReason reason) override { ENVOY_LOG(debug, "{}: fetch pubkey [uri = {}]: network error {}", __func__, uri_->uri(), enumToInt(reason)); complete_ = true; diff --git a/source/extensions/filters/http/lua/lua_filter.cc b/source/extensions/filters/http/lua/lua_filter.cc index 2d94234b7a451..fc935fa8c33ba 100644 --- a/source/extensions/filters/http/lua/lua_filter.cc +++ b/source/extensions/filters/http/lua/lua_filter.cc @@ -295,7 +295,8 @@ int StreamHandleWrapper::luaHttpCallAsynchronous(lua_State* state) { return 0; } -void StreamHandleWrapper::onSuccess(Http::ResponseMessagePtr&& response) { +void StreamHandleWrapper::onSuccess(const Http::AsyncClient::Request&, + Http::ResponseMessagePtr&& response) { ASSERT(state_ == State::HttpCall || state_ == State::Running); ENVOY_LOG(debug, "async HTTP response complete"); http_request_ = nullptr; @@ -341,7 +342,8 @@ void StreamHandleWrapper::onSuccess(Http::ResponseMessagePtr&& response) { } } -void StreamHandleWrapper::onFailure(Http::AsyncClient::FailureReason) { +void StreamHandleWrapper::onFailure(const Http::AsyncClient::Request& request, + Http::AsyncClient::FailureReason) { ASSERT(state_ == State::HttpCall || state_ == State::Running); ENVOY_LOG(debug, "async HTTP failure"); @@ -351,7 +353,7 @@ void StreamHandleWrapper::onFailure(Http::AsyncClient::FailureReason) { {{Http::Headers::get().Status, std::to_string(enumToInt(Http::Code::ServiceUnavailable))}}))); response_message->body() = std::make_unique("upstream failure"); - onSuccess(std::move(response_message)); + onSuccess(request, std::move(response_message)); } int StreamHandleWrapper::luaHeaders(lua_State* state) { diff --git a/source/extensions/filters/http/lua/lua_filter.h b/source/extensions/filters/http/lua/lua_filter.h index ffab53fc67ab3..88725c50ef406 100644 --- a/source/extensions/filters/http/lua/lua_filter.h +++ b/source/extensions/filters/http/lua/lua_filter.h @@ -253,8 +253,8 @@ class StreamHandleWrapper : public Filters::Common::Lua::BaseLuaObject&& on_fail) : on_success_(on_success), on_fail_(on_fail) {} // Http::AsyncClient::Callbacks - void onSuccess(Http::ResponseMessagePtr&& m) override { + void onSuccess(const Http::AsyncClient::Request&, Http::ResponseMessagePtr&& m) override { on_success_(std::forward(m)); } - void onFailure(Http::AsyncClient::FailureReason f) override { on_fail_(f); } + void onFailure(const Http::AsyncClient::Request&, Http::AsyncClient::FailureReason f) override { + on_fail_(f); + } private: const std::function on_success_; diff --git a/source/extensions/tracers/datadog/datadog_tracer_impl.cc b/source/extensions/tracers/datadog/datadog_tracer_impl.cc index 4f596a8ba65b6..71206182a63f2 100644 --- a/source/extensions/tracers/datadog/datadog_tracer_impl.cc +++ b/source/extensions/tracers/datadog/datadog_tracer_impl.cc @@ -109,12 +109,13 @@ void TraceReporter::flushTraces() { } } -void TraceReporter::onFailure(Http::AsyncClient::FailureReason) { +void TraceReporter::onFailure(const Http::AsyncClient::Request&, Http::AsyncClient::FailureReason) { ENVOY_LOG(debug, "failure submitting traces to datadog agent"); driver_.tracerStats().reports_failed_.inc(); } -void TraceReporter::onSuccess(Http::ResponseMessagePtr&& http_response) { +void TraceReporter::onSuccess(const Http::AsyncClient::Request&, + Http::ResponseMessagePtr&& http_response) { uint64_t responseStatus = Http::Utility::getResponseStatus(http_response->headers()); if (responseStatus != enumToInt(Http::Code::OK)) { // TODO: Consider adding retries for failed submissions. diff --git a/source/extensions/tracers/datadog/datadog_tracer_impl.h b/source/extensions/tracers/datadog/datadog_tracer_impl.h index d48a8e0f13e7c..b36384b716b99 100644 --- a/source/extensions/tracers/datadog/datadog_tracer_impl.h +++ b/source/extensions/tracers/datadog/datadog_tracer_impl.h @@ -107,8 +107,8 @@ class TraceReporter : public Http::AsyncClient::Callbacks, TraceReporter(TraceEncoderSharedPtr encoder, Driver& driver, Event::Dispatcher& dispatcher); // Http::AsyncClient::Callbacks. - void onSuccess(Http::ResponseMessagePtr&&) override; - void onFailure(Http::AsyncClient::FailureReason) override; + void onSuccess(const Http::AsyncClient::Request&, Http::ResponseMessagePtr&&) override; + void onFailure(const Http::AsyncClient::Request&, Http::AsyncClient::FailureReason) override; private: /** diff --git a/source/extensions/tracers/lightstep/lightstep_tracer_impl.cc b/source/extensions/tracers/lightstep/lightstep_tracer_impl.cc index 2dd99f6f577e4..a7434a2d419f3 100644 --- a/source/extensions/tracers/lightstep/lightstep_tracer_impl.cc +++ b/source/extensions/tracers/lightstep/lightstep_tracer_impl.cc @@ -68,14 +68,15 @@ LightStepDriver::LightStepTransporter::~LightStepTransporter() { } } -void LightStepDriver::LightStepTransporter::onSuccess(Http::ResponseMessagePtr&& /*response*/) { +void LightStepDriver::LightStepTransporter::onSuccess(const Http::AsyncClient::Request&, + Http::ResponseMessagePtr&& /*response*/) { driver_.grpc_context_.chargeStat(*driver_.cluster(), driver_.request_names_, true); active_callback_->OnSuccess(*active_report_); reset(); } void LightStepDriver::LightStepTransporter::onFailure( - Http::AsyncClient::FailureReason /*failure_reason*/) { + const Http::AsyncClient::Request&, Http::AsyncClient::FailureReason /*failure_reason*/) { driver_.grpc_context_.chargeStat(*driver_.cluster(), driver_.request_names_, false); active_callback_->OnFailure(*active_report_); reset(); diff --git a/source/extensions/tracers/lightstep/lightstep_tracer_impl.h b/source/extensions/tracers/lightstep/lightstep_tracer_impl.h index 41ec6c944dd37..873e61f7da068 100644 --- a/source/extensions/tracers/lightstep/lightstep_tracer_impl.h +++ b/source/extensions/tracers/lightstep/lightstep_tracer_impl.h @@ -88,8 +88,9 @@ class LightStepDriver : public Common::Ot::OpenTracingDriver { Callback& callback) noexcept override; // Http::AsyncClient::Callbacks - void onSuccess(Http::ResponseMessagePtr&& response) override; - void onFailure(Http::AsyncClient::FailureReason failure_reason) override; + void onSuccess(const Http::AsyncClient::Request&, Http::ResponseMessagePtr&& response) override; + void onFailure(const Http::AsyncClient::Request&, + Http::AsyncClient::FailureReason failure_reason) override; private: std::unique_ptr active_report_; diff --git a/source/extensions/tracers/zipkin/zipkin_tracer_impl.cc b/source/extensions/tracers/zipkin/zipkin_tracer_impl.cc index cf710d172298b..be61fa5aaa0bc 100644 --- a/source/extensions/tracers/zipkin/zipkin_tracer_impl.cc +++ b/source/extensions/tracers/zipkin/zipkin_tracer_impl.cc @@ -197,11 +197,12 @@ void ReporterImpl::flushSpans() { } } -void ReporterImpl::onFailure(Http::AsyncClient::FailureReason) { +void ReporterImpl::onFailure(const Http::AsyncClient::Request&, Http::AsyncClient::FailureReason) { driver_.tracerStats().reports_failed_.inc(); } -void ReporterImpl::onSuccess(Http::ResponseMessagePtr&& http_response) { +void ReporterImpl::onSuccess(const Http::AsyncClient::Request&, + Http::ResponseMessagePtr&& http_response) { if (Http::Utility::getResponseStatus(http_response->headers()) != enumToInt(Http::Code::Accepted)) { driver_.tracerStats().reports_dropped_.inc(); diff --git a/source/extensions/tracers/zipkin/zipkin_tracer_impl.h b/source/extensions/tracers/zipkin/zipkin_tracer_impl.h index 7c8881d23b9aa..142dbf84cbbe1 100644 --- a/source/extensions/tracers/zipkin/zipkin_tracer_impl.h +++ b/source/extensions/tracers/zipkin/zipkin_tracer_impl.h @@ -194,8 +194,8 @@ class ReporterImpl : public Reporter, Http::AsyncClient::Callbacks { // Http::AsyncClient::Callbacks. // The callbacks below record Zipkin-span-related stats. - void onSuccess(Http::ResponseMessagePtr&&) override; - void onFailure(Http::AsyncClient::FailureReason) override; + void onSuccess(const Http::AsyncClient::Request&, Http::ResponseMessagePtr&&) override; + void onFailure(const Http::AsyncClient::Request&, Http::AsyncClient::FailureReason) override; /** * Creates a heap-allocated ZipkinReporter. diff --git a/test/common/config/datasource_test.cc b/test/common/config/datasource_test.cc index e76c19a43da18..1897c1a867cec 100644 --- a/test/common/config/datasource_test.cc +++ b/test/common/config/datasource_test.cc @@ -33,6 +33,7 @@ class AsyncDataSourceTest : public testing::Test { Event::MockDispatcher dispatcher_; Event::MockTimer* retry_timer_; Event::TimerCb retry_timer_cb_; + NiceMock request_{&cm_.async_client_}; Config::DataSource::LocalAsyncDataProviderPtr local_data_provider_; Config::DataSource::RemoteAsyncDataProviderPtr remote_data_provider_; @@ -115,7 +116,7 @@ TEST_F(AsyncDataSourceTest, LoadRemoteDataSourceReturnFailure) { initialize([&](Http::RequestMessagePtr&, Http::AsyncClient::Callbacks& callbacks, const Http::AsyncClient::RequestOptions&) -> Http::AsyncClient::Request* { - callbacks.onFailure(Envoy::Http::AsyncClient::FailureReason::Reset); + callbacks.onFailure(request_, Envoy::Http::AsyncClient::FailureReason::Reset); return nullptr; }); @@ -155,8 +156,9 @@ TEST_F(AsyncDataSourceTest, LoadRemoteDataSourceSuccessWith503) { initialize([&](Http::RequestMessagePtr&, Http::AsyncClient::Callbacks& callbacks, const Http::AsyncClient::RequestOptions&) -> Http::AsyncClient::Request* { - callbacks.onSuccess(Http::ResponseMessagePtr{new Http::ResponseMessageImpl( - Http::ResponseHeaderMapPtr{new Http::TestResponseHeaderMapImpl{{":status", "503"}}})}); + callbacks.onSuccess( + request_, Http::ResponseMessagePtr{new Http::ResponseMessageImpl(Http::ResponseHeaderMapPtr{ + new Http::TestResponseHeaderMapImpl{{":status", "503"}}})}); return nullptr; }); @@ -196,8 +198,9 @@ TEST_F(AsyncDataSourceTest, LoadRemoteDataSourceSuccessWithEmptyBody) { initialize([&](Http::RequestMessagePtr&, Http::AsyncClient::Callbacks& callbacks, const Http::AsyncClient::RequestOptions&) -> Http::AsyncClient::Request* { - callbacks.onSuccess(Http::ResponseMessagePtr{new Http::ResponseMessageImpl( - Http::ResponseHeaderMapPtr{new Http::TestResponseHeaderMapImpl{{":status", "200"}}})}); + callbacks.onSuccess( + request_, Http::ResponseMessagePtr{new Http::ResponseMessageImpl(Http::ResponseHeaderMapPtr{ + new Http::TestResponseHeaderMapImpl{{":status", "200"}}})}); return nullptr; }); @@ -243,7 +246,7 @@ TEST_F(AsyncDataSourceTest, LoadRemoteDataSourceSuccessIncorrectSha256) { Http::ResponseHeaderMapPtr{new Http::TestResponseHeaderMapImpl{{":status", "200"}}})); response->body() = std::make_unique(body); - callbacks.onSuccess(std::move(response)); + callbacks.onSuccess(request_, std::move(response)); return nullptr; }); @@ -288,7 +291,7 @@ TEST_F(AsyncDataSourceTest, LoadRemoteDataSourceSuccess) { Http::ResponseHeaderMapPtr{new Http::TestResponseHeaderMapImpl{{":status", "200"}}})); response->body() = std::make_unique(body); - callbacks.onSuccess(std::move(response)); + callbacks.onSuccess(request_, std::move(response)); return nullptr; }); @@ -325,8 +328,9 @@ TEST_F(AsyncDataSourceTest, LoadRemoteDataSourceDoNotAllowEmpty) { initialize([&](Http::RequestMessagePtr&, Http::AsyncClient::Callbacks& callbacks, const Http::AsyncClient::RequestOptions&) -> Http::AsyncClient::Request* { - callbacks.onSuccess(Http::ResponseMessagePtr{new Http::ResponseMessageImpl( - Http::ResponseHeaderMapPtr{new Http::TestResponseHeaderMapImpl{{":status", "503"}}})}); + callbacks.onSuccess( + request_, Http::ResponseMessagePtr{new Http::ResponseMessageImpl(Http::ResponseHeaderMapPtr{ + new Http::TestResponseHeaderMapImpl{{":status", "503"}}})}); return nullptr; }); @@ -369,7 +373,7 @@ TEST_F(AsyncDataSourceTest, DatasourceReleasedBeforeFetchingData) { Http::ResponseHeaderMapPtr{new Http::TestResponseHeaderMapImpl{{":status", "200"}}})); response->body() = std::make_unique(body); - callbacks.onSuccess(std::move(response)); + callbacks.onSuccess(request_, std::move(response)); return nullptr; }); @@ -413,8 +417,10 @@ TEST_F(AsyncDataSourceTest, LoadRemoteDataSourceWithRetry) { initialize( [&](Http::RequestMessagePtr&, Http::AsyncClient::Callbacks& callbacks, const Http::AsyncClient::RequestOptions&) -> Http::AsyncClient::Request* { - callbacks.onSuccess(Http::ResponseMessagePtr{new Http::ResponseMessageImpl( - Http::ResponseHeaderMapPtr{new Http::TestResponseHeaderMapImpl{{":status", "503"}}})}); + callbacks.onSuccess( + request_, + Http::ResponseMessagePtr{new Http::ResponseMessageImpl(Http::ResponseHeaderMapPtr{ + new Http::TestResponseHeaderMapImpl{{":status", "503"}}})}); return nullptr; }, num_retries); @@ -442,7 +448,7 @@ TEST_F(AsyncDataSourceTest, LoadRemoteDataSourceWithRetry) { new Http::TestResponseHeaderMapImpl{{":status", "200"}}})); response->body() = std::make_unique(body); - callbacks.onSuccess(std::move(response)); + callbacks.onSuccess(request_, std::move(response)); return nullptr; })); } diff --git a/test/common/config/http_subscription_impl_test.cc b/test/common/config/http_subscription_impl_test.cc index 9c3e8c4022a5f..d79884ef19158 100644 --- a/test/common/config/http_subscription_impl_test.cc +++ b/test/common/config/http_subscription_impl_test.cc @@ -18,7 +18,7 @@ TEST_F(HttpSubscriptionImplTest, OnRequestReset) { EXPECT_CALL(callbacks_, onConfigUpdateFailed(Envoy::Config::ConfigUpdateFailureReason::ConnectionFailure, _)) .Times(0); - http_callbacks_->onFailure(Http::AsyncClient::FailureReason::Reset); + http_callbacks_->onFailure(http_request_, Http::AsyncClient::FailureReason::Reset); EXPECT_TRUE(statsAre(1, 0, 0, 1, 0, 0, 0)); timerTick(); EXPECT_TRUE(statsAre(2, 0, 0, 1, 0, 0, 0)); @@ -37,7 +37,7 @@ TEST_F(HttpSubscriptionImplTest, BadJsonRecovery) { EXPECT_CALL(*timer_, enableTimer(_, _)); EXPECT_CALL(callbacks_, onConfigUpdateFailed(Envoy::Config::ConfigUpdateFailureReason::UpdateRejected, _)); - http_callbacks_->onSuccess(std::move(message)); + http_callbacks_->onSuccess(http_request_, std::move(message)); EXPECT_TRUE(statsAre(1, 0, 1, 0, 0, 0, 0)); request_in_progress_ = false; timerTick(); diff --git a/test/common/config/http_subscription_test_harness.h b/test/common/config/http_subscription_test_harness.h index c570086fa600c..af798a4efac8f 100644 --- a/test/common/config/http_subscription_test_harness.h +++ b/test/common/config/http_subscription_test_harness.h @@ -151,7 +151,7 @@ class HttpSubscriptionTestHarness : public SubscriptionTestHarness { } EXPECT_CALL(random_gen_, random()).WillOnce(Return(0)); EXPECT_CALL(*timer_, enableTimer(_, _)); - http_callbacks_->onSuccess(std::move(message)); + http_callbacks_->onSuccess(http_request_, std::move(message)); if (accept) { version_ = version; } diff --git a/test/common/http/async_client_impl_test.cc b/test/common/http/async_client_impl_test.cc index d17b09311c561..064f5b85b56a6 100644 --- a/test/common/http/async_client_impl_test.cc +++ b/test/common/http/async_client_impl_test.cc @@ -52,9 +52,13 @@ class AsyncClientImplTest : public testing::Test { .WillByDefault(ReturnRef(envoy::config::core::v3::Locality().default_instance())); } - void expectSuccess(uint64_t code) { - EXPECT_CALL(callbacks_, onSuccess_(_)) - .WillOnce(Invoke([code](ResponseMessage* response) -> void { + void expectSuccess(AsyncClient::Request* sent_request, uint64_t code) { + EXPECT_CALL(callbacks_, onSuccess_(_, _)) + .WillOnce(Invoke([sent_request, code](const AsyncClient::Request& request, + ResponseMessage* response) -> void { + // Verify that callback is called with the same request handle as returned by + // AsyncClient::send(). + EXPECT_EQ(sent_request, &request); EXPECT_EQ(code, Utility::getResponseStatus(response->headers())); })); } @@ -152,9 +156,11 @@ TEST_F(AsyncClientImplTest, Basic) { EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(©), false)); EXPECT_CALL(stream_encoder_, encodeData(BufferEqual(&data), true)); - expectSuccess(200); - client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + auto* request = client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + EXPECT_NE(request, nullptr); + + expectSuccess(request, 200); ResponseHeaderMapPtr response_headers(new TestResponseHeaderMapImpl{{":status", "200"}}); response_decoder_->decodeHeaders(std::move(response_headers), false); @@ -188,12 +194,15 @@ TEST_F(AsyncClientImplTracingTest, Basic) { EXPECT_CALL(parent_span_, spawnChild_(_, "async fake_cluster egress", _)) .WillOnce(Return(child_span)); - expectSuccess(200); AsyncClient::RequestOptions options = AsyncClient::RequestOptions().setParentSpan(parent_span_); EXPECT_CALL(*child_span, setSampled(true)); EXPECT_CALL(*child_span, injectContext(_)); - client_.send(std::move(message_), callbacks_, options); + + auto* request = client_.send(std::move(message_), callbacks_, options); + EXPECT_NE(request, nullptr); + + expectSuccess(request, 200); EXPECT_CALL(*child_span, setTag(Eq(Tracing::Tags::get().Component), Eq(Tracing::Tags::get().Proxy))); @@ -228,7 +237,6 @@ TEST_F(AsyncClientImplTracingTest, BasicNamedChildSpan) { copy.addCopy(":scheme", "http"); EXPECT_CALL(parent_span_, spawnChild_(_, child_span_name_, _)).WillOnce(Return(child_span)); - expectSuccess(200); AsyncClient::RequestOptions options = AsyncClient::RequestOptions() .setParentSpan(parent_span_) @@ -236,7 +244,11 @@ TEST_F(AsyncClientImplTracingTest, BasicNamedChildSpan) { .setSampled(false); EXPECT_CALL(*child_span, setSampled(false)); EXPECT_CALL(*child_span, injectContext(_)); - client_.send(std::move(message_), callbacks_, options); + + auto* request = client_.send(std::move(message_), callbacks_, options); + EXPECT_NE(request, nullptr); + + expectSuccess(request, 200); EXPECT_CALL(*child_span, setTag(Eq(Tracing::Tags::get().Component), Eq(Tracing::Tags::get().Proxy))); @@ -279,13 +291,16 @@ TEST_F(AsyncClientImplTest, BasicHashPolicy) { EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(©), false)); EXPECT_CALL(stream_encoder_, encodeData(BufferEqual(&data), true)); - expectSuccess(200); AsyncClient::RequestOptions options; Protobuf::RepeatedPtrField hash_policy; hash_policy.Add()->mutable_header()->set_header_name(":path"); options.setHashPolicy(hash_policy); - client_.send(std::move(message_), callbacks_, options); + + auto* request = client_.send(std::move(message_), callbacks_, options); + EXPECT_NE(request, nullptr); + + expectSuccess(request, 200); ResponseHeaderMapPtr response_headers(new TestResponseHeaderMapImpl{{":status", "200"}}); response_decoder_->decodeHeaders(std::move(response_headers), false); @@ -312,7 +327,9 @@ TEST_F(AsyncClientImplTest, Retry) { EXPECT_CALL(stream_encoder_, encodeData(BufferEqual(&data), true)); message_->headers().setReferenceEnvoyRetryOn(Headers::get().EnvoyRetryOnValues._5xx); - client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + + auto* request = client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + EXPECT_NE(request, nullptr); // Expect retry and retry timer create. timer_ = new NiceMock(&dispatcher_); @@ -333,7 +350,7 @@ TEST_F(AsyncClientImplTest, Retry) { timer_->invokeCallback(); // Normal response. - expectSuccess(200); + expectSuccess(request, 200); ResponseHeaderMapPtr response_headers2(new TestResponseHeaderMapImpl{{":status", "200"}}); response_decoder_->decodeHeaders(std::move(response_headers2), true); } @@ -462,7 +479,8 @@ TEST_F(AsyncClientImplTest, MultipleRequests) { EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(&message_->headers()), false)); EXPECT_CALL(stream_encoder_, encodeData(BufferEqual(&data), true)); - client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + auto* request1 = client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + EXPECT_NE(request1, nullptr); // Send request 2. RequestMessagePtr message2{new RequestMessageImpl()}; @@ -478,18 +496,57 @@ TEST_F(AsyncClientImplTest, MultipleRequests) { return nullptr; })); EXPECT_CALL(stream_encoder2, encodeHeaders(HeaderMapEqualRef(&message2->headers()), true)); - client_.send(std::move(message2), callbacks2, AsyncClient::RequestOptions()); + + auto* request2 = client_.send(std::move(message2), callbacks2, AsyncClient::RequestOptions()); + EXPECT_NE(request2, nullptr); + + // Send request 3. + RequestMessagePtr message3{new RequestMessageImpl()}; + HttpTestUtility::addDefaultHeaders(message3->headers()); + NiceMock stream_encoder3; + ResponseDecoder* response_decoder3{}; + MockAsyncClientCallbacks callbacks3; + EXPECT_CALL(cm_.conn_pool_, newStream(_, _)) + .WillOnce(Invoke([&](ResponseDecoder& decoder, + ConnectionPool::Callbacks& callbacks) -> ConnectionPool::Cancellable* { + callbacks.onPoolReady(stream_encoder3, cm_.conn_pool_.host_, stream_info_); + response_decoder3 = &decoder; + return nullptr; + })); + EXPECT_CALL(stream_encoder3, encodeHeaders(HeaderMapEqualRef(&message3->headers()), true)); + + auto* request3 = client_.send(std::move(message3), callbacks3, AsyncClient::RequestOptions()); + EXPECT_NE(request3, nullptr); // Finish request 2. ResponseHeaderMapPtr response_headers2(new TestResponseHeaderMapImpl{{":status", "503"}}); - EXPECT_CALL(callbacks2, onSuccess_(_)); + EXPECT_CALL(callbacks2, onSuccess_(_, _)) + .WillOnce(Invoke( + [request2](const AsyncClient::Request& request, ResponseMessage* response) -> void { + // Verify that callback is called with the same request handle as returned by + // AsyncClient::send(). + EXPECT_EQ(request2, &request); + EXPECT_EQ(503, Utility::getResponseStatus(response->headers())); + })); response_decoder2->decodeHeaders(std::move(response_headers2), true); // Finish request 1. ResponseHeaderMapPtr response_headers(new TestResponseHeaderMapImpl{{":status", "200"}}); response_decoder_->decodeHeaders(std::move(response_headers), false); - expectSuccess(200); + expectSuccess(request1, 200); response_decoder_->decodeData(data, true); + + // Finish request 3. + ResponseHeaderMapPtr response_headers3(new TestResponseHeaderMapImpl{{":status", "500"}}); + EXPECT_CALL(callbacks3, onSuccess_(_, _)) + .WillOnce(Invoke( + [request3](const AsyncClient::Request& request, ResponseMessage* response) -> void { + // Verify that callback is called with the same request handle as returned by + // AsyncClient::send(). + EXPECT_EQ(request3, &request); + EXPECT_EQ(500, Utility::getResponseStatus(response->headers())); + })); + response_decoder3->decodeHeaders(std::move(response_headers3), true); } TEST_F(AsyncClientImplTest, StreamAndRequest) { @@ -508,7 +565,8 @@ TEST_F(AsyncClientImplTest, StreamAndRequest) { EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(&message_->headers()), false)); EXPECT_CALL(stream_encoder_, encodeData(BufferEqual(&data), true)); - client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + auto* request = client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + EXPECT_NE(request, nullptr); // Start stream Buffer::InstancePtr body{new Buffer::OwnedImpl("test body")}; @@ -544,7 +602,7 @@ TEST_F(AsyncClientImplTest, StreamAndRequest) { // Finish request. ResponseHeaderMapPtr response_headers(new TestResponseHeaderMapImpl{{":status", "200"}}); response_decoder_->decodeHeaders(std::move(response_headers), false); - expectSuccess(200); + expectSuccess(request, 200); response_decoder_->decodeData(data, true); } @@ -598,9 +656,11 @@ TEST_F(AsyncClientImplTest, Trailers) { EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(&message_->headers()), false)); EXPECT_CALL(stream_encoder_, encodeData(BufferEqual(&data), true)); - expectSuccess(200); - client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + auto* request = client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + EXPECT_NE(request, nullptr); + + expectSuccess(request, 200); ResponseHeaderMapPtr response_headers(new TestResponseHeaderMapImpl{{":status", "200"}}); response_decoder_->decodeHeaders(std::move(response_headers), false); response_decoder_->decodeData(data, false); @@ -617,9 +677,11 @@ TEST_F(AsyncClientImplTest, ImmediateReset) { })); EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(&message_->headers()), true)); - expectSuccess(503); - client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + auto* request = client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + EXPECT_NE(request, nullptr); + + expectSuccess(request, 503); stream_encoder_.getStream().resetStream(StreamResetReason::RemoteReset); EXPECT_EQ( @@ -818,11 +880,20 @@ TEST_F(AsyncClientImplTest, ResetAfterResponseStart) { response_decoder_ = &decoder; return nullptr; })); - EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(&message_->headers()), true)); - EXPECT_CALL(callbacks_, onFailure(_)); - client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + auto* request = client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + EXPECT_NE(request, nullptr); + + EXPECT_CALL(callbacks_, onFailure(_, _)) + .WillOnce(Invoke([sent_request = request](const AsyncClient::Request& request, + AsyncClient::FailureReason reason) { + // Verify that callback is called with the same request handle as returned by + // AsyncClient::send(). + EXPECT_EQ(&request, sent_request); + EXPECT_EQ(reason, AsyncClient::FailureReason::Reset); + })); + ResponseHeaderMapPtr response_headers(new TestResponseHeaderMapImpl{{":status", "200"}}); response_decoder_->decodeHeaders(std::move(response_headers), false); stream_encoder_.getStream().resetStream(StreamResetReason::RemoteReset); @@ -915,11 +986,20 @@ TEST_F(AsyncClientImplTest, DestroyWithActiveRequest) { callbacks.onPoolReady(stream_encoder_, cm_.conn_pool_.host_, stream_info_); return nullptr; })); - EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(&message_->headers()), true)); + + auto* request = client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + EXPECT_NE(request, nullptr); + EXPECT_CALL(stream_encoder_.stream_, resetStream(_)); - EXPECT_CALL(callbacks_, onFailure(_)); - client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions()); + EXPECT_CALL(callbacks_, onFailure(_, _)) + .WillOnce(Invoke([sent_request = request](const AsyncClient::Request& request, + AsyncClient::FailureReason reason) { + // Verify that callback is called with the same request handle as returned by + // AsyncClient::send(). + EXPECT_EQ(&request, sent_request); + EXPECT_EQ(reason, AsyncClient::FailureReason::Reset); + })); } TEST_F(AsyncClientImplTracingTest, DestroyWithActiveRequest) { @@ -937,9 +1017,18 @@ TEST_F(AsyncClientImplTracingTest, DestroyWithActiveRequest) { AsyncClient::RequestOptions options = AsyncClient::RequestOptions().setParentSpan(parent_span_); EXPECT_CALL(*child_span, setSampled(true)); EXPECT_CALL(*child_span, injectContext(_)); - client_.send(std::move(message_), callbacks_, options); - EXPECT_CALL(callbacks_, onFailure(_)); + auto* request = client_.send(std::move(message_), callbacks_, options); + EXPECT_NE(request, nullptr); + + EXPECT_CALL(callbacks_, onFailure(_, _)) + .WillOnce(Invoke([sent_request = request](const AsyncClient::Request& request, + AsyncClient::FailureReason reason) { + // Verify that callback is called with the same request handle as returned by + // AsyncClient::send(). + EXPECT_EQ(&request, sent_request); + EXPECT_EQ(reason, AsyncClient::FailureReason::Reset); + })); EXPECT_CALL(*child_span, setTag(Eq(Tracing::Tags::get().Component), Eq(Tracing::Tags::get().Proxy))); EXPECT_CALL(*child_span, setTag(Eq(Tracing::Tags::get().HttpProtocol), Eq("HTTP/1.1"))); @@ -962,7 +1051,14 @@ TEST_F(AsyncClientImplTest, PoolFailure) { return nullptr; })); - expectSuccess(503); + EXPECT_CALL(callbacks_, onSuccess_(_, _)) + .WillOnce(Invoke([](const AsyncClient::Request& request, ResponseMessage* response) -> void { + // The callback gets called before AsyncClient::send() completes, which means that we don't + // have a request handle to compare to. + EXPECT_NE(nullptr, &request); + EXPECT_EQ(503, Utility::getResponseStatus(response->headers())); + })); + EXPECT_EQ(nullptr, client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions())); EXPECT_EQ( @@ -979,7 +1075,13 @@ TEST_F(AsyncClientImplTest, PoolFailureWithBody) { return nullptr; })); - expectSuccess(503); + EXPECT_CALL(callbacks_, onSuccess_(_, _)) + .WillOnce(Invoke([](const AsyncClient::Request& request, ResponseMessage* response) -> void { + // The callback gets called before AsyncClient::send() completes, which means that we don't + // have a request handle to compare to. + EXPECT_NE(nullptr, &request); + EXPECT_EQ(503, Utility::getResponseStatus(response->headers())); + })); message_->body() = std::make_unique("hello"); EXPECT_EQ(nullptr, client_.send(std::move(message_), callbacks_, AsyncClient::RequestOptions())); @@ -1056,12 +1158,16 @@ TEST_F(AsyncClientImplTest, RequestTimeout) { })); EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(&message_->headers()), true)); - expectSuccess(504); timer_ = new NiceMock(&dispatcher_); EXPECT_CALL(*timer_, enableTimer(std::chrono::milliseconds(40), _)); EXPECT_CALL(stream_encoder_.stream_, resetStream(_)); - client_.send(std::move(message_), callbacks_, - AsyncClient::RequestOptions().setTimeout(std::chrono::milliseconds(40))); + + auto* request = + client_.send(std::move(message_), callbacks_, + AsyncClient::RequestOptions().setTimeout(std::chrono::milliseconds(40))); + EXPECT_NE(request, nullptr); + + expectSuccess(request, 504); timer_->invokeCallback(); EXPECT_EQ(1UL, @@ -1083,7 +1189,6 @@ TEST_F(AsyncClientImplTracingTest, RequestTimeout) { })); EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(&message_->headers()), true)); - expectSuccess(504); timer_ = new NiceMock(&dispatcher_); EXPECT_CALL(*timer_, enableTimer(std::chrono::milliseconds(40), _)); @@ -1096,7 +1201,11 @@ TEST_F(AsyncClientImplTracingTest, RequestTimeout) { .setTimeout(std::chrono::milliseconds(40)); EXPECT_CALL(*child_span, setSampled(true)); EXPECT_CALL(*child_span, injectContext(_)); - client_.send(std::move(message_), callbacks_, options); + + auto* request = client_.send(std::move(message_), callbacks_, options); + EXPECT_NE(request, nullptr); + + expectSuccess(request, 504); EXPECT_CALL(*child_span, setTag(Eq(Tracing::Tags::get().Component), Eq(Tracing::Tags::get().Proxy))); diff --git a/test/common/router/shadow_writer_impl_test.cc b/test/common/router/shadow_writer_impl_test.cc index 64a3984369a83..d95ae08565b42 100644 --- a/test/common/router/shadow_writer_impl_test.cc +++ b/test/common/router/shadow_writer_impl_test.cc @@ -27,7 +27,6 @@ class ShadowWriterImplTest : public testing::Test { message->headers().setHost(host); EXPECT_CALL(cm_, get(Eq("foo"))); EXPECT_CALL(cm_, httpAsyncClientForCluster("foo")).WillOnce(ReturnRef(cm_.async_client_)); - Http::MockAsyncClientRequest request(&cm_.async_client_); auto options = Http::AsyncClient::RequestOptions().setTimeout(std::chrono::milliseconds(5)); EXPECT_CALL(cm_.async_client_, send_(_, _, options)) .WillOnce(Invoke( @@ -36,13 +35,14 @@ class ShadowWriterImplTest : public testing::Test { EXPECT_EQ(message, inner_message); EXPECT_EQ(shadowed_host, message->headers().Host()->value().getStringView()); callback_ = &callbacks; - return &request; + return &request_; })); writer_.shadow("foo", std::move(message), options); } Upstream::MockClusterManager cm_; ShadowWriterImpl writer_{cm_}; + Http::MockAsyncClientRequest request_{&cm_.async_client_}; Http::AsyncClient::Callbacks* callback_{}; }; @@ -51,14 +51,14 @@ TEST_F(ShadowWriterImplTest, Success) { expectShadowWriter("cluster1", "cluster1-shadow"); Http::ResponseMessagePtr response(new Http::ResponseMessageImpl()); - callback_->onSuccess(std::move(response)); + callback_->onSuccess(request_, std::move(response)); } TEST_F(ShadowWriterImplTest, Failure) { InSequence s; expectShadowWriter("cluster1:8000", "cluster1-shadow:8000"); - callback_->onFailure(Http::AsyncClient::FailureReason::Reset); + callback_->onFailure(request_, Http::AsyncClient::FailureReason::Reset); } TEST_F(ShadowWriterImplTest, NoCluster) { diff --git a/test/extensions/filters/common/ext_authz/ext_authz_http_impl_test.cc b/test/extensions/filters/common/ext_authz/ext_authz_http_impl_test.cc index e102d2923f00d..fbbce3cb67281 100644 --- a/test/extensions/filters/common/ext_authz/ext_authz_http_impl_test.cc +++ b/test/extensions/filters/common/ext_authz/ext_authz_http_impl_test.cc @@ -119,7 +119,7 @@ class ExtAuthzHttpClientTest : public testing::Test { client_->check(request_callbacks_, request, Tracing::NullSpan::instance()); EXPECT_CALL(request_callbacks_, onComplete_(WhenDynamicCastTo(AuthzOkResponse(authz_response)))); - client_->onSuccess(std::move(check_response)); + client_->onSuccess(async_request_, std::move(check_response)); return message_ptr; } @@ -300,7 +300,7 @@ TEST_F(ExtAuthzHttpClientTest, AuthorizationOk) { EXPECT_CALL(*child_span, setTag(Eq("ext_authz_status"), Eq("ext_authz_ok"))); EXPECT_CALL(*child_span, setTag(Eq("ext_authz_http_status"), Eq("OK"))); EXPECT_CALL(*child_span, finishSpan()); - client_->onSuccess(std::move(check_response)); + client_->onSuccess(async_request_, std::move(check_response)); } // Verify client response headers when authorization_headers_to_add is configured. @@ -329,7 +329,7 @@ TEST_F(ExtAuthzHttpClientTest, AuthorizationOkWithAddedAuthzHeaders) { EXPECT_CALL(*child_span, setTag(Eq("ext_authz_status"), Eq("ext_authz_ok"))); EXPECT_CALL(*child_span, setTag(Eq("ext_authz_http_status"), Eq("OK"))); EXPECT_CALL(*child_span, finishSpan()); - client_->onSuccess(std::move(check_response)); + client_->onSuccess(async_request_, std::move(check_response)); } // Verify client response headers when allow_upstream_headers is configured. @@ -363,7 +363,7 @@ TEST_F(ExtAuthzHttpClientTest, AuthorizationOkWithAllowHeader) { EXPECT_CALL(*child_span, setTag(Eq("ext_authz_http_status"), Eq("OK"))); EXPECT_CALL(*child_span, finishSpan()); auto message_response = TestCommon::makeMessageResponse(check_response_headers); - client_->onSuccess(std::move(message_response)); + client_->onSuccess(async_request_, std::move(message_response)); } // Test the client when a denied response is received. @@ -385,7 +385,7 @@ TEST_F(ExtAuthzHttpClientTest, AuthorizationDenied) { EXPECT_CALL(*child_span, finishSpan()); EXPECT_CALL(request_callbacks_, onComplete_(WhenDynamicCastTo(AuthzDeniedResponse(authz_response)))); - client_->onSuccess(TestCommon::makeMessageResponse(expected_headers)); + client_->onSuccess(async_request_, TestCommon::makeMessageResponse(expected_headers)); } // Verify client response headers and body when the authorization server denies the request. @@ -410,7 +410,8 @@ TEST_F(ExtAuthzHttpClientTest, AuthorizationDeniedWithAllAttributes) { EXPECT_CALL(*child_span, finishSpan()); EXPECT_CALL(request_callbacks_, onComplete_(WhenDynamicCastTo(AuthzDeniedResponse(authz_response)))); - client_->onSuccess(TestCommon::makeMessageResponse(expected_headers, expected_body)); + client_->onSuccess(async_request_, + TestCommon::makeMessageResponse(expected_headers, expected_body)); } // Verify client response headers when the authorization server denies the request and @@ -439,7 +440,8 @@ TEST_F(ExtAuthzHttpClientTest, AuthorizationDeniedAndAllowedClientHeaders) { {"x-foo", "bar", false}, {":status", "401", false}, {"foo", "bar", false}}); - client_->onSuccess(TestCommon::makeMessageResponse(check_response_headers, expected_body)); + client_->onSuccess(async_request_, + TestCommon::makeMessageResponse(check_response_headers, expected_body)); } // Test the client when an unknown error occurs. @@ -458,7 +460,7 @@ TEST_F(ExtAuthzHttpClientTest, AuthorizationRequestError) { onComplete_(WhenDynamicCastTo(AuthzErrorResponse(CheckStatus::Error)))); EXPECT_CALL(*child_span, setTag(Eq(Tracing::Tags::get().Error), Eq(Tracing::Tags::get().True))); EXPECT_CALL(*child_span, finishSpan()); - client_->onFailure(Http::AsyncClient::FailureReason::Reset); + client_->onFailure(async_request_, Http::AsyncClient::FailureReason::Reset); } // Test the client when a call to authorization server returns a 5xx error status. @@ -479,7 +481,7 @@ TEST_F(ExtAuthzHttpClientTest, AuthorizationRequest5xxError) { onComplete_(WhenDynamicCastTo(AuthzErrorResponse(CheckStatus::Error)))); EXPECT_CALL(*child_span, setTag(Eq("ext_authz_http_status"), Eq("Service Unavailable"))); EXPECT_CALL(*child_span, finishSpan()); - client_->onSuccess(std::move(check_response)); + client_->onSuccess(async_request_, std::move(check_response)); } // Test the client when a call to authorization server returns a status code that cannot be @@ -501,7 +503,7 @@ TEST_F(ExtAuthzHttpClientTest, AuthorizationRequestErrorParsingStatusCode) { onComplete_(WhenDynamicCastTo(AuthzErrorResponse(CheckStatus::Error)))); EXPECT_CALL(*child_span, setTag(Eq(Tracing::Tags::get().Error), Eq(Tracing::Tags::get().True))); EXPECT_CALL(*child_span, finishSpan()); - client_->onSuccess(std::move(check_response)); + client_->onSuccess(async_request_, std::move(check_response)); } // Test the client when the request is canceled. diff --git a/test/extensions/filters/http/common/mock.cc b/test/extensions/filters/http/common/mock.cc index 8244cbd51b1d7..45129c0edc5ca 100644 --- a/test/extensions/filters/http/common/mock.cc +++ b/test/extensions/filters/http/common/mock.cc @@ -21,7 +21,7 @@ MockUpstream::MockUpstream(Upstream::MockClusterManager& mock_cm, const std::str } else { response_message->body().reset(nullptr); } - cb.onSuccess(std::move(response_message)); + cb.onSuccess(request_, std::move(response_message)); return &request_; })); } @@ -33,7 +33,7 @@ MockUpstream::MockUpstream(Upstream::MockClusterManager& mock_cm, .WillByDefault(testing::Invoke( [this, reason](Http::RequestMessagePtr&, Http::AsyncClient::Callbacks& cb, const Http::AsyncClient::RequestOptions&) -> Http::AsyncClient::Request* { - cb.onFailure(reason); + cb.onFailure(request_, reason); return &request_; })); } diff --git a/test/extensions/filters/http/jwt_authn/mock.h b/test/extensions/filters/http/jwt_authn/mock.h index d0f8321a9d4d7..38ec192d19f6d 100644 --- a/test/extensions/filters/http/jwt_authn/mock.h +++ b/test/extensions/filters/http/jwt_authn/mock.h @@ -71,7 +71,7 @@ class MockUpstream { new Http::ResponseMessageImpl(Http::ResponseHeaderMapPtr{ new Http::TestResponseHeaderMapImpl{{":status", "200"}}})); response_message->body() = std::make_unique(response_body_); - cb.onSuccess(std::move(response_message)); + cb.onSuccess(request_, std::move(response_message)); called_count_++; return &request_; })); diff --git a/test/extensions/filters/http/lua/lua_filter_test.cc b/test/extensions/filters/http/lua/lua_filter_test.cc index a0fa4184d9152..8de472b120ab0 100644 --- a/test/extensions/filters/http/lua/lua_filter_test.cc +++ b/test/extensions/filters/http/lua/lua_filter_test.cc @@ -797,7 +797,7 @@ TEST_F(LuaHttpFilterTest, HttpCall) { EXPECT_CALL(*filter_, scriptLog(spdlog::level::trace, StrEq(":status 200"))); EXPECT_CALL(*filter_, scriptLog(spdlog::level::trace, StrEq("response"))); EXPECT_CALL(decoder_callbacks_, continueDecoding()); - callbacks->onSuccess(std::move(response_message)); + callbacks->onSuccess(request, std::move(response_message)); } // Basic HTTP request flow. Asynchronous flag set to false. @@ -860,7 +860,7 @@ TEST_F(LuaHttpFilterTest, HttpCallAsyncFalse) { EXPECT_CALL(*filter_, scriptLog(spdlog::level::trace, StrEq(":status 200"))); EXPECT_CALL(*filter_, scriptLog(spdlog::level::trace, StrEq("response"))); EXPECT_CALL(decoder_callbacks_, continueDecoding()); - callbacks->onSuccess(std::move(response_message)); + callbacks->onSuccess(request, std::move(response_message)); } // Basic asynchronous, fire-and-forget HTTP request flow. @@ -990,14 +990,14 @@ TEST_F(LuaHttpFilterTest, DoubleHttpCall) { callbacks = &cb; return &request; })); - callbacks->onSuccess(std::move(response_message)); + callbacks->onSuccess(request, std::move(response_message)); response_message = std::make_unique( Http::ResponseHeaderMapPtr{new Http::TestResponseHeaderMapImpl{{":status", "403"}}}); EXPECT_CALL(*filter_, scriptLog(spdlog::level::trace, StrEq(":status 403"))); EXPECT_CALL(*filter_, scriptLog(spdlog::level::trace, StrEq("no body"))); EXPECT_CALL(decoder_callbacks_, continueDecoding()); - callbacks->onSuccess(std::move(response_message)); + callbacks->onSuccess(request, std::move(response_message)); Buffer::OwnedImpl data("hello"); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data, false)); @@ -1061,7 +1061,7 @@ TEST_F(LuaHttpFilterTest, HttpCallNoBody) { EXPECT_CALL(*filter_, scriptLog(spdlog::level::trace, StrEq(":status 200"))); EXPECT_CALL(*filter_, scriptLog(spdlog::level::trace, StrEq("no body"))); EXPECT_CALL(decoder_callbacks_, continueDecoding()); - callbacks->onSuccess(std::move(response_message)); + callbacks->onSuccess(request, std::move(response_message)); } // HTTP call followed by immediate response. @@ -1114,7 +1114,7 @@ TEST_F(LuaHttpFilterTest, HttpCallImmediateResponse) { {"set-cookie", "flavor=chocolate; Path=/"}, {"set-cookie", "variant=chewy; Path=/"}}; EXPECT_CALL(decoder_callbacks_, encodeHeaders_(HeaderMapEqualRef(&expected_headers), true)); - callbacks->onSuccess(std::move(response_message)); + callbacks->onSuccess(request, std::move(response_message)); } // HTTP call with script error after resume. @@ -1162,7 +1162,7 @@ TEST_F(LuaHttpFilterTest, HttpCallErrorAfterResumeSuccess) { scriptLog(spdlog::level::err, StrEq("[string \"...\"]:14: attempt to index local 'foo' (a nil value)"))); EXPECT_CALL(decoder_callbacks_, continueDecoding()); - callbacks->onSuccess(std::move(response_message)); + callbacks->onSuccess(request, std::move(response_message)); } // HTTP call failure. @@ -1207,7 +1207,7 @@ TEST_F(LuaHttpFilterTest, HttpCallFailure) { EXPECT_CALL(*filter_, scriptLog(spdlog::level::trace, StrEq(":status 503"))); EXPECT_CALL(*filter_, scriptLog(spdlog::level::trace, StrEq("upstream failure"))); EXPECT_CALL(decoder_callbacks_, continueDecoding()); - callbacks->onFailure(Http::AsyncClient::FailureReason::Reset); + callbacks->onFailure(request, Http::AsyncClient::FailureReason::Reset); } // HTTP call reset. @@ -1283,7 +1283,9 @@ TEST_F(LuaHttpFilterTest, HttpCallImmediateFailure) { .WillOnce( Invoke([&](Http::RequestMessagePtr&, Http::AsyncClient::Callbacks& cb, const Http::AsyncClient::RequestOptions&) -> Http::AsyncClient::Request* { - cb.onFailure(Http::AsyncClient::FailureReason::Reset); + cb.onFailure(request, Http::AsyncClient::FailureReason::Reset); + // Intentionally return nullptr (instead of request handle) to trigger a particular + // code path. return nullptr; })); diff --git a/test/extensions/filters/http/squash/squash_filter_test.cc b/test/extensions/filters/http/squash/squash_filter_test.cc index c4769c97d8589..d2cb53bf52b63 100644 --- a/test/extensions/filters/http/squash/squash_filter_test.cc +++ b/test/extensions/filters/http/squash/squash_filter_test.cc @@ -223,7 +223,7 @@ class SquashFilterTest : public testing::Test { Http::ResponseMessagePtr msg(new Http::ResponseMessageImpl( Http::ResponseHeaderMapPtr{new Http::TestResponseHeaderMapImpl{{":status", status}}})); msg->body() = std::make_unique(body); - popPendingCallback()->onSuccess(std::move(msg)); + popPendingCallback()->onSuccess(request_, std::move(msg)); } void completeCreateRequest() { @@ -265,7 +265,9 @@ TEST_F(SquashFilterTest, DecodeHeaderContinuesOnClientFail) { .WillOnce(Invoke( [&](Envoy::Http::RequestMessagePtr&, Envoy::Http::AsyncClient::Callbacks& callbacks, const Http::AsyncClient::RequestOptions&) -> Envoy::Http::AsyncClient::Request* { - callbacks.onFailure(Envoy::Http::AsyncClient::FailureReason::Reset); + callbacks.onFailure(request_, Envoy::Http::AsyncClient::FailureReason::Reset); + // Intentionally return nullptr (instead of request handle) to trigger a particular + // code path. return nullptr; })); @@ -286,7 +288,7 @@ TEST_F(SquashFilterTest, DecodeContinuesOnCreateAttachmentFail) { EXPECT_CALL(filter_callbacks_, continueDecoding()); EXPECT_CALL(*attachmentTimeout_timer_, disableTimer()); - popPendingCallback()->onFailure(Envoy::Http::AsyncClient::FailureReason::Reset); + popPendingCallback()->onFailure(request_, Envoy::Http::AsyncClient::FailureReason::Reset); Envoy::Buffer::OwnedImpl data("nothing here"); EXPECT_EQ(Envoy::Http::FilterDataStatus::Continue, filter_->decodeData(data, false)); @@ -365,7 +367,7 @@ TEST_F(SquashFilterTest, CheckRetryPollingAttachmentOnFailure) { auto retry_timer = new NiceMock(&filter_callbacks_.dispatcher_); EXPECT_CALL(*retry_timer, enableTimer(config_->attachmentPollPeriod(), _)); - popPendingCallback()->onFailure(Envoy::Http::AsyncClient::FailureReason::Reset); + popPendingCallback()->onFailure(request_, Envoy::Http::AsyncClient::FailureReason::Reset); // Expect the second get attachment request expectAsyncClientSend(); diff --git a/test/extensions/filters/network/client_ssl_auth/client_ssl_auth_test.cc b/test/extensions/filters/network/client_ssl_auth/client_ssl_auth_test.cc index b4b522f7c274d..75aa187a26ba9 100644 --- a/test/extensions/filters/network/client_ssl_auth/client_ssl_auth_test.cc +++ b/test/extensions/filters/network/client_ssl_auth/client_ssl_auth_test.cc @@ -173,7 +173,7 @@ TEST_F(ClientSslAuthFilterTest, Ssl) { message->body() = std::make_unique( api_->fileSystem().fileReadToEnd(TestEnvironment::runfilesPath( "test/extensions/filters/network/client_ssl_auth/test_data/vpn_response_1.json"))); - callbacks_->onSuccess(std::move(message)); + callbacks_->onSuccess(request_, std::move(message)); EXPECT_EQ(1U, stats_store_ .gauge("auth.clientssl.vpn.total_principals", Stats::Gauge::ImportMode::NeverImport) @@ -227,7 +227,7 @@ TEST_F(ClientSslAuthFilterTest, Ssl) { EXPECT_CALL(*interval_timer_, enableTimer(_, _)); message = std::make_unique( Http::ResponseHeaderMapPtr{new Http::TestResponseHeaderMapImpl{{":status", "503"}}}); - callbacks_->onSuccess(std::move(message)); + callbacks_->onSuccess(request_, std::move(message)); // Interval timer fires. setupRequest(); @@ -238,7 +238,7 @@ TEST_F(ClientSslAuthFilterTest, Ssl) { message = std::make_unique( Http::ResponseHeaderMapPtr{new Http::TestResponseHeaderMapImpl{{":status", "200"}}}); message->body() = std::make_unique("bad_json"); - callbacks_->onSuccess(std::move(message)); + callbacks_->onSuccess(request_, std::move(message)); // Interval timer fires. setupRequest(); @@ -246,7 +246,7 @@ TEST_F(ClientSslAuthFilterTest, Ssl) { // No response failure. EXPECT_CALL(*interval_timer_, enableTimer(_, _)); - callbacks_->onFailure(Http::AsyncClient::FailureReason::Reset); + callbacks_->onFailure(request_, Http::AsyncClient::FailureReason::Reset); // Interval timer fires, cannot obtain async client. EXPECT_CALL(cm_, httpAsyncClientForCluster("vpn")).WillOnce(ReturnRef(cm_.async_client_)); @@ -255,8 +255,11 @@ TEST_F(ClientSslAuthFilterTest, Ssl) { Invoke([&](Http::RequestMessagePtr&, Http::AsyncClient::Callbacks& callbacks, const Http::AsyncClient::RequestOptions&) -> Http::AsyncClient::Request* { callbacks.onSuccess( + request_, Http::ResponseMessagePtr{new Http::ResponseMessageImpl(Http::ResponseHeaderMapPtr{ new Http::TestResponseHeaderMapImpl{{":status", "503"}}})}); + // Intentionally return nullptr (instead of request handle) to trigger a particular + // code path. return nullptr; })); EXPECT_CALL(*interval_timer_, enableTimer(_, _)); diff --git a/test/extensions/tracers/datadog/datadog_tracer_impl_test.cc b/test/extensions/tracers/datadog/datadog_tracer_impl_test.cc index 659d9f424309e..0a0e28d2b6285 100644 --- a/test/extensions/tracers/datadog/datadog_tracer_impl_test.cc +++ b/test/extensions/tracers/datadog/datadog_tracer_impl_test.cc @@ -158,7 +158,7 @@ TEST_F(DatadogDriverTest, FlushSpansTimer) { Http::ResponseHeaderMapPtr{new Http::TestResponseHeaderMapImpl{{":status", "200"}}})); msg->body() = std::make_unique(""); - callback->onSuccess(std::move(msg)); + callback->onSuccess(request, std::move(msg)); EXPECT_EQ(1U, stats_.counter("tracing.datadog.reports_sent").value()); EXPECT_EQ(0U, stats_.counter("tracing.datadog.reports_dropped").value()); diff --git a/test/extensions/tracers/lightstep/lightstep_tracer_impl_test.cc b/test/extensions/tracers/lightstep/lightstep_tracer_impl_test.cc index 9a2e6b0130589..0234b1a579b27 100644 --- a/test/extensions/tracers/lightstep/lightstep_tracer_impl_test.cc +++ b/test/extensions/tracers/lightstep/lightstep_tracer_impl_test.cc @@ -229,7 +229,7 @@ TEST_F(LightStepDriverTest, FlushSeveralSpans) { start_time_, {Tracing::Reason::Sampling, true}); third_span->finishSpan(); - callback->onSuccess(makeSuccessResponse()); + callback->onSuccess(request, makeSuccessResponse()); EXPECT_EQ(1U, cm_.thread_local_cluster_.cluster_.info_->stats_store_ .counter("grpc.lightstep.collector.CollectorService.Report.success") @@ -277,7 +277,7 @@ TEST_F(LightStepDriverTest, FlushOneFailure) { second_span->finishSpan(); - callback->onFailure(Http::AsyncClient::FailureReason::Reset); + callback->onFailure(request, Http::AsyncClient::FailureReason::Reset); EXPECT_EQ(1U, cm_.thread_local_cluster_.cluster_.info_->stats_store_ .counter("grpc.lightstep.collector.CollectorService.Report.failure") @@ -411,7 +411,7 @@ TEST_F(LightStepDriverTest, FlushSpansTimer) { timer_->invokeCallback(); - callback->onSuccess(makeSuccessResponse()); + callback->onSuccess(request, makeSuccessResponse()); EXPECT_EQ(1U, stats_.counter("tracing.lightstep.timer_flushed").value()); EXPECT_EQ(1U, stats_.counter("tracing.lightstep.spans_sent").value()); diff --git a/test/extensions/tracers/zipkin/zipkin_tracer_impl_test.cc b/test/extensions/tracers/zipkin/zipkin_tracer_impl_test.cc index a898fbb2e7e99..7356c897b736d 100644 --- a/test/extensions/tracers/zipkin/zipkin_tracer_impl_test.cc +++ b/test/extensions/tracers/zipkin/zipkin_tracer_impl_test.cc @@ -111,14 +111,14 @@ class ZipkinDriverTest : public testing::Test { Http::ResponseMessagePtr msg(new Http::ResponseMessageImpl( Http::ResponseHeaderMapPtr{new Http::TestResponseHeaderMapImpl{{":status", "202"}}})); - callback->onSuccess(std::move(msg)); + callback->onSuccess(request, std::move(msg)); EXPECT_EQ(2U, stats_.counter("tracing.zipkin.spans_sent").value()); EXPECT_EQ(1U, stats_.counter("tracing.zipkin.reports_sent").value()); EXPECT_EQ(0U, stats_.counter("tracing.zipkin.reports_dropped").value()); EXPECT_EQ(0U, stats_.counter("tracing.zipkin.reports_failed").value()); - callback->onFailure(Http::AsyncClient::FailureReason::Reset); + callback->onFailure(request, Http::AsyncClient::FailureReason::Reset); EXPECT_EQ(1U, stats_.counter("tracing.zipkin.reports_failed").value()); } @@ -236,7 +236,7 @@ TEST_F(ZipkinDriverTest, FlushOneSpanReportFailure) { Http::ResponseHeaderMapPtr{new Http::TestResponseHeaderMapImpl{{":status", "404"}}})); // AsyncClient can fail with valid HTTP headers - callback->onSuccess(std::move(msg)); + callback->onSuccess(request, std::move(msg)); EXPECT_EQ(1U, stats_.counter("tracing.zipkin.spans_sent").value()); EXPECT_EQ(0U, stats_.counter("tracing.zipkin.reports_sent").value()); diff --git a/test/mocks/http/mocks.h b/test/mocks/http/mocks.h index cc4fdbd2be467..5d030de2080e8 100644 --- a/test/mocks/http/mocks.h +++ b/test/mocks/http/mocks.h @@ -334,11 +334,15 @@ class MockAsyncClientCallbacks : public AsyncClient::Callbacks { MockAsyncClientCallbacks(); ~MockAsyncClientCallbacks() override; - void onSuccess(ResponseMessagePtr&& response) override { onSuccess_(response.get()); } + void onSuccess(const Http::AsyncClient::Request& request, + ResponseMessagePtr&& response) override { + onSuccess_(request, response.get()); + } // Http::AsyncClient::Callbacks - MOCK_METHOD(void, onSuccess_, (ResponseMessage * response)); - MOCK_METHOD(void, onFailure, (Http::AsyncClient::FailureReason reason)); + MOCK_METHOD(void, onSuccess_, (const Http::AsyncClient::Request&, ResponseMessage*)); + MOCK_METHOD(void, onFailure, + (const Http::AsyncClient::Request&, Http::AsyncClient::FailureReason)); }; class MockAsyncClientStreamCallbacks : public AsyncClient::StreamCallbacks {