diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 870d274f3d..c9ee1c6a1e 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -1315,10 +1315,12 @@ func (s *graphServer) buildGraphMux( if s.redisClient != nil { handlerOpts.RateLimitConfig = s.rateLimit handlerOpts.RateLimiter, err = NewCosmoRateLimiter(&CosmoRateLimiterOptions{ - RedisClient: s.redisClient, - Debug: s.rateLimit.Debug, - RejectStatusCode: s.rateLimit.SimpleStrategy.RejectStatusCode, + RedisClient: s.redisClient, + Debug: s.rateLimit.Debug, + //RejectStatusCode: s.rateLimit.SimpleStrategy.RejectStatusCode, KeySuffixExpression: s.rateLimit.KeySuffixExpression, + RateLimitConfig: s.rateLimit.SimpleStrategy, + BaseRateLimitKey: s.rateLimit.Storage.KeyPrefix, ExprManager: exprManager, }) if err != nil { diff --git a/router/core/graphql_handler.go b/router/core/graphql_handler.go index c494fff4ce..e443c8f4ba 100644 --- a/router/core/graphql_handler.go +++ b/router/core/graphql_handler.go @@ -269,12 +269,8 @@ func (h *GraphQLHandler) configureRateLimiting(ctx *resolve.Context) *resolve.Co ctx.SetRateLimiter(h.rateLimiter) ctx.RateLimitOptions = resolve.RateLimitOptions{ Enable: true, - IncludeStatsInResponseExtension: !h.rateLimitConfig.SimpleStrategy.HideStatsFromResponseExtension, - Rate: h.rateLimitConfig.SimpleStrategy.Rate, - Burst: h.rateLimitConfig.SimpleStrategy.Burst, - Period: h.rateLimitConfig.SimpleStrategy.Period, - RateLimitKey: h.rateLimitConfig.Storage.KeyPrefix, - RejectExceedingRequests: h.rateLimitConfig.SimpleStrategy.RejectExceedingRequests, + IncludeStatsInResponseExtension: !h.rateLimitConfig.HideStatsFromResponseExtension, + //RateLimitKey: h.rateLimitConfig.Storage.KeyPrefix, ErrorExtensionCode: resolve.RateLimitErrorExtensionCode{ Enabled: h.rateLimitConfig.ErrorExtensionCode.Enabled, Code: h.rateLimitConfig.ErrorExtensionCode.Code, @@ -317,7 +313,7 @@ func (h *GraphQLHandler) WriteError(ctx *resolve.Context, err error, res *resolv Code: h.rateLimitConfig.ErrorExtensionCode.Code, } } - if !h.rateLimitConfig.SimpleStrategy.HideStatsFromResponseExtension { + if !h.rateLimitConfig.HideStatsFromResponseExtension { buf := bytes.NewBuffer(make([]byte, 0, 1024)) err = h.rateLimiter.RenderResponseExtension(ctx, buf) if err != nil { diff --git a/router/core/ratelimiter.go b/router/core/ratelimiter.go index 0e47f30b60..f6dd3d9ab8 100644 --- a/router/core/ratelimiter.go +++ b/router/core/ratelimiter.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" rd "github.com/wundergraph/cosmo/router/internal/persistedoperation/operationstorage/redis" + "github.com/wundergraph/cosmo/router/pkg/config" "io" "reflect" "sync" @@ -25,10 +26,12 @@ type CosmoRateLimiterOptions struct { RedisClient rd.RDCloser Debug bool - RejectStatusCode int - KeySuffixExpression string ExprManager *expr.Manager + + BaseRateLimitKey string + + RateLimitConfig config.RateLimitSimpleStrategy } func NewCosmoRateLimiter(opts *CosmoRateLimiterOptions) (rl *CosmoRateLimiter, err error) { @@ -37,7 +40,8 @@ func NewCosmoRateLimiter(opts *CosmoRateLimiterOptions) (rl *CosmoRateLimiter, e client: opts.RedisClient, limiter: limiter, debug: opts.Debug, - rejectStatusCode: opts.RejectStatusCode, + baseRateLimitKey: opts.BaseRateLimitKey, + rateLimitConfig: opts.RateLimitConfig, } if rl.rejectStatusCode == 0 { rl.rejectStatusCode = 200 @@ -59,6 +63,10 @@ type CosmoRateLimiter struct { rejectStatusCode int keySuffixProgram *vm.Program + + // TODO: To decouple from the config + rateLimitConfig config.RateLimitSimpleStrategy + baseRateLimitKey string } func (c *CosmoRateLimiter) RateLimitPreFetch(ctx *resolve.Context, info *resolve.FetchInfo, input json.RawMessage) (result *resolve.RateLimitDeny, err error) { @@ -66,15 +74,30 @@ func (c *CosmoRateLimiter) RateLimitPreFetch(ctx *resolve.Context, info *resolve return nil, nil } requestRate := c.calculateRate() - limit := redis_rate.Limit{ - Rate: ctx.RateLimitOptions.Rate, - Burst: ctx.RateLimitOptions.Burst, - Period: ctx.RateLimitOptions.Period, - } - key, err := c.generateKey(ctx) + rawExprKey, key, err := c.generateKey(ctx, info) if err != nil { return nil, err } + + limitDetails, found := c.rateLimitConfig.KeyMapping[rawExprKey] + if !found { + if !c.rateLimitConfig.Enabled { + return nil, nil + } else { + limitDetails = c.rateLimitConfig.RateLimitSimpleStrategyEntry + } + } else { + if !limitDetails.Enabled { + return nil, nil + } + } + + limit := redis_rate.Limit{ + Rate: limitDetails.Rate, + Burst: limitDetails.Burst, + Period: limitDetails.Period, + } + allow, err := c.limiter.AllowN(ctx.Context(), key, limit, requestRate) if err != nil { return nil, err @@ -83,29 +106,32 @@ func (c *CosmoRateLimiter) RateLimitPreFetch(ctx *resolve.Context, info *resolve if allow.Allowed >= requestRate { return nil, nil } - if ctx.RateLimitOptions.RejectExceedingRequests { + if limitDetails.RejectExceedingRequests { return nil, ErrRateLimitExceeded } return &resolve.RateLimitDeny{}, nil } -func (c *CosmoRateLimiter) generateKey(ctx *resolve.Context) (string, error) { +func (c *CosmoRateLimiter) generateKey(ctx *resolve.Context, info *resolve.FetchInfo) (string, string, error) { if c.keySuffixProgram == nil { - return ctx.RateLimitOptions.RateLimitKey, nil + return "", c.baseRateLimitKey, nil } rc := getRequestContext(ctx.Context()) if rc == nil { - return "", errors.New("no request context") + return "", "", errors.New("no request context") } - str, err := expr.ResolveStringExpression(c.keySuffixProgram, rc.expressionContext) + clonedEc := rc.expressionContext.Clone() + clonedEc.Subgraph.Id = info.DataSourceID + clonedEc.Subgraph.Name = info.DataSourceName + str, err := expr.ResolveStringExpression(c.keySuffixProgram, *clonedEc) if err != nil { - return "", fmt.Errorf("failed to resolve key suffix expression: %w", err) + return "", "", fmt.Errorf("failed to resolve key suffix expression: %w", err) } - buf := bytes.NewBuffer(make([]byte, 0, len(ctx.RateLimitOptions.RateLimitKey)+len(str)+1)) - _, _ = buf.WriteString(ctx.RateLimitOptions.RateLimitKey) + buf := bytes.NewBuffer(make([]byte, 0, len(c.baseRateLimitKey)+len(str)+1)) + _, _ = buf.WriteString(c.baseRateLimitKey) _ = buf.WriteByte(':') _, _ = buf.WriteString(str) - return buf.String(), nil + return str, buf.String(), nil } func (c *CosmoRateLimiter) RejectStatusCode() int { diff --git a/router/core/ratelimiter_test.go b/router/core/ratelimiter_test.go index f260ac1ae0..73a8c5b5a8 100644 --- a/router/core/ratelimiter_test.go +++ b/router/core/ratelimiter_test.go @@ -40,7 +40,7 @@ func TestRateLimiterGenerateKey(t *testing.T) { t.Parallel() rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{}) assert.NoError(t, err) - key, err := rl.generateKey(expressionResolveContext(t, nil, nil)) + key, err := rl.generateKey(expressionResolveContext(t, nil, nil), nil) assert.NoError(t, err) assert.Equal(t, "test", key) }) @@ -51,9 +51,7 @@ func TestRateLimiterGenerateKey(t *testing.T) { ExprManager: expr.CreateNewExprManager(), }) require.NoError(t, err) - key, err := rl.generateKey( - expressionResolveContext(t, http.Header{"Authorization": []string{"token"}}, nil), - ) + key, err := rl.generateKey(expressionResolveContext(t, http.Header{"Authorization": []string{"token"}}, nil), nil) assert.NoError(t, err) assert.Equal(t, "test:token", key) }) @@ -64,9 +62,7 @@ func TestRateLimiterGenerateKey(t *testing.T) { ExprManager: expr.CreateNewExprManager(), }) assert.NoError(t, err) - key, err := rl.generateKey( - expressionResolveContext(t, http.Header{"Authorization": []string{"123"}}, nil), - ) + key, err := rl.generateKey(expressionResolveContext(t, http.Header{"Authorization": []string{"123"}}, nil), nil) assert.NoError(t, err) assert.Equal(t, "test:123", key) }) @@ -77,9 +73,7 @@ func TestRateLimiterGenerateKey(t *testing.T) { ExprManager: expr.CreateNewExprManager(), }) assert.NoError(t, err) - key, err := rl.generateKey( - expressionResolveContext(t, http.Header{"Authorization": []string{" token "}}, nil), - ) + key, err := rl.generateKey(expressionResolveContext(t, http.Header{"Authorization": []string{" token "}}, nil), nil) assert.NoError(t, err) assert.Equal(t, "test:token", key) }) @@ -90,9 +84,7 @@ func TestRateLimiterGenerateKey(t *testing.T) { ExprManager: expr.CreateNewExprManager(), }) assert.NoError(t, err) - key, err := rl.generateKey( - expressionResolveContext(t, nil, map[string]any{"sub": "token"}), - ) + key, err := rl.generateKey(expressionResolveContext(t, nil, map[string]any{"sub": "token"}), nil) assert.NoError(t, err) assert.Equal(t, "test:token", key) }) @@ -103,9 +95,7 @@ func TestRateLimiterGenerateKey(t *testing.T) { ExprManager: expr.CreateNewExprManager(), }) assert.NoError(t, err) - key, err := rl.generateKey( - expressionResolveContext(t, nil, map[string]any{"sub": 123}), - ) + key, err := rl.generateKey(expressionResolveContext(t, nil, map[string]any{"sub": 123}), nil) assert.Error(t, err) assert.Empty(t, key) }) @@ -116,9 +106,7 @@ func TestRateLimiterGenerateKey(t *testing.T) { ExprManager: expr.CreateNewExprManager(), }) assert.NoError(t, err) - key, err := rl.generateKey( - expressionResolveContext(t, http.Header{"X-Forwarded-For": []string{"192.168.0.1"}}, map[string]any{"sub": "token"}), - ) + key, err := rl.generateKey(expressionResolveContext(t, http.Header{"X-Forwarded-For": []string{"192.168.0.1"}}, map[string]any{"sub": "token"}), nil) assert.NoError(t, err) assert.Equal(t, "test:token", key) }) @@ -129,9 +117,7 @@ func TestRateLimiterGenerateKey(t *testing.T) { ExprManager: expr.CreateNewExprManager(), }) assert.NoError(t, err) - key, err := rl.generateKey( - expressionResolveContext(t, http.Header{"X-Forwarded-For": []string{"192.168.0.1"}}, nil), - ) + key, err := rl.generateKey(expressionResolveContext(t, http.Header{"X-Forwarded-For": []string{"192.168.0.1"}}, nil), nil) assert.NoError(t, err) assert.Equal(t, "test:192.168.0.1", key) }) diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index 53b252632c..bd66bd4be5 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -96,7 +96,7 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS c.URL: store, }, PrioritizeHTTP: true, - RefreshUnknownKID: rate.NewLimiter(rate.Every(5*time.Minute), 1), + RefreshUnknownKID: rate.NewLimiter(rate.Every(2*time.Second), 1), } jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions) diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index c51e2810f5..0d03f6e5e8 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -507,6 +507,10 @@ type RateLimitConfiguration struct { Debug bool `yaml:"debug" envDefault:"false" env:"RATE_LIMIT_DEBUG"` KeySuffixExpression string `yaml:"key_suffix_expression,omitempty" env:"RATE_LIMIT_KEY_SUFFIX_EXPRESSION"` ErrorExtensionCode RateLimitErrorExtensionCode `yaml:"error_extension_code"` + + // This makes more sense here since it's a single value for the whole rate limit configuration + // Since we cannot have this per subgraph + HideStatsFromResponseExtension bool `yaml:"hide_stats_from_response_extension" envDefault:"false" env:"RATE_LIMIT_SIMPLE_HIDE_STATS_FROM_RESPONSE_EXTENSION"` } type RateLimitErrorExtensionCode struct { @@ -521,12 +525,23 @@ type RedisConfiguration struct { } type RateLimitSimpleStrategy struct { - Rate int `yaml:"rate" envDefault:"10" env:"RATE_LIMIT_SIMPLE_RATE"` - Burst int `yaml:"burst" envDefault:"10" env:"RATE_LIMIT_SIMPLE_BURST"` - Period time.Duration `yaml:"period" envDefault:"1s" env:"RATE_LIMIT_SIMPLE_PERIOD"` - RejectExceedingRequests bool `yaml:"reject_exceeding_requests" envDefault:"false" env:"RATE_LIMIT_SIMPLE_REJECT_EXCEEDING_REQUESTS"` - RejectStatusCode int `yaml:"reject_status_code" envDefault:"200" env:"RATE_LIMIT_SIMPLE_REJECT_STATUS_CODE"` - HideStatsFromResponseExtension bool `yaml:"hide_stats_from_response_extension" envDefault:"false" env:"RATE_LIMIT_SIMPLE_HIDE_STATS_FROM_RESPONSE_EXTENSION"` + RateLimitSimpleStrategyEntry + //Enabled bool `yaml:"enabled" envDefault:"true" env:"RATE_LIMIT_SIMPLE_ENABLED"` + //Rate int `yaml:"rate" envDefault:"10" env:"RATE_LIMIT_SIMPLE_RATE"` + //Burst int `yaml:"burst" envDefault:"10" env:"RATE_LIMIT_SIMPLE_BURST"` + //Period time.Duration `yaml:"period" envDefault:"1s" env:"RATE_LIMIT_SIMPLE_PERIOD"` + //RejectExceedingRequests bool `yaml:"reject_exceeding_requests" envDefault:"false" env:"RATE_LIMIT_SIMPLE_REJECT_EXCEEDING_REQUESTS"` + //RejectStatusCode int `yaml:"reject_status_code" envDefault:"200" env:"RATE_LIMIT_SIMPLE_REJECT_STATUS_CODE"` + KeyMapping map[string]RateLimitSimpleStrategyEntry `yaml:"key_mapping"` +} + +type RateLimitSimpleStrategyEntry struct { + Enabled bool `yaml:"enabled" envDefault:"true" env:"RATE_LIMIT_SIMPLE_ENABLED"` + Rate int `yaml:"rate" envDefault:"10" env:"RATE_LIMIT_SIMPLE_RATE"` + Burst int `yaml:"burst" envDefault:"10" env:"RATE_LIMIT_SIMPLE_BURST"` + Period time.Duration `yaml:"period" envDefault:"1s" env:"RATE_LIMIT_SIMPLE_PERIOD"` + RejectExceedingRequests bool `yaml:"reject_exceeding_requests" envDefault:"false" env:"RATE_LIMIT_SIMPLE_REJECT_EXCEEDING_REQUESTS"` + RejectStatusCode int `yaml:"reject_status_code" envDefault:"200" env:"RATE_LIMIT_SIMPLE_REJECT_STATUS_CODE"` } type CDNConfiguration struct { diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 54b2635627..da8a8cc3e8 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1820,6 +1820,47 @@ "type": "boolean", "default": false, "description": "Hide the rate limit stats from the response extension. If the value is true, the rate limit stats are not included in the response extension." + }, + "key_mapping": { + "type": "object", + "description": "The configuration per key entry.", + "additionalProperties": { + "type": "object", + "additionalProperties": false, + "description": "The configuration for all subgraphs. The configuration is used to configure the traffic shaping for all subgraphs.", + "properties": { + "enabled": { + "type": "boolean", + "description": "The rate at which the requests are allowed. The rate is specified as a number of requests per second." + }, + "rate": { + "type": "integer", + "description": "The rate at which the requests are allowed. The rate is specified as a number of requests per second.", + "minimum": 1 + }, + "burst": { + "type": "integer", + "description": "The maximum number of requests that are allowed to exceed the rate. The burst is specified as a number of requests.", + "minimum": 1 + }, + "period": { + "type": "string", + "description": "The period of time over which the rate limit is enforced. The period is specified as a string with a number and a unit, e.g. 10ms, 1s, 1m, 1h. The supported units are 'ms', 's', 'm', 'h'.", + "duration": { + "minimum": "1s" + } + }, + "reject_exceeding_requests": { + "type": "boolean", + "description": "Reject the requests that exceed the rate limit. If the value is true, the requests that exceed the rate limit are rejected." + }, + "reject_status_code": { + "type": "integer", + "description": "The status code to return when the request is rejected. The default value is 200 (OK) as we're returning a well formed GraphQL response.", + "default": 200 + } + } + } } }, "required": ["rate", "burst", "period"]