diff --git a/DEPRECATED.md b/DEPRECATED.md index c9b1e285bd56b..8e08bc7c8f14a 100644 --- a/DEPRECATED.md +++ b/DEPRECATED.md @@ -14,7 +14,8 @@ The following features have been DEPRECATED and will be removed in the specified instead. * gRPC service configuration via the `cluster_names` field in `ApiConfigSource` is deprecated. Use `grpc_services` instead. - +* 'use_original_dst' field in the v2 LDS API is deprecated. Use listerner filters and filter chain + matching instead. ## Version 1.5.0 diff --git a/bazel/repository_locations.bzl b/bazel/repository_locations.bzl index 32b921c98e8cc..ec2830e3f6e86 100644 --- a/bazel/repository_locations.bzl +++ b/bazel/repository_locations.bzl @@ -71,7 +71,7 @@ REPOSITORY_LOCATIONS = dict( urls = ["https://github.com/google/protobuf/archive/v3.5.0.tar.gz"], ), envoy_api = dict( - commit = "040b29a717eb5180c4a6797bb72f5a6ce2731363", + commit = "fd1a8c4269910caa2d99bf919c0ad13fb3d70f4f", remote = "https://github.com/envoyproxy/data-plane-api", ), grpc_httpjson_transcoding = dict( diff --git a/include/envoy/event/dispatcher.h b/include/envoy/event/dispatcher.h index 99a563a95aba8..5e0587eca9158 100644 --- a/include/envoy/event/dispatcher.h +++ b/include/envoy/event/dispatcher.h @@ -38,6 +38,16 @@ class Dispatcher { */ virtual void clearDeferredDeleteList() PURE; + /** + * Create a server connection. + * @param socket supplies an open file descriptor and connection metadata to use for the + * connection. Takes ownership of the socket. + * @param ssl_ctx supplies the SSL context to use, if not nullptr. + * @return Network::ConnectionPtr a server connection that is owned by the caller. + */ + virtual Network::ConnectionPtr createServerConnection(Network::ConnectionSocketPtr&& socket, + Ssl::Context* ssl_ctx) PURE; + /** * Create a client connection. * @param address supplies the address to connect to. @@ -78,32 +88,16 @@ class Dispatcher { /** * Create a listener on a specific port. - * @param conn_handler supplies the handler for connections received by the listener - * @param socket supplies the socket to listen on. - * @param cb supplies the callbacks to invoke for listener events. - * @param scope supplies the Stats::Scope to use. - * @param listener_options listener configuration options. - * @return Network::ListenerPtr a new listener that is owned by the caller. - */ - virtual Network::ListenerPtr - createListener(Network::ConnectionHandler& conn_handler, Network::ListenSocket& socket, - Network::ListenerCallbacks& cb, Stats::Scope& scope, - const Network::ListenerOptions& listener_options) PURE; - - /** - * Create a listener on a specific port. - * @param conn_handler supplies the handler for connections received by the listener - * @param ssl_ctx supplies the SSL context to use. * @param socket supplies the socket to listen on. * @param cb supplies the callbacks to invoke for listener events. - * @param scope supplies the Stats::Scope to use. - * @param listener_options listener configuration options. + * @param bind_to_port controls whether the listener binds to a transport port or not. + * @param hand_off_restored_destination_connections controls whether the listener searches for + * another listener after restoring the destination address of a new connection. * @return Network::ListenerPtr a new listener that is owned by the caller. */ - virtual Network::ListenerPtr - createSslListener(Network::ConnectionHandler& conn_handler, Ssl::ServerContext& ssl_ctx, - Network::ListenSocket& socket, Network::ListenerCallbacks& cb, - Stats::Scope& scope, const Network::ListenerOptions& listener_options) PURE; + virtual Network::ListenerPtr createListener(Network::ListenSocket& socket, + Network::ListenerCallbacks& cb, bool bind_to_port, + bool hand_off_restored_destination_connections) PURE; /** * Allocate a timer. @see Event::Timer for docs on how to use the timer. diff --git a/include/envoy/network/BUILD b/include/envoy/network/BUILD index 57e1c57d564e2..bf2e6982c04c9 100644 --- a/include/envoy/network/BUILD +++ b/include/envoy/network/BUILD @@ -73,6 +73,7 @@ envoy_cc_library( envoy_cc_library( name = "listener_interface", hdrs = ["listener.h"], + deps = ["//include/envoy/network:listen_socket_interface"], ) envoy_cc_library( diff --git a/include/envoy/network/connection.h b/include/envoy/network/connection.h index 35ce5ca4f44e9..f7e5d074b8b90 100644 --- a/include/envoy/network/connection.h +++ b/include/envoy/network/connection.h @@ -207,10 +207,10 @@ class Connection : public Event::DeferredDeletable, public FilterManager { virtual uint32_t bufferLimit() const PURE; /** - * @return boolean telling if the connection's local address is an original destination address, - * rather than the listener's address. + * @return boolean telling if the connection's local address has been restored to an original + * destination address, rather than the address the connection was accepted at. */ - virtual bool usingOriginalDst() const PURE; + virtual bool localAddressRestored() const PURE; /** * @return boolean telling if the connection is currently above the high watermark. diff --git a/include/envoy/network/connection_handler.h b/include/envoy/network/connection_handler.h index 354304baf2905..9d3f654d17ea9 100644 --- a/include/envoy/network/connection_handler.h +++ b/include/envoy/network/connection_handler.h @@ -26,28 +26,9 @@ class ConnectionHandler { /** * Adds listener to the handler. - * @param factory supplies the configuration factory for new connections. - * @param socket supplies the already bound socket to listen on. - * @param scope supplies the stats scope to use for listener specific stats. - * @param listener_tag supplies an opaque tag that can be used to stop or remove the listener. - * @param listener_options listener configuration options. + * @param config listener configuration options. */ - virtual void addListener(Network::FilterChainFactory& factory, Network::ListenSocket& socket, - Stats::Scope& scope, uint64_t listener_tag, - const Network::ListenerOptions& listener_options) PURE; - - /** - * Adds listener to the handler. - * @param factory supplies the configuration factory for new connections. - * @param socket supplies the already bound socket to listen on. - * @param scope supplies the stats scope to use for listener specific stats. - * @param listener_tag supplies an opaque tag that can be used to stop or remove the listener. - * @param listener_options listener configuration options. - */ - virtual void addSslListener(Network::FilterChainFactory& factory, Ssl::ServerContext& ssl_ctx, - Network::ListenSocket& socket, Stats::Scope& scope, - uint64_t listener_tag, - const Network::ListenerOptions& listener_options) PURE; + virtual void addListener(ListenerConfig& config) PURE; /** * Find a listener based on the provided listener address value. diff --git a/include/envoy/network/filter.h b/include/envoy/network/filter.h index 1114f77bdb4dc..202b16b51804f 100644 --- a/include/envoy/network/filter.h +++ b/include/envoy/network/filter.h @@ -9,6 +9,7 @@ namespace Envoy { namespace Network { class Connection; +class ConnectionSocket; /** * Status codes returned by filters that can cause future filters to not get iterated to. @@ -147,6 +148,67 @@ class FilterManager { virtual bool initializeReadFilters() PURE; }; +/** + * Callbacks used by individual listener filter instances to communicate with the listener filter + * manager. + */ +class ListenerFilterCallbacks { +public: + virtual ~ListenerFilterCallbacks() {} + + /** + * @return ConnectionSocket the socket the filter is operating on. + */ + virtual ConnectionSocket& socket() PURE; + + /** + * @return the Dispatcher for issuing events. + */ + virtual Event::Dispatcher& dispatcher() PURE; + + /** + * If a filter stopped filter iteration by returning FilterStatus::StopIteration, + * the filter should call continueFilterChain(true) when complete to continue the filter chain, + * or continueFilterChain(false) if the filter execution failed and the connection must be + * closed. + * @param success boolean telling whether the filter execution was successful or not. + */ + virtual void continueFilterChain(bool success) PURE; +}; + +/** + * Listener Filter + */ +class ListenerFilter { +public: + virtual ~ListenerFilter() {} + + /** + * Called when a new connection is accepted, but before a Connection is created. + * Filter chain iteration can be stopped if needed. + * @param cb the callbacks the filter instance can use to communicate with the filter chain. + * @return status used by the filter manager to manage further filter iteration. + */ + virtual FilterStatus onAccept(ListenerFilterCallbacks& cb) PURE; +}; + +typedef std::unique_ptr ListenerFilterPtr; + +/** + * Interface for filter callbacks and adding listener filters to a manager. + */ +class ListenerFilterManager { +public: + virtual ~ListenerFilterManager() {} + + /** + * Add a filter to the listener. Filters are invoked in FIFO order (the filter added + * first is called first). + * @param filter supplies the filter being added. + */ + virtual void addAcceptFilter(ListenerFilterPtr&& filter) PURE; +}; + /** * Creates a chain of network filters for a new connection. */ @@ -155,12 +217,19 @@ class FilterChainFactory { virtual ~FilterChainFactory() {} /** - * Called to create the filter chain. + * Called to create the network filter chain. * @param connection supplies the connection to create the chain on. * @return true if filter chain was created successfully. Otherwise * false, e.g. filter chain is empty. */ - virtual bool createFilterChain(Connection& connection) PURE; + virtual bool createNetworkFilterChain(Connection& connection) PURE; + + /** + * Called to create the listener filter chain. + * @param listener supplies the listener to create the chain on. + * @return true if filter chain was created successfully. Otherwise false. + */ + virtual bool createListenerFilterChain(ListenerFilterManager& listener) PURE; }; } // namespace Network diff --git a/include/envoy/network/listen_socket.h b/include/envoy/network/listen_socket.h index 2d0cb689f79ca..9ccd3b9124c53 100644 --- a/include/envoy/network/listen_socket.h +++ b/include/envoy/network/listen_socket.h @@ -34,5 +34,65 @@ class ListenSocket { typedef std::unique_ptr ListenSocketPtr; typedef std::shared_ptr ListenSocketSharedPtr; +/** + * A socket passed to a connection. For server connections this represents the accepted socket, and + * for client connections this represents the socket being connected to a remote address. + * + * TODO(jrajahalme): Hide internals (e.g., fd) from listener filters by providing callbacks filters + * may need (set/getsockopt(), peek(), recv(), etc.) + */ +class ConnectionSocket { +public: + virtual ~ConnectionSocket() {} + + /** + * @return the local address of the socket. + */ + virtual const Address::InstanceConstSharedPtr& localAddress() const PURE; + + /** + * @return the remote address of the socket. + */ + virtual const Address::InstanceConstSharedPtr& remoteAddress() const PURE; + + /** + * Set the local address of the socket. On accepted sockets the local address defaults to the + * one at which the connection was received at, which is the same as the listener's address, if + * the listener is bound to a specific address. + * + * @param local_address the new local address. + * @param restored a flag marking the local address as being restored to a value that is + * different from the one the socket was initially accepted at. This should only be set + * to 'true' when restoring the original destination address of a connection redirected + * by iptables REDIRECT. The caller is responsible for making sure the new address is + * actually different when passing restored as 'true'. + */ + virtual void setLocalAddress(const Address::InstanceConstSharedPtr& local_address, + bool restored = false) PURE; + + /** + * Set the remote address of the socket. + */ + virtual void setRemoteAddress(const Address::InstanceConstSharedPtr& remote_address) PURE; + + /** + * @return true if the local address has been restored to a value that is different from the + * address the socket was initially accepted at. + */ + virtual bool localAddressRestored() const PURE; + + /** + * @return fd the socket's file descriptor. + */ + virtual int fd() const PURE; + + /** + * Close the underlying socket. + */ + virtual void close() PURE; +}; + +typedef std::unique_ptr ConnectionSocketPtr; + } // namespace Network } // namespace Envoy diff --git a/include/envoy/network/listener.h b/include/envoy/network/listener.h index 33e57ae0f1cf5..16e1327840e09 100644 --- a/include/envoy/network/listener.h +++ b/include/envoy/network/listener.h @@ -6,39 +6,12 @@ #include "envoy/common/exception.h" #include "envoy/network/connection.h" +#include "envoy/network/listen_socket.h" #include "envoy/ssl/context.h" namespace Envoy { namespace Network { -/** - * Listener configurations options. - */ -struct ListenerOptions { - // Specifies if the listener should actually bind to the port. A listener that doesn't bind can - // only receive connections redirected from other listeners that set use_origin_dst_ to true. - bool bind_to_port_; - // Whether to use the PROXY Protocol V1 - // (http://www.haproxy.org/download/1.5/doc/proxy-protocol.txt) - bool use_proxy_proto_; - // If a connection was redirected to this port using iptables, allow the listener to hand it off - // to the listener associated to the original port. - bool use_original_dst_; - // Soft limit on size of the listener's new connection read and write buffers. - uint32_t per_connection_buffer_limit_bytes_; - - /** - * Factory for ListenerOptions with bind_to_port_ set. - * @return ListenerOptions object initialized with bind_to_port_ set. - */ - static ListenerOptions listenerOptionsWithBindToPort() { - return {.bind_to_port_ = true, - .use_proxy_proto_ = false, - .use_original_dst_ = false, - .per_connection_buffer_limit_bytes_ = 0}; - } -}; - /** * A configuration for an individual listener. */ @@ -63,12 +36,6 @@ class ListenerConfig { */ virtual Ssl::ServerContext* defaultSslContext() PURE; - /** - * @return bool whether to use the PROXY Protocol V1 - * (http://www.haproxy.org/download/1.5/doc/proxy-protocol.txt) - */ - virtual bool useProxyProto() PURE; - /** * @return bool specifies whether the listener should actually listen on the port. * A listener that doesn't listen on a port can only receive connections @@ -77,10 +44,12 @@ class ListenerConfig { virtual bool bindToPort() PURE; /** - * @return bool if a connection was redirected to this listener address using iptables, - * allow the listener to hand it off to the listener associated to the original address + * @return bool if a connection should be handed off to another Listener after the original + * destination address has been restored. 'true' when 'use_original_dst' flag in listener + * configuration is set, false otherwise. Note that this flag is deprecated and will be + * removed from the v2 API. */ - virtual bool useOriginalDst() PURE; + virtual bool handOffRestoredDestinationConnections() const PURE; /** * @return uint32_t providing a soft limit on size of the listener's new connection read and write @@ -96,7 +65,7 @@ class ListenerConfig { /** * @return uint64_t the tag the listener should use for connection handler tracking. */ - virtual uint64_t listenerTag() PURE; + virtual uint64_t listenerTag() const PURE; /** * @return const std::string& the listener's name. @@ -111,6 +80,16 @@ class ListenerCallbacks { public: virtual ~ListenerCallbacks() {} + /** + * Called when a new connection is accepted. + * @param socket supplies the socket that is moved into the callee. + * @param redirected is true when the socket was first accepted by another listener + * and is redirected to a new listener. The recipient should not redirect + * the socket any further. + */ + virtual void onAccept(ConnectionSocketPtr&& socket, + bool hand_off_restored_destination_connections = true) PURE; + /** * Called when a new connection is accepted. * @param new_connection supplies the new connection that is moved into the callee. diff --git a/include/envoy/network/transport_socket.h b/include/envoy/network/transport_socket.h index e7c5438d2e33c..7a2216c6d82f4 100644 --- a/include/envoy/network/transport_socket.h +++ b/include/envoy/network/transport_socket.h @@ -40,7 +40,7 @@ class TransportSocketCallbacks { /** * @return int the file descriptor associated with the connection. */ - virtual int fd() PURE; + virtual int fd() const PURE; /** * @return Network::Connection& the connection interface. diff --git a/include/envoy/server/filter_config.h b/include/envoy/server/filter_config.h index 459c668c98266..c3c09b28af9b6 100644 --- a/include/envoy/server/filter_config.h +++ b/include/envoy/server/filter_config.h @@ -123,6 +123,51 @@ class FactoryContext { virtual Stats::Scope& listenerScope() PURE; }; +/** + * This function is used to wrap the creation of a listener filter chain for new sockets as they are + * created. Filter factories create the lambda at configuration initialization time, and then they + * are used at runtime. + * @param filter_manager supplies the filter manager for the listener to install filters to. + * Typically the function will install a single filter, but it's technically possibly to install + * more than one if desired. + */ +typedef std::function ListenerFilterFactoryCb; + +/** + * Implemented by each listener filter and registered via Registry::registerFactory() + * or the convenience class RegisterFactory. + */ +class NamedListenerFilterConfigFactory { +public: + virtual ~NamedListenerFilterConfigFactory() {} + + /** + * Create a particular listener filter factory implementation. If the implementation is unable to + * produce a factory with the provided parameters, it should throw an EnvoyException in the case + * of general error or a Json::Exception if the json configuration is erroneous. The returned + * callback should always be initialized. + * @param config supplies the general protobuf configuration for the filter + * @param context supplies the filter's context. + * @return ListenerFilterFactoryCb the factory creation function. + */ + virtual ListenerFilterFactoryCb createFilterFactoryFromProto(const Protobuf::Message& config, + FactoryContext& context) PURE; + + /** + * @return ProtobufTypes::MessagePtr create empty config proto message for v2. The filter + * config, which arrives in an opaque google.protobuf.Struct message, will be converted to + * JSON and then parsed into this empty proto. Optional today, will be compulsory when v1 + * is deprecated. + */ + virtual ProtobufTypes::MessagePtr createEmptyConfigProto() PURE; + + /** + * @return std::string the identifying name for a particular implementation of a listener filter + * produced by the factory. + */ + virtual std::string name() PURE; +}; + /** * This function is used to wrap the creation of a network filter chain for new connections as * they come in. Filter factories create the lambda at configuration initialization time, and then diff --git a/include/envoy/server/listener_manager.h b/include/envoy/server/listener_manager.h index e1ae5645a2be4..fccafe7ed2aed 100644 --- a/include/envoy/server/listener_manager.h +++ b/include/envoy/server/listener_manager.h @@ -37,8 +37,18 @@ class ListenerComponentFactory { * @return std::vector the list of filter factories. */ virtual std::vector - createFilterFactoryList(const Protobuf::RepeatedPtrField& filters, - Configuration::FactoryContext& context) PURE; + createNetworkFilterFactoryList(const Protobuf::RepeatedPtrField& filters, + Configuration::FactoryContext& context) PURE; + + /** + * Creates a list of listener filter factories. + * @param filters supplies the JSON configuration. + * @param context supplies the factory creation context. + * @return std::vector the list of filter factories. + */ + virtual std::vector createListenerFilterFactoryList( + const Protobuf::RepeatedPtrField& filters, + Configuration::FactoryContext& context) PURE; /** * @return DrainManagerPtr a new drain manager. diff --git a/source/common/config/well_known_names.h b/source/common/config/well_known_names.h index c025fadee9087..d5e34f2532a0c 100644 --- a/source/common/config/well_known_names.h +++ b/source/common/config/well_known_names.h @@ -47,6 +47,19 @@ class V1Converter { std::unordered_map v1_to_v2_names_; }; +/** + * Well-known listener filter names. + */ +class ListenerFilterNameValues { +public: + // Original destination listener filter + const std::string ORIGINAL_DST = "envoy.listener.original_dst"; + // Proxy Protocol listener filter + const std::string PROXY_PROTOCOL = "envoy.listener.proxy_protocol"; +}; + +typedef ConstSingleton ListenerFilterNames; + /** * Well-known network filter names. */ diff --git a/source/common/event/dispatcher_impl.cc b/source/common/event/dispatcher_impl.cc index abea165f7212d..46377f5794baa 100644 --- a/source/common/event/dispatcher_impl.cc +++ b/source/common/event/dispatcher_impl.cc @@ -73,6 +73,15 @@ void DispatcherImpl::clearDeferredDeleteList() { deferred_deleting_ = false; } +Network::ConnectionPtr DispatcherImpl::createServerConnection(Network::ConnectionSocketPtr&& socket, + Ssl::Context* ssl_ctx) { + ASSERT(isThreadSafe()); + return Network::ConnectionPtr{ssl_ctx + ? new Ssl::ConnectionImpl(*this, std::move(socket), true, + *ssl_ctx, Ssl::InitialState::Server) + : new Network::ConnectionImpl(*this, std::move(socket), true)}; +} + Network::ClientConnectionPtr DispatcherImpl::createClientConnection(Network::Address::InstanceConstSharedPtr address, Network::Address::InstanceConstSharedPtr source_address, @@ -100,23 +109,11 @@ Filesystem::WatcherPtr DispatcherImpl::createFilesystemWatcher() { } Network::ListenerPtr -DispatcherImpl::createListener(Network::ConnectionHandler& conn_handler, - Network::ListenSocket& socket, Network::ListenerCallbacks& cb, - Stats::Scope& scope, - const Network::ListenerOptions& listener_options) { - ASSERT(isThreadSafe()); - return Network::ListenerPtr{ - new Network::ListenerImpl(conn_handler, *this, socket, cb, scope, listener_options)}; -} - -Network::ListenerPtr -DispatcherImpl::createSslListener(Network::ConnectionHandler& conn_handler, - Ssl::ServerContext& ssl_ctx, Network::ListenSocket& socket, - Network::ListenerCallbacks& cb, Stats::Scope& scope, - const Network::ListenerOptions& listener_options) { +DispatcherImpl::createListener(Network::ListenSocket& socket, Network::ListenerCallbacks& cb, + bool bind_to_port, bool hand_off_restored_destination_connections) { ASSERT(isThreadSafe()); - return Network::ListenerPtr{new Network::SslListenerImpl(conn_handler, *this, ssl_ctx, socket, cb, - scope, listener_options)}; + return Network::ListenerPtr{new Network::ListenerImpl(*this, socket, cb, bind_to_port, + hand_off_restored_destination_connections)}; } TimerPtr DispatcherImpl::createTimer(TimerCb cb) { diff --git a/source/common/event/dispatcher_impl.h b/source/common/event/dispatcher_impl.h index eeb2b08a25a98..6aa4342aece43 100644 --- a/source/common/event/dispatcher_impl.h +++ b/source/common/event/dispatcher_impl.h @@ -33,6 +33,8 @@ class DispatcherImpl : Logger::Loggable, public Dispatcher { // Event::Dispatcher void clearDeferredDeleteList() override; + Network::ConnectionPtr createServerConnection(Network::ConnectionSocketPtr&& socket, + Ssl::Context* ssl_ctx) override; Network::ClientConnectionPtr createClientConnection(Network::Address::InstanceConstSharedPtr address, Network::Address::InstanceConstSharedPtr source_address, @@ -42,14 +44,9 @@ class DispatcherImpl : Logger::Loggable, public Dispatcher { FileEventPtr createFileEvent(int fd, FileReadyCb cb, FileTriggerType trigger, uint32_t events) override; Filesystem::WatcherPtr createFilesystemWatcher() override; - Network::ListenerPtr createListener(Network::ConnectionHandler& conn_handler, - Network::ListenSocket& socket, Network::ListenerCallbacks& cb, - Stats::Scope& scope, - const Network::ListenerOptions& listener_options) override; - Network::ListenerPtr createSslListener(Network::ConnectionHandler& conn_handler, - Ssl::ServerContext& ssl_ctx, Network::ListenSocket& socket, - Network::ListenerCallbacks& cb, Stats::Scope& scope, - const Network::ListenerOptions& listener_options) override; + Network::ListenerPtr createListener(Network::ListenSocket& socket, Network::ListenerCallbacks& cb, + bool bind_to_port, + bool hand_off_restored_destination_connections) override; TimerPtr createTimer(TimerCb cb) override; void deferredDelete(DeferredDeletablePtr&& to_delete) override; void exit() override; diff --git a/source/common/filter/listener/BUILD b/source/common/filter/listener/BUILD new file mode 100644 index 0000000000000..059df357bc086 --- /dev/null +++ b/source/common/filter/listener/BUILD @@ -0,0 +1,39 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "original_dst_lib", + srcs = ["original_dst.cc"], + hdrs = ["original_dst.h"], + deps = [ + "//include/envoy/network:filter_interface", + "//include/envoy/network:listen_socket_interface", + "//source/common/common:assert_lib", + "//source/common/common:logger_lib", + "//source/common/network:utility_lib", + ], +) + +envoy_cc_library( + name = "proxy_protocol_lib", + srcs = ["proxy_protocol.cc"], + hdrs = ["proxy_protocol.h"], + deps = [ + "//include/envoy/event:dispatcher_interface", + "//include/envoy/network:filter_interface", + "//include/envoy/network:listen_socket_interface", + "//source/common/common:assert_lib", + "//source/common/common:empty_string", + "//source/common/common:logger_lib", + "//source/common/common:utility_lib", + "//source/common/network:address_lib", + "//source/common/network:utility_lib", + ], +) diff --git a/source/common/filter/listener/original_dst.cc b/source/common/filter/listener/original_dst.cc new file mode 100644 index 0000000000000..70c7e90b895d2 --- /dev/null +++ b/source/common/filter/listener/original_dst.cc @@ -0,0 +1,39 @@ +#include "common/filter/listener/original_dst.h" + +#include "envoy/network/listen_socket.h" + +#include "common/common/assert.h" +#include "common/network/utility.h" + +namespace Envoy { +namespace Filter { +namespace Listener { + +Network::Address::InstanceConstSharedPtr OriginalDst::getOriginalDst(int fd) { + return Network::Utility::getOriginalDst(fd); +} + +Network::FilterStatus OriginalDst::onAccept(Network::ListenerFilterCallbacks& cb) { + ENVOY_LOG(debug, "original_dst: New connection accepted"); + Network::ConnectionSocket& socket = cb.socket(); + const Network::Address::Instance& local_address = *socket.localAddress(); + + if (local_address.type() == Network::Address::Type::Ip) { + Network::Address::InstanceConstSharedPtr original_local_address = getOriginalDst(socket.fd()); + + // A listener that has the use_original_dst flag set to true can still receive + // connections that are NOT redirected using iptables. If a connection was not redirected, + // the address returned by getOriginalDst() matches the local address of the new socket. + // In this case the listener handles the connection directly and does not hand it off. + if (original_local_address && (*original_local_address != local_address)) { + // Restore the local address to the original one. + socket.setLocalAddress(original_local_address, true); + } + } + + return Network::FilterStatus::Continue; +} + +} // namespace Listener +} // namespace Filter +} // namespace Envoy diff --git a/source/common/filter/listener/original_dst.h b/source/common/filter/listener/original_dst.h new file mode 100644 index 0000000000000..0cd83157ab6e2 --- /dev/null +++ b/source/common/filter/listener/original_dst.h @@ -0,0 +1,24 @@ +#pragma once + +#include "envoy/network/filter.h" + +#include "common/common/logger.h" + +namespace Envoy { +namespace Filter { +namespace Listener { + +/** + * Implementation of an original destination listener filter. + */ +class OriginalDst : public Network::ListenerFilter, Logger::Loggable { +public: + virtual Network::Address::InstanceConstSharedPtr getOriginalDst(int fd); + + // Network::ListenerFilter + Network::FilterStatus onAccept(Network::ListenerFilterCallbacks& cb) override; +}; + +} // namespace Listener +} // namespace Filter +} // namespace Envoy diff --git a/source/common/filter/listener/proxy_protocol.cc b/source/common/filter/listener/proxy_protocol.cc new file mode 100644 index 0000000000000..d10415e3e7694 --- /dev/null +++ b/source/common/filter/listener/proxy_protocol.cc @@ -0,0 +1,162 @@ +#include "common/filter/listener/proxy_protocol.h" + +#include + +#include +#include +#include + +#include "envoy/common/exception.h" +#include "envoy/event/dispatcher.h" +#include "envoy/network/listen_socket.h" +#include "envoy/stats/stats.h" + +#include "common/common/assert.h" +#include "common/common/empty_string.h" +#include "common/common/utility.h" +#include "common/network/address_impl.h" +#include "common/network/utility.h" + +namespace Envoy { +namespace Filter { +namespace Listener { +namespace ProxyProtocol { + +Config::Config(Stats::Scope& scope) : stats_{ALL_PROXY_PROTOCOL_STATS(POOL_COUNTER(scope))} {} + +Network::FilterStatus Instance::onAccept(Network::ListenerFilterCallbacks& cb) { + ENVOY_LOG(debug, "proxy_protocol: New connection accepted"); + Network::ConnectionSocket& socket = cb.socket(); + ASSERT(file_event_.get() == nullptr); + file_event_ = + cb.dispatcher().createFileEvent(socket.fd(), + [this](uint32_t events) { + ASSERT(events == Event::FileReadyType::Read); + UNREFERENCED_PARAMETER(events); + onRead(); + }, + Event::FileTriggerType::Edge, Event::FileReadyType::Read); + cb_ = &cb; + return Network::FilterStatus::StopIteration; +} + +void Instance::onRead() { + try { + onReadWorker(); + } catch (const EnvoyException& ee) { + config_->stats_.downstream_cx_proxy_proto_error_.inc(); + cb_->continueFilterChain(false); + } +} + +void Instance::onReadWorker() { + Network::ConnectionSocket& socket = cb_->socket(); + std::string proxy_line; + if (!readLine(socket.fd(), proxy_line)) { + return; + } + + const auto trimmed_proxy_line = StringUtil::rtrim(proxy_line); + + // Parse proxy protocol line with format: PROXY TCP4/TCP6/UNKNOWN SOURCE_ADDRESS + // DESTINATION_ADDRESS SOURCE_PORT DESTINATION_PORT. + const auto line_parts = StringUtil::splitToken(trimmed_proxy_line, " ", true); + if (line_parts.size() < 2 || line_parts[0] != "PROXY") { + throw EnvoyException("failed to read proxy protocol"); + } + + // If the line starts with UNKNOWN we know it's a proxy protocol line, so we can remove it from + // the socket and continue. According to spec "real connection's parameters" should be used, so + // we should NOT restore the addresses in this case. + if (line_parts[1] != "UNKNOWN") { + // If protocol not UNKNOWN, src and dst addresses have to be present. + if (line_parts.size() != 6) { + throw EnvoyException("failed to read proxy protocol"); + } + + Network::Address::IpVersion protocol_version; + Network::Address::InstanceConstSharedPtr remote_address; + Network::Address::InstanceConstSharedPtr local_address; + + // TODO(gsagula): parseInternetAddressAndPort() could be modified to take two string_view + // arguments, so we can eliminate allocation here. + if (line_parts[1] == "TCP4") { + protocol_version = Network::Address::IpVersion::v4; + remote_address = Network::Utility::parseInternetAddressAndPort( + std::string{line_parts[2]} + ":" + std::string{line_parts[4]}); + local_address = Network::Utility::parseInternetAddressAndPort( + std::string{line_parts[3]} + ":" + std::string{line_parts[5]}); + } else if (line_parts[1] == "TCP6") { + protocol_version = Network::Address::IpVersion::v6; + remote_address = Network::Utility::parseInternetAddressAndPort( + "[" + std::string{line_parts[2]} + "]:" + std::string{line_parts[4]}); + local_address = Network::Utility::parseInternetAddressAndPort( + "[" + std::string{line_parts[3]} + "]:" + std::string{line_parts[5]}); + } else { + throw EnvoyException("failed to read proxy protocol"); + } + + // Error check the source and destination fields. Most errors are caught by the address + // parsing above, but a malformed IPv6 address may combine with a malformed port and parse as + // an IPv6 address when parsing for an IPv4 address. Remote address refers to the source + // address. + const auto remote_version = remote_address->ip()->version(); + const auto local_version = local_address->ip()->version(); + if (remote_version != protocol_version || local_version != protocol_version) { + throw EnvoyException("failed to read proxy protocol"); + } + // Check that both addresses are valid unicast addresses, as required for TCP + if (!remote_address->ip()->isUnicastAddress() || !local_address->ip()->isUnicastAddress()) { + throw EnvoyException("failed to read proxy protocol"); + } + + socket.setLocalAddress(local_address); + socket.setRemoteAddress(remote_address); + } + + // Release the file event so that we do not interfere with the connection read events. + file_event_.reset(); + cb_->continueFilterChain(true); +} + +bool Instance::readLine(int fd, std::string& s) { + while (buf_off_ < MAX_PROXY_PROTO_LEN) { + ssize_t nread = recv(fd, buf_ + buf_off_, MAX_PROXY_PROTO_LEN - buf_off_, MSG_PEEK); + + if (nread == -1 && errno == EAGAIN) { + return false; + } else if (nread < 1) { + throw EnvoyException("failed to read proxy protocol"); + } + + bool found = false; + // continue searching buf_ from where we left off + for (; search_index_ < buf_off_ + nread; search_index_++) { + if (buf_[search_index_] == '\n' && buf_[search_index_ - 1] == '\r') { + search_index_++; + found = true; + break; + } + } + + // Read the data upto and including the line feed, if available, but not past it. + // This should never fail, as search_index_ - buf_off_ <= nread, so we're asking + // only for bytes we have already seen. + nread = recv(fd, buf_ + buf_off_, search_index_ - buf_off_, 0); + ASSERT(size_t(nread) == search_index_ - buf_off_); + + buf_off_ += nread; + + if (found) { + s.assign(buf_, buf_off_); + return true; + } + } + + throw EnvoyException("failed to read proxy protocol"); +} + +} // namespace ProxyProtocol +} // namespace Listener +} // namespace Filter +} // namespace Envoy diff --git a/source/common/filter/listener/proxy_protocol.h b/source/common/filter/listener/proxy_protocol.h new file mode 100644 index 0000000000000..98a4b6defbfc3 --- /dev/null +++ b/source/common/filter/listener/proxy_protocol.h @@ -0,0 +1,83 @@ +#pragma once + +#include "envoy/event/file_event.h" +#include "envoy/network/filter.h" +#include "envoy/stats/stats_macros.h" + +#include "common/common/logger.h" + +namespace Envoy { +namespace Filter { +namespace Listener { +namespace ProxyProtocol { + +/** + * All stats for the proxy protocol. @see stats_macros.h + */ +// clang-format off +#define ALL_PROXY_PROTOCOL_STATS(COUNTER) \ + COUNTER(downstream_cx_proxy_proto_error) +// clang-format on + +/** + * Definition of all stats for the proxy protocol. @see stats_macros.h + */ +struct ProxyProtocolStats { + ALL_PROXY_PROTOCOL_STATS(GENERATE_COUNTER_STRUCT) +}; + +/** + * Global configuration for Proxy Protocol listener filter. + */ +class Config { +public: + Config(Stats::Scope& scope); + + ProxyProtocolStats stats_; +}; + +typedef std::shared_ptr ConfigSharedPtr; + +/** + * Implementation the PROXY Protocol V1 listener filter + * (http://www.haproxy.org/download/1.5/doc/proxy-protocol.txt) + */ +class Instance : public Network::ListenerFilter, Logger::Loggable { +public: + Instance(const ConfigSharedPtr& config) : config_(config) {} + + // Network::ListenerFilter + Network::FilterStatus onAccept(Network::ListenerFilterCallbacks& cb) override; + +private: + static const size_t MAX_PROXY_PROTO_LEN = 108; + + void onRead(); + void onReadWorker(); + + /** + * Helper function that attempts to read a line (delimited by '\r\n') from the socket. + * throws EnvoyException on any socket errors. + * @return bool true if a line should be read, false if more data is needed. + */ + bool readLine(int fd, std::string& s); + + Network::ListenerFilterCallbacks* cb_{}; + Event::FileEventPtr file_event_; + + // The offset in buf_ that has been fully read + size_t buf_off_{}; + + // The index in buf_ where the search for '\r\n' should continue from + size_t search_index_{1}; + + // Stores the portion of the first line that has been read so far. + char buf_[MAX_PROXY_PROTO_LEN]; + + ConfigSharedPtr config_; +}; + +} // namespace ProxyProtocol +} // namespace Listener +} // namespace Filter +} // namespace Envoy diff --git a/source/common/network/BUILD b/source/common/network/BUILD index ed4961b6ef87d..607a3b248b255 100644 --- a/source/common/network/BUILD +++ b/source/common/network/BUILD @@ -57,6 +57,7 @@ envoy_cc_library( "//source/common/common:enum_to_int", "//source/common/common:logger_lib", "//source/common/event:libevent_lib", + "//source/common/network:listen_socket_lib", "//source/common/ssl:ssl_socket_lib", ], ) @@ -111,29 +112,23 @@ envoy_cc_library( name = "listener_lib", srcs = [ "listener_impl.cc", - "proxy_protocol.cc", ], hdrs = [ "listener_impl.h", - "proxy_protocol.h", ], deps = [ ":address_lib", - ":connection_lib", ":listen_socket_lib", - ":utility_lib", "//include/envoy/event:dispatcher_interface", "//include/envoy/event:file_event_interface", - "//include/envoy/network:connection_handler_interface", "//include/envoy/network:listener_interface", "//include/envoy/stats:stats_interface", "//include/envoy/stats:stats_macros", + "//source/common/common:assert_lib", "//source/common/common:empty_string", "//source/common/common:linked_object", - "//source/common/common:utility_lib", "//source/common/event:dispatcher_includes", "//source/common/event:libevent_lib", - "//source/common/ssl:connection_lib", ], ) diff --git a/source/common/network/connection_impl.cc b/source/common/network/connection_impl.cc index 5e69e56626a74..c375eebb2d06f 100644 --- a/source/common/network/connection_impl.cc +++ b/source/common/network/connection_impl.cc @@ -16,22 +16,13 @@ #include "common/common/empty_string.h" #include "common/common/enum_to_int.h" #include "common/network/address_impl.h" +#include "common/network/listen_socket_impl.h" #include "common/network/raw_buffer_socket.h" #include "common/network/utility.h" namespace Envoy { namespace Network { -namespace { -Address::InstanceConstSharedPtr getNullLocalAddress(const Address::Instance& address) { - if (address.type() == Address::Type::Ip && address.ip()->version() == Address::IpVersion::v6) { - return Utility::getIpv6AnyAddress(); - } - // Default to IPv4 any address. - return Utility::getIpv4AnyAddress(); -} -} // namespace - void ConnectionImplUtility::updateBufferStats(uint64_t delta, uint64_t new_total, uint64_t& previous_total, Stats::Counter& stat_total, Stats::Gauge& stat_current) { @@ -52,34 +43,22 @@ void ConnectionImplUtility::updateBufferStats(uint64_t delta, uint64_t new_total std::atomic ConnectionImpl::next_global_id_; -ConnectionImpl::ConnectionImpl(Event::Dispatcher& dispatcher, int fd, - Address::InstanceConstSharedPtr remote_address, - Address::InstanceConstSharedPtr local_address, - Address::InstanceConstSharedPtr bind_to_address, - bool using_original_dst, bool connected) - : ConnectionImpl(dispatcher, fd, remote_address, local_address, bind_to_address, - TransportSocketPtr{new RawBufferSocket}, using_original_dst, connected) {} - -ConnectionImpl::ConnectionImpl(Event::Dispatcher& dispatcher, int fd, - Address::InstanceConstSharedPtr remote_address, - Address::InstanceConstSharedPtr local_address, - Address::InstanceConstSharedPtr bind_to_address, - TransportSocketPtr&& transport_socket, bool using_original_dst, +ConnectionImpl::ConnectionImpl(Event::Dispatcher& dispatcher, ConnectionSocketPtr&& socket, bool connected) - : transport_socket_(std::move(transport_socket)), filter_manager_(*this, *this), - remote_address_(remote_address), - local_address_((local_address == nullptr) ? getNullLocalAddress(*remote_address) - : local_address), + : ConnectionImpl(dispatcher, std::move(socket), std::make_unique(), + connected) {} - write_buffer_( - dispatcher.getWatermarkFactory().create([this]() -> void { this->onLowWatermark(); }, - [this]() -> void { this->onHighWatermark(); })), - dispatcher_(dispatcher), fd_(fd), id_(++next_global_id_), - using_original_dst_(using_original_dst) { +ConnectionImpl::ConnectionImpl(Event::Dispatcher& dispatcher, ConnectionSocketPtr&& socket, + TransportSocketPtr&& transport_socket, bool connected) + : transport_socket_(std::move(transport_socket)), filter_manager_(*this, *this), + socket_(std::move(socket)), write_buffer_(dispatcher.getWatermarkFactory().create( + [this]() -> void { this->onLowWatermark(); }, + [this]() -> void { this->onHighWatermark(); })), + dispatcher_(dispatcher), id_(++next_global_id_) { // Treat the lack of a valid fd (which in practice only happens if we run out of FDs) as an OOM // condition and just crash. - RELEASE_ASSERT(fd_ != -1); + RELEASE_ASSERT(fd() != -1); if (!connected) { connecting_ = true; @@ -88,28 +67,14 @@ ConnectionImpl::ConnectionImpl(Event::Dispatcher& dispatcher, int fd, // We never ask for both early close and read at the same time. If we are reading, we want to // consume all available data. file_event_ = dispatcher_.createFileEvent( - fd_, [this](uint32_t events) -> void { onFileEvent(events); }, Event::FileTriggerType::Edge, + fd(), [this](uint32_t events) -> void { onFileEvent(events); }, Event::FileTriggerType::Edge, Event::FileReadyType::Read | Event::FileReadyType::Write); - if (bind_to_address != nullptr) { - int rc = bind_to_address->bind(fd); - if (rc < 0) { - ENVOY_LOG_MISC(debug, "Bind failure. Failed to bind to {}: {}", bind_to_address->asString(), - strerror(errno)); - // Set a special error state to ensure asynchronous close to give the owner of the - // ConnectionImpl a chance to add callbacks and detect the "disconnect" - bind_error_ = true; - - // Trigger a write event to close this connection out-of-band. - file_event_->activate(Event::FileReadyType::Write); - } - } - transport_socket_->setTransportSocketCallbacks(*this); } ConnectionImpl::~ConnectionImpl() { - ASSERT(fd_ == -1); + ASSERT(fd() == -1); // In general we assume that owning code has called close() previously to the destructor being // run. This generally must be done so that callbacks run in the correct context (vs. deferred @@ -131,7 +96,7 @@ void ConnectionImpl::addReadFilter(ReadFilterSharedPtr filter) { bool ConnectionImpl::initializeReadFilters() { return filter_manager_.initializeReadFilters(); } void ConnectionImpl::close(ConnectionCloseType type) { - if (fd_ == -1) { + if (fd() == -1) { return; } @@ -156,7 +121,7 @@ void ConnectionImpl::close(ConnectionCloseType type) { } Connection::State ConnectionImpl::state() const { - if (fd_ == -1) { + if (fd() == -1) { return State::Closed; } else if (close_with_flush_) { return State::Closing; @@ -166,7 +131,7 @@ Connection::State ConnectionImpl::state() const { } void ConnectionImpl::closeSocket(ConnectionEvent close_type) { - if (fd_ == -1) { + if (fd() == -1) { return; } @@ -179,8 +144,7 @@ void ConnectionImpl::closeSocket(ConnectionEvent close_type) { connection_stats_.reset(); file_event_.reset(); - ::close(fd_); - fd_ = -1; + socket_->close(); raiseEvent(close_type); } @@ -195,14 +159,14 @@ void ConnectionImpl::noDelay(bool enable) { // invalid. For this call instead of plumbing through logic that will immediately indicate that a // connect failed, we will just ignore the noDelay() call if the socket is invalid since error is // going to be raised shortly anyway and it makes the calling code simpler. - if (fd_ == -1) { + if (fd() == -1) { return; } // Don't set NODELAY for unix domain sockets sockaddr addr; socklen_t len = sizeof(addr); - int rc = getsockname(fd_, &addr, &len); + int rc = getsockname(fd(), &addr, &len); RELEASE_ASSERT(rc == 0); if (addr.sa_family == AF_UNIX) { @@ -211,7 +175,7 @@ void ConnectionImpl::noDelay(bool enable) { // Set NODELAY int new_value = enable; - rc = setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &new_value, sizeof(new_value)); + rc = setsockopt(fd(), IPPROTO_TCP, TCP_NODELAY, &new_value, sizeof(new_value)); #ifdef __APPLE__ if (-1 == rc && errno == EINVAL) { // Sometimes occurs when the connection is not yet fully formed. Empirically, TCP_NODELAY is @@ -409,7 +373,7 @@ void ConnectionImpl::onFileEvent(uint32_t events) { // It's possible for a write event callback to close the socket (which will cause fd_ to be -1). // In this case ignore write event processing. - if (fd_ != -1 && (events & Event::FileReadyType::Read)) { + if (fd() != -1 && (events & Event::FileReadyType::Read)) { onReadReady(); } } @@ -441,7 +405,7 @@ void ConnectionImpl::onWriteReady() { if (connecting_) { int error; socklen_t error_size = sizeof(error); - int rc = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &error_size); + int rc = getsockopt(fd(), SOL_SOCKET, SO_ERROR, &error, &error_size); ASSERT(0 == rc); UNREFERENCED_PARAMETER(rc); @@ -478,39 +442,13 @@ void ConnectionImpl::onWriteReady() { cb(result.bytes_processed_); // If a callback closes the socket, stop iterating. - if (fd_ == -1) { + if (fd() == -1) { return; } } } } -void ConnectionImpl::doConnect() { - ENVOY_CONN_LOG(debug, "connecting to {}", *this, remote_address_->asString()); - int rc = remote_address_->connect(fd_); - if (rc == 0) { - // write will become ready. - ASSERT(connecting_); - } else { - ASSERT(rc == -1); - if (errno == EINPROGRESS) { - ASSERT(connecting_); - ENVOY_CONN_LOG(debug, "connection in progress", *this); - } else { - // read/write will become ready. - immediate_connection_error_ = true; - connecting_ = false; - ENVOY_CONN_LOG(debug, "immediate connection error: {}", *this, errno); - } - } - - // The local address can only be retrieved for IP connections. Other - // types, such as UDS, don't have a notion of a local address. - if (remote_address_->type() == Address::Type::Ip) { - local_address_ = Address::addressFromFd(fd_); - } -} - void ConnectionImpl::setConnectionStats(const ConnectionStats& stats) { ASSERT(!connection_stats_); connection_stats_.reset(new ConnectionStats(stats)); @@ -540,9 +478,48 @@ ClientConnectionImpl::ClientConnectionImpl( Event::Dispatcher& dispatcher, const Address::InstanceConstSharedPtr& remote_address, const Network::Address::InstanceConstSharedPtr& source_address, Network::TransportSocketPtr&& transport_socket) - : ConnectionImpl(dispatcher, remote_address->socket(Address::SocketType::Stream), - remote_address, nullptr, source_address, std::move(transport_socket), false, - false) {} + : ConnectionImpl(dispatcher, std::make_unique(remote_address), + std::move(transport_socket), false) { + if (source_address != nullptr) { + const int rc = source_address->bind(fd()); + if (rc < 0) { + ENVOY_LOG_MISC(debug, "Bind failure. Failed to bind to {}: {}", source_address->asString(), + strerror(errno)); + // Set a special error state to ensure asynchronous close to give the owner of the + // ConnectionImpl a chance to add callbacks and detect the "disconnect" + bind_error_ = true; + + // Trigger a write event to close this connection out-of-band. + file_event_->activate(Event::FileReadyType::Write); + } + } +} + +void ClientConnectionImpl::connect() { + ENVOY_CONN_LOG(debug, "connecting to {}", *this, socket_->remoteAddress()->asString()); + const int rc = socket_->remoteAddress()->connect(fd()); + if (rc == 0) { + // write will become ready. + ASSERT(connecting_); + } else { + ASSERT(rc == -1); + if (errno == EINPROGRESS) { + ASSERT(connecting_); + ENVOY_CONN_LOG(debug, "connection in progress", *this); + } else { + // read/write will become ready. + immediate_connection_error_ = true; + connecting_ = false; + ENVOY_CONN_LOG(debug, "immediate connection error: {}", *this, errno); + } + } + + // The local address can only be retrieved for IP connections. Other + // types, such as UDS, don't have a notion of a local address. + if (socket_->remoteAddress()->type() == Address::Type::Ip) { + socket_->setLocalAddress(Address::addressFromFd(fd())); + } +} } // namespace Network } // namespace Envoy diff --git a/source/common/network/connection_impl.h b/source/common/network/connection_impl.h index b9c0963f411a1..7cbd9f500e716 100644 --- a/source/common/network/connection_impl.h +++ b/source/common/network/connection_impl.h @@ -48,17 +48,10 @@ class ConnectionImpl : public virtual Connection, protected Logger::Loggable { public: // TODO(lizan): Remove the old style constructor when factory is ready. - ConnectionImpl(Event::Dispatcher& dispatcher, int fd, - Address::InstanceConstSharedPtr remote_address, - Address::InstanceConstSharedPtr local_address, - Address::InstanceConstSharedPtr bind_to_address, bool using_original_dst, - bool connected); - - ConnectionImpl(Event::Dispatcher& dispatcher, int fd, - Address::InstanceConstSharedPtr remote_address, - Address::InstanceConstSharedPtr local_address, - Address::InstanceConstSharedPtr bind_to_address, - TransportSocketPtr&& transport_socket, bool using_original_dst, bool connected); + ConnectionImpl(Event::Dispatcher& dispatcher, ConnectionSocketPtr&& socket, bool connected); + + ConnectionImpl(Event::Dispatcher& dispatcher, ConnectionSocketPtr&& socket, + TransportSocketPtr&& transport_socket, bool connected); ~ConnectionImpl(); @@ -79,8 +72,12 @@ class ConnectionImpl : public virtual Connection, void readDisable(bool disable) override; void detectEarlyCloseWhenReadDisabled(bool value) override { detect_early_close_ = value; } bool readEnabled() const override; - const Address::InstanceConstSharedPtr& remoteAddress() const override { return remote_address_; } - const Address::InstanceConstSharedPtr& localAddress() const override { return local_address_; } + const Address::InstanceConstSharedPtr& remoteAddress() const override { + return socket_->remoteAddress(); + } + const Address::InstanceConstSharedPtr& localAddress() const override { + return socket_->localAddress(); + } void setConnectionStats(const ConnectionStats& stats) override; Ssl::Connection* ssl() override { return transport_socket_->ssl(); } const Ssl::Connection* ssl() const override { return transport_socket_->ssl(); } @@ -88,7 +85,7 @@ class ConnectionImpl : public virtual Connection, void write(Buffer::Instance& data) override; void setBufferLimits(uint32_t limit) override; uint32_t bufferLimit() const override { return read_buffer_limit_; } - bool usingOriginalDst() const override { return using_original_dst_; } + bool localAddressRestored() const override { return socket_->localAddressRestored(); } bool aboveHighWatermark() const override { return above_high_watermark_; } // Network::BufferSource @@ -96,7 +93,7 @@ class ConnectionImpl : public virtual Connection, Buffer::Instance& getWriteBuffer() override { return *current_write_buffer_; } // Network::TransportSocketCallbacks - int fd() override { return fd_; } + int fd() const override { return socket_->fd(); } Connection& connection() override { return *this; } void raiseEvent(ConnectionEvent event) override; // Should the read buffer be drained? @@ -112,21 +109,25 @@ class ConnectionImpl : public virtual Connection, protected: void closeSocket(ConnectionEvent close_type); - void doConnect(); void onLowWatermark(); void onHighWatermark(); TransportSocketPtr transport_socket_; FilterManagerImpl filter_manager_; - Address::InstanceConstSharedPtr remote_address_; - Address::InstanceConstSharedPtr local_address_; + ConnectionSocketPtr socket_; Buffer::OwnedImpl read_buffer_; // This must be a WatermarkBuffer, but as it is created by a factory the ConnectionImpl only has // a generic pointer. Buffer::InstancePtr write_buffer_; uint32_t read_buffer_limit_ = 0; +protected: + bool connecting_{false}; + bool immediate_connection_error_{false}; + bool bind_error_{false}; + Event::FileEventPtr file_event_; + private: void onFileEvent(uint32_t events); void onRead(uint64_t read_buffer_size); @@ -138,17 +139,11 @@ class ConnectionImpl : public virtual Connection, static std::atomic next_global_id_; Event::Dispatcher& dispatcher_; - int fd_{-1}; - Event::FileEventPtr file_event_; const uint64_t id_; std::list callbacks_; std::list bytes_sent_callbacks_; bool read_enabled_{true}; - bool connecting_{false}; bool close_with_flush_{false}; - bool immediate_connection_error_{false}; - bool bind_error_{false}; - const bool using_original_dst_; bool above_high_watermark_{false}; bool detect_early_close_{true}; Buffer::Instance* current_write_buffer_{}; @@ -172,7 +167,7 @@ class ClientConnectionImpl : public ConnectionImpl, virtual public ClientConnect Network::TransportSocketPtr&& transport_socket); // Network::ClientConnection - void connect() override { doConnect(); } + void connect() override; }; } // namespace Network diff --git a/source/common/network/listen_socket_impl.h b/source/common/network/listen_socket_impl.h index e525b17e0c872..660f674b7ebb7 100644 --- a/source/common/network/listen_socket_impl.h +++ b/source/common/network/listen_socket_impl.h @@ -7,6 +7,7 @@ #include "envoy/network/listen_socket.h" +#include "common/common/assert.h" #include "common/ssl/context_impl.h" namespace Envoy { @@ -50,5 +51,56 @@ class UdsListenSocket : public ListenSocketImpl { UdsListenSocket(const std::string& uds_path); }; +class ConnectionSocketImpl : virtual public ConnectionSocket { +public: + ConnectionSocketImpl(int fd, const Address::InstanceConstSharedPtr& local_address, + const Address::InstanceConstSharedPtr& remote_address) + : fd_(fd), local_address_(local_address), remote_address_(remote_address) {} + ~ConnectionSocketImpl() { close(); } + + // Network::ConnectionSocket + const Address::InstanceConstSharedPtr& localAddress() const override { return local_address_; } + const Address::InstanceConstSharedPtr& remoteAddress() const override { return remote_address_; } + void setLocalAddress(const Address::InstanceConstSharedPtr& local_address, + bool restored) override { + ASSERT(!restored || *local_address != *local_address_); + local_address_ = local_address; + local_address_restored_ = restored; + } + void setRemoteAddress(const Address::InstanceConstSharedPtr& remote_address) override { + remote_address_ = remote_address; + } + bool localAddressRestored() const override { return local_address_restored_; } + int fd() const override { return fd_; } + void close() override { + if (fd_ != -1) { + ::close(fd_); + fd_ = -1; + } + } + +protected: + int fd_; + Address::InstanceConstSharedPtr local_address_; + Address::InstanceConstSharedPtr remote_address_; + bool local_address_restored_{false}; +}; + +// ConnectionSocket used with server connections. +class AcceptedSocketImpl : public ConnectionSocketImpl { +public: + AcceptedSocketImpl(int fd, const Address::InstanceConstSharedPtr& local_address, + const Address::InstanceConstSharedPtr& remote_address) + : ConnectionSocketImpl(fd, local_address, remote_address) {} +}; + +// ConnectionSocket used with client connections. +class ClientSocketImpl : public ConnectionSocketImpl { +public: + ClientSocketImpl(const Address::InstanceConstSharedPtr& remote_address) + : ConnectionSocketImpl(remote_address->socket(Address::SocketType::Stream), nullptr, + remote_address) {} +}; + } // namespace Network } // namespace Envoy diff --git a/source/common/network/listener_impl.cc b/source/common/network/listener_impl.cc index a3a9e395c4456..f5fa515d31625 100644 --- a/source/common/network/listener_impl.cc +++ b/source/common/network/listener_impl.cc @@ -3,15 +3,12 @@ #include #include "envoy/common/exception.h" -#include "envoy/network/connection_handler.h" +#include "common/common/assert.h" #include "common/common/empty_string.h" #include "common/event/dispatcher_impl.h" #include "common/event/file_event_impl.h" #include "common/network/address_impl.h" -#include "common/network/connection_impl.h" -#include "common/network/utility.h" -#include "common/ssl/connection_impl.h" #include "event2/listener.h" #include "fmt/format.h" @@ -23,77 +20,42 @@ Address::InstanceConstSharedPtr ListenerImpl::getLocalAddress(int fd) { return Address::addressFromFd(fd); } -Address::InstanceConstSharedPtr ListenerImpl::getOriginalDst(int fd) { - return Utility::getOriginalDst(fd); -} - void ListenerImpl::listenCallback(evconnlistener*, evutil_socket_t fd, sockaddr* remote_addr, int remote_addr_len, void* arg) { ListenerImpl* listener = static_cast(arg); - Address::InstanceConstSharedPtr final_local_address = listener->socket_.localAddress(); - bool using_original_dst = false; - - // Get the local address from the new socket if the listener is listening on the all hosts - // address (e.g., 0.0.0.0 for IPv4). - auto ip = final_local_address->ip(); - if (ip && ip->isAnyAddress()) { - final_local_address = listener->getLocalAddress(fd); - } - - if (listener->options_.use_original_dst_ && final_local_address->type() == Address::Type::Ip) { - Address::InstanceConstSharedPtr original_local_address = listener->getOriginalDst(fd); - - // A listener that has the use_original_dst flag set to true can still receive - // connections that are NOT redirected using iptables. If a connection was not redirected, - // the address returned by getOriginalDst() matches the local address of the new socket. - // In this case the listener handles the connection directly and does not hand it off. - if (original_local_address && (*original_local_address != *final_local_address)) { - final_local_address = original_local_address; - using_original_dst = true; - - // Hands off redirected connections (from iptables) to the listener associated with the - // original destination address. If there is no listener associated with the original - // destination address, the connection is handled by the listener that receives it. - ListenerImpl* new_listener = dynamic_cast( - listener->connection_handler_.findListenerByAddress(*original_local_address)); - - if (new_listener != nullptr) { - listener = new_listener; - } - } - } - - if (listener->options_.use_proxy_proto_) { - listener->proxy_protocol_.newConnection(listener->dispatcher_, fd, *listener); - } else { - Address::InstanceConstSharedPtr final_remote_address; - if (remote_addr->sa_family == AF_UNIX) { + ConnectionSocketPtr socket(new AcceptedSocketImpl( + fd, + // Get the local address from the new socket if the listener is listening on IP ANY + // (e.g., 0.0.0.0 for IPv4) (local_address_ is nullptr in this case). + !listener->local_address_ ? listener->getLocalAddress(fd) : listener->local_address_, // The accept() call that filled in remote_addr doesn't fill in more than the sa_family field // for Unix domain sockets; apparently there isn't a mechanism in the kernel to get the // sockaddr_un associated with the client socket when starting from the server socket. // We work around this by using our own name for the socket in this case. - final_remote_address = Address::peerAddressFromFd(fd); - } else { - final_remote_address = Address::addressFromSockAddr( - *reinterpret_cast(remote_addr), remote_addr_len); - } - // TODO(jamessynge): We need to keep per-family stats. BUT, should it be based on the original - // family or the local family? Probably local family, as the original proxy can take care of - // stats for the original family. - listener->newConnection(fd, final_remote_address, final_local_address, using_original_dst); - } + (remote_addr->sa_family == AF_UNIX) + ? Address::peerAddressFromFd(fd) + : Address::addressFromSockAddr(*reinterpret_cast(remote_addr), + remote_addr_len))); + listener->cb_.onAccept(std::move(socket), listener->hand_off_restored_destination_connections_); } -ListenerImpl::ListenerImpl(Network::ConnectionHandler& conn_handler, - Event::DispatcherImpl& dispatcher, ListenSocket& socket, - ListenerCallbacks& cb, Stats::Scope& scope, - const Network::ListenerOptions& listener_options) - : connection_handler_(conn_handler), dispatcher_(dispatcher), socket_(socket), cb_(cb), - proxy_protocol_(scope), options_(listener_options), listener_(nullptr) { +ListenerImpl::ListenerImpl(Event::DispatcherImpl& dispatcher, ListenSocket& socket, + ListenerCallbacks& cb, bool bind_to_port, + bool hand_off_restored_destination_connections) + : local_address_(nullptr), cb_(cb), + hand_off_restored_destination_connections_(hand_off_restored_destination_connections), + listener_(nullptr) { + const auto ip = socket.localAddress()->ip(); + + // Only use the listen socket's local address for new connections if it is not the all hosts + // address (e.g., 0.0.0.0 for IPv4). + if (!(ip && ip->isAnyAddress())) { + local_address_ = socket.localAddress(); + } - if (options_.bind_to_port_) { + if (bind_to_port) { listener_.reset( - evconnlistener_new(&dispatcher_.base(), listenCallback, this, 0, -1, socket.fd())); + evconnlistener_new(&dispatcher.base(), listenCallback, this, 0, -1, socket.fd())); if (!listener_) { throw CreateListenerException( @@ -110,25 +72,5 @@ void ListenerImpl::errorCallback(evconnlistener*, void*) { PANIC(fmt::format("listener accept failure: {}", strerror(errno))); } -void ListenerImpl::newConnection(int fd, Address::InstanceConstSharedPtr remote_address, - Address::InstanceConstSharedPtr local_address, - bool using_original_dst) { - ConnectionPtr new_connection(new ConnectionImpl(dispatcher_, fd, remote_address, local_address, - Network::Address::InstanceConstSharedPtr(), - using_original_dst, true)); - new_connection->setBufferLimits(options_.per_connection_buffer_limit_bytes_); - cb_.onNewConnection(std::move(new_connection)); -} - -void SslListenerImpl::newConnection(int fd, Address::InstanceConstSharedPtr remote_address, - Address::InstanceConstSharedPtr local_address, - bool using_original_dst) { - ConnectionPtr new_connection(new Ssl::ConnectionImpl( - dispatcher_, fd, remote_address, local_address, Network::Address::InstanceConstSharedPtr(), - using_original_dst, true, ssl_ctx_, Ssl::InitialState::Server)); - new_connection->setBufferLimits(options_.per_connection_buffer_limit_bytes_); - cb_.onNewConnection(std::move(new_connection)); -} - } // namespace Network } // namespace Envoy diff --git a/source/common/network/listener_impl.h b/source/common/network/listener_impl.h index e46dfcd99e341..8ff0bcdbccb4e 100644 --- a/source/common/network/listener_impl.h +++ b/source/common/network/listener_impl.h @@ -1,12 +1,10 @@ #pragma once -#include "envoy/network/connection_handler.h" #include "envoy/network/listener.h" #include "common/event/dispatcher_impl.h" #include "common/event/libevent.h" #include "common/network/listen_socket_impl.h" -#include "common/network/proxy_protocol.h" #include "event2/event.h" @@ -18,35 +16,15 @@ namespace Network { */ class ListenerImpl : public Listener { public: - ListenerImpl(Network::ConnectionHandler& conn_handler, Event::DispatcherImpl& dispatcher, - ListenSocket& socket, ListenerCallbacks& cb, Stats::Scope& scope, - const ListenerOptions& listener_options); - - /** - * Accept/process a new connection. - * @param fd supplies the new connection's fd. - * @param remote_address supplies the remote address for the new connection. - * @param local_address supplies the local address for the new connection. - */ - virtual void newConnection(int fd, Address::InstanceConstSharedPtr remote_address, - Address::InstanceConstSharedPtr local_address, - bool using_original_dst); - - /** - * @return the socket supplied to the listener at construction time - */ - ListenSocket& socket() { return socket_; } + ListenerImpl(Event::DispatcherImpl& dispatcher, ListenSocket& socket, ListenerCallbacks& cb, + bool bind_to_port, bool hand_off_restored_destination_connections); protected: virtual Address::InstanceConstSharedPtr getLocalAddress(int fd); - virtual Address::InstanceConstSharedPtr getOriginalDst(int fd); - Network::ConnectionHandler& connection_handler_; - Event::DispatcherImpl& dispatcher_; - ListenSocket& socket_; + Address::InstanceConstSharedPtr local_address_; ListenerCallbacks& cb_; - ProxyProtocol proxy_protocol_; - const ListenerOptions options_; + const bool hand_off_restored_destination_connections_; private: static void errorCallback(evconnlistener* listener, void* context); @@ -56,22 +34,5 @@ class ListenerImpl : public Listener { Event::Libevent::ListenerPtr listener_; }; -class SslListenerImpl : public ListenerImpl { -public: - SslListenerImpl(Network::ConnectionHandler& conn_handler, Event::DispatcherImpl& dispatcher, - Ssl::Context& ssl_ctx, ListenSocket& socket, ListenerCallbacks& cb, - Stats::Scope& scope, const Network::ListenerOptions& listener_options) - : ListenerImpl(conn_handler, dispatcher, socket, cb, scope, listener_options), - ssl_ctx_(ssl_ctx) {} - - // ListenerImpl - void newConnection(int fd, Address::InstanceConstSharedPtr remote_address, - Address::InstanceConstSharedPtr local_address, - bool using_original_dst) override; - -private: - Ssl::Context& ssl_ctx_; -}; - } // namespace Network } // namespace Envoy diff --git a/source/common/network/proxy_protocol.cc b/source/common/network/proxy_protocol.cc deleted file mode 100644 index eaf8e433d150c..0000000000000 --- a/source/common/network/proxy_protocol.cc +++ /dev/null @@ -1,190 +0,0 @@ -#include "common/network/proxy_protocol.h" - -#include - -#include -#include -#include - -#include "envoy/common/exception.h" -#include "envoy/event/dispatcher.h" -#include "envoy/event/file_event.h" -#include "envoy/stats/stats.h" - -#include "common/common/empty_string.h" -#include "common/common/utility.h" -#include "common/network/address_impl.h" -#include "common/network/listener_impl.h" -#include "common/network/utility.h" - -namespace Envoy { -namespace Network { - -ProxyProtocol::ProxyProtocol(Stats::Scope& scope) - : stats_{ALL_PROXY_PROTOCOL_STATS(POOL_COUNTER(scope))} {} - -void ProxyProtocol::newConnection(Event::Dispatcher& dispatcher, int fd, ListenerImpl& listener) { - std::unique_ptr p{new ActiveConnection(*this, dispatcher, fd, listener)}; - p->moveIntoList(std::move(p), connections_); -} - -ProxyProtocol::ActiveConnection::ActiveConnection(ProxyProtocol& parent, - Event::Dispatcher& dispatcher, int fd, - ListenerImpl& listener) - : parent_(parent), fd_(fd), listener_(listener), search_index_(1) { - file_event_ = - dispatcher.createFileEvent(fd, - [this](uint32_t events) { - ASSERT(events == Event::FileReadyType::Read); - UNREFERENCED_PARAMETER(events); - onRead(); - }, - Event::FileTriggerType::Edge, Event::FileReadyType::Read); -} - -ProxyProtocol::ActiveConnection::~ActiveConnection() { - if (fd_ != -1) { - ::close(fd_); - } -} - -void ProxyProtocol::ActiveConnection::onRead() { - try { - onReadWorker(); - } catch (const EnvoyException& ee) { - parent_.stats_.downstream_cx_proxy_proto_error_.inc(); - close(); - } -} - -void ProxyProtocol::ActiveConnection::onReadWorker() { - std::string proxy_line; - if (!readLine(fd_, proxy_line)) { - return; - } - - const auto trimmed_proxy_line = StringUtil::rtrim(proxy_line); - - // Parse proxy protocol line with format: PROXY TCP4/TCP6/UNKNOWN SOURCE_ADDRESS - // DESTINATION_ADDRESS SOURCE_PORT DESTINATION_PORT. - const auto line_parts = StringUtil::splitToken(trimmed_proxy_line, " ", true); - if (line_parts.size() < 2 || line_parts[0] != "PROXY") { - throw EnvoyException("failed to read proxy protocol"); - } - - if (line_parts[1] == "UNKNOWN") { - // At this point we know it's a proxy protocol line, so we can remove it from the socket - // and continue. - Address::InstanceConstSharedPtr local_address = Envoy::Network::Address::addressFromFd(fd_); - Address::InstanceConstSharedPtr remote_address; - // The remote address not known. - if (local_address->ip()->version() == Address::IpVersion::v4) { - remote_address = std::make_shared(Address::Ipv4Instance("0.0.0.0")); - } else { - remote_address = std::make_shared(Address::Ipv6Instance("::")); - } - finishConnection(remote_address, local_address); - return; - } - - // If protocol not UNKNOWN, src and dst addresses have to be present. - if (line_parts.size() != 6) { - throw EnvoyException("failed to read proxy protocol"); - } - - Address::IpVersion protocol_version; - Address::InstanceConstSharedPtr remote_address; - Address::InstanceConstSharedPtr local_address; - - // TODO(gsagula): parseInternetAddressAndPort() could be modified to take two string_view - // arguments, so we can eliminate allocation here. - if (line_parts[1] == "TCP4") { - protocol_version = Address::IpVersion::v4; - remote_address = Utility::parseInternetAddressAndPort(std::string{line_parts[2]} + ":" + - std::string{line_parts[4]}); - local_address = Utility::parseInternetAddressAndPort(std::string{line_parts[3]} + ":" + - std::string{line_parts[5]}); - } else if (line_parts[1] == "TCP6") { - protocol_version = Address::IpVersion::v6; - remote_address = Utility::parseInternetAddressAndPort("[" + std::string{line_parts[2]} + - "]:" + std::string{line_parts[4]}); - local_address = Utility::parseInternetAddressAndPort("[" + std::string{line_parts[3]} + - "]:" + std::string{line_parts[5]}); - } else { - throw EnvoyException("failed to read proxy protocol"); - } - - // Error check the source and destination fields. Most errors are caught by the address - // parsing above, but a malformed IPv6 address may combine with a malformed port and parse as - // an IPv6 address when parsing for an IPv4 address. Remote address refers to the source - // address. - const auto remote_version = remote_address->ip()->version(); - const auto local_version = local_address->ip()->version(); - if (remote_version != protocol_version || local_version != protocol_version) { - throw EnvoyException("failed to read proxy protocol"); - } - // Check that both addresses are valid unicast addresses, as required for TCP - if (!remote_address->ip()->isUnicastAddress() || !local_address->ip()->isUnicastAddress()) { - throw EnvoyException("failed to read proxy protocol"); - } - - finishConnection(remote_address, local_address); -} - -void ProxyProtocol::ActiveConnection::finishConnection( - Address::InstanceConstSharedPtr remote_address, Address::InstanceConstSharedPtr local_address) { - - ListenerImpl& listener = listener_; - int fd = fd_; - fd_ = -1; - - removeFromList(parent_.connections_); - - listener.newConnection(fd, remote_address, local_address, true); -} - -void ProxyProtocol::ActiveConnection::close() { - ::close(fd_); - fd_ = -1; - removeFromList(parent_.connections_); -} - -bool ProxyProtocol::ActiveConnection::readLine(int fd, std::string& s) { - while (buf_off_ < MAX_PROXY_PROTO_LEN) { - ssize_t nread = recv(fd, buf_ + buf_off_, MAX_PROXY_PROTO_LEN - buf_off_, MSG_PEEK); - - if (nread == -1 && errno == EAGAIN) { - return false; - } else if (nread < 1) { - throw EnvoyException("failed to read proxy protocol"); - } - - bool found = false; - // continue searching buf_ from where we left off - for (; search_index_ < buf_off_ + nread; search_index_++) { - if (buf_[search_index_] == '\n' && buf_[search_index_ - 1] == '\r') { - search_index_++; - found = true; - break; - } - } - - // Read the data upto and including the line feed, if available, but not past it. - // This should never fail, as search_index_ - buf_off_ <= nread, so we're asking - // only for bytes we have already seen. - nread = recv(fd, buf_ + buf_off_, search_index_ - buf_off_, 0); - ASSERT(size_t(nread) == search_index_ - buf_off_); - - buf_off_ += nread; - - if (found) { - s.assign(buf_, buf_off_); - return true; - } - } - - throw EnvoyException("failed to read proxy protocol"); -} - -} // namespace Network -} // namespace Envoy diff --git a/source/common/network/proxy_protocol.h b/source/common/network/proxy_protocol.h deleted file mode 100644 index 1780f792c19d3..0000000000000 --- a/source/common/network/proxy_protocol.h +++ /dev/null @@ -1,89 +0,0 @@ -#pragma once - -#include -#include -#include - -#include "envoy/event/dispatcher.h" -#include "envoy/stats/stats_macros.h" - -#include "common/common/linked_object.h" - -namespace Envoy { -namespace Network { - -class ListenerImpl; - -/** - * All stats for the proxy protocol. @see stats_macros.h - */ -// clang-format off -#define ALL_PROXY_PROTOCOL_STATS(COUNTER) \ - COUNTER(downstream_cx_proxy_proto_error) -// clang-format on - -/** - * Definition of all stats for the proxy protocol. @see stats_macros.h - */ -struct ProxyProtocolStats { - ALL_PROXY_PROTOCOL_STATS(GENERATE_COUNTER_STRUCT) -}; - -/** - * Implementation the PROXY Protocol V1 - * (http://www.haproxy.org/download/1.5/doc/proxy-protocol.txt) - */ -class ProxyProtocol { -public: - class ActiveConnection : public LinkedObject { - public: - ActiveConnection(ProxyProtocol& parent, Event::Dispatcher& dispatcher, int fd, - ListenerImpl& listener); - ~ActiveConnection(); - - private: - static const size_t MAX_PROXY_PROTO_LEN = 108; - - void onRead(); - void onReadWorker(); - - /** - * Helper function that attempts to read a line (delimited by '\r\n') from the socket. - * throws EnvoyException on any socket errors. - * @return bool true if a line should be read, false if more data is needed. - */ - bool readLine(int fd, std::string& s); - void close(); - - /** - * Helper function that replaces the current connection with the specified one. - */ - void finishConnection(Address::InstanceConstSharedPtr remote_address, - Address::InstanceConstSharedPtr local_address); - - ProxyProtocol& parent_; - int fd_; - ListenerImpl& listener_; - Event::FileEventPtr file_event_; - - // The offset in buf_ that has been fully read - size_t buf_off_{}; - - // The index in buf_ where the search for '\r\n' should continue from - size_t search_index_; - - // Stores the portion of the first line that has been read so far. - char buf_[MAX_PROXY_PROTO_LEN]; - }; - - ProxyProtocol(Stats::Scope& scope); - - void newConnection(Event::Dispatcher& dispatcher, int fd, ListenerImpl& listener); - -private: - ProxyProtocolStats stats_; - std::list> connections_; -}; - -} // namespace Network -} // namespace Envoy diff --git a/source/common/ssl/connection_impl.cc b/source/common/ssl/connection_impl.cc index b0eb524fb7b41..7df4f5d2a8f7a 100644 --- a/source/common/ssl/connection_impl.cc +++ b/source/common/ssl/connection_impl.cc @@ -3,15 +3,10 @@ namespace Envoy { namespace Ssl { -ConnectionImpl::ConnectionImpl(Event::Dispatcher& dispatcher, int fd, - Network::Address::InstanceConstSharedPtr remote_address, - Network::Address::InstanceConstSharedPtr local_address, - Network::Address::InstanceConstSharedPtr bind_to_address, - bool using_original_dst, bool connected, Context& ctx, - InitialState state) - : Network::ConnectionImpl(dispatcher, fd, remote_address, local_address, bind_to_address, - Network::TransportSocketPtr{new SslSocket(ctx, state)}, - using_original_dst, connected) {} +ConnectionImpl::ConnectionImpl(Event::Dispatcher& dispatcher, Network::ConnectionSocketPtr&& socket, + bool connected, Context& ctx, InitialState state) + : Network::ConnectionImpl(dispatcher, std::move(socket), + std::make_unique(ctx, state), connected) {} } // namespace Ssl } // namespace Envoy diff --git a/source/common/ssl/connection_impl.h b/source/common/ssl/connection_impl.h index 28e6d54b0f925..58052d60987fb 100644 --- a/source/common/ssl/connection_impl.h +++ b/source/common/ssl/connection_impl.h @@ -17,10 +17,7 @@ namespace Ssl { // TODO(lizan): Remove Ssl::ConnectionImpl entirely when factory of TransportSocket is ready. class ConnectionImpl : public Network::ConnectionImpl { public: - ConnectionImpl(Event::Dispatcher& dispatcher, int fd, - Network::Address::InstanceConstSharedPtr remote_address, - Network::Address::InstanceConstSharedPtr local_address, - Network::Address::InstanceConstSharedPtr bind_to_address, bool using_original_dst, + ConnectionImpl(Event::Dispatcher& dispatcher, Network::ConnectionSocketPtr&& socket, bool connected, Context& ctx, InitialState state); }; diff --git a/source/common/upstream/original_dst_cluster.cc b/source/common/upstream/original_dst_cluster.cc index 93f41b28e80b0..0c8ae80f7cad3 100644 --- a/source/common/upstream/original_dst_cluster.cc +++ b/source/common/upstream/original_dst_cluster.cc @@ -43,8 +43,8 @@ HostConstSharedPtr OriginalDstCluster::LoadBalancer::chooseHost(LoadBalancerCont const Network::Connection* connection = context->downstreamConnection(); // The local address of the downstream connection is the original destination address, - // if usingOriginalDst() returns 'true'. - if (connection && connection->usingOriginalDst()) { + // if localAddressRestored() returns 'true'. + if (connection && connection->localAddressRestored()) { const Network::Address::Instance& dst_addr = *connection->localAddress(); // Check if a host with the destination address is already in the host set. diff --git a/source/exe/BUILD b/source/exe/BUILD index cab17ba0013ff..8747827d67409 100644 --- a/source/exe/BUILD +++ b/source/exe/BUILD @@ -45,6 +45,8 @@ envoy_cc_library( "//source/server/config/http:lua_lib", "//source/server/config/http:ratelimit_lib", "//source/server/config/http:router_lib", + "//source/server/config/listener:original_dst_lib", + "//source/server/config/listener:proxy_protocol_lib", "//source/server/config/network:client_ssl_auth_lib", "//source/server/config/network:echo_lib", "//source/server/config/network:http_connection_manager_lib", diff --git a/source/server/BUILD b/source/server/BUILD index b18b94c6a94e0..5e73eafb5eae2 100644 --- a/source/server/BUILD +++ b/source/server/BUILD @@ -63,9 +63,11 @@ envoy_cc_library( "//include/envoy/network:filter_interface", "//include/envoy/network:listen_socket_interface", "//include/envoy/network:listener_interface", + "//include/envoy/server:listener_manager_interface", "//include/envoy/stats:timespan", "//source/common/common:linked_object", "//source/common/common:non_copyable", + "//source/common/network:connection_lib", ], ) @@ -202,10 +204,13 @@ envoy_cc_library( "//include/envoy/server:listener_manager_interface", "//include/envoy/server:worker_interface", "//source/common/config:utility_lib", + "//source/common/config:well_known_names", "//source/common/network:listen_socket_lib", "//source/common/network:utility_lib", "//source/common/protobuf:utility_lib", "//source/common/ssl:context_config_lib", + "//source/server/config/listener:original_dst_lib", + "//source/server/config/listener:proxy_protocol_lib", ], ) diff --git a/source/server/config/listener/BUILD b/source/server/config/listener/BUILD new file mode 100644 index 0000000000000..2409d444272f8 --- /dev/null +++ b/source/server/config/listener/BUILD @@ -0,0 +1,31 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "original_dst_lib", + srcs = ["original_dst.cc"], + deps = [ + "//include/envoy/registry", + "//include/envoy/server:filter_config_interface", + "//source/common/config:well_known_names", + "//source/common/filter/listener:original_dst_lib", + ], +) + +envoy_cc_library( + name = "proxy_protocol_lib", + srcs = ["proxy_protocol.cc"], + deps = [ + "//include/envoy/registry", + "//include/envoy/server:filter_config_interface", + "//source/common/config:well_known_names", + "//source/common/filter/listener:proxy_protocol_lib", + ], +) diff --git a/source/server/config/listener/original_dst.cc b/source/server/config/listener/original_dst.cc new file mode 100644 index 0000000000000..33adae73253fa --- /dev/null +++ b/source/server/config/listener/original_dst.cc @@ -0,0 +1,41 @@ +#include + +#include "envoy/registry/registry.h" +#include "envoy/server/filter_config.h" + +#include "common/config/well_known_names.h" +#include "common/filter/listener/original_dst.h" + +namespace Envoy { +namespace Server { +namespace Configuration { + +/** + * Config registration for the original dst filter. @see NamedNetworkFilterConfigFactory. + */ +class OriginalDstConfigFactory : public NamedListenerFilterConfigFactory { +public: + // NamedListenerFilterConfigFactory + ListenerFilterFactoryCb createFilterFactoryFromProto(const Protobuf::Message&, + FactoryContext&) override { + return [](Network::ListenerFilterManager& filter_manager) -> void { + filter_manager.addAcceptFilter(std::make_unique()); + }; + } + + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return std::make_unique(); + } + + std::string name() override { return Config::ListenerFilterNames::get().ORIGINAL_DST; } +}; + +/** + * Static registration for the original dst filter. @see RegisterFactory. + */ +static Registry::RegisterFactory + registered_; + +} // namespace Configuration +} // namespace Server +} // namespace Envoy diff --git a/source/server/config/listener/proxy_protocol.cc b/source/server/config/listener/proxy_protocol.cc new file mode 100644 index 0000000000000..3c983ab9b2109 --- /dev/null +++ b/source/server/config/listener/proxy_protocol.cc @@ -0,0 +1,44 @@ +#include + +#include "envoy/registry/registry.h" +#include "envoy/server/filter_config.h" + +#include "common/config/well_known_names.h" +#include "common/filter/listener/proxy_protocol.h" + +namespace Envoy { +namespace Server { +namespace Configuration { + +/** + * Config registration for the proxy protocol filter. @see NamedNetworkFilterConfigFactory. + */ +class ProxyProtocolConfigFactory : public NamedListenerFilterConfigFactory { +public: + // NamedListenerFilterConfigFactory + ListenerFilterFactoryCb createFilterFactoryFromProto(const Protobuf::Message&, + FactoryContext& context) override { + Filter::Listener::ProxyProtocol::ConfigSharedPtr config( + new Filter::Listener::ProxyProtocol::Config(context.scope())); + return [config](Network::ListenerFilterManager& filter_manager) -> void { + filter_manager.addAcceptFilter( + std::make_unique(config)); + }; + } + + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return std::make_unique(); + } + + std::string name() override { return Config::ListenerFilterNames::get().PROXY_PROTOCOL; } +}; + +/** + * Static registration for the proxy protocol filter. @see RegisterFactory. + */ +static Registry::RegisterFactory + registered_; + +} // namespace Configuration +} // namespace Server +} // namespace Envoy diff --git a/source/server/config_validation/dispatcher.cc b/source/server/config_validation/dispatcher.cc index 08920e8184a54..d4c433a836f27 100644 --- a/source/server/config_validation/dispatcher.cc +++ b/source/server/config_validation/dispatcher.cc @@ -17,18 +17,8 @@ Network::DnsResolverSharedPtr ValidationDispatcher::createDnsResolver( NOT_IMPLEMENTED; } -Network::ListenerPtr ValidationDispatcher::createListener(Network::ConnectionHandler&, - Network::ListenSocket&, - Network::ListenerCallbacks&, - Stats::Scope&, - const Network::ListenerOptions&) { - NOT_IMPLEMENTED; -} - -Network::ListenerPtr -ValidationDispatcher::createSslListener(Network::ConnectionHandler&, Ssl::ServerContext&, - Network::ListenSocket&, Network::ListenerCallbacks&, - Stats::Scope&, const Network::ListenerOptions&) { +Network::ListenerPtr ValidationDispatcher::createListener(Network::ListenSocket&, + Network::ListenerCallbacks&, bool, bool) { NOT_IMPLEMENTED; } diff --git a/source/server/config_validation/dispatcher.h b/source/server/config_validation/dispatcher.h index b9639b0a9ade4..eeb647460bf23 100644 --- a/source/server/config_validation/dispatcher.h +++ b/source/server/config_validation/dispatcher.h @@ -19,12 +19,9 @@ class ValidationDispatcher : public DispatcherImpl { Network::TransportSocketPtr&&) override; Network::DnsResolverSharedPtr createDnsResolver( const std::vector& resolvers) override; - Network::ListenerPtr createListener(Network::ConnectionHandler&, Network::ListenSocket&, - Network::ListenerCallbacks&, Stats::Scope&, - const Network::ListenerOptions&) override; - Network::ListenerPtr createSslListener(Network::ConnectionHandler&, Ssl::ServerContext&, - Network::ListenSocket&, Network::ListenerCallbacks&, - Stats::Scope&, const Network::ListenerOptions&) override; + Network::ListenerPtr createListener(Network::ListenSocket&, Network::ListenerCallbacks&, + bool bind_to_port, + bool hand_off_restored_destination_connections) override; }; } // namespace Event diff --git a/source/server/config_validation/server.h b/source/server/config_validation/server.h index 7ab175aa45bc3..e09ffa24ad405 100644 --- a/source/server/config_validation/server.h +++ b/source/server/config_validation/server.h @@ -90,9 +90,14 @@ class ValidationInstance : Logger::Loggable, // Server::ListenerComponentFactory std::vector - createFilterFactoryList(const Protobuf::RepeatedPtrField& filters, - Configuration::FactoryContext& context) override { - return ProdListenerComponentFactory::createFilterFactoryList_(filters, context); + createNetworkFilterFactoryList(const Protobuf::RepeatedPtrField& filters, + Configuration::FactoryContext& context) override { + return ProdListenerComponentFactory::createNetworkFilterFactoryList_(filters, context); + } + std::vector createListenerFilterFactoryList( + const Protobuf::RepeatedPtrField& filters, + Configuration::FactoryContext& context) override { + return ProdListenerComponentFactory::createListenerFilterFactoryList_(filters, context); } Network::ListenSocketSharedPtr createListenSocket(Network::Address::InstanceConstSharedPtr, bool) override { diff --git a/source/server/configuration_impl.cc b/source/server/configuration_impl.cc index 933d845b64bb8..4d3933a777d97 100644 --- a/source/server/configuration_impl.cc +++ b/source/server/configuration_impl.cc @@ -36,6 +36,15 @@ bool FilterChainUtility::buildFilterChain(Network::FilterManager& filter_manager return filter_manager.initializeReadFilters(); } +bool FilterChainUtility::buildFilterChain(Network::ListenerFilterManager& filter_manager, + const std::vector& factories) { + for (const ListenerFilterFactoryCb& factory : factories) { + factory(filter_manager); + } + + return true; +} + void MainImpl::initialize(const envoy::api::v2::Bootstrap& bootstrap, Instance& server, Upstream::ClusterManagerFactory& cluster_manager_factory) { cluster_manager_ = cluster_manager_factory.clusterManagerFromProto( diff --git a/source/server/configuration_impl.h b/source/server/configuration_impl.h index 3dafe282a6b3e..2ba943f429754 100644 --- a/source/server/configuration_impl.h +++ b/source/server/configuration_impl.h @@ -99,6 +99,13 @@ class FilterChainUtility { */ static bool buildFilterChain(Network::FilterManager& filter_manager, const std::vector& factories); + + /** + * Given a ListenerFilterManager and a list of factories, create a new filter chain. Chain + * creation will exit early if any filters immediately close the connection. + */ + static bool buildFilterChain(Network::ListenerFilterManager& filter_manager, + const std::vector& factories); }; /** diff --git a/source/server/connection_handler_impl.cc b/source/server/connection_handler_impl.cc index e6b81ec03315e..444bf40fae606 100644 --- a/source/server/connection_handler_impl.cc +++ b/source/server/connection_handler_impl.cc @@ -5,6 +5,9 @@ #include "envoy/network/filter.h" #include "envoy/stats/timespan.h" +#include "common/network/connection_impl.h" +#include "common/network/utility.h" + namespace Envoy { namespace Server { @@ -17,23 +20,9 @@ ConnectionHandlerImpl::ConnectionHandlerImpl(spdlog::logger& logger, Event::Disp UNREFERENCED_PARAMETER(logger); } -void ConnectionHandlerImpl::addListener(Network::FilterChainFactory& factory, - Network::ListenSocket& socket, Stats::Scope& scope, - uint64_t listener_tag, - const Network::ListenerOptions& listener_options) { - ActiveListenerPtr l( - new ActiveListener(*this, socket, factory, scope, listener_tag, listener_options)); - listeners_.emplace_back(socket.localAddress(), std::move(l)); -} - -void ConnectionHandlerImpl::addSslListener(Network::FilterChainFactory& factory, - Ssl::ServerContext& ssl_ctx, - Network::ListenSocket& socket, Stats::Scope& scope, - uint64_t listener_tag, - const Network::ListenerOptions& listener_options) { - ActiveListenerPtr l(new SslActiveListener(*this, ssl_ctx, socket, factory, scope, listener_tag, - listener_options)); - listeners_.emplace_back(socket.localAddress(), std::move(l)); +void ConnectionHandlerImpl::addListener(Network::ListenerConfig& config) { + ActiveListenerPtr l(new ActiveListener(*this, config)); + listeners_.emplace_back(config.socket().localAddress(), std::move(l)); } void ConnectionHandlerImpl::removeListeners(uint64_t listener_tag) { @@ -69,22 +58,29 @@ void ConnectionHandlerImpl::ActiveListener::removeConnection(ActiveConnection& c parent_.num_connections_--; } -ConnectionHandlerImpl::ActiveListener::ActiveListener( - ConnectionHandlerImpl& parent, Network::ListenSocket& socket, - Network::FilterChainFactory& factory, Stats::Scope& scope, uint64_t listener_tag, - const Network::ListenerOptions& listener_options) +ConnectionHandlerImpl::ActiveListener::ActiveListener(ConnectionHandlerImpl& parent, + Network::ListenerConfig& config) : ActiveListener( - parent, parent.dispatcher_.createListener(parent, socket, *this, scope, listener_options), - factory, scope, listener_tag) {} + parent, + parent.dispatcher_.createListener(config.socket(), *this, config.bindToPort(), + config.handOffRestoredDestinationConnections()), + config) {} ConnectionHandlerImpl::ActiveListener::ActiveListener(ConnectionHandlerImpl& parent, Network::ListenerPtr&& listener, - Network::FilterChainFactory& factory, - Stats::Scope& scope, uint64_t listener_tag) - : parent_(parent), factory_(factory), listener_(std::move(listener)), - stats_(generateStats(scope)), listener_tag_(listener_tag) {} + Network::ListenerConfig& config) + : parent_(parent), listener_(std::move(listener)), + stats_(generateStats(config.listenerScope())), listener_tag_(config.listenerTag()), + config_(config) {} ConnectionHandlerImpl::ActiveListener::~ActiveListener() { + // Purge sockets that have not progressed to connections. This should only happen when + // a listener filter stops iteration and never resumes. + while (!sockets_.empty()) { + ActiveSocketPtr removed = sockets_.front()->removeFromList(sockets_); + parent_.dispatcher_.deferredDelete(std::move(removed)); + } + while (!connections_.empty()) { connections_.front()->connection_->close(Network::ConnectionCloseType::NoFlush); } @@ -92,21 +88,18 @@ ConnectionHandlerImpl::ActiveListener::~ActiveListener() { parent_.dispatcher_.clearDeferredDeleteList(); } -ConnectionHandlerImpl::SslActiveListener::SslActiveListener( - ConnectionHandlerImpl& parent, Ssl::ServerContext& ssl_ctx, Network::ListenSocket& socket, - Network::FilterChainFactory& factory, Stats::Scope& scope, uint64_t listener_tag, - const Network::ListenerOptions& listener_options) - : ActiveListener(parent, - parent.dispatcher_.createSslListener(parent, ssl_ctx, socket, *this, scope, - listener_options), - factory, scope, listener_tag) {} - Network::Listener* ConnectionHandlerImpl::findListenerByAddress(const Network::Address::Instance& address) { + ActiveListener* listener = findActiveListenerByAddress(address); + return listener ? listener->listener_.get() : nullptr; +} + +ConnectionHandlerImpl::ActiveListener* +ConnectionHandlerImpl::findActiveListenerByAddress(const Network::Address::Instance& address) { // This is a linear operation, may need to add a map to improve performance. // However, linear performance might be adequate since the number of listeners is small. // We do not return stopped listeners. - auto listener = std::find_if( + auto listener_it = std::find_if( listeners_.begin(), listeners_.end(), [&address](const std::pair& p) { return p.second->listener_ != nullptr && p.first->type() == Network::Address::Type::Ip && @@ -114,26 +107,93 @@ ConnectionHandlerImpl::findListenerByAddress(const Network::Address::Instance& a }); // If there is exact address match, return the corresponding listener. - if (listener != listeners_.end()) { - return listener->second->listener_.get(); + if (listener_it != listeners_.end()) { + return listener_it->second.get(); } // Otherwise, we need to look for the wild card match, i.e., 0.0.0.0:[address_port]. // We do not return stopped listeners. // TODO(wattli): consolidate with previous search for more efficiency. - listener = std::find_if( + listener_it = std::find_if( listeners_.begin(), listeners_.end(), [&address](const std::pair& p) { return p.second->listener_ != nullptr && p.first->type() == Network::Address::Type::Ip && p.first->ip()->port() == address.ip()->port() && p.first->ip()->isAnyAddress(); }); - return (listener != listeners_.end()) ? listener->second->listener_.get() : nullptr; + return (listener_it != listeners_.end()) ? listener_it->second.get() : nullptr; +} + +void ConnectionHandlerImpl::ActiveSocket::continueFilterChain(bool success) { + if (success) { + if (iter_ == accept_filters_.end()) { + iter_ = accept_filters_.begin(); + } else { + iter_ = std::next(iter_); + } + + for (; iter_ != accept_filters_.end(); iter_++) { + Network::FilterStatus status = (*iter_)->onAccept(*this); + if (status == Network::FilterStatus::StopIteration) { + // The filter is responsible for calling us again at a later time to continue the filter + // chain from the next filter. + return; + } + } + // Successfully ran all the accept filters. + + // Check if the socket may need to be redirected to another listener. + ActiveListener* new_listener = nullptr; + + if (hand_off_restored_destination_connections_ && socket_->localAddressRestored()) { + // Find a listener associated with the original destination address. + new_listener = listener_.parent_.findActiveListenerByAddress(*socket_->localAddress()); + } + if (new_listener != nullptr) { + // Hands off connections redirected by iptables to the listener associated with the + // original destination address. Pass 'hand_off_restored_destionations' as false to + // prevent further redirection. + new_listener->onAccept(std::move(socket_), false); + } else { + // Create a new connection on this listener. + listener_.newConnection(std::move(socket_)); + } + } + + // Filter execution concluded, unlink and delete this ActiveSocket if it was linked. + if (inserted()) { + ActiveSocketPtr removed = removeFromList(listener_.sockets_); + listener_.parent_.dispatcher_.deferredDelete(std::move(removed)); + } +} + +void ConnectionHandlerImpl::ActiveListener::onAccept( + Network::ConnectionSocketPtr&& socket, bool hand_off_restored_destination_connections) { + Network::Address::InstanceConstSharedPtr local_address = socket->localAddress(); + auto active_socket = std::make_unique(*this, std::move(socket), + hand_off_restored_destination_connections); + + // Create and run the filters + config_.filterChainFactory().createListenerFilterChain(*active_socket); + active_socket->continueFilterChain(true); + + // Move active_socket to the sockets_ list if filter iteration needs to continue later. + // Otherwise we let active_socket be destructed when it goes out of scope. + if (active_socket->iter_ != active_socket->accept_filters_.end()) { + active_socket->moveIntoListBack(std::move(active_socket), sockets_); + } +} + +void ConnectionHandlerImpl::ActiveListener::newConnection(Network::ConnectionSocketPtr&& socket) { + Network::ConnectionPtr new_connection = + parent_.dispatcher_.createServerConnection(std::move(socket), config_.defaultSslContext()); + new_connection->setBufferLimits(config_.perConnectionBufferLimitBytes()); + onNewConnection(std::move(new_connection)); } void ConnectionHandlerImpl::ActiveListener::onNewConnection( Network::ConnectionPtr&& new_connection) { ENVOY_CONN_LOG_TO_LOGGER(parent_.logger_, debug, "new connection", *new_connection); - bool empty_filter_chain = !factory_.createFilterChain(*new_connection); + bool empty_filter_chain = !config_.filterChainFactory().createNetworkFilterChain(*new_connection); // If the connection is already closed, we can just let this connection immediately die. if (new_connection->state() != Network::Connection::State::Closed) { diff --git a/source/server/connection_handler_impl.h b/source/server/connection_handler_impl.h index 8a008a4a68b57..86be964968ecb 100644 --- a/source/server/connection_handler_impl.h +++ b/source/server/connection_handler_impl.h @@ -12,6 +12,7 @@ #include "envoy/network/filter.h" #include "envoy/network/listen_socket.h" #include "envoy/network/listener.h" +#include "envoy/server/listener_manager.h" #include "envoy/stats/timespan.h" #include "common/common/linked_object.h" @@ -47,39 +48,36 @@ class ConnectionHandlerImpl : public Network::ConnectionHandler, NonCopyable { // Network::ConnectionHandler uint64_t numConnections() override { return num_connections_; } - void addListener(Network::FilterChainFactory& factory, Network::ListenSocket& socket, - Stats::Scope& scope, uint64_t listener_tag, - const Network::ListenerOptions& listener_options) override; - void addSslListener(Network::FilterChainFactory& factory, Ssl::ServerContext& ssl_ctx, - Network::ListenSocket& socket, Stats::Scope& scope, uint64_t listener_tag, - const Network::ListenerOptions& listener_options) override; - Network::Listener* findListenerByAddress(const Network::Address::Instance& address) override; + void addListener(Network::ListenerConfig& config) override; void removeListeners(uint64_t listener_tag) override; void stopListeners(uint64_t listener_tag) override; void stopListeners() override; + Network::Listener* findListenerByAddress(const Network::Address::Instance& address) override; + private: + struct ActiveListener; + ActiveListener* findActiveListenerByAddress(const Network::Address::Instance& address); + struct ActiveConnection; typedef std::unique_ptr ActiveConnectionPtr; + struct ActiveSocket; + typedef std::unique_ptr ActiveSocketPtr; /** * Wrapper for an active listener owned by this handler. */ struct ActiveListener : public Network::ListenerCallbacks { - ActiveListener(ConnectionHandlerImpl& parent, Network::ListenSocket& socket, - Network::FilterChainFactory& factory, Stats::Scope& scope, uint64_t listener_tag, - const Network::ListenerOptions& listener_options); + ActiveListener(ConnectionHandlerImpl& parent, Network::ListenerConfig& config); ActiveListener(ConnectionHandlerImpl& parent, Network::ListenerPtr&& listener, - Network::FilterChainFactory& factory, Stats::Scope& scope, - uint64_t listener_tag); + Network::ListenerConfig& config); ~ActiveListener(); - /** - * Fires when a new connection is received from the listener. - * @param new_connection supplies the connection to take control of. - */ + // Network::ListenerCallbacks + void onAccept(Network::ConnectionSocketPtr&& socket, + bool hand_off_restored_destination_connections) override; void onNewConnection(Network::ConnectionPtr&& new_connection) override; /** @@ -88,19 +86,18 @@ class ConnectionHandlerImpl : public Network::ConnectionHandler, NonCopyable { */ void removeConnection(ActiveConnection& connection); + /** + * Create a new connection from a socket accepted by the listener. + */ + void newConnection(Network::ConnectionSocketPtr&& socket); + ConnectionHandlerImpl& parent_; - Network::FilterChainFactory& factory_; Network::ListenerPtr listener_; ListenerStats stats_; + std::list sockets_; std::list connections_; const uint64_t listener_tag_; - }; - - struct SslActiveListener : public ActiveListener { - SslActiveListener(ConnectionHandlerImpl& parent, Ssl::ServerContext& ssl_ctx, - Network::ListenSocket& socket, Network::FilterChainFactory& factory, - Stats::Scope& scope, uint64_t listener_tag, - const Network::ListenerOptions& listener_options); + Network::ListenerConfig& config_; }; typedef std::unique_ptr ActiveListenerPtr; @@ -130,6 +127,37 @@ class ConnectionHandlerImpl : public Network::ConnectionHandler, NonCopyable { Stats::TimespanPtr conn_length_; }; + /** + * Wrapper for an active accepted socket owned by this handler. + */ + struct ActiveSocket : public Network::ListenerFilterManager, + public Network::ListenerFilterCallbacks, + LinkedObject, + public Event::DeferredDeletable { + ActiveSocket(ActiveListener& listener, Network::ConnectionSocketPtr&& socket, + bool hand_off_restored_destination_connections) + : listener_(listener), socket_(std::move(socket)), + hand_off_restored_destination_connections_(hand_off_restored_destination_connections), + iter_(accept_filters_.end()) {} + ~ActiveSocket() { accept_filters_.clear(); } + + // Network::ListenerFilterManager + void addAcceptFilter(Network::ListenerFilterPtr&& filter) override { + accept_filters_.emplace_back(std::move(filter)); + } + + // Network::ListenerFilterCallbacks + Network::ConnectionSocket& socket() override { return *socket_.get(); } + Event::Dispatcher& dispatcher() override { return listener_.parent_.dispatcher_; } + void continueFilterChain(bool success) override; + + ActiveListener& listener_; + Network::ConnectionSocketPtr socket_; + const bool hand_off_restored_destination_connections_; + std::list accept_filters_; + std::list::iterator iter_; + }; + static ListenerStats generateStats(Stats::Scope& scope); #ifndef NVLOG diff --git a/source/server/http/BUILD b/source/server/http/BUILD index 633778e4a9f17..cf9e834745c3d 100644 --- a/source/server/http/BUILD +++ b/source/server/http/BUILD @@ -20,6 +20,7 @@ envoy_cc_library( "//include/envoy/server:admin_interface", "//include/envoy/server:hot_restart_interface", "//include/envoy/server:instance_interface", + "//include/envoy/server:listener_manager_interface", "//include/envoy/server:options_interface", "//include/envoy/stats:stats_interface", "//include/envoy/upstream:cluster_manager_interface", diff --git a/source/server/http/admin.cc b/source/server/http/admin.cc index 3a9a70f544a67..08c732b9ee09d 100644 --- a/source/server/http/admin.cc +++ b/source/server/http/admin.cc @@ -609,7 +609,7 @@ AdminImpl::NullRouteConfigProvider::NullRouteConfigProvider() AdminImpl::AdminImpl(const std::string& access_log_path, const std::string& profile_path, const std::string& address_out_path, Network::Address::InstanceConstSharedPtr address, Server::Instance& server, - Stats::Scope& listener_scope) + Stats::ScopePtr&& listener_scope) : server_(server), profile_path_(profile_path), socket_(new Network::TcpListenSocket(address, true)), stats_(Http::ConnectionManagerImpl::generateStats("http.admin.", server_.stats())), @@ -642,8 +642,7 @@ AdminImpl::AdminImpl(const std::string& access_log_path, const std::string& prof {"/listeners", "print listener addresses", MAKE_ADMIN_HANDLER(handlerListenerInfo), false, false}, {"/runtime", "print runtime values", MAKE_ADMIN_HANDLER(handlerRuntime), false, false}}, - listener_stats_( - Http::ConnectionManagerImpl::generateListenerStats("http.admin.", listener_scope)) { + listener_(*this, std::move(listener_scope)) { if (!address_out_path.empty()) { std::ofstream address_out_file(address_out_path); @@ -667,7 +666,7 @@ Http::ServerConnectionPtr AdminImpl::createCodec(Network::Connection& connection new Http::Http1::ServerConnectionImpl(connection, callbacks, Http::Http1Settings())}; } -bool AdminImpl::createFilterChain(Network::Connection& connection) { +bool AdminImpl::createNetworkFilterChain(Network::Connection& connection) { connection.addReadFilter(Network::ReadFilterSharedPtr{new Http::ConnectionManagerImpl( *this, server_.drainManager(), server_.random(), server_.httpTracer(), server_.runtime(), server_.localInfo(), server_.clusterManager())}); diff --git a/source/server/http/admin.h b/source/server/http/admin.h index 289f73950d5f5..c329993faae0b 100644 --- a/source/server/http/admin.h +++ b/source/server/http/admin.h @@ -12,6 +12,7 @@ #include "envoy/runtime/runtime.h" #include "envoy/server/admin.h" #include "envoy/server/instance.h" +#include "envoy/server/listener_manager.h" #include "envoy/upstream/outlier_detection.h" #include "envoy/upstream/resource_manager.h" @@ -37,12 +38,13 @@ class AdminImpl : public Admin, public: AdminImpl(const std::string& access_log_path, const std::string& profiler_path, const std::string& address_out_path, Network::Address::InstanceConstSharedPtr address, - Server::Instance& server, Stats::Scope& listener_scope); + Server::Instance& server, Stats::ScopePtr&& listener_scope); Http::Code runCallback(const std::string& path_and_query, Http::HeaderMap& response_headers, Buffer::Instance& response); const Network::ListenSocket& socket() override { return *socket_; } Network::ListenSocket& mutable_socket() { return *socket_; } + Network::ListenerConfig& listener() { return listener_; } // Server::Admin bool addHandler(const std::string& prefix, const std::string& help_text, HandlerCb callback, @@ -50,7 +52,8 @@ class AdminImpl : public Admin, bool removeHandler(const std::string& prefix) override; // Network::FilterChainFactory - bool createFilterChain(Network::Connection& connection) override; + bool createNetworkFilterChain(Network::Connection& connection) override; + bool createListenerFilterChain(Network::ListenerFilterManager&) override { return true; } // Http::FilterChainFactory void createFilterChain(Http::FilterChainFactoryCallbacks& callbacks) override; @@ -81,7 +84,7 @@ class AdminImpl : public Admin, const Network::Address::Instance& localAddress() override; const Optional& userAgent() override { return user_agent_; } const Http::TracingConnectionManagerConfig* tracingConfig() override { return nullptr; } - Http::ConnectionManagerListenerStats& listenerStats() override { return listener_stats_; } + Http::ConnectionManagerListenerStats& listenerStats() override { return listener_.stats_; } private: /** @@ -162,6 +165,29 @@ class AdminImpl : public Admin, Http::Code handlerRuntime(const std::string& path_and_query, Http::HeaderMap& response_headers, Buffer::Instance& response); + class AdminListener : public Network::ListenerConfig { + public: + AdminListener(AdminImpl& parent, Stats::ScopePtr&& listener_scope) + : parent_(parent), name_("admin"), scope_(std::move(listener_scope)), + stats_(Http::ConnectionManagerImpl::generateListenerStats("http.admin.", *scope_)) {} + + // Network::ListenerConfig + Network::FilterChainFactory& filterChainFactory() override { return parent_; } + Network::ListenSocket& socket() override { return parent_.mutable_socket(); } + Ssl::ServerContext* defaultSslContext() override { return nullptr; } + bool bindToPort() override { return true; } + bool handOffRestoredDestinationConnections() const override { return false; } + uint32_t perConnectionBufferLimitBytes() override { return 0; } + Stats::Scope& listenerScope() override { return *scope_; } + uint64_t listenerTag() const override { return 0; } + const std::string& name() const override { return name_; } + + AdminImpl& parent_; + const std::string name_; + Stats::ScopePtr scope_; + Http::ConnectionManagerListenerStats stats_; + }; + Server::Instance& server_; std::list access_logs_; const std::string profile_path_; @@ -174,7 +200,7 @@ class AdminImpl : public Admin, Optional user_agent_; Http::SlowDateProviderImpl date_provider_; std::vector set_current_client_cert_details_; - Http::ConnectionManagerListenerStats listener_stats_; + AdminListener listener_; }; /** diff --git a/source/server/listener_manager_impl.cc b/source/server/listener_manager_impl.cc index 25e06dda855dd..736681d04fd53 100644 --- a/source/server/listener_manager_impl.cc +++ b/source/server/listener_manager_impl.cc @@ -4,6 +4,7 @@ #include "common/common/assert.h" #include "common/config/utility.h" +#include "common/config/well_known_names.h" #include "common/network/listen_socket_impl.h" #include "common/network/utility.h" #include "common/protobuf/utility.h" @@ -18,7 +19,7 @@ namespace Envoy { namespace Server { std::vector -ProdListenerComponentFactory::createFilterFactoryList_( +ProdListenerComponentFactory::createNetworkFilterFactoryList_( const Protobuf::RepeatedPtrField& filters, Configuration::FactoryContext& context) { std::vector ret; @@ -48,6 +49,30 @@ ProdListenerComponentFactory::createFilterFactoryList_( return ret; } +std::vector +ProdListenerComponentFactory::createListenerFilterFactoryList_( + const Protobuf::RepeatedPtrField& filters, + Configuration::FactoryContext& context) { + std::vector ret; + for (ssize_t i = 0; i < filters.size(); i++) { + const auto& proto_config = filters[i]; + const ProtobufTypes::String string_name = proto_config.name(); + ENVOY_LOG(debug, " filter #{}:", i); + ENVOY_LOG(debug, " name: {}", string_name); + const Json::ObjectSharedPtr filter_config = + MessageUtil::getJsonObjectFromMessage(proto_config.config()); + ENVOY_LOG(debug, " config: {}", filter_config->asJsonString()); + + // Now see if there is a factory that will accept the config. + auto& factory = + Config::Utility::getAndCheckFactory( + string_name); + auto message = Config::Utility::translateToFactoryConfig(proto_config, factory); + ret.push_back(factory.createFilterFactoryFromProto(*message, context)); + } + return ret; +} + Network::ListenSocketSharedPtr ProdListenerComponentFactory::createListenSocket(Network::Address::InstanceConstSharedPtr address, bool bind_to_port) { @@ -84,9 +109,8 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, ListenerManag listener_scope_( parent_.server_.stats().createScope(fmt::format("listener.{}.", address_->asString()))), bind_to_port_(PROTOBUF_GET_WRAPPED_OR_DEFAULT(config.deprecated_v1(), bind_to_port, true)), - use_proxy_proto_( - PROTOBUF_GET_WRAPPED_OR_DEFAULT(config.filter_chains()[0], use_proxy_proto, false)), - use_original_dst_(PROTOBUF_GET_WRAPPED_OR_DEFAULT(config, use_original_dst, false)), + hand_off_restored_destination_connections_( + PROTOBUF_GET_WRAPPED_OR_DEFAULT(config, use_original_dst, false)), per_connection_buffer_limit_bytes_( PROTOBUF_GET_WRAPPED_OR_DEFAULT(config, per_connection_buffer_limit_bytes, 1024 * 1024)), listener_tag_(parent_.factory_.nextListenerTag()), name_(name), modifiable_(modifiable), @@ -96,6 +120,30 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, ListenerManag // filter chain #1308. ASSERT(config.filter_chains().size() >= 1); + if (!config.listener_filters().empty()) { + listener_filter_factories_ = + parent_.factory_.createListenerFilterFactoryList(config.listener_filters(), *this); + } + // Add original dst listener filter if 'use_original_dst' flag is set. + if (PROTOBUF_GET_WRAPPED_OR_DEFAULT(config, use_original_dst, false)) { + auto& factory = + Config::Utility::getAndCheckFactory( + Config::ListenerFilterNames::get().ORIGINAL_DST); + listener_filter_factories_.push_back( + factory.createFilterFactoryFromProto(Envoy::ProtobufWkt::Empty(), *this)); + } + // Add proxy protocol listener filter if 'use_proxy_proto' flag is set. + // TODO(jrajahalme): This is the last listener filter on purpose. When filter chain matching + // is implemented, this needs to be run after the filter chain has been + // selected. + if (PROTOBUF_GET_WRAPPED_OR_DEFAULT(config.filter_chains()[0], use_proxy_proto, false)) { + auto& factory = + Config::Utility::getAndCheckFactory( + Config::ListenerFilterNames::get().PROXY_PROTOCOL); + listener_filter_factories_.push_back( + factory.createFilterFactoryFromProto(Envoy::ProtobufWkt::Empty(), *this)); + } + // Skip lookup and update of the SSL Context if there is only one filter chain // and it doesn't enforce any SNI restrictions. const bool skip_context_update = @@ -110,7 +158,8 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, ListenerManag filter_chain.filter_chain_match().sni_domains().end()); if (!filters_hash.valid()) { filters_hash.value(RepeatedPtrUtil::hash(filter_chain.filters())); - filter_factories_ = parent_.factory_.createFilterFactoryList(filter_chain.filters(), *this); + filter_factories_ = + parent_.factory_.createNetworkFilterFactoryList(filter_chain.filters(), *this); } else if (filters_hash.value() != RepeatedPtrUtil::hash(filter_chain.filters())) { throw EnvoyException(fmt::format("error adding listener '{}': use of different filter chains " "is currently not supported", @@ -149,10 +198,14 @@ ListenerImpl::~ListenerImpl() { filter_factories_.clear(); } -bool ListenerImpl::createFilterChain(Network::Connection& connection) { +bool ListenerImpl::createNetworkFilterChain(Network::Connection& connection) { return Configuration::FilterChainUtility::buildFilterChain(connection, filter_factories_); } +bool ListenerImpl::createListenerFilterChain(Network::ListenerFilterManager& manager) { + return Configuration::FilterChainUtility::buildFilterChain(manager, listener_filter_factories_); +} + bool ListenerImpl::drainClose() const { // When a listener is draining, the "drain close" decision is the union of the per-listener drain // manager and the server wide drain manager. This allows individual listeners to be drained and diff --git a/source/server/listener_manager_impl.h b/source/server/listener_manager_impl.h index f54028f88d9c1..e3e5c78e7d732 100644 --- a/source/server/listener_manager_impl.h +++ b/source/server/listener_manager_impl.h @@ -25,18 +25,30 @@ class ProdListenerComponentFactory : public ListenerComponentFactory, ProdListenerComponentFactory(Instance& server) : server_(server) {} /** - * Static worker for createFilterFactoryList() that can be used directly in tests. + * Static worker for createNetworkFilterFactoryList() that can be used directly in tests. */ static std::vector - createFilterFactoryList_(const Protobuf::RepeatedPtrField& filters, - Configuration::FactoryContext& context); + createNetworkFilterFactoryList_(const Protobuf::RepeatedPtrField& filters, + Configuration::FactoryContext& context); + /** + * Static worker for createListenerFilterFactoryList() that can be used directly in tests. + */ + static std::vector createListenerFilterFactoryList_( + const Protobuf::RepeatedPtrField& filters, + Configuration::FactoryContext& context); - // Server::ListenSocketFactory + // Server::ListenerComponentFactory std::vector - createFilterFactoryList(const Protobuf::RepeatedPtrField& filters, - Configuration::FactoryContext& context) override { - return createFilterFactoryList_(filters, context); + createNetworkFilterFactoryList(const Protobuf::RepeatedPtrField& filters, + Configuration::FactoryContext& context) override { + return createNetworkFilterFactoryList_(filters, context); + } + std::vector createListenerFilterFactoryList( + const Protobuf::RepeatedPtrField& filters, + Configuration::FactoryContext& context) override { + return createListenerFilterFactoryList_(filters, context); } + Network::ListenSocketSharedPtr createListenSocket(Network::Address::InstanceConstSharedPtr address, bool bind_to_port) override; DrainManagerPtr createDrainManager(envoy::api::v2::Listener::DrainType drain_type) override; @@ -200,14 +212,15 @@ class ListenerImpl : public Network::ListenerConfig, Network::FilterChainFactory& filterChainFactory() override { return *this; } Network::ListenSocket& socket() override { return *socket_; } bool bindToPort() override { return bind_to_port_; } + bool handOffRestoredDestinationConnections() const override { + return hand_off_restored_destination_connections_; + } Ssl::ServerContext* defaultSslContext() override { return tls_contexts_.empty() ? nullptr : tls_contexts_[0].get(); } - bool useProxyProto() override { return use_proxy_proto_; } - bool useOriginalDst() override { return use_original_dst_; } uint32_t perConnectionBufferLimitBytes() override { return per_connection_buffer_limit_bytes_; } Stats::Scope& listenerScope() override { return *listener_scope_; } - uint64_t listenerTag() override { return listener_tag_; } + uint64_t listenerTag() const override { return listener_tag_; } const std::string& name() const override { return name_; } // Server::Configuration::FactoryContext @@ -236,7 +249,8 @@ class ListenerImpl : public Network::ListenerConfig, bool drainClose() const override; // Network::FilterChainFactory - bool createFilterChain(Network::Connection& connection) override; + bool createNetworkFilterChain(Network::Connection& connection) override; + bool createListenerFilterChain(Network::ListenerFilterManager& manager) override; private: ListenerManagerImpl& parent_; @@ -246,8 +260,7 @@ class ListenerImpl : public Network::ListenerConfig, Stats::ScopePtr listener_scope_; // Stats with listener named scope. std::vector tls_contexts_; const bool bind_to_port_; - const bool use_proxy_proto_; - const bool use_original_dst_; + const bool hand_off_restored_destination_connections_; const uint32_t per_connection_buffer_limit_bytes_; const uint64_t listener_tag_; const std::string name_; @@ -257,6 +270,7 @@ class ListenerImpl : public Network::ListenerConfig, InitManagerImpl dynamic_init_manager_; bool initialize_canceled_{}; std::vector filter_factories_; + std::vector listener_filter_factories_; DrainManagerPtr local_drain_manager_; bool saw_listener_create_failure_{}; }; diff --git a/source/server/server.cc b/source/server/server.cc index d6c616b08c065..448a9c23cbcd6 100644 --- a/source/server/server.cc +++ b/source/server/server.cc @@ -216,13 +216,11 @@ void InstanceImpl::initialize(Options& options, info.original_start_time_ = original_start_time_; restarter_.shutdownParentAdmin(info); original_start_time_ = info.original_start_time_; - admin_scope_ = stats_store_.createScope("listener.admin."); admin_.reset(new AdminImpl(initial_config.admin().accessLogPath(), initial_config.admin().profilePath(), options.adminAddressPath(), - initial_config.admin().address(), *this, *admin_scope_)); - - handler_->addListener(*admin_, admin_->mutable_socket(), *admin_scope_, 0, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + initial_config.admin().address(), *this, + stats_store_.createScope("listener.admin."))); + handler_->addListener(admin_->listener()); loadServerFlags(initial_config.flagsPath()); diff --git a/source/server/server.h b/source/server/server.h index 3357999c91bbb..02c8840d61d06 100644 --- a/source/server/server.h +++ b/source/server/server.h @@ -191,7 +191,6 @@ class InstanceImpl : Logger::Loggable, public Instance { ProdWorkerFactory worker_factory_; std::unique_ptr listener_manager_; std::unique_ptr config_; - Stats::ScopePtr admin_scope_; Network::DnsResolverSharedPtr dns_resolver_; Event::TimerPtr stat_flush_timer_; LocalInfo::LocalInfoPtr local_info_; diff --git a/source/server/worker_impl.cc b/source/server/worker_impl.cc index cfbba89688fb6..6c18db816944d 100644 --- a/source/server/worker_impl.cc +++ b/source/server/worker_impl.cc @@ -34,7 +34,8 @@ void WorkerImpl::addListener(Network::ListenerConfig& listener, AddListenerCompl // to surface this. dispatcher_->post([this, &listener, completion]() -> void { try { - addListenerWorker(listener); + handler_->addListener(listener); + hooks_.onWorkerListenerAdded(); completion(true); } catch (const Network::CreateListenerException& e) { completion(false); @@ -42,24 +43,6 @@ void WorkerImpl::addListener(Network::ListenerConfig& listener, AddListenerCompl }); } -void WorkerImpl::addListenerWorker(Network::ListenerConfig& listener) { - const Network::ListenerOptions listener_options = {.bind_to_port_ = listener.bindToPort(), - .use_proxy_proto_ = listener.useProxyProto(), - .use_original_dst_ = listener.useOriginalDst(), - .per_connection_buffer_limit_bytes_ = - listener.perConnectionBufferLimitBytes()}; - if (listener.defaultSslContext()) { - handler_->addSslListener(listener.filterChainFactory(), *listener.defaultSslContext(), - listener.socket(), listener.listenerScope(), listener.listenerTag(), - listener_options); - } else { - handler_->addListener(listener.filterChainFactory(), listener.socket(), - listener.listenerScope(), listener.listenerTag(), listener_options); - } - - hooks_.onWorkerListenerAdded(); -} - uint64_t WorkerImpl::numConnections() { uint64_t ret = 0; if (handler_) { diff --git a/source/server/worker_impl.h b/source/server/worker_impl.h index 79da2ae722ace..ea018c3b73d59 100644 --- a/source/server/worker_impl.h +++ b/source/server/worker_impl.h @@ -50,7 +50,6 @@ class WorkerImpl : public Worker, Logger::Loggable { void stopListeners() override; private: - void addListenerWorker(Network::ListenerConfig& listener); void threadRoutine(GuardDog& guard_dog); ThreadLocal::Instance& tls_; diff --git a/test/common/http/codec_client_test.cc b/test/common/http/codec_client_test.cc index a774420349ad9..20014cb3dfa6f 100644 --- a/test/common/http/codec_client_test.cc +++ b/test/common/http/codec_client_test.cc @@ -179,9 +179,7 @@ class CodecNetworkTest : public testing::TestWithParamcreateListener(connection_handler_, socket_, listener_callbacks_, stats_store_, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + upstream_listener_ = dispatcher_->createListener(socket_, listener_callbacks_, true, false); Network::ClientConnectionPtr client_connection = dispatcher_->createClientConnection( socket_.localAddress(), source_address_, Network::Test::createRawBufferSocket()); client_connection_ = client_connection.get(); @@ -190,6 +188,13 @@ class CodecNetworkTest : public testing::TestWithParam void { + Network::ConnectionPtr new_connection = + dispatcher_->createServerConnection(std::move(socket), nullptr); + listener_callbacks_.onNewConnection(std::move(new_connection)); + })); + int expected_callbacks = 2; EXPECT_CALL(listener_callbacks_, onNewConnection_(_)) diff --git a/test/common/network/BUILD b/test/common/network/BUILD index ca42e6c7149ab..bd69afa4c05d5 100644 --- a/test/common/network/BUILD +++ b/test/common/network/BUILD @@ -130,9 +130,11 @@ envoy_cc_test( "//source/common/buffer:buffer_lib", "//source/common/event:dispatcher_includes", "//source/common/event:dispatcher_lib", + "//source/common/filter/listener:proxy_protocol_lib", "//source/common/network:listener_lib", "//source/common/network:utility_lib", "//source/common/stats:stats_lib", + "//source/server:connection_handler_lib", "//test/mocks/buffer:buffer_mocks", "//test/mocks/network:network_mocks", "//test/mocks/server:server_mocks", diff --git a/test/common/network/connection_impl_test.cc b/test/common/network/connection_impl_test.cc index f168e32bb0129..ae869ca079ace 100644 --- a/test/common/network/connection_impl_test.cc +++ b/test/common/network/connection_impl_test.cc @@ -70,11 +70,9 @@ INSTANTIATE_TEST_CASE_P(IpVersions, ConnectionImplDeathTest, TEST_P(ConnectionImplDeathTest, BadFd) { Event::DispatcherImpl dispatcher; - EXPECT_DEATH(ConnectionImpl(dispatcher, -1, - Network::Test::getCanonicalLoopbackAddress(GetParam()), - Network::Test::getCanonicalLoopbackAddress(GetParam()), - Network::Address::InstanceConstSharedPtr(), false, false), - ".*assert failure: fd_ != -1.*"); + EXPECT_DEATH(ConnectionImpl(dispatcher, + std::make_unique(-1, nullptr, nullptr), false), + ".*assert failure: fd\\(\\) != -1.*"); } class ConnectionImplTest : public testing::TestWithParam { @@ -83,9 +81,7 @@ class ConnectionImplTest : public testing::TestWithParam { if (dispatcher_.get() == nullptr) { dispatcher_.reset(new Event::DispatcherImpl); } - listener_ = - dispatcher_->createListener(connection_handler_, socket_, listener_callbacks_, stats_store_, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + listener_ = dispatcher_->createListener(socket_, listener_callbacks_, true, false); client_connection_ = dispatcher_->createClientConnection( socket_.localAddress(), source_address_, Network::Test::createRawBufferSocket()); @@ -93,13 +89,19 @@ class ConnectionImplTest : public testing::TestWithParam { EXPECT_EQ(nullptr, client_connection_->ssl()); const Network::ClientConnection& const_connection = *client_connection_; EXPECT_EQ(nullptr, const_connection.ssl()); - EXPECT_FALSE(client_connection_->usingOriginalDst()); + EXPECT_FALSE(client_connection_->localAddressRestored()); } void connect() { int expected_callbacks = 2; client_connection_->connect(); read_filter_.reset(new NiceMock()); + EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + Network::ConnectionPtr new_connection = + dispatcher_->createServerConnection(std::move(socket), nullptr); + listener_callbacks_.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(listener_callbacks_, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { server_connection_ = std::move(conn); @@ -188,6 +190,13 @@ TEST_P(ConnectionImplTest, CloseDuringConnectCallback) { EXPECT_CALL(client_callbacks_, onEvent(ConnectionEvent::LocalClose)); read_filter_.reset(new NiceMock()); + + EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + Network::ConnectionPtr new_connection = + dispatcher_->createServerConnection(std::move(socket), nullptr); + listener_callbacks_.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(listener_callbacks_, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { server_connection_ = std::move(conn); @@ -236,6 +245,12 @@ TEST_P(ConnectionImplTest, ConnectionStats) { read_filter_.reset(new NiceMock()); MockConnectionStats server_connection_stats; + EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + Network::ConnectionPtr new_connection = + dispatcher_->createServerConnection(std::move(socket), nullptr); + listener_callbacks_.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(listener_callbacks_, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { server_connection_ = std::move(conn); @@ -561,9 +576,7 @@ TEST_P(ConnectionImplTest, BindFailureTest) { new Network::Address::Ipv6Instance(address_string, 0)}; } dispatcher_.reset(new Event::DispatcherImpl); - listener_ = - dispatcher_->createListener(connection_handler_, socket_, listener_callbacks_, stats_store_, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + listener_ = dispatcher_->createListener(socket_, listener_callbacks_, true, false); client_connection_ = dispatcher_->createClientConnection(socket_.localAddress(), source_address_, Network::Test::createRawBufferSocket()); @@ -639,11 +652,9 @@ class ConnectionImplBytesSentTest : public testing::Test { EXPECT_CALL(dispatcher_, createFileEvent_(0, _, _, _)) .WillOnce(DoAll(SaveArg<1>(&file_ready_cb_), Return(new Event::MockFileEvent))); transport_socket_ = new NiceMock; - connection_.reset(new ConnectionImpl( - dispatcher_, 0, Network::Test::getCanonicalLoopbackAddress(Address::IpVersion::v4), - Network::Test::getCanonicalLoopbackAddress(Address::IpVersion::v4), - Network::Address::InstanceConstSharedPtr(), TransportSocketPtr(transport_socket_), false, - true)); + connection_.reset( + new ConnectionImpl(dispatcher_, std::make_unique(0, nullptr, nullptr), + TransportSocketPtr(transport_socket_), true)); connection_->addConnectionCallbacks(callbacks_); } @@ -740,12 +751,7 @@ class ReadBufferLimitTest : public ConnectionImplTest { void readBufferLimitTest(uint32_t read_buffer_limit, uint32_t expected_chunk_size) { const uint32_t buffer_size = 256 * 1024; dispatcher_.reset(new Event::DispatcherImpl); - listener_ = - dispatcher_->createListener(connection_handler_, socket_, listener_callbacks_, stats_store_, - {.bind_to_port_ = true, - .use_proxy_proto_ = false, - .use_original_dst_ = false, - .per_connection_buffer_limit_bytes_ = read_buffer_limit}); + listener_ = dispatcher_->createListener(socket_, listener_callbacks_, true, false); client_connection_ = dispatcher_->createClientConnection( socket_.localAddress(), Network::Address::InstanceConstSharedPtr(), @@ -754,6 +760,13 @@ class ReadBufferLimitTest : public ConnectionImplTest { client_connection_->connect(); read_filter_.reset(new NiceMock()); + EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + Network::ConnectionPtr new_connection = + dispatcher_->createServerConnection(std::move(socket), nullptr); + new_connection->setBufferLimits(read_buffer_limit); + listener_callbacks_.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(listener_callbacks_, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { server_connection_ = std::move(conn); diff --git a/test/common/network/dns_impl_test.cc b/test/common/network/dns_impl_test.cc index 8f775eef0fe77..28accb4de76bc 100644 --- a/test/common/network/dns_impl_test.cc +++ b/test/common/network/dns_impl_test.cc @@ -198,6 +198,14 @@ class TestDnsServerQuery { class TestDnsServer : public ListenerCallbacks { public: + TestDnsServer(Event::DispatcherImpl& dispatcher) : dispatcher_(dispatcher) {} + + void onAccept(ConnectionSocketPtr&& socket, bool) override { + Network::ConnectionPtr new_connection = + dispatcher_.createServerConnection(std::move(socket), nullptr); + onNewConnection(std::move(new_connection)); + } + void onNewConnection(ConnectionPtr&& new_connection) override { TestDnsServerQuery* query = new TestDnsServerQuery(std::move(new_connection), hosts_A_, hosts_AAAA_); @@ -213,6 +221,8 @@ class TestDnsServer : public ListenerCallbacks { } private: + Event::DispatcherImpl& dispatcher_; + HostMap hosts_A_; HostMap hosts_AAAA_; // All queries are tracked so we can do resource reclamation when the test is @@ -276,14 +286,10 @@ class DnsImplTest : public testing::TestWithParam { resolver_ = dispatcher_.createDnsResolver({}); // Instantiate TestDnsServer and listen on a random port on the loopback address. - server_.reset(new TestDnsServer()); + server_.reset(new TestDnsServer(dispatcher_)); socket_.reset( new Network::TcpListenSocket(Network::Test::getCanonicalLoopbackAddress(GetParam()), true)); - listener_ = dispatcher_.createListener(connection_handler_, *socket_, *server_, stats_store_, - {.bind_to_port_ = true, - .use_proxy_proto_ = false, - .use_original_dst_ = false, - .per_connection_buffer_limit_bytes_ = 0}); + listener_ = dispatcher_.createListener(*socket_, *server_, true, false); // Point c-ares at the listener with no search domains and TCP-only. peer_.reset(new DnsResolverImplPeer(dynamic_cast(resolver_.get()))); diff --git a/test/common/network/listener_impl_test.cc b/test/common/network/listener_impl_test.cc index ad3f3a65873ac..7048954f216a2 100644 --- a/test/common/network/listener_impl_test.cc +++ b/test/common/network/listener_impl_test.cc @@ -11,8 +11,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -using testing::ByRef; -using testing::Eq; using testing::Invoke; using testing::Return; using testing::_; @@ -30,17 +28,19 @@ static void errorCallbackTest(Address::IpVersion version) { Network::MockListenerCallbacks listener_callbacks; Network::MockConnectionHandler connection_handler; Network::ListenerPtr listener = - dispatcher.createListener(connection_handler, socket, listener_callbacks, stats_store, - {.bind_to_port_ = true, - .use_proxy_proto_ = false, - .use_original_dst_ = false, - .per_connection_buffer_limit_bytes_ = 0}); + dispatcher.createListener(socket, listener_callbacks, true, false); Network::ClientConnectionPtr client_connection = dispatcher.createClientConnection( socket.localAddress(), Network::Address::InstanceConstSharedPtr(), Network::Test::createRawBufferSocket()); client_connection->connect(); + EXPECT_CALL(listener_callbacks, onAccept_(_, _)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + Network::ConnectionPtr new_connection = + dispatcher.createServerConnection(std::move(socket), nullptr); + listener_callbacks.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(listener_callbacks, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { client_connection->close(ConnectionCloseType::NoFlush); @@ -61,25 +61,12 @@ TEST_P(ListenerImplDeathTest, ErrorCallback) { class TestListenerImpl : public ListenerImpl { public: - TestListenerImpl(Network::ConnectionHandler& conn_handler, Event::DispatcherImpl& dispatcher, - ListenSocket& socket, ListenerCallbacks& cb, Stats::Store& stats_store, - const Network::ListenerOptions& listener_options) - : ListenerImpl(conn_handler, dispatcher, socket, cb, stats_store, listener_options) { - ON_CALL(*this, newConnection(_, _, _, _)) - .WillByDefault(Invoke( - [this](int fd, Address::InstanceConstSharedPtr remote_address, - Address::InstanceConstSharedPtr local_address, bool using_original_dst) -> void { - ListenerImpl::newConnection(fd, remote_address, local_address, using_original_dst); - } - - )); - } + TestListenerImpl(Event::DispatcherImpl& dispatcher, ListenSocket& socket, ListenerCallbacks& cb, + bool bind_to_port, bool hand_off_restored_destination_connections) + : ListenerImpl(dispatcher, socket, cb, bind_to_port, + hand_off_restored_destination_connections) {} MOCK_METHOD1(getLocalAddress, Address::InstanceConstSharedPtr(int fd)); - MOCK_METHOD1(getOriginalDst, Address::InstanceConstSharedPtr(int fd)); - MOCK_METHOD4(newConnection, - void(int fd, Address::InstanceConstSharedPtr remote_address, - Address::InstanceConstSharedPtr local_address, bool using_original_dst)); }; class ListenerImplTest : public testing::TestWithParam { @@ -95,167 +82,6 @@ class ListenerImplTest : public testing::TestWithParam { INSTANTIATE_TEST_CASE_P(IpVersions, ListenerImplTest, testing::ValuesIn(TestEnvironment::getIpVersionsForTest())); -TEST_P(ListenerImplTest, NormalRedirect) { - Stats::IsolatedStoreImpl stats_store; - Event::DispatcherImpl dispatcher; - Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(version_), true); - Network::TcpListenSocket socketDst(alt_address_, false); - Network::MockListenerCallbacks listener_callbacks1; - Network::MockConnectionHandler connection_handler; - // The traffic should redirect from binding listener to the virtual listener. - Network::TestListenerImpl listener(connection_handler, dispatcher, socket, listener_callbacks1, - stats_store, - {.bind_to_port_ = true, - .use_proxy_proto_ = false, - .use_original_dst_ = true, - .per_connection_buffer_limit_bytes_ = 0}); - Network::MockListenerCallbacks listener_callbacks2; - Network::TestListenerImpl listenerDst(connection_handler, dispatcher, socketDst, - listener_callbacks2, stats_store, - Network::ListenerOptions()); - - Network::ClientConnectionPtr client_connection = dispatcher.createClientConnection( - socket.localAddress(), Network::Address::InstanceConstSharedPtr(), - Network::Test::createRawBufferSocket()); - client_connection->connect(); - - EXPECT_CALL(listener, getLocalAddress(_)).Times(0); - EXPECT_CALL(listener, getOriginalDst(_)).WillOnce(Return(alt_address_)); - EXPECT_CALL(connection_handler, findListenerByAddress(Eq(ByRef(*alt_address_)))) - .WillOnce(Return(&listenerDst)); - - EXPECT_CALL(listener, newConnection(_, _, _, _)).Times(0); - EXPECT_CALL(listenerDst, newConnection(_, _, _, _)); - EXPECT_CALL(listener_callbacks2, onNewConnection_(_)) - .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { - EXPECT_EQ(*alt_address_, *conn->localAddress()); - client_connection->close(ConnectionCloseType::NoFlush); - conn->close(ConnectionCloseType::NoFlush); - dispatcher.exit(); - })); - - dispatcher.run(Event::Dispatcher::RunType::Block); -} - -TEST_P(ListenerImplTest, FallbackToWildcardListener) { - Stats::IsolatedStoreImpl stats_store; - Event::DispatcherImpl dispatcher; - Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(version_), true); - Network::TcpListenSocket socketDst(Network::Test::getAnyAddress(version_), false); - Network::MockListenerCallbacks listener_callbacks1; - Network::MockConnectionHandler connection_handler; - // The virtual listener of exact address does not exist, fall back to wild card virtual listener. - Network::TestListenerImpl listener(connection_handler, dispatcher, socket, listener_callbacks1, - stats_store, - {.bind_to_port_ = true, - .use_proxy_proto_ = false, - .use_original_dst_ = true, - .per_connection_buffer_limit_bytes_ = 0}); - Network::MockListenerCallbacks listener_callbacks2; - Network::TestListenerImpl listenerDst(connection_handler, dispatcher, socketDst, - listener_callbacks2, stats_store, - Network::ListenerOptions()); - - Network::ClientConnectionPtr client_connection = dispatcher.createClientConnection( - socket.localAddress(), Network::Address::InstanceConstSharedPtr(), - Network::Test::createRawBufferSocket()); - client_connection->connect(); - - EXPECT_CALL(listener, getLocalAddress(_)).Times(0); - EXPECT_CALL(listener, getOriginalDst(_)).WillOnce(Return(alt_address_)); - EXPECT_CALL(connection_handler, findListenerByAddress(Eq(ByRef(*alt_address_)))) - .WillOnce(Return(&listenerDst)); - - EXPECT_CALL(listener, newConnection(_, _, _, _)).Times(0); - EXPECT_CALL(listenerDst, newConnection(_, _, _, _)); - EXPECT_CALL(listener_callbacks2, onNewConnection_(_)) - .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { - EXPECT_EQ(*alt_address_, *conn->localAddress()); - EXPECT_FALSE(*socketDst.localAddress() == *conn->localAddress()); - client_connection->close(ConnectionCloseType::NoFlush); - conn->close(ConnectionCloseType::NoFlush); - dispatcher.exit(); - })); - - dispatcher.run(Event::Dispatcher::RunType::Block); -} - -TEST_P(ListenerImplTest, WildcardListenerWithOriginalDst) { - Stats::IsolatedStoreImpl stats_store; - Event::DispatcherImpl dispatcher; - Network::TcpListenSocket socket(Network::Test::getAnyAddress(version_), true); - Network::MockListenerCallbacks listener_callbacks; - Network::MockConnectionHandler connection_handler; - // The virtual listener of exact address does not exist, fall back to the wild card listener. - Network::TestListenerImpl listener(connection_handler, dispatcher, socket, listener_callbacks, - stats_store, - {.bind_to_port_ = true, - .use_proxy_proto_ = false, - .use_original_dst_ = true, - .per_connection_buffer_limit_bytes_ = 0}); - - auto local_dst_address = Network::Utility::getAddressWithPort( - *Network::Test::getCanonicalLoopbackAddress(version_), socket.localAddress()->ip()->port()); - Network::ClientConnectionPtr client_connection = dispatcher.createClientConnection( - local_dst_address, Network::Address::InstanceConstSharedPtr(), - Network::Test::createRawBufferSocket()); - client_connection->connect(); - - EXPECT_CALL(listener, getLocalAddress(_)).WillOnce(Return(local_dst_address)); - EXPECT_CALL(listener, getOriginalDst(_)).WillOnce(Return(alt_address_)); - EXPECT_CALL(connection_handler, findListenerByAddress(Eq(ByRef(*alt_address_)))) - .WillOnce(Return(&listener)); - - EXPECT_CALL(listener, newConnection(_, _, _, _)); - EXPECT_CALL(listener_callbacks, onNewConnection_(_)) - .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { - EXPECT_EQ(*conn->localAddress(), *alt_address_); - client_connection->close(ConnectionCloseType::NoFlush); - conn->close(ConnectionCloseType::NoFlush); - dispatcher.exit(); - })); - - dispatcher.run(Event::Dispatcher::RunType::Block); -} - -TEST_P(ListenerImplTest, WildcardListenerNoOriginalDst) { - Stats::IsolatedStoreImpl stats_store; - Event::DispatcherImpl dispatcher; - Network::TcpListenSocket socket(Network::Test::getAnyAddress(version_), true); - Network::MockListenerCallbacks listener_callbacks; - Network::MockConnectionHandler connection_handler; - // The virtual listener of exact address does not exist, fall back to the wild card listener. - Network::TestListenerImpl listener(connection_handler, dispatcher, socket, listener_callbacks, - stats_store, - {.bind_to_port_ = true, - .use_proxy_proto_ = false, - .use_original_dst_ = true, - .per_connection_buffer_limit_bytes_ = 0}); - - auto local_dst_address = Network::Utility::getAddressWithPort( - *Network::Test::getCanonicalLoopbackAddress(version_), socket.localAddress()->ip()->port()); - Network::ClientConnectionPtr client_connection = dispatcher.createClientConnection( - local_dst_address, Network::Address::InstanceConstSharedPtr(), - Network::Test::createRawBufferSocket()); - client_connection->connect(); - - EXPECT_CALL(listener, getLocalAddress(_)).WillOnce(Return(local_dst_address)); - // getOriginalDst() returns the same address as the connections destination. - EXPECT_CALL(listener, getOriginalDst(_)).WillOnce(Return(local_dst_address)); - EXPECT_CALL(connection_handler, findListenerByAddress(_)).Times(0); - - EXPECT_CALL(listener, newConnection(_, _, _, _)); - EXPECT_CALL(listener_callbacks, onNewConnection_(_)) - .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { - EXPECT_EQ(*conn->localAddress(), *local_dst_address); - client_connection->close(ConnectionCloseType::NoFlush); - conn->close(ConnectionCloseType::NoFlush); - dispatcher.exit(); - })); - - dispatcher.run(Event::Dispatcher::RunType::Block); -} - TEST_P(ListenerImplTest, UseActualDst) { Stats::IsolatedStoreImpl stats_store; Event::DispatcherImpl dispatcher; @@ -264,16 +90,9 @@ TEST_P(ListenerImplTest, UseActualDst) { Network::MockListenerCallbacks listener_callbacks1; Network::MockConnectionHandler connection_handler; // Do not redirect since use_original_dst is false. - Network::TestListenerImpl listener(connection_handler, dispatcher, socket, listener_callbacks1, - stats_store, - {.bind_to_port_ = true, - .use_proxy_proto_ = false, - .use_original_dst_ = false, - .per_connection_buffer_limit_bytes_ = 0}); + Network::TestListenerImpl listener(dispatcher, socket, listener_callbacks1, true, true); Network::MockListenerCallbacks listener_callbacks2; - Network::TestListenerImpl listenerDst(connection_handler, dispatcher, socketDst, - listener_callbacks2, stats_store, - Network::ListenerOptions()); + Network::TestListenerImpl listenerDst(dispatcher, socketDst, listener_callbacks2, false, false); Network::ClientConnectionPtr client_connection = dispatcher.createClientConnection( socket.localAddress(), Network::Address::InstanceConstSharedPtr(), @@ -281,11 +100,14 @@ TEST_P(ListenerImplTest, UseActualDst) { client_connection->connect(); EXPECT_CALL(listener, getLocalAddress(_)).Times(0); - EXPECT_CALL(listener, getOriginalDst(_)).Times(0); - EXPECT_CALL(connection_handler, findListenerByAddress(_)).Times(0); - EXPECT_CALL(listener, newConnection(_, _, _, _)).Times(1); - EXPECT_CALL(listenerDst, newConnection(_, _, _, _)).Times(0); + EXPECT_CALL(listener_callbacks2, onAccept_(_, _)).Times(0); + EXPECT_CALL(listener_callbacks1, onAccept_(_, _)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + Network::ConnectionPtr new_connection = + dispatcher.createServerConnection(std::move(socket), nullptr); + listener_callbacks1.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(listener_callbacks1, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { EXPECT_EQ(*conn->localAddress(), *socket.localAddress()); @@ -304,12 +126,7 @@ TEST_P(ListenerImplTest, WildcardListenerUseActualDst) { Network::MockListenerCallbacks listener_callbacks; Network::MockConnectionHandler connection_handler; // Do not redirect since use_original_dst is false. - Network::TestListenerImpl listener(connection_handler, dispatcher, socket, listener_callbacks, - stats_store, - {.bind_to_port_ = true, - .use_proxy_proto_ = false, - .use_original_dst_ = false, - .per_connection_buffer_limit_bytes_ = 0}); + Network::TestListenerImpl listener(dispatcher, socket, listener_callbacks, true, true); auto local_dst_address = Network::Utility::getAddressWithPort( *Network::Test::getCanonicalLoopbackAddress(version_), socket.localAddress()->ip()->port()); @@ -319,10 +136,13 @@ TEST_P(ListenerImplTest, WildcardListenerUseActualDst) { client_connection->connect(); EXPECT_CALL(listener, getLocalAddress(_)).WillOnce(Return(local_dst_address)); - EXPECT_CALL(listener, getOriginalDst(_)).Times(0); - EXPECT_CALL(connection_handler, findListenerByAddress(_)).Times(0); - EXPECT_CALL(listener, newConnection(_, _, _, _)).Times(1); + EXPECT_CALL(listener_callbacks, onAccept_(_, _)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + Network::ConnectionPtr new_connection = + dispatcher.createServerConnection(std::move(socket), nullptr); + listener_callbacks.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(listener_callbacks, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { EXPECT_EQ(*conn->localAddress(), *local_dst_address); diff --git a/test/common/network/proxy_protocol_test.cc b/test/common/network/proxy_protocol_test.cc index 9d3574119ddfd..040d704fceca8 100644 --- a/test/common/network/proxy_protocol_test.cc +++ b/test/common/network/proxy_protocol_test.cc @@ -4,11 +4,14 @@ #include "common/buffer/buffer_impl.h" #include "common/event/dispatcher_impl.h" +#include "common/filter/listener/proxy_protocol.h" #include "common/network/listen_socket_impl.h" #include "common/network/listener_impl.h" #include "common/network/utility.h" #include "common/stats/stats_impl.h" +#include "server/connection_handler_impl.h" + #include "test/mocks/buffer/mocks.h" #include "test/mocks/network/mocks.h" #include "test/mocks/server/mocks.h" @@ -19,45 +22,63 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +using testing::AtLeast; using testing::Invoke; using testing::NiceMock; +using testing::Return; using testing::_; namespace Envoy { namespace Network { -class ProxyProtocolTest : public testing::TestWithParam { +// Build again on the basis of the connection_handler_test.cc + +class ProxyProtocolTest : public testing::TestWithParam, + public Network::ListenerConfig, + protected Logger::Loggable { public: ProxyProtocolTest() : socket_(Network::Test::getCanonicalLoopbackAddress(GetParam()), true), - listener_(dispatcher_.createListener(connection_handler_, socket_, callbacks_, stats_store_, - {.bind_to_port_ = true, - .use_proxy_proto_ = true, - .use_original_dst_ = false, - .per_connection_buffer_limit_bytes_ = 0})) { + connection_handler_(new Server::ConnectionHandlerImpl(ENVOY_LOGGER(), dispatcher_)), + name_("proxy") { + connection_handler_->addListener(*this); conn_ = dispatcher_.createClientConnection(socket_.localAddress(), Network::Address::InstanceConstSharedPtr(), Network::Test::createRawBufferSocket()); conn_->addConnectionCallbacks(connection_callbacks_); } - void connect() { - conn_->connect(); - read_filter_.reset(new NiceMock()); - EXPECT_CALL(callbacks_, onNewConnection_(_)) - .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { - server_connection_ = std::move(conn); - server_connection_->addConnectionCallbacks(server_callbacks_); - server_connection_->addReadFilter(read_filter_); + // Listener + Network::FilterChainFactory& filterChainFactory() override { return factory_; } + Network::ListenSocket& socket() override { return socket_; } + Ssl::ServerContext* defaultSslContext() override { return nullptr; } + bool bindToPort() override { return true; } + bool handOffRestoredDestinationConnections() const override { return false; } + uint32_t perConnectionBufferLimitBytes() override { return 0; } + Stats::Scope& listenerScope() override { return stats_store_; } + uint64_t listenerTag() const override { return 1; } + const std::string& name() const override { return name_; } + + void connect(bool read = true) { + EXPECT_CALL(factory_, createListenerFilterChain(_)) + .WillOnce(Invoke([&](ListenerFilterManager& filter_manager) -> bool { + filter_manager.addAcceptFilter( + std::make_unique( + std::make_shared( + listenerScope()))); + return true; })); - EXPECT_CALL(connection_callbacks_, onEvent(ConnectionEvent::Connected)) - .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { dispatcher_.exit(); })); - - dispatcher_.run(Event::Dispatcher::RunType::Block); - } - - void connectNoRead() { conn_->connect(); + if (read) { + read_filter_.reset(new NiceMock()); + EXPECT_CALL(factory_, createNetworkFilterChain(_)) + .WillOnce(Invoke([&](Connection& connection) -> bool { + server_connection_ = &connection; + connection.addConnectionCallbacks(server_callbacks_); + connection.addReadFilter(read_filter_); + return true; + })); + } EXPECT_CALL(connection_callbacks_, onEvent(ConnectionEvent::Connected)) .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { dispatcher_.exit(); })); dispatcher_.run(Event::Dispatcher::RunType::Block); @@ -68,6 +89,19 @@ class ProxyProtocolTest : public testing::TestWithParam { conn_->write(buf); } + void expectData(std::string expected) { + EXPECT_CALL(*read_filter_, onNewConnection()); + EXPECT_CALL(*read_filter_, onData(_)) + .WillOnce(Invoke([&](Buffer::Instance& buffer) -> FilterStatus { + EXPECT_EQ(TestUtility::bufferToString(buffer), expected); + buffer.drain(expected.length()); + dispatcher_.exit(); + return Network::FilterStatus::Continue; + })); + + dispatcher_.run(Event::Dispatcher::RunType::Block); + } + void disconnect() { EXPECT_CALL(connection_callbacks_, onEvent(ConnectionEvent::LocalClose)); EXPECT_CALL(server_callbacks_, onEvent(ConnectionEvent::RemoteClose)) @@ -90,14 +124,14 @@ class ProxyProtocolTest : public testing::TestWithParam { Event::DispatcherImpl dispatcher_; TcpListenSocket socket_; Stats::IsolatedStoreImpl stats_store_; - MockListenerCallbacks callbacks_; - Network::MockConnectionHandler connection_handler_; - Network::ListenerPtr listener_; + Network::ConnectionHandlerPtr connection_handler_; + Network::MockFilterChainFactory factory_; ClientConnectionPtr conn_; NiceMock connection_callbacks_; - Network::ConnectionPtr server_connection_; + Network::Connection* server_connection_; Network::MockConnectionCallbacks server_callbacks_; std::shared_ptr read_filter_; + std::string name_; }; // Parameterize the listener socket address version. @@ -108,17 +142,9 @@ TEST_P(ProxyProtocolTest, Basic) { connect(); write("PROXY TCP4 1.2.3.4 253.253.253.253 65535 1234\r\nmore data"); - EXPECT_CALL(*read_filter_, onNewConnection()); - EXPECT_CALL(*read_filter_, onData(_)) - .WillOnce(Invoke([&](Buffer::Instance& buffer) -> FilterStatus { - EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "1.2.3.4"); + expectData("more data"); - EXPECT_EQ(TestUtility::bufferToString(buffer), "more data"); - buffer.drain(9); - return Network::FilterStatus::Continue; - })); - - dispatcher_.run(Event::Dispatcher::RunType::NonBlock); + EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "1.2.3.4"); disconnect(); } @@ -127,17 +153,9 @@ TEST_P(ProxyProtocolTest, BasicV6) { connect(); write("PROXY TCP6 1:2:3::4 5:6::7:8 65535 1234\r\nmore data"); - EXPECT_CALL(*read_filter_, onNewConnection()); - EXPECT_CALL(*read_filter_, onData(_)) - .WillOnce(Invoke([&](Buffer::Instance& buffer) -> FilterStatus { - EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "1:2:3::4"); + expectData("more data"); - EXPECT_EQ(TestUtility::bufferToString(buffer), "more data"); - buffer.drain(9); - return Network::FilterStatus::Continue; - })); - - dispatcher_.run(Event::Dispatcher::RunType::NonBlock); + EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "1:2:3::4"); disconnect(); } @@ -148,13 +166,17 @@ TEST_P(ProxyProtocolTest, Fragmented) { write(" 254.254.2"); write("54.254 1.2"); write(".3.4 65535"); - write(" 1234\r\n"); + write(" 1234\r\n..."); - dispatcher_.run(Event::Dispatcher::RunType::NonBlock); - - disconnect(); + // If there is no data after the PROXY line, the read filter does not receive even the + // onNewConnection() callback. We need this in order to run the dispatcher in blocking + // mode to make sure that proxy protocol processing is completed before we start testing + // the results. Since we must have data we might as well check that we get it. + expectData("..."); EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "254.254.254.254"); + + disconnect(); } TEST_P(ProxyProtocolTest, PartialRead) { @@ -167,17 +189,17 @@ TEST_P(ProxyProtocolTest, PartialRead) { write("54.254 1.2"); write(".3.4 65535"); - write(" 1234\r\n"); - - dispatcher_.run(Event::Dispatcher::RunType::NonBlock); + write(" 1234\r\n..."); - disconnect(); + expectData("..."); EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "254.254.254.254"); + + disconnect(); } TEST_P(ProxyProtocolTest, MalformedProxyLine) { - connectNoRead(); + connect(false); write("BOGUS\r"); dispatcher_.run(Event::Dispatcher::RunType::NonBlock); @@ -187,74 +209,74 @@ TEST_P(ProxyProtocolTest, MalformedProxyLine) { } TEST_P(ProxyProtocolTest, ProxyLineTooLarge) { - connectNoRead(); + connect(false); write("012345678901234567890123456789012345678901234567890123456789" "012345678901234567890123456789012345678901234567890123456789"); expectProxyProtoError(); } TEST_P(ProxyProtocolTest, NotEnoughFields) { - connectNoRead(); + connect(false); write("PROXY TCP6 1:2:3::4 5:6::7:8 1234\r\nmore data"); expectProxyProtoError(); } TEST_P(ProxyProtocolTest, UnsupportedProto) { - connectNoRead(); + connect(false); write("PROXY UDP6 1:2:3::4 5:6::7:8 1234 5678\r\nmore data"); expectProxyProtoError(); } TEST_P(ProxyProtocolTest, InvalidSrcAddress) { - connectNoRead(); + connect(false); write("PROXY TCP4 230.0.0.1 10.1.1.3 1234 5678\r\nmore data"); expectProxyProtoError(); } TEST_P(ProxyProtocolTest, InvalidDstAddress) { - connectNoRead(); + connect(false); write("PROXY TCP4 10.1.1.2 0.0.0.0 1234 5678\r\nmore data"); expectProxyProtoError(); } TEST_P(ProxyProtocolTest, BadPort) { - connectNoRead(); + connect(false); write("PROXY TCP6 1:2:3::4 5:6::7:8 1234 abc\r\nmore data"); expectProxyProtoError(); } TEST_P(ProxyProtocolTest, NegativePort) { - connectNoRead(); + connect(false); write("PROXY TCP6 1:2:3::4 5:6::7:8 -1 1234\r\nmore data"); expectProxyProtoError(); } TEST_P(ProxyProtocolTest, PortOutOfRange) { - connectNoRead(); + connect(false); write("PROXY TCP6 1:2:3::4 5:6::7:8 66776 1234\r\nmore data"); expectProxyProtoError(); } TEST_P(ProxyProtocolTest, BadAddress) { - connectNoRead(); + connect(false); write("PROXY TCP6 1::2:3::4 5:6::7:8 1234 5678\r\nmore data"); expectProxyProtoError(); } TEST_P(ProxyProtocolTest, AddressVersionsNotMatch) { - connectNoRead(); + connect(false); write("PROXY TCP4 [1:2:3::4] 1.2.3.4 1234 5678\r\nmore data"); expectProxyProtoError(); } TEST_P(ProxyProtocolTest, AddressVersionsNotMatch2) { - connectNoRead(); + connect(false); write("PROXY TCP4 1.2.3.4 [1:2:3: 1234 4]:5678\r\nmore data"); expectProxyProtoError(); } TEST_P(ProxyProtocolTest, Truncated) { - connectNoRead(); + connect(false); write("PROXY TCP4 1.2.3.4 5.6.7.8 1234 5678"); dispatcher_.run(Event::Dispatcher::RunType::NonBlock); @@ -266,7 +288,7 @@ TEST_P(ProxyProtocolTest, Truncated) { } TEST_P(ProxyProtocolTest, Closed) { - connectNoRead(); + connect(false); write("PROXY TCP4 1.2.3"); dispatcher_.run(Event::Dispatcher::RunType::NonBlock); @@ -278,37 +300,61 @@ TEST_P(ProxyProtocolTest, Closed) { } TEST_P(ProxyProtocolTest, ClosedEmpty) { + // We may or may not get these, depending on the operating system timing. + EXPECT_CALL(factory_, createListenerFilterChain(_)).Times(AtLeast(0)); + EXPECT_CALL(factory_, createNetworkFilterChain(_)).Times(AtLeast(0)); conn_->connect(); conn_->close(ConnectionCloseType::NoFlush); dispatcher_.run(Event::Dispatcher::RunType::NonBlock); } -class WildcardProxyProtocolTest : public testing::TestWithParam { +class WildcardProxyProtocolTest : public testing::TestWithParam, + public Network::ListenerConfig, + protected Logger::Loggable { public: WildcardProxyProtocolTest() : socket_(Network::Test::getAnyAddress(GetParam()), true), local_dst_address_(Network::Utility::getAddressWithPort( *Network::Test::getCanonicalLoopbackAddress(GetParam()), socket_.localAddress()->ip()->port())), - listener_(dispatcher_.createListener(connection_handler_, socket_, callbacks_, stats_store_, - {.bind_to_port_ = true, - .use_proxy_proto_ = true, - .use_original_dst_ = false, - .per_connection_buffer_limit_bytes_ = 0})) { + connection_handler_(new Server::ConnectionHandlerImpl(ENVOY_LOGGER(), dispatcher_)), + name_("proxy") { + connection_handler_->addListener(*this); conn_ = dispatcher_.createClientConnection(local_dst_address_, Network::Address::InstanceConstSharedPtr(), Network::Test::createRawBufferSocket()); conn_->addConnectionCallbacks(connection_callbacks_); + + EXPECT_CALL(factory_, createListenerFilterChain(_)) + .WillOnce(Invoke([&](ListenerFilterManager& filter_manager) -> bool { + filter_manager.addAcceptFilter( + std::make_unique( + std::make_shared( + listenerScope()))); + return true; + })); } + // Network::ListenerConfig + Network::FilterChainFactory& filterChainFactory() override { return factory_; } + Network::ListenSocket& socket() override { return socket_; } + Ssl::ServerContext* defaultSslContext() override { return nullptr; } + bool bindToPort() override { return true; } + bool handOffRestoredDestinationConnections() const override { return false; } + uint32_t perConnectionBufferLimitBytes() override { return 0; } + Stats::Scope& listenerScope() override { return stats_store_; } + uint64_t listenerTag() const override { return 1; } + const std::string& name() const override { return name_; } + void connect() { conn_->connect(); read_filter_.reset(new NiceMock()); - EXPECT_CALL(callbacks_, onNewConnection_(_)) - .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { - server_connection_ = std::move(conn); - server_connection_->addConnectionCallbacks(server_callbacks_); - server_connection_->addReadFilter(read_filter_); + EXPECT_CALL(factory_, createNetworkFilterChain(_)) + .WillOnce(Invoke([&](Connection& connection) -> bool { + server_connection_ = &connection; + connection.addConnectionCallbacks(server_callbacks_); + connection.addReadFilter(read_filter_); + return true; })); EXPECT_CALL(connection_callbacks_, onEvent(ConnectionEvent::Connected)) .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { dispatcher_.exit(); })); @@ -320,6 +366,19 @@ class WildcardProxyProtocolTest : public testing::TestWithParamwrite(buf); } + void expectData(std::string expected) { + EXPECT_CALL(*read_filter_, onNewConnection()); + EXPECT_CALL(*read_filter_, onData(_)) + .WillOnce(Invoke([&](Buffer::Instance& buffer) -> FilterStatus { + EXPECT_EQ(TestUtility::bufferToString(buffer), expected); + buffer.drain(expected.length()); + dispatcher_.exit(); + return Network::FilterStatus::Continue; + })); + + dispatcher_.run(Event::Dispatcher::RunType::Block); + } + void disconnect() { EXPECT_CALL(connection_callbacks_, onEvent(ConnectionEvent::LocalClose)); conn_->close(ConnectionCloseType::NoFlush); @@ -333,14 +392,14 @@ class WildcardProxyProtocolTest : public testing::TestWithParam connection_callbacks_; - Network::ConnectionPtr server_connection_; + Network::Connection* server_connection_; Network::MockConnectionCallbacks server_callbacks_; std::shared_ptr read_filter_; + std::string name_; }; // Parameterize the listener socket address version. @@ -351,17 +410,11 @@ TEST_P(WildcardProxyProtocolTest, Basic) { connect(); write("PROXY TCP4 1.2.3.4 254.254.254.254 65535 1234\r\nmore data"); - EXPECT_CALL(*read_filter_, onNewConnection()); - EXPECT_CALL(*read_filter_, onData(_)) - .WillOnce(Invoke([&](Buffer::Instance& buffer) -> FilterStatus { - EXPECT_EQ(server_connection_->remoteAddress()->asString(), "1.2.3.4:65535"); - EXPECT_EQ(server_connection_->localAddress()->asString(), "254.254.254.254:1234"); + expectData("more data"); + + EXPECT_EQ(server_connection_->remoteAddress()->asString(), "1.2.3.4:65535"); + EXPECT_EQ(server_connection_->localAddress()->asString(), "254.254.254.254:1234"); - EXPECT_EQ(TestUtility::bufferToString(buffer), "more data"); - buffer.drain(9); - return Network::FilterStatus::Continue; - })); - dispatcher_.run(Event::Dispatcher::RunType::NonBlock); disconnect(); } @@ -369,18 +422,11 @@ TEST_P(WildcardProxyProtocolTest, BasicV6) { connect(); write("PROXY TCP6 1:2:3::4 5:6::7:8 65535 1234\r\nmore data"); - EXPECT_CALL(*read_filter_, onNewConnection()); - EXPECT_CALL(*read_filter_, onData(_)) - .WillOnce(Invoke([&](Buffer::Instance& buffer) -> FilterStatus { - EXPECT_EQ(server_connection_->remoteAddress()->asString(), "[1:2:3::4]:65535"); - EXPECT_EQ(server_connection_->localAddress()->asString(), "[5:6::7:8]:1234"); + expectData("more data"); - EXPECT_EQ(TestUtility::bufferToString(buffer), "more data"); - buffer.drain(9); - return Network::FilterStatus::Continue; - })); + EXPECT_EQ(server_connection_->remoteAddress()->asString(), "[1:2:3::4]:65535"); + EXPECT_EQ(server_connection_->localAddress()->asString(), "[5:6::7:8]:1234"); - dispatcher_.run(Event::Dispatcher::RunType::NonBlock); disconnect(); } diff --git a/test/common/ssl/ssl_socket_test.cc b/test/common/ssl/ssl_socket_test.cc index d5d2f9de471ad..5a87edd4230b3 100644 --- a/test/common/ssl/ssl_socket_test.cc +++ b/test/common/ssl/ssl_socket_test.cc @@ -57,9 +57,7 @@ void testUtil(const std::string& client_ctx_json, const std::string& server_ctx_ Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(version), true); Network::MockListenerCallbacks callbacks; Network::MockConnectionHandler connection_handler; - Network::ListenerPtr listener = - dispatcher.createSslListener(connection_handler, *server_ctx, socket, callbacks, stats_store, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + Network::ListenerPtr listener = dispatcher.createListener(socket, callbacks, true, false); Json::ObjectSharedPtr client_ctx_loader = TestEnvironment::jsonLoadFromString(client_ctx_json); ClientContextConfigImpl client_ctx_config(*client_ctx_loader); @@ -71,6 +69,12 @@ void testUtil(const std::string& client_ctx_json, const std::string& server_ctx_ Network::ConnectionPtr server_connection; Network::MockConnectionCallbacks server_connection_callbacks; + EXPECT_CALL(callbacks, onAccept_(_, _)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + Network::ConnectionPtr new_connection = + dispatcher.createServerConnection(std::move(socket), server_ctx.get()); + callbacks.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(callbacks, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { server_connection = std::move(conn); @@ -143,9 +147,7 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto, Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(version), true); NiceMock callbacks; Network::MockConnectionHandler connection_handler; - Network::ListenerPtr listener = dispatcher.createSslListener( - connection_handler, *server_contexts[0], socket, callbacks, stats_store, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + Network::ListenerPtr listener = dispatcher.createListener(socket, callbacks, true, false); ClientContextConfigImpl client_ctx_config(client_ctx_proto); ClientSslSocketFactory ssl_socket_factory(client_ctx_config, manager, stats_store); @@ -169,6 +171,12 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto, Network::ConnectionPtr server_connection; Network::MockConnectionCallbacks server_connection_callbacks; + EXPECT_CALL(callbacks, onAccept_(_, _)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + Network::ConnectionPtr new_connection = + dispatcher.createServerConnection(std::move(socket), server_contexts[0].get()); + callbacks.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(callbacks, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { server_connection = std::move(conn); @@ -677,9 +685,7 @@ TEST_P(SslSocketTest, FlushCloseDuringHandshake) { Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(GetParam()), true); Network::MockListenerCallbacks callbacks; Network::MockConnectionHandler connection_handler; - Network::ListenerPtr listener = - dispatcher.createSslListener(connection_handler, *server_ctx, socket, callbacks, stats_store, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + Network::ListenerPtr listener = dispatcher.createListener(socket, callbacks, true, false); Network::ClientConnectionPtr client_connection = dispatcher.createClientConnection( socket.localAddress(), Network::Address::InstanceConstSharedPtr(), @@ -690,6 +696,12 @@ TEST_P(SslSocketTest, FlushCloseDuringHandshake) { Network::ConnectionPtr server_connection; Network::MockConnectionCallbacks server_connection_callbacks; + EXPECT_CALL(callbacks, onAccept_(_, _)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + Network::ConnectionPtr new_connection = + dispatcher.createServerConnection(std::move(socket), server_ctx.get()); + callbacks.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(callbacks, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { server_connection = std::move(conn); @@ -729,9 +741,7 @@ TEST_P(SslSocketTest, ClientAuthMultipleCAs) { Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(GetParam()), true); Network::MockListenerCallbacks callbacks; Network::MockConnectionHandler connection_handler; - Network::ListenerPtr listener = - dispatcher.createSslListener(connection_handler, *server_ctx, socket, callbacks, stats_store, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + Network::ListenerPtr listener = dispatcher.createListener(socket, callbacks, true, false); std::string client_ctx_json = R"EOF( { @@ -762,6 +772,12 @@ TEST_P(SslSocketTest, ClientAuthMultipleCAs) { Network::ConnectionPtr server_connection; Network::MockConnectionCallbacks server_connection_callbacks; + EXPECT_CALL(callbacks, onAccept_(_, _)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + Network::ConnectionPtr new_connection = + dispatcher.createServerConnection(std::move(socket), server_ctx.get()); + callbacks.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(callbacks, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { server_connection = std::move(conn); @@ -806,12 +822,8 @@ void testTicketSessionResumption(const std::string& server_ctx_json1, Network::TcpListenSocket socket2(Network::Test::getCanonicalLoopbackAddress(ip_version), true); NiceMock callbacks; Network::MockConnectionHandler connection_handler; - Network::ListenerPtr listener1 = dispatcher.createSslListener( - connection_handler, *server_ctx1, socket1, callbacks, stats_store, - Network::ListenerOptions::listenerOptionsWithBindToPort()); - Network::ListenerPtr listener2 = dispatcher.createSslListener( - connection_handler, *server_ctx2, socket2, callbacks, stats_store, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + Network::ListenerPtr listener1 = dispatcher.createListener(socket1, callbacks, true, false); + Network::ListenerPtr listener2 = dispatcher.createListener(socket2, callbacks, true, false); Json::ObjectSharedPtr client_ctx_loader = TestEnvironment::jsonLoadFromString(client_ctx_json); ClientContextConfigImpl client_ctx_config(*client_ctx_loader); @@ -826,6 +838,14 @@ void testTicketSessionResumption(const std::string& server_ctx_json1, SSL_SESSION* ssl_session = nullptr; Network::ConnectionPtr server_connection; + EXPECT_CALL(callbacks, onAccept_(_, _)) + .WillRepeatedly(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + ServerContext* ctx = socket->localAddress() == socket1.localAddress() ? server_ctx1.get() + : server_ctx2.get(); + Network::ConnectionPtr new_connection = + dispatcher.createServerConnection(std::move(socket), ctx); + callbacks.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(callbacks, onNewConnection_(_)) .WillOnce(Invoke( [&](Network::ConnectionPtr& conn) -> void { server_connection = std::move(conn); })); @@ -1123,12 +1143,8 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { Network::TcpListenSocket socket2(Network::Test::getCanonicalLoopbackAddress(GetParam()), true); Network::MockListenerCallbacks callbacks; Network::MockConnectionHandler connection_handler; - Network::ListenerPtr listener = - dispatcher.createSslListener(connection_handler, *server_ctx, socket, callbacks, stats_store, - Network::ListenerOptions::listenerOptionsWithBindToPort()); - Network::ListenerPtr listener2 = dispatcher.createSslListener( - connection_handler, *server2_ctx, socket2, callbacks, stats_store, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + Network::ListenerPtr listener = dispatcher.createListener(socket, callbacks, true, false); + Network::ListenerPtr listener2 = dispatcher.createListener(socket2, callbacks, true, false); std::string client_ctx_json = R"EOF( { @@ -1151,6 +1167,15 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { SSL_SESSION* ssl_session = nullptr; Network::ConnectionPtr server_connection; Network::MockConnectionCallbacks server_connection_callbacks; + EXPECT_CALL(callbacks, onAccept_(_, _)) + .WillRepeatedly(Invoke([&](Network::ConnectionSocketPtr& accepted_socket, bool) -> void { + ServerContext* ctx = accepted_socket->localAddress() == socket.localAddress() + ? server_ctx.get() + : server2_ctx.get(); + Network::ConnectionPtr new_connection = + dispatcher.createServerConnection(std::move(accepted_socket), ctx); + callbacks.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(callbacks, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { server_connection = std::move(conn); @@ -1223,9 +1248,7 @@ TEST_P(SslSocketTest, SslError) { Network::TcpListenSocket socket(Network::Test::getCanonicalLoopbackAddress(GetParam()), true); Network::MockListenerCallbacks callbacks; Network::MockConnectionHandler connection_handler; - Network::ListenerPtr listener = - dispatcher.createSslListener(connection_handler, *server_ctx, socket, callbacks, stats_store, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + Network::ListenerPtr listener = dispatcher.createListener(socket, callbacks, true, false); Network::ClientConnectionPtr client_connection = dispatcher.createClientConnection( socket.localAddress(), Network::Address::InstanceConstSharedPtr(), @@ -1236,6 +1259,12 @@ TEST_P(SslSocketTest, SslError) { Network::ConnectionPtr server_connection; Network::MockConnectionCallbacks server_connection_callbacks; + EXPECT_CALL(callbacks, onAccept_(_, _)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + Network::ConnectionPtr new_connection = + dispatcher.createServerConnection(std::move(socket), server_ctx.get()); + callbacks.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(callbacks, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { server_connection = std::move(conn); @@ -1816,18 +1845,13 @@ TEST_P(SslSocketTest, SniProtocolVersions) { class SslReadBufferLimitTest : public SslCertsTest, public testing::WithParamInterface { public: - void initialize(uint32_t read_buffer_limit) { + void initialize() { server_ctx_loader_ = TestEnvironment::jsonLoadFromString(server_ctx_json_); server_ctx_config_.reset(new ServerContextConfigImpl(*server_ctx_loader_)); manager_.reset(new ContextManagerImpl(runtime_)); server_ctx_ = manager_->createSslServerContext("", {}, stats_store_, *server_ctx_config_, true); - listener_ = dispatcher_->createSslListener( - connection_handler_, *server_ctx_, socket_, listener_callbacks_, stats_store_, - {.bind_to_port_ = true, - .use_proxy_proto_ = false, - .use_original_dst_ = false, - .per_connection_buffer_limit_bytes_ = read_buffer_limit}); + listener_ = dispatcher_->createListener(socket_, listener_callbacks_, true, false); client_ctx_loader_ = TestEnvironment::jsonLoadFromString(client_ctx_json_); client_ctx_config_.reset(new ClientContextConfigImpl(*client_ctx_loader_)); @@ -1844,8 +1868,15 @@ class SslReadBufferLimitTest : public SslCertsTest, void readBufferLimitTest(uint32_t read_buffer_limit, uint32_t expected_chunk_size, uint32_t write_size, uint32_t num_writes, bool reserve_write_space) { - initialize(read_buffer_limit); - + initialize(); + + EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + Network::ConnectionPtr new_connection = + dispatcher_->createServerConnection(std::move(socket), server_ctx_.get()); + new_connection->setBufferLimits(read_buffer_limit); + listener_callbacks_.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(listener_callbacks_, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { server_connection_ = std::move(conn); @@ -1917,11 +1948,18 @@ class SslReadBufferLimitTest : public SslCertsTest, return new Buffer::WatermarkBuffer(below_low, above_high); })); - initialize(read_buffer_limit); + initialize(); EXPECT_CALL(client_callbacks_, onEvent(Network::ConnectionEvent::Connected)) .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { dispatcher_->exit(); })); + EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + Network::ConnectionPtr new_connection = + dispatcher_->createServerConnection(std::move(socket), server_ctx_.get()); + new_connection->setBufferLimits(read_buffer_limit); + listener_callbacks_.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(listener_callbacks_, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { server_connection_ = std::move(conn); @@ -2030,8 +2068,15 @@ TEST_P(SslReadBufferLimitTest, TestBind) { new Network::Address::Ipv6Instance(address_string, 0)}; } - initialize(0); + initialize(); + EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + Network::ConnectionPtr new_connection = + dispatcher_->createServerConnection(std::move(socket), server_ctx_.get()); + new_connection->setBufferLimits(0); + listener_callbacks_.onNewConnection(std::move(new_connection)); + })); EXPECT_CALL(listener_callbacks_, onNewConnection_(_)) .WillOnce(Invoke([&](Network::ConnectionPtr& conn) -> void { server_connection_ = std::move(conn); diff --git a/test/common/upstream/original_dst_cluster_test.cc b/test/common/upstream/original_dst_cluster_test.cc index 22b24fb80a6f2..f676de9b640ea 100644 --- a/test/common/upstream/original_dst_cluster_test.cc +++ b/test/common/upstream/original_dst_cluster_test.cc @@ -154,7 +154,7 @@ TEST_F(OriginalDstClusterTest, NoContext) { NiceMock connection; TestLoadBalancerContext lb_context(&connection); - EXPECT_CALL(connection, usingOriginalDst()).WillOnce(Return(false)); + EXPECT_CALL(connection, localAddressRestored()).WillOnce(Return(false)); // First argument is normally the reference to the ThreadLocalCluster's HostSet, but in these // tests we do not have the thread local clusters, so we pass a reference to the HostSet of the // primary cluster. The implementation handles both cases the same. @@ -169,7 +169,7 @@ TEST_F(OriginalDstClusterTest, NoContext) { NiceMock connection; TestLoadBalancerContext lb_context(&connection); connection.local_address_ = std::make_shared("unix://foo"); - EXPECT_CALL(connection, usingOriginalDst()).WillRepeatedly(Return(true)); + EXPECT_CALL(connection, localAddressRestored()).WillRepeatedly(Return(true)); OriginalDstCluster::LoadBalancer lb(cluster_->prioritySet(), cluster_); EXPECT_CALL(dispatcher_, post(_)).Times(0); @@ -205,7 +205,7 @@ TEST_F(OriginalDstClusterTest, Membership) { NiceMock connection; TestLoadBalancerContext lb_context(&connection); connection.local_address_ = std::make_shared("10.10.11.11"); - EXPECT_CALL(connection, usingOriginalDst()).WillRepeatedly(Return(true)); + EXPECT_CALL(connection, localAddressRestored()).WillRepeatedly(Return(true)); OriginalDstCluster::LoadBalancer lb(cluster_->prioritySet(), cluster_); Event::PostCb post_cb; @@ -292,12 +292,12 @@ TEST_F(OriginalDstClusterTest, Membership2) { NiceMock connection1; TestLoadBalancerContext lb_context1(&connection1); connection1.local_address_ = std::make_shared("10.10.11.11"); - EXPECT_CALL(connection1, usingOriginalDst()).WillRepeatedly(Return(true)); + EXPECT_CALL(connection1, localAddressRestored()).WillRepeatedly(Return(true)); NiceMock connection2; TestLoadBalancerContext lb_context2(&connection2); connection2.local_address_ = std::make_shared("10.10.11.12"); - EXPECT_CALL(connection2, usingOriginalDst()).WillRepeatedly(Return(true)); + EXPECT_CALL(connection2, localAddressRestored()).WillRepeatedly(Return(true)); OriginalDstCluster::LoadBalancer lb(cluster_->prioritySet(), cluster_); @@ -382,7 +382,7 @@ TEST_F(OriginalDstClusterTest, Connection) { NiceMock connection; TestLoadBalancerContext lb_context(&connection); connection.local_address_ = std::make_shared("FD00::1"); - EXPECT_CALL(connection, usingOriginalDst()).WillRepeatedly(Return(true)); + EXPECT_CALL(connection, localAddressRestored()).WillRepeatedly(Return(true)); OriginalDstCluster::LoadBalancer lb(cluster_->prioritySet(), cluster_); Event::PostCb post_cb; @@ -431,7 +431,7 @@ TEST_F(OriginalDstClusterTest, MultipleClusters) { NiceMock connection; TestLoadBalancerContext lb_context(&connection); connection.local_address_ = std::make_shared("FD00::1"); - EXPECT_CALL(connection, usingOriginalDst()).WillRepeatedly(Return(true)); + EXPECT_CALL(connection, localAddressRestored()).WillRepeatedly(Return(true)); OriginalDstCluster::LoadBalancer lb1(cluster_->prioritySet(), cluster_); OriginalDstCluster::LoadBalancer lb2(second, cluster_); diff --git a/test/config_test/config_test.cc b/test/config_test/config_test.cc index 7de7607f71093..282ee33445a5a 100644 --- a/test/config_test/config_test.cc +++ b/test/config_test/config_test.cc @@ -49,12 +49,21 @@ class ConfigTest { return main_config.clusterManager(); })); ON_CALL(server_, listenerManager()).WillByDefault(ReturnRef(listener_manager_)); - ON_CALL(component_factory_, createFilterFactoryList(_, _)) + ON_CALL(component_factory_, createNetworkFilterFactoryList(_, _)) .WillByDefault(Invoke([&](const Protobuf::RepeatedPtrField& filters, Server::Configuration::FactoryContext& context) -> std::vector { - return Server::ProdListenerComponentFactory::createFilterFactoryList_(filters, context); + return Server::ProdListenerComponentFactory::createNetworkFilterFactoryList_(filters, + context); })); + ON_CALL(component_factory_, createListenerFilterFactoryList(_, _)) + .WillByDefault( + Invoke([&](const Protobuf::RepeatedPtrField& filters, + Server::Configuration::FactoryContext& context) + -> std::vector { + return Server::ProdListenerComponentFactory::createListenerFilterFactoryList_( + filters, context); + })); try { main_config.initialize(bootstrap, server_, *cluster_manager_factory_); diff --git a/test/integration/autonomous_upstream.cc b/test/integration/autonomous_upstream.cc index 6e86339e1ead3..c601a1642ad83 100644 --- a/test/integration/autonomous_upstream.cc +++ b/test/integration/autonomous_upstream.cc @@ -66,7 +66,7 @@ AutonomousUpstream::~AutonomousUpstream() { http_connections_.clear(); } -bool AutonomousUpstream::createFilterChain(Network::Connection& connection) { +bool AutonomousUpstream::createNetworkFilterChain(Network::Connection& connection) { AutonomousHttpConnectionPtr http_connection(new AutonomousHttpConnection( QueuedConnectionWrapperPtr{new QueuedConnectionWrapper(connection, true)}, stats_store_, http_type_)); @@ -75,4 +75,6 @@ bool AutonomousUpstream::createFilterChain(Network::Connection& connection) { return true; } +bool AutonomousUpstream::createListenerFilterChain(Network::ListenerFilterManager&) { return true; } + } // namespace Envoy diff --git a/test/integration/autonomous_upstream.h b/test/integration/autonomous_upstream.h index 592888cc12dd7..c6137070b40ef 100644 --- a/test/integration/autonomous_upstream.h +++ b/test/integration/autonomous_upstream.h @@ -48,7 +48,8 @@ class AutonomousUpstream : public FakeUpstream { Network::Address::IpVersion version) : FakeUpstream(port, type, version) {} ~AutonomousUpstream(); - bool createFilterChain(Network::Connection& connection) override; + bool createNetworkFilterChain(Network::Connection& connection) override; + bool createListenerFilterChain(Network::ListenerFilterManager& listener) override; private: std::vector http_connections_; diff --git a/test/integration/fake_upstream.cc b/test/integration/fake_upstream.cc index 28f428aba0aea..5b76d1d4d0cdf 100644 --- a/test/integration/fake_upstream.cc +++ b/test/integration/fake_upstream.cc @@ -263,7 +263,7 @@ FakeUpstream::FakeUpstream(Ssl::ServerContext* ssl_ctx, Network::ListenSocketPtr api_(new Api::Impl(std::chrono::milliseconds(10000))), dispatcher_(api_->allocateDispatcher()), handler_(new Server::ConnectionHandlerImpl(ENVOY_LOGGER(), *dispatcher_)), - allow_unexpected_disconnects_(false) { + allow_unexpected_disconnects_(false), listener_(*this) { thread_.reset(new Thread::Thread([this]() -> void { threadRoutine(); })); server_initialized_.waitReady(); } @@ -278,7 +278,7 @@ void FakeUpstream::cleanUp() { } } -bool FakeUpstream::createFilterChain(Network::Connection& connection) { +bool FakeUpstream::createNetworkFilterChain(Network::Connection& connection) { std::unique_lock lock(lock_); connection.readDisable(true); new_connections_.emplace_back( @@ -287,14 +287,10 @@ bool FakeUpstream::createFilterChain(Network::Connection& connection) { return true; } +bool FakeUpstream::createListenerFilterChain(Network::ListenerFilterManager&) { return true; } + void FakeUpstream::threadRoutine() { - if (ssl_ctx_) { - handler_->addSslListener(*this, *ssl_ctx_, *socket_, stats_store_, 0, - Network::ListenerOptions::listenerOptionsWithBindToPort()); - } else { - handler_->addListener(*this, *socket_, stats_store_, 0, - Network::ListenerOptions::listenerOptionsWithBindToPort()); - } + handler_->addListener(listener_); server_initialized_.setReady(); dispatcher_->run(Event::Dispatcher::RunType::Block); diff --git a/test/integration/fake_upstream.h b/test/integration/fake_upstream.h index 98af3ca579246..f7986604f5891 100644 --- a/test/integration/fake_upstream.h +++ b/test/integration/fake_upstream.h @@ -14,6 +14,7 @@ #include "envoy/network/connection_handler.h" #include "envoy/network/filter.h" #include "envoy/server/configuration.h" +#include "envoy/server/listener_manager.h" #include "common/buffer/buffer_impl.h" #include "common/buffer/zero_copy_input_stream_impl.h" @@ -296,7 +297,8 @@ class FakeUpstream : Logger::Loggable, public Network::Filt std::vector>& upstreams); // Network::FilterChainFactory - bool createFilterChain(Network::Connection& connection) override; + bool createNetworkFilterChain(Network::Connection& connection) override; + bool createListenerFilterChain(Network::ListenerFilterManager& listener) override; void set_allow_unexpected_disconnects(bool value) { allow_unexpected_disconnects_ = value; } protected: @@ -307,6 +309,27 @@ class FakeUpstream : Logger::Loggable, public Network::Filt private: FakeUpstream(Ssl::ServerContext* ssl_ctx, Network::ListenSocketPtr&& connection, FakeHttpConnection::Type type); + + class FakeListener : public Network::ListenerConfig { + public: + FakeListener(FakeUpstream& parent) : parent_(parent), name_("fake_upstream") {} + + private: + // Network::ListenerConfig + Network::FilterChainFactory& filterChainFactory() override { return parent_; } + Network::ListenSocket& socket() override { return *parent_.socket_; } + Ssl::ServerContext* defaultSslContext() override { return parent_.ssl_ctx_; } + bool bindToPort() override { return true; } + bool handOffRestoredDestinationConnections() const override { return false; } + uint32_t perConnectionBufferLimitBytes() override { return 0; } + Stats::Scope& listenerScope() override { return parent_.stats_store_; } + uint64_t listenerTag() const override { return 0; } + const std::string& name() const override { return name_; } + + FakeUpstream& parent_; + std::string name_; + }; + void threadRoutine(); Ssl::ServerContext* ssl_ctx_{}; @@ -322,5 +345,6 @@ class FakeUpstream : Logger::Loggable, public Network::Filt Network::ConnectionHandlerPtr handler_; std::list new_connections_; // Guarded by lock_ bool allow_unexpected_disconnects_; + FakeListener listener_; }; } // namespace Envoy diff --git a/test/mocks/event/mocks.h b/test/mocks/event/mocks.h index d27ae7e1d82ff..6998cff846d51 100644 --- a/test/mocks/event/mocks.h +++ b/test/mocks/event/mocks.h @@ -29,6 +29,11 @@ class MockDispatcher : public Dispatcher { MockDispatcher(); ~MockDispatcher(); + Network::ConnectionPtr createServerConnection(Network::ConnectionSocketPtr&& socket, + Ssl::Context* ssl_ctx) override { + return Network::ConnectionPtr{createServerConnection_(socket.get(), ssl_ctx)}; + } + Network::ClientConnectionPtr createClientConnection(Network::Address::InstanceConstSharedPtr address, Network::Address::InstanceConstSharedPtr source_address, @@ -46,20 +51,11 @@ class MockDispatcher : public Dispatcher { return Filesystem::WatcherPtr{createFilesystemWatcher_()}; } - Network::ListenerPtr createListener(Network::ConnectionHandler& conn_handler, - Network::ListenSocket& socket, Network::ListenerCallbacks& cb, - Stats::Scope& scope, - const Network::ListenerOptions& listener_options) override { - return Network::ListenerPtr{createListener_(conn_handler, socket, cb, scope, listener_options)}; - } - - Network::ListenerPtr - createSslListener(Network::ConnectionHandler& conn_handler, Ssl::ServerContext& ssl_ctx, - Network::ListenSocket& socket, Network::ListenerCallbacks& cb, - Stats::Scope& scope, - const Network::ListenerOptions& listener_options) override { + Network::ListenerPtr createListener(Network::ListenSocket& socket, Network::ListenerCallbacks& cb, + bool bind_to_port, + bool hand_off_restored_destination_connections) override { return Network::ListenerPtr{ - createSslListener_(conn_handler, ssl_ctx, socket, cb, scope, listener_options)}; + createListener_(socket, cb, bind_to_port, hand_off_restored_destination_connections)}; } TimerPtr createTimer(TimerCb cb) override { return TimerPtr{createTimer_(cb)}; } @@ -77,6 +73,8 @@ class MockDispatcher : public Dispatcher { // Event::Dispatcher MOCK_METHOD0(clearDeferredDeleteList, void()); + MOCK_METHOD2(createServerConnection_, + Network::Connection*(Network::ConnectionSocket* socket, Ssl::Context* ssl_ctx)); MOCK_METHOD3(createClientConnection_, Network::ClientConnection*(Network::Address::InstanceConstSharedPtr address, Network::Address::InstanceConstSharedPtr source_address, @@ -87,16 +85,10 @@ class MockDispatcher : public Dispatcher { MOCK_METHOD4(createFileEvent_, FileEvent*(int fd, FileReadyCb cb, FileTriggerType trigger, uint32_t events)); MOCK_METHOD0(createFilesystemWatcher_, Filesystem::Watcher*()); - MOCK_METHOD5(createListener_, - Network::Listener*(Network::ConnectionHandler& conn_handler, - Network::ListenSocket& socket, Network::ListenerCallbacks& cb, - Stats::Scope& scope, - const Network::ListenerOptions& listener_options)); - MOCK_METHOD6(createSslListener_, - Network::Listener*(Network::ConnectionHandler& conn_handler, - Ssl::ServerContext& ssl_ctx, Network::ListenSocket& socket, - Network::ListenerCallbacks& cb, Stats::Scope& scope, - const Network::ListenerOptions& listener_options)); + MOCK_METHOD4(createListener_, + Network::Listener*(Network::ListenSocket& socket, Network::ListenerCallbacks& cb, + bool bind_to_port, + bool hand_off_restored_destination_connections)); MOCK_METHOD1(createTimer_, Timer*(TimerCb cb)); MOCK_METHOD1(deferredDelete_, void(DeferredDeletablePtr& to_delete)); MOCK_METHOD0(exit, void()); diff --git a/test/mocks/network/BUILD b/test/mocks/network/BUILD index dbcfac52c0eaf..5cc4100695b20 100644 --- a/test/mocks/network/BUILD +++ b/test/mocks/network/BUILD @@ -18,6 +18,7 @@ envoy_cc_mock( "//include/envoy/network:drain_decision_interface", "//include/envoy/network:filter_interface", "//include/envoy/network:transport_socket_interface", + "//include/envoy/server:listener_manager_interface", "//source/common/network:address_lib", "//source/common/network:utility_lib", "//test/mocks/event:event_mocks", diff --git a/test/mocks/network/mocks.cc b/test/mocks/network/mocks.cc index 69de7c92de05d..dad3a38bdbc03 100644 --- a/test/mocks/network/mocks.cc +++ b/test/mocks/network/mocks.cc @@ -3,6 +3,7 @@ #include #include "envoy/buffer/buffer.h" +#include "envoy/server/listener_manager.h" #include "common/network/address_impl.h" #include "common/network/utility.h" @@ -150,7 +151,18 @@ MockListenerCallbacks::~MockListenerCallbacks() {} MockDrainDecision::MockDrainDecision() {} MockDrainDecision::~MockDrainDecision() {} -MockFilterChainFactory::MockFilterChainFactory() {} +MockListenerFilter::MockListenerFilter() {} +MockListenerFilter::~MockListenerFilter() {} + +MockListenerFilterCallbacks::MockListenerFilterCallbacks() {} +MockListenerFilterCallbacks::~MockListenerFilterCallbacks() {} + +MockListenerFilterManager::MockListenerFilterManager() {} +MockListenerFilterManager::~MockListenerFilterManager() {} + +MockFilterChainFactory::MockFilterChainFactory() { + ON_CALL(*this, createListenerFilterChain(_)).WillByDefault(Return(true)); +} MockFilterChainFactory::~MockFilterChainFactory() {} MockListenSocket::MockListenSocket() : local_address_(new Address::Ipv4Instance(80)) { @@ -159,6 +171,12 @@ MockListenSocket::MockListenSocket() : local_address_(new Address::Ipv4Instance( MockListenSocket::~MockListenSocket() {} +MockConnectionSocket::MockConnectionSocket() : local_address_(new Address::Ipv4Instance(80)) { + ON_CALL(*this, localAddress()).WillByDefault(ReturnRef(local_address_)); +} + +MockConnectionSocket::~MockConnectionSocket() {} + MockListener::MockListener() {} MockListener::~MockListener() { onDestroy(); } diff --git a/test/mocks/network/mocks.h b/test/mocks/network/mocks.h index 1b566fd2800ff..e6e178f8919b2 100644 --- a/test/mocks/network/mocks.h +++ b/test/mocks/network/mocks.h @@ -78,7 +78,7 @@ class MockConnection : public Connection, public MockConnectionBase { MOCK_METHOD1(write, void(Buffer::Instance& data)); MOCK_METHOD1(setBufferLimits, void(uint32_t limit)); MOCK_CONST_METHOD0(bufferLimit, uint32_t()); - MOCK_CONST_METHOD0(usingOriginalDst, bool()); + MOCK_CONST_METHOD0(localAddressRestored, bool()); MOCK_CONST_METHOD0(aboveHighWatermark, bool()); }; @@ -115,7 +115,7 @@ class MockClientConnection : public ClientConnection, public MockConnectionBase MOCK_METHOD1(write, void(Buffer::Instance& data)); MOCK_METHOD1(setBufferLimits, void(uint32_t limit)); MOCK_CONST_METHOD0(bufferLimit, uint32_t()); - MOCK_CONST_METHOD0(usingOriginalDst, bool()); + MOCK_CONST_METHOD0(localAddressRestored, bool()); MOCK_CONST_METHOD0(aboveHighWatermark, bool()); // Network::ClientConnection @@ -195,8 +195,12 @@ class MockListenerCallbacks : public ListenerCallbacks { MockListenerCallbacks(); ~MockListenerCallbacks(); + void onAccept(ConnectionSocketPtr&& socket, bool redirected) override { + onAccept_(socket, redirected); + } void onNewConnection(ConnectionPtr&& conn) override { onNewConnection_(conn); } + MOCK_METHOD2(onAccept_, void(ConnectionSocketPtr& socket, bool redirected)); MOCK_METHOD1(onNewConnection_, void(ConnectionPtr& conn)); }; @@ -208,12 +212,41 @@ class MockDrainDecision : public DrainDecision { MOCK_CONST_METHOD0(drainClose, bool()); }; +class MockListenerFilter : public Network::ListenerFilter { +public: + MockListenerFilter(); + ~MockListenerFilter(); + + MOCK_METHOD1(onAccept, Network::FilterStatus(Network::ListenerFilterCallbacks&)); +}; + +class MockListenerFilterCallbacks : public ListenerFilterCallbacks { +public: + MockListenerFilterCallbacks(); + ~MockListenerFilterCallbacks(); + + MOCK_METHOD0(socket, ConnectionSocket&()); + MOCK_METHOD0(dispatcher, Event::Dispatcher&()); + MOCK_METHOD1(continueFilterChain, void(bool)); +}; + +class MockListenerFilterManager : public ListenerFilterManager { +public: + MockListenerFilterManager(); + ~MockListenerFilterManager(); + + void addAcceptFilter(Network::ListenerFilterPtr&& filter) override { addAcceptFilter_(filter); } + + MOCK_METHOD1(addAcceptFilter_, void(Network::ListenerFilterPtr&)); +}; + class MockFilterChainFactory : public FilterChainFactory { public: MockFilterChainFactory(); ~MockFilterChainFactory(); - MOCK_METHOD1(createFilterChain, bool(Connection& connection)); + MOCK_METHOD1(createNetworkFilterChain, bool(Connection& connection)); + MOCK_METHOD1(createListenerFilterChain, bool(ListenerFilterManager& listener)); }; class MockListenSocket : public ListenSocket { @@ -228,6 +261,22 @@ class MockListenSocket : public ListenSocket { Address::InstanceConstSharedPtr local_address_; }; +class MockConnectionSocket : public ConnectionSocket { +public: + MockConnectionSocket(); + ~MockConnectionSocket(); + + MOCK_CONST_METHOD0(localAddress, const Address::InstanceConstSharedPtr&()); + MOCK_METHOD2(setLocalAddress, void(const Address::InstanceConstSharedPtr&, bool)); + MOCK_CONST_METHOD0(localAddressRestored, bool()); + MOCK_CONST_METHOD0(remoteAddress, const Address::InstanceConstSharedPtr&()); + MOCK_METHOD1(setRemoteAddress, void(const Address::InstanceConstSharedPtr&)); + MOCK_CONST_METHOD0(fd, int()); + MOCK_METHOD0(close, void()); + + Address::InstanceConstSharedPtr local_address_; +}; + class MockListenerConfig : public ListenerConfig { public: MockListenerConfig(); @@ -236,12 +285,11 @@ class MockListenerConfig : public ListenerConfig { MOCK_METHOD0(filterChainFactory, FilterChainFactory&()); MOCK_METHOD0(socket, ListenSocket&()); MOCK_METHOD0(defaultSslContext, Ssl::ServerContext*()); - MOCK_METHOD0(useProxyProto, bool()); MOCK_METHOD0(bindToPort, bool()); - MOCK_METHOD0(useOriginalDst, bool()); + MOCK_CONST_METHOD0(handOffRestoredDestinationConnections, bool()); MOCK_METHOD0(perConnectionBufferLimitBytes, uint32_t()); MOCK_METHOD0(listenerScope, Stats::Scope&()); - MOCK_METHOD0(listenerTag, uint64_t()); + MOCK_CONST_METHOD0(listenerTag, uint64_t()); MOCK_CONST_METHOD0(name, const std::string&()); testing::NiceMock filter_chain_factory_; @@ -264,14 +312,7 @@ class MockConnectionHandler : public ConnectionHandler { ~MockConnectionHandler(); MOCK_METHOD0(numConnections, uint64_t()); - MOCK_METHOD5(addListener, - void(Network::FilterChainFactory& factory, Network::ListenSocket& socket, - Stats::Scope& scope, uint64_t listener_tag, - const Network::ListenerOptions& listener_options)); - MOCK_METHOD6(addSslListener, - void(Network::FilterChainFactory& factory, Ssl::ServerContext& ssl_ctx, - Network::ListenSocket& socket, Stats::Scope& scope, uint64_t listener_tag, - const Network::ListenerOptions& listener_options)); + MOCK_METHOD1(addListener, void(ListenerConfig& config)); MOCK_METHOD1(findListenerByAddress, Network::Listener*(const Network::Address::Instance& address)); MOCK_METHOD1(removeListeners, void(uint64_t listener_tag)); diff --git a/test/mocks/server/mocks.h b/test/mocks/server/mocks.h index 9ff13450b37f8..dfeb1d5deb71e 100644 --- a/test/mocks/server/mocks.h +++ b/test/mocks/server/mocks.h @@ -146,10 +146,14 @@ class MockListenerComponentFactory : public ListenerComponentFactory { return DrainManagerPtr{createDrainManager_(drain_type)}; } - MOCK_METHOD2(createFilterFactoryList, + MOCK_METHOD2(createNetworkFilterFactoryList, std::vector( const Protobuf::RepeatedPtrField& filters, Configuration::FactoryContext& context)); + MOCK_METHOD2(createListenerFilterFactoryList, + std::vector( + const Protobuf::RepeatedPtrField&, + Configuration::FactoryContext& context)); MOCK_METHOD2(createListenSocket, Network::ListenSocketSharedPtr(Network::Address::InstanceConstSharedPtr address, bool bind_to_port)); diff --git a/test/server/BUILD b/test/server/BUILD index 076e4609665cd..f2723537c02f4 100644 --- a/test/server/BUILD +++ b/test/server/BUILD @@ -123,7 +123,9 @@ envoy_cc_test( data = ["//test/common/ssl/test_data:certs"], deps = [ ":utility_lib", + "//source/common/network:listen_socket_lib", "//source/server:listener_manager_lib", + "//source/server/config/listener:original_dst_lib", "//source/server/config/network:http_connection_manager_lib", "//test/mocks/server:server_mocks", "//test/test_common:environment_lib", diff --git a/test/server/connection_handler_test.cc b/test/server/connection_handler_test.cc index b10dafe9ed83a..44710fc8c9bfe 100644 --- a/test/server/connection_handler_test.cc +++ b/test/server/connection_handler_test.cc @@ -15,6 +15,7 @@ using testing::InSequence; using testing::Invoke; using testing::NiceMock; using testing::Return; +using testing::ReturnRef; using testing::_; namespace Envoy { @@ -24,11 +25,51 @@ class ConnectionHandlerTest : public testing::Test, protected Logger::Loggable { + public: + TestListener(ConnectionHandlerTest& parent, uint64_t tag, bool bind_to_port, + bool hand_off_restored_destination_connections, const std::string& name) + : parent_(parent), tag_(tag), bind_to_port_(bind_to_port), + hand_off_restored_destination_connections_(hand_off_restored_destination_connections), + name_(name) {} + + Network::FilterChainFactory& filterChainFactory() override { return parent_.factory_; } + Network::ListenSocket& socket() override { return socket_; } + Ssl::ServerContext* defaultSslContext() override { return nullptr; } + bool bindToPort() override { return bind_to_port_; } + bool handOffRestoredDestinationConnections() const override { + return hand_off_restored_destination_connections_; + } + uint32_t perConnectionBufferLimitBytes() override { return 0; } + Stats::Scope& listenerScope() override { return parent_.stats_store_; } + uint64_t listenerTag() const override { return tag_; } + const std::string& name() const override { return name_; } + + ConnectionHandlerTest& parent_; + Network::MockListenSocket socket_; + uint64_t tag_; + bool bind_to_port_; + const bool hand_off_restored_destination_connections_; + const std::string name_; + }; + + typedef std::unique_ptr TestListenerPtr; + + TestListener* addListener(uint64_t tag, bool bind_to_port, + bool hand_off_restored_destination_connections, + const std::string& name) { + TestListener* listener = + new TestListener(*this, tag, bind_to_port, hand_off_restored_destination_connections, name); + listener->moveIntoListBack(TestListenerPtr{listener}, listeners_); + return listener; + } + Stats::IsolatedStoreImpl stats_store_; NiceMock dispatcher_; Network::ConnectionHandlerPtr handler_; - Network::MockFilterChainFactory factory_; - NiceMock socket_; + NiceMock factory_; + std::list listeners_; }; TEST_F(ConnectionHandlerTest, RemoveListener) { @@ -36,19 +77,19 @@ TEST_F(ConnectionHandlerTest, RemoveListener) { Network::MockListener* listener = new NiceMock(); Network::ListenerCallbacks* listener_callbacks; - EXPECT_CALL(dispatcher_, createListener_(_, _, _, _, _)) - .WillOnce(Invoke([&](Network::ConnectionHandler&, Network::ListenSocket&, - Network::ListenerCallbacks& cb, Stats::Scope&, - const Network::ListenerOptions&) -> Network::Listener* { + EXPECT_CALL(dispatcher_, createListener_(_, _, _, false)) + .WillOnce(Invoke([&](Network::ListenSocket&, Network::ListenerCallbacks& cb, bool, + bool) -> Network::Listener* { listener_callbacks = &cb; return listener; })); - handler_->addListener(factory_, socket_, stats_store_, 1, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + TestListener* test_listener = addListener(1, true, false, "test_listener"); + EXPECT_CALL(test_listener->socket_, localAddress()); + handler_->addListener(*test_listener); Network::MockConnection* connection = new NiceMock(); - EXPECT_CALL(factory_, createFilterChain(_)).WillOnce(Return(true)); + EXPECT_CALL(factory_, createNetworkFilterChain(_)).WillOnce(Return(true)); listener_callbacks->onNewConnection(Network::ConnectionPtr{connection}); EXPECT_EQ(1UL, handler_->numConnections()); @@ -74,19 +115,19 @@ TEST_F(ConnectionHandlerTest, DestroyCloseConnections) { Network::MockListener* listener = new NiceMock(); Network::ListenerCallbacks* listener_callbacks; - EXPECT_CALL(dispatcher_, createListener_(_, _, _, _, _)) - .WillOnce(Invoke([&](Network::ConnectionHandler&, Network::ListenSocket&, - Network::ListenerCallbacks& cb, Stats::Scope&, - const Network::ListenerOptions&) -> Network::Listener* { + EXPECT_CALL(dispatcher_, createListener_(_, _, _, _)) + .WillOnce(Invoke([&](Network::ListenSocket&, Network::ListenerCallbacks& cb, bool, + bool) -> Network::Listener* { listener_callbacks = &cb; return listener; })); - handler_->addListener(factory_, socket_, stats_store_, 1, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + TestListener* test_listener = addListener(1, true, false, "test_listener"); + EXPECT_CALL(test_listener->socket_, localAddress()); + handler_->addListener(*test_listener); Network::MockConnection* connection = new NiceMock(); - EXPECT_CALL(factory_, createFilterChain(_)).WillOnce(Return(true)); + EXPECT_CALL(factory_, createNetworkFilterChain(_)).WillOnce(Return(true)); listener_callbacks->onNewConnection(Network::ConnectionPtr{connection}); EXPECT_EQ(1UL, handler_->numConnections()); @@ -101,19 +142,19 @@ TEST_F(ConnectionHandlerTest, CloseDuringFilterChainCreate) { Network::MockListener* listener = new Network::MockListener(); Network::ListenerCallbacks* listener_callbacks; - EXPECT_CALL(dispatcher_, createListener_(_, _, _, _, _)) - .WillOnce(Invoke([&](Network::ConnectionHandler&, Network::ListenSocket&, - Network::ListenerCallbacks& cb, Stats::Scope&, - const Network::ListenerOptions&) -> Network::Listener* { + EXPECT_CALL(dispatcher_, createListener_(_, _, _, _)) + .WillOnce(Invoke([&](Network::ListenSocket&, Network::ListenerCallbacks& cb, bool, + bool) -> Network::Listener* { listener_callbacks = &cb; return listener; })); - handler_->addListener(factory_, socket_, stats_store_, 1, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + TestListener* test_listener = addListener(1, true, false, "test_listener"); + EXPECT_CALL(test_listener->socket_, localAddress()); + handler_->addListener(*test_listener); Network::MockConnection* connection = new NiceMock(); - EXPECT_CALL(factory_, createFilterChain(_)); + EXPECT_CALL(factory_, createNetworkFilterChain(_)); EXPECT_CALL(*connection, state()).WillOnce(Return(Network::Connection::State::Closed)); EXPECT_CALL(*connection, addConnectionCallbacks(_)).Times(0); listener_callbacks->onNewConnection(Network::ConnectionPtr{connection}); @@ -127,19 +168,19 @@ TEST_F(ConnectionHandlerTest, CloseConnectionOnEmptyFilterChain) { Network::MockListener* listener = new Network::MockListener(); Network::ListenerCallbacks* listener_callbacks; - EXPECT_CALL(dispatcher_, createListener_(_, _, _, _, _)) - .WillOnce(Invoke([&](Network::ConnectionHandler&, Network::ListenSocket&, - Network::ListenerCallbacks& cb, Stats::Scope&, - const Network::ListenerOptions&) -> Network::Listener* { + EXPECT_CALL(dispatcher_, createListener_(_, _, _, _)) + .WillOnce(Invoke([&](Network::ListenSocket&, Network::ListenerCallbacks& cb, bool, + bool) -> Network::Listener* { listener_callbacks = &cb; return listener; })); - handler_->addListener(factory_, socket_, stats_store_, 1, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + TestListener* test_listener = addListener(1, true, false, "test_listener"); + EXPECT_CALL(test_listener->socket_, localAddress()); + handler_->addListener(*test_listener); Network::MockConnection* connection = new NiceMock(); - EXPECT_CALL(factory_, createFilterChain(_)).WillOnce(Return(false)); + EXPECT_CALL(factory_, createNetworkFilterChain(_)).WillOnce(Return(false)); EXPECT_CALL(*connection, close(Network::ConnectionCloseType::NoFlush)); listener_callbacks->onNewConnection(Network::ConnectionPtr{connection}); EXPECT_EQ(0UL, handler_->numConnections()); @@ -148,43 +189,31 @@ TEST_F(ConnectionHandlerTest, CloseConnectionOnEmptyFilterChain) { } TEST_F(ConnectionHandlerTest, FindListenerByAddress) { + TestListener* test_listener1 = addListener(1, true, true, "test_listener1"); Network::Address::InstanceConstSharedPtr alt_address( new Network::Address::Ipv4Instance("127.0.0.1", 10001)); - EXPECT_CALL(socket_, localAddress()).WillRepeatedly(Return(alt_address)); Network::MockListener* listener = new Network::MockListener(); - Network::ListenerCallbacks* listener_callbacks; - EXPECT_CALL(dispatcher_, createListener_(_, _, _, _, _)) - .WillOnce(Invoke([&](Network::ConnectionHandler&, Network::ListenSocket&, - Network::ListenerCallbacks& cb, Stats::Scope&, - const Network::ListenerOptions&) -> Network::Listener* { - listener_callbacks = &cb; - return listener; - - })); - handler_->addListener(factory_, socket_, stats_store_, 1, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + EXPECT_CALL(dispatcher_, createListener_(_, _, _, true)) + .WillOnce(Invoke([&](Network::ListenSocket&, Network::ListenerCallbacks&, bool, + bool) -> Network::Listener* { return listener; })); + EXPECT_CALL(test_listener1->socket_, localAddress()).WillRepeatedly(Return(alt_address)); + handler_->addListener(*test_listener1); EXPECT_EQ(listener, handler_->findListenerByAddress(ByRef(*alt_address))); - Network::MockListenSocket socket2; + TestListener* test_listener2 = addListener(2, true, false, "test_listener2"); Network::Address::InstanceConstSharedPtr alt_address2( new Network::Address::Ipv4Instance("0.0.0.0", 10001)); Network::Address::InstanceConstSharedPtr alt_address3( new Network::Address::Ipv4Instance("127.0.0.2", 10001)); - EXPECT_CALL(socket2, localAddress()).WillRepeatedly(Return(alt_address2)); Network::MockListener* listener2 = new Network::MockListener(); - EXPECT_CALL(dispatcher_, createListener_(_, _, _, _, _)) - .WillOnce(Invoke([&](Network::ConnectionHandler&, Network::ListenSocket&, - Network::ListenerCallbacks& cb, Stats::Scope&, - const Network::ListenerOptions&) -> Network::Listener* { - listener_callbacks = &cb; - return listener2; - - })); - handler_->addListener(factory_, socket2, stats_store_, 2, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + EXPECT_CALL(dispatcher_, createListener_(_, _, _, false)) + .WillOnce(Invoke([&](Network::ListenSocket&, Network::ListenerCallbacks&, bool, + bool) -> Network::Listener* { return listener2; })); + EXPECT_CALL(test_listener2->socket_, localAddress()).WillRepeatedly(Return(alt_address2)); + handler_->addListener(*test_listener2); EXPECT_EQ(listener, handler_->findListenerByAddress(ByRef(*alt_address))); EXPECT_EQ(listener2, handler_->findListenerByAddress(ByRef(*alt_address2))); @@ -198,16 +227,10 @@ TEST_F(ConnectionHandlerTest, FindListenerByAddress) { handler_->stopListeners(2); Network::MockListener* listener3 = new Network::MockListener(); - EXPECT_CALL(dispatcher_, createListener_(_, _, _, _, _)) - .WillOnce(Invoke([&](Network::ConnectionHandler&, Network::ListenSocket&, - Network::ListenerCallbacks& cb, Stats::Scope&, - const Network::ListenerOptions&) -> Network::Listener* { - listener_callbacks = &cb; - return listener3; - - })); - handler_->addListener(factory_, socket2, stats_store_, 2, - Network::ListenerOptions::listenerOptionsWithBindToPort()); + EXPECT_CALL(dispatcher_, createListener_(_, _, _, _)) + .WillOnce(Invoke([&](Network::ListenSocket&, Network::ListenerCallbacks&, bool, + bool) -> Network::Listener* { return listener3; })); + handler_->addListener(*test_listener2); EXPECT_EQ(listener3, handler_->findListenerByAddress(ByRef(*alt_address2))); EXPECT_EQ(listener3, handler_->findListenerByAddress(ByRef(*alt_address3))); @@ -215,5 +238,207 @@ TEST_F(ConnectionHandlerTest, FindListenerByAddress) { EXPECT_CALL(*listener3, onDestroy()); } +TEST_F(ConnectionHandlerTest, NormalRedirect) { + TestListener* test_listener1 = addListener(1, true, true, "test_listener1"); + Network::MockListener* listener1 = new Network::MockListener(); + Network::ListenerCallbacks* listener_callbacks1; + EXPECT_CALL(dispatcher_, createListener_(_, _, _, true)) + .WillOnce(Invoke([&](Network::ListenSocket&, Network::ListenerCallbacks& cb, bool, + bool) -> Network::Listener* { + listener_callbacks1 = &cb; + return listener1; + })); + Network::Address::InstanceConstSharedPtr normal_address( + new Network::Address::Ipv4Instance("127.0.0.1", 10001)); + EXPECT_CALL(test_listener1->socket_, localAddress()).WillRepeatedly(Return(normal_address)); + handler_->addListener(*test_listener1); + + TestListener* test_listener2 = addListener(1, false, false, "test_listener2"); + Network::MockListener* listener2 = new Network::MockListener(); + Network::ListenerCallbacks* listener_callbacks2; + EXPECT_CALL(dispatcher_, createListener_(_, _, _, false)) + .WillOnce(Invoke([&](Network::ListenSocket&, Network::ListenerCallbacks& cb, bool, + bool) -> Network::Listener* { + listener_callbacks2 = &cb; + return listener2; + })); + Network::Address::InstanceConstSharedPtr alt_address( + new Network::Address::Ipv4Instance("127.0.0.2", 20002)); + EXPECT_CALL(test_listener2->socket_, localAddress()).WillRepeatedly(Return(alt_address)); + handler_->addListener(*test_listener2); + + Network::MockListenerFilter* test_filter = new Network::MockListenerFilter(); + Network::MockConnectionSocket* accepted_socket = new NiceMock(); + bool redirected = false; + EXPECT_CALL(factory_, createListenerFilterChain(_)) + .WillRepeatedly(Invoke([&](Network::ListenerFilterManager& manager) -> bool { + // Insert the Mock filter. + if (!redirected) { + manager.addAcceptFilter(Network::ListenerFilterPtr{test_filter}); + redirected = true; + } + return true; + })); + EXPECT_CALL(*test_filter, onAccept(_)) + .WillOnce(Invoke([&](Network::ListenerFilterCallbacks& cb) -> Network::FilterStatus { + cb.socket().setLocalAddress(alt_address, true); + return Network::FilterStatus::Continue; + })); + EXPECT_CALL(*accepted_socket, setLocalAddress(alt_address, true)); + EXPECT_CALL(*accepted_socket, localAddressRestored()).WillOnce(Return(true)); + EXPECT_CALL(*accepted_socket, localAddress()).WillRepeatedly(ReturnRef(alt_address)); + Network::MockConnection* connection = new NiceMock(); + EXPECT_CALL(factory_, createNetworkFilterChain(_)).WillOnce(Return(true)); + EXPECT_CALL(dispatcher_, createServerConnection_(_, _)).WillOnce(Return(connection)); + listener_callbacks1->onAccept(Network::ConnectionSocketPtr{accepted_socket}, true); + EXPECT_EQ(1UL, handler_->numConnections()); + + EXPECT_CALL(*listener2, onDestroy()); + EXPECT_CALL(*listener1, onDestroy()); +} + +TEST_F(ConnectionHandlerTest, FallbackToWildcardListener) { + TestListener* test_listener1 = addListener(1, true, true, "test_listener1"); + Network::MockListener* listener1 = new Network::MockListener(); + Network::ListenerCallbacks* listener_callbacks1; + EXPECT_CALL(dispatcher_, createListener_(_, _, _, true)) + .WillOnce(Invoke([&](Network::ListenSocket&, Network::ListenerCallbacks& cb, bool, + bool) -> Network::Listener* { + listener_callbacks1 = &cb; + return listener1; + })); + Network::Address::InstanceConstSharedPtr normal_address( + new Network::Address::Ipv4Instance("127.0.0.1", 10001)); + EXPECT_CALL(test_listener1->socket_, localAddress()).WillRepeatedly(Return(normal_address)); + handler_->addListener(*test_listener1); + + TestListener* test_listener2 = addListener(1, false, false, "test_listener2"); + Network::MockListener* listener2 = new Network::MockListener(); + Network::ListenerCallbacks* listener_callbacks2; + EXPECT_CALL(dispatcher_, createListener_(_, _, _, false)) + .WillOnce(Invoke([&](Network::ListenSocket&, Network::ListenerCallbacks& cb, bool, + bool) -> Network::Listener* { + listener_callbacks2 = &cb; + return listener2; + })); + Network::Address::InstanceConstSharedPtr any_address = Network::Utility::getIpv4AnyAddress(); + EXPECT_CALL(test_listener2->socket_, localAddress()).WillRepeatedly(Return(any_address)); + handler_->addListener(*test_listener2); + + Network::MockListenerFilter* test_filter = new Network::MockListenerFilter(); + Network::MockConnectionSocket* accepted_socket = new NiceMock(); + bool redirected = false; + EXPECT_CALL(factory_, createListenerFilterChain(_)) + .WillRepeatedly(Invoke([&](Network::ListenerFilterManager& manager) -> bool { + // Insert the Mock filter. + if (!redirected) { + manager.addAcceptFilter(Network::ListenerFilterPtr{test_filter}); + redirected = true; + } + return true; + })); + // Zero port to match the port of AnyAddress + Network::Address::InstanceConstSharedPtr alt_address( + new Network::Address::Ipv4Instance("127.0.0.2", 0)); + EXPECT_CALL(*test_filter, onAccept(_)) + .WillOnce(Invoke([&](Network::ListenerFilterCallbacks& cb) -> Network::FilterStatus { + cb.socket().setLocalAddress(alt_address, true); + return Network::FilterStatus::Continue; + })); + EXPECT_CALL(*accepted_socket, setLocalAddress(alt_address, true)); + EXPECT_CALL(*accepted_socket, localAddressRestored()).WillOnce(Return(true)); + EXPECT_CALL(*accepted_socket, localAddress()).WillRepeatedly(ReturnRef(alt_address)); + Network::MockConnection* connection = new NiceMock(); + EXPECT_CALL(factory_, createNetworkFilterChain(_)).WillOnce(Return(true)); + EXPECT_CALL(dispatcher_, createServerConnection_(_, _)).WillOnce(Return(connection)); + listener_callbacks1->onAccept(Network::ConnectionSocketPtr{accepted_socket}, true); + EXPECT_EQ(1UL, handler_->numConnections()); + + EXPECT_CALL(*listener2, onDestroy()); + EXPECT_CALL(*listener1, onDestroy()); +} + +TEST_F(ConnectionHandlerTest, WildcardListenerWithOriginalDst) { + TestListener* test_listener1 = addListener(1, true, true, "test_listener1"); + Network::MockListener* listener1 = new Network::MockListener(); + Network::ListenerCallbacks* listener_callbacks1; + EXPECT_CALL(dispatcher_, createListener_(_, _, _, true)) + .WillOnce(Invoke([&](Network::ListenSocket&, Network::ListenerCallbacks& cb, bool, + bool) -> Network::Listener* { + listener_callbacks1 = &cb; + return listener1; + })); + Network::Address::InstanceConstSharedPtr normal_address( + new Network::Address::Ipv4Instance("127.0.0.1", 80)); + // Original dst address nor port number match that of the listener's address. + Network::Address::InstanceConstSharedPtr original_dst_address( + new Network::Address::Ipv4Instance("127.0.0.2", 8080)); + Network::Address::InstanceConstSharedPtr any_address = Network::Utility::getAddressWithPort( + *Network::Utility::getIpv4AnyAddress(), normal_address->ip()->port()); + EXPECT_CALL(test_listener1->socket_, localAddress()).WillRepeatedly(Return(any_address)); + handler_->addListener(*test_listener1); + + Network::MockListenerFilter* test_filter = new Network::MockListenerFilter(); + Network::MockConnectionSocket* accepted_socket = new NiceMock(); + EXPECT_CALL(factory_, createListenerFilterChain(_)) + .WillRepeatedly(Invoke([&](Network::ListenerFilterManager& manager) -> bool { + // Insert the Mock filter. + manager.addAcceptFilter(Network::ListenerFilterPtr{test_filter}); + return true; + })); + EXPECT_CALL(*test_filter, onAccept(_)) + .WillOnce(Invoke([&](Network::ListenerFilterCallbacks& cb) -> Network::FilterStatus { + cb.socket().setLocalAddress(original_dst_address, true); + return Network::FilterStatus::Continue; + })); + EXPECT_CALL(*accepted_socket, setLocalAddress(original_dst_address, true)); + EXPECT_CALL(*accepted_socket, localAddressRestored()).WillOnce(Return(true)); + EXPECT_CALL(*accepted_socket, localAddress()).WillRepeatedly(ReturnRef(original_dst_address)); + Network::MockConnection* connection = new NiceMock(); + EXPECT_CALL(factory_, createNetworkFilterChain(_)).WillOnce(Return(true)); + EXPECT_CALL(dispatcher_, createServerConnection_(_, _)).WillOnce(Return(connection)); + listener_callbacks1->onAccept(Network::ConnectionSocketPtr{accepted_socket}, true); + EXPECT_EQ(1UL, handler_->numConnections()); + + EXPECT_CALL(*listener1, onDestroy()); +} + +TEST_F(ConnectionHandlerTest, WildcardListenerWithNoOriginalDst) { + TestListener* test_listener1 = addListener(1, true, true, "test_listener1"); + Network::MockListener* listener1 = new Network::MockListener(); + Network::ListenerCallbacks* listener_callbacks1; + EXPECT_CALL(dispatcher_, createListener_(_, _, _, true)) + .WillOnce(Invoke([&](Network::ListenSocket&, Network::ListenerCallbacks& cb, bool, + bool) -> Network::Listener* { + listener_callbacks1 = &cb; + return listener1; + })); + Network::Address::InstanceConstSharedPtr normal_address( + new Network::Address::Ipv4Instance("127.0.0.1", 80)); + Network::Address::InstanceConstSharedPtr any_address = Network::Utility::getAddressWithPort( + *Network::Utility::getIpv4AnyAddress(), normal_address->ip()->port()); + EXPECT_CALL(test_listener1->socket_, localAddress()).WillRepeatedly(Return(any_address)); + handler_->addListener(*test_listener1); + + Network::MockListenerFilter* test_filter = new Network::MockListenerFilter(); + Network::MockConnectionSocket* accepted_socket = new NiceMock(); + EXPECT_CALL(factory_, createListenerFilterChain(_)) + .WillRepeatedly(Invoke([&](Network::ListenerFilterManager& manager) -> bool { + // Insert the Mock filter. + manager.addAcceptFilter(Network::ListenerFilterPtr{test_filter}); + return true; + })); + EXPECT_CALL(*test_filter, onAccept(_)).WillOnce(Return(Network::FilterStatus::Continue)); + EXPECT_CALL(*accepted_socket, localAddressRestored()).WillOnce(Return(false)); + EXPECT_CALL(*accepted_socket, localAddress()).WillRepeatedly(ReturnRef(normal_address)); + Network::MockConnection* connection = new NiceMock(); + EXPECT_CALL(factory_, createNetworkFilterChain(_)).WillOnce(Return(true)); + EXPECT_CALL(dispatcher_, createServerConnection_(_, _)).WillOnce(Return(connection)); + listener_callbacks1->onAccept(Network::ConnectionSocketPtr{accepted_socket}, true); + EXPECT_EQ(1UL, handler_->numConnections()); + + EXPECT_CALL(*listener1, onDestroy()); +} + } // namespace Server } // namespace Envoy diff --git a/test/server/http/admin_test.cc b/test/server/http/admin_test.cc index a3a059b0ee2b6..166c3695bfd4f 100644 --- a/test/server/http/admin_test.cc +++ b/test/server/http/admin_test.cc @@ -32,7 +32,8 @@ class AdminFilterTest : public testing::TestWithParamdrain_manager_)); - EXPECT_CALL(listener_factory_, createFilterFactoryList(_, _)) + EXPECT_CALL(listener_factory_, createNetworkFilterFactoryList(_, _)) .WillOnce(Invoke( [raw_listener, need_init](const Protobuf::RepeatedPtrField&, Configuration::FactoryContext& context) @@ -95,12 +98,20 @@ class ListenerManagerImplWithRealFiltersTest : public ListenerManagerImplTest { public: ListenerManagerImplWithRealFiltersTest() { // Use real filter loading by default. - ON_CALL(listener_factory_, createFilterFactoryList(_, _)) + ON_CALL(listener_factory_, createNetworkFilterFactoryList(_, _)) .WillByDefault(Invoke([](const Protobuf::RepeatedPtrField& filters, Configuration::FactoryContext& context) -> std::vector { - return ProdListenerComponentFactory::createFilterFactoryList_(filters, context); + return ProdListenerComponentFactory::createNetworkFilterFactoryList_(filters, context); })); + ON_CALL(listener_factory_, createListenerFilterFactoryList(_, _)) + .WillByDefault( + Invoke([](const Protobuf::RepeatedPtrField& filters, + Configuration::FactoryContext& context) + -> std::vector { + return ProdListenerComponentFactory::createListenerFilterFactoryList_(filters, + context); + })); } }; @@ -1170,5 +1181,119 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, TlsCertificateInvalidTrustedCA) { EnvoyException, "Failed to load trusted CA certificates from "); } +TEST_F(ListenerManagerImplWithRealFiltersTest, OriginalDstFilter) { + const std::string yaml = TestEnvironment::substitute(R"EOF( + address: + socket_address: { address: 127.0.0.1, port_value: 1111 } + filter_chains: {} + listener_filters: + - name: "envoy.listener.original_dst" + config: {} + )EOF", // " + Network::Address::IpVersion::v4); + + EXPECT_CALL(server_.random_, uuid()); + EXPECT_CALL(listener_factory_, createListenSocket(_, true)); + manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), true); + EXPECT_EQ(1U, manager_->listeners().size()); + + Network::ListenerConfig& listener = manager_->listeners().back().get(); + + Network::FilterChainFactory& filterChainFactory = listener.filterChainFactory(); + Network::MockListenerFilterManager manager; + + NiceMock callbacks; + Network::AcceptedSocketImpl socket(-1, + Network::Address::InstanceConstSharedPtr{ + new Network::Address::Ipv4Instance("127.0.0.1", 1234)}, + Network::Address::InstanceConstSharedPtr{ + new Network::Address::Ipv4Instance("127.0.0.1", 5678)}); + + EXPECT_CALL(callbacks, socket()).WillOnce(Invoke([&]() -> Network::ConnectionSocket& { + return socket; + })); + + EXPECT_CALL(manager, addAcceptFilter_(_)) + .WillOnce(Invoke([&](Network::ListenerFilterPtr& filter) -> void { + EXPECT_EQ(Network::FilterStatus::Continue, filter->onAccept(callbacks)); + })); + + EXPECT_TRUE(filterChainFactory.createListenerFilterChain(manager)); +} + +class OriginalDstTest : public Filter::Listener::OriginalDst { + Network::Address::InstanceConstSharedPtr getOriginalDst(int) override { + return Network::Address::InstanceConstSharedPtr{ + new Network::Address::Ipv4Instance("127.0.0.2", 2345)}; + } +}; + +namespace Configuration { + +class OriginalDstTestConfigFactory : public NamedListenerFilterConfigFactory { +public: + // NamedListenerFilterConfigFactory + ListenerFilterFactoryCb createFilterFactoryFromProto(const Protobuf::Message&, + FactoryContext&) override { + return [](Network::ListenerFilterManager& filter_manager) -> void { + filter_manager.addAcceptFilter(std::make_unique()); + }; + } + + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return std::make_unique(); + } + + std::string name() override { return "test.listener.original_dst"; } +}; + +/** + * Static registration for the original dst filter. @see RegisterFactory. + */ +static Registry::RegisterFactory + registered_; + +} // namespace Configuration + +TEST_F(ListenerManagerImplWithRealFiltersTest, OriginalDstTestFilter) { + const std::string yaml = TestEnvironment::substitute(R"EOF( + address: + socket_address: { address: 127.0.0.1, port_value: 1111 } + filter_chains: {} + listener_filters: + - name: "test.listener.original_dst" + config: {} + )EOF", // " + Network::Address::IpVersion::v4); + + EXPECT_CALL(server_.random_, uuid()); + EXPECT_CALL(listener_factory_, createListenSocket(_, true)); + manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), true); + EXPECT_EQ(1U, manager_->listeners().size()); + + Network::ListenerConfig& listener = manager_->listeners().back().get(); + + Network::FilterChainFactory& filterChainFactory = listener.filterChainFactory(); + Network::MockListenerFilterManager manager; + + NiceMock callbacks; + Network::AcceptedSocketImpl socket( + -1, std::make_unique("127.0.0.1", 1234), + std::make_unique("127.0.0.1", 5678)); + + EXPECT_CALL(callbacks, socket()).WillOnce(Invoke([&]() -> Network::ConnectionSocket& { + return socket; + })); + + EXPECT_CALL(manager, addAcceptFilter_(_)) + .WillOnce(Invoke([&](Network::ListenerFilterPtr& filter) -> void { + EXPECT_EQ(Network::FilterStatus::Continue, filter->onAccept(callbacks)); + })); + + EXPECT_TRUE(filterChainFactory.createListenerFilterChain(manager)); + EXPECT_TRUE(socket.localAddressRestored()); + EXPECT_EQ("127.0.0.2:2345", socket.localAddress()->asString()); +} + } // namespace Server } // namespace Envoy diff --git a/test/server/worker_impl_test.cc b/test/server/worker_impl_test.cc index d44ba35837c72..6a73bf302cfff 100644 --- a/test/server/worker_impl_test.cc +++ b/test/server/worker_impl_test.cc @@ -10,6 +10,7 @@ #include "gtest/gtest.h" using testing::InSequence; +using testing::Invoke; using testing::InvokeWithoutArgs; using testing::NiceMock; using testing::Return; @@ -45,9 +46,10 @@ TEST_F(WorkerImplTest, BasicFlow) { // Before a worker is started adding a listener will be posted and will get added when the // thread starts running. NiceMock listener; - ON_CALL(listener, listenerTag()).WillByDefault(Return(1)); - EXPECT_CALL(*handler_, addListener(_, _, _, 1, _)) - .WillOnce(InvokeWithoutArgs([current_thread_id]() -> void { + ON_CALL(listener, listenerTag()).WillByDefault(Return(1UL)); + EXPECT_CALL(*handler_, addListener(_)) + .WillOnce(Invoke([current_thread_id](Network::ListenerConfig& config) -> void { + EXPECT_EQ(config.listenerTag(), 1UL); EXPECT_NE(current_thread_id, std::this_thread::get_id()); })); worker_.addListener(listener, [&ci](bool success) -> void { @@ -60,9 +62,10 @@ TEST_F(WorkerImplTest, BasicFlow) { // After a worker is started adding/stopping/removing a listener happens on the worker thread. NiceMock listener2; - ON_CALL(listener2, listenerTag()).WillByDefault(Return(2)); - EXPECT_CALL(*handler_, addListener(_, _, _, 2, _)) - .WillOnce(InvokeWithoutArgs([current_thread_id]() -> void { + ON_CALL(listener2, listenerTag()).WillByDefault(Return(2UL)); + EXPECT_CALL(*handler_, addListener(_)) + .WillOnce(Invoke([current_thread_id](Network::ListenerConfig& config) -> void { + EXPECT_EQ(config.listenerTag(), 2UL); EXPECT_NE(current_thread_id, std::this_thread::get_id()); })); worker_.addListener(listener2, [&ci](bool success) -> void { @@ -91,9 +94,10 @@ TEST_F(WorkerImplTest, BasicFlow) { // Now test adding and removing a listener without stopping it first. NiceMock listener3; - ON_CALL(listener3, listenerTag()).WillByDefault(Return(3)); - EXPECT_CALL(*handler_, addListener(_, _, _, 3, _)) - .WillOnce(InvokeWithoutArgs([current_thread_id]() -> void { + ON_CALL(listener3, listenerTag()).WillByDefault(Return(3UL)); + EXPECT_CALL(*handler_, addListener(_)) + .WillOnce(Invoke([current_thread_id](Network::ListenerConfig& config) -> void { + EXPECT_EQ(config.listenerTag(), 3UL); EXPECT_NE(current_thread_id, std::this_thread::get_id()); })); worker_.addListener(listener3, [&ci](bool success) -> void { @@ -117,8 +121,8 @@ TEST_F(WorkerImplTest, ListenerException) { InSequence s; NiceMock listener; - ON_CALL(listener, listenerTag()).WillByDefault(Return(1)); - EXPECT_CALL(*handler_, addListener(_, _, _, 1, _)) + ON_CALL(listener, listenerTag()).WillByDefault(Return(1UL)); + EXPECT_CALL(*handler_, addListener(_)) .WillOnce(Throw(Network::CreateListenerException("failed"))); worker_.addListener(listener, [](bool success) -> void { EXPECT_FALSE(success); });