diff --git a/include/envoy/common/token_bucket.h b/include/envoy/common/token_bucket.h index 4a8750fc013c9..2f04c7c8c461f 100644 --- a/include/envoy/common/token_bucket.h +++ b/include/envoy/common/token_bucket.h @@ -15,7 +15,7 @@ namespace Envoy { */ class TokenBucket { public: - virtual ~TokenBucket() {} + virtual ~TokenBucket() = default; /** * @param tokens supplies the number of tokens to be consumed. @@ -32,6 +32,12 @@ class TokenBucket { * returns the upper bound on the amount of time until a next token is available. */ virtual std::chrono::milliseconds nextTokenAvailable() PURE; + + /** + * Reset the bucket with a specific number of tokens. Refill will begin again from the time that + * this routine is called. + */ + virtual void reset(uint64_t num_tokens) PURE; }; typedef std::unique_ptr TokenBucketPtr; diff --git a/source/common/common/token_bucket_impl.cc b/source/common/common/token_bucket_impl.cc index bf93dc32f447c..5e7de9e6bb1a7 100644 --- a/source/common/common/token_bucket_impl.cc +++ b/source/common/common/token_bucket_impl.cc @@ -38,4 +38,10 @@ std::chrono::milliseconds TokenBucketImpl::nextTokenAvailable() { return std::chrono::milliseconds(static_cast(std::ceil((1 / fill_rate_) * 1000))); } +void TokenBucketImpl::reset(uint64_t num_tokens) { + ASSERT(num_tokens <= max_tokens_); + tokens_ = num_tokens; + last_fill_ = time_source_.monotonicTime(); +} + } // namespace Envoy diff --git a/source/common/common/token_bucket_impl.h b/source/common/common/token_bucket_impl.h index 7daa3fb8e79b0..644a4185dd5ab 100644 --- a/source/common/common/token_bucket_impl.h +++ b/source/common/common/token_bucket_impl.h @@ -23,6 +23,7 @@ class TokenBucketImpl : public TokenBucket { // TokenBucket uint64_t consume(uint64_t tokens, bool allow_partial) override; std::chrono::milliseconds nextTokenAvailable() override; + void reset(uint64_t num_tokens) override; private: const double max_tokens_; diff --git a/source/extensions/filters/http/fault/fault_filter.cc b/source/extensions/filters/http/fault/fault_filter.cc index c9502ff938c56..3d9acf69f94ae 100644 --- a/source/extensions/filters/http/fault/fault_filter.cc +++ b/source/extensions/filters/http/fault/fault_filter.cc @@ -416,6 +416,16 @@ void StreamRateLimiter::onTokenTimer() { ENVOY_LOG(trace, "limiter: timer wakeup: buffered={}", buffer_.length()); Buffer::OwnedImpl data_to_write; + if (!saw_data_) { + // The first time we see any data on this stream (via writeData()), reset the number of tokens + // to 1. This will ensure that we start pacing the data at the desired rate (and don't send a + // full 1s of data right away which might not introduce enough delay for a stream that doesn't + // have enough data to span more than 1s of rate allowance). Once we reset, we will subsequently + // allow for bursting within the second to account for our data provider being bursty. + token_bucket_.reset(1); + saw_data_ = true; + } + // Compute the number of tokens needed (rounded up), try to obtain that many tickets, and then // figure out how many bytes to write given the number of tokens we actually got. const uint64_t tokens_needed = diff --git a/source/extensions/filters/http/fault/fault_filter.h b/source/extensions/filters/http/fault/fault_filter.h index fcf78ca15d003..2b535edd666e0 100644 --- a/source/extensions/filters/http/fault/fault_filter.h +++ b/source/extensions/filters/http/fault/fault_filter.h @@ -151,6 +151,7 @@ class StreamRateLimiter : Logger::Loggable { const std::function continue_cb_; TokenBucketImpl token_bucket_; Event::TimerPtr token_timer_; + bool saw_data_{}; bool saw_end_stream_{}; bool saw_trailers_{}; Buffer::WatermarkBuffer buffer_; diff --git a/test/common/common/token_bucket_impl_test.cc b/test/common/common/token_bucket_impl_test.cc index 4a44acd847016..aec4744bc83eb 100644 --- a/test/common/common/token_bucket_impl_test.cc +++ b/test/common/common/token_bucket_impl_test.cc @@ -85,4 +85,12 @@ TEST_F(TokenBucketImplTest, PartialConsumption) { EXPECT_EQ(std::chrono::milliseconds(63), token_bucket.nextTokenAvailable()); } +// Test reset functionality. +TEST_F(TokenBucketImplTest, Reset) { + TokenBucketImpl token_bucket{16, time_system_, 16}; + token_bucket.reset(1); + EXPECT_EQ(1, token_bucket.consume(2, true)); + EXPECT_EQ(std::chrono::milliseconds(63), token_bucket.nextTokenAvailable()); +} + } // namespace Envoy diff --git a/test/extensions/filters/http/fault/fault_filter_integration_test.cc b/test/extensions/filters/http/fault/fault_filter_integration_test.cc index 3aabd78412062..e4527d30a41a9 100644 --- a/test/extensions/filters/http/fault/fault_filter_integration_test.cc +++ b/test/extensions/filters/http/fault/fault_filter_integration_test.cc @@ -74,17 +74,15 @@ TEST_P(FaultIntegrationTestAllProtocols, ResponseRateLimitNoTrailers) { codec_client_->makeHeaderOnlyRequest(default_request_headers_); waitForNextUpstreamRequest(); upstream_request_->encodeHeaders(default_response_headers_, false); - Buffer::OwnedImpl data(std::string(1152, 'a')); + Buffer::OwnedImpl data(std::string(127, 'a')); upstream_request_->encodeData(data, true); - decoder->waitForBodyData(1024); - // Advance time and wait for a tick worth of data. - simTime().sleep(std::chrono::milliseconds(63)); - decoder->waitForBodyData(1088); + // Wait for a tick worth of data. + decoder->waitForBodyData(64); - // Advance time and wait for a tick worth of data and end stream. + // Wait for a tick worth of data and end stream. simTime().sleep(std::chrono::milliseconds(63)); - decoder->waitForBodyData(1152); + decoder->waitForBodyData(127); decoder->waitForEndStream(); EXPECT_EQ(0UL, test_server_->counter("http.config_test.fault.delays_injected")->value()); @@ -110,13 +108,15 @@ TEST_P(FaultIntegrationTestAllProtocols, HeaderFaultConfig) { // Verify response body throttling. upstream_request_->encodeHeaders(default_response_headers_, false); - Buffer::OwnedImpl data(std::string(1025, 'a')); + Buffer::OwnedImpl data(std::string(128, 'a')); upstream_request_->encodeData(data, true); - decoder->waitForBodyData(1024); - // Advance time and wait for a tick worth of data and end stream. + // Wait for a tick worth of data. + decoder->waitForBodyData(64); + + // Wait for a tick worth of data and end stream. simTime().sleep(std::chrono::milliseconds(63)); - decoder->waitForBodyData(1025); + decoder->waitForBodyData(128); decoder->waitForEndStream(); EXPECT_EQ(1UL, test_server_->counter("http.config_test.fault.delays_injected")->value()); @@ -149,17 +149,15 @@ TEST_P(FaultIntegrationTestHttp2, ResponseRateLimitTrailersBodyFlushed) { codec_client_->makeHeaderOnlyRequest(default_request_headers_); waitForNextUpstreamRequest(); upstream_request_->encodeHeaders(default_response_headers_, false); - Buffer::OwnedImpl data(std::string(1152, 'a')); + Buffer::OwnedImpl data(std::string(127, 'a')); upstream_request_->encodeData(data, false); - decoder->waitForBodyData(1024); - // Advance time and wait for a tick worth of data. - simTime().sleep(std::chrono::milliseconds(63)); - decoder->waitForBodyData(1088); + // Wait for a tick worth of data. + decoder->waitForBodyData(64); // Advance time and wait for a tick worth of data. simTime().sleep(std::chrono::milliseconds(63)); - decoder->waitForBodyData(1152); + decoder->waitForBodyData(127); // Send trailers and wait for end stream. Http::TestHeaderMapImpl trailers{{"hello", "world"}}; @@ -179,19 +177,17 @@ TEST_P(FaultIntegrationTestHttp2, ResponseRateLimitTrailersBodyNotFlushed) { codec_client_->makeHeaderOnlyRequest(default_request_headers_); waitForNextUpstreamRequest(); upstream_request_->encodeHeaders(default_response_headers_, false); - Buffer::OwnedImpl data(std::string(1152, 'a')); + Buffer::OwnedImpl data(std::string(128, 'a')); upstream_request_->encodeData(data, false); Http::TestHeaderMapImpl trailers{{"hello", "world"}}; upstream_request_->encodeTrailers(trailers); - decoder->waitForBodyData(1024); - // Advance time and wait for a tick worth of data. - simTime().sleep(std::chrono::milliseconds(63)); - decoder->waitForBodyData(1088); + // Wait for a tick worth of data. + decoder->waitForBodyData(64); // Advance time and wait for a tick worth of data, trailers, and end stream. simTime().sleep(std::chrono::milliseconds(63)); - decoder->waitForBodyData(1152); + decoder->waitForBodyData(128); decoder->waitForEndStream(); EXPECT_NE(nullptr, decoder->trailers());