diff --git a/api/envoy/api/v2/listener/listener.proto b/api/envoy/api/v2/listener/listener.proto index 2968e873df8bd..c4274733335af 100644 --- a/api/envoy/api/v2/listener/listener.proto +++ b/api/envoy/api/v2/listener/listener.proto @@ -55,10 +55,11 @@ message Filter { // // The following order applies: // -// [#comment:TODO(PiotrSikora): destination IP / ranges are going to be 1.] -// 1. Server name (e.g. SNI for TLS protocol), -// 2. Transport protocol. -// 3. Application protocols (e.g. ALPN for TLS protocol). +// 1. Destination port. +// 2. Destination IP address. +// 3. Server name (e.g. SNI for TLS protocol), +// 4. Transport protocol. +// 5. Application protocols (e.g. ALPN for TLS protocol). // // For criterias that allow ranges or wildcards, the most specific value in any // of the configured filter chains that matches the incoming connection is going @@ -71,9 +72,12 @@ message Filter { // // [#comment:TODO(PiotrSikora): Add support for configurable precedence of the rules] message FilterChainMatch { + // Optional destination port to consider when use_original_dst is set on the + // listener in determining a filter chain match. + google.protobuf.UInt32Value destination_port = 8 [(validate.rules).uint32 = {gte: 1, lte: 65535}]; + // If non-empty, an IP address and prefix length to match addresses when the // listener is bound to 0.0.0.0/:: or when use_original_dst is specified. - // [#not-implemented-hide:] repeated core.CidrRange prefix_ranges = 3; // If non-empty, an IP address and suffix length to match addresses when the @@ -97,11 +101,6 @@ message FilterChainMatch { // [#not-implemented-hide:] repeated google.protobuf.UInt32Value source_ports = 7; - // Optional destination port to consider when use_original_dst is set on the - // listener in determining a filter chain match. - // [#not-implemented-hide:] - google.protobuf.UInt32Value destination_port = 8; - // If non-empty, a list of server names (e.g. SNI for TLS protocol) to consider when determining // a filter chain match. Those values will be compared against the server names of a new // connection, when detected by one of the listener filters. diff --git a/docs/root/intro/version_history.rst b/docs/root/intro/version_history.rst index 2f74c99e65f71..f9217f55e99a7 100644 --- a/docs/root/intro/version_history.rst +++ b/docs/root/intro/version_history.rst @@ -19,6 +19,9 @@ Version history * proxy_protocol: added support for HAProxy Proxy Protocol v2 (AF_INET/AF_INET6 only). * http: added generic +:ref:`Upgrade support ` +* listeners: added the ability to match :ref:`FilterChain ` using + :ref:`destination_port ` and + :ref:`prefix_ranges `. * lua: added :ref:`connection() ` wrapper and *ssl()* API. * lua: added :ref:`requestInfo() ` wrapper and *protocol()* API. * ratelimit: added support for :repo:`api/envoy/service/ratelimit/v2/rls.proto`. diff --git a/source/server/BUILD b/source/server/BUILD index 6ff07ca41dd86..a85b98271e938 100644 --- a/source/server/BUILD +++ b/source/server/BUILD @@ -208,6 +208,8 @@ envoy_cc_library( "//source/common/api:os_sys_calls_lib", "//source/common/common:empty_string", "//source/common/config:utility_lib", + "//source/common/network:cidr_range_lib", + "//source/common/network:lc_trie_lib", "//source/common/network:listen_socket_lib", "//source/common/network:resolver_lib", "//source/common/network:socket_option_factory_lib", diff --git a/source/server/listener_manager_impl.cc b/source/server/listener_manager_impl.cc index f27081dac3743..84ecdac8f9111 100644 --- a/source/server/listener_manager_impl.cc +++ b/source/server/listener_manager_impl.cc @@ -202,6 +202,13 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, const std::st ProtobufTypes::MessagePtr message = Config::Utility::translateToFactoryConfig(transport_socket, config_factory); + // Validate IP addresses. + std::vector destination_ips; + for (const auto& destination_ip : filter_chain_match.prefix_ranges()) { + const auto& cidr_range = Network::Address::CidrRange::create(destination_ip); + destination_ips.push_back(cidr_range.asString()); + } + std::vector server_names; if (!filter_chain_match.server_names().empty()) { if (!filter_chain_match.sni_domains().empty()) { @@ -233,7 +240,9 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, const std::st filter_chain_match.application_protocols().begin(), filter_chain_match.application_protocols().end()); - addFilterChain(server_names, filter_chain_match.transport_protocol(), application_protocols, + addFilterChain(PROTOBUF_GET_WRAPPED_OR_DEFAULT(filter_chain_match, destination_port, 0), + destination_ips, server_names, filter_chain_match.transport_protocol(), + application_protocols, config_factory.createTransportSocketFactory(*message, *this, server_names), parent_.factory_.createNetworkFilterFactoryList(filter_chain.filters(), *this)); @@ -242,6 +251,9 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, const std::st (!server_names.empty() || !application_protocols.empty())); } + // Convert DestinationIPsMap to DestinationIPsTrie for faster lookups. + convertDestinationIPsMapToTrie(); + // Automatically inject TLS Inspector if it wasn't configured explicitly and it's needed. if (need_tls_inspector) { for (const auto& filter : config.listener_filters()) { @@ -274,33 +286,73 @@ ListenerImpl::~ListenerImpl() { // active. This is done here explicitly by setting a boolean and then clearing the factory // vector for clarity. initialize_canceled_ = true; - filter_chains_.clear(); + destination_ports_map_.clear(); } bool ListenerImpl::isWildcardServerName(const std::string& name) { return absl::StartsWith(name, "*."); } -void ListenerImpl::addFilterChain(const std::vector& server_names, +void ListenerImpl::addFilterChain(uint16_t destination_port, + const std::vector& destination_ips, + const std::vector& server_names, const std::string& transport_protocol, const std::vector& application_protocols, Network::TransportSocketFactoryPtr&& transport_socket_factory, std::vector filters_factory) { const auto filter_chain = std::make_shared(std::move(transport_socket_factory), std::move(filters_factory)); - // Save mappings. + addFilterChainForDestinationPorts(destination_ports_map_, destination_port, destination_ips, + server_names, transport_protocol, application_protocols, + filter_chain); +} + +void ListenerImpl::addFilterChainForDestinationPorts( + DestinationPortsMap& destination_ports_map, uint16_t destination_port, + const std::vector& destination_ips, const std::vector& server_names, + const std::string& transport_protocol, const std::vector& application_protocols, + const Network::FilterChainSharedPtr& filter_chain) { + if (destination_ports_map.find(destination_port) == destination_ports_map.end()) { + destination_ports_map[destination_port] = + std::make_pair(DestinationIPsMap{}, nullptr); + } + addFilterChainForDestinationIPs(destination_ports_map[destination_port].first, destination_ips, + server_names, transport_protocol, application_protocols, + filter_chain); +} + +void ListenerImpl::addFilterChainForDestinationIPs( + DestinationIPsMap& destination_ips_map, const std::vector& destination_ips, + const std::vector& server_names, const std::string& transport_protocol, + const std::vector& application_protocols, + const Network::FilterChainSharedPtr& filter_chain) { + if (destination_ips.empty()) { + addFilterChainForServerNames(destination_ips_map[EMPTY_STRING], server_names, + transport_protocol, application_protocols, filter_chain); + } else { + for (const auto& destination_ip : destination_ips) { + addFilterChainForServerNames(destination_ips_map[destination_ip], server_names, + transport_protocol, application_protocols, filter_chain); + } + } +} + +void ListenerImpl::addFilterChainForServerNames( + ServerNamesMap& server_names_map, const std::vector& server_names, + const std::string& transport_protocol, const std::vector& application_protocols, + const Network::FilterChainSharedPtr& filter_chain) { if (server_names.empty()) { - addFilterChainForApplicationProtocols(filter_chains_[EMPTY_STRING][transport_protocol], + addFilterChainForApplicationProtocols(server_names_map[EMPTY_STRING][transport_protocol], application_protocols, filter_chain); } else { for (const auto& server_name : server_names) { if (isWildcardServerName(server_name)) { // Add mapping for the wildcard domain, i.e. ".example.com" for "*.example.com". addFilterChainForApplicationProtocols( - filter_chains_[server_name.substr(1)][transport_protocol], application_protocols, + server_names_map[server_name.substr(1)][transport_protocol], application_protocols, filter_chain); } else { - addFilterChainForApplicationProtocols(filter_chains_[server_name][transport_protocol], + addFilterChainForApplicationProtocols(server_names_map[server_name][transport_protocol], application_protocols, filter_chain); } } @@ -308,64 +360,129 @@ void ListenerImpl::addFilterChain(const std::vector& server_names, } void ListenerImpl::addFilterChainForApplicationProtocols( - std::unordered_map& transport_protocol_map, + ApplicationProtocolsMap& application_protocols_map, const std::vector& application_protocols, const Network::FilterChainSharedPtr& filter_chain) { if (application_protocols.empty()) { - transport_protocol_map[EMPTY_STRING] = filter_chain; + application_protocols_map[EMPTY_STRING] = filter_chain; } else { for (const auto& application_protocol : application_protocols) { - transport_protocol_map[application_protocol] = filter_chain; + application_protocols_map[application_protocol] = filter_chain; + } + } +} + +void ListenerImpl::convertDestinationIPsMapToTrie() { + for (auto& port : destination_ports_map_) { + auto& destination_ips_pair = port.second; + auto& destination_ips_map = destination_ips_pair.first; + std::vector>> list; + for (const auto& entry : destination_ips_map) { + std::vector subnets; + if (entry.first == EMPTY_STRING) { + list.push_back( + std::make_pair>( + std::make_shared(entry.second), + {Network::Address::CidrRange::create("0.0.0.0/0"), + Network::Address::CidrRange::create("::/0")})); + } else { + list.push_back( + std::make_pair>( + std::make_shared(entry.second), + {Network::Address::CidrRange::create(entry.first)})); + } } + destination_ips_pair.second = std::make_unique(list, true); } } const Network::FilterChain* ListenerImpl::findFilterChain(const Network::ConnectionSocket& socket) const { + const auto& address = socket.localAddress(); + + // Match on destination port (only for IP addresses). + if (address->type() == Network::Address::Type::Ip) { + const auto port_match = destination_ports_map_.find(address->ip()->port()); + if (port_match != destination_ports_map_.end()) { + return findFilterChainForDestinationIP(*port_match->second.second, socket); + } + } + + // Match on catch-all port 0. + const auto port_match = destination_ports_map_.find(0); + if (port_match != destination_ports_map_.end()) { + return findFilterChainForDestinationIP(*port_match->second.second, socket); + } + + return nullptr; +} + +const Network::FilterChain* +ListenerImpl::findFilterChainForDestinationIP(const DestinationIPsTrie& destination_ips_trie, + const Network::ConnectionSocket& socket) const { + // Use invalid IP address (matching only filter chains without IP requirements) for UDS. + static const auto& fake_address = Network::Utility::parseInternetAddress("255.255.255.255"); + + auto address = socket.localAddress(); + if (address->type() != Network::Address::Type::Ip) { + address = fake_address; + } + + // Match on both: exact IP and wider CIDR ranges using LcTrie. + const auto& data = destination_ips_trie.getData(address); + if (!data.empty()) { + ASSERT(data.size() == 1); + return findFilterChainForServerName(*data.back(), socket); + } + + return nullptr; +} + +const Network::FilterChain* +ListenerImpl::findFilterChainForServerName(const ServerNamesMap& server_names_map, + const Network::ConnectionSocket& socket) const { const std::string server_name(socket.requestedServerName()); // Match on exact server name, i.e. "www.example.com" for "www.example.com". - const auto server_name_exact_match = filter_chains_.find(server_name); - if (server_name_exact_match != filter_chains_.end()) { - return findFilterChainForServerName(server_name_exact_match->second, socket); + const auto server_name_exact_match = server_names_map.find(server_name); + if (server_name_exact_match != server_names_map.end()) { + return findFilterChainForTransportProtocol(server_name_exact_match->second, socket); } // Match on all wildcard domains, i.e. ".example.com" and ".com" for "www.example.com". size_t pos = server_name.find('.', 1); while (pos < server_name.size() - 1 && pos != std::string::npos) { const std::string wildcard = server_name.substr(pos); - const auto server_name_wildcard_match = filter_chains_.find(wildcard); - if (server_name_wildcard_match != filter_chains_.end()) { - return findFilterChainForServerName(server_name_wildcard_match->second, socket); + const auto server_name_wildcard_match = server_names_map.find(wildcard); + if (server_name_wildcard_match != server_names_map.end()) { + return findFilterChainForTransportProtocol(server_name_wildcard_match->second, socket); } pos = server_name.find('.', pos + 1); } // Match on a filter chain without server name requirements. - const auto server_name_catchall_match = filter_chains_.find(EMPTY_STRING); - if (server_name_catchall_match != filter_chains_.end()) { - return findFilterChainForServerName(server_name_catchall_match->second, socket); + const auto server_name_catchall_match = server_names_map.find(EMPTY_STRING); + if (server_name_catchall_match != server_names_map.end()) { + return findFilterChainForTransportProtocol(server_name_catchall_match->second, socket); } return nullptr; } -const Network::FilterChain* ListenerImpl::findFilterChainForServerName( - const std::unordered_map>& - server_name_match, +const Network::FilterChain* ListenerImpl::findFilterChainForTransportProtocol( + const TransportProtocolsMap& transport_protocols_map, const Network::ConnectionSocket& socket) const { const std::string transport_protocol(socket.detectedTransportProtocol()); // Match on exact transport protocol, e.g. "tls". - const auto transport_protocol_match = server_name_match.find(transport_protocol); - if (transport_protocol_match != server_name_match.end()) { + const auto transport_protocol_match = transport_protocols_map.find(transport_protocol); + if (transport_protocol_match != transport_protocols_map.end()) { return findFilterChainForApplicationProtocols(transport_protocol_match->second, socket); } // Match on a filter chain without transport protocol requirements. - const auto any_protocol_match = server_name_match.find(EMPTY_STRING); - if (any_protocol_match != server_name_match.end()) { + const auto any_protocol_match = transport_protocols_map.find(EMPTY_STRING); + if (any_protocol_match != transport_protocols_map.end()) { return findFilterChainForApplicationProtocols(any_protocol_match->second, socket); } @@ -373,19 +490,19 @@ const Network::FilterChain* ListenerImpl::findFilterChainForServerName( } const Network::FilterChain* ListenerImpl::findFilterChainForApplicationProtocols( - const std::unordered_map& transport_protocol_match, + const ApplicationProtocolsMap& application_protocols_map, const Network::ConnectionSocket& socket) const { // Match on exact application protocol, e.g. "h2" or "http/1.1". for (const auto& application_protocol : socket.requestedApplicationProtocols()) { - const auto application_protocol_match = transport_protocol_match.find(application_protocol); - if (application_protocol_match != transport_protocol_match.end()) { + const auto application_protocol_match = application_protocols_map.find(application_protocol); + if (application_protocol_match != application_protocols_map.end()) { return application_protocol_match->second.get(); } } // Match on a filter chain without application protocol requirements. - const auto any_protocol_match = transport_protocol_match.find(EMPTY_STRING); - if (any_protocol_match != transport_protocol_match.end()) { + const auto any_protocol_match = application_protocols_map.find(EMPTY_STRING); + if (any_protocol_match != application_protocols_map.end()) { return any_protocol_match->second.get(); } diff --git a/source/server/listener_manager_impl.h b/source/server/listener_manager_impl.h index 5349e4c319ce0..14dfadf30ab58 100644 --- a/source/server/listener_manager_impl.h +++ b/source/server/listener_manager_impl.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "envoy/api/v2/listener/listener.pb.h" #include "envoy/network/filter.h" #include "envoy/server/filter_config.h" @@ -9,6 +11,8 @@ #include "envoy/server/worker.h" #include "common/common/logger.h" +#include "common/network/cidr_range.h" +#include "common/network/lc_trie.h" #include "server/init_manager_impl.h" #include "server/lds_api.h" @@ -305,36 +309,67 @@ class ListenerImpl : public Network::ListenerConfig, SystemTime last_updated_; private: - void addFilterChain(const std::vector& server_names, + typedef std::unordered_map ApplicationProtocolsMap; + typedef std::unordered_map TransportProtocolsMap; + // Both exact server names and wildcard domains are part of the same map, in which wildcard + // domains are prefixed with "." (i.e. ".example.com" for "*.example.com") to differentiate + // between exact and wildcard entries. + typedef std::unordered_map ServerNamesMap; + typedef std::unordered_map DestinationIPsMap; + typedef std::shared_ptr ServerNamesMapSharedPtr; + typedef Network::LcTrie::LcTrie DestinationIPsTrie; + typedef std::unique_ptr DestinationIPsTriePtr; + typedef std::unordered_map> + DestinationPortsMap; + + void addFilterChain(uint16_t destination_port, const std::vector& destination_ips, + const std::vector& server_names, const std::string& transport_protocol, const std::vector& application_protocols, Network::TransportSocketFactoryPtr&& transport_socket_factory, std::vector filters_factory); - void addFilterChainForApplicationProtocols( - std::unordered_map& transport_protocol_map, - const std::vector& application_protocols, - const Network::FilterChainSharedPtr& filter_chain); - const Network::FilterChain* findFilterChainForServerName( - const std::unordered_map>& - server_name_match, - const Network::ConnectionSocket& socket) const; - const Network::FilterChain* findFilterChainForApplicationProtocols( - const std::unordered_map& - transport_protocol_match, - const Network::ConnectionSocket& socket) const; + void addFilterChainForDestinationPorts(DestinationPortsMap& destination_ports_map, + uint16_t destination_port, + const std::vector& destination_ips, + const std::vector& server_names, + const std::string& transport_protocol, + const std::vector& application_protocols, + const Network::FilterChainSharedPtr& filter_chain); + void addFilterChainForDestinationIPs(DestinationIPsMap& destination_ips_map, + const std::vector& destination_ips, + const std::vector& server_names, + const std::string& transport_protocol, + const std::vector& application_protocols, + const Network::FilterChainSharedPtr& filter_chain); + void addFilterChainForServerNames(ServerNamesMap& server_names_map, + const std::vector& server_names, + const std::string& transport_protocol, + const std::vector& application_protocols, + const Network::FilterChainSharedPtr& filter_chain); + void addFilterChainForApplicationProtocols(ApplicationProtocolsMap& application_protocol_map, + const std::vector& application_protocols, + const Network::FilterChainSharedPtr& filter_chain); + + void convertDestinationIPsMapToTrie(); + + const Network::FilterChain* + findFilterChainForDestinationIP(const DestinationIPsTrie& destination_ips_trie, + const Network::ConnectionSocket& socket) const; + const Network::FilterChain* + findFilterChainForServerName(const ServerNamesMap& server_names_map, + const Network::ConnectionSocket& socket) const; + const Network::FilterChain* + findFilterChainForTransportProtocol(const TransportProtocolsMap& transport_protocols_map, + const Network::ConnectionSocket& socket) const; + const Network::FilterChain* + findFilterChainForApplicationProtocols(const ApplicationProtocolsMap& application_protocols_map, + const Network::ConnectionSocket& socket) const; + static bool isWildcardServerName(const std::string& name); - // Mapping of FilterChain's configured server name and transport protocol, i.e. - // map[server_name][transport_protocol][application_protocol] => FilterChainSharedPtr - // - // For the server_name lookups, both exact server names and wildcard domains are part of the same - // map, in which wildcard domains are prefixed with "." (i.e. ".example.com" for "*.example.com") - // to differentiate between exact and wildcard entries. - std::unordered_map< - std::string, std::unordered_map< - std::string, std::unordered_map>> - filter_chains_; + // Mapping of FilterChain's configured destination ports, IPs, server names, transport protocols + // and application protocols, using structures defined above. + DestinationPortsMap destination_ports_map_; ListenerManagerImpl& parent_; Network::Address::InstanceConstSharedPtr address_; diff --git a/test/server/listener_manager_impl_test.cc b/test/server/listener_manager_impl_test.cc index 7f2f1cf646005..94698c3d76d3d 100644 --- a/test/server/listener_manager_impl_test.cc +++ b/test/server/listener_manager_impl_test.cc @@ -23,6 +23,7 @@ #include "test/test_common/threadsafe_singleton_injector.h" #include "test/test_common/utility.h" +#include "absl/strings/match.h" #include "gtest/gtest.h" using testing::InSequence; @@ -137,31 +138,49 @@ class ListenerManagerImplWithRealFiltersTest : public ListenerManagerImplTest { context); })); socket_.reset(new NiceMock()); + address_.reset(new Network::Address::Ipv4Instance("127.0.0.1", 1234)); } const Network::FilterChain* - findFilterChain(const std::string& server_name, bool expect_server_name_match, + findFilterChain(uint16_t destination_port, bool expect_destination_port_match, + const std::string& destination_address, bool expect_destination_address_match, + const std::string& server_name, bool expect_server_name_match, const std::string& transport_protocol, bool expect_transport_protocol_match, const std::vector& application_protocols) { - EXPECT_CALL(*socket_, requestedServerName()).WillOnce(Return(absl::string_view(server_name))); + const int times = expect_destination_port_match ? 2 : 1; + if (absl::StartsWith(destination_address, "/")) { + address_.reset(new Network::Address::PipeInstance(destination_address)); + } else { + address_.reset(new Network::Address::Ipv4Instance(destination_address, destination_port)); + } + EXPECT_CALL(*socket_, localAddress()).Times(times).WillRepeatedly(ReturnRef(address_)); + + if (expect_destination_address_match) { + EXPECT_CALL(*socket_, requestedServerName()).WillOnce(Return(absl::string_view(server_name))); + } else { + EXPECT_CALL(*socket_, requestedServerName()).Times(0); + } + if (expect_server_name_match) { EXPECT_CALL(*socket_, detectedTransportProtocol()) .WillOnce(Return(absl::string_view(transport_protocol))); - if (expect_transport_protocol_match) { - EXPECT_CALL(*socket_, requestedApplicationProtocols()) - .WillOnce(ReturnRef(application_protocols)); - } else { - EXPECT_CALL(*socket_, requestedApplicationProtocols()).Times(0); - } } else { EXPECT_CALL(*socket_, detectedTransportProtocol()).Times(0); + } + + if (expect_transport_protocol_match) { + EXPECT_CALL(*socket_, requestedApplicationProtocols()) + .WillOnce(ReturnRef(application_protocols)); + } else { EXPECT_CALL(*socket_, requestedApplicationProtocols()).Times(0); } + return manager_->listeners().back().get().filterChainManager().findFilterChain(*socket_); } private: std::unique_ptr socket_; + Network::Address::InstanceConstSharedPtr address_; }; class MockLdsApi : public LdsApi { @@ -232,7 +251,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SslContext) { manager_->addOrUpdateListener(parseListenerFromJson(json), "", true); EXPECT_EQ(1U, manager_->listeners().size()); - auto filter_chain = findFilterChain("", true, "tls", true, {}); + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); } @@ -1033,6 +1052,90 @@ TEST_F(ListenerManagerImplTest, EarlyShutdown) { manager_->stopWorkers(); } +TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithDestinationPortMatch) { + const std::string yaml = TestEnvironment::substitute(R"EOF( + address: + socket_address: { address: 127.0.0.1, port_value: 1234 } + listener_filters: + - name: "envoy.listener.tls_inspector" + config: {} + filter_chains: + - filter_chain_match: + destination_port: 8080 + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_key.pem" } + )EOF", + Network::Address::IpVersion::v4); + + EXPECT_CALL(server_.random_, uuid()); + EXPECT_CALL(listener_factory_, createListenSocket(_, _, true)); + manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), "", true); + EXPECT_EQ(1U, manager_->listeners().size()); + + // IPv4 client connects to unknown port - no match. + auto filter_chain = findFilterChain(1234, false, "127.0.0.1", false, "", false, "tls", false, {}); + EXPECT_EQ(filter_chain, nullptr); + + // IPv4 client connects to valid port - using 1st filter chain. + filter_chain = findFilterChain(8080, true, "127.0.0.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto ssl_socket = dynamic_cast(transport_socket.get()); + auto server_names = ssl_socket->dnsSansLocalCertificate(); + EXPECT_EQ(server_names.size(), 1); + EXPECT_EQ(server_names.front(), "server1.example.com"); + + // UDS client - no match. + filter_chain = findFilterChain(0, false, "/tmp/test.sock", false, "", false, "tls", false, {}); + EXPECT_EQ(filter_chain, nullptr); +} + +TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithDestinationIPMatch) { + const std::string yaml = TestEnvironment::substitute(R"EOF( + address: + socket_address: { address: 127.0.0.1, port_value: 1234 } + listener_filters: + - name: "envoy.listener.tls_inspector" + config: {} + filter_chains: + - filter_chain_match: + prefix_ranges: { address_prefix: 127.0.0.0, prefix_len: 8 } + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_key.pem" } + )EOF", + Network::Address::IpVersion::v4); + + EXPECT_CALL(server_.random_, uuid()); + EXPECT_CALL(listener_factory_, createListenSocket(_, _, true)); + manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), "", true); + EXPECT_EQ(1U, manager_->listeners().size()); + + // IPv4 client connects to unknown IP - no match. + auto filter_chain = findFilterChain(1234, true, "1.2.3.4", false, "", false, "tls", false, {}); + EXPECT_EQ(filter_chain, nullptr); + + // IPv4 client connects to valid IP - using 1st filter chain. + filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto ssl_socket = dynamic_cast(transport_socket.get()); + auto server_names = ssl_socket->dnsSansLocalCertificate(); + EXPECT_EQ(server_names.size(), 1); + EXPECT_EQ(server_names.front(), "server1.example.com"); + + // UDS client - no match. + filter_chain = findFilterChain(0, true, "/tmp/test.sock", false, "", false, "tls", false, {}); + EXPECT_EQ(filter_chain, nullptr); +} + TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithServerNamesMatch) { const std::string yaml = TestEnvironment::substitute(R"EOF( address: @@ -1057,15 +1160,17 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithServerNamesM EXPECT_EQ(1U, manager_->listeners().size()); // TLS client without SNI - no match. - auto filter_chain = findFilterChain("", false, "tls", false, {}); + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", false, "tls", false, {}); EXPECT_EQ(filter_chain, nullptr); // TLS client without matching SNI - no match. - filter_chain = findFilterChain("www.example.com", false, "tls", false, {}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "www.example.com", false, "tls", false, {}); EXPECT_EQ(filter_chain, nullptr); // TLS client with matching SNI - using 1st filter chain. - filter_chain = findFilterChain("server1.example.com", true, "tls", true, {}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "server1.example.com", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1099,11 +1204,12 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithTransportPro EXPECT_EQ(1U, manager_->listeners().size()); // TCP client - no match. - auto filter_chain = findFilterChain("", true, "raw_buffer", false, {}); + auto filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "", true, "raw_buffer", false, {}); EXPECT_EQ(filter_chain, nullptr); // TLS client - using 1st filter chain. - filter_chain = findFilterChain("", true, "tls", true, {}); + filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1137,11 +1243,12 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithApplicationP EXPECT_EQ(1U, manager_->listeners().size()); // TLS client without ALPN - no match. - auto filter_chain = findFilterChain("", true, "tls", true, {}); + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); EXPECT_EQ(filter_chain, nullptr); // TLS client with "http/1.1" ALPN - using 1st filter chain. - filter_chain = findFilterChain("", true, "tls", true, {"h2", "http/1.1"}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {"h2", "http/1.1"}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1151,6 +1258,158 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithApplicationP EXPECT_EQ(server_names.front(), "server1.example.com"); } +TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinationPortMatch) { + const std::string yaml = TestEnvironment::substitute(R"EOF( + address: + socket_address: { address: 127.0.0.1, port_value: 1234 } + listener_filters: + - name: "envoy.listener.tls_inspector" + config: {} + filter_chains: + - filter_chain_match: + # empty + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_uri_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_uri_key.pem" } + - filter_chain_match: + destination_port: 8080 + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_key.pem" } + - filter_chain_match: + destination_port: 8081 + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_multiple_dns_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_multiple_dns_key.pem" } + )EOF", + Network::Address::IpVersion::v4); + + EXPECT_CALL(server_.random_, uuid()); + EXPECT_CALL(listener_factory_, createListenSocket(_, _, true)); + manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), "", true); + EXPECT_EQ(1U, manager_->listeners().size()); + + // IPv4 client connects to default port - using 1st filter chain. + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto ssl_socket = dynamic_cast(transport_socket.get()); + auto uri = ssl_socket->uriSanLocalCertificate(); + EXPECT_EQ(uri, "spiffe://lyft.com/test-team"); + + // IPv4 client connects to port 8080 - using 2nd filter chain. + filter_chain = findFilterChain(8080, true, "127.0.0.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + ssl_socket = dynamic_cast(transport_socket.get()); + auto server_names = ssl_socket->dnsSansLocalCertificate(); + EXPECT_EQ(server_names.size(), 1); + EXPECT_EQ(server_names.front(), "server1.example.com"); + + // IPv4 client connects to port 8081 - using 3nd filter chain. + filter_chain = findFilterChain(8081, true, "127.0.0.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + ssl_socket = dynamic_cast(transport_socket.get()); + server_names = ssl_socket->dnsSansLocalCertificate(); + EXPECT_EQ(server_names.size(), 2); + EXPECT_EQ(server_names.front(), "*.example.com"); + + // UDS client - using 1st filter chain. + filter_chain = findFilterChain(0, true, "/tmp/test.sock", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + ssl_socket = dynamic_cast(transport_socket.get()); + uri = ssl_socket->uriSanLocalCertificate(); + EXPECT_EQ(uri, "spiffe://lyft.com/test-team"); +} + +TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinationIPMatch) { + const std::string yaml = TestEnvironment::substitute(R"EOF( + address: + socket_address: { address: 127.0.0.1, port_value: 1234 } + listener_filters: + - name: "envoy.listener.tls_inspector" + config: {} + filter_chains: + - filter_chain_match: + # empty + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_uri_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_uri_key.pem" } + - filter_chain_match: + prefix_ranges: { address_prefix: 192.168.0.1, prefix_len: 32 } + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_key.pem" } + - filter_chain_match: + prefix_ranges: { address_prefix: 192.168.0.0, prefix_len: 16 } + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_multiple_dns_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_multiple_dns_key.pem" } + )EOF", + Network::Address::IpVersion::v4); + + EXPECT_CALL(server_.random_, uuid()); + EXPECT_CALL(listener_factory_, createListenSocket(_, _, true)); + manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), "", true); + EXPECT_EQ(1U, manager_->listeners().size()); + + // IPv4 client connects to default IP - using 1st filter chain. + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto ssl_socket = dynamic_cast(transport_socket.get()); + auto uri = ssl_socket->uriSanLocalCertificate(); + EXPECT_EQ(uri, "spiffe://lyft.com/test-team"); + + // IPv4 client connects to exact IP match - using 2nd filter chain. + filter_chain = findFilterChain(1234, true, "192.168.0.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + ssl_socket = dynamic_cast(transport_socket.get()); + auto server_names = ssl_socket->dnsSansLocalCertificate(); + EXPECT_EQ(server_names.size(), 1); + EXPECT_EQ(server_names.front(), "server1.example.com"); + + // IPv4 client connects to wildcard IP match - using 3nd filter chain. + filter_chain = findFilterChain(1234, true, "192.168.1.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + ssl_socket = dynamic_cast(transport_socket.get()); + server_names = ssl_socket->dnsSansLocalCertificate(); + EXPECT_EQ(server_names.size(), 2); + EXPECT_EQ(server_names.front(), "*.example.com"); + + // UDS client - using 1st filter chain. + filter_chain = findFilterChain(0, true, "/tmp/test.sock", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + ssl_socket = dynamic_cast(transport_socket.get()); + uri = ssl_socket->uriSanLocalCertificate(); + EXPECT_EQ(uri, "spiffe://lyft.com/test-team"); +} + TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNamesMatch) { const std::string yaml = TestEnvironment::substitute(R"EOF( address: @@ -1198,7 +1457,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam EXPECT_EQ(1U, manager_->listeners().size()); // TLS client without SNI - using 1st filter chain. - auto filter_chain = findFilterChain("", true, "tls", true, {}); + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1207,7 +1466,8 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam EXPECT_EQ(uri, "spiffe://lyft.com/test-team"); // TLS client with exact SNI match - using 2nd filter chain. - filter_chain = findFilterChain("server1.example.com", true, "tls", true, {}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "server1.example.com", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1217,7 +1477,8 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam EXPECT_EQ(server_names.front(), "server1.example.com"); // TLS client with wildcard SNI match - using 3nd filter chain. - filter_chain = findFilterChain("server2.example.com", true, "tls", true, {}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "server2.example.com", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1227,7 +1488,8 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam EXPECT_EQ(server_names.front(), "*.example.com"); // TLS client with wildcard SNI match - using 3nd filter chain. - filter_chain = findFilterChain("www.wildcard.com", true, "tls", true, {}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "www.wildcard.com", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1263,12 +1525,13 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithTransport EXPECT_EQ(1U, manager_->listeners().size()); // TCP client - using 1st filter chain. - auto filter_chain = findFilterChain("", true, "raw_buffer", true, {}); + auto filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "", true, "raw_buffer", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_FALSE(filter_chain->transportSocketFactory().implementsSecureTransport()); // TLS client - using 2nd filter chain. - filter_chain = findFilterChain("", true, "tls", true, {}); + filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1304,12 +1567,13 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithApplicati EXPECT_EQ(1U, manager_->listeners().size()); // TLS client without ALPN - using 1st filter chain. - auto filter_chain = findFilterChain("", true, "tls", true, {}); + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_FALSE(filter_chain->transportSocketFactory().implementsSecureTransport()); // TLS client with "h2,http/1.1" ALPN - using 2nd filter chain. - filter_chain = findFilterChain("", true, "tls", true, {"h2", "http/1.1"}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {"h2", "http/1.1"}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1347,21 +1611,24 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithMultipleR EXPECT_EQ(1U, manager_->listeners().size()); // TLS client without SNI and ALPN - using 1st filter chain. - auto filter_chain = findFilterChain("", true, "tls", true, {}); + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_FALSE(filter_chain->transportSocketFactory().implementsSecureTransport()); // TLS client with exact SNI match but without ALPN - no match (SNI blackholed by configuration). - filter_chain = findFilterChain("server1.example.com", true, "tls", true, {}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "server1.example.com", true, "tls", true, {}); EXPECT_EQ(filter_chain, nullptr); // TLS client with ALPN match but without SNI - using 1st filter chain. - filter_chain = findFilterChain("", true, "tls", true, {"h2", "http/1.1"}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {"h2", "http/1.1"}); ASSERT_NE(filter_chain, nullptr); EXPECT_FALSE(filter_chain->transportSocketFactory().implementsSecureTransport()); // TLS client with exact SNI match and ALPN match - using 2nd filter chain. - filter_chain = findFilterChain("server1.example.com", true, "tls", true, {"h2", "http/1.1"}); + filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "server1.example.com", true, "tls", + true, {"h2", "http/1.1"}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1443,6 +1710,23 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, EXPECT_EQ(1U, manager_->listeners().size()); } +TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithInvalidDestinationIPMatch) { + const std::string yaml = TestEnvironment::substitute(R"EOF( + address: + socket_address: { address: 127.0.0.1, port_value: 1234 } + listener_filters: + - name: "envoy.listener.tls_inspector" + config: {} + filter_chains: + - filter_chain_match: + prefix_ranges: { address_prefix: a.b.c.d, prefix_len: 32 } + )EOF", + Network::Address::IpVersion::v4); + + EXPECT_THROW_WITH_MESSAGE(manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), "", true), + EnvoyException, "malformed IP address: a.b.c.d"); +} + TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithInvalidServerNamesMatch) { const std::string yaml = TestEnvironment::substitute(R"EOF( address: @@ -1615,15 +1899,17 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithDeprecatedSn EXPECT_EQ(1U, manager_->listeners().size()); // TLS client without SNI - no match. - auto filter_chain = findFilterChain("", false, "tls", false, {}); + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", false, "tls", false, {}); EXPECT_EQ(filter_chain, nullptr); // TLS client without matching SNI - no match. - filter_chain = findFilterChain("www.example.com", false, "tls", false, {}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "www.example.com", false, "tls", false, {}); EXPECT_EQ(filter_chain, nullptr); // TLS client with matching SNI - using 1st filter chain. - filter_chain = findFilterChain("server1.example.com", true, "tls", true, {}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "server1.example.com", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket();