diff --git a/test/common/network/filter_manager_impl_test.cc b/test/common/network/filter_manager_impl_test.cc index 94201d2010b77..e3620a5c313bb 100644 --- a/test/common/network/filter_manager_impl_test.cc +++ b/test/common/network/filter_manager_impl_test.cc @@ -1,13 +1,23 @@ #include "common/buffer/buffer_impl.h" +#include "common/filter/tcp_proxy.h" +#include "common/filter/ratelimit.h" #include "common/network/filter_manager_impl.h" +#include "common/stats/stats_impl.h" +#include "common/upstream/upstream_impl.h" #include "test/mocks/buffer/mocks.h" #include "test/mocks/network/mocks.h" +#include "test/mocks/ratelimit/mocks.h" +#include "test/mocks/runtime/mocks.h" #include "test/mocks/upstream/host.h" +#include "test/mocks/upstream/mocks.h" +using testing::_; using testing::InSequence; +using testing::Invoke; using testing::NiceMock; using testing::Return; +using testing::WithArgs; namespace Network { @@ -78,4 +88,75 @@ TEST_F(NetworkFilterManagerTest, All) { manager.onWrite(); } +// This is a very important flow so make sure it works correctly in aggregate. +TEST_F(NetworkFilterManagerTest, RateLimitAndTcpProxy) { + InSequence s; + Stats::IsolatedStoreImpl stats_store; + NiceMock runtime; + NiceMock cm; + NiceMock connection; + FilterManagerImpl manager(connection, *this); + + std::string rl_json = R"EOF( + { + "domain": "foo", + "descriptors": [ + [{"key": "hello", "value": "world"}] + ], + "stat_prefix": "name" + } + )EOF"; + + ON_CALL(runtime.snapshot_, featureEnabled("ratelimit.tcp_filter_enabled", 100)) + .WillByDefault(Return(true)); + ON_CALL(runtime.snapshot_, featureEnabled("ratelimit.tcp_filter_enforcing", 100)) + .WillByDefault(Return(true)); + + Json::ObjectPtr rl_config_loader = Json::Factory::LoadFromString(rl_json); + + RateLimit::TcpFilter::ConfigPtr rl_config( + new RateLimit::TcpFilter::Config(*rl_config_loader, stats_store, runtime)); + RateLimit::MockClient* rl_client = new RateLimit::MockClient(); + manager.addReadFilter(ReadFilterPtr{ + new RateLimit::TcpFilter::Instance(rl_config, RateLimit::ClientPtr{rl_client})}); + + std::string tcp_proxy_json = R"EOF( + { + "cluster": "fake_cluster", + "stat_prefix": "name" + } + )EOF"; + + Json::ObjectPtr tcp_proxy_config_loader = Json::Factory::LoadFromString(tcp_proxy_json); + ::Filter::TcpProxyConfigPtr tcp_proxy_config( + new ::Filter::TcpProxyConfig(*tcp_proxy_config_loader, cm, stats_store)); + manager.addReadFilter(ReadFilterPtr{new ::Filter::TcpProxy(tcp_proxy_config, cm)}); + + RateLimit::RequestCallbacks* request_callbacks{}; + EXPECT_CALL( + *rl_client, + limit(_, "foo", + testing::ContainerEq(std::vector{{{{"hello", "world"}}}}), "")) + .WillOnce(WithArgs<0>(Invoke([&](RateLimit::RequestCallbacks& callbacks) + -> void { request_callbacks = &callbacks; }))); + + manager.initializeReadFilters(); + + NiceMock* upstream_connection = + new NiceMock(); + Upstream::MockHost::MockCreateConnectionData conn_info; + conn_info.connection_ = upstream_connection; + conn_info.host_.reset(new Upstream::HostImpl(cm.cluster_, "tcp://127.0.0.1:80", false, 1, "")); + EXPECT_CALL(cm, tcpConnForCluster_("fake_cluster")).WillOnce(Return(conn_info)); + + request_callbacks->complete(RateLimit::LimitStatus::OK); + + upstream_connection->raiseEvents(Network::ConnectionEvent::Connected); + + Buffer::OwnedImpl buffer("hello"); + EXPECT_CALL(*upstream_connection, write(BufferEqual(&buffer))); + read_buffer_.add("hello"); + manager.onRead(); +} + } // Network