diff --git a/include/envoy/http/alternate_protocols_cache.h b/include/envoy/http/alternate_protocols_cache.h index 6d94f32eafcb6..2d02cda9bd48e 100644 --- a/include/envoy/http/alternate_protocols_cache.h +++ b/include/envoy/http/alternate_protocols_cache.h @@ -64,16 +64,18 @@ class AlternateProtocolsCache { }; /** - * Represents an alternative protocol that can be used to connect to an origin. + * Represents an alternative protocol that can be used to connect to an origin + * with a specified expiration time. */ struct AlternateProtocol { public: - AlternateProtocol(absl::string_view alpn, absl::string_view hostname, uint32_t port) - : alpn_(alpn), hostname_(hostname), port_(port) {} + AlternateProtocol(absl::string_view alpn, absl::string_view hostname, uint32_t port, + MonotonicTime expiration) + : alpn_(alpn), hostname_(hostname), port_(port), expiration_(expiration) {} bool operator==(const AlternateProtocol& other) const { - return std::tie(alpn_, hostname_, port_) == - std::tie(other.alpn_, other.hostname_, other.port_); + return std::tie(alpn_, hostname_, port_, expiration_) == + std::tie(other.alpn_, other.hostname_, other.port_, other.expiration_); } bool operator!=(const AlternateProtocol& other) const { return !this->operator==(other); } @@ -81,6 +83,7 @@ class AlternateProtocolsCache { std::string alpn_; std::string hostname_; uint32_t port_; + MonotonicTime expiration_; }; virtual ~AlternateProtocolsCache() = default; @@ -90,11 +93,9 @@ class AlternateProtocolsCache { * specified origin. Expires after the specified expiration time. * @param origin The origin to set alternate protocols for. * @param protocols A list of alternate protocols. - * @param expiration The time after which the alternatives are no longer valid. */ virtual void setAlternatives(const Origin& origin, - const std::vector& protocols, - const MonotonicTime& expiration) PURE; + const std::vector& protocols) PURE; /** * Returns the possible alternative protocols which can be used to connect to the diff --git a/source/common/http/BUILD b/source/common/http/BUILD index bc2b727a1b3fb..a581bc5d8d54e 100644 --- a/source/common/http/BUILD +++ b/source/common/http/BUILD @@ -176,6 +176,7 @@ envoy_cc_library( "//include/envoy/singleton:manager_interface", "//include/envoy/thread_local:thread_local_interface", "//include/envoy/upstream:resource_manager_interface", + "//source/common/common:logger_lib", "@envoy_api//envoy/config/core/v3:pkg_cc_proto", ], ) diff --git a/source/common/http/alternate_protocols_cache_impl.cc b/source/common/http/alternate_protocols_cache_impl.cc index 89d31df615aaf..4368e9a778599 100644 --- a/source/common/http/alternate_protocols_cache_impl.cc +++ b/source/common/http/alternate_protocols_cache_impl.cc @@ -1,5 +1,7 @@ #include "common/http/alternate_protocols_cache_impl.h" +#include "common/common/logger.h" + namespace Envoy { namespace Http { @@ -9,14 +11,13 @@ AlternateProtocolsCacheImpl::AlternateProtocolsCacheImpl(TimeSource& time_source AlternateProtocolsCacheImpl::~AlternateProtocolsCacheImpl() = default; void AlternateProtocolsCacheImpl::setAlternatives(const Origin& origin, - const std::vector& protocols, - const MonotonicTime& expiration) { - Entry& entry = protocols_[origin]; - if (entry.protocols_ != protocols) { - entry.protocols_ = protocols; - } - if (entry.expiration_ != expiration) { - entry.expiration_ = expiration; + const std::vector& protocols) { + protocols_[origin] = protocols; + static const size_t max_protocols = 10; + if (protocols.size() > max_protocols) { + ENVOY_LOG_MISC(trace, "Too many alternate protocols: {}, truncating", protocols.size()); + std::vector& p = protocols_[origin]; + p.erase(p.begin() + max_protocols, p.end()); } } @@ -24,19 +25,24 @@ OptRef> AlternateProtocolsCacheImpl::findAlternatives(const Origin& origin) { auto entry_it = protocols_.find(origin); if (entry_it == protocols_.end()) { - return makeOptRefFromPtr>( - nullptr); + return makeOptRefFromPtr>(nullptr); } - const Entry& entry = entry_it->second; - if (time_source_.monotonicTime() > entry.expiration_) { - // Expire the entry. - // TODO(RyanTheOptimist): expire entries based on a timer. + std::vector& protocols = entry_it->second; + + const MonotonicTime now = time_source_.monotonicTime(); + protocols.erase(std::remove_if(protocols.begin(), protocols.end(), + [now](const AlternateProtocol& protocol) { + return (now > protocol.expiration_); + }), + protocols.end()); + + if (protocols.empty()) { protocols_.erase(entry_it); - return makeOptRefFromPtr>( - nullptr); + return makeOptRefFromPtr>(nullptr); } - return makeOptRef(entry.protocols_); + + return makeOptRef(const_cast&>(protocols)); } size_t AlternateProtocolsCacheImpl::size() const { return protocols_.size(); } diff --git a/source/common/http/alternate_protocols_cache_impl.h b/source/common/http/alternate_protocols_cache_impl.h index f72ec506dc2d4..a029a970c763f 100644 --- a/source/common/http/alternate_protocols_cache_impl.h +++ b/source/common/http/alternate_protocols_cache_impl.h @@ -22,23 +22,18 @@ class AlternateProtocolsCacheImpl : public AlternateProtocolsCache { ~AlternateProtocolsCacheImpl() override; // AlternateProtocolsCache - void setAlternatives(const Origin& origin, const std::vector& protocols, - const MonotonicTime& expiration) override; + void setAlternatives(const Origin& origin, + const std::vector& protocols) override; OptRef> findAlternatives(const Origin& origin) override; size_t size() const override; private: - struct Entry { - std::vector protocols_; - MonotonicTime expiration_; - }; - // Time source used to check expiration of entries. TimeSource& time_source_; // Map from hostname to list of alternate protocols. // TODO(RyanTheOptimist): Add a limit to the size of this map and evict based on usage. - std::map protocols_; + std::map> protocols_; }; } // namespace Http diff --git a/test/common/http/alternate_protocols_cache_impl_test.cc b/test/common/http/alternate_protocols_cache_impl_test.cc index dc0f95848d2b3..88347e98bac3e 100644 --- a/test/common/http/alternate_protocols_cache_impl_test.cc +++ b/test/common/http/alternate_protocols_cache_impl_test.cc @@ -23,29 +23,31 @@ class AlternateProtocolsCacheImplTest : public testing::Test, public Event::Test const std::string alpn1_ = "alpn1"; const std::string alpn2_ = "alpn2"; + const MonotonicTime expiration1_ = simTime().monotonicTime() + Seconds(5); + const MonotonicTime expiration2_ = simTime().monotonicTime() + Seconds(10); + const AlternateProtocolsCacheImpl::Origin origin1_ = {https_, hostname1_, port1_}; const AlternateProtocolsCacheImpl::Origin origin2_ = {https_, hostname2_, port2_}; - const AlternateProtocolsCacheImpl::AlternateProtocol protocol1_ = {alpn1_, hostname1_, port1_}; - const AlternateProtocolsCacheImpl::AlternateProtocol protocol2_ = {alpn2_, hostname2_, port2_}; + const AlternateProtocolsCacheImpl::AlternateProtocol protocol1_ = {alpn1_, hostname1_, port1_, + expiration1_}; + const AlternateProtocolsCacheImpl::AlternateProtocol protocol2_ = {alpn2_, hostname2_, port2_, + expiration2_}; const std::vector protocols1_ = {protocol1_}; const std::vector protocols2_ = {protocol2_}; - - const MonotonicTime expiration1_ = simTime().monotonicTime() + Seconds(5); - const MonotonicTime expiration2_ = simTime().monotonicTime() + Seconds(10); }; TEST_F(AlternateProtocolsCacheImplTest, Init) { EXPECT_EQ(0, protocols_.size()); } TEST_F(AlternateProtocolsCacheImplTest, SetAlternatives) { EXPECT_EQ(0, protocols_.size()); - protocols_.setAlternatives(origin1_, protocols1_, expiration1_); + protocols_.setAlternatives(origin1_, protocols1_); EXPECT_EQ(1, protocols_.size()); } TEST_F(AlternateProtocolsCacheImplTest, FindAlternatives) { - protocols_.setAlternatives(origin1_, protocols1_, expiration1_); + protocols_.setAlternatives(origin1_, protocols1_); OptRef> protocols = protocols_.findAlternatives(origin1_); ASSERT_TRUE(protocols.has_value()); @@ -53,8 +55,8 @@ TEST_F(AlternateProtocolsCacheImplTest, FindAlternatives) { } TEST_F(AlternateProtocolsCacheImplTest, FindAlternativesAfterReplacement) { - protocols_.setAlternatives(origin1_, protocols1_, expiration1_); - protocols_.setAlternatives(origin1_, protocols2_, expiration2_); + protocols_.setAlternatives(origin1_, protocols1_); + protocols_.setAlternatives(origin1_, protocols2_); OptRef> protocols = protocols_.findAlternatives(origin1_); ASSERT_TRUE(protocols.has_value()); @@ -63,8 +65,8 @@ TEST_F(AlternateProtocolsCacheImplTest, FindAlternativesAfterReplacement) { } TEST_F(AlternateProtocolsCacheImplTest, FindAlternativesForMultipleOrigins) { - protocols_.setAlternatives(origin1_, protocols1_, expiration1_); - protocols_.setAlternatives(origin2_, protocols2_, expiration2_); + protocols_.setAlternatives(origin1_, protocols1_); + protocols_.setAlternatives(origin2_, protocols2_); OptRef> protocols = protocols_.findAlternatives(origin1_); ASSERT_TRUE(protocols.has_value()); @@ -75,7 +77,7 @@ TEST_F(AlternateProtocolsCacheImplTest, FindAlternativesForMultipleOrigins) { } TEST_F(AlternateProtocolsCacheImplTest, FindAlternativesAfterExpiration) { - protocols_.setAlternatives(origin1_, protocols1_, expiration1_); + protocols_.setAlternatives(origin1_, protocols1_); simTime().setMonotonicTime(expiration1_ + Seconds(1)); OptRef> protocols = protocols_.findAlternatives(origin1_); @@ -83,6 +85,37 @@ TEST_F(AlternateProtocolsCacheImplTest, FindAlternativesAfterExpiration) { EXPECT_EQ(0, protocols_.size()); } +TEST_F(AlternateProtocolsCacheImplTest, FindAlternativesAfterPartialExpiration) { + protocols_.setAlternatives(origin1_, {protocol1_, protocol2_}); + simTime().setMonotonicTime(expiration1_ + Seconds(1)); + OptRef> protocols = + protocols_.findAlternatives(origin1_); + ASSERT_TRUE(protocols.has_value()); + EXPECT_EQ(protocols2_.size(), protocols->size()); + EXPECT_EQ(protocols2_, protocols.ref()); +} + +TEST_F(AlternateProtocolsCacheImplTest, FindAlternativesAfterTruncation) { + AlternateProtocolsCacheImpl::AlternateProtocol protocol = protocol1_; + + std::vector expected_protocols; + for (size_t i = 0; i < 10; ++i) { + protocol.port_++; + expected_protocols.push_back(protocol); + } + std::vector full_protocols = expected_protocols; + protocol.port_++; + full_protocols.push_back(protocol); + full_protocols.push_back(protocol); + + protocols_.setAlternatives(origin1_, full_protocols); + OptRef> protocols = + protocols_.findAlternatives(origin1_); + ASSERT_TRUE(protocols.has_value()); + EXPECT_EQ(10, protocols->size()); + EXPECT_EQ(expected_protocols, protocols.ref()); +} + } // namespace } // namespace Http } // namespace Envoy diff --git a/test/common/http/conn_pool_grid_test.cc b/test/common/http/conn_pool_grid_test.cc index e25551aaaf5dc..d52fd391620af 100644 --- a/test/common/http/conn_pool_grid_test.cc +++ b/test/common/http/conn_pool_grid_test.cc @@ -122,9 +122,8 @@ class ConnectivityGridTestBase : public Event::TestUsingSimulatedTime, public te void addHttp3AlternateProtocol() { AlternateProtocolsCacheImpl::Origin origin("https", "hostname", 9000); const std::vector protocols = { - {"h3-29", "", origin.port_}}; - alternate_protocols_->setAlternatives(origin, protocols, - simTime().monotonicTime() + Seconds(5)); + {"h3-29", "", origin.port_, simTime().monotonicTime() + Seconds(5)}}; + alternate_protocols_->setAlternatives(origin, protocols); } const Network::ConnectionSocket::OptionsSharedPtr socket_options_; @@ -513,8 +512,8 @@ TEST_F(ConnectivityGridWithAlternateProtocolsCacheImplTest, SuccessWithoutHttp3) TEST_F(ConnectivityGridWithAlternateProtocolsCacheImplTest, SuccessWithExpiredHttp3) { AlternateProtocolsCacheImpl::Origin origin("https", "hostname", 9000); const std::vector protocols = { - {"h3-29", "", origin.port_}}; - alternate_protocols_->setAlternatives(origin, protocols, simTime().monotonicTime() + Seconds(5)); + {"h3-29", "", origin.port_, simTime().monotonicTime() + Seconds(5)}}; + alternate_protocols_->setAlternatives(origin, protocols); simTime().setMonotonicTime(simTime().monotonicTime() + Seconds(10)); EXPECT_EQ(grid_.first(), nullptr); @@ -536,8 +535,8 @@ TEST_F(ConnectivityGridWithAlternateProtocolsCacheImplTest, SuccessWithExpiredHt TEST_F(ConnectivityGridWithAlternateProtocolsCacheImplTest, SuccessWithoutHttp3NoMatchingHostname) { AlternateProtocolsCacheImpl::Origin origin("https", "hostname", 9000); const std::vector protocols = { - {"h3-29", "otherhostname", origin.port_}}; - alternate_protocols_->setAlternatives(origin, protocols, simTime().monotonicTime() + Seconds(5)); + {"h3-29", "otherhostname", origin.port_, simTime().monotonicTime() + Seconds(5)}}; + alternate_protocols_->setAlternatives(origin, protocols); EXPECT_EQ(grid_.first(), nullptr); @@ -557,8 +556,8 @@ TEST_F(ConnectivityGridWithAlternateProtocolsCacheImplTest, SuccessWithoutHttp3N TEST_F(ConnectivityGridWithAlternateProtocolsCacheImplTest, SuccessWithoutHttp3NoMatchingPort) { AlternateProtocolsCacheImpl::Origin origin("https", "hostname", 9000); const std::vector protocols = { - {"h3-29", "", origin.port_ + 1}}; - alternate_protocols_->setAlternatives(origin, protocols, simTime().monotonicTime() + Seconds(5)); + {"h3-29", "", origin.port_ + 1, simTime().monotonicTime() + Seconds(5)}}; + alternate_protocols_->setAlternatives(origin, protocols); EXPECT_EQ(grid_.first(), nullptr); @@ -577,8 +576,8 @@ TEST_F(ConnectivityGridWithAlternateProtocolsCacheImplTest, SuccessWithoutHttp3N TEST_F(ConnectivityGridWithAlternateProtocolsCacheImplTest, SuccessWithoutHttp3NoMatchingAlpn) { AlternateProtocolsCacheImpl::Origin origin("https", "hostname", 9000); const std::vector protocols = { - {"http/2", "", origin.port_}}; - alternate_protocols_->setAlternatives(origin, protocols, simTime().monotonicTime() + Seconds(5)); + {"http/2", "", origin.port_, simTime().monotonicTime() + Seconds(5)}}; + alternate_protocols_->setAlternatives(origin, protocols); EXPECT_EQ(grid_.first(), nullptr);