diff --git a/include/envoy/http/async_client.h b/include/envoy/http/async_client.h index b2d559596c12..59d6fef6c301 100644 --- a/include/envoy/http/async_client.h +++ b/include/envoy/http/async_client.h @@ -52,19 +52,18 @@ class AsyncClient { virtual void cancel() PURE; }; - typedef std::unique_ptr RequestPtr; - virtual ~AsyncClient() {} /** * Send an HTTP request asynchronously - * @param request the request to send - * @param callbacks the callbacks to be notified of request status + * @param request the request to send. + * @param callbacks the callbacks to be notified of request status. * @return a request handle or nullptr if no request could be created. NOTE: In this case - * onFailure() has already been called inline. + * onFailure() has already been called inline. The client owns the request and the + * handle should just be used to cancel. */ - virtual RequestPtr send(MessagePtr&& request, Callbacks& callbacks, - const Optional& timeout) PURE; + virtual Request* send(MessagePtr&& request, Callbacks& callbacks, + const Optional& timeout) PURE; }; typedef std::unique_ptr AsyncClientPtr; diff --git a/include/envoy/upstream/cluster_manager.h b/include/envoy/upstream/cluster_manager.h index ae03f2a30f13..002c025fa0cb 100644 --- a/include/envoy/upstream/cluster_manager.h +++ b/include/envoy/upstream/cluster_manager.h @@ -31,18 +31,12 @@ class ClusterManager { */ virtual const Cluster* get(const std::string& cluster) PURE; - /** - * @return whether the cluster manager knows about a particular cluster by name. - */ - virtual bool has(const std::string& cluster) PURE; - /** * Allocate a load balanced HTTP connection pool for a cluster. This is *per-thread* so that * callers do not need to worry about per thread synchronization. The load balancing policy that * is used is the one defined on the cluster when it was created. * - * Can return nullptr if there is no host available in the cluster or the cluster name is not - * valid. + * Can return nullptr if there is no host available in the cluster. */ virtual Http::ConnectionPool::Instance* httpConnPoolForCluster(const std::string& cluster) PURE; @@ -52,15 +46,16 @@ class ClusterManager { * load balancing policy that is used is the one defined on the cluster when it was created. * * Returns both a connection and the host that backs the connection. Both can be nullptr if there - * is no host available in the cluster or the cluster name is not valid. + * is no host available in the cluster. */ virtual Host::CreateConnectionData tcpConnForCluster(const std::string& cluster) PURE; /** - * Returns a client that can be used to make async HTTP calls against the given cluster. The - * client may be backed by a connection pool or by a multiplexed connection. + * Returns a client that can be used to make async HTTP calls against the given cluster. The + * client may be backed by a connection pool or by a multiplexed connection. The cluster manager + * owns the client. */ - virtual Http::AsyncClientPtr httpAsyncClientForCluster(const std::string& cluster) PURE; + virtual Http::AsyncClient& httpAsyncClientForCluster(const std::string& cluster) PURE; /** * Shutdown the cluster prior to destroying connection pools and other thread local data. diff --git a/include/envoy/upstream/upstream.h b/include/envoy/upstream/upstream.h index 7a89bc26863f..75d5cc3d2b37 100644 --- a/include/envoy/upstream/upstream.h +++ b/include/envoy/upstream/upstream.h @@ -238,8 +238,7 @@ class Cluster : public virtual HostSet { virtual ResourceManager& resourceManager() const PURE; /** - * Shutdown the cluster manager prior to destroying connection pools and other thread local - * data. + * Shutdown the cluster prior to destroying connection pools and other thread local data. */ virtual void shutdown() PURE; diff --git a/source/common/common/linked_object.h b/source/common/common/linked_object.h index 127ffe431c04..c99ed7eadb99 100644 --- a/source/common/common/linked_object.h +++ b/source/common/common/linked_object.h @@ -18,6 +18,11 @@ template class LinkedObject { return entry_; } + /** + * @return whether the object is currently inserted into a list. + */ + bool inserted() { return inserted_; } + /** * Move a linked item between 2 lists. * @param list1 supplies the first list. diff --git a/source/common/filter/auth/client_ssl.cc b/source/common/filter/auth/client_ssl.cc index 5bce5fa01e8a..edce72ad5fcf 100644 --- a/source/common/filter/auth/client_ssl.cc +++ b/source/common/filter/auth/client_ssl.cc @@ -21,7 +21,7 @@ Config::Config(const Json::Object& config, ThreadLocal::Instance& tls, Upstream: ip_white_list_(config), stats_(generateStats(stats_store, config.getString("stat_prefix"))), runtime_(runtime), local_address_(local_address) { - if (!cm_.has(auth_api_cluster_)) { + if (!cm_.get(auth_api_cluster_)) { throw EnvoyException( fmt::format("unknown cluster '{}' in client ssl auth config", auth_api_cluster_)); } @@ -83,29 +83,19 @@ void Config::onFailure(Http::AsyncClient::FailureReason) { } void Config::refreshPrincipals() { - ASSERT(!active_request_); - active_request_.reset(new ActiveRequest()); - active_request_->client_ = cm_.httpAsyncClientForCluster(auth_api_cluster_); - if (!active_request_->client_) { - onFailure(Http::AsyncClient::FailureReason::Reset); - return; - } - Http::MessagePtr message(new Http::RequestMessageImpl()); message->headers().addViaMoveValue(Http::Headers::get().Scheme, "http"); message->headers().addViaMoveValue(Http::Headers::get().Method, "GET"); message->headers().addViaMoveValue(Http::Headers::get().Path, "/v1/certs/list/approved"); message->headers().addViaCopy(Http::Headers::get().Host, auth_api_cluster_); message->headers().addViaCopy(Http::Headers::get().ForwardedFor, local_address_); - active_request_->request_ = active_request_->client_->send(std::move(message), *this, - Optional()); + cm_.httpAsyncClientForCluster(auth_api_cluster_) + .send(std::move(message), *this, Optional()); } void Config::requestComplete() { std::chrono::milliseconds interval( runtime_.snapshot().getInteger("auth.clientssl.refresh_interval_ms", 60000)); - - active_request_.reset(); interval_timer_->enableTimer(interval); } diff --git a/source/common/filter/auth/client_ssl.h b/source/common/filter/auth/client_ssl.h index 9ad2966034f2..5d83558ab85c 100644 --- a/source/common/filter/auth/client_ssl.h +++ b/source/common/filter/auth/client_ssl.h @@ -78,13 +78,6 @@ class Config : public Http::AsyncClient::Callbacks { void onFailure(Http::AsyncClient::FailureReason reason) override; private: - struct ActiveRequest { - Http::AsyncClientPtr client_; - Http::AsyncClient::RequestPtr request_; - }; - - typedef std::unique_ptr ActiveRequestPtr; - static GlobalStats generateStats(Stats::Store& store, const std::string& prefix); AllowedPrincipalsPtr parseAuthResponse(Http::Message& message); void refreshPrincipals(); @@ -94,7 +87,6 @@ class Config : public Http::AsyncClient::Callbacks { uint32_t tls_slot_; Upstream::ClusterManager& cm_; const std::string auth_api_cluster_; - ActiveRequestPtr active_request_; Event::TimerPtr interval_timer_; Network::IpWhiteList ip_white_list_; GlobalStats stats_; diff --git a/source/common/filter/tcp_proxy.cc b/source/common/filter/tcp_proxy.cc index 7bd914df4c58..0d83d437db43 100644 --- a/source/common/filter/tcp_proxy.cc +++ b/source/common/filter/tcp_proxy.cc @@ -16,7 +16,7 @@ TcpProxyConfig::TcpProxyConfig(const Json::Object& config, Upstream::ClusterManager& cluster_manager, Stats::Store& stats_store) : cluster_name_(config.getString("cluster")), stats_(generateStats(config.getString("stat_prefix"), stats_store)) { - if (!cluster_manager.has(cluster_name_)) { + if (!cluster_manager.get(cluster_name_)) { throw EnvoyException(fmt::format("tcp proxy: unknown cluster '{}'", cluster_name_)); } } diff --git a/source/common/grpc/rpc_channel_impl.cc b/source/common/grpc/rpc_channel_impl.cc index e30f91855e28..99f1bb418bf3 100644 --- a/source/common/grpc/rpc_channel_impl.cc +++ b/source/common/grpc/rpc_channel_impl.cc @@ -30,12 +30,6 @@ void RpcChannelImpl::CallMethod(const proto::MethodDescriptor* method, proto::Rp // here for clarity. ASSERT(cm_.get(cluster_)->features() & Upstream::Cluster::Features::HTTP2); - client_ = cm_.httpAsyncClientForCluster(cluster_); - if (!client_) { - onFailureWorker(Optional(), "http request failure"); - return; - } - Http::MessagePtr message(new Http::RequestMessageImpl()); message->headers().addViaMoveValue(Http::Headers::get().Scheme, "http"); message->headers().addViaMoveValue(Http::Headers::get().Method, "POST"); @@ -46,10 +40,7 @@ void RpcChannelImpl::CallMethod(const proto::MethodDescriptor* method, proto::Rp message->headers().addViaCopy(Http::Headers::get().ContentType, Common::GRPC_CONTENT_TYPE); message->body(serializeBody(*grpc_request)); - http_request_ = client_->send(std::move(message), *this, timeout_); - if (!http_request_) { - onFailureWorker(Optional(), "http request failure"); - } + http_request_ = cm_.httpAsyncClientForCluster(cluster_).send(std::move(message), *this, timeout_); } void RpcChannelImpl::incStat(bool success) { diff --git a/source/common/grpc/rpc_channel_impl.h b/source/common/grpc/rpc_channel_impl.h index 194ca46745f0..72b74e0660a6 100644 --- a/source/common/grpc/rpc_channel_impl.h +++ b/source/common/grpc/rpc_channel_impl.h @@ -62,8 +62,7 @@ class RpcChannelImpl : public RpcChannel, public Http::AsyncClient::Callbacks { Upstream::ClusterManager& cm_; const std::string cluster_; - Http::AsyncClientPtr client_; - Http::AsyncClient::RequestPtr http_request_; + Http::AsyncClient::Request* http_request_{}; const proto::MethodDescriptor* grpc_method_{}; proto::Message* grpc_response_{}; RpcChannelCallbacks& callbacks_; diff --git a/source/common/http/async_client_impl.cc b/source/common/http/async_client_impl.cc index 5d2911b4f6c5..b1e51a2a3221 100644 --- a/source/common/http/async_client_impl.cc +++ b/source/common/http/async_client_impl.cc @@ -7,26 +7,35 @@ namespace Http { -const Http::HeaderMapImpl AsyncRequestImpl::SERVICE_UNAVAILABLE_HEADER{ - {Http::Headers::get().Status, std::to_string(enumToInt(Http::Code::ServiceUnavailable))}}; +const HeaderMapImpl AsyncRequestImpl::SERVICE_UNAVAILABLE_HEADER{ + {Headers::get().Status, std::to_string(enumToInt(Code::ServiceUnavailable))}}; -const Http::HeaderMapImpl AsyncRequestImpl::REQUEST_TIMEOUT_HEADER{ - {Http::Headers::get().Status, std::to_string(enumToInt(Http::Code::GatewayTimeout))}}; +const HeaderMapImpl AsyncRequestImpl::REQUEST_TIMEOUT_HEADER{ + {Headers::get().Status, std::to_string(enumToInt(Code::GatewayTimeout))}}; -AsyncClientImpl::AsyncClientImpl(ConnectionPool::Instance& conn_pool, const std::string& cluster, - Stats::Store& stats_store, Event::Dispatcher& dispatcher) - : conn_pool_(conn_pool), stat_prefix_(fmt::format("cluster.{}.", cluster)), - stats_store_(stats_store), dispatcher_(dispatcher) {} +AsyncClientImpl::AsyncClientImpl(const Upstream::Cluster& cluster, + AsyncClientConnPoolFactory& factory, Stats::Store& stats_store, + Event::Dispatcher& dispatcher) + : cluster_(cluster), factory_(factory), stats_store_(stats_store), dispatcher_(dispatcher), + stat_prefix_(fmt::format("cluster.{}.", cluster.name())) {} + +AsyncClientImpl::~AsyncClientImpl() { ASSERT(active_requests_.empty()); } + +AsyncClient::Request* AsyncClientImpl::send(MessagePtr&& request, AsyncClient::Callbacks& callbacks, + const Optional& timeout) { + ConnectionPool::Instance* conn_pool = factory_.connPool(); + if (!conn_pool) { + callbacks.onFailure(AsyncClient::FailureReason::Reset); + return nullptr; + } -AsyncClient::RequestPtr AsyncClientImpl::send(MessagePtr&& request, - AsyncClient::Callbacks& callbacks, - const Optional& timeout) { std::unique_ptr new_request{ - new AsyncRequestImpl(std::move(request), *this, callbacks, dispatcher_, timeout)}; + new AsyncRequestImpl(std::move(request), *this, callbacks, dispatcher_, *conn_pool, timeout)}; // The request may get immediately failed. If so, we will return nullptr. if (new_request->stream_encoder_) { - return std::move(new_request); + new_request->moveIntoList(std::move(new_request), active_requests_); + return active_requests_.front().get(); } else { return nullptr; } @@ -34,10 +43,11 @@ AsyncClient::RequestPtr AsyncClientImpl::send(MessagePtr&& request, AsyncRequestImpl::AsyncRequestImpl(MessagePtr&& request, AsyncClientImpl& parent, AsyncClient::Callbacks& callbacks, Event::Dispatcher& dispatcher, + ConnectionPool::Instance& conn_pool, const Optional& timeout) : request_(std::move(request)), parent_(parent), callbacks_(callbacks) { - stream_encoder_.reset(new PooledStreamEncoder(parent_.conn_pool_, *this, *this, 0, 0, *this)); + stream_encoder_.reset(new PooledStreamEncoder(conn_pool, *this, *this, 0, 0, *this)); stream_encoder_->encodeHeaders(request_->headers(), !request_->body()); // We might have been immediately failed. @@ -66,9 +76,9 @@ void AsyncRequestImpl::decodeHeaders(HeaderMapPtr&& headers, bool end_stream) { -> void { log_debug(" '{}':'{}'", key.get(), value); }); #endif - Http::CodeUtility::ResponseStatInfo info{parent_.stats_store_, parent_.stat_prefix_, - response_->headers(), true, EMPTY_STRING, EMPTY_STRING}; - Http::CodeUtility::chargeResponseStat(info); + CodeUtility::ResponseStatInfo info{parent_.stats_store_, parent_.stat_prefix_, + response_->headers(), true, EMPTY_STRING, EMPTY_STRING}; + CodeUtility::chargeResponseStat(info); if (end_stream) { onComplete(); @@ -102,34 +112,32 @@ void AsyncRequestImpl::decodeTrailers(HeaderMapPtr&& trailers) { void AsyncRequestImpl::onComplete() { // TODO: Check host's canary status in addition to canary header. - Http::CodeUtility::ResponseTimingInfo info{ + CodeUtility::ResponseTimingInfo info{ parent_.stats_store_, parent_.stat_prefix_, stream_encoder_->requestCompleteTime(), - response_->headers().get(Http::Headers::get().EnvoyUpstreamCanary) == "true", true, - EMPTY_STRING, EMPTY_STRING}; - Http::CodeUtility::chargeResponseTiming(info); + response_->headers().get(Headers::get().EnvoyUpstreamCanary) == "true", true, EMPTY_STRING, + EMPTY_STRING}; + CodeUtility::chargeResponseTiming(info); - cleanup(); callbacks_.onSuccess(std::move(response_)); + cleanup(); } void AsyncRequestImpl::onResetStream(StreamResetReason) { - Http::CodeUtility::ResponseStatInfo info{parent_.stats_store_, parent_.stat_prefix_, - SERVICE_UNAVAILABLE_HEADER, true, EMPTY_STRING, - EMPTY_STRING}; - Http::CodeUtility::chargeResponseStat(info); + CodeUtility::ResponseStatInfo info{parent_.stats_store_, parent_.stat_prefix_, + SERVICE_UNAVAILABLE_HEADER, true, EMPTY_STRING, EMPTY_STRING}; + CodeUtility::chargeResponseStat(info); + callbacks_.onFailure(AsyncClient::FailureReason::Reset); cleanup(); - callbacks_.onFailure(Http::AsyncClient::FailureReason::Reset); } void AsyncRequestImpl::onRequestTimeout() { - Http::CodeUtility::ResponseStatInfo info{parent_.stats_store_, parent_.stat_prefix_, - REQUEST_TIMEOUT_HEADER, true, EMPTY_STRING, - EMPTY_STRING}; - Http::CodeUtility::chargeResponseStat(info); - parent_.stats_store_.counter(fmt::format("{}upstream_rq_timeout", parent_.stat_prefix_)).inc(); + CodeUtility::ResponseStatInfo info{parent_.stats_store_, parent_.stat_prefix_, + REQUEST_TIMEOUT_HEADER, true, EMPTY_STRING, EMPTY_STRING}; + CodeUtility::chargeResponseStat(info); + parent_.cluster_.stats().upstream_rq_timeout_.inc(); stream_encoder_->resetStream(); + callbacks_.onFailure(AsyncClient::FailureReason::RequestTimemout); cleanup(); - callbacks_.onFailure(Http::AsyncClient::FailureReason::RequestTimemout); } void AsyncRequestImpl::cleanup() { @@ -137,5 +145,12 @@ void AsyncRequestImpl::cleanup() { if (request_timeout_) { request_timeout_->disableTimer(); } + + // This will destroy us, but only do so if we are actually in a list. This does not happen in + // the immediate failure case. + if (inserted()) { + removeFromList(parent_.active_requests_); + } } + } // Http diff --git a/source/common/http/async_client_impl.h b/source/common/http/async_client_impl.h index ad42a8d63d30..afc1b693afda 100644 --- a/source/common/http/async_client_impl.h +++ b/source/common/http/async_client_impl.h @@ -10,24 +10,43 @@ #include "envoy/http/message.h" #include "common/common/assert.h" +#include "common/common/linked_object.h" #include "common/http/header_map_impl.h" namespace Http { +/** + * Factory for obtaining a connection pool. + */ +class AsyncClientConnPoolFactory { +public: + virtual ~AsyncClientConnPoolFactory() {} + + /** + * Return a connection pool or nullptr if there is no healthy upstream host. + */ + virtual ConnectionPool::Instance* connPool() PURE; +}; + +class AsyncRequestImpl; + class AsyncClientImpl final : public AsyncClient { public: - AsyncClientImpl(ConnectionPool::Instance& conn_pool, const std::string& cluster, + AsyncClientImpl(const Upstream::Cluster& cluster, AsyncClientConnPoolFactory& factory, Stats::Store& stats_store, Event::Dispatcher& dispatcher); + ~AsyncClientImpl(); // Http::AsyncClient - RequestPtr send(MessagePtr&& request, Callbacks& callbacks, - const Optional& timeout) override; + Request* send(MessagePtr&& request, Callbacks& callbacks, + const Optional& timeout) override; private: - ConnectionPool::Instance& conn_pool_; - const std::string stat_prefix_; + const Upstream::Cluster& cluster_; + AsyncClientConnPoolFactory& factory_; Stats::Store& stats_store_; Event::Dispatcher& dispatcher_; + const std::string stat_prefix_; + std::list> active_requests_; friend class AsyncRequestImpl; }; @@ -40,10 +59,11 @@ class AsyncRequestImpl final : public AsyncClient::Request, StreamDecoder, StreamCallbacks, PooledStreamEncoderCallbacks, - Logger::Loggable { + Logger::Loggable, + LinkedObject { public: AsyncRequestImpl(MessagePtr&& request, AsyncClientImpl& parent, AsyncClient::Callbacks& callbacks, - Event::Dispatcher& dispatcher, + Event::Dispatcher& dispatcher, ConnectionPool::Instance& conn_pool, const Optional& timeout); ~AsyncRequestImpl(); @@ -74,8 +94,8 @@ class AsyncRequestImpl final : public AsyncClient::Request, std::unique_ptr response_; PooledStreamEncoderPtr stream_encoder_; - static const Http::HeaderMapImpl SERVICE_UNAVAILABLE_HEADER; - static const Http::HeaderMapImpl REQUEST_TIMEOUT_HEADER; + static const HeaderMapImpl SERVICE_UNAVAILABLE_HEADER; + static const HeaderMapImpl REQUEST_TIMEOUT_HEADER; friend class AsyncClientImpl; }; diff --git a/source/common/ratelimit/ratelimit_impl.cc b/source/common/ratelimit/ratelimit_impl.cc index f20aa9416a42..48733103c3d7 100644 --- a/source/common/ratelimit/ratelimit_impl.cc +++ b/source/common/ratelimit/ratelimit_impl.cc @@ -60,7 +60,7 @@ void GrpcClientImpl::onFailure(const Optional&, const std::string&) { GrpcFactoryImpl::GrpcFactoryImpl(const Json::Object& config, Upstream::ClusterManager& cm, Stats::Store& stats_store) : cluster_name_(config.getString("cluster_name")), cm_(cm), stats_store_(stats_store) { - if (!cm_.has(cluster_name_)) { + if (!cm_.get(cluster_name_)) { throw EnvoyException(fmt::format("unknown rate limit service cluster '{}'", cluster_name_)); } } diff --git a/source/common/router/config_impl.cc b/source/common/router/config_impl.cc index c3443ff64c61..721eb720040e 100644 --- a/source/common/router/config_impl.cc +++ b/source/common/router/config_impl.cc @@ -194,7 +194,7 @@ VirtualHost::VirtualHost(const Json::Object& virtual_host, Runtime::Loader& runt } if (!routes_.back()->isRedirect()) { - if (!cm.has(routes_.back()->clusterName())) { + if (!cm.get(routes_.back()->clusterName())) { throw EnvoyException( fmt::format("route: unknown cluster '{}'", routes_.back()->clusterName())); } diff --git a/source/common/stats/statsd.cc b/source/common/stats/statsd.cc index e6c054eafeb8..45536f1f5325 100644 --- a/source/common/stats/statsd.cc +++ b/source/common/stats/statsd.cc @@ -60,7 +60,7 @@ TcpStatsdSink::TcpStatsdSink(const std::string& stat_cluster, const std::string& : stat_cluster_(stat_cluster), stat_host_(stat_host), cluster_name_(cluster_name), tls_(tls), tls_slot_(tls.allocateSlot()), cluster_manager_(cluster_manager) { - if (!cluster_manager.has(cluster_name)) { + if (!cluster_manager.get(cluster_name)) { throw EnvoyException(fmt::format("unknown TCP statsd upstream cluster: {}", cluster_name)); } diff --git a/source/common/tracing/http_tracer_impl.cc b/source/common/tracing/http_tracer_impl.cc index ba86a684fadb..a95f1eab2841 100644 --- a/source/common/tracing/http_tracer_impl.cc +++ b/source/common/tracing/http_tracer_impl.cc @@ -258,90 +258,39 @@ std::string LightStepUtility::buildJsonBody(const Http::HeaderMap& request_heade } LightStepSink::LightStepSink(const Json::Object& config, Upstream::ClusterManager& cluster_manager, - ThreadLocal::Instance& tls, const std::string& stat_prefix, - Stats::Store& stats, Runtime::RandomGenerator& random, + const std::string& stat_prefix, Stats::Store& stats, + Runtime::RandomGenerator& random, const std::string& local_service_cluster, const std::string& service_node, const std::string& access_token) - : cm_(cluster_manager), tls_(tls), tls_slot_(tls.allocateSlot()) { - collector_cluster_ = config.getString("collector_cluster"); - if (!cm_.has(collector_cluster_)) { - throw EnvoyException(fmt::format("{} collector cluster is not defined on cluster manager level", - collector_cluster_)); - } - - tls.set(tls_slot_, - [this, stat_prefix, &stats, &random, local_service_cluster, service_node, access_token]( - Event::Dispatcher&) -> ThreadLocal::ThreadLocalObjectPtr { - return ThreadLocal::ThreadLocalObjectPtr{new TlsSink(*this, stat_prefix, stats, random, - local_service_cluster, - service_node, access_token)}; - }); -} - -LightStepSink::TlsSink::TlsSink(LightStepSink& parent, const std::string& stat_prefix, - Stats::Store& stats, Runtime::RandomGenerator& random, - const std::string& service_cluster, const std::string& service_node, - const std::string& access_token) - : parent_(parent), + : collector_cluster_(config.getString("collector_cluster")), cm_(cluster_manager), stats_{LIGHTSTEP_STATS(POOL_COUNTER_PREFIX(stats, stat_prefix + "tracing.lightstep."))}, - random_(random), local_service_cluster_(service_cluster), service_node_(service_node), + random_(random), local_service_cluster_(local_service_cluster), service_node_(service_node), access_token_(access_token) { - shutdown_ = false; -} - -void LightStepSink::TlsSink::shutdown() { - shutdown_ = true; - - for (auto& active_request : active_requests_) { - active_request->request_->cancel(); + if (!cm_.get(collector_cluster_)) { + throw EnvoyException(fmt::format("{} collector cluster is not defined on cluster manager level", + collector_cluster_)); } } -void LightStepSink::TlsSink::flushTrace(const Http::HeaderMap& request_headers, - const Http::HeaderMap& response_headers, - const Http::AccessLog::RequestInfo& request_info) { - if (shutdown_) { - return; - } - - Http::AsyncClientPtr client = parent_.cm_.httpAsyncClientForCluster(parent_.collector_cluster_); - - if (!client) { - stats_.client_failed_.inc(); - return; - } - +void LightStepSink::flushTrace(const Http::HeaderMap& request_headers, + const Http::HeaderMap& response_headers, + const Http::AccessLog::RequestInfo& request_info) { Http::MessagePtr msg = LightStepUtility::buildHeaders(access_token_); Buffer::InstancePtr buffer(new Buffer::OwnedImpl( LightStepUtility::buildJsonBody(request_headers, response_headers, request_info, random_, local_service_cluster_, service_node_))); msg->body(std::move(buffer)); - - executeRequest(std::move(client), std::move(msg)); + executeRequest(std::move(msg)); } -void LightStepSink::TlsSink::executeRequest(Http::AsyncClientPtr&& client, Http::MessagePtr&& msg) { - ActiveRequestPtr active_request(new LightStepSink::ActiveRequest(*this)); - Http::AsyncClient::RequestPtr request = - client->send(std::move(msg), *active_request, std::chrono::milliseconds(5000)); - if (request) { - active_request->request_ = std::move(request); - active_request->client_ = std::move(client); - active_request->moveIntoListBack(std::move(active_request), active_requests_); - } +void LightStepSink::executeRequest(Http::MessagePtr&& msg) { + cm_.httpAsyncClientForCluster(collector_cluster_) + .send(std::move(msg), *this, std::chrono::milliseconds(5000)); } -void LightStepSink::ActiveRequest::onFailure(Http::AsyncClient::FailureReason) { - parent_.stats_.collector_failed_.inc(); - clean(); -} - -void LightStepSink::ActiveRequest::onSuccess(Http::MessagePtr&&) { - parent_.stats_.collector_success_.inc(); - clean(); -} +void LightStepSink::onFailure(Http::AsyncClient::FailureReason) { stats_.collector_failed_.inc(); } -void LightStepSink::ActiveRequest::clean() { removeFromList(parent_.active_requests_); } +void LightStepSink::onSuccess(Http::MessagePtr&&) { stats_.collector_success_.inc(); } } // Tracing \ No newline at end of file diff --git a/source/common/tracing/http_tracer_impl.h b/source/common/tracing/http_tracer_impl.h index 9a3135032594..ca281e09e825 100644 --- a/source/common/tracing/http_tracer_impl.h +++ b/source/common/tracing/http_tracer_impl.h @@ -1,11 +1,9 @@ #pragma once #include "envoy/runtime/runtime.h" -#include "envoy/thread_local/thread_local.h" #include "envoy/tracing/http_tracer.h" #include "envoy/upstream/cluster_manager.h" -#include "common/common/linked_object.h" #include "common/http/header_map_impl.h" #include "common/json/json_loader.h" @@ -13,8 +11,7 @@ namespace Tracing { #define LIGHTSTEP_STATS(COUNTER) \ COUNTER(collector_failed) \ - COUNTER(collector_success) \ - COUNTER(client_failed) + COUNTER(collector_success) struct LightStepStats { LIGHTSTEP_STATS(GENERATE_COUNTER_STRUCT) @@ -136,64 +133,31 @@ class LightStepUtility { * * LightStepSink is for flushing data to LightStep collectors. */ -class LightStepSink : public HttpSink { +class LightStepSink : public HttpSink, public Http::AsyncClient::Callbacks { public: LightStepSink(const Json::Object& config, Upstream::ClusterManager& cluster_manager, - ThreadLocal::Instance& tls, const std::string& stat_prefix, Stats::Store& stats, + const std::string& stat_prefix, Stats::Store& stats, Runtime::RandomGenerator& random, const std::string& local_service_cluster, const std::string& service_node, const std::string& access_token); // Tracer::HttpSink void flushTrace(const Http::HeaderMap& request_headers, const Http::HeaderMap& response_headers, - const Http::AccessLog::RequestInfo& request_info) override { - tls_.getTyped(tls_slot_).flushTrace(request_headers, response_headers, request_info); - } + const Http::AccessLog::RequestInfo& request_info) override; + + // Http::AsyncClient::Callbacks + void onSuccess(Http::MessagePtr&&) override; + void onFailure(Http::AsyncClient::FailureReason reason) override; private: - struct ActiveRequest; - typedef std::unique_ptr ActiveRequestPtr; - - struct TlsSink : public ThreadLocal::ThreadLocalObject { - TlsSink(LightStepSink& parent, const std::string& stat_prefix, Stats::Store& stats, - Runtime::RandomGenerator& random, const std::string& service_cluster, - const std::string& service_node, const std::string& access_token); - ~TlsSink() {} - - void flushTrace(const Http::HeaderMap& request_headers, const Http::HeaderMap& response_headers, - const Http::AccessLog::RequestInfo& request_info); - void executeRequest(Http::AsyncClientPtr&& client, Http::MessagePtr&& msg); - - // ThreadLocal::ThreadLocalObject - void shutdown() override; - - LightStepSink& parent_; - bool shutdown_{}; - LightStepStats stats_; - Runtime::RandomGenerator& random_; - std::string local_service_cluster_; - std::string service_node_; - std::string access_token_; - std::list active_requests_; - }; - - struct ActiveRequest : public Http::AsyncClient::Callbacks, LinkedObject { - ActiveRequest(TlsSink& parent) : parent_(parent) {} - - // Http::AsyncClient::Callbacks - void onSuccess(Http::MessagePtr&&) override; - void onFailure(Http::AsyncClient::FailureReason reason) override; - - void clean(); - - TlsSink& parent_; - Http::AsyncClientPtr client_; - Http::AsyncClient::RequestPtr request_; - }; - - std::string collector_cluster_; + void executeRequest(Http::MessagePtr&& msg); + + const std::string collector_cluster_; Upstream::ClusterManager& cm_; - ThreadLocal::Instance& tls_; - const uint32_t tls_slot_; + LightStepStats stats_; + Runtime::RandomGenerator& random_; + const std::string local_service_cluster_; + const std::string service_node_; + const std::string access_token_; }; } // Tracing \ No newline at end of file diff --git a/source/common/upstream/cluster_manager_impl.cc b/source/common/upstream/cluster_manager_impl.cc index 2f2badb83a0c..8b78bd22ed59 100644 --- a/source/common/upstream/cluster_manager_impl.cc +++ b/source/common/upstream/cluster_manager_impl.cc @@ -145,19 +145,11 @@ ClusterManagerImpl::httpConnPoolForCluster(const std::string& cluster) { // Select a host and create a connection pool for it if it does not already exist. auto entry = cluster_manager.thread_local_clusters_.find(cluster); - ConstHostPtr host = entry->second->lb_->chooseHost(); - if (!host) { - entry->second->primary_cluster_.stats().upstream_cx_none_healthy_.inc(); - return nullptr; - } - - if (cluster_manager.host_http_conn_pool_map_.find(host) == - cluster_manager.host_http_conn_pool_map_.end()) { - cluster_manager.host_http_conn_pool_map_[host] = - allocateConnPool(cluster_manager.dispatcher_, host, stats_); + if (entry == cluster_manager.thread_local_clusters_.end()) { + throw EnvoyException(fmt::format("unknown cluster '{}'", cluster)); } - return cluster_manager.host_http_conn_pool_map_[host].get(); + return entry->second->connPool(); } void ClusterManagerImpl::postThreadLocalClusterUpdate(const ClusterImplBase& primary_cluster, @@ -184,6 +176,10 @@ Host::CreateConnectionData ClusterManagerImpl::tcpConnForCluster(const std::stri tls_.getTyped(thread_local_slot_); auto entry = cluster_manager.thread_local_clusters_.find(cluster); + if (entry == cluster_manager.thread_local_clusters_.end()) { + throw EnvoyException(fmt::format("unknown cluster '{}'", cluster)); + } + ConstHostPtr logical_host = entry->second->lb_->chooseHost(); if (logical_host) { return logical_host->createConnection(cluster_manager.dispatcher_); @@ -193,28 +189,28 @@ Host::CreateConnectionData ClusterManagerImpl::tcpConnForCluster(const std::stri } } -Http::AsyncClientPtr ClusterManagerImpl::httpAsyncClientForCluster(const std::string& cluster) { - Http::ConnectionPool::Instance* conn_pool = httpConnPoolForCluster(cluster); +Http::AsyncClient& ClusterManagerImpl::httpAsyncClientForCluster(const std::string& cluster) { ThreadLocalClusterManagerImpl& cluster_manager = tls_.getTyped(thread_local_slot_); - if (conn_pool) { - return Http::AsyncClientPtr{ - new Http::AsyncClientImpl(*conn_pool, cluster, stats_, cluster_manager.dispatcher_)}; + auto entry = cluster_manager.thread_local_clusters_.find(cluster); + if (entry != cluster_manager.thread_local_clusters_.end()) { + return entry->second->http_async_client_; } else { - return nullptr; + throw EnvoyException(fmt::format("unknown cluster '{}'", cluster)); } } ClusterManagerImpl::ThreadLocalClusterManagerImpl::ThreadLocalClusterManagerImpl( ClusterManagerImpl& parent, Event::Dispatcher& dispatcher, Runtime::Loader& runtime, Runtime::RandomGenerator& random) - : dispatcher_(dispatcher) { + : parent_(parent), dispatcher_(dispatcher) { for (auto& cluster : parent.primary_clusters_) { - thread_local_clusters_[cluster.first].reset(new ClusterEntry(*cluster.second, runtime, random)); + thread_local_clusters_[cluster.first].reset( + new ClusterEntry(*this, *cluster.second, runtime, random, parent.stats_, dispatcher)); } for (auto& cluster : thread_local_clusters_) { - cluster.second->host_set_->addMemberUpdateCb( + cluster.second->host_set_.addMemberUpdateCb( [this](const std::vector&, const std::vector& hosts_removed) -> void { // We need to go through and purge any connection pools for hosts that got deleted. // Right now hosts are specific to clusters, so even if two hosts actually point @@ -245,7 +241,7 @@ void ClusterManagerImpl::ThreadLocalClusterManagerImpl::updateClusterMembership( tls.getTyped(thead_local_slot); ASSERT(config.thread_local_clusters_.find(name) != config.thread_local_clusters_.end()); - config.thread_local_clusters_[name]->host_set_->updateHosts( + config.thread_local_clusters_[name]->host_set_.updateHosts( hosts, healthy_hosts, local_zone_hosts, local_zone_healthy_hosts, hosts_added, hosts_removed); } @@ -254,25 +250,43 @@ void ClusterManagerImpl::ThreadLocalClusterManagerImpl::shutdown() { } ClusterManagerImpl::ThreadLocalClusterManagerImpl::ClusterEntry::ClusterEntry( - const Cluster& parent, Runtime::Loader& runtime, Runtime::RandomGenerator& random) - : host_set_(new HostSetImpl()), primary_cluster_(parent) { + ThreadLocalClusterManagerImpl& parent, const Cluster& cluster, Runtime::Loader& runtime, + Runtime::RandomGenerator& random, Stats::Store& stats_store, Event::Dispatcher& dispatcher) + : parent_(parent), primary_cluster_(cluster), + http_async_client_(cluster, *this, stats_store, dispatcher) { - switch (parent.lbType()) { + switch (cluster.lbType()) { case LoadBalancerType::LeastRequest: { - lb_.reset(new LeastRequestLoadBalancer(*host_set_, parent.stats(), runtime, random)); + lb_.reset(new LeastRequestLoadBalancer(host_set_, cluster.stats(), runtime, random)); break; } case LoadBalancerType::Random: { - lb_.reset(new RandomLoadBalancer(*host_set_, parent.stats(), runtime, random)); + lb_.reset(new RandomLoadBalancer(host_set_, cluster.stats(), runtime, random)); break; } case LoadBalancerType::RoundRobin: { - lb_.reset(new RoundRobinLoadBalancer(*host_set_, parent.stats(), runtime)); + lb_.reset(new RoundRobinLoadBalancer(host_set_, cluster.stats(), runtime)); break; } } } +Http::ConnectionPool::Instance* +ClusterManagerImpl::ThreadLocalClusterManagerImpl::ClusterEntry::connPool() { + ConstHostPtr host = lb_->chooseHost(); + if (!host) { + primary_cluster_.stats().upstream_cx_none_healthy_.inc(); + return nullptr; + } + + if (parent_.host_http_conn_pool_map_.find(host) == parent_.host_http_conn_pool_map_.end()) { + parent_.host_http_conn_pool_map_[host] = + parent_.parent_.allocateConnPool(parent_.dispatcher_, host, parent_.parent_.stats_); + } + + return parent_.host_http_conn_pool_map_[host].get(); +} + Http::ConnectionPool::InstancePtr ProdClusterManagerImpl::allocateConnPool(Event::Dispatcher& dispatcher, ConstHostPtr host, Stats::Store& store) { diff --git a/source/common/upstream/cluster_manager_impl.h b/source/common/upstream/cluster_manager_impl.h index 8fcbbf554ece..4b00de9c0012 100644 --- a/source/common/upstream/cluster_manager_impl.h +++ b/source/common/upstream/cluster_manager_impl.h @@ -7,6 +7,7 @@ #include "envoy/thread_local/thread_local.h" #include "envoy/upstream/cluster_manager.h" +#include "common/http/async_client_impl.h" #include "common/json/json_loader.h" namespace Upstream { @@ -41,10 +42,9 @@ class ClusterManagerImpl : public ClusterManager { } const Cluster* get(const std::string& cluster) override; - bool has(const std::string& cluster) override { return primary_clusters_.count(cluster); } Http::ConnectionPool::Instance* httpConnPoolForCluster(const std::string& cluster) override; Host::CreateConnectionData tcpConnForCluster(const std::string& cluster) override; - Http::AsyncClientPtr httpAsyncClientForCluster(const std::string& cluster) override; + Http::AsyncClient& httpAsyncClientForCluster(const std::string& cluster) override; void shutdown() override { for (auto& cluster : primary_clusters_) { @@ -62,13 +62,19 @@ class ClusterManagerImpl : public ClusterManager { * connection pools. */ struct ThreadLocalClusterManagerImpl : public ThreadLocal::ThreadLocalObject { - struct ClusterEntry { - ClusterEntry(const Cluster& parent, Runtime::Loader& runtime, - Runtime::RandomGenerator& random); + struct ClusterEntry : public Http::AsyncClientConnPoolFactory { + ClusterEntry(ThreadLocalClusterManagerImpl& parent, const Cluster& cluster, + Runtime::Loader& runtime, Runtime::RandomGenerator& random, + Stats::Store& stats_store, Event::Dispatcher& dispatcher); - HostSetImplPtr host_set_; + // Http::AsyncClientConnPoolFactory + Http::ConnectionPool::Instance* connPool() override; + + ThreadLocalClusterManagerImpl& parent_; + HostSetImpl host_set_; LoadBalancerPtr lb_; const Cluster& primary_cluster_; + Http::AsyncClientImpl http_async_client_; }; typedef std::unique_ptr ClusterEntryPtr; @@ -87,6 +93,7 @@ class ClusterManagerImpl : public ClusterManager { // ThreadLocal::ThreadLocalObject void shutdown() override; + ClusterManagerImpl& parent_; Event::Dispatcher& dispatcher_; std::unordered_map thread_local_clusters_; std::unordered_map host_http_conn_pool_map_; diff --git a/source/common/upstream/sds.cc b/source/common/upstream/sds.cc index 82e9fc43bc61..48bf721ebf82 100644 --- a/source/common/upstream/sds.cc +++ b/source/common/upstream/sds.cc @@ -101,11 +101,6 @@ void SdsClusterImpl::parseSdsResponse(Http::Message& response) { void SdsClusterImpl::refreshHosts() { log_debug("starting sds refresh for cluster: {}", name_); stats_.update_attempt_.inc(); - client_ = cm_.httpAsyncClientForCluster(sds_config_.sds_cluster_name_); - if (!client_) { - onFailure(Http::AsyncClient::FailureReason::Reset); - return; - } Http::MessagePtr message(new Http::RequestMessageImpl()); message->headers().addViaMoveValue(Http::Headers::get().Scheme, "http"); @@ -113,7 +108,8 @@ void SdsClusterImpl::refreshHosts() { message->headers().addViaMoveValue(Http::Headers::get().Path, "/v1/registration/" + service_name_); message->headers().addViaMoveValue(Http::Headers::get().Host, "sds"); - active_request_ = client_->send(std::move(message), *this, Optional()); + active_request_ = cm_.httpAsyncClientForCluster(sds_config_.sds_cluster_name_) + .send(std::move(message), *this, Optional()); } void SdsClusterImpl::requestComplete() { @@ -125,8 +121,7 @@ void SdsClusterImpl::requestComplete() { initialize_callback_ = nullptr; } - active_request_.reset(); - client_.reset(); + active_request_ = nullptr; // Add refresh jitter based on the configured interval. std::chrono::milliseconds final_delay = @@ -139,8 +134,7 @@ void SdsClusterImpl::requestComplete() { void SdsClusterImpl::shutdown() { if (active_request_) { active_request_->cancel(); - active_request_.reset(); - client_.reset(); + active_request_ = nullptr; } refresh_timer_.reset(); diff --git a/source/common/upstream/sds.h b/source/common/upstream/sds.h index 94e09261d81c..e465dec280e4 100644 --- a/source/common/upstream/sds.h +++ b/source/common/upstream/sds.h @@ -54,8 +54,7 @@ class SdsClusterImpl : public BaseDynamicClusterImpl, public Http::AsyncClient:: const std::string service_name_; Runtime::RandomGenerator& random_; Event::TimerPtr refresh_timer_; - Http::AsyncClientPtr client_; - Http::AsyncClient::RequestPtr active_request_; + Http::AsyncClient::Request* active_request_{}; uint64_t pending_health_checks_{}; }; diff --git a/source/server/configuration_impl.cc b/source/server/configuration_impl.cc index 8c7a4ddf5eba..76bda817e681 100644 --- a/source/server/configuration_impl.cc +++ b/source/server/configuration_impl.cc @@ -82,9 +82,9 @@ void MainImpl::initializeTracers(const Json::Object& tracing_configuration_) { StringUtil::rtrim(access_token); http_tracer_->addSink(Tracing::HttpSinkPtr{new Tracing::LightStepSink( - sink.getObject("config"), *cluster_manager_, server_.threadLocal(), "", - server_.stats(), server_.random(), server_.options().serviceClusterName(), - server_.options().serviceNodeName(), access_token)}); + sink.getObject("config"), *cluster_manager_, "", server_.stats(), server_.random(), + server_.options().serviceClusterName(), server_.options().serviceNodeName(), + access_token)}); } else { throw EnvoyException(fmt::format("Unsupported sink type: '{}'", type)); } diff --git a/test/common/filter/auth/client_ssl_test.cc b/test/common/filter/auth/client_ssl_test.cc index d124b95b4cdb..401ac1281ee8 100644 --- a/test/common/filter/auth/client_ssl_test.cc +++ b/test/common/filter/auth/client_ssl_test.cc @@ -26,7 +26,8 @@ TEST(ClientSslAuthAllowedPrincipalsTest, EmptyString) { class ClientSslAuthFilterTest : public testing::Test { public: - ClientSslAuthFilterTest() : interval_timer_(new Event::MockTimer(&dispatcher_)) {} + ClientSslAuthFilterTest() + : interval_timer_(new Event::MockTimer(&dispatcher_)), request_(&cm_.async_client_) {} ~ClientSslAuthFilterTest() { tls_.shutdownThread(); } void setup() { @@ -39,7 +40,7 @@ class ClientSslAuthFilterTest : public testing::Test { )EOF"; Json::StringLoader loader(json); - EXPECT_CALL(cm_, has("vpn")).WillOnce(Return(true)); + EXPECT_CALL(cm_, get("vpn")); setupRequest(); config_.reset(new Config(loader, tls_, cm_, dispatcher_, stats_store_, runtime_, "127.0.0.1")); @@ -52,15 +53,14 @@ class ClientSslAuthFilterTest : public testing::Test { } void setupRequest() { - client_ = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("vpn")).WillOnce(Return(client_)); - EXPECT_CALL(*client_, send_(_, _, _)) + EXPECT_CALL(cm_, httpAsyncClientForCluster("vpn")).WillOnce(ReturnRef(cm_.async_client_)); + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) .WillOnce( Invoke([this](Http::MessagePtr& request, Http::AsyncClient::Callbacks& callbacks, Optional) -> Http::AsyncClient::Request* { EXPECT_EQ("127.0.0.1", request->headers().get("x-forwarded-for")); callbacks_ = &callbacks; - return new Http::MockAsyncClientRequest(client_); + return &request_; })); } @@ -71,11 +71,11 @@ class ClientSslAuthFilterTest : public testing::Test { NiceMock filter_callbacks_; std::unique_ptr instance_; Event::MockTimer* interval_timer_; - Http::MockAsyncClient* client_; Http::AsyncClient::Callbacks* callbacks_; Ssl::MockConnection ssl_; Stats::IsolatedStoreImpl stats_store_; NiceMock runtime_; + Http::MockAsyncClientRequest request_; }; TEST_F(ClientSslAuthFilterTest, NoCluster) { @@ -87,7 +87,7 @@ TEST_F(ClientSslAuthFilterTest, NoCluster) { )EOF"; Json::StringLoader loader(json); - EXPECT_CALL(cm_, has("bad_cluster")).WillOnce(Return(false)); + EXPECT_CALL(cm_, get("bad_cluster")).WillOnce(Return(nullptr)); EXPECT_THROW(new Config(loader, tls_, cm_, dispatcher_, stats_store_, runtime_, "127.0.0.1"), EnvoyException); } @@ -174,7 +174,14 @@ TEST_F(ClientSslAuthFilterTest, Basic) { // Interval timer fires, cannot obtain async client. EXPECT_CALL(*interval_timer_, enableTimer(_)); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("vpn")).WillOnce(Return(nullptr)); + EXPECT_CALL(cm_, httpAsyncClientForCluster("vpn")).WillOnce(ReturnRef(cm_.async_client_)); + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) + .WillOnce( + Invoke([&](Http::MessagePtr&, Http::AsyncClient::Callbacks& callbacks, + const Optional&) -> Http::AsyncClient::Request* { + callbacks.onFailure(Http::AsyncClient::FailureReason::Reset); + return nullptr; + })); interval_timer_->callback_(); EXPECT_EQ(4U, stats_store_.counter("auth.clientssl.vpn.update_failure").value()); diff --git a/test/common/filter/tcp_proxy_test.cc b/test/common/filter/tcp_proxy_test.cc index d6765af3f087..623473c395ab 100644 --- a/test/common/filter/tcp_proxy_test.cc +++ b/test/common/filter/tcp_proxy_test.cc @@ -25,7 +25,7 @@ TEST(TcpProxyConfigTest, NoCluster) { Json::StringLoader config(json); NiceMock cluster_manager; - EXPECT_CALL(cluster_manager, has("fake_cluster")).WillOnce(Return(false)); + EXPECT_CALL(cluster_manager, get("fake_cluster")).WillOnce(Return(nullptr)); EXPECT_THROW(TcpProxyConfig(config, cluster_manager, cluster_manager.cluster_.stats_store_), EnvoyException); } @@ -41,7 +41,6 @@ class TcpProxyTest : public testing::Test { )EOF"; Json::StringLoader config(json); - EXPECT_CALL(cluster_manager_, has("fake_cluster")).WillOnce(Return(true)); config_.reset( new TcpProxyConfig(config, cluster_manager_, cluster_manager_.cluster_.stats_store_)); } diff --git a/test/common/grpc/rpc_channel_impl_test.cc b/test/common/grpc/rpc_channel_impl_test.cc index 4eb92ed09974..15e9ea6bbbff 100644 --- a/test/common/grpc/rpc_channel_impl_test.cc +++ b/test/common/grpc/rpc_channel_impl_test.cc @@ -12,20 +12,19 @@ namespace Grpc { class GrpcRequestImplTest : public testing::Test { public: - GrpcRequestImplTest() { + GrpcRequestImplTest() : http_async_client_request_(&cm_.async_client_) { ON_CALL(cm_.cluster_, features()).WillByDefault(Return(Upstream::Cluster::Features::HTTP2)); } - void expectNormalRequest() { - http_async_client_ = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("cluster")).WillOnce(Return(http_async_client_)); - EXPECT_CALL(*http_async_client_, send_(_, _, _)) + void expectNormalRequest( + const Optional timeout = Optional()) { + EXPECT_CALL(cm_, httpAsyncClientForCluster("cluster")).WillOnce(ReturnRef(cm_.async_client_)); + EXPECT_CALL(cm_.async_client_, send_(_, _, timeout)) .WillOnce(Invoke([&](Http::MessagePtr& request, Http::AsyncClient::Callbacks& callbacks, Optional) -> Http::AsyncClient::Request* { http_request_ = std::move(request); http_callbacks_ = &callbacks; - http_async_client_request_ = new Http::MockAsyncClientRequest(http_async_client_); - return http_async_client_request_; + return &http_async_client_request_; })); } @@ -34,8 +33,7 @@ class GrpcRequestImplTest : public testing::Test { RpcChannelImpl grpc_request_{cm_, "cluster", grpc_callbacks_, cm_.cluster_.stats_store_, Optional()}; helloworld::Greeter::Stub service_{&grpc_request_}; - Http::MockAsyncClient* http_async_client_{}; - Http::MockAsyncClientRequest* http_async_client_request_{}; + Http::MockAsyncClientRequest http_async_client_request_; Http::MessagePtr http_request_; Http::AsyncClient::Callbacks* http_callbacks_{}; }; @@ -231,21 +229,16 @@ TEST_F(GrpcRequestImplTest, HttpAsyncRequestTimeout) { http_callbacks_->onFailure(Http::AsyncClient::FailureReason::RequestTimemout); } -TEST_F(GrpcRequestImplTest, NoHttpAsyncClient) { - EXPECT_CALL(cm_, httpAsyncClientForCluster_("cluster")).WillOnce(Return(nullptr)); - EXPECT_CALL(grpc_callbacks_, onFailure(Optional(), "http request failure")); - - helloworld::HelloRequest request; - request.set_name("a name"); - helloworld::HelloReply response; - service_.SayHello(nullptr, &request, &response, nullptr); -} - TEST_F(GrpcRequestImplTest, NoHttpAsyncRequest) { - Http::MockAsyncClient* http_async_client = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("cluster")).WillOnce(Return(http_async_client)); - EXPECT_CALL(*http_async_client, send_(_, _, _)).WillOnce(Return(nullptr)); - EXPECT_CALL(grpc_callbacks_, onFailure(Optional(), "http request failure")); + EXPECT_CALL(cm_, httpAsyncClientForCluster("cluster")).WillOnce(ReturnRef(cm_.async_client_)); + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) + .WillOnce( + Invoke([&](Http::MessagePtr&, Http::AsyncClient::Callbacks& callbacks, + const Optional&) -> Http::AsyncClient::Request* { + callbacks.onFailure(Http::AsyncClient::FailureReason::Reset); + return nullptr; + })); + EXPECT_CALL(grpc_callbacks_, onFailure(Optional(), "stream reset")); helloworld::HelloRequest request; request.set_name("a name"); @@ -261,7 +254,7 @@ TEST_F(GrpcRequestImplTest, Cancel) { helloworld::HelloReply response; service_.SayHello(nullptr, &request, &response, nullptr); - EXPECT_CALL(*http_async_client_request_, cancel()); + EXPECT_CALL(http_async_client_request_, cancel()); grpc_request_.cancel(); } @@ -270,17 +263,7 @@ TEST_F(GrpcRequestImplTest, RequestTimeoutSet) { RpcChannelImpl grpc_request_timeout{cm_, "cluster", grpc_callbacks_, cm_.cluster_.stats_store_, timeout}; helloworld::Greeter::Stub service_timeout{&grpc_request_timeout}; - http_async_client_ = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("cluster")).WillOnce(Return(http_async_client_)); - EXPECT_CALL(*http_async_client_, send_(_, _, timeout)) - .WillOnce( - Invoke([&](Http::MessagePtr& request, Http::AsyncClient::Callbacks& callbacks, - const Optional&) -> Http::AsyncClient::Request* { - http_request_ = std::move(request); - http_callbacks_ = &callbacks; - http_async_client_request_ = new Http::MockAsyncClientRequest(http_async_client_); - return http_async_client_request_; - })); + expectNormalRequest(timeout); helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; diff --git a/test/common/http/async_client_impl_test.cc b/test/common/http/async_client_impl_test.cc index f306c34c9b4a..ff4c20ad8b0a 100644 --- a/test/common/http/async_client_impl_test.cc +++ b/test/common/http/async_client_impl_test.cc @@ -6,6 +6,7 @@ #include "test/mocks/common.h" #include "test/mocks/http/mocks.h" #include "test/mocks/stats/mocks.h" +#include "test/mocks/upstream/mocks.h" using testing::_; using testing::ByRef; @@ -15,10 +16,13 @@ using testing::Ref; namespace Http { -class AsyncClientImplTest : public testing::Test { +class AsyncClientImplTest : public testing::Test, public AsyncClientConnPoolFactory { public: AsyncClientImplTest() { HttpTestUtility::addDefaultHeaders(message_->headers()); } + // Http::AsyncClientConnPoolFactory + Http::ConnectionPool::Instance* connPool() override { return &conn_pool_; } + MessagePtr message_{new RequestMessageImpl()}; MockAsyncClientCallbacks callbacks_; ConnectionPool::MockInstance conn_pool_; @@ -27,6 +31,7 @@ class AsyncClientImplTest : public testing::Test { NiceMock stats_store_; NiceMock* timer_; NiceMock dispatcher_; + NiceMock cluster_; }; TEST_F(AsyncClientImplTest, Basic) { @@ -45,9 +50,8 @@ TEST_F(AsyncClientImplTest, Basic) { EXPECT_CALL(stream_encoder_, encodeData(BufferEqual(&data), true)); EXPECT_CALL(callbacks_, onSuccess_(_)); - AsyncClientImpl client(conn_pool_, "fake_cluster", stats_store_, dispatcher_); - AsyncClient::RequestPtr request = - client.send(std::move(message_), callbacks_, Optional()); + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); + client.send(std::move(message_), callbacks_, Optional()); EXPECT_CALL(stats_store_, counter("cluster.fake_cluster.upstream_rq_2xx")); EXPECT_CALL(stats_store_, counter("cluster.fake_cluster.upstream_rq_200")); @@ -62,6 +66,53 @@ TEST_F(AsyncClientImplTest, Basic) { response_decoder_->decodeData(data, true); } +TEST_F(AsyncClientImplTest, MultipleRequests) { + // Send request 1 + message_->body(Buffer::InstancePtr{new Buffer::OwnedImpl("test body")}); + Buffer::Instance& data = *message_->body(); + + EXPECT_CALL(conn_pool_, newStream(_, _)) + .WillOnce(Invoke([&](StreamDecoder& decoder, ConnectionPool::Callbacks& callbacks) + -> ConnectionPool::Cancellable* { + callbacks.onPoolReady(stream_encoder_, conn_pool_.host_); + response_decoder_ = &decoder; + return nullptr; + })); + + EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(ByRef(message_->headers())), false)); + EXPECT_CALL(stream_encoder_, encodeData(BufferEqual(&data), true)); + + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); + client.send(std::move(message_), callbacks_, Optional()); + + // Send request 2. + MessagePtr message2{new RequestMessageImpl()}; + HttpTestUtility::addDefaultHeaders(message2->headers()); + NiceMock stream_encoder2; + StreamDecoder* response_decoder2{}; + MockAsyncClientCallbacks callbacks2; + EXPECT_CALL(conn_pool_, newStream(_, _)) + .WillOnce(Invoke([&](StreamDecoder& decoder, ConnectionPool::Callbacks& callbacks) + -> ConnectionPool::Cancellable* { + callbacks.onPoolReady(stream_encoder2, conn_pool_.host_); + response_decoder2 = &decoder; + return nullptr; + })); + EXPECT_CALL(stream_encoder2, encodeHeaders(HeaderMapEqualRef(ByRef(message2->headers())), true)); + client.send(std::move(message2), callbacks2, Optional()); + + // Finish request 2. + HeaderMapPtr response_headers2(new HeaderMapImpl{{":status", "503"}}); + EXPECT_CALL(callbacks2, onSuccess_(_)); + response_decoder2->decodeHeaders(std::move(response_headers2), true); + + // Finish request 1. + HeaderMapPtr response_headers(new HeaderMapImpl{{":status", "200"}}); + response_decoder_->decodeHeaders(std::move(response_headers), false); + EXPECT_CALL(callbacks_, onSuccess_(_)); + response_decoder_->decodeData(data, true); +} + TEST_F(AsyncClientImplTest, Trailers) { message_->body(Buffer::InstancePtr{new Buffer::OwnedImpl("test body")}); Buffer::Instance& data = *message_->body(); @@ -78,9 +129,8 @@ TEST_F(AsyncClientImplTest, Trailers) { EXPECT_CALL(stream_encoder_, encodeData(BufferEqual(&data), true)); EXPECT_CALL(callbacks_, onSuccess_(_)); - AsyncClientImpl client(conn_pool_, "fake_cluster", stats_store_, dispatcher_); - AsyncClient::RequestPtr request = - client.send(std::move(message_), callbacks_, Optional()); + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); + client.send(std::move(message_), callbacks_, Optional()); HeaderMapPtr response_headers(new HeaderMapImpl{{":status", "200"}}); response_decoder_->decodeHeaders(std::move(response_headers), false); response_decoder_->decodeData(data, false); @@ -103,9 +153,8 @@ TEST_F(AsyncClientImplTest, FailRequest) { EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(ByRef(message_->headers())), true)); EXPECT_CALL(callbacks_, onFailure(Http::AsyncClient::FailureReason::Reset)); - AsyncClientImpl client(conn_pool_, "fake_cluster", stats_store_, dispatcher_); - AsyncClient::RequestPtr request = - client.send(std::move(message_), callbacks_, Optional()); + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); + client.send(std::move(message_), callbacks_, Optional()); stream_encoder_.getStream().resetStream(StreamResetReason::RemoteReset); } @@ -120,8 +169,8 @@ TEST_F(AsyncClientImplTest, CancelRequest) { EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(ByRef(message_->headers())), true)); EXPECT_CALL(stream_encoder_.stream_, resetStream(_)); - AsyncClientImpl client(conn_pool_, "fake_cluster", stats_store_, dispatcher_); - AsyncClient::RequestPtr request = + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); + AsyncClient::Request* request = client.send(std::move(message_), callbacks_, Optional()); request->cancel(); } @@ -141,7 +190,7 @@ TEST_F(AsyncClientImplTest, PoolFailure) { })); EXPECT_CALL(callbacks_, onFailure(Http::AsyncClient::FailureReason::Reset)); - AsyncClientImpl client(conn_pool_, "fake_cluster", stats_store_, dispatcher_); + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); EXPECT_EQ(nullptr, client.send(std::move(message_), callbacks_, Optional())); } @@ -151,7 +200,6 @@ TEST_F(AsyncClientImplTest, RequestTimeout) { EXPECT_CALL(stats_store_, counter("cluster.fake_cluster.upstream_rq_504")); EXPECT_CALL(stats_store_, counter("cluster.fake_cluster.internal.upstream_rq_5xx")); EXPECT_CALL(stats_store_, counter("cluster.fake_cluster.internal.upstream_rq_504")); - EXPECT_CALL(stats_store_, counter("cluster.fake_cluster.upstream_rq_timeout")); EXPECT_CALL(conn_pool_, newStream(_, _)) .WillOnce(Invoke([&](StreamDecoder&, ConnectionPool::Callbacks& callbacks) -> ConnectionPool::Cancellable* { @@ -164,10 +212,11 @@ TEST_F(AsyncClientImplTest, RequestTimeout) { timer_ = new NiceMock(&dispatcher_); EXPECT_CALL(*timer_, enableTimer(std::chrono::milliseconds(40))); EXPECT_CALL(stream_encoder_.stream_, resetStream(_)); - AsyncClientImpl client(conn_pool_, "fake_cluster", stats_store_, dispatcher_); - AsyncClient::RequestPtr request = - client.send(std::move(message_), callbacks_, std::chrono::milliseconds(40)); + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); + client.send(std::move(message_), callbacks_, std::chrono::milliseconds(40)); timer_->callback_(); + + EXPECT_EQ(1UL, cluster_.stats_store_.counter("cluster.fake_cluster.upstream_rq_timeout").value()); } TEST_F(AsyncClientImplTest, DisableTimer) { @@ -183,8 +232,8 @@ TEST_F(AsyncClientImplTest, DisableTimer) { EXPECT_CALL(*timer_, enableTimer(std::chrono::milliseconds(200))); EXPECT_CALL(*timer_, disableTimer()); EXPECT_CALL(stream_encoder_.stream_, resetStream(_)); - AsyncClientImpl client(conn_pool_, "fake_cluster", stats_store_, dispatcher_); - AsyncClient::RequestPtr request = + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); + AsyncClient::Request* request = client.send(std::move(message_), callbacks_, std::chrono::milliseconds(200)); request->cancel(); } diff --git a/test/common/ratelimit/ratelimit_impl_test.cc b/test/common/ratelimit/ratelimit_impl_test.cc index a5773a2e93f6..92862b97b2a6 100644 --- a/test/common/ratelimit/ratelimit_impl_test.cc +++ b/test/common/ratelimit/ratelimit_impl_test.cc @@ -112,7 +112,7 @@ TEST(RateLimitGrpcFactoryTest, NoCluster) { Upstream::MockClusterManager cm; Stats::IsolatedStoreImpl stats_store; - EXPECT_CALL(cm, has("foo")).WillOnce(Return(false)); + EXPECT_CALL(cm, get("foo")).WillOnce(Return(nullptr)); EXPECT_THROW(GrpcFactoryImpl(config, cm, stats_store), EnvoyException); } @@ -127,7 +127,7 @@ TEST(RateLimitGrpcFactoryTest, Create) { Upstream::MockClusterManager cm; Stats::IsolatedStoreImpl stats_store; - EXPECT_CALL(cm, has("foo")).WillOnce(Return(true)); + EXPECT_CALL(cm, get("foo")); GrpcFactoryImpl factory(config, cm, stats_store); factory.create(Optional()); } diff --git a/test/common/router/config_impl_test.cc b/test/common/router/config_impl_test.cc index b68d0b080cc1..212cd6dfd75c 100644 --- a/test/common/router/config_impl_test.cc +++ b/test/common/router/config_impl_test.cc @@ -146,7 +146,6 @@ TEST(RouteMatcherTest, TestRoutes) { Json::StringLoader loader(json); NiceMock runtime; NiceMock cm; - ON_CALL(cm, has(_)).WillByDefault(Return(true)); ConfigImpl config(loader, runtime, cm); // Base routing testing. @@ -351,7 +350,6 @@ TEST(RouteMatcherTest, ContentType) { Json::StringLoader loader(json); NiceMock runtime; NiceMock cm; - ON_CALL(cm, has(_)).WillByDefault(Return(true)); ConfigImpl config(loader, runtime, cm); { @@ -403,7 +401,6 @@ TEST(RouteMatcherTest, Runtime) { NiceMock cm; Runtime::MockSnapshot snapshot; - ON_CALL(cm, has(_)).WillByDefault(Return(true)); ON_CALL(runtime, snapshot()).WillByDefault(ReturnRef(snapshot)); ConfigImpl config(loader, runtime, cm); @@ -445,7 +442,6 @@ TEST(RouteMatcherTest, RateLimit) { Json::StringLoader loader(json); NiceMock runtime; NiceMock cm; - ON_CALL(cm, has(_)).WillByDefault(Return(true)); ConfigImpl config(loader, runtime, cm); EXPECT_TRUE(config.routeForRequest(genHeaders("www.lyft.com", "/foo", "GET"), 0) @@ -492,7 +488,6 @@ TEST(RouteMatcherTest, Retry) { Json::StringLoader loader(json); NiceMock runtime; NiceMock cm; - ON_CALL(cm, has(_)).WillByDefault(Return(true)); ConfigImpl config(loader, runtime, cm); EXPECT_EQ(1U, config.routeForRequest(genHeaders("www.lyft.com", "/foo", "GET"), 0) @@ -619,7 +614,6 @@ TEST(RouteMatcherTest, Redirect) { Json::StringLoader loader(json); NiceMock runtime; NiceMock cm; - ON_CALL(cm, has(StrNe(""))).WillByDefault(Return(true)); ConfigImpl config(loader, runtime, cm); EXPECT_EQ(nullptr, diff --git a/test/common/stats/statsd_test.cc b/test/common/stats/statsd_test.cc index e4dca82ae755..11f791c32003 100644 --- a/test/common/stats/statsd_test.cc +++ b/test/common/stats/statsd_test.cc @@ -16,7 +16,7 @@ namespace Statsd { class TcpStatsdSinkTest : public testing::Test { public: TcpStatsdSinkTest() { - EXPECT_CALL(cluster_manager_, has(_)).WillOnce(Return(true)); + EXPECT_CALL(cluster_manager_, get("statsd")); sink_.reset(new TcpStatsdSink("cluster", "host", "statsd", tls_, cluster_manager_)); } diff --git a/test/common/tracing/http_tracer_impl_test.cc b/test/common/tracing/http_tracer_impl_test.cc index e4556aed7fb4..7eb17b6f2c4e 100644 --- a/test/common/tracing/http_tracer_impl_test.cc +++ b/test/common/tracing/http_tracer_impl_test.cc @@ -246,12 +246,12 @@ class LightStepSinkTest : public Test { : stats_{LIGHTSTEP_STATS(POOL_COUNTER_PREFIX(fake_stats_, "prefix.tracing.lightstep."))} {} void setup(Json::Object& config) { - sink_.reset(new LightStepSink(config, cm_, tls_, "prefix.", fake_stats_, random_, - "service_cluster", "service_node", "token")); + sink_.reset(new LightStepSink(config, cm_, "prefix.", fake_stats_, random_, "service_cluster", + "service_node", "token")); } void setupValidSink() { - EXPECT_CALL(cm_, has("lightstep_saas")).WillOnce(Return(true)); + EXPECT_CALL(cm_, get("lightstep_saas")); std::string valid_config = R"EOF( {"collector_cluster": "lightstep_saas"} @@ -266,7 +266,6 @@ class LightStepSinkTest : public Test { Stats::IsolatedStoreImpl fake_stats_; LightStepStats stats_; NiceMock cm_; - NiceMock tls_; NiceMock random_; std::unique_ptr sink_; }; @@ -290,7 +289,7 @@ TEST_F(LightStepSinkTest, InitializeSink) { { // Valid config but not valid cluster - EXPECT_CALL(cm_, has("lightstep_saas")).WillOnce(Return(false)); + EXPECT_CALL(cm_, get("lightstep_saas")).WillOnce(Return(nullptr)); std::string valid_config = R"EOF( {"collector_cluster": "lightstep_saas"} @@ -301,7 +300,7 @@ TEST_F(LightStepSinkTest, InitializeSink) { } { - EXPECT_CALL(cm_, has("lightstep_saas")).WillOnce(Return(true)); + EXPECT_CALL(cm_, get("lightstep_saas")); std::string valid_config = R"EOF( {"collector_cluster": "lightstep_saas"} @@ -317,19 +316,19 @@ TEST_F(LightStepSinkTest, CallbacksCalled) { NiceMock request_info; - NiceMock* client_1 = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("lightstep_saas")).WillOnce(Return(client_1)); + EXPECT_CALL(cm_, httpAsyncClientForCluster("lightstep_saas")) + .WillOnce(ReturnRef(cm_.async_client_)); - Http::MockAsyncClientRequest* request_1 = new Http::MockAsyncClientRequest(client_1); + Http::MockAsyncClientRequest request_1(&cm_.async_client_); Http::AsyncClient::Callbacks* callback_1; const Optional timeout(std::chrono::seconds(5)); - EXPECT_CALL(*client_1, send_(_, _, timeout)) + EXPECT_CALL(cm_.async_client_, send_(_, _, timeout)) .WillOnce( Invoke([&](Http::MessagePtr&, Http::AsyncClient::Callbacks& callbacks, const Optional&) -> Http::AsyncClient::Request* { callback_1 = &callbacks; - return request_1; + return &request_1; })); EXPECT_CALL(random_, uuid()).WillOnce(Return("1")).WillOnce(Return("2")); SystemTime start_time_1; @@ -343,16 +342,16 @@ TEST_F(LightStepSinkTest, CallbacksCalled) { sink_->flushTrace(empty_header_, empty_header_, request_info); - NiceMock* client_2 = new NiceMock(); - Http::MockAsyncClientRequest* request_2 = new Http::MockAsyncClientRequest(client_2); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("lightstep_saas")).WillOnce(Return(client_2)); + Http::MockAsyncClientRequest request_2(&cm_.async_client_); + EXPECT_CALL(cm_, httpAsyncClientForCluster("lightstep_saas")) + .WillOnce(ReturnRef(cm_.async_client_)); Http::AsyncClient::Callbacks* callback_2; - EXPECT_CALL(*client_2, send_(_, _, _)) + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) .WillOnce(Invoke([&](Http::MessagePtr&, Http::AsyncClient::Callbacks& callbacks, Optional) -> Http::AsyncClient::Request* { callback_2 = &callbacks; - return request_2; + return &request_2; })); EXPECT_CALL(random_, uuid()).WillOnce(Return("3")).WillOnce(Return("4")); SystemTime start_time_2; @@ -372,12 +371,6 @@ TEST_F(LightStepSinkTest, CallbacksCalled) { callback_1->onSuccess(std::move(msg)); EXPECT_EQ(1UL, stats_.collector_failed_.value()); EXPECT_EQ(1UL, stats_.collector_success_.value()); - - // Shutdown sink and try to make trace - tls_.shutdownThread_(); - - EXPECT_CALL(cm_, httpAsyncClientForCluster_("lightstep_saas")).Times(0); - sink_->flushTrace(empty_header_, empty_header_, request_info); } TEST_F(LightStepSinkTest, ClientNotAvailable) { @@ -385,21 +378,36 @@ TEST_F(LightStepSinkTest, ClientNotAvailable) { NiceMock request_info; - EXPECT_CALL(cm_, httpAsyncClientForCluster_("lightstep_saas")).WillOnce(Return(nullptr)); + EXPECT_CALL(cm_, httpAsyncClientForCluster("lightstep_saas")) + .WillOnce(ReturnRef(cm_.async_client_)); + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) + .WillOnce( + Invoke([&](Http::MessagePtr&, Http::AsyncClient::Callbacks& callbacks, + const Optional&) -> Http::AsyncClient::Request* { + callbacks.onFailure(Http::AsyncClient::FailureReason::Reset); + return nullptr; + })); + SystemTime start_time_1; + EXPECT_CALL(request_info, startTime()).WillOnce(Return(start_time_1)); + std::chrono::seconds duration_1(1); + EXPECT_CALL(request_info, duration()).WillOnce(Return(duration_1)); + const std::string protocol = "http/1"; + EXPECT_CALL(request_info, protocol()).WillRepeatedly(ReturnRef(protocol)); + Optional code_1(200); + EXPECT_CALL(request_info, responseCode()).WillRepeatedly(ReturnRef(code_1)); sink_->flushTrace(empty_header_, empty_header_, request_info); - EXPECT_EQ(1UL, stats_.client_failed_.value()); - EXPECT_EQ(0UL, stats_.collector_failed_.value()); + EXPECT_EQ(1UL, stats_.collector_failed_.value()); EXPECT_EQ(0UL, stats_.collector_success_.value()); } TEST_F(LightStepSinkTest, ShutdownWhenActiveRequests) { setupValidSink(); - NiceMock* client = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("lightstep_saas")).WillOnce(Return(client)); + EXPECT_CALL(cm_, httpAsyncClientForCluster("lightstep_saas")) + .WillOnce(ReturnRef(cm_.async_client_)); - Http::MockAsyncClientRequest* request = new Http::MockAsyncClientRequest(client); + Http::MockAsyncClientRequest request(&cm_.async_client_); NiceMock request_info; const std::string protocol = "http/1"; @@ -467,19 +475,16 @@ TEST_F(LightStepSinkTest, ShutdownWhenActiveRequests) { )EOF"; Http::AsyncClient::Callbacks* callback; - EXPECT_CALL(*client, send_(_, _, _)) + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) .WillOnce(Invoke([&](Http::MessagePtr& msg, Http::AsyncClient::Callbacks& callbacks, Optional) -> Http::AsyncClient::Request* { callback = &callbacks; EXPECT_EQ(expected_json, msg->bodyAsString()); EXPECT_EQ("token", msg->headers().get("LightStep-Access-Token")); - return request; + return &request; })); sink_->flushTrace(request_header, empty_header_, request_info); - - EXPECT_CALL(*request, cancel()); - tls_.shutdownThread_(); } TEST(LightStepUtilityTest, HeadersNotSet) { diff --git a/test/common/upstream/cluster_manager_impl_test.cc b/test/common/upstream/cluster_manager_impl_test.cc index 1ac2f68afc47..0bc5694e7ac1 100644 --- a/test/common/upstream/cluster_manager_impl_test.cc +++ b/test/common/upstream/cluster_manager_impl_test.cc @@ -11,6 +11,7 @@ using testing::_; using testing::NiceMock; +using testing::Return; using testing::ReturnNew; using testing::SaveArg; @@ -28,46 +29,188 @@ class ClusterManagerImplForTest : public ClusterManagerImpl { MOCK_METHOD1(allocateConnPool_, Http::ConnectionPool::Instance*(ConstHostPtr host)); }; -TEST(ClusterManagerImplTest, DynamicHostRemove) { +class ClusterManagerImplTest : public testing::Test { +public: + void create(const Json::Object& config) { + cluster_manager_.reset(new ClusterManagerImplForTest(config, stats_, tls_, dns_resolver_, + ssl_context_manager_, runtime_, random_, + "us-east-1d")); + } + + Stats::IsolatedStoreImpl stats_; + NiceMock tls_; + NiceMock dns_resolver_; + NiceMock runtime_; + NiceMock random_; + Ssl::ContextManagerImpl ssl_context_manager_{runtime_}; + std::unique_ptr cluster_manager_; +}; + +TEST_F(ClusterManagerImplTest, NoSdsConfig) { std::string json = R"EOF( { - "cluster_manager": { - "clusters": [ - { - "name": "cluster_1", - "connect_timeout_ms": 250, - "type": "strict_dns", - "lb_type": "round_robin", - "hosts": [{"url": "tcp://localhost:11001"}] - }] - } + "clusters": [ + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "sds", + "lb_type": "round_robin" + }] } )EOF"; Json::StringLoader loader(json); + EXPECT_THROW(create(loader), EnvoyException); +} - Stats::IsolatedStoreImpl stats; - NiceMock tls; - NiceMock dns_resolver; - Network::DnsResolver::ResolveCb dns_callback; - Event::MockTimer* dns_timer_ = new NiceMock(&dns_resolver.dispatcher_); - NiceMock runtime; - NiceMock random; - Ssl::ContextManagerImpl ssl_context_manager(runtime); - EXPECT_CALL(dns_resolver, resolve(_, _)).WillRepeatedly(SaveArg<1>(&dns_callback)); +TEST_F(ClusterManagerImplTest, UnknownClusterType) { + std::string json = R"EOF( + { + "clusters": [ + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "foo", + "lb_type": "round_robin" + }] + } + )EOF"; + + Json::StringLoader loader(json); + EXPECT_THROW(create(loader), EnvoyException); +} + +TEST_F(ClusterManagerImplTest, DuplicateCluster) { + std::string json = R"EOF( + { + "clusters": [ + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "static", + "lb_type": "round_robin", + "hosts": [{"url": "tcp://127.0.0.1:11001"}] + }, + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "static", + "lb_type": "round_robin", + "hosts": [{"url": "tcp://127.0.0.1:11001"}] + }] + } + )EOF"; + + Json::StringLoader loader(json); + EXPECT_THROW(create(loader), EnvoyException); +} + +TEST_F(ClusterManagerImplTest, UnknownHcType) { + std::string json = R"EOF( + { + "clusters": [ + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "static", + "lb_type": "round_robin", + "hosts": [{"url": "tcp://127.0.0.1:11001"}], + "health_check": { + "type": "foo" + } + }] + } + )EOF"; - ClusterManagerImplForTest cluster_manager(loader.getObject("cluster_manager"), stats, tls, - dns_resolver, ssl_context_manager, runtime, random, - "us-east-1d"); + Json::StringLoader loader(json); + EXPECT_THROW(create(loader), EnvoyException); +} + +TEST_F(ClusterManagerImplTest, TcpHealthChecker) { + std::string json = R"EOF( + { + "clusters": [ + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "static", + "lb_type": "round_robin", + "hosts": [{"url": "tcp://127.0.0.1:11001"}], + "health_check": { + "type": "tcp", + "timeout_ms": 1000, + "interval_ms": 1000, + "unhealthy_threshold": 2, + "healthy_threshold": 2, + "send": [ + {"binary": "01"} + ], + "receive": [ + {"binary": "02"} + ] + } + }] + } + )EOF"; + + Json::StringLoader loader(json); + Network::MockClientConnection* connection = new NiceMock(); + EXPECT_CALL(dns_resolver_.dispatcher_, createClientConnection_("tcp://127.0.0.1:11001")) + .WillOnce(Return(connection)); + create(loader); +} + +TEST_F(ClusterManagerImplTest, UnknownCluster) { + std::string json = R"EOF( + { + "clusters": [ + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "static", + "lb_type": "round_robin", + "hosts": [{"url": "tcp://127.0.0.1:11001"}] + }] + } + )EOF"; + + Json::StringLoader loader(json); + create(loader); + EXPECT_EQ(nullptr, cluster_manager_->get("hello")); + EXPECT_THROW(cluster_manager_->httpConnPoolForCluster("hello"), EnvoyException); + EXPECT_THROW(cluster_manager_->tcpConnForCluster("hello"), EnvoyException); + EXPECT_THROW(cluster_manager_->httpAsyncClientForCluster("hello"), EnvoyException); +} + +TEST_F(ClusterManagerImplTest, DynamicHostRemove) { + std::string json = R"EOF( + { + "clusters": [ + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "strict_dns", + "lb_type": "round_robin", + "hosts": [{"url": "tcp://localhost:11001"}] + }] + } + )EOF"; + + Json::StringLoader loader(json); + + Network::DnsResolver::ResolveCb dns_callback; + Event::MockTimer* dns_timer_ = new NiceMock(&dns_resolver_.dispatcher_); + EXPECT_CALL(dns_resolver_, resolve(_, _)).WillRepeatedly(SaveArg<1>(&dns_callback)); + create(loader); // Test for no hosts returning the correct values before we have hosts. - EXPECT_EQ(nullptr, cluster_manager.httpConnPoolForCluster("cluster_1")); - EXPECT_EQ(nullptr, cluster_manager.tcpConnForCluster("cluster_1").connection_); - EXPECT_EQ(2UL, stats.counter("cluster.cluster_1.upstream_cx_none_healthy").value()); + EXPECT_EQ(nullptr, cluster_manager_->httpConnPoolForCluster("cluster_1")); + EXPECT_EQ(nullptr, cluster_manager_->tcpConnForCluster("cluster_1").connection_); + EXPECT_EQ(2UL, stats_.counter("cluster.cluster_1.upstream_cx_none_healthy").value()); // Set up for an initialize callback. ReadyWatcher initialized; - cluster_manager.setInitializedCb([&]() -> void { initialized.ready(); }); + cluster_manager_->setInitializedCb([&]() -> void { initialized.ready(); }); EXPECT_CALL(initialized, ready()); dns_callback({"127.0.0.1", "127.0.0.2"}); @@ -75,17 +218,17 @@ TEST(ClusterManagerImplTest, DynamicHostRemove) { // After we are initialized, we should immediately get called back if someone asks for an // initialize callback. EXPECT_CALL(initialized, ready()); - cluster_manager.setInitializedCb([&]() -> void { initialized.ready(); }); + cluster_manager_->setInitializedCb([&]() -> void { initialized.ready(); }); - EXPECT_CALL(cluster_manager, allocateConnPool_(_)) + EXPECT_CALL(*cluster_manager_, allocateConnPool_(_)) .Times(2) .WillRepeatedly(ReturnNew()); // This should provide us a CP for each of the above hosts. Http::ConnectionPool::MockInstance* cp1 = dynamic_cast( - cluster_manager.httpConnPoolForCluster("cluster_1")); + cluster_manager_->httpConnPoolForCluster("cluster_1")); Http::ConnectionPool::MockInstance* cp2 = dynamic_cast( - cluster_manager.httpConnPoolForCluster("cluster_1")); + cluster_manager_->httpConnPoolForCluster("cluster_1")); EXPECT_NE(cp1, cp2); @@ -100,7 +243,7 @@ TEST(ClusterManagerImplTest, DynamicHostRemove) { // Make sure we get back the same connection pool for the 2nd host as we did before the change. Http::ConnectionPool::MockInstance* cp3 = dynamic_cast( - cluster_manager.httpConnPoolForCluster("cluster_1")); + cluster_manager_->httpConnPoolForCluster("cluster_1")); EXPECT_EQ(cp2, cp3); // Now add and remove a host that we never have a conn pool to. This should not lead to any diff --git a/test/common/upstream/sds_test.cc b/test/common/upstream/sds_test.cc index 3979e3d3f475..8e16e3fb42d1 100644 --- a/test/common/upstream/sds_test.cc +++ b/test/common/upstream/sds_test.cc @@ -12,7 +12,6 @@ using testing::DoAll; using testing::Invoke; using testing::NiceMock; using testing::Return; -using testing::ReturnNew; using testing::SaveArg; using testing::WithArg; @@ -20,7 +19,9 @@ namespace Upstream { class SdsTest : public testing::Test { protected: - SdsTest() : sds_config_{"us-east-1a", "sds", std::chrono::milliseconds(30000)} { + SdsTest() + : sds_config_{"us-east-1a", "sds", std::chrono::milliseconds(30000)}, + request_(&cm_.async_client_) { std::string raw_config = R"EOF( { "name": "name", @@ -60,9 +61,8 @@ class SdsTest : public testing::Test { } void setupPoolFailure() { - client_ = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("sds")).WillOnce(Return(client_)); - EXPECT_CALL(*client_, send_(_, _, _)) + EXPECT_CALL(cm_, httpAsyncClientForCluster("sds")).WillOnce(ReturnRef(cm_.async_client_)); + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) .WillOnce(Invoke([](Http::MessagePtr&, Http::AsyncClient::Callbacks& callbacks, Optional) -> Http::AsyncClient::Request* { callbacks.onFailure(Http::AsyncClient::FailureReason::Reset); @@ -71,11 +71,9 @@ class SdsTest : public testing::Test { } void setupRequest() { - client_ = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("sds")).WillOnce(Return(client_)); - EXPECT_CALL(*client_, send_(_, _, _)) - .WillOnce(DoAll(WithArg<1>(SaveArgAddress(&callbacks_)), - ReturnNew>(client_))); + EXPECT_CALL(cm_, httpAsyncClientForCluster("sds")).WillOnce(ReturnRef(cm_.async_client_)); + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) + .WillOnce(DoAll(WithArg<1>(SaveArgAddress(&callbacks_)), Return(&request_))); } Stats::IsolatedStoreImpl stats_; @@ -85,15 +83,16 @@ class SdsTest : public testing::Test { Event::MockDispatcher dispatcher_; std::unique_ptr cluster_; Event::MockTimer* timer_; - Http::MockAsyncClient* client_; Http::AsyncClient::Callbacks* callbacks_; ReadyWatcher membership_updated_; NiceMock random_; + Http::MockAsyncClientRequest request_; }; TEST_F(SdsTest, Shutdown) { setupRequest(); cluster_->initialize(); + EXPECT_CALL(request_, cancel()); cluster_->shutdown(); } @@ -167,11 +166,6 @@ TEST_F(SdsTest, NoHealthChecker) { EXPECT_EQ(13UL, cluster_->hosts().size()); EXPECT_EQ(50U, canary_host->weight()); EXPECT_EQ(50UL, cluster_->stats().max_host_weight_.value()); - - // No healthy SDS hosts. - EXPECT_CALL(cm_, httpAsyncClientForCluster_("sds")).WillOnce(Return(nullptr)); - EXPECT_CALL(*timer_, enableTimer(_)); - timer_->callback_(); } TEST_F(SdsTest, HealthChecker) { diff --git a/test/mocks/http/mocks.h b/test/mocks/http/mocks.h index 85735f06a56f..6de18aa3d0de 100644 --- a/test/mocks/http/mocks.h +++ b/test/mocks/http/mocks.h @@ -281,9 +281,9 @@ class MockAsyncClient : public AsyncClient { MOCK_METHOD0(onRequestDestroy, void()); // Http::AsyncClient - RequestPtr send(MessagePtr&& request, Callbacks& callbacks, - const Optional& timeout) override { - return RequestPtr{send_(request, callbacks, timeout)}; + Request* send(MessagePtr&& request, Callbacks& callbacks, + const Optional& timeout) override { + return send_(request, callbacks, timeout); } MOCK_METHOD3(send_, Request*(MessagePtr& request, Callbacks& callbacks, diff --git a/test/mocks/upstream/mocks.h b/test/mocks/upstream/mocks.h index 31b91d8e8f38..a016feae258a 100644 --- a/test/mocks/upstream/mocks.h +++ b/test/mocks/upstream/mocks.h @@ -66,22 +66,18 @@ class MockClusterManager : public ClusterManager { return {Network::ClientConnectionPtr{data.connection_}, data.host_}; } - Http::AsyncClientPtr httpAsyncClientForCluster(const std::string& cluster) override { - return Http::AsyncClientPtr{httpAsyncClientForCluster_(cluster)}; - } - // Upstream::ClusterManager MOCK_METHOD1(setInitializedCb, void(std::function)); MOCK_METHOD0(clusters, std::unordered_map()); MOCK_METHOD1(get, const Cluster*(const std::string& cluster)); - MOCK_METHOD1(has, bool(const std::string& cluster)); MOCK_METHOD1(httpConnPoolForCluster, Http::ConnectionPool::Instance*(const std::string& cluster)); MOCK_METHOD1(tcpConnForCluster_, MockHost::MockCreateConnectionData(const std::string& cluster)); - MOCK_METHOD1(httpAsyncClientForCluster_, Http::AsyncClient*(const std::string& cluster)); + MOCK_METHOD1(httpAsyncClientForCluster, Http::AsyncClient&(const std::string& cluster)); MOCK_METHOD0(shutdown, void()); NiceMock conn_pool_; NiceMock cluster_; + NiceMock async_client_; }; class MockHealthChecker : public HealthChecker {