Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion include/envoy/common/token_bucket.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace Envoy {
*/
class TokenBucket {
public:
virtual ~TokenBucket() {}
virtual ~TokenBucket() = default;

/**
* @param tokens supplies the number of tokens to be consumed.
Expand All @@ -32,6 +32,12 @@ class TokenBucket {
* returns the upper bound on the amount of time until a next token is available.
*/
virtual std::chrono::milliseconds nextTokenAvailable() PURE;

/**
* Reset the bucket with a specific number of tokens. Refill will begin again from the time that
* this routine is called.
*/
virtual void reset(uint64_t num_tokens) PURE;
};

typedef std::unique_ptr<TokenBucket> TokenBucketPtr;
Expand Down
6 changes: 6 additions & 0 deletions source/common/common/token_bucket_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,10 @@ std::chrono::milliseconds TokenBucketImpl::nextTokenAvailable() {
return std::chrono::milliseconds(static_cast<uint64_t>(std::ceil((1 / fill_rate_) * 1000)));
}

void TokenBucketImpl::reset(uint64_t num_tokens) {
ASSERT(num_tokens <= max_tokens_);
tokens_ = num_tokens;
last_fill_ = time_source_.monotonicTime();
}

} // namespace Envoy
1 change: 1 addition & 0 deletions source/common/common/token_bucket_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class TokenBucketImpl : public TokenBucket {
// TokenBucket
uint64_t consume(uint64_t tokens, bool allow_partial) override;
std::chrono::milliseconds nextTokenAvailable() override;
void reset(uint64_t num_tokens) override;

private:
const double max_tokens_;
Expand Down
10 changes: 10 additions & 0 deletions source/extensions/filters/http/fault/fault_filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,16 @@ void StreamRateLimiter::onTokenTimer() {
ENVOY_LOG(trace, "limiter: timer wakeup: buffered={}", buffer_.length());
Buffer::OwnedImpl data_to_write;

if (!saw_data_) {
Comment thread
mattklein123 marked this conversation as resolved.
// The first time we see any data on this stream (via writeData()), reset the number of tokens
// to 1. This will ensure that we start pacing the data at the desired rate (and don't send a
// full 1s of data right away which might not introduce enough delay for a stream that doesn't
// have enough data to span more than 1s of rate allowance). Once we reset, we will subsequently
// allow for bursting within the second to account for our data provider being bursty.
token_bucket_.reset(1);
saw_data_ = true;
}

// Compute the number of tokens needed (rounded up), try to obtain that many tickets, and then
// figure out how many bytes to write given the number of tokens we actually got.
const uint64_t tokens_needed =
Expand Down
1 change: 1 addition & 0 deletions source/extensions/filters/http/fault/fault_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class StreamRateLimiter : Logger::Loggable<Logger::Id::filter> {
const std::function<void()> continue_cb_;
TokenBucketImpl token_bucket_;
Event::TimerPtr token_timer_;
bool saw_data_{};
bool saw_end_stream_{};
bool saw_trailers_{};
Buffer::WatermarkBuffer buffer_;
Expand Down
8 changes: 8 additions & 0 deletions test/common/common/token_bucket_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,12 @@ TEST_F(TokenBucketImplTest, PartialConsumption) {
EXPECT_EQ(std::chrono::milliseconds(63), token_bucket.nextTokenAvailable());
}

// Test reset functionality.
TEST_F(TokenBucketImplTest, Reset) {
TokenBucketImpl token_bucket{16, time_system_, 16};
token_bucket.reset(1);
EXPECT_EQ(1, token_bucket.consume(2, true));
EXPECT_EQ(std::chrono::milliseconds(63), token_bucket.nextTokenAvailable());
}

} // namespace Envoy
42 changes: 19 additions & 23 deletions test/extensions/filters/http/fault/fault_filter_integration_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,15 @@ TEST_P(FaultIntegrationTestAllProtocols, ResponseRateLimitNoTrailers) {
codec_client_->makeHeaderOnlyRequest(default_request_headers_);
waitForNextUpstreamRequest();
upstream_request_->encodeHeaders(default_response_headers_, false);
Buffer::OwnedImpl data(std::string(1152, 'a'));
Buffer::OwnedImpl data(std::string(127, 'a'));
upstream_request_->encodeData(data, true);
decoder->waitForBodyData(1024);

// Advance time and wait for a tick worth of data.
simTime().sleep(std::chrono::milliseconds(63));
decoder->waitForBodyData(1088);
// Wait for a tick worth of data.
decoder->waitForBodyData(64);

// Advance time and wait for a tick worth of data and end stream.
// Wait for a tick worth of data and end stream.
simTime().sleep(std::chrono::milliseconds(63));
decoder->waitForBodyData(1152);
decoder->waitForBodyData(127);
decoder->waitForEndStream();

EXPECT_EQ(0UL, test_server_->counter("http.config_test.fault.delays_injected")->value());
Expand All @@ -110,13 +108,15 @@ TEST_P(FaultIntegrationTestAllProtocols, HeaderFaultConfig) {

// Verify response body throttling.
upstream_request_->encodeHeaders(default_response_headers_, false);
Buffer::OwnedImpl data(std::string(1025, 'a'));
Buffer::OwnedImpl data(std::string(128, 'a'));
upstream_request_->encodeData(data, true);
decoder->waitForBodyData(1024);

// Advance time and wait for a tick worth of data and end stream.
// Wait for a tick worth of data.
decoder->waitForBodyData(64);

// Wait for a tick worth of data and end stream.
simTime().sleep(std::chrono::milliseconds(63));
decoder->waitForBodyData(1025);
decoder->waitForBodyData(128);
decoder->waitForEndStream();

EXPECT_EQ(1UL, test_server_->counter("http.config_test.fault.delays_injected")->value());
Expand Down Expand Up @@ -149,17 +149,15 @@ TEST_P(FaultIntegrationTestHttp2, ResponseRateLimitTrailersBodyFlushed) {
codec_client_->makeHeaderOnlyRequest(default_request_headers_);
waitForNextUpstreamRequest();
upstream_request_->encodeHeaders(default_response_headers_, false);
Buffer::OwnedImpl data(std::string(1152, 'a'));
Buffer::OwnedImpl data(std::string(127, 'a'));
upstream_request_->encodeData(data, false);
decoder->waitForBodyData(1024);

// Advance time and wait for a tick worth of data.
simTime().sleep(std::chrono::milliseconds(63));
decoder->waitForBodyData(1088);
// Wait for a tick worth of data.
decoder->waitForBodyData(64);

// Advance time and wait for a tick worth of data.
simTime().sleep(std::chrono::milliseconds(63));
decoder->waitForBodyData(1152);
decoder->waitForBodyData(127);

// Send trailers and wait for end stream.
Http::TestHeaderMapImpl trailers{{"hello", "world"}};
Expand All @@ -179,19 +177,17 @@ TEST_P(FaultIntegrationTestHttp2, ResponseRateLimitTrailersBodyNotFlushed) {
codec_client_->makeHeaderOnlyRequest(default_request_headers_);
waitForNextUpstreamRequest();
upstream_request_->encodeHeaders(default_response_headers_, false);
Buffer::OwnedImpl data(std::string(1152, 'a'));
Buffer::OwnedImpl data(std::string(128, 'a'));
upstream_request_->encodeData(data, false);
Http::TestHeaderMapImpl trailers{{"hello", "world"}};
upstream_request_->encodeTrailers(trailers);
decoder->waitForBodyData(1024);

// Advance time and wait for a tick worth of data.
simTime().sleep(std::chrono::milliseconds(63));
decoder->waitForBodyData(1088);
// Wait for a tick worth of data.
decoder->waitForBodyData(64);

// Advance time and wait for a tick worth of data, trailers, and end stream.
simTime().sleep(std::chrono::milliseconds(63));
decoder->waitForBodyData(1152);
decoder->waitForBodyData(128);
decoder->waitForEndStream();
EXPECT_NE(nullptr, decoder->trailers());

Expand Down
15 changes: 8 additions & 7 deletions tools/check_spelling_pedantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,14 +525,15 @@ def execute(files, dictionary_file, fix):
DEBUG = args.debug
MARK = args.mark

target_paths = args.target_paths
if not target_paths:
exts = ['.cc', '.h', '.proto']
target_paths = []
paths = args.target_paths
if not paths:
paths = ['./api', './include', './source', './test']
for p in paths:
for root, _, files in os.walk(p):
target_paths += [os.path.join(root, f) for f in files if os.path.splitext(f)[1] in exts]

exts = ['.cc', '.h', '.proto']
target_paths = []
for p in paths:
for root, _, files in os.walk(p):
target_paths += [os.path.join(root, f) for f in files if os.path.splitext(f)[1] in exts]

rv = execute(target_paths, args.dictionary, args.operation_type == 'fix')

Expand Down