Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/envoy/network/filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ class ReadFilterCallbacks {
* Set the currently selected upstream host for the connection.
*/
virtual void upstreamHost(Upstream::HostDescriptionConstSharedPtr host) PURE;

/**
* Return the requested server name (e.g. SNI in TLS) of the network level, if any.
*/
virtual absl::string_view networkLevelRequestedServerName() PURE;

/**
* Set the requested server name (e.g. SNI in TLS) of the network level, if any.
*/
virtual void networkLevelRequestedServerName(absl::string_view name) PURE;
};

/**
Expand Down
8 changes: 8 additions & 0 deletions source/common/network/filter_manager_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ class FilterManagerImpl {
void upstreamHost(Upstream::HostDescriptionConstSharedPtr host) override {
parent_.host_description_ = host;
}
absl::string_view networkLevelRequestedServerName() override {
// TODO: write a warning to log when inner SNI reader is not set.
return parent_.network_level_requested_server_name_;
}
void networkLevelRequestedServerName(absl::string_view name) override {
parent_.network_level_requested_server_name_ = std::string(name);
}

FilterManagerImpl& parent_;
ReadFilterSharedPtr filter_;
Expand All @@ -73,6 +80,7 @@ class FilterManagerImpl {

Connection& connection_;
BufferSource& buffer_source_;
std::string network_level_requested_server_name_;
Upstream::HostDescriptionConstSharedPtr host_description_;
std::list<ActiveReadFilterPtr> upstream_filters_;
std::list<WriteFilterSharedPtr> downstream_filters_;
Expand Down
4 changes: 4 additions & 0 deletions source/common/ssl/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,8 @@ envoy_cc_library(
external_deps = [
"ssl",
],
deps = [
"//include/envoy/stats:stats_macros",
"//source/common/common:assert_lib",
],
)
5 changes: 5 additions & 0 deletions source/common/ssl/utility.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#include "common/ssl/utility.h"

#include "common/common/assert.h"

#include "openssl/bytestring.h"
#include "openssl/ssl.h"

namespace Envoy {
namespace Ssl {

Expand Down
2 changes: 2 additions & 0 deletions source/common/ssl/utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <string>

#include "envoy/stats/stats_macros.h"

#include "openssl/ssl.h"

namespace Envoy {
Expand Down
1 change: 1 addition & 0 deletions source/extensions/extensions_build_config.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ EXTENSIONS = {
"envoy.filters.network.echo": "//source/extensions/filters/network/echo:config",
"envoy.filters.network.ext_authz": "//source/extensions/filters/network/ext_authz:config",
"envoy.filters.network.http_connection_manager": "//source/extensions/filters/network/http_connection_manager:config",
"envoy.filters.network.network_level_sni_reader": "//source/extensions/filters/network/network_level_sni_reader:config",
"envoy.filters.network.mongo_proxy": "//source/extensions/filters/network/mongo_proxy:config",
"envoy.filters.network.ratelimit": "//source/extensions/filters/network/ratelimit:config",
"envoy.filters.network.rbac": "//source/extensions/filters/network/rbac:config",
Expand Down
94 changes: 64 additions & 30 deletions source/extensions/filters/listener/tls_inspector/tls_inspector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ namespace Extensions {
namespace ListenerFilters {
namespace TlsInspector {

Config::Config(Stats::Scope& scope, uint32_t max_client_hello_size)
: stats_{ALL_TLS_INSPECTOR_STATS(POOL_COUNTER_PREFIX(scope, "tls_inspector."))},
Config::Config(Stats::Scope& scope, uint32_t max_client_hello_size, const std::string& stat_prefix)
: stats_{TLS_STATS(POOL_COUNTER_PREFIX(scope, stat_prefix))},
ssl_ctx_(SSL_CTX_new(TLS_with_buffers_method())),
max_client_hello_size_(max_client_hello_size) {

Expand All @@ -42,14 +42,14 @@ Config::Config(Stats::Scope& scope, uint32_t max_client_hello_size)
size_t len;
if (SSL_early_callback_ctx_extension_get(
client_hello, TLSEXT_TYPE_application_layer_protocol_negotiation, &data, &len)) {
Filter* filter = static_cast<Filter*>(SSL_get_app_data(client_hello->ssl));
TlsFilterBase* filter = static_cast<TlsFilterBase*>(SSL_get_app_data(client_hello->ssl));
filter->onALPN(data, len);
}
return ssl_select_cert_success;
});
SSL_CTX_set_tlsext_servername_callback(
ssl_ctx_.get(), [](SSL* ssl, int* out_alert, void*) -> int {
Filter* filter = static_cast<Filter*>(SSL_get_app_data(ssl));
TlsFilterBase* filter = static_cast<TlsFilterBase*>(SSL_get_app_data(ssl));
filter->onServername(SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name));

// Return an error to stop the handshake; we have what we wanted already.
Expand All @@ -63,10 +63,16 @@ bssl::UniquePtr<SSL> Config::newSsl() { return bssl::UniquePtr<SSL>{SSL_new(ssl_
thread_local uint8_t Filter::buf_[Config::TLS_MAX_CLIENT_HELLO];

Filter::Filter(const ConfigSharedPtr config) : config_(config), ssl_(config_->newSsl()) {
RELEASE_ASSERT(sizeof(buf_) >= config_->maxClientHelloSize(), "");
initializeSsl(config->maxClientHelloSize(), sizeof(buf_), ssl_,
static_cast<TlsFilterBase*>(this));
}

void Filter::initializeSsl(uint32_t maxClientHelloSize, size_t bufSize,
const bssl::UniquePtr<SSL>& ssl, void* appData) {
RELEASE_ASSERT(bufSize >= maxClientHelloSize, "");

SSL_set_app_data(ssl_.get(), this);
SSL_set_accept_state(ssl_.get());
SSL_set_app_data(ssl.get(), appData);
SSL_set_accept_state(ssl.get());
}

Network::FilterStatus Filter::onAccept(Network::ListenerFilterCallbacks& cb) {
Expand Down Expand Up @@ -100,6 +106,16 @@ Network::FilterStatus Filter::onAccept(Network::ListenerFilterCallbacks& cb) {
}

void Filter::onALPN(const unsigned char* data, unsigned int len) {
doOnALPN(data, len,
[&](std::vector<absl::string_view> protocols) {
cb_->socket().setRequestedApplicationProtocols(protocols);
},
alpn_found_);
}

void Filter::doOnALPN(const unsigned char* data, unsigned int len,
std::function<void(std::vector<absl::string_view> protocols)> onAlpnCb,
bool& alpn_found) {
CBS wire, list;
CBS_init(&wire, reinterpret_cast<const uint8_t*>(data), static_cast<size_t>(len));
if (!CBS_get_u16_length_prefixed(&wire, &list) || CBS_len(&wire) != 0 || CBS_len(&list) < 2) {
Expand All @@ -115,19 +131,28 @@ void Filter::onALPN(const unsigned char* data, unsigned int len) {
}
protocols.emplace_back(reinterpret_cast<const char*>(CBS_data(&name)), CBS_len(&name));
}
cb_->socket().setRequestedApplicationProtocols(protocols);
alpn_found_ = true;
onAlpnCb(protocols);
alpn_found = true;
}

void Filter::onServername(absl::string_view servername) {
ENVOY_LOG(debug, "tls:onServerName(), requestedServerName: {}", servername);
doOnServername(
servername, config_->stats(),
[&](absl::string_view name) -> void { cb_->socket().setRequestedServerName(name); },
clienthello_success_);
}

void Filter::onServername(absl::string_view name) {
void Filter::doOnServername(absl::string_view name, const TlsStats& stats,
std::function<void(absl::string_view name)> onServernameCb,
bool& clienthello_success) {
if (!name.empty()) {
config_->stats().sni_found_.inc();
cb_->socket().setRequestedServerName(name);
ENVOY_LOG(debug, "tls:onServerName(), requestedServerName: {}", name);
stats.sni_found_.inc();
onServernameCb(name);
} else {
config_->stats().sni_not_found_.inc();
stats.sni_not_found_.inc();
}
clienthello_success_ = true;
clienthello_success = true;
}

void Filter::onRead() {
Expand Down Expand Up @@ -162,7 +187,13 @@ void Filter::onRead() {
const uint8_t* data = buf_ + read_;
const size_t len = result.rc_ - read_;
read_ = result.rc_;
parseClientHello(data, len);
parseClientHello(data, len, ssl_, read_, config_->maxClientHelloSize(), config_->stats(),
[&](bool success) -> void { done(success); }, alpn_found_,
clienthello_success_,
[&]() -> void {
cb_->socket().setDetectedTransportProtocol(
TransportSockets::TransportSocketNames::get().Tls);
});
}
}

Expand All @@ -179,41 +210,44 @@ void Filter::done(bool success) {
cb_->continueFilterChain(success);
}

void Filter::parseClientHello(const void* data, size_t len) {
// Ownership is passed to ssl_ in SSL_set_bio()
void Filter::parseClientHello(const void* data, size_t len, bssl::UniquePtr<SSL>& ssl,
uint64_t read, uint32_t maxClientHelloSize, const TlsStats& stats,
std::function<void(bool)> done, bool& alpn_found,
bool& clienthello_success, std::function<void()> onSuccess) {
// Ownership is passed to ssl in SSL_set_bio()
bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(data, len));

// Make the mem-BIO return that there is more data
// available beyond it's end
BIO_set_mem_eof_return(bio.get(), -1);

SSL_set_bio(ssl_.get(), bio.get(), bio.get());
SSL_set_bio(ssl.get(), bio.get(), bio.get());
bio.release();

int ret = SSL_do_handshake(ssl_.get());
int ret = SSL_do_handshake(ssl.get());

// This should never succeed because an error is always returned from the SNI callback.
ASSERT(ret <= 0);
switch (SSL_get_error(ssl_.get(), ret)) {
switch (SSL_get_error(ssl.get(), ret)) {
case SSL_ERROR_WANT_READ:
if (read_ == config_->maxClientHelloSize()) {
if (read == maxClientHelloSize) {
// We've hit the specified size limit. This is an unreasonably large ClientHello;
// indicate failure.
config_->stats().client_hello_too_large_.inc();
stats.client_hello_too_large_.inc();
done(false);
}
break;
case SSL_ERROR_SSL:
if (clienthello_success_) {
config_->stats().tls_found_.inc();
if (alpn_found_) {
config_->stats().alpn_found_.inc();
if (clienthello_success) {
stats.tls_found_.inc();
if (alpn_found) {
stats.alpn_found_.inc();
} else {
config_->stats().alpn_not_found_.inc();
stats.alpn_not_found_.inc();
}
cb_->socket().setDetectedTransportProtocol(TransportSockets::TransportSocketNames::get().Tls);
onSuccess();
} else {
config_->stats().tls_not_found_.inc();
stats.tls_not_found_.inc();
}
done(true);
break;
Expand Down
52 changes: 38 additions & 14 deletions source/extensions/filters/listener/tls_inspector/tls_inspector.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace TlsInspector {
/**
* All stats for the TLS inspector. @see stats_macros.h
*/
#define ALL_TLS_INSPECTOR_STATS(COUNTER) \
#define TLS_STATS(COUNTER) \
COUNTER(connection_closed) \
COUNTER(client_hello_too_large) \
COUNTER(read_error) \
Expand All @@ -32,50 +32,77 @@ namespace TlsInspector {
COUNTER(sni_not_found)

/**
* Definition of all stats for the TLS inspector. @see stats_macros.h
* Definition of stats for the TLS. @see stats_macros.h
*/
struct TlsInspectorStats {
ALL_TLS_INSPECTOR_STATS(GENERATE_COUNTER_STRUCT)
struct TlsStats {
TLS_STATS(GENERATE_COUNTER_STRUCT)
};

/**
* Global configuration for TLS inspector.
*/
class Config {
public:
Config(Stats::Scope& scope, uint32_t max_client_hello_size = TLS_MAX_CLIENT_HELLO);
Config(Stats::Scope& scope, uint32_t max_client_hello_size = TLS_MAX_CLIENT_HELLO,
const std::string& stat_prefix = "tls_inspector.");

const TlsInspectorStats& stats() const { return stats_; }
const TlsStats& stats() const { return stats_; }
bssl::UniquePtr<SSL> newSsl();
uint32_t maxClientHelloSize() const { return max_client_hello_size_; }

static constexpr size_t TLS_MAX_CLIENT_HELLO = 64 * 1024;

private:
TlsInspectorStats stats_;
TlsStats stats_;
bssl::UniquePtr<SSL_CTX> ssl_ctx_;
const uint32_t max_client_hello_size_;
};

typedef std::shared_ptr<Config> ConfigSharedPtr;

class TlsFilterBase {
public:
virtual ~TlsFilterBase() {}

private:
virtual void onALPN(const unsigned char* data, unsigned int len) PURE;
virtual void onServername(absl::string_view name) PURE;

// Allows callbacks on the SSL_CTX to set fields in this class.
friend class Config;
};

/**
* TLS inspector listener filter.
*/
class Filter : public Network::ListenerFilter, Logger::Loggable<Logger::Id::filter> {
class Filter : public Network::ListenerFilter,
public TlsFilterBase,
Logger::Loggable<Logger::Id::filter> {
public:
Filter(const ConfigSharedPtr config);

// Network::ListenerFilter
Network::FilterStatus onAccept(Network::ListenerFilterCallbacks& cb) override;
static void initializeSsl(uint32_t maxClientHelloSize, size_t bufSize,
const bssl::UniquePtr<SSL>& ssl, void* appData);
static void parseClientHello(const void* data, size_t len, bssl::UniquePtr<SSL>& ssl,
uint64_t read, uint32_t maxClientHelloSize, const TlsStats& stats,
std::function<void(bool)> done, bool& alpn_found,
bool& clienthello_success, std::function<void()> onSuccess);
static void doOnServername(absl::string_view name, const TlsStats& stats,
std::function<void(absl::string_view name)> onServernameCb,
bool& clienthello_success_);
static void doOnALPN(const unsigned char* data, unsigned int len,
std::function<void(std::vector<absl::string_view> protocols)> onAlpnCb,
bool& alpn_found);

private:
void parseClientHello(const void* data, size_t len);
void onRead();
void onTimeout();
void done(bool success);
void onALPN(const unsigned char* data, unsigned int len);
void onServername(absl::string_view name);
// Extensions::ListenerFilters::TlsInspector::TlsFilterBase
void onALPN(const unsigned char* data, unsigned int len) override;
void onServername(absl::string_view name) override;

ConfigSharedPtr config_;
Network::ListenerFilterCallbacks* cb_;
Expand All @@ -88,9 +115,6 @@ class Filter : public Network::ListenerFilter, Logger::Loggable<Logger::Id::filt
bool clienthello_success_{false};

static thread_local uint8_t buf_[Config::TLS_MAX_CLIENT_HELLO];

// Allows callbacks on the SSL_CTX to set fields in this class.
friend class Config;
};

} // namespace TlsInspector
Expand Down
36 changes: 36 additions & 0 deletions source/extensions/filters/network/network_level_sni_reader/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
licenses(["notice"]) # Apache 2

load(
"//bazel:envoy_build_system.bzl",
"envoy_cc_library",
"envoy_package",
)

envoy_package()

envoy_cc_library(
name = "network_level_sni_reader",
srcs = ["network_level_sni_reader.cc"],
hdrs = ["network_level_sni_reader.h"],
external_deps = ["ssl"],
deps = [
"//include/envoy/buffer:buffer_interface",
"//include/envoy/network:connection_interface",
"//include/envoy/network:filter_interface",
"//source/common/common:assert_lib",
"//source/common/common:minimal_logger_lib",
"//source/extensions/filters/listener/tls_inspector:tls_inspector_lib",
],
)

envoy_cc_library(
name = "config",
srcs = ["config.cc"],
deps = [
":network_level_sni_reader",
"//include/envoy/registry",
"//include/envoy/server:filter_config_interface",
"//source/extensions/filters/listener/tls_inspector:tls_inspector_lib",
"//source/extensions/filters/network:well_known_names",
],
)
Loading