diff --git a/source/common/http/BUILD b/source/common/http/BUILD index 60d3cb0cf17eb..1a0f09e2c2b3c 100644 --- a/source/common/http/BUILD +++ b/source/common/http/BUILD @@ -169,6 +169,7 @@ envoy_cc_library( "alternate_protocols_cache_impl.h", "alternate_protocols_cache_manager_impl.h", ], + external_deps = ["quiche_quic_platform"], deps = [ "//envoy/common:time_interface", "//envoy/event:dispatcher_interface", diff --git a/source/common/http/alternate_protocols_cache_impl.cc b/source/common/http/alternate_protocols_cache_impl.cc index 851209796c9c5..8f2122049ecd8 100644 --- a/source/common/http/alternate_protocols_cache_impl.cc +++ b/source/common/http/alternate_protocols_cache_impl.cc @@ -64,8 +64,9 @@ AlternateProtocolsCacheImpl::protocolsFromString(absl::string_view alt_svc_strin } AlternateProtocolsCacheImpl::AlternateProtocolsCacheImpl( - TimeSource& time_source, std::unique_ptr&& key_value_store) - : time_source_(time_source), key_value_store_(std::move(key_value_store)) {} + TimeSource& time_source, std::unique_ptr&& key_value_store, size_t max_entries) + : time_source_(time_source), key_value_store_(std::move(key_value_store)), + max_entries_(max_entries > 0 ? max_entries : 1024) {} AlternateProtocolsCacheImpl::~AlternateProtocolsCacheImpl() = default; @@ -76,6 +77,11 @@ void AlternateProtocolsCacheImpl::setAlternatives(const Origin& origin, ENVOY_LOG_MISC(trace, "Too many alternate protocols: {}, truncating", protocols.size()); protocols.erase(protocols.begin() + max_protocols, protocols.end()); } + while (protocols_.size() >= max_entries_) { + auto iter = protocols_.begin(); + key_value_store_->remove(originToString(iter->first)); + protocols_.erase(iter); + } protocols_[origin] = protocols; if (key_value_store_) { key_value_store_->addOrUpdate(originToString(origin), diff --git a/source/common/http/alternate_protocols_cache_impl.h b/source/common/http/alternate_protocols_cache_impl.h index df216bf0407c3..4d3d2ddb10c0b 100644 --- a/source/common/http/alternate_protocols_cache_impl.h +++ b/source/common/http/alternate_protocols_cache_impl.h @@ -11,6 +11,7 @@ #include "envoy/http/alternate_protocols_cache.h" #include "absl/strings/string_view.h" +#include "quiche/common/quiche_linked_hash_map.h" namespace Envoy { namespace Http { @@ -19,7 +20,8 @@ namespace Http { // See: source/docs/http3_upstream.md class AlternateProtocolsCacheImpl : public AlternateProtocolsCache { public: - AlternateProtocolsCacheImpl(TimeSource& time_source, std::unique_ptr&& store); + AlternateProtocolsCacheImpl(TimeSource& time_source, std::unique_ptr&& store, + size_t max_entries); ~AlternateProtocolsCacheImpl() override; // Convert an AlternateProtocol vector to a string to cache to the key value @@ -48,12 +50,23 @@ class AlternateProtocolsCacheImpl : public AlternateProtocolsCache { // Time source used to check expiration of entries. TimeSource& time_source_; - // Map from hostname to list of alternate protocols. - // TODO(RyanTheOptimist): Add a limit to the size of this map and evict based on usage. - std::map> protocols_; + struct OriginHash { + size_t operator()(const Origin& origin) const { + // Multiply the hashes by the magic number 37 to spread the bits around. + size_t hash = std::hash()(origin.scheme_) + + 37 * (std::hash()(origin.hostname_) + + 37 * std::hash()(origin.port_)); + return hash; + } + }; + + // Map from origin to list of alternate protocols. + quiche::QuicheLinkedHashMap, OriginHash> protocols_; // The key value store, if flushing to persistent storage. std::unique_ptr key_value_store_; + + const size_t max_entries_; }; } // namespace Http diff --git a/source/common/http/alternate_protocols_cache_manager_impl.cc b/source/common/http/alternate_protocols_cache_manager_impl.cc index 9496a5ea1c8b2..c0bf2016258c3 100644 --- a/source/common/http/alternate_protocols_cache_manager_impl.cc +++ b/source/common/http/alternate_protocols_cache_manager_impl.cc @@ -51,7 +51,7 @@ AlternateProtocolsCacheSharedPtr AlternateProtocolsCacheManagerImpl::getCache( } AlternateProtocolsCacheSharedPtr new_cache = std::make_shared( - data_.dispatcher_.timeSource(), std::move(store)); + data_.dispatcher_.timeSource(), std::move(store), options.max_entries().value()); (*slot_).caches_.emplace(options.name(), CacheWithOptions{options, new_cache}); return new_cache; } diff --git a/test/common/http/alternate_protocols_cache_impl_test.cc b/test/common/http/alternate_protocols_cache_impl_test.cc index 9aa0dc5f0016e..651e371bb89fb 100644 --- a/test/common/http/alternate_protocols_cache_impl_test.cc +++ b/test/common/http/alternate_protocols_cache_impl_test.cc @@ -15,10 +15,13 @@ class AlternateProtocolsCacheImplTest : public testing::Test, public Event::Test public: AlternateProtocolsCacheImplTest() : store_(new NiceMock()), - protocols_(simTime(), std::unique_ptr(store_)) {} + protocols_(simTime(), std::unique_ptr(store_), max_entries_) {} + + const size_t max_entries_ = 10; MockKeyValueStore* store_; AlternateProtocolsCacheImpl protocols_; + const std::string hostname1_ = "hostname1"; const std::string hostname2_ = "hostname2"; const uint32_t port1_ = 1; @@ -132,6 +135,22 @@ TEST_F(AlternateProtocolsCacheImplTest, FindAlternativesAfterTruncation) { EXPECT_EQ(expected_protocols, protocols.ref()); } +TEST_F(AlternateProtocolsCacheImplTest, MaxEntries) { + EXPECT_EQ(0, protocols_.size()); + const std::string hostname = "hostname"; + for (uint32_t i = 0; i <= max_entries_; ++i) { + const AlternateProtocolsCache::Origin origin = {https_, hostname, i}; + AlternateProtocolsCache::AlternateProtocol protocol = {alpn1_, hostname, i, expiration1_}; + std::vector protocols = {protocol}; + EXPECT_CALL(*store_, addOrUpdate(absl::StrCat("https://hostname:", i), + absl::StrCat("alpn1=\"hostname:", i, "\"; ma=5"))); + if (i == max_entries_) { + EXPECT_CALL(*store_, remove("https://hostname:0")); + } + protocols_.setAlternatives(origin, protocols); + } +} + TEST_F(AlternateProtocolsCacheImplTest, ToAndFromString) { auto testAltSvc = [&](const std::string& original_alt_svc, const std::string& expected_alt_svc) -> void { diff --git a/test/common/http/conn_pool_grid_test.cc b/test/common/http/conn_pool_grid_test.cc index ebe3accff1774..69528416002db 100644 --- a/test/common/http/conn_pool_grid_test.cc +++ b/test/common/http/conn_pool_grid_test.cc @@ -102,7 +102,7 @@ class ConnectivityGridTest : public Event::TestUsingSimulatedTime, public testin public: ConnectivityGridTest() : options_({Http::Protocol::Http11, Http::Protocol::Http2, Http::Protocol::Http3}), - alternate_protocols_(std::make_shared(simTime(), nullptr)), + alternate_protocols_(std::make_shared(simTime(), nullptr, 10)), quic_stat_names_(store_.symbolTable()), grid_(dispatcher_, random_, Upstream::makeTestHost(cluster_, "hostname", "tcp://127.0.0.1:9000", simTime()), @@ -120,7 +120,7 @@ class ConnectivityGridTest : public Event::TestUsingSimulatedTime, public testin if (!use_alternate_protocols) { return nullptr; } - return std::make_shared(simTime(), nullptr); + return std::make_shared(simTime(), nullptr, 10); } void addHttp3AlternateProtocol() {