Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions api/envoy/api/v2/listener/listener.proto
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ message Filter {
//
// The following order applies:
//
// [#comment:TODO(PiotrSikora): destination IP / ranges are going to be 1.]
// 1. Server name (e.g. SNI for TLS protocol),
// 2. Transport protocol.
// 3. Application protocols (e.g. ALPN for TLS protocol).
// 1. Destination port.
// 2. Destination IP address.
// 3. Server name (e.g. SNI for TLS protocol),
// 4. Transport protocol.
// 5. Application protocols (e.g. ALPN for TLS protocol).
//
// For criterias that allow ranges or wildcards, the most specific value in any
// of the configured filter chains that matches the incoming connection is going
Expand All @@ -71,9 +72,12 @@ message Filter {
//
// [#comment:TODO(PiotrSikora): Add support for configurable precedence of the rules]
message FilterChainMatch {
// Optional destination port to consider when use_original_dst is set on the
// listener in determining a filter chain match.
google.protobuf.UInt32Value destination_port = 8 [(validate.rules).uint32 = {gte: 1, lte: 65535}];

// If non-empty, an IP address and prefix length to match addresses when the
// listener is bound to 0.0.0.0/:: or when use_original_dst is specified.
// [#not-implemented-hide:]
repeated core.CidrRange prefix_ranges = 3;

// If non-empty, an IP address and suffix length to match addresses when the
Expand All @@ -97,11 +101,6 @@ message FilterChainMatch {
// [#not-implemented-hide:]
repeated google.protobuf.UInt32Value source_ports = 7;

// Optional destination port to consider when use_original_dst is set on the
// listener in determining a filter chain match.
// [#not-implemented-hide:]
google.protobuf.UInt32Value destination_port = 8;

// If non-empty, a list of server names (e.g. SNI for TLS protocol) to consider when determining
// a filter chain match. Those values will be compared against the server names of a new
// connection, when detected by one of the listener filters.
Expand Down
3 changes: 3 additions & 0 deletions docs/root/intro/version_history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ Version history
* proxy_protocol: added support for HAProxy Proxy Protocol v2 (AF_INET/AF_INET6 only).
* http: added generic +:ref:`Upgrade support
<envoy_api_field_config.filter.network.http_connection_manager.v2.HttpConnectionManager.upgrade_configs>`
* listeners: added the ability to match :ref:`FilterChain <envoy_api_msg_listener.FilterChain>` using
:ref:`destination_port <envoy_api_field_listener.FilterChainMatch.destination_port>` and
:ref:`prefix_ranges <envoy_api_field_listener.FilterChainMatch.prefix_ranges>`.
* lua: added :ref:`connection() <config_http_filters_lua_connection_wrapper>` wrapper and *ssl()* API.
* lua: added :ref:`requestInfo() <config_http_filters_lua_request_info_wrapper>` wrapper and *protocol()* API.
* ratelimit: added support for :repo:`api/envoy/service/ratelimit/v2/rls.proto`.
Expand Down
2 changes: 2 additions & 0 deletions source/server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ envoy_cc_library(
"//source/common/api:os_sys_calls_lib",
"//source/common/common:empty_string",
"//source/common/config:utility_lib",
"//source/common/network:cidr_range_lib",
"//source/common/network:lc_trie_lib",
"//source/common/network:listen_socket_lib",
"//source/common/network:resolver_lib",
"//source/common/network:socket_option_factory_lib",
Expand Down
181 changes: 149 additions & 32 deletions source/server/listener_manager_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,13 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, const std::st
ProtobufTypes::MessagePtr message =
Config::Utility::translateToFactoryConfig(transport_socket, config_factory);

// Validate IP addresses.
std::vector<std::string> destination_ips;
for (const auto& destination_ip : filter_chain_match.prefix_ranges()) {
const auto& cidr_range = Network::Address::CidrRange::create(destination_ip);
destination_ips.push_back(cidr_range.asString());
}

std::vector<std::string> server_names;
if (!filter_chain_match.server_names().empty()) {
if (!filter_chain_match.sni_domains().empty()) {
Expand Down Expand Up @@ -233,7 +240,9 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, const std::st
filter_chain_match.application_protocols().begin(),
filter_chain_match.application_protocols().end());

addFilterChain(server_names, filter_chain_match.transport_protocol(), application_protocols,
addFilterChain(PROTOBUF_GET_WRAPPED_OR_DEFAULT(filter_chain_match, destination_port, 0),
destination_ips, server_names, filter_chain_match.transport_protocol(),
application_protocols,
config_factory.createTransportSocketFactory(*message, *this, server_names),
parent_.factory_.createNetworkFilterFactoryList(filter_chain.filters(), *this));

Expand All @@ -242,6 +251,9 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, const std::st
(!server_names.empty() || !application_protocols.empty()));
}

// Convert DestinationIPsMap to DestinationIPsTrie for faster lookups.
convertDestinationIPsMapToTrie();

// Automatically inject TLS Inspector if it wasn't configured explicitly and it's needed.
if (need_tls_inspector) {
for (const auto& filter : config.listener_filters()) {
Expand Down Expand Up @@ -274,118 +286,223 @@ ListenerImpl::~ListenerImpl() {
// active. This is done here explicitly by setting a boolean and then clearing the factory
// vector for clarity.
initialize_canceled_ = true;
filter_chains_.clear();
destination_ports_map_.clear();
}

bool ListenerImpl::isWildcardServerName(const std::string& name) {
return absl::StartsWith(name, "*.");
}

void ListenerImpl::addFilterChain(const std::vector<std::string>& server_names,
void ListenerImpl::addFilterChain(uint16_t destination_port,
const std::vector<std::string>& destination_ips,
const std::vector<std::string>& server_names,
const std::string& transport_protocol,
const std::vector<std::string>& application_protocols,
Network::TransportSocketFactoryPtr&& transport_socket_factory,
std::vector<Network::FilterFactoryCb> filters_factory) {
const auto filter_chain = std::make_shared<FilterChainImpl>(std::move(transport_socket_factory),
std::move(filters_factory));
// Save mappings.
addFilterChainForDestinationPorts(destination_ports_map_, destination_port, destination_ips,
server_names, transport_protocol, application_protocols,
filter_chain);
}

void ListenerImpl::addFilterChainForDestinationPorts(
DestinationPortsMap& destination_ports_map, uint16_t destination_port,
const std::vector<std::string>& destination_ips, const std::vector<std::string>& server_names,
const std::string& transport_protocol, const std::vector<std::string>& application_protocols,
const Network::FilterChainSharedPtr& filter_chain) {
if (destination_ports_map.find(destination_port) == destination_ports_map.end()) {
destination_ports_map[destination_port] =
std::make_pair<DestinationIPsMap, DestinationIPsTriePtr>(DestinationIPsMap{}, nullptr);
}
addFilterChainForDestinationIPs(destination_ports_map[destination_port].first, destination_ips,
server_names, transport_protocol, application_protocols,
filter_chain);
}

void ListenerImpl::addFilterChainForDestinationIPs(
DestinationIPsMap& destination_ips_map, const std::vector<std::string>& destination_ips,
const std::vector<std::string>& server_names, const std::string& transport_protocol,
const std::vector<std::string>& application_protocols,
const Network::FilterChainSharedPtr& filter_chain) {
if (destination_ips.empty()) {
addFilterChainForServerNames(destination_ips_map[EMPTY_STRING], server_names,
transport_protocol, application_protocols, filter_chain);
} else {
for (const auto& destination_ip : destination_ips) {
addFilterChainForServerNames(destination_ips_map[destination_ip], server_names,
transport_protocol, application_protocols, filter_chain);
}
}
}

void ListenerImpl::addFilterChainForServerNames(
ServerNamesMap& server_names_map, const std::vector<std::string>& server_names,
const std::string& transport_protocol, const std::vector<std::string>& application_protocols,
const Network::FilterChainSharedPtr& filter_chain) {
if (server_names.empty()) {
addFilterChainForApplicationProtocols(filter_chains_[EMPTY_STRING][transport_protocol],
addFilterChainForApplicationProtocols(server_names_map[EMPTY_STRING][transport_protocol],
application_protocols, filter_chain);
} else {
for (const auto& server_name : server_names) {
if (isWildcardServerName(server_name)) {
// Add mapping for the wildcard domain, i.e. ".example.com" for "*.example.com".
addFilterChainForApplicationProtocols(
filter_chains_[server_name.substr(1)][transport_protocol], application_protocols,
server_names_map[server_name.substr(1)][transport_protocol], application_protocols,
filter_chain);
} else {
addFilterChainForApplicationProtocols(filter_chains_[server_name][transport_protocol],
addFilterChainForApplicationProtocols(server_names_map[server_name][transport_protocol],
application_protocols, filter_chain);
}
}
}
}

void ListenerImpl::addFilterChainForApplicationProtocols(
std::unordered_map<std::string, Network::FilterChainSharedPtr>& transport_protocol_map,
ApplicationProtocolsMap& application_protocols_map,
const std::vector<std::string>& application_protocols,
const Network::FilterChainSharedPtr& filter_chain) {
if (application_protocols.empty()) {
transport_protocol_map[EMPTY_STRING] = filter_chain;
application_protocols_map[EMPTY_STRING] = filter_chain;
} else {
for (const auto& application_protocol : application_protocols) {
transport_protocol_map[application_protocol] = filter_chain;
application_protocols_map[application_protocol] = filter_chain;
}
}
}

void ListenerImpl::convertDestinationIPsMapToTrie() {
for (auto& port : destination_ports_map_) {
auto& destination_ips_pair = port.second;
auto& destination_ips_map = destination_ips_pair.first;
std::vector<std::pair<ServerNamesMapSharedPtr, std::vector<Network::Address::CidrRange>>> list;
for (const auto& entry : destination_ips_map) {
std::vector<Network::Address::CidrRange> subnets;
if (entry.first == EMPTY_STRING) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

entry.first.empty() ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, although I feel that the old check was a bit more correct, since we're checking if it's the map[EMPTY_STRING] entry.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, right, I forgot that's what it is doing. In that case, I'd be ok with putting it back to how you had it.

list.push_back(
std::make_pair<ServerNamesMapSharedPtr, std::vector<Network::Address::CidrRange>>(
std::make_shared<ServerNamesMap>(entry.second),
{Network::Address::CidrRange::create("0.0.0.0/0"),
Network::Address::CidrRange::create("::/0")}));
} else {
list.push_back(
std::make_pair<ServerNamesMapSharedPtr, std::vector<Network::Address::CidrRange>>(
std::make_shared<ServerNamesMap>(entry.second),
{Network::Address::CidrRange::create(entry.first)}));
}
}
destination_ips_pair.second = std::make_unique<DestinationIPsTrie>(list, true);
}
}

const Network::FilterChain*
ListenerImpl::findFilterChain(const Network::ConnectionSocket& socket) const {
const auto& address = socket.localAddress();

// Match on destination port (only for IP addresses).
if (address->type() == Network::Address::Type::Ip) {
const auto port_match = destination_ports_map_.find(address->ip()->port());
if (port_match != destination_ports_map_.end()) {
return findFilterChainForDestinationIP(*port_match->second.second, socket);
}
}

// Match on catch-all port 0.
const auto port_match = destination_ports_map_.find(0);
if (port_match != destination_ports_map_.end()) {
return findFilterChainForDestinationIP(*port_match->second.second, socket);
}

return nullptr;
}

const Network::FilterChain*
ListenerImpl::findFilterChainForDestinationIP(const DestinationIPsTrie& destination_ips_trie,
const Network::ConnectionSocket& socket) const {
// Use invalid IP address (matching only filter chains without IP requirements) for UDS.
static const auto& fake_address = Network::Utility::parseInternetAddress("255.255.255.255");

auto address = socket.localAddress();
if (address->type() != Network::Address::Type::Ip) {
address = fake_address;
}

// Match on both: exact IP and wider CIDR ranges using LcTrie.
const auto& data = destination_ips_trie.getData(address);
if (!data.empty()) {
ASSERT(data.size() == 1);
return findFilterChainForServerName(*data.back(), socket);
}

return nullptr;
}

const Network::FilterChain*
ListenerImpl::findFilterChainForServerName(const ServerNamesMap& server_names_map,
const Network::ConnectionSocket& socket) const {
const std::string server_name(socket.requestedServerName());

// Match on exact server name, i.e. "www.example.com" for "www.example.com".
const auto server_name_exact_match = filter_chains_.find(server_name);
if (server_name_exact_match != filter_chains_.end()) {
return findFilterChainForServerName(server_name_exact_match->second, socket);
const auto server_name_exact_match = server_names_map.find(server_name);
if (server_name_exact_match != server_names_map.end()) {
return findFilterChainForTransportProtocol(server_name_exact_match->second, socket);
}

// Match on all wildcard domains, i.e. ".example.com" and ".com" for "www.example.com".
size_t pos = server_name.find('.', 1);
while (pos < server_name.size() - 1 && pos != std::string::npos) {
const std::string wildcard = server_name.substr(pos);
const auto server_name_wildcard_match = filter_chains_.find(wildcard);
if (server_name_wildcard_match != filter_chains_.end()) {
return findFilterChainForServerName(server_name_wildcard_match->second, socket);
const auto server_name_wildcard_match = server_names_map.find(wildcard);
if (server_name_wildcard_match != server_names_map.end()) {
return findFilterChainForTransportProtocol(server_name_wildcard_match->second, socket);
}
pos = server_name.find('.', pos + 1);
}

// Match on a filter chain without server name requirements.
const auto server_name_catchall_match = filter_chains_.find(EMPTY_STRING);
if (server_name_catchall_match != filter_chains_.end()) {
return findFilterChainForServerName(server_name_catchall_match->second, socket);
const auto server_name_catchall_match = server_names_map.find(EMPTY_STRING);
if (server_name_catchall_match != server_names_map.end()) {
return findFilterChainForTransportProtocol(server_name_catchall_match->second, socket);
}

return nullptr;
}

const Network::FilterChain* ListenerImpl::findFilterChainForServerName(
const std::unordered_map<std::string,
std::unordered_map<std::string, Network::FilterChainSharedPtr>>&
server_name_match,
const Network::FilterChain* ListenerImpl::findFilterChainForTransportProtocol(
const TransportProtocolsMap& transport_protocols_map,
const Network::ConnectionSocket& socket) const {
const std::string transport_protocol(socket.detectedTransportProtocol());

// Match on exact transport protocol, e.g. "tls".
const auto transport_protocol_match = server_name_match.find(transport_protocol);
if (transport_protocol_match != server_name_match.end()) {
const auto transport_protocol_match = transport_protocols_map.find(transport_protocol);
if (transport_protocol_match != transport_protocols_map.end()) {
return findFilterChainForApplicationProtocols(transport_protocol_match->second, socket);
}

// Match on a filter chain without transport protocol requirements.
const auto any_protocol_match = server_name_match.find(EMPTY_STRING);
if (any_protocol_match != server_name_match.end()) {
const auto any_protocol_match = transport_protocols_map.find(EMPTY_STRING);
if (any_protocol_match != transport_protocols_map.end()) {
return findFilterChainForApplicationProtocols(any_protocol_match->second, socket);
}

return nullptr;
}

const Network::FilterChain* ListenerImpl::findFilterChainForApplicationProtocols(
const std::unordered_map<std::string, Network::FilterChainSharedPtr>& transport_protocol_match,
const ApplicationProtocolsMap& application_protocols_map,
const Network::ConnectionSocket& socket) const {
// Match on exact application protocol, e.g. "h2" or "http/1.1".
for (const auto& application_protocol : socket.requestedApplicationProtocols()) {
const auto application_protocol_match = transport_protocol_match.find(application_protocol);
if (application_protocol_match != transport_protocol_match.end()) {
const auto application_protocol_match = application_protocols_map.find(application_protocol);
if (application_protocol_match != application_protocols_map.end()) {
return application_protocol_match->second.get();
}
}

// Match on a filter chain without application protocol requirements.
const auto any_protocol_match = transport_protocol_match.find(EMPTY_STRING);
if (any_protocol_match != transport_protocol_match.end()) {
const auto any_protocol_match = application_protocols_map.find(EMPTY_STRING);
if (any_protocol_match != application_protocols_map.end()) {
return any_protocol_match->second.get();
}

Expand Down
Loading