diff --git a/api/envoy/config/route/v3/route_components.proto b/api/envoy/config/route/v3/route_components.proto index c1bc1ace91302..1dac8a3250b00 100644 --- a/api/envoy/config/route/v3/route_components.proto +++ b/api/envoy/config/route/v3/route_components.proto @@ -1516,6 +1516,7 @@ message VirtualCluster { } // Global rate limiting :ref:`architecture overview `. +// Also applies to Local rate limiting :ref:`using descriptors `. message RateLimit { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit"; diff --git a/api/envoy/config/route/v4alpha/route_components.proto b/api/envoy/config/route/v4alpha/route_components.proto index 4083513009d77..b60300845ef42 100644 --- a/api/envoy/config/route/v4alpha/route_components.proto +++ b/api/envoy/config/route/v4alpha/route_components.proto @@ -1465,6 +1465,7 @@ message VirtualCluster { } // Global rate limiting :ref:`architecture overview `. +// Also applies to Local rate limiting :ref:`using descriptors `. message RateLimit { option (udpa.annotations.versioning).previous_message_type = "envoy.config.route.v3.RateLimit"; diff --git a/api/envoy/extensions/common/ratelimit/v3/ratelimit.proto b/api/envoy/extensions/common/ratelimit/v3/ratelimit.proto index 30efa60262187..6bb771d25af94 100644 --- a/api/envoy/extensions/common/ratelimit/v3/ratelimit.proto +++ b/api/envoy/extensions/common/ratelimit/v3/ratelimit.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.extensions.common.ratelimit.v3; import "envoy/type/v3/ratelimit_unit.proto"; +import "envoy/type/v3/token_bucket.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; @@ -92,3 +93,11 @@ message RateLimitDescriptor { // Optional rate limit override to supply to the ratelimit service. RateLimitOverride limit = 2; } + +message LocalRateLimitDescriptor { + // Descriptor entries. + repeated v3.RateLimitDescriptor.Entry entries = 1 [(validate.rules).repeated = {min_items: 1}]; + + // Token Bucket algorithm for local ratelimiting. + type.v3.TokenBucket token_bucket = 2 [(validate.rules).message = {required: true}]; +} diff --git a/api/envoy/extensions/filters/http/local_ratelimit/v3/BUILD b/api/envoy/extensions/filters/http/local_ratelimit/v3/BUILD index ad2fc9a9a84fd..6c58a43e4ff6b 100644 --- a/api/envoy/extensions/filters/http/local_ratelimit/v3/BUILD +++ b/api/envoy/extensions/filters/http/local_ratelimit/v3/BUILD @@ -7,6 +7,7 @@ licenses(["notice"]) # Apache 2 api_proto_package( deps = [ "//envoy/config/core/v3:pkg", + "//envoy/extensions/common/ratelimit/v3:pkg", "//envoy/type/v3:pkg", "@com_github_cncf_udpa//udpa/annotations:pkg", ], diff --git a/api/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto b/api/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto index 94f21edd3eedb..a7d1592746fde 100644 --- a/api/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto +++ b/api/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.extensions.filters.http.local_ratelimit.v3; import "envoy/config/core/v3/base.proto"; +import "envoy/extensions/common/ratelimit/v3/ratelimit.proto"; import "envoy/type/v3/http_status.proto"; import "envoy/type/v3/token_bucket.proto"; @@ -19,7 +20,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // Local Rate limit :ref:`configuration overview `. // [#extension: envoy.filters.http.local_ratelimit] -// [#next-free-field: 7] +// [#next-free-field: 10] message LocalRateLimit { // The human readable prefix to use when emitting stats. string stat_prefix = 1 [(validate.rules).string = {min_len: 1}]; @@ -67,4 +68,23 @@ message LocalRateLimit { // have been rate limited. repeated config.core.v3.HeaderValueOption response_headers_to_add = 6 [(validate.rules).repeated = {max_items: 10}]; + + // The rate limit descriptor list to use in the local rate limit to override + // on. + // Example on how to use ::ref:`this ` + // + // .. note:: + // + // In the current implementation the descriptor's token bucket :ref:`fill_interval + // ` should be a multiple + // global :ref:`token bucket's` fill interval. + repeated common.ratelimit.v3.LocalRateLimitDescriptor descriptors = 8; + + // Specifies the rate limit configurations to be applied with the same + // stage number. If not set, the default stage number is 0. + // + // .. note:: + // + // The filter supports a range of 0 - 10 inclusively for stage numbers. + uint32 stage = 9 [(validate.rules).uint32 = {lte: 10}]; } diff --git a/docs/root/api-v3/config/common/common.rst b/docs/root/api-v3/config/common/common.rst index bb6965a5f1497..f286ba06c4e93 100644 --- a/docs/root/api-v3/config/common/common.rst +++ b/docs/root/api-v3/config/common/common.rst @@ -8,3 +8,4 @@ Common matcher/v3/* ../../extensions/common/dynamic_forward_proxy/v3/* ../../extensions/common/tap/v3/* + ../../extensions/common/ratelimit/v3/* diff --git a/docs/root/configuration/http/http_filters/local_rate_limit_filter.rst b/docs/root/configuration/http/http_filters/local_rate_limit_filter.rst index 78bbc806a78ef..a02b8d2d719bc 100644 --- a/docs/root/configuration/http/http_filters/local_rate_limit_filter.rst +++ b/docs/root/configuration/http/http_filters/local_rate_limit_filter.rst @@ -103,6 +103,89 @@ The route specific configuration: Note that if this filter is configured as globally disabled and there are no virtual host or route level token buckets, no rate limiting will be applied. +.. _config_http_filters_local_rate_limit_descriptors: + +Using Descriptors to rate limit on +---------------------------------- + +Descriptors can be used to override local rate limiting based on presence of certain descriptors/route actions. +A route's :ref:`rate limit action ` is used to match up a +:ref:`local descriptor ` in the filter config descriptor list. +The local descriptor's token bucket config is used to decide if the request should be +rate limited or not, if the local descriptor's entries match the route's rate limit actions descriptor entries. +Otherwise the default token bucket config is used. + +Example filter configuration using descriptors is as follows: + +.. validated-code-block:: yaml + :type-name: envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager + + route_config: + name: local_route + virtual_hosts: + - name: local_service + domains: ["*"] + routes: + - match: { prefix: "/foo" } + route: { cluster: service_protected_by_rate_limit } + typed_per_filter_config: + envoy.filters.http.local_ratelimit: + "@type": type.googleapis.com/envoy.extensions.filters.http.local_ratelimit.v3.LocalRateLimit + stat_prefix: test + token_bucket: + max_tokens: 1000 + tokens_per_fill: 1000 + fill_interval: 60s + filter_enabled: + runtime_key: test_enabled + default_value: + numerator: 100 + denominator: HUNDRED + filter_enforced: + runtime_key: test_enforced + default_value: + numerator: 100 + denominator: HUNDRED + response_headers_to_add: + - append: false + header: + key: x-test-rate-limit + value: 'true' + descriptors: + - entries: + - key: client_id + value: foo + - key: path + value: /foo/bar + token_bucket: + max_tokens: 10 + tokens_per_fill: 10 + fill_interval: 60s + - entries: + - key: client_id + value: foo + - key: path + value: /foo/bar2 + token_bucket: + max_tokens: 100 + tokens_per_fill: 100 + fill_interval: 60s + - match: { prefix: "/" } + route: { cluster: default_service } + rate_limits: + - actions: # any actions in here + - request_headers: + header_name: ":path" + descriptor_key: "path" + - generic_key: + descriptor_value: "foo" + descriptor_key: "client_id" + +For this config, requests are ratelimited for routes prefixed with "/foo" +In that, if requests come from client_id "foo" for "/foo/bar" path, then 10 req/min are allowed. +But if they come from client_id "foo" for "/foo/bar2" path, then 100 req/min are allowed. +Otherwise 1000 req/min are allowed. + Statistics ---------- diff --git a/generated_api_shadow/envoy/config/route/v3/route_components.proto b/generated_api_shadow/envoy/config/route/v3/route_components.proto index 89c0a1c76cf71..71fd35be4a8d1 100644 --- a/generated_api_shadow/envoy/config/route/v3/route_components.proto +++ b/generated_api_shadow/envoy/config/route/v3/route_components.proto @@ -1528,6 +1528,7 @@ message VirtualCluster { } // Global rate limiting :ref:`architecture overview `. +// Also applies to Local rate limiting :ref:`using descriptors `. message RateLimit { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit"; diff --git a/generated_api_shadow/envoy/config/route/v4alpha/route_components.proto b/generated_api_shadow/envoy/config/route/v4alpha/route_components.proto index b67c4efa39526..37cf45dc963b6 100644 --- a/generated_api_shadow/envoy/config/route/v4alpha/route_components.proto +++ b/generated_api_shadow/envoy/config/route/v4alpha/route_components.proto @@ -1532,6 +1532,7 @@ message VirtualCluster { } // Global rate limiting :ref:`architecture overview `. +// Also applies to Local rate limiting :ref:`using descriptors `. message RateLimit { option (udpa.annotations.versioning).previous_message_type = "envoy.config.route.v3.RateLimit"; diff --git a/generated_api_shadow/envoy/extensions/common/ratelimit/v3/ratelimit.proto b/generated_api_shadow/envoy/extensions/common/ratelimit/v3/ratelimit.proto index 30efa60262187..6bb771d25af94 100644 --- a/generated_api_shadow/envoy/extensions/common/ratelimit/v3/ratelimit.proto +++ b/generated_api_shadow/envoy/extensions/common/ratelimit/v3/ratelimit.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.extensions.common.ratelimit.v3; import "envoy/type/v3/ratelimit_unit.proto"; +import "envoy/type/v3/token_bucket.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; @@ -92,3 +93,11 @@ message RateLimitDescriptor { // Optional rate limit override to supply to the ratelimit service. RateLimitOverride limit = 2; } + +message LocalRateLimitDescriptor { + // Descriptor entries. + repeated v3.RateLimitDescriptor.Entry entries = 1 [(validate.rules).repeated = {min_items: 1}]; + + // Token Bucket algorithm for local ratelimiting. + type.v3.TokenBucket token_bucket = 2 [(validate.rules).message = {required: true}]; +} diff --git a/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/BUILD b/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/BUILD index ad2fc9a9a84fd..6c58a43e4ff6b 100644 --- a/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/BUILD +++ b/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/BUILD @@ -7,6 +7,7 @@ licenses(["notice"]) # Apache 2 api_proto_package( deps = [ "//envoy/config/core/v3:pkg", + "//envoy/extensions/common/ratelimit/v3:pkg", "//envoy/type/v3:pkg", "@com_github_cncf_udpa//udpa/annotations:pkg", ], diff --git a/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto b/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto index 94f21edd3eedb..a7d1592746fde 100644 --- a/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto +++ b/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.extensions.filters.http.local_ratelimit.v3; import "envoy/config/core/v3/base.proto"; +import "envoy/extensions/common/ratelimit/v3/ratelimit.proto"; import "envoy/type/v3/http_status.proto"; import "envoy/type/v3/token_bucket.proto"; @@ -19,7 +20,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // Local Rate limit :ref:`configuration overview `. // [#extension: envoy.filters.http.local_ratelimit] -// [#next-free-field: 7] +// [#next-free-field: 10] message LocalRateLimit { // The human readable prefix to use when emitting stats. string stat_prefix = 1 [(validate.rules).string = {min_len: 1}]; @@ -67,4 +68,23 @@ message LocalRateLimit { // have been rate limited. repeated config.core.v3.HeaderValueOption response_headers_to_add = 6 [(validate.rules).repeated = {max_items: 10}]; + + // The rate limit descriptor list to use in the local rate limit to override + // on. + // Example on how to use ::ref:`this ` + // + // .. note:: + // + // In the current implementation the descriptor's token bucket :ref:`fill_interval + // ` should be a multiple + // global :ref:`token bucket's` fill interval. + repeated common.ratelimit.v3.LocalRateLimitDescriptor descriptors = 8; + + // Specifies the rate limit configurations to be applied with the same + // stage number. If not set, the default stage number is 0. + // + // .. note:: + // + // The filter supports a range of 0 - 10 inclusively for stage numbers. + uint32 stage = 9 [(validate.rules).uint32 = {lte: 10}]; } diff --git a/include/envoy/ratelimit/ratelimit.h b/include/envoy/ratelimit/ratelimit.h index f23c8170ef684..73bdd23074bdf 100644 --- a/include/envoy/ratelimit/ratelimit.h +++ b/include/envoy/ratelimit/ratelimit.h @@ -5,6 +5,7 @@ #include "envoy/type/v3/ratelimit_unit.pb.h" +#include "absl/time/time.h" #include "absl/types/optional.h" namespace Envoy { @@ -24,6 +25,10 @@ struct RateLimitOverride { struct DescriptorEntry { std::string key_; std::string value_; + + friend bool operator==(const DescriptorEntry& lhs, const DescriptorEntry& rhs) { + return lhs.key_ == rhs.key_ && lhs.value_ == rhs.value_; + } }; /** @@ -34,5 +39,48 @@ struct Descriptor { absl::optional limit_ = absl::nullopt; }; +/** + * A single token bucket. See token_bucket.proto. + */ +struct TokenBucket { + uint32_t max_tokens_; + uint32_t tokens_per_fill_; + absl::Duration fill_interval_; + + friend bool operator==(const TokenBucket& lhs, const TokenBucket& rhs) { + return lhs.max_tokens_ == rhs.max_tokens_ && lhs.tokens_per_fill_ == rhs.tokens_per_fill_ && + lhs.fill_interval_ == rhs.fill_interval_; + } + + // Support absl::Hash. + template + friend H AbslHashValue(H h, const TokenBucket& d) { // NOLINT(readability-identifier-naming) + h = H::combine(std::move(h), d.max_tokens_, d.tokens_per_fill_, d.fill_interval_); + return h; + } +}; + +/** + * A single rate limit request descriptor. See ratelimit.proto. + */ +struct LocalDescriptor { + std::vector entries_; + TokenBucket token_bucket_; + + friend bool operator==(const LocalDescriptor& lhs, const LocalDescriptor& rhs) { + return lhs.entries_ == rhs.entries_ && lhs.token_bucket_ == rhs.token_bucket_; + } + + // Support absl::Hash. + template + friend H AbslHashValue(H h, const LocalDescriptor& d) { // NOLINT(readability-identifier-naming) + for (const auto& entry : d.entries_) { + h = H::combine(std::move(h), entry.key_, entry.value_); + } + h = H::combine(std::move(h), d.token_bucket_); + return h; + } +}; + } // namespace RateLimit } // namespace Envoy diff --git a/include/envoy/router/router_ratelimit.h b/include/envoy/router/router_ratelimit.h index 1e6910c3b9ba2..41b00a1914c4d 100644 --- a/include/envoy/router/router_ratelimit.h +++ b/include/envoy/router/router_ratelimit.h @@ -40,7 +40,7 @@ class RateLimitAction { virtual ~RateLimitAction() = default; /** - * Potentially append a descriptor entry to the end of descriptor. + * Potentially fill a descriptor entry to the end of descriptor. * @param route supplies the target route for the request. * @param descriptor supplies the descriptor to optionally fill. * @param local_service_cluster supplies the name of the local service cluster. @@ -50,7 +50,7 @@ class RateLimitAction { * @return true if the RateLimitAction populated the descriptor. */ virtual bool - populateDescriptor(const RouteEntry& route, RateLimit::Descriptor& descriptor, + populateDescriptor(const RouteEntry& route, RateLimit::DescriptorEntry& descriptor, const std::string& local_service_cluster, const Http::HeaderMap& headers, const Network::Address::Instance& remote_address, const envoy::config::core::v3::Metadata* dynamic_metadata) const PURE; @@ -89,6 +89,22 @@ class RateLimitPolicyEntry { const std::string& local_service_cluster, const Http::HeaderMap& headers, const Network::Address::Instance& remote_address, const envoy::config::core::v3::Metadata* dynamic_metadata) const PURE; + + /** + * Potentially populate the local descriptor array with new local descriptors to query. + * @param route supplies the target route for the request. + * @param descriptors supplies the descriptor array to optionally fill. + * @param local_service_cluster supplies the name of the local service cluster. + * @param headers supplies the header for the request. + * @param remote_address supplies the trusted downstream address for the connection. + * @param dynamic_metadata supplies the dynamic metadata for the request. + */ + virtual void + populateLocalDescriptors(const RouteEntry& route, + std::vector& descriptors, + const std::string& local_service_cluster, const Http::HeaderMap& headers, + const Network::Address::Instance& remote_address, + const envoy::config::core::v3::Metadata* dynamic_metadata) const PURE; }; /** diff --git a/source/common/router/router_ratelimit.cc b/source/common/router/router_ratelimit.cc index 0774f8340be5f..7167cd1e60997 100644 --- a/source/common/router/router_ratelimit.cc +++ b/source/common/router/router_ratelimit.cc @@ -16,6 +16,29 @@ namespace Envoy { namespace Router { +namespace { +bool populateDescriptor(const std::vector& actions, + std::vector& descriptor_entries, + const Router::RouteEntry& route, const std::string& local_service_cluster, + const Http::HeaderMap& headers, + const Network::Address::Instance& remote_address, + const envoy::config::core::v3::Metadata* dynamic_metadata) { + bool result = true; + for (const RateLimitActionPtr& action : actions) { + RateLimit::DescriptorEntry descriptor_entry; + result = result && action->populateDescriptor(route, descriptor_entry, local_service_cluster, + headers, remote_address, dynamic_metadata); + if (!result) { + break; + } + if (!descriptor_entry.key_.empty()) { + descriptor_entries.push_back(descriptor_entry); + } + } + return result; +} +} // namespace + const uint64_t RateLimitPolicyImpl::MAX_STAGE_NUMBER = 10UL; bool DynamicMetadataRateLimitOverride::populateOverride( @@ -44,27 +67,29 @@ bool DynamicMetadataRateLimitOverride::populateOverride( } bool SourceClusterAction::populateDescriptor(const Router::RouteEntry&, - RateLimit::Descriptor& descriptor, + RateLimit::DescriptorEntry& descriptor_entry, const std::string& local_service_cluster, const Http::HeaderMap&, const Network::Address::Instance&, const envoy::config::core::v3::Metadata*) const { - descriptor.entries_.push_back({"source_cluster", local_service_cluster}); + descriptor_entry.key_ = "source_cluster"; + descriptor_entry.value_ = local_service_cluster; return true; } bool DestinationClusterAction::populateDescriptor(const Router::RouteEntry& route, - RateLimit::Descriptor& descriptor, + RateLimit::DescriptorEntry& descriptor_entry, const std::string&, const Http::HeaderMap&, const Network::Address::Instance&, const envoy::config::core::v3::Metadata*) const { - descriptor.entries_.push_back({"destination_cluster", route.clusterName()}); + descriptor_entry.key_ = "destination_cluster"; + descriptor_entry.value_ = route.clusterName(); return true; } bool RequestHeadersAction::populateDescriptor(const Router::RouteEntry&, - RateLimit::Descriptor& descriptor, const std::string&, - const Http::HeaderMap& headers, + RateLimit::DescriptorEntry& descriptor_entry, + const std::string&, const Http::HeaderMap& headers, const Network::Address::Instance&, const envoy::config::core::v3::Metadata*) const { const auto header_value = headers.get(header_name_); @@ -76,29 +101,32 @@ bool RequestHeadersAction::populateDescriptor(const Router::RouteEntry&, return skip_if_absent_; } // TODO(https://github.com/envoyproxy/envoy/issues/13454): Potentially populate all header values. - descriptor.entries_.push_back( - {descriptor_key_, std::string(header_value[0]->value().getStringView())}); + descriptor_entry.key_ = descriptor_key_; + descriptor_entry.value_ = std::string(header_value[0]->value().getStringView()); return true; } bool RemoteAddressAction::populateDescriptor(const Router::RouteEntry&, - RateLimit::Descriptor& descriptor, const std::string&, - const Http::HeaderMap&, + RateLimit::DescriptorEntry& descriptor_entry, + const std::string&, const Http::HeaderMap&, const Network::Address::Instance& remote_address, const envoy::config::core::v3::Metadata*) const { if (remote_address.type() != Network::Address::Type::Ip) { return false; } - descriptor.entries_.push_back({"remote_address", remote_address.ip()->addressAsString()}); + descriptor_entry.key_ = "remote_address"; + descriptor_entry.value_ = remote_address.ip()->addressAsString(); return true; } bool GenericKeyAction::populateDescriptor(const Router::RouteEntry&, - RateLimit::Descriptor& descriptor, const std::string&, - const Http::HeaderMap&, const Network::Address::Instance&, + RateLimit::DescriptorEntry& descriptor_entry, + const std::string&, const Http::HeaderMap&, + const Network::Address::Instance&, const envoy::config::core::v3::Metadata*) const { - descriptor.entries_.push_back({descriptor_key_, descriptor_value_}); + descriptor_entry.key_ = descriptor_key_; + descriptor_entry.value_ = descriptor_value_; return true; } @@ -113,8 +141,8 @@ MetaDataAction::MetaDataAction( source_(envoy::config::route::v3::RateLimit::Action::MetaData::DYNAMIC) {} bool MetaDataAction::populateDescriptor( - const Router::RouteEntry& route, RateLimit::Descriptor& descriptor, const std::string&, - const Http::HeaderMap&, const Network::Address::Instance&, + const Router::RouteEntry& route, RateLimit::DescriptorEntry& descriptor_entry, + const std::string&, const Http::HeaderMap&, const Network::Address::Instance&, const envoy::config::core::v3::Metadata* dynamic_metadata) const { const envoy::config::core::v3::Metadata* metadata_source; @@ -133,10 +161,12 @@ bool MetaDataAction::populateDescriptor( Envoy::Config::Metadata::metadataValue(metadata_source, metadata_key_).string_value(); if (!metadata_string_value.empty()) { - descriptor.entries_.push_back({descriptor_key_, metadata_string_value}); + descriptor_entry.key_ = descriptor_key_; + descriptor_entry.value_ = metadata_string_value; return true; } else if (metadata_string_value.empty() && !default_value_.empty()) { - descriptor.entries_.push_back({descriptor_key_, default_value_}); + descriptor_entry.key_ = descriptor_key_; + descriptor_entry.value_ = default_value_; return true; } @@ -150,12 +180,13 @@ HeaderValueMatchAction::HeaderValueMatchAction( action_headers_(Http::HeaderUtility::buildHeaderDataVector(action.headers())) {} bool HeaderValueMatchAction::populateDescriptor(const Router::RouteEntry&, - RateLimit::Descriptor& descriptor, + RateLimit::DescriptorEntry& descriptor_entry, const std::string&, const Http::HeaderMap& headers, const Network::Address::Instance&, const envoy::config::core::v3::Metadata*) const { if (expect_match_ == Http::HeaderUtility::matchHeaders(headers, action_headers_)) { - descriptor.entries_.push_back({"header_match", descriptor_value_}); + descriptor_entry.key_ = "header_match"; + descriptor_entry.value_ = descriptor_value_; return true; } else { return false; @@ -214,14 +245,8 @@ void RateLimitPolicyEntryImpl::populateDescriptors( const Network::Address::Instance& remote_address, const envoy::config::core::v3::Metadata* dynamic_metadata) const { RateLimit::Descriptor descriptor; - bool result = true; - for (const RateLimitActionPtr& action : actions_) { - result = result && action->populateDescriptor(route, descriptor, local_service_cluster, headers, - remote_address, dynamic_metadata); - if (!result) { - break; - } - } + bool result = populateDescriptor(actions_, descriptor.entries_, route, local_service_cluster, + headers, remote_address, dynamic_metadata); if (limit_override_) { limit_override_.value()->populateOverride(descriptor, dynamic_metadata); @@ -232,6 +257,20 @@ void RateLimitPolicyEntryImpl::populateDescriptors( } } +void RateLimitPolicyEntryImpl::populateLocalDescriptors( + const Router::RouteEntry& route, std::vector& descriptors, + const std::string& local_service_cluster, const Http::HeaderMap& headers, + const Network::Address::Instance& remote_address, + const envoy::config::core::v3::Metadata* dynamic_metadata) const { + RateLimit::LocalDescriptor descriptor; + descriptor.token_bucket_ = {}; + + if (populateDescriptor(actions_, descriptor.entries_, route, local_service_cluster, headers, + remote_address, dynamic_metadata)) { + descriptors.emplace_back(descriptor); + } +} + RateLimitPolicyImpl::RateLimitPolicyImpl( const Protobuf::RepeatedPtrField& rate_limits) : rate_limit_entries_reference_(RateLimitPolicyImpl::MAX_STAGE_NUMBER + 1) { diff --git a/source/common/router/router_ratelimit.h b/source/common/router/router_ratelimit.h index 912606fc0da85..1bee646cb733b 100644 --- a/source/common/router/router_ratelimit.h +++ b/source/common/router/router_ratelimit.h @@ -41,7 +41,7 @@ class DynamicMetadataRateLimitOverride : public RateLimitOverrideAction { class SourceClusterAction : public RateLimitAction { public: // Router::RateLimitAction - bool populateDescriptor(const Router::RouteEntry& route, RateLimit::Descriptor& descriptor, + bool populateDescriptor(const Router::RouteEntry& route, RateLimit::DescriptorEntry& descriptor, const std::string& local_service_cluster, const Http::HeaderMap& headers, const Network::Address::Instance& remote_address, const envoy::config::core::v3::Metadata* dynamic_metadata) const override; @@ -53,7 +53,7 @@ class SourceClusterAction : public RateLimitAction { class DestinationClusterAction : public RateLimitAction { public: // Router::RateLimitAction - bool populateDescriptor(const Router::RouteEntry& route, RateLimit::Descriptor& descriptor, + bool populateDescriptor(const Router::RouteEntry& route, RateLimit::DescriptorEntry& descriptor, const std::string& local_service_cluster, const Http::HeaderMap& headers, const Network::Address::Instance& remote_address, const envoy::config::core::v3::Metadata* dynamic_metadata) const override; @@ -69,7 +69,7 @@ class RequestHeadersAction : public RateLimitAction { skip_if_absent_(action.skip_if_absent()) {} // Router::RateLimitAction - bool populateDescriptor(const Router::RouteEntry& route, RateLimit::Descriptor& descriptor, + bool populateDescriptor(const Router::RouteEntry& route, RateLimit::DescriptorEntry& descriptor, const std::string& local_service_cluster, const Http::HeaderMap& headers, const Network::Address::Instance& remote_address, const envoy::config::core::v3::Metadata* dynamic_metadata) const override; @@ -86,7 +86,7 @@ class RequestHeadersAction : public RateLimitAction { class RemoteAddressAction : public RateLimitAction { public: // Router::RateLimitAction - bool populateDescriptor(const Router::RouteEntry& route, RateLimit::Descriptor& descriptor, + bool populateDescriptor(const Router::RouteEntry& route, RateLimit::DescriptorEntry& descriptor, const std::string& local_service_cluster, const Http::HeaderMap& headers, const Network::Address::Instance& remote_address, const envoy::config::core::v3::Metadata* dynamic_metadata) const override; @@ -103,7 +103,7 @@ class GenericKeyAction : public RateLimitAction { : "generic_key") {} // Router::RateLimitAction - bool populateDescriptor(const Router::RouteEntry& route, RateLimit::Descriptor& descriptor, + bool populateDescriptor(const Router::RouteEntry& route, RateLimit::DescriptorEntry& descriptor, const std::string& local_service_cluster, const Http::HeaderMap& headers, const Network::Address::Instance& remote_address, const envoy::config::core::v3::Metadata* dynamic_metadata) const override; @@ -122,7 +122,7 @@ class MetaDataAction : public RateLimitAction { // for maintaining backward compatibility with the deprecated DynamicMetaData action MetaDataAction(const envoy::config::route::v3::RateLimit::Action::DynamicMetaData& action); // Router::RateLimitAction - bool populateDescriptor(const Router::RouteEntry& route, RateLimit::Descriptor& descriptor, + bool populateDescriptor(const Router::RouteEntry& route, RateLimit::DescriptorEntry& descriptor, const std::string& local_service_cluster, const Http::HeaderMap& headers, const Network::Address::Instance& remote_address, const envoy::config::core::v3::Metadata* dynamic_metadata) const override; @@ -143,7 +143,7 @@ class HeaderValueMatchAction : public RateLimitAction { const envoy::config::route::v3::RateLimit::Action::HeaderValueMatch& action); // Router::RateLimitAction - bool populateDescriptor(const Router::RouteEntry& route, RateLimit::Descriptor& descriptor, + bool populateDescriptor(const Router::RouteEntry& route, RateLimit::DescriptorEntry& descriptor, const std::string& local_service_cluster, const Http::HeaderMap& headers, const Network::Address::Instance& remote_address, const envoy::config::core::v3::Metadata* dynamic_metadata) const override; @@ -170,6 +170,11 @@ class RateLimitPolicyEntryImpl : public RateLimitPolicyEntry { const std::string& local_service_cluster, const Http::HeaderMap&, const Network::Address::Instance& remote_address, const envoy::config::core::v3::Metadata* dynamic_metadata) const override; + void populateLocalDescriptors( + const Router::RouteEntry& route, std::vector& descriptors, + const std::string& local_service_cluster, const Http::HeaderMap&, + const Network::Address::Instance& remote_address, + const envoy::config::core::v3::Metadata* dynamic_metadata) const override; private: const std::string disable_key_; diff --git a/source/extensions/filters/common/local_ratelimit/BUILD b/source/extensions/filters/common/local_ratelimit/BUILD index 1a201025ca3fc..0234c335c3e68 100644 --- a/source/extensions/filters/common/local_ratelimit/BUILD +++ b/source/extensions/filters/common/local_ratelimit/BUILD @@ -15,6 +15,9 @@ envoy_cc_library( deps = [ "//include/envoy/event:dispatcher_interface", "//include/envoy/event:timer_interface", + "//include/envoy/ratelimit:ratelimit_interface", "//source/common/common:thread_synchronizer_lib", + "//source/common/protobuf:utility_lib", + "@envoy_api//envoy/extensions/common/ratelimit/v3:pkg_cc_proto", ], ) diff --git a/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.cc b/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.cc index 2adee384673e1..23619bf1910e7 100644 --- a/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.cc +++ b/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.cc @@ -1,28 +1,63 @@ #include "extensions/filters/common/local_ratelimit/local_ratelimit_impl.h" +#include "common/protobuf/utility.h" + namespace Envoy { namespace Extensions { namespace Filters { namespace Common { namespace LocalRateLimit { -LocalRateLimiterImpl::LocalRateLimiterImpl(const std::chrono::milliseconds fill_interval, - const uint32_t max_tokens, - const uint32_t tokens_per_fill, - Event::Dispatcher& dispatcher) +LocalRateLimiterImpl::LocalRateLimiterImpl( + const std::chrono::milliseconds fill_interval, const uint32_t max_tokens, + const uint32_t tokens_per_fill, Event::Dispatcher& dispatcher, + const Envoy::Protobuf::RepeatedPtrField< + envoy::extensions::common::ratelimit::v3::LocalRateLimitDescriptor>& descriptors) : fill_interval_(fill_interval), max_tokens_(max_tokens), tokens_per_fill_(tokens_per_fill), fill_timer_(fill_interval_ > std::chrono::milliseconds(0) ? dispatcher.createTimer([this] { onFillTimer(); }) - : nullptr) { + : nullptr), + time_source_(dispatcher.timeSource()) { if (fill_timer_ && fill_interval_ < std::chrono::milliseconds(50)) { throw EnvoyException("local rate limit token bucket fill timer must be >= 50ms"); } - tokens_ = max_tokens; + tokens_.tokens = max_tokens; if (fill_timer_) { fill_timer_->enableTimer(fill_interval_); } + + for (const auto& descriptor : descriptors) { + Envoy::RateLimit::LocalDescriptor new_descriptor; + for (const auto& entry : descriptor.entries()) { + new_descriptor.entries_.push_back({entry.key(), entry.value()}); + } + Envoy::RateLimit::TokenBucket token_bucket; + token_bucket.fill_interval_ = + absl::Milliseconds(PROTOBUF_GET_MS_OR_DEFAULT(descriptor.token_bucket(), fill_interval, 0)); + if (absl::ToChronoMilliseconds(token_bucket.fill_interval_).count() % fill_interval_.count() != + 0) { + throw EnvoyException( + "local rate descriptor limit is not a multiple of token bucket fill timer"); + } + token_bucket.max_tokens_ = descriptor.token_bucket().max_tokens(); + token_bucket.tokens_per_fill_ = + PROTOBUF_GET_WRAPPED_OR_DEFAULT(descriptor.token_bucket(), tokens_per_fill, 1); + new_descriptor.token_bucket_ = token_bucket; + + // Push to descriptors vector to maintain the ordering in which each + // descriptor appeared in the config. + descriptors_.push_back(new_descriptor); + + // Maintain the hash map of state of token bucket for each descriptor. + std::unique_ptr descriptor_state_token = + std::make_unique(); + // Fill with max_tokens first time. + descriptor_state_token->token.tokens = token_bucket.max_tokens_; + descriptor_state_token->monotonic_time = time_source_.monotonicTime(); + tokens_per_descriptor_[new_descriptor] = std::move(descriptor_state_token); + } } LocalRateLimiterImpl::~LocalRateLimiterImpl() { @@ -32,28 +67,49 @@ LocalRateLimiterImpl::~LocalRateLimiterImpl() { } void LocalRateLimiterImpl::onFillTimer() { + onFillTimerHelper(tokens_, max_tokens_, tokens_per_fill_); + onFillTimerDescriptorHelper(); + fill_timer_->enableTimer(fill_interval_); +} + +void LocalRateLimiterImpl::onFillTimerHelper(const Token& tokens, const uint32_t max_tokens, + const uint32_t tokens_per_fill) { // Relaxed consistency is used for all operations because we don't care about ordering, just the // final atomic correctness. - uint32_t expected_tokens = tokens_.load(std::memory_order_relaxed); + uint32_t expected_tokens = tokens.tokens.load(std::memory_order_relaxed); uint32_t new_tokens_value; do { // expected_tokens is either initialized above or reloaded during the CAS failure below. - new_tokens_value = std::min(max_tokens_, expected_tokens + tokens_per_fill_); + new_tokens_value = std::min(max_tokens, expected_tokens + tokens_per_fill); // Testing hook. synchronizer_.syncPoint("on_fill_timer_pre_cas"); // Loop while the weak CAS fails trying to update the tokens value. - } while ( - !tokens_.compare_exchange_weak(expected_tokens, new_tokens_value, std::memory_order_relaxed)); + } while (!tokens.tokens.compare_exchange_weak(expected_tokens, new_tokens_value, + std::memory_order_relaxed)); +} - fill_timer_->enableTimer(fill_interval_); +void LocalRateLimiterImpl::onFillTimerDescriptorHelper() { + auto current_time = time_source_.monotonicTime(); + for (const auto& descriptor : tokens_per_descriptor_) { + if (std::chrono::duration_cast(current_time - + descriptor.second->monotonic_time) >= + absl::ToChronoMilliseconds(descriptor.first.token_bucket_.fill_interval_)) { + + onFillTimerHelper(descriptor.second->token, descriptor.first.token_bucket_.max_tokens_, + descriptor.first.token_bucket_.tokens_per_fill_); + + // Update the time. + descriptor.second->monotonic_time = current_time; + } + } } -bool LocalRateLimiterImpl::requestAllowed() const { +bool LocalRateLimiterImpl::requestAllowedHelper(const Token& tokens) const { // Relaxed consistency is used for all operations because we don't care about ordering, just the // final atomic correctness. - uint32_t expected_tokens = tokens_.load(std::memory_order_relaxed); + uint32_t expected_tokens = tokens.tokens.load(std::memory_order_relaxed); do { // expected_tokens is either initialized above or reloaded during the CAS failure below. if (expected_tokens == 0) { @@ -64,13 +120,39 @@ bool LocalRateLimiterImpl::requestAllowed() const { synchronizer_.syncPoint("allowed_pre_cas"); // Loop while the weak CAS fails trying to subtract 1 from expected. - } while (!tokens_.compare_exchange_weak(expected_tokens, expected_tokens - 1, - std::memory_order_relaxed)); + } while (!tokens.tokens.compare_exchange_weak(expected_tokens, expected_tokens - 1, + std::memory_order_relaxed)); // We successfully decremented the counter by 1. return true; } +bool LocalRateLimiterImpl::requestAllowed( + std::vector route_descriptors) const { + const Envoy::RateLimit::LocalDescriptor* descriptor = findDescriptor(route_descriptors); + if (descriptor == nullptr) { + return requestAllowedHelper(tokens_); + } + auto it = tokens_per_descriptor_.find(*descriptor); + return requestAllowedHelper(it->second->token); +} + +const Envoy::RateLimit::LocalDescriptor* LocalRateLimiterImpl::findDescriptor( + std::vector route_descriptors) const { + if (descriptors_.empty() || route_descriptors.empty()) { + return nullptr; + } + for (const auto& config_descriptor : descriptors_) { + for (const auto& route_descriptor : route_descriptors) { + if (std::equal(config_descriptor.entries_.begin(), config_descriptor.entries_.end(), + route_descriptor.entries_.begin())) { + return &config_descriptor; + } + } + } + return nullptr; +} + } // namespace LocalRateLimit } // namespace Common } // namespace Filters diff --git a/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.h b/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.h index 2e35dc5b0ef4b..98325efc35721 100644 --- a/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.h +++ b/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.h @@ -2,10 +2,14 @@ #include +#include "envoy/common/time.h" #include "envoy/event/dispatcher.h" #include "envoy/event/timer.h" +#include "envoy/extensions/common/ratelimit/v3/ratelimit.pb.h" +#include "envoy/ratelimit/ratelimit.h" #include "common/common/thread_synchronizer.h" +#include "common/protobuf/protobuf.h" namespace Envoy { namespace Extensions { @@ -15,22 +19,44 @@ namespace LocalRateLimit { class LocalRateLimiterImpl { public: - LocalRateLimiterImpl(const std::chrono::milliseconds fill_interval, const uint32_t max_tokens, - const uint32_t tokens_per_fill, Event::Dispatcher& dispatcher); + LocalRateLimiterImpl( + const std::chrono::milliseconds fill_interval, const uint32_t max_tokens, + const uint32_t tokens_per_fill, Event::Dispatcher& dispatcher, + const Envoy::Protobuf::RepeatedPtrField< + envoy::extensions::common::ratelimit::v3::LocalRateLimitDescriptor>& descriptors); ~LocalRateLimiterImpl(); - bool requestAllowed() const; + bool requestAllowed(std::vector route_descriptors) const; private: + struct Token { + mutable std::atomic tokens; + }; + struct DescriptorTokenState { + Token token; + Envoy::MonotonicTime monotonic_time; + }; + const Envoy::RateLimit::LocalDescriptor* + findDescriptor(std::vector descriptors) const; void onFillTimer(); + void onFillTimerHelper(const Token& tokens, const uint32_t max_tokens, + const uint32_t tokens_per_fill); + void onFillTimerDescriptorHelper(); + bool requestAllowedHelper(const Token& tokens) const; const std::chrono::milliseconds fill_interval_; const uint32_t max_tokens_; const uint32_t tokens_per_fill_; const Event::TimerPtr fill_timer_; - mutable std::atomic tokens_; + TimeSource& time_source_; + Token tokens_; mutable Thread::ThreadSynchronizer synchronizer_; // Used for testing only. + absl::flat_hash_map> + tokens_per_descriptor_; + + std::vector descriptors_; + friend class LocalRateLimiterImplTest; }; diff --git a/source/extensions/filters/common/ratelimit/BUILD b/source/extensions/filters/common/ratelimit/BUILD index 4bf0b36b1e5de..70bb00f6d2298 100644 --- a/source/extensions/filters/common/ratelimit/BUILD +++ b/source/extensions/filters/common/ratelimit/BUILD @@ -35,9 +35,12 @@ envoy_cc_library( hdrs = ["ratelimit.h"], external_deps = ["abseil_optional"], deps = [ + "//include/envoy/http:filter_interface", "//include/envoy/ratelimit:ratelimit_interface", + "//include/envoy/router:router_ratelimit_interface", "//include/envoy/singleton:manager_interface", "//include/envoy/tracing:http_tracer_interface", + "//source/common/http:utility_lib", "//source/common/stats:symbol_table_lib", "@envoy_api//envoy/service/ratelimit/v3:pkg_cc_proto", ], diff --git a/source/extensions/filters/http/local_ratelimit/BUILD b/source/extensions/filters/http/local_ratelimit/BUILD index 048d7d4ed4e0a..67940ef4bf720 100644 --- a/source/extensions/filters/http/local_ratelimit/BUILD +++ b/source/extensions/filters/http/local_ratelimit/BUILD @@ -23,9 +23,11 @@ envoy_cc_library( "//source/common/common:utility_lib", "//source/common/http:header_utility_lib", "//source/common/http:headers_lib", + "//source/common/router:config_lib", "//source/common/router:header_parser_lib", "//source/common/runtime:runtime_lib", "//source/extensions/filters/common/local_ratelimit:local_ratelimit_lib", + "//source/extensions/filters/common/ratelimit:ratelimit_client_interface", "//source/extensions/filters/http/common:pass_through_filter_lib", "@envoy_api//envoy/extensions/filters/http/local_ratelimit/v3:pkg_cc_proto", ], diff --git a/source/extensions/filters/http/local_ratelimit/config.cc b/source/extensions/filters/http/local_ratelimit/config.cc index 529fd0dd29776..a5e629d88055d 100644 --- a/source/extensions/filters/http/local_ratelimit/config.cc +++ b/source/extensions/filters/http/local_ratelimit/config.cc @@ -17,7 +17,7 @@ Http::FilterFactoryCb LocalRateLimitFilterConfig::createFilterFactoryFromProtoTy const envoy::extensions::filters::http::local_ratelimit::v3::LocalRateLimit& proto_config, const std::string&, Server::Configuration::FactoryContext& context) { FilterConfigSharedPtr filter_config = std::make_shared( - proto_config, context.dispatcher(), context.scope(), context.runtime()); + proto_config, context.localInfo(), context.dispatcher(), context.scope(), context.runtime()); return [filter_config](Http::FilterChainFactoryCallbacks& callbacks) -> void { callbacks.addStreamFilter(std::make_shared(filter_config)); }; @@ -27,7 +27,8 @@ Router::RouteSpecificFilterConfigConstSharedPtr LocalRateLimitFilterConfig::createRouteSpecificFilterConfigTyped( const envoy::extensions::filters::http::local_ratelimit::v3::LocalRateLimit& proto_config, Server::Configuration::ServerFactoryContext& context, ProtobufMessage::ValidationVisitor&) { - return std::make_shared(proto_config, context.dispatcher(), context.scope(), + return std::make_shared(proto_config, context.localInfo(), + context.dispatcher(), context.scope(), context.runtime(), true); } diff --git a/source/extensions/filters/http/local_ratelimit/local_ratelimit.cc b/source/extensions/filters/http/local_ratelimit/local_ratelimit.cc index 3b13bfa374ace..8cb21ca0b7a49 100644 --- a/source/extensions/filters/http/local_ratelimit/local_ratelimit.cc +++ b/source/extensions/filters/http/local_ratelimit/local_ratelimit.cc @@ -6,6 +6,7 @@ #include "envoy/http/codes.h" #include "common/http/utility.h" +#include "common/router/config_impl.h" namespace Envoy { namespace Extensions { @@ -14,16 +15,17 @@ namespace LocalRateLimitFilter { FilterConfig::FilterConfig( const envoy::extensions::filters::http::local_ratelimit::v3::LocalRateLimit& config, - Event::Dispatcher& dispatcher, Stats::Scope& scope, Runtime::Loader& runtime, - const bool per_route) + const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, Stats::Scope& scope, + Runtime::Loader& runtime, const bool per_route) : status_(toErrorCode(config.status().code())), stats_(generateStats(config.stat_prefix(), scope)), rate_limiter_(Filters::Common::LocalRateLimit::LocalRateLimiterImpl( std::chrono::milliseconds( PROTOBUF_GET_MS_OR_DEFAULT(config.token_bucket(), fill_interval, 0)), config.token_bucket().max_tokens(), - PROTOBUF_GET_WRAPPED_OR_DEFAULT(config.token_bucket(), tokens_per_fill, 1), dispatcher)), - runtime_(runtime), + PROTOBUF_GET_WRAPPED_OR_DEFAULT(config.token_bucket(), tokens_per_fill, 1), dispatcher, + config.descriptors())), + local_info_(local_info), runtime_(runtime), filter_enabled_( config.has_filter_enabled() ? absl::optional( @@ -35,7 +37,9 @@ FilterConfig::FilterConfig( Envoy::Runtime::FractionalPercent(config.filter_enforced(), runtime_)) : absl::nullopt), response_headers_parser_( - Envoy::Router::HeaderParser::configure(config.response_headers_to_add())) { + Envoy::Router::HeaderParser::configure(config.response_headers_to_add())), + stage_(static_cast(config.stage())), + are_descriptors_configured_(!config.descriptors().empty()) { // Note: no token bucket is fine for the global config, which would be the case for enabling // the filter globally but disabled and then applying limits at the virtual host or // route level. At the virtual or route level, it makes no sense to have an no token @@ -44,9 +48,21 @@ FilterConfig::FilterConfig( if (per_route && !config.has_token_bucket()) { throw EnvoyException("local rate limit token bucket must be set for per filter configs"); } + + // Note: Descriptors work at the route level, as we get rate limit descriptors + // from route's rate limit actions. Hence configuring descriptors at + // global config level does not makes sense. + if (!per_route && are_descriptors_configured_) { + throw EnvoyException("no descriptors required for global config."); + } +} + +bool FilterConfig::requestAllowed( + std::vector route_descriptors) const { + return rate_limiter_.requestAllowed(route_descriptors); } -bool FilterConfig::requestAllowed() const { return rate_limiter_.requestAllowed(); } +bool FilterConfig::areDescriptorsConfigured() const { return are_descriptors_configured_; } LocalRateLimitStats FilterConfig::generateStats(const std::string& prefix, Stats::Scope& scope) { const std::string final_prefix = prefix + ".http_local_rate_limit"; @@ -61,7 +77,7 @@ bool FilterConfig::enforced() const { return filter_enforced_.has_value() ? filter_enforced_->enabled() : false; } -Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap&, bool) { +Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, bool) { const auto* config = getConfig(); if (!config->enabled()) { @@ -70,7 +86,12 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap&, bool) { config->stats().enabled_.inc(); - if (config->requestAllowed()) { + std::vector route_descriptors; + if (config->areDescriptorsConfigured()) { + getRouteSpecificDescriptors(route_descriptors, headers); + } + + if (config->requestAllowed(route_descriptors)) { config->stats().ok_.inc(); return Http::FilterHeadersStatus::Continue; } @@ -94,6 +115,36 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap&, bool) { return Http::FilterHeadersStatus::StopIteration; } +void Filter::getRouteSpecificDescriptors( + std::vector& local_descriptors, + Http::RequestHeaderMap& headers) { + Router::RouteConstSharedPtr route = decoder_callbacks_->route(); + if (!route || !route->routeEntry()) { + return; + } + + auto cluster = decoder_callbacks_->clusterInfo(); + if (!cluster) { + return; + } + + const Envoy::Router::RouteEntry* route_entry = route->routeEntry(); + // Get all applicable rate limit policy entries for the route. + const auto* config = getConfig(); + for (const Envoy::Router::RateLimitPolicyEntry& rate_limit : + route_entry->rateLimitPolicy().getApplicableRateLimit(config->stage())) { + const std::string& disable_key = rate_limit.disableKey(); + + if (!disable_key.empty()) { + continue; + } + rate_limit.populateLocalDescriptors(*route_entry, local_descriptors, + config->localInfo().clusterName(), headers, + *decoder_callbacks_->streamInfo().downstreamRemoteAddress(), + &decoder_callbacks_->streamInfo().dynamicMetadata()); + } +} + const FilterConfig* Filter::getConfig() const { const auto* config = Http::Utility::resolveMostSpecificPerFilterConfig( "envoy.filters.http.local_ratelimit", decoder_callbacks_->route()); diff --git a/source/extensions/filters/http/local_ratelimit/local_ratelimit.h b/source/extensions/filters/http/local_ratelimit/local_ratelimit.h index 6549094d07c3f..36e47b5663b44 100644 --- a/source/extensions/filters/http/local_ratelimit/local_ratelimit.h +++ b/source/extensions/filters/http/local_ratelimit/local_ratelimit.h @@ -7,9 +7,11 @@ #include "envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.pb.h" #include "envoy/http/filter.h" +#include "envoy/local_info/local_info.h" #include "envoy/runtime/runtime.h" #include "envoy/stats/scope.h" #include "envoy/stats/stats_macros.h" +#include "envoy/upstream/cluster_manager.h" #include "common/common/assert.h" #include "common/http/header_map_impl.h" @@ -17,6 +19,7 @@ #include "common/runtime/runtime_protos.h" #include "extensions/filters/common/local_ratelimit/local_ratelimit_impl.h" +#include "extensions/filters/common/ratelimit/ratelimit.h" #include "extensions/filters/http/common/pass_through_filter.h" namespace Envoy { @@ -43,20 +46,26 @@ struct LocalRateLimitStats { /** * Global configuration for the HTTP local rate limit filter. */ -class FilterConfig : public ::Envoy::Router::RouteSpecificFilterConfig { +class FilterConfig : public ::Envoy::Router::RouteSpecificFilterConfig, + Logger::Loggable { public: FilterConfig(const envoy::extensions::filters::http::local_ratelimit::v3::LocalRateLimit& config, - Event::Dispatcher& dispatcher, Stats::Scope& scope, Runtime::Loader& runtime, - bool per_route = false); + const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, + Stats::Scope& scope, Runtime::Loader& runtime, bool per_route = false); ~FilterConfig() override = default; + const LocalInfo::LocalInfo& localInfo() const { return local_info_; } Runtime::Loader& runtime() { return runtime_; } - bool requestAllowed() const; bool enabled() const; bool enforced() const; + bool requestAllowed(std::vector route_descriptors) const; LocalRateLimitStats& stats() const { return stats_; } const Router::HeaderParser& responseHeadersParser() const { return *response_headers_parser_; } Http::Code status() const { return status_; } + uint64_t stage() const { return stage_; } + + bool areDescriptorsConfigured() const; + private: friend class FilterTest; @@ -73,10 +82,14 @@ class FilterConfig : public ::Envoy::Router::RouteSpecificFilterConfig { const Http::Code status_; mutable LocalRateLimitStats stats_; Filters::Common::LocalRateLimit::LocalRateLimiterImpl rate_limiter_; + + const LocalInfo::LocalInfo& local_info_; Runtime::Loader& runtime_; const absl::optional filter_enabled_; const absl::optional filter_enforced_; Router::HeaderParserPtr response_headers_parser_; + const uint64_t stage_; + const bool are_descriptors_configured_; }; using FilterConfigSharedPtr = std::shared_ptr; @@ -88,7 +101,6 @@ using FilterConfigSharedPtr = std::shared_ptr; class Filter : public Http::PassThroughFilter { public: Filter(FilterConfigSharedPtr config) : config_(config) {} - // Http::StreamDecoderFilter Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap& headers, bool end_stream) override; @@ -96,8 +108,10 @@ class Filter : public Http::PassThroughFilter { private: friend class FilterTest; - const FilterConfig* getConfig() const; + void getRouteSpecificDescriptors(std::vector& descriptors, + Http::RequestHeaderMap& headers); + const FilterConfig* getConfig() const; FilterConfigSharedPtr config_; }; diff --git a/source/extensions/filters/network/local_ratelimit/local_ratelimit.cc b/source/extensions/filters/network/local_ratelimit/local_ratelimit.cc index 773daf1751395..66b0c9700355d 100644 --- a/source/extensions/filters/network/local_ratelimit/local_ratelimit.cc +++ b/source/extensions/filters/network/local_ratelimit/local_ratelimit.cc @@ -18,7 +18,9 @@ Config::Config( PROTOBUF_GET_MS_REQUIRED(proto_config.token_bucket(), fill_interval)), proto_config.token_bucket().max_tokens(), PROTOBUF_GET_WRAPPED_OR_DEFAULT(proto_config.token_bucket(), tokens_per_fill, 1), - dispatcher)), + dispatcher, + Envoy::Protobuf::RepeatedPtrField< + envoy::extensions::common::ratelimit::v3::LocalRateLimitDescriptor>())), enabled_(proto_config.runtime_enabled(), runtime), stats_(generateStats(proto_config.stat_prefix(), scope)) {} @@ -27,7 +29,7 @@ LocalRateLimitStats Config::generateStats(const std::string& prefix, Stats::Scop return {ALL_LOCAL_RATE_LIMIT_STATS(POOL_COUNTER_PREFIX(scope, final_prefix))}; } -bool Config::canCreateConnection() { return rate_limiter_.requestAllowed(); } +bool Config::canCreateConnection() { return rate_limiter_.requestAllowed(descriptors_); } Network::FilterStatus Filter::onNewConnection() { if (!config_->enabled()) { diff --git a/source/extensions/filters/network/local_ratelimit/local_ratelimit.h b/source/extensions/filters/network/local_ratelimit/local_ratelimit.h index e1cd52ac1beed..330852a266dca 100644 --- a/source/extensions/filters/network/local_ratelimit/local_ratelimit.h +++ b/source/extensions/filters/network/local_ratelimit/local_ratelimit.h @@ -49,6 +49,7 @@ class Config : Logger::Loggable { Runtime::FeatureFlag enabled_; LocalRateLimitStats stats_; + std::vector descriptors_; friend class LocalRateLimitTestBase; }; diff --git a/test/common/router/router_ratelimit_test.cc b/test/common/router/router_ratelimit_test.cc index d8ab967f72de5..74b4d03500130 100644 --- a/test/common/router/router_ratelimit_test.cc +++ b/test/common/router/router_ratelimit_test.cc @@ -207,12 +207,17 @@ TEST_F(RateLimitConfiguration, TestVirtualHost) { EXPECT_EQ(1U, rate_limits.size()); std::vector descriptors; + std::vector local_descriptors; for (const RateLimitPolicyEntry& rate_limit : rate_limits) { rate_limit.populateDescriptors(*route_, descriptors, "service_cluster", header_, default_remote_address_, dynamic_metadata_); + rate_limit.populateLocalDescriptors(*route_, local_descriptors, "service_cluster", header_, + default_remote_address_, dynamic_metadata_); } EXPECT_THAT(std::vector({{{{"destination_cluster", "www2test"}}}}), testing::ContainerEq(descriptors)); + EXPECT_THAT(std::vector({{"destination_cluster", "www2test"}}), + testing::ContainerEq(local_descriptors.at(0).entries_)); } TEST_F(RateLimitConfiguration, Stages) { @@ -247,26 +252,40 @@ TEST_F(RateLimitConfiguration, Stages) { EXPECT_EQ(2U, rate_limits.size()); std::vector descriptors; + std::vector local_descriptors; for (const RateLimitPolicyEntry& rate_limit : rate_limits) { rate_limit.populateDescriptors(*route_, descriptors, "service_cluster", header_, default_remote_address_, dynamic_metadata_); + rate_limit.populateLocalDescriptors(*route_, local_descriptors, "service_cluster", header_, + default_remote_address_, dynamic_metadata_); } EXPECT_THAT(std::vector( {{{{"destination_cluster", "www2test"}}}, {{{"destination_cluster", "www2test"}, {"source_cluster", "service_cluster"}}}}), testing::ContainerEq(descriptors)); + Envoy::RateLimit::TokenBucket token_bucket{}; + EXPECT_THAT(std::vector( + {{{{"destination_cluster", "www2test"}}, {token_bucket}}, + {{{"destination_cluster", "www2test"}, {"source_cluster", "service_cluster"}}, + {token_bucket}}}), + testing::ContainerEq(local_descriptors)); descriptors.clear(); + local_descriptors.clear(); rate_limits = route_->rateLimitPolicy().getApplicableRateLimit(1UL); EXPECT_EQ(1U, rate_limits.size()); for (const RateLimitPolicyEntry& rate_limit : rate_limits) { rate_limit.populateDescriptors(*route_, descriptors, "service_cluster", header_, default_remote_address_, dynamic_metadata_); + rate_limit.populateLocalDescriptors(*route_, local_descriptors, "service_cluster", header_, + default_remote_address_, dynamic_metadata_); } EXPECT_THAT(std::vector({{{{"remote_address", "10.0.0.1"}}}}), testing::ContainerEq(descriptors)); - + EXPECT_THAT(std::vector( + {{{{"remote_address", "10.0.0.1"}}, {token_bucket}}}), + testing::ContainerEq(local_descriptors)); rate_limits = route_->rateLimitPolicy().getApplicableRateLimit(10UL); EXPECT_TRUE(rate_limits.empty()); } @@ -276,14 +295,17 @@ class RateLimitPolicyEntryTest : public testing::Test { void setupTest(const std::string& yaml) { rate_limit_entry_ = std::make_unique(parseRateLimitFromV3Yaml(yaml)); descriptors_.clear(); + local_descriptors_.clear(); } std::unique_ptr rate_limit_entry_; Http::TestRequestHeaderMapImpl header_; NiceMock route_; std::vector descriptors_; + std::vector local_descriptors_; Network::Address::Ipv4Instance default_remote_address_{"10.0.0.1"}; const envoy::config::core::v3::Metadata* dynamic_metadata_; + Envoy::RateLimit::TokenBucket token_bucket_{}; }; TEST_F(RateLimitPolicyEntryTest, RateLimitPolicyEntryMembers) { @@ -310,8 +332,13 @@ TEST_F(RateLimitPolicyEntryTest, RemoteAddress) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header_, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header_, + default_remote_address_, dynamic_metadata_); EXPECT_THAT(std::vector({{{{"remote_address", "10.0.0.1"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector( + {{{{"remote_address", "10.0.0.1"}}, {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } // Verify no descriptor is emitted if remote is a pipe. @@ -326,7 +353,10 @@ TEST_F(RateLimitPolicyEntryTest, PipeAddress) { Network::Address::PipeInstance pipe_address("/hello"); rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header_, pipe_address, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header_, pipe_address, + dynamic_metadata_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, SourceService) { @@ -339,9 +369,14 @@ TEST_F(RateLimitPolicyEntryTest, SourceService) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "service_cluster", header_, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "service_cluster", + header_, default_remote_address_, dynamic_metadata_); EXPECT_THAT( std::vector({{{{"source_cluster", "service_cluster"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector( + {{{{"source_cluster", "service_cluster"}}, {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, DestinationService) { @@ -354,9 +389,14 @@ TEST_F(RateLimitPolicyEntryTest, DestinationService) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "service_cluster", header_, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "service_cluster", + header_, default_remote_address_, dynamic_metadata_); EXPECT_THAT( std::vector({{{{"destination_cluster", "fake_cluster"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector( + {{{{"destination_cluster", "fake_cluster"}}, {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, RequestHeaders) { @@ -372,8 +412,13 @@ TEST_F(RateLimitPolicyEntryTest, RequestHeaders) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "service_cluster", header, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "service_cluster", header, + default_remote_address_, dynamic_metadata_); EXPECT_THAT(std::vector({{{{"my_header_name", "test_value"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector( + {{{{"my_header_name", "test_value"}}, {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } // Validate that a descriptor is added if the missing request header @@ -396,8 +441,13 @@ TEST_F(RateLimitPolicyEntryTest, RequestHeadersWithSkipIfAbsent) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "service_cluster", header, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "service_cluster", header, + default_remote_address_, dynamic_metadata_); EXPECT_THAT(std::vector({{{{"my_header_name", "test_value"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector( + {{{{"my_header_name", "test_value"}}, {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } // Tests if the descriptors are added if one of the headers is missing @@ -420,7 +470,10 @@ TEST_F(RateLimitPolicyEntryTest, RequestHeadersWithDefaultSkipIfAbsent) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "service_cluster", header, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "service_cluster", header, + default_remote_address_, dynamic_metadata_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, RequestHeadersNoMatch) { @@ -436,7 +489,10 @@ TEST_F(RateLimitPolicyEntryTest, RequestHeadersNoMatch) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "service_cluster", header, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "service_cluster", header, + default_remote_address_, dynamic_metadata_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, RateLimitKey) { @@ -450,8 +506,13 @@ TEST_F(RateLimitPolicyEntryTest, RateLimitKey) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header_, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header_, + default_remote_address_, dynamic_metadata_); EXPECT_THAT(std::vector({{{{"generic_key", "fake_key"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector( + {{{{"generic_key", "fake_key"}}, {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, GenericKeyWithSetDescriptorKey) { @@ -466,8 +527,13 @@ TEST_F(RateLimitPolicyEntryTest, GenericKeyWithSetDescriptorKey) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header_, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header_, + default_remote_address_, dynamic_metadata_); EXPECT_THAT(std::vector({{{{"fake_key", "fake_value"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector( + {{{{"fake_key", "fake_value"}}, {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, GenericKeyWithEmptyDescriptorKey) { @@ -482,8 +548,13 @@ TEST_F(RateLimitPolicyEntryTest, GenericKeyWithEmptyDescriptorKey) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header_, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header_, + default_remote_address_, dynamic_metadata_); EXPECT_THAT(std::vector({{{{"generic_key", "fake_value"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector( + {{{{"generic_key", "fake_value"}}, {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, DEPRECATED_FEATURE_TEST(DynamicMetaDataMatch)) { @@ -513,9 +584,13 @@ TEST_F(RateLimitPolicyEntryTest, DEPRECATED_FEATURE_TEST(DynamicMetaDataMatch)) rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header_, default_remote_address_, &metadata); - + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header_, + default_remote_address_, &metadata); EXPECT_THAT(std::vector({{{{"fake_key", "foo"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT( + std::vector({{{{"fake_key", "foo"}}, {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, MetaDataMatchDynamicSourceByDefault) { @@ -545,9 +620,13 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataMatchDynamicSourceByDefault) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header_, default_remote_address_, &metadata); - + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header_, + default_remote_address_, &metadata); EXPECT_THAT(std::vector({{{{"fake_key", "foo"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT( + std::vector({{{{"fake_key", "foo"}}, {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, MetaDataMatchDynamicSource) { @@ -578,9 +657,13 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataMatchDynamicSource) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header_, default_remote_address_, &metadata); - + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header_, + default_remote_address_, &metadata); EXPECT_THAT(std::vector({{{{"fake_key", "foo"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT( + std::vector({{{{"fake_key", "foo"}}, {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, MetaDataMatchRouteEntrySource) { @@ -610,9 +693,13 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataMatchRouteEntrySource) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header_, default_remote_address_, dynamic_metadata_); - + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header_, + default_remote_address_, dynamic_metadata_); EXPECT_THAT(std::vector({{{{"fake_key", "foo"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT( + std::vector({{{{"fake_key", "foo"}}, {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } // Tests that the default_value is used in the descriptor when the metadata_key is empty. @@ -643,9 +730,13 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataNoMatchWithDefaultValue) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header_, default_remote_address_, &metadata); - + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header_, + default_remote_address_, &metadata); EXPECT_THAT(std::vector({{{{"fake_key", "fake_value"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector( + {{{{"fake_key", "fake_value"}}, {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, MetaDataNoMatch) { @@ -674,8 +765,10 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataNoMatch) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header_, default_remote_address_, &metadata); - + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header_, + default_remote_address_, &metadata); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, MetaDataEmptyValue) { @@ -704,8 +797,11 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataEmptyValue) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header_, default_remote_address_, &metadata); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header_, + default_remote_address_, &metadata); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } // Tests that no descriptor is generated when both the metadata_key and default_value are empty. TEST_F(RateLimitPolicyEntryTest, MetaDataAndDefaultValueEmpty) { @@ -735,8 +831,10 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataAndDefaultValueEmpty) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header_, default_remote_address_, &metadata); - + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header_, + default_remote_address_, &metadata); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, MetaDataNonStringNoMatch) { @@ -766,8 +864,10 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataNonStringNoMatch) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header_, default_remote_address_, &metadata); - + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header_, + default_remote_address_, &metadata); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, HeaderValueMatch) { @@ -785,8 +885,13 @@ TEST_F(RateLimitPolicyEntryTest, HeaderValueMatch) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header, + default_remote_address_, dynamic_metadata_); EXPECT_THAT(std::vector({{{{"header_match", "fake_value"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector( + {{{{"header_match", "fake_value"}}, {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, HeaderValueMatchNoMatch) { @@ -804,7 +909,10 @@ TEST_F(RateLimitPolicyEntryTest, HeaderValueMatchNoMatch) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header, + default_remote_address_, dynamic_metadata_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, HeaderValueMatchHeadersNotPresent) { @@ -823,8 +931,13 @@ TEST_F(RateLimitPolicyEntryTest, HeaderValueMatchHeadersNotPresent) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header, + default_remote_address_, dynamic_metadata_); EXPECT_THAT(std::vector({{{{"header_match", "fake_value"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector( + {{{{"header_match", "fake_value"}}, {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, HeaderValueMatchHeadersPresent) { @@ -843,7 +956,10 @@ TEST_F(RateLimitPolicyEntryTest, HeaderValueMatchHeadersPresent) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "", header, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "", header, + default_remote_address_, dynamic_metadata_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, CompoundActions) { @@ -857,10 +973,17 @@ TEST_F(RateLimitPolicyEntryTest, CompoundActions) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "service_cluster", header_, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "service_cluster", + header_, default_remote_address_, dynamic_metadata_); EXPECT_THAT( std::vector( {{{{"destination_cluster", "fake_cluster"}, {"source_cluster", "service_cluster"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT( + std::vector( + {{{{"destination_cluster", "fake_cluster"}, {"source_cluster", "service_cluster"}}, + {token_bucket_}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, CompoundActionsNoDescriptor) { @@ -878,7 +1001,10 @@ TEST_F(RateLimitPolicyEntryTest, CompoundActionsNoDescriptor) { rate_limit_entry_->populateDescriptors(route_, descriptors_, "service_cluster", header_, default_remote_address_, dynamic_metadata_); + rate_limit_entry_->populateLocalDescriptors(route_, local_descriptors_, "service_cluster", + header_, default_remote_address_, dynamic_metadata_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, DynamicMetadataRateLimitOverride) { diff --git a/test/extensions/filters/common/local_ratelimit/local_ratelimit_test.cc b/test/extensions/filters/common/local_ratelimit/local_ratelimit_test.cc index a6142dfb16aa4..49a494c03707d 100644 --- a/test/extensions/filters/common/local_ratelimit/local_ratelimit_test.cc +++ b/test/extensions/filters/common/local_ratelimit/local_ratelimit_test.cc @@ -1,6 +1,7 @@ #include "extensions/filters/common/local_ratelimit/local_ratelimit_impl.h" #include "test/mocks/event/mocks.h" +#include "test/test_common/utility.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -16,19 +17,27 @@ namespace LocalRateLimit { class LocalRateLimiterImplTest : public testing::Test { public: - void initialize(const std::chrono::milliseconds fill_interval, const uint32_t max_tokens, - const uint32_t tokens_per_fill) { - + void initializeTimer() { fill_timer_ = new Event::MockTimer(&dispatcher_); EXPECT_CALL(*fill_timer_, enableTimer(_, nullptr)); EXPECT_CALL(*fill_timer_, disableTimer()); + } + + void initialize(const std::chrono::milliseconds fill_interval, const uint32_t max_tokens, + const uint32_t tokens_per_fill) { - rate_limiter_ = std::make_shared(fill_interval, max_tokens, - tokens_per_fill, dispatcher_); + initializeTimer(); + + rate_limiter_ = std::make_shared( + fill_interval, max_tokens, tokens_per_fill, dispatcher_, descriptors_); } Thread::ThreadSynchronizer& synchronizer() { return rate_limiter_->synchronizer_; } + Envoy::Protobuf::RepeatedPtrField< + envoy::extensions::common::ratelimit::v3::LocalRateLimitDescriptor> + descriptors_; + std::vector route_descriptors_; NiceMock dispatcher_; Event::MockTimer* fill_timer_{}; std::shared_ptr rate_limiter_; @@ -37,8 +46,8 @@ class LocalRateLimiterImplTest : public testing::Test { // Make sure we fail with a fill rate this is too fast. TEST_F(LocalRateLimiterImplTest, TooFastFillRate) { EXPECT_THROW_WITH_MESSAGE( - LocalRateLimiterImpl(std::chrono::milliseconds(49), 100, 1, dispatcher_), EnvoyException, - "local rate limit token bucket fill timer must be >= 50ms"); + LocalRateLimiterImpl(std::chrono::milliseconds(49), 100, 1, dispatcher_, descriptors_), + EnvoyException, "local rate limit token bucket fill timer must be >= 50ms"); } // Verify various token bucket CAS edge cases. @@ -59,15 +68,15 @@ TEST_F(LocalRateLimiterImplTest, CasEdgeCases) { synchronizer().barrierOn("on_fill_timer_pre_cas"); // This should succeed. - EXPECT_TRUE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); // Now signal the thread to continue which should cause a CAS failure and the loop to repeat. synchronizer().signal("on_fill_timer_pre_cas"); t1.join(); // 1 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); } // This tests the case in which two allowed checks race. @@ -78,12 +87,12 @@ TEST_F(LocalRateLimiterImplTest, CasEdgeCases) { // Start a thread and see if we are under limit. This will wait pre-CAS. synchronizer().waitOn("allowed_pre_cas"); - std::thread t1([&] { EXPECT_FALSE(rate_limiter_->requestAllowed()); }); + std::thread t1([&] { EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); }); // Wait until the thread is actually waiting. synchronizer().barrierOn("allowed_pre_cas"); // Consume a token on this thread, which should cause the CAS to fail on the other thread. - EXPECT_TRUE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); synchronizer().signal("allowed_pre_cas"); t1.join(); } @@ -94,17 +103,17 @@ TEST_F(LocalRateLimiterImplTest, TokenBucket) { initialize(std::chrono::milliseconds(200), 1, 1); // 1 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); // 0 -> 1 tokens EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); fill_timer_->invokeCallback(); // 1 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); // 0 -> 1 tokens EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); @@ -115,8 +124,8 @@ TEST_F(LocalRateLimiterImplTest, TokenBucket) { fill_timer_->invokeCallback(); // 1 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); } // Verify token bucket functionality with max tokens and tokens per fill > 1. @@ -124,25 +133,25 @@ TEST_F(LocalRateLimiterImplTest, TokenBucketMultipleTokensPerFill) { initialize(std::chrono::milliseconds(200), 2, 2); // 2 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); // 0 -> 2 tokens EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); fill_timer_->invokeCallback(); // 2 -> 1 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); // 1 -> 2 tokens EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); fill_timer_->invokeCallback(); // 2 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); } // Verify token bucket functionality with max tokens > tokens per fill. @@ -150,17 +159,235 @@ TEST_F(LocalRateLimiterImplTest, TokenBucketMaxTokensGreaterThanTokensPerFill) { initialize(std::chrono::milliseconds(200), 2, 1); // 2 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); // 0 -> 1 tokens EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); fill_timer_->invokeCallback(); // 1 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); +} + +class LocalRateLimiterDescriptorImplTest : public LocalRateLimiterImplTest { +public: + void initializeWithDescriptor(const std::chrono::milliseconds fill_interval, + const uint32_t max_tokens, const uint32_t tokens_per_fill) { + + initializeTimer(); + + rate_limiter_ = std::make_shared( + fill_interval, max_tokens, tokens_per_fill, dispatcher_, descriptors_); + } + const std::string single_descriptor_config_yaml = R"( + entries: + - key: foo2 + value: bar2 + token_bucket: + max_tokens: {} + tokens_per_fill: {} + fill_interval: {} + )"; + + const std::string multiple_descriptor_config_yaml = R"( + entries: + - key: hello + value: world + - key: foo + value: bar + token_bucket: + max_tokens: 1 + tokens_per_fill: 1 + fill_interval: 0.05s + )"; + + // Default token bucket + RateLimit::TokenBucket bucket; + std::vector descriptor_{{{{"foo2", "bar2"}}, {bucket}}}; + std::vector descriptor2_{{{{ + {"hello", "world"}, + {"foo", "bar"}, + }, + bucket}}}; +}; + +// Verify descriptor rate limit time interval is multiple of token bucket fill interval. +TEST_F(LocalRateLimiterDescriptorImplTest, DescriptorRateLimitDivisibleByTokenFillInterval) { + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 10, 10, "60s"), + *descriptors_.Add()); + + EXPECT_THROW_WITH_MESSAGE( + LocalRateLimiterImpl(std::chrono::milliseconds(59000), 2, 1, dispatcher_, descriptors_), + EnvoyException, "local rate descriptor limit is not a multiple of token bucket fill timer"); +} + +// Verify no exception for per route config without descriptors. +TEST_F(LocalRateLimiterDescriptorImplTest, DescriptorRateLimitNoExceptionWithoutDescriptor) { + VERBOSE_EXPECT_NO_THROW( + LocalRateLimiterImpl(std::chrono::milliseconds(59000), 2, 1, dispatcher_, descriptors_)); +} + +// Verify various token bucket CAS edge cases for descriptors. +TEST_F(LocalRateLimiterDescriptorImplTest, CasEdgeCasesDescriptor) { + // This tests the case in which an allowed check races with the fill timer. + { + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "0.1s"), + *descriptors_.Add()); + initializeWithDescriptor(std::chrono::milliseconds(50), 1, 1); + + synchronizer().enable(); + + // Start a thread and start the fill callback. This will wait pre-CAS. + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(100), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + synchronizer().waitOn("on_fill_timer_pre_cas"); + std::thread t1([&] { + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); + fill_timer_->invokeCallback(); + }); + // Wait until the thread is actually waiting. + synchronizer().barrierOn("on_fill_timer_pre_cas"); + + // This should succeed. + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + + // Now signal the thread to continue which should cause a CAS failure and the loop to repeat. + synchronizer().signal("on_fill_timer_pre_cas"); + t1.join(); + + // 1 -> 0 tokens + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); + } + + // This tests the case in which two allowed checks race. + { + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "0.1s"), + *descriptors_.Add()); + initializeWithDescriptor(std::chrono::milliseconds(50), 1, 1); + + synchronizer().enable(); + + // Start a thread and see if we are under limit. This will wait pre-CAS. + synchronizer().waitOn("allowed_pre_cas"); + std::thread t1([&] { EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); }); + // Wait until the thread is actually waiting. + synchronizer().barrierOn("allowed_pre_cas"); + + // Consume a token on this thread, which should cause the CAS to fail on the other thread. + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + synchronizer().signal("allowed_pre_cas"); + t1.join(); + } +} + +TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDescriptor2) { + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "0.1s"), + *descriptors_.Add()); + initializeWithDescriptor(std::chrono::milliseconds(50), 1, 1); + + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(100), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); +} + +// Verify token bucket functionality with a single token. +TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDescriptor) { + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "0.1s"), + *descriptors_.Add()); + initializeWithDescriptor(std::chrono::milliseconds(50), 1, 1); + + // 1 -> 0 tokens + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); + + // 0 -> 1 tokens + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(100), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); + fill_timer_->invokeCallback(); + + // 1 -> 0 tokens + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); + + // 0 -> 1 tokens + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(100), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); + fill_timer_->invokeCallback(); + + // 1 -> 1 tokens + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(100), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); + fill_timer_->invokeCallback(); + + // 1 -> 0 tokens + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); +} + +// Verify token bucket functionality with request per unit > 1. +TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketMultipleTokensPerFillDescriptor) { + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 2, 2, "0.1s"), + *descriptors_.Add()); + initializeWithDescriptor(std::chrono::milliseconds(50), 2, 2); + + // 2 -> 0 tokens + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); + + // 0 -> 2 tokens + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(100), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); + fill_timer_->invokeCallback(); + + // 2 -> 1 tokens + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + + // 1 -> 2 tokens + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(100), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); + fill_timer_->invokeCallback(); + + // 2 -> 0 tokens + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); +} + +// Verify token bucket functionality with multiple descriptors. +TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDifferentDescriptorDifferentRateLimits) { + TestUtility::loadFromYaml(multiple_descriptor_config_yaml, *descriptors_.Add()); + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "1000s"), + *descriptors_.Add()); + initializeWithDescriptor(std::chrono::milliseconds(50), 2, 1); + + // 1 -> 0 tokens for descriptor_ and descriptor2_ + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor2_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor2_)); + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); + + // 0 -> 1 tokens for descriptor2_ + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(50), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); + fill_timer_->invokeCallback(); + + // 1 -> 0 tokens for descriptor2_ and 0 only for descriptor_ + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor2_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor2_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); } } // Namespace LocalRateLimit diff --git a/test/extensions/filters/http/local_ratelimit/BUILD b/test/extensions/filters/http/local_ratelimit/BUILD index 38cd85098ee73..0aeca92a857a8 100644 --- a/test/extensions/filters/http/local_ratelimit/BUILD +++ b/test/extensions/filters/http/local_ratelimit/BUILD @@ -19,6 +19,7 @@ envoy_extension_cc_test( "//source/extensions/filters/http/local_ratelimit:local_ratelimit_lib", "//test/common/stream_info:test_util", "//test/mocks/http:http_mocks", + "//test/mocks/local_info:local_info_mocks", "@envoy_api//envoy/extensions/filters/http/local_ratelimit/v3:pkg_cc_proto", ], ) diff --git a/test/extensions/filters/http/local_ratelimit/config_test.cc b/test/extensions/filters/http/local_ratelimit/config_test.cc index 3f48a5830e21a..d22a4642c2a91 100644 --- a/test/extensions/filters/http/local_ratelimit/config_test.cc +++ b/test/extensions/filters/http/local_ratelimit/config_test.cc @@ -63,7 +63,8 @@ stat_prefix: test const auto route_config = factory.createRouteSpecificFilterConfig( *proto_config, context, ProtobufMessage::getNullValidationVisitor()); const auto* config = dynamic_cast(route_config.get()); - EXPECT_TRUE(config->requestAllowed()); + std::vector route_descriptors; + EXPECT_TRUE(config->requestAllowed(route_descriptors)); } TEST(Factory, EnabledEnforcedDisabledByDefault) { @@ -125,6 +126,210 @@ stat_prefix: test EnvoyException); } +TEST(Factory, RouteSpecificFilterConfigWithDescriptorsWithNoTokenBucket) { + const std::string config_yaml = R"( +stat_prefix: test +token_bucket: + max_tokens: 1 + tokens_per_fill: 1 + fill_interval: 1000s +filter_enabled: + runtime_key: test_enabled + default_value: + numerator: 100 + denominator: HUNDRED +filter_enforced: + runtime_key: test_enforced + default_value: + numerator: 100 + denominator: HUNDRED +response_headers_to_add: + - append: false + header: + key: x-test-rate-limit + value: 'true' +descriptors: +- entries: + - key: hello + value: world + - key: foo + value: bar +- entries: + - key: foo2 + value: bar2 + )"; + + LocalRateLimitFilterConfig factory; + ProtobufTypes::MessagePtr proto_config = factory.createEmptyRouteConfigProto(); + TestUtility::loadFromYaml(config_yaml, *proto_config); + + NiceMock context; + + EXPECT_CALL(context.dispatcher_, createTimer_(_)).Times(0); + ; + EXPECT_THROW(factory.createRouteSpecificFilterConfig(*proto_config, context, + ProtobufMessage::getNullValidationVisitor()), + EnvoyException); +} + +TEST(Factory, RouteSpecificFilterConfigWithDescriptors) { + const std::string config_yaml = R"( +stat_prefix: test +token_bucket: + max_tokens: 1 + tokens_per_fill: 1 + fill_interval: 60s +filter_enabled: + runtime_key: test_enabled + default_value: + numerator: 100 + denominator: HUNDRED +filter_enforced: + runtime_key: test_enforced + default_value: + numerator: 100 + denominator: HUNDRED +response_headers_to_add: + - append: false + header: + key: x-test-rate-limit + value: 'true' +descriptors: +- entries: + - key: hello + value: world + - key: foo + value: bar + token_bucket: + max_tokens: 10 + tokens_per_fill: 10 + fill_interval: 60s +- entries: + - key: foo2 + value: bar2 + token_bucket: + max_tokens: 100 + tokens_per_fill: 100 + fill_interval: 3600s + )"; + + LocalRateLimitFilterConfig factory; + ProtobufTypes::MessagePtr proto_config = factory.createEmptyRouteConfigProto(); + TestUtility::loadFromYaml(config_yaml, *proto_config); + + NiceMock context; + + EXPECT_CALL(context.dispatcher_, createTimer_(_)); + const auto route_config = factory.createRouteSpecificFilterConfig( + *proto_config, context, ProtobufMessage::getNullValidationVisitor()); + const auto* config = dynamic_cast(route_config.get()); + std::vector route_descriptors; + EXPECT_TRUE(config->requestAllowed(route_descriptors)); +} + +TEST(Factory, RouteSpecificFilterConfigWithDescriptorsTimerNotDivisible) { + const std::string config_yaml = R"( +stat_prefix: test +token_bucket: + max_tokens: 1 + tokens_per_fill: 1 + fill_interval: 100s +filter_enabled: + runtime_key: test_enabled + default_value: + numerator: 100 + denominator: HUNDRED +filter_enforced: + runtime_key: test_enforced + default_value: + numerator: 100 + denominator: HUNDRED +response_headers_to_add: + - append: false + header: + key: x-test-rate-limit + value: 'true' +descriptors: +- entries: + - key: hello + value: world + - key: foo + value: bar + token_bucket: + max_tokens: 10 + tokens_per_fill: 10 + fill_interval: 1s +- entries: + - key: foo2 + value: bar2 + token_bucket: + max_tokens: 100 + tokens_per_fill: 100 + fill_interval: 86400s + )"; + + LocalRateLimitFilterConfig factory; + ProtobufTypes::MessagePtr proto_config = factory.createEmptyRouteConfigProto(); + TestUtility::loadFromYaml(config_yaml, *proto_config); + + NiceMock context; + + EXPECT_CALL(context.dispatcher_, createTimer_(_)); + EXPECT_THROW(factory.createRouteSpecificFilterConfig(*proto_config, context, + ProtobufMessage::getNullValidationVisitor()), + EnvoyException); +} + +TEST(Factory, GlobalConfigWithDescriptors) { + const std::string config_yaml = R"( +stat_prefix: test +token_bucket: + max_tokens: 1 + tokens_per_fill: 1 + fill_interval: 60s +filter_enabled: + runtime_key: test_enabled + default_value: + numerator: 100 + denominator: HUNDRED +filter_enforced: + runtime_key: test_enforced + default_value: + numerator: 100 + denominator: HUNDRED +response_headers_to_add: + - append: false + header: + key: x-test-rate-limit + value: 'true' +descriptors: +- entries: + - key: hello + value: world + - key: foo + value: bar + token_bucket: + max_tokens: 10 + tokens_per_fill: 10 + fill_interval: 60s +- entries: + - key: foo2 + value: bar2 + token_bucket: + max_tokens: 100 + tokens_per_fill: 100 + fill_interval: 3600s + )"; + + LocalRateLimitFilterConfig factory; + ProtobufTypes::MessagePtr proto_config = factory.createEmptyRouteConfigProto(); + TestUtility::loadFromYaml(config_yaml, *proto_config); + + NiceMock context; + EXPECT_THROW(factory.createFilterFactoryFromProto(*proto_config, "stats", context), + EnvoyException); +} + } // namespace LocalRateLimitFilter } // namespace HttpFilters } // namespace Extensions diff --git a/test/extensions/filters/http/local_ratelimit/filter_test.cc b/test/extensions/filters/http/local_ratelimit/filter_test.cc index 9662f9a783e10..e72c48ca6a493 100644 --- a/test/extensions/filters/http/local_ratelimit/filter_test.cc +++ b/test/extensions/filters/http/local_ratelimit/filter_test.cc @@ -3,6 +3,7 @@ #include "extensions/filters/http/local_ratelimit/local_ratelimit.h" #include "test/mocks/http/mocks.h" +#include "test/mocks/local_info/mocks.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -39,7 +40,8 @@ class FilterTest : public testing::Test { public: FilterTest() = default; - void setup(const std::string& yaml, const bool enabled = true, const bool enforced = true) { + void setupPerRoute(const std::string& yaml, const bool enabled = true, const bool enforced = true, + const bool per_route = false) { EXPECT_CALL( runtime_.snapshot_, featureEnabled(absl::string_view("test_enabled"), @@ -53,10 +55,14 @@ class FilterTest : public testing::Test { envoy::extensions::filters::http::local_ratelimit::v3::LocalRateLimit config; TestUtility::loadFromYaml(yaml, config); - config_ = std::make_shared(config, dispatcher_, stats_, runtime_); + config_ = std::make_shared(config, local_info_, dispatcher_, stats_, runtime_, + per_route); filter_ = std::make_shared(config_); filter_->setDecoderFilterCallbacks(decoder_callbacks_); } + void setup(const std::string& yaml, const bool enabled = true, const bool enforced = true) { + setupPerRoute(yaml, enabled, enforced); + } uint64_t findCounter(const std::string& name) { const auto counter = TestUtility::findCounter(stats_, name); @@ -69,6 +75,7 @@ class FilterTest : public testing::Test { testing::NiceMock decoder_callbacks_; NiceMock dispatcher_; NiceMock runtime_; + NiceMock local_info_; std::shared_ptr config_; std::shared_ptr filter_; }; @@ -140,6 +147,191 @@ TEST_F(FilterTest, RequestRateLimitedButNotEnforced) { EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.rate_limited")); } +static const std::string descriptor_config_yaml = R"( +stat_prefix: test +token_bucket: + max_tokens: {} + tokens_per_fill: 1 + fill_interval: 60s +filter_enabled: + runtime_key: test_enabled + default_value: + numerator: 100 + denominator: HUNDRED +filter_enforced: + runtime_key: test_enforced + default_value: + numerator: 100 + denominator: HUNDRED +response_headers_to_add: + - append: false + header: + key: x-test-rate-limit + value: 'true' +descriptors: +- entries: + - key: hello + value: world + - key: foo + value: bar + token_bucket: + max_tokens: 10 + tokens_per_fill: 10 + fill_interval: 60s +- entries: + - key: foo2 + value: bar2 + token_bucket: + max_tokens: {} + tokens_per_fill: 1 + fill_interval: 60s +stage: {} + )"; + +class DescriptorFilterTest : public FilterTest { +public: + DescriptorFilterTest() = default; + + void setUpTest(const std::string& yaml) { + setupPerRoute(yaml, true, true, true); + decoder_callbacks_.route_->route_entry_.rate_limit_policy_.rate_limit_policy_entry_.clear(); + decoder_callbacks_.route_->route_entry_.rate_limit_policy_.rate_limit_policy_entry_ + .emplace_back(route_rate_limit_); + } + + // Default token bucket + RateLimit::TokenBucket bucket; + std::vector descriptor_{{{{"foo2", "bar2"}}, {bucket}}}; + std::vector descriptor_first_match_{{ + {{ + {"hello", "world"}, + {"foo", "bar"}, + }, + {bucket}}, + {{{"foo2", "bar2"}}, {bucket}}, + }}; + std::vector descriptor_not_found_{{{{"foo", "bar"}}, {bucket}}}; + NiceMock route_rate_limit_; +}; + +TEST_F(DescriptorFilterTest, NoRouteEntry) { + setupPerRoute(fmt::format(descriptor_config_yaml, "1", "1", "0"), true, true, true); + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.ok")); +} + +TEST_F(DescriptorFilterTest, NoCluster) { + setUpTest(fmt::format(descriptor_config_yaml, "1", "1", "0")); + + EXPECT_CALL(decoder_callbacks_, clusterInfo()).WillRepeatedly(testing::Return(nullptr)); + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.ok")); +} + +TEST_F(DescriptorFilterTest, DisabledInRoute) { + setUpTest(fmt::format(descriptor_config_yaml, "1", "1", "0")); + + EXPECT_CALL(decoder_callbacks_.route_->route_entry_.rate_limit_policy_, + getApplicableRateLimit(0)); + + route_rate_limit_.disable_key_ = "disabled"; + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.ok")); +} + +TEST_F(DescriptorFilterTest, RouteDescriptorRequestOk) { + setUpTest(fmt::format(descriptor_config_yaml, "1", "1", "0")); + + EXPECT_CALL(decoder_callbacks_.route_->route_entry_.rate_limit_policy_, + getApplicableRateLimit(0)); + + EXPECT_CALL(route_rate_limit_, populateLocalDescriptors(_, _, _, _, _, _)) + .WillOnce(testing::SetArgReferee<1>(descriptor_)); + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.ok")); +} + +TEST_F(DescriptorFilterTest, RouteDescriptorRequestRatelimited) { + setUpTest(fmt::format(descriptor_config_yaml, "0", "0", "0")); + + EXPECT_CALL(decoder_callbacks_.route_->route_entry_.rate_limit_policy_, + getApplicableRateLimit(0)); + + EXPECT_CALL(route_rate_limit_, populateLocalDescriptors(_, _, _, _, _, _)) + .WillOnce(testing::SetArgReferee<1>(descriptor_)); + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.rate_limited")); +} + +TEST_F(DescriptorFilterTest, RouteDescriptorNotFound) { + setUpTest(fmt::format(descriptor_config_yaml, "1", "1", "0")); + + EXPECT_CALL(decoder_callbacks_.route_->route_entry_.rate_limit_policy_, + getApplicableRateLimit(0)); + + EXPECT_CALL(route_rate_limit_, populateLocalDescriptors(_, _, _, _, _, _)) + .WillOnce(testing::SetArgReferee<1>(descriptor_not_found_)); + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.rate_limited")); +} + +TEST_F(DescriptorFilterTest, RouteDescriptorFirstMatch) { + // Request should not be rate limited as it should match first descriptor with 10 req/min + setUpTest(fmt::format(descriptor_config_yaml, "0", "0", "0")); + + EXPECT_CALL(decoder_callbacks_.route_->route_entry_.rate_limit_policy_, + getApplicableRateLimit(0)); + + EXPECT_CALL(route_rate_limit_, populateLocalDescriptors(_, _, _, _, _, _)) + .WillOnce(testing::SetArgReferee<1>(descriptor_first_match_)); + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.rate_limited")); +} + +TEST_F(DescriptorFilterTest, RouteDescriptorWithStageConfig) { + setUpTest(fmt::format(descriptor_config_yaml, "1", "1", "1")); + + EXPECT_CALL(decoder_callbacks_.route_->route_entry_.rate_limit_policy_, + getApplicableRateLimit(1)); + + EXPECT_CALL(route_rate_limit_, populateLocalDescriptors(_, _, _, _, _, _)) + .WillOnce(testing::SetArgReferee<1>(descriptor_)); + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.ok")); +} + } // namespace LocalRateLimitFilter } // namespace HttpFilters } // namespace Extensions diff --git a/test/mocks/ratelimit/mocks.h b/test/mocks/ratelimit/mocks.h index 7f983beabbca8..2cb8fbef7471f 100644 --- a/test/mocks/ratelimit/mocks.h +++ b/test/mocks/ratelimit/mocks.h @@ -14,10 +14,6 @@ inline bool operator==(const RateLimitOverride& lhs, const RateLimitOverride& rh return lhs.requests_per_unit_ == rhs.requests_per_unit_ && lhs.unit_ == rhs.unit_; } -inline bool operator==(const DescriptorEntry& lhs, const DescriptorEntry& rhs) { - return lhs.key_ == rhs.key_ && lhs.value_ == rhs.value_; -} - inline bool operator==(const Descriptor& lhs, const Descriptor& rhs) { return lhs.entries_ == rhs.entries_ && lhs.limit_ == rhs.limit_; } diff --git a/test/mocks/router/mocks.h b/test/mocks/router/mocks.h index 02be9a0824c15..6ddd5390ffdd2 100644 --- a/test/mocks/router/mocks.h +++ b/test/mocks/router/mocks.h @@ -195,6 +195,12 @@ class MockRateLimitPolicyEntry : public RateLimitPolicyEntry { const Network::Address::Instance& remote_address, const envoy::config::core::v3::Metadata* dynamic_metadata), (const)); + MOCK_METHOD(void, populateLocalDescriptors, + (const RouteEntry& route, std::vector& descriptors, + const std::string& local_service_cluster, const Http::HeaderMap& headers, + const Network::Address::Instance& remote_address, + const envoy::config::core::v3::Metadata* dynamic_metadata), + (const)); uint64_t stage_{}; std::string disable_key_;