diff --git a/apps/docs/docs.json b/apps/docs/docs.json index 3d4d5dda1f..a15a25aa70 100644 --- a/apps/docs/docs.json +++ b/apps/docs/docs.json @@ -209,7 +209,10 @@ }, { "group": "User Errors", - "pages": ["errors/user/bad_request/permissions_query_syntax_error"] + "pages": [ + "errors/user/bad_request/permissions_query_syntax_error", + "errors/user/bad_request/request_body_too_large" + ] } ] } diff --git a/apps/docs/errors/user/bad_request/request_body_too_large.mdx b/apps/docs/errors/user/bad_request/request_body_too_large.mdx new file mode 100644 index 0000000000..10d19a0b4c --- /dev/null +++ b/apps/docs/errors/user/bad_request/request_body_too_large.mdx @@ -0,0 +1,87 @@ +--- +title: "request_body_too_large" +description: "Request body exceeds the maximum allowed size limit" +--- + +`err:user:bad_request:request_body_too_large` + +```json Example +{ + "meta": { + "requestId": "req_4dgzrNP3Je5mU1tD" + }, + "error": { + "detail": "The request body exceeds the maximum allowed size of 100 bytes.", + "status": 413, + "title": "Request Entity Too Large", + "type": "https://unkey.com/docs/errors/user/bad_request/request_body_too_large", + "errors": [] + } +} +``` + +## What Happened? + +Your request was too big! We limit how much data you can send in a single API request to keep everything running smoothly. + +This usually happens when you're trying to send a lot of data at once - like huge metadata objects or really long strings in your request. + +## How to Fix It + +### 1. Trim Down Your Request + +The most common cause is putting too much data in the `meta` field or other parts of your request. + + + +```bash Too Big +curl -X POST https://api.unkey.com/v2/keys.create \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer unkey_XXXX" \ + -d '{ + "apiId": "api_123", + "name": "My Key", + "meta": { + "userProfile": "... really long user profile data ...", + "settings": { /* huge nested object with tons of properties */ } + } + }' +``` + +```bash Just Right +curl -X POST https://api.unkey.com/v2/keys.create \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer unkey_XXXX" \ + -d '{ + "apiId": "api_123", + "name": "My Key", + "meta": { + "userId": "user_123", + "tier": "premium" + } + }' +``` + + + +### 2. Store Big Data Elsewhere + +Instead of cramming everything into your API request: + +- Store large data in your own database +- Only send IDs or references to Unkey +- Fetch the full data when you need it + +## Need a Higher Limit? + + +**Got a special use case?** If you have a legitimate need to send larger requests, we'd love to hear about it! + +[Contact our support team](mailto:support@unkey.com) and include: +- What you're building +- Why you need to send large requests +- An example of the data you're trying to send + +We'll work with you to find a solution that works for your use case. + +``` diff --git a/go/apps/api/config.go b/go/apps/api/config.go index 9da9962c55..da7e13ae03 100644 --- a/go/apps/api/config.go +++ b/go/apps/api/config.go @@ -78,6 +78,11 @@ type Config struct { // ChproxyToken is the authentication token for ClickHouse proxy endpoints ChproxyToken string + + // MaxRequestBodySize sets the maximum allowed request body size in bytes. + // If 0 or negative, no limit is enforced. Default is 0 (no limit). + // This helps prevent DoS attacks from excessively large request bodies. + MaxRequestBodySize int64 } func (c Config) Validate() error { diff --git a/go/apps/api/routes/chproxy_metrics/handler.go b/go/apps/api/routes/chproxy_metrics/handler.go index 59f76ea3c6..b0bcc9ac7f 100644 --- a/go/apps/api/routes/chproxy_metrics/handler.go +++ b/go/apps/api/routes/chproxy_metrics/handler.go @@ -33,6 +33,8 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { + s.DisableClickHouseLogging() + // Authenticate using Bearer token token, err := zen.Bearer(s) if err != nil { diff --git a/go/apps/api/routes/chproxy_ratelimits/handler.go b/go/apps/api/routes/chproxy_ratelimits/handler.go index 93b1c7e515..ed02c8bb1b 100644 --- a/go/apps/api/routes/chproxy_ratelimits/handler.go +++ b/go/apps/api/routes/chproxy_ratelimits/handler.go @@ -33,6 +33,8 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { + s.DisableClickHouseLogging() + // Authenticate using Bearer token token, err := zen.Bearer(s) if err != nil { diff --git a/go/apps/api/routes/chproxy_verifications/handler.go b/go/apps/api/routes/chproxy_verifications/handler.go index 776be8b7c3..eafcc1fe4c 100644 --- a/go/apps/api/routes/chproxy_verifications/handler.go +++ b/go/apps/api/routes/chproxy_verifications/handler.go @@ -33,6 +33,8 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { + s.DisableClickHouseLogging() + // Authenticate using Bearer token token, err := zen.Bearer(s) if err != nil { diff --git a/go/apps/api/routes/openapi/handler.go b/go/apps/api/routes/openapi/handler.go index d471597a29..afaa527aea 100644 --- a/go/apps/api/routes/openapi/handler.go +++ b/go/apps/api/routes/openapi/handler.go @@ -26,6 +26,8 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - s.AddHeader("Content-Type", "text/html") + s.DisableClickHouseLogging() + + s.AddHeader("Content-Type", "application/yaml") return s.Send(200, openapi.Spec) } diff --git a/go/apps/api/routes/reference/handler.go b/go/apps/api/routes/reference/handler.go index c4eaf918fd..69b3debd3f 100644 --- a/go/apps/api/routes/reference/handler.go +++ b/go/apps/api/routes/reference/handler.go @@ -27,6 +27,8 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { + s.DisableClickHouseLogging() + html := fmt.Sprintf(` diff --git a/go/apps/api/routes/v2_liveness/handler.go b/go/apps/api/routes/v2_liveness/handler.go index 6a7f111fd5..5d7624475d 100644 --- a/go/apps/api/routes/v2_liveness/handler.go +++ b/go/apps/api/routes/v2_liveness/handler.go @@ -27,13 +27,14 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - res := Response{ + s.DisableClickHouseLogging() + + return s.JSON(http.StatusOK, Response{ Meta: openapi.Meta{ RequestId: s.RequestID(), }, Data: openapi.V2LivenessResponseData{ Message: "we're cooking", }, - } - return s.JSON(http.StatusOK, res) + }) } diff --git a/go/apps/api/routes/v2_ratelimit_limit/handler.go b/go/apps/api/routes/v2_ratelimit_limit/handler.go index 7fa1bd5b9e..ee09c7cc08 100644 --- a/go/apps/api/routes/v2_ratelimit_limit/handler.go +++ b/go/apps/api/routes/v2_ratelimit_limit/handler.go @@ -51,6 +51,10 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { + if s.Request().Header.Get("X-Unkey-Metrics") == "disabled" { + s.DisableClickHouseLogging() + } + // Authenticate the request with a root key auth, emit, err := h.Keys.GetRootKey(ctx, s) defer emit() @@ -206,7 +210,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ) } - if s.Request().Header.Get("X-Unkey-Metrics") != "disabled" { + if s.ShouldLogRequestToClickHouse() { h.ClickHouse.BufferRatelimit(schema.RatelimitRequestV1{ RequestID: s.RequestID(), WorkspaceID: auth.AuthorizedWorkspaceID, @@ -216,6 +220,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { Passed: result.Success, }) } + res := Response{ Meta: openapi.Meta{ RequestId: s.RequestID(), diff --git a/go/apps/api/run.go b/go/apps/api/run.go index 6f728eebc5..924c09c21d 100644 --- a/go/apps/api/run.go +++ b/go/apps/api/run.go @@ -145,7 +145,8 @@ func Run(ctx context.Context, cfg Config) error { Flags: &zen.Flags{ TestMode: cfg.TestMode, }, - TLS: cfg.TLSConfig, + TLS: cfg.TLSConfig, + MaxRequestBodySize: cfg.MaxRequestBodySize, }) if err != nil { return fmt.Errorf("unable to create server: %w", err) diff --git a/go/cmd/api/main.go b/go/cmd/api/main.go index 9f0e26ece8..c24874ec06 100644 --- a/go/cmd/api/main.go +++ b/go/cmd/api/main.go @@ -77,6 +77,10 @@ var Cmd = &cli.Command{ "Authentication token for ClickHouse proxy endpoints. Required when proxy is enabled.", cli.EnvVar("UNKEY_CHPROXY_AUTH_TOKEN"), ), + + // Request Body Configuration + cli.Int64("max-request-body-size", "Maximum allowed request body size in bytes. Set to 0 or negative to disable limit. Default: 10485760 (10MB)", + cli.Default(int64(10485760)), cli.EnvVar("UNKEY_MAX_REQUEST_BODY_SIZE")), }, Action: action, @@ -146,6 +150,9 @@ func action(ctx context.Context, cmd *cli.Command) error { // ClickHouse proxy configuration ChproxyToken: cmd.String("chproxy-auth-token"), + + // Request body configuration + MaxRequestBodySize: cmd.Int64("max-request-body-size"), } err := config.Validate() diff --git a/go/internal/services/usagelimiter/limit.go b/go/internal/services/usagelimiter/limit.go index 862f7b9766..6200b75b01 100644 --- a/go/internal/services/usagelimiter/limit.go +++ b/go/internal/services/usagelimiter/limit.go @@ -42,7 +42,6 @@ func (s *service) Limit(ctx context.Context, req UsageRequest) (UsageResponse, e } metrics.UsagelimiterDecisions.WithLabelValues("db", "allowed").Inc() - metrics.UsagelimiterCreditsProcessed.Add(float64(req.Cost)) return UsageResponse{Valid: true, Remaining: max(0, remaining-req.Cost)}, nil } diff --git a/go/internal/services/usagelimiter/redis.go b/go/internal/services/usagelimiter/redis.go index 9ba3268032..563c533ae5 100644 --- a/go/internal/services/usagelimiter/redis.go +++ b/go/internal/services/usagelimiter/redis.go @@ -191,7 +191,6 @@ func (s *counterService) Limit(ctx context.Context, req UsageRequest) (UsageResp // Attempt decrement if key already exists in Redis remaining, exists, success, err := s.counter.DecrementIfExists(ctx, redisKey, int64(req.Cost)) if err != nil { - metrics.UsagelimiterFallbackOperations.Inc() return s.dbFallback.Limit(ctx, req) } @@ -215,7 +214,6 @@ func (s *counterService) handleResult(req UsageRequest, remaining int64, success }) metrics.UsagelimiterDecisions.WithLabelValues("redis", "allowed").Inc() - metrics.UsagelimiterCreditsProcessed.Add(float64(req.Cost)) return UsageResponse{Valid: true, Remaining: int32(remaining)}, nil } @@ -262,7 +260,6 @@ func (s *counterService) initializeFromDatabase(ctx context.Context, req UsageRe wasSet, err := s.counter.SetIfNotExists(ctx, redisKey, initValue, s.ttl) if err != nil { - metrics.UsagelimiterFallbackOperations.Inc() s.logger.Debug("failed to initialize counter with SetIfNotExists, falling back to DB", "error", err, "keyId", req.KeyId) return s.dbFallback.Limit(ctx, req) } @@ -281,7 +278,6 @@ func (s *counterService) initializeFromDatabase(ctx context.Context, req UsageRe // Another node already initialized the key, check if we have enough after decrement remaining, exists, success, err := s.counter.DecrementIfExists(ctx, redisKey, int64(req.Cost)) if err != nil || !exists { - metrics.UsagelimiterFallbackOperations.Inc() s.logger.Debug("failed to decrement after initialization attempt", "error", err, "exists", exists, "keyId", req.KeyId) return s.dbFallback.Limit(ctx, req) } diff --git a/go/pkg/cli/command.go b/go/pkg/cli/command.go index 0ce764c43e..e0f528cafb 100644 --- a/go/pkg/cli/command.go +++ b/go/pkg/cli/command.go @@ -125,6 +125,33 @@ func (c *Command) RequireInt(name string) int { return inf.Value() } +// Int64 returns the value of an int64 flag by name +// Returns 0 if flag doesn't exist or isn't an Int64Flag +func (c *Command) Int64(name string) int64 { + if flag, ok := c.flagMap[name]; ok { + if i64f, ok := flag.(*Int64Flag); ok { + return i64f.Value() + } + } + return 0 +} + +// RequireInt64 returns the value of an int64 flag by name +// Panics if flag doesn't exist or isn't an Int64Flag +func (c *Command) RequireInt64(name string) int64 { + flag, ok := c.flagMap[name] + if !ok { + panic(c.newFlagNotFoundError(name)) + } + + i64f, ok := flag.(*Int64Flag) + if !ok { + panic(c.newWrongFlagTypeError(name, flag, "Int64Flag")) + } + + return i64f.Value() +} + // Float returns the value of a float flag by name // Returns 0.0 if flag doesn't exist or isn't a FloatFlag func (c *Command) Float(name string) float64 { @@ -220,6 +247,8 @@ func (c *Command) getFlagType(flag Flag) string { return "BoolFlag" case *IntFlag: return "IntFlag" + case *Int64Flag: + return "Int64Flag" case *FloatFlag: return "FloatFlag" case *StringSliceFlag: diff --git a/go/pkg/cli/flag.go b/go/pkg/cli/flag.go index 85bafc0dcc..269f94a452 100644 --- a/go/pkg/cli/flag.go +++ b/go/pkg/cli/flag.go @@ -12,6 +12,7 @@ var ( ErrValidationFailed = errors.New("validation failed") ErrInvalidBoolValue = errors.New("invalid boolean value") ErrInvalidIntValue = errors.New("invalid integer value") + ErrInvalidInt64Value = errors.New("invalid int64 value") ErrInvalidFloatValue = errors.New("invalid float value") ) @@ -150,6 +151,38 @@ func (f *IntFlag) Value() int { return f.value } // HasValue returns true if the flag has a non-zero value or came from environment func (f *IntFlag) HasValue() bool { return f.value != 0 || f.hasEnvValue } +// Int64Flag represents an int64 command line flag +type Int64Flag struct { + baseFlag + value int64 // Current value + hasEnvValue bool // Track if value came from environment +} + +// Parse sets the flag value from a string +func (f *Int64Flag) Parse(value string) error { + parsed, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return fmt.Errorf("%w: %s", ErrInvalidInt64Value, value) + } + + // Run validation if provided + if f.validate != nil { + if err := f.validate(value); err != nil { + return newValidationError(f.name, err) + } + } + + f.value = parsed + f.set = true + return nil +} + +// Value returns the current int64 value +func (f *Int64Flag) Value() int64 { return f.value } + +// HasValue returns true if the flag has a non-zero value or came from environment +func (f *Int64Flag) HasValue() bool { return f.value != 0 || f.hasEnvValue } + // FloatFlag represents a float64 command line flag type FloatFlag struct { baseFlag @@ -240,6 +273,8 @@ func Required() FlagOption { flag.required = true case *IntFlag: flag.required = true + case *Int64Flag: + flag.required = true case *FloatFlag: flag.required = true case *StringSliceFlag: @@ -258,6 +293,8 @@ func EnvVar(envVar string) FlagOption { flag.envVar = envVar case *IntFlag: flag.envVar = envVar + case *Int64Flag: + flag.envVar = envVar case *FloatFlag: flag.envVar = envVar case *StringSliceFlag: @@ -276,6 +313,8 @@ func Validate(fn ValidateFunc) FlagOption { flag.validate = fn case *IntFlag: flag.validate = fn + case *Int64Flag: + flag.validate = fn case *FloatFlag: flag.validate = fn case *StringSliceFlag: @@ -307,6 +346,12 @@ func Default(value any) FlagOption { } else { err = fmt.Errorf("default value for int flag '%s' must be int, got %T", flag.name, value) } + case *Int64Flag: + if v, ok := value.(int64); ok { + flag.value = v + } else { + err = fmt.Errorf("default value for int64 flag '%s' must be int64, got %T", flag.name, value) + } case *FloatFlag: if v, ok := value.(float64); ok { flag.value = v @@ -516,6 +561,46 @@ func StringSlice(name, usage string, opts ...FlagOption) *StringSliceFlag { return flag } +// Int64 creates a new int64 flag with optional configuration +func Int64(name, usage string, opts ...FlagOption) *Int64Flag { + flag := &Int64Flag{ + baseFlag: baseFlag{ + name: name, + usage: usage, + required: false, // Default to not required + }, + value: 0, // Default to zero + } + + // Apply options + for _, opt := range opts { + opt(flag) + } + + // Check environment variable for default value if specified + if flag.envVar != "" { + if envValue := os.Getenv(flag.envVar); envValue != "" { + parsed, err := strconv.ParseInt(envValue, 10, 64) + if err != nil { + Exit(fmt.Sprintf("Environment variable error: invalid int64 value in %s=%q: %v", + flag.envVar, envValue, err), 1) + } + // Apply validation to environment variable values + if flag.validate != nil { + if err := flag.validate(envValue); err != nil { + Exit(fmt.Sprintf("Environment variable error: validation failed for %s=%q: %v", + flag.envVar, envValue, err), 1) + } + } + flag.value = parsed + flag.hasEnvValue = true + // Don't mark as explicitly set - this is from environment + } + } + + return flag +} + func newValidationError(flagName string, err error) error { return fmt.Errorf("%w for flag %s: %w", ErrValidationFailed, flagName, err) } diff --git a/go/pkg/codes/constants_gen.go b/go/pkg/codes/constants_gen.go index 70deeca842..61b6767f88 100644 --- a/go/pkg/codes/constants_gen.go +++ b/go/pkg/codes/constants_gen.go @@ -14,6 +14,8 @@ const ( // PermissionsQuerySyntaxError indicates a syntax or lexical error in verifyKey permissions query parsing. UserErrorsBadRequestPermissionsQuerySyntaxError URN = "err:user:bad_request:permissions_query_syntax_error" + // RequestBodyTooLarge indicates the request body exceeds the maximum allowed size. + UserErrorsBadRequestRequestBodyTooLarge URN = "err:user:bad_request:request_body_too_large" // ---------------- // UnkeyAuthErrors diff --git a/go/pkg/codes/user_request.go b/go/pkg/codes/user_request.go index 40bba1636b..7a94795a73 100644 --- a/go/pkg/codes/user_request.go +++ b/go/pkg/codes/user_request.go @@ -4,6 +4,8 @@ package codes type userBadRequest struct { // PermissionsQuerySyntaxError indicates a syntax or lexical error in verifyKey permissions query parsing. PermissionsQuerySyntaxError Code + // RequestBodyTooLarge indicates the request body exceeds the maximum allowed size. + RequestBodyTooLarge Code } // UserErrors defines all user-related errors in the Unkey system. @@ -19,5 +21,6 @@ type UserErrors struct { var User = UserErrors{ BadRequest: userBadRequest{ PermissionsQuerySyntaxError: Code{SystemUser, CategoryUserBadRequest, "permissions_query_syntax_error"}, + RequestBodyTooLarge: Code{SystemUser, CategoryUserBadRequest, "request_body_too_large"}, }, } diff --git a/go/pkg/prometheus/metrics/usagelimiter.go b/go/pkg/prometheus/metrics/usagelimiter.go index 8a429e4f5a..a396453803 100644 --- a/go/pkg/prometheus/metrics/usagelimiter.go +++ b/go/pkg/prometheus/metrics/usagelimiter.go @@ -62,34 +62,4 @@ var ( ConstLabels: constLabels, }, ) - - // UsagelimiterFallbackOperations counts fallback operations to direct DB access - // This counter helps monitor Redis health and fallback frequency. - // - // Example usage: - // metrics.UsagelimiterFallbackOperations.Inc() - UsagelimiterFallbackOperations = promauto.NewCounter( - prometheus.CounterOpts{ - Namespace: "unkey", - Subsystem: "usagelimiter", - Name: "fallback_operations_total", - Help: "Total number of fallback operations to direct database access.", - ConstLabels: constLabels, - }, - ) - - // UsagelimiterCreditsProcessed counts the total number of credits processed - // This counter helps track the overall usage and throughput of the system. - // - // Example usage: - // metrics.UsagelimiterCreditsProcessed.Add(float64(creditsUsed)) - UsagelimiterCreditsProcessed = promauto.NewCounter( - prometheus.CounterOpts{ - Namespace: "unkey", - Subsystem: "usagelimiter", - Name: "credits_processed_total", - Help: "Total number of credits processed by the usage limiter.", - ConstLabels: constLabels, - }, - ) ) diff --git a/go/pkg/zen/auth_test.go b/go/pkg/zen/auth_test.go index c1488d1c69..d082443bd0 100644 --- a/go/pkg/zen/auth_test.go +++ b/go/pkg/zen/auth_test.go @@ -103,7 +103,7 @@ func TestBearer_Integration(t *testing.T) { w := httptest.NewRecorder() sess := &Session{} - err := sess.init(w, req) + err := sess.init(w, req, 0) require.NoError(t, err) token, err := Bearer(sess) diff --git a/go/pkg/zen/middleware_errors.go b/go/pkg/zen/middleware_errors.go index 7dd1ee1261..eda2b40c1c 100644 --- a/go/pkg/zen/middleware_errors.go +++ b/go/pkg/zen/middleware_errors.go @@ -74,6 +74,21 @@ func WithErrorHandling(logger logging.Logger) Middleware { }, }) + // Request Entity Too Large errors + case codes.UserErrorsBadRequestRequestBodyTooLarge: + return s.JSON(http.StatusRequestEntityTooLarge, openapi.BadRequestErrorResponse{ + Meta: openapi.Meta{ + RequestId: s.RequestID(), + }, + Error: openapi.BadRequestErrorDetails{ + Title: "Request Entity Too Large", + Type: code.DocsURL(), + Detail: fault.UserFacingMessage(err), + Status: http.StatusRequestEntityTooLarge, + Errors: []openapi.ValidationError{}, + }, + }) + // Unauthorized errors case codes.UnkeyAuthErrorsAuthenticationKeyNotFound: diff --git a/go/pkg/zen/middleware_metrics.go b/go/pkg/zen/middleware_metrics.go index 66e4d7d1d0..ae7a5f5c8d 100644 --- a/go/pkg/zen/middleware_metrics.go +++ b/go/pkg/zen/middleware_metrics.go @@ -49,20 +49,6 @@ func WithMetrics(eventBuffer EventBuffer) Middleware { nextErr := next(ctx, s) serviceLatency := time.Since(start) - requestHeaders := []string{} - for k, vv := range s.r.Header { - if strings.ToLower(k) == "authorization" { - requestHeaders = append(requestHeaders, fmt.Sprintf("%s: %s", k, "[REDACTED]")) - } else { - requestHeaders = append(requestHeaders, fmt.Sprintf("%s: %s", k, strings.Join(vv, ","))) - } - } - - responseHeaders := []string{} - for k, vv := range s.w.Header() { - responseHeaders = append(responseHeaders, fmt.Sprintf("%s: %s", k, strings.Join(vv, ","))) - } - // "method", "path", "status" labelValues := []string{s.r.Method, s.r.URL.Path, strconv.Itoa(s.responseStatus)} @@ -70,14 +56,29 @@ func WithMetrics(eventBuffer EventBuffer) Middleware { metrics.HTTPRequestTotal.WithLabelValues(labelValues...).Inc() metrics.HTTPRequestLatency.WithLabelValues(labelValues...).Observe(serviceLatency.Seconds()) - // https://docs.aws.amazon.com/elasticloadbalancing/latest/classic/x-forwarded-headers.html#x-forwarded-for - ips := strings.Split(s.r.Header.Get("X-Forwarded-For"), ",") - ipAddress := "" - if len(ips) > 0 { - ipAddress = ips[0] - } + // Only log if we should log request to ClickHouse + if s.ShouldLogRequestToClickHouse() { + requestHeaders := []string{} + for k, vv := range s.r.Header { + if strings.ToLower(k) == "authorization" { + requestHeaders = append(requestHeaders, fmt.Sprintf("%s: %s", k, "[REDACTED]")) + } else { + requestHeaders = append(requestHeaders, fmt.Sprintf("%s: %s", k, strings.Join(vv, ","))) + } + } + + responseHeaders := []string{} + for k, vv := range s.w.Header() { + responseHeaders = append(responseHeaders, fmt.Sprintf("%s: %s", k, strings.Join(vv, ","))) + } + + // https://docs.aws.amazon.com/elasticloadbalancing/latest/classic/x-forwarded-headers.html#x-forwarded-for + ips := strings.Split(s.r.Header.Get("X-Forwarded-For"), ",") + ipAddress := "" + if len(ips) > 0 { + ipAddress = ips[0] + } - if s.r.Header.Get("X-Unkey-Metrics") != "disabled" { eventBuffer.BufferApiRequest(schema.ApiRequestV1{ WorkspaceID: s.WorkspaceID, RequestID: s.RequestID(), diff --git a/go/pkg/zen/request_util.go b/go/pkg/zen/request_util.go index d7b2766846..4a48045024 100644 --- a/go/pkg/zen/request_util.go +++ b/go/pkg/zen/request_util.go @@ -14,7 +14,8 @@ func BindBody[T any](s *Session) (T, error) { if err != nil { return req, fault.Wrap(err, fault.Code(codes.App.Validation.InvalidInput.URN()), - fault.Internal("invalid request body"), fault.Public("The request body is invalid."), + fault.Internal("invalid request body"), + fault.Public("The request body is invalid."), ) } diff --git a/go/pkg/zen/server.go b/go/pkg/zen/server.go index 9e47d1b8da..9f850e4605 100644 --- a/go/pkg/zen/server.go +++ b/go/pkg/zen/server.go @@ -28,7 +28,7 @@ type Server struct { mux *http.ServeMux srv *http.Server flags Flags - tlsConfig *tls.Config + config Config sessions sync.Pool } @@ -49,6 +49,11 @@ type Config struct { TLS *tls.Config Flags *Flags + + // MaxRequestBodySize sets the maximum allowed request body size in bytes. + // If 0 or negative, no limit is enforced. Default is 0 (no limit). + // This helps prevent DoS attacks from excessively large request bodies. + MaxRequestBodySize int64 } // New creates a new server with the provided configuration. @@ -102,7 +107,7 @@ func New(config Config) (*Server, error) { mux: mux, srv: srv, flags: flags, - tlsConfig: config.TLS, + config: config, sessions: sync.Pool{ New: func() any { return &Session{ @@ -196,10 +201,10 @@ func (s *Server) Serve(ctx context.Context, ln net.Listener) error { var err error // Check if TLS should be used - if s.tlsConfig != nil { + if s.config.TLS != nil { s.logger.Info("listening", "srv", "https", "addr", ln.Addr().String()) - s.srv.TLSConfig = s.tlsConfig + s.srv.TLSConfig = s.config.TLS // ListenAndServeTLS with empty strings will use the certificates from TLSConfig err = s.srv.ServeTLS(ln, "", "") @@ -241,14 +246,24 @@ func (s *Server) RegisterRoute(middlewares []Middleware, route Route) { if !ok { panic("Unable to cast session") } + defer func() { sess.reset() s.returnSession(sess) }() - err := sess.init(w, r) + err := sess.init(w, r, s.config.MaxRequestBodySize) if err != nil { - s.logger.Error("failed to init session") + s.logger.Error("failed to init session", "error", err) + + // Apply error handling middleware for session initialization errors + errorHandler := WithErrorHandling(s.logger) + handleFn := func(ctx context.Context, session *Session) error { + return err // Return the session init error + } + wrappedHandler := errorHandler(handleFn) + _ = wrappedHandler(r.Context(), sess) + return } diff --git a/go/pkg/zen/session.go b/go/pkg/zen/session.go index 457bfa1ed2..adc9562780 100644 --- a/go/pkg/zen/session.go +++ b/go/pkg/zen/session.go @@ -1,7 +1,9 @@ package zen import ( + "bytes" "encoding/json" + "errors" "fmt" "io" "net" @@ -10,6 +12,7 @@ import ( "strconv" "strings" + "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/fault" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -39,13 +42,57 @@ type Session struct { requestBody []byte responseStatus int responseBody []byte + + // ClickHouse request logging control - defaults to true (log by default) + logRequestToClickHouse bool } -func (s *Session) init(w http.ResponseWriter, r *http.Request) error { +func (s *Session) init(w http.ResponseWriter, r *http.Request, maxBodySize int64) error { s.requestID = uid.New(uid.RequestPrefix) s.w = w s.r = r + s.logRequestToClickHouse = true // Default to logging requests to ClickHouse + + // Apply body size limit if configured + if maxBodySize > 0 { + s.r.Body = http.MaxBytesReader(s.w, s.r.Body, maxBodySize) + } + + // Read and cache the request body so metrics middleware can access it even on early errors. + // We need to replace r.Body with a fresh reader afterwards so other middleware + // can still read the body if necessary. + var err error + s.requestBody, err = io.ReadAll(s.r.Body) + closeErr := s.r.Body.Close() + + // Handle read errors (including MaxBytesError) + if err != nil { + // Check if this is a MaxBytesError from http.MaxBytesReader + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + return fault.Wrap(err, + fault.Code(codes.User.BadRequest.RequestBodyTooLarge.URN()), + fault.Internal(fmt.Sprintf("request body exceeds size limit of %d bytes", maxBytesErr.Limit)), + fault.Public(fmt.Sprintf("The request body exceeds the maximum allowed size of %d bytes.", maxBytesErr.Limit)), + ) + } + return fault.Wrap(err, + fault.Internal("unable to read request body"), + fault.Public("The request body could not be read."), + ) + } + + // Handle close error (incase that ever happens) + if closeErr != nil { + return fault.Wrap(closeErr, + fault.Internal("failed to close request body"), + fault.Public("An error occurred processing the request."), + ) + } + + // Replace body with a fresh reader for subsequent middleware + s.r.Body = io.NopCloser(bytes.NewReader(s.requestBody)) s.WorkspaceID = "" return nil } @@ -58,6 +105,21 @@ func (s *Session) AuthorizedWorkspaceID() string { return s.WorkspaceID } +// DisableClickHouseLogging prevents this request from being logged to ClickHouse. +// By default, all requests are logged to ClickHouse unless explicitly disabled. +// +// This is useful for internal endpoints like health checks, OpenAPI specs, +// or requests that should not appear in analytics. +func (s *Session) DisableClickHouseLogging() { + s.logRequestToClickHouse = false +} + +// ShouldLogRequestToClickHouse returns whether this request should be logged to ClickHouse. +// Returns true by default, false only if explicitly disabled. +func (s *Session) ShouldLogRequestToClickHouse() bool { + return s.logRequestToClickHouse +} + func (s *Session) UserAgent() string { return s.r.UserAgent() } @@ -114,19 +176,14 @@ func (s *Session) ResponseWriter() http.ResponseWriter { // } // // Use the parsed user data func (s *Session) BindBody(dst any) error { - var err error - s.requestBody, err = io.ReadAll(s.r.Body) - if err != nil { - return fault.Wrap(err, fault.Internal("unable to read request body"), fault.Public("The request body is malformed.")) - } - defer s.r.Body.Close() - - err = json.Unmarshal(s.requestBody, dst) + err := json.Unmarshal(s.requestBody, dst) if err != nil { return fault.Wrap(err, - fault.Internal("failed to unmarshal request body"), fault.Public("The request body was not valid json."), + fault.Internal("failed to unmarshal request body"), + fault.Public("The request body was not valid JSON."), ) } + return nil } @@ -362,4 +419,5 @@ func (s *Session) reset() { s.requestBody = nil s.responseStatus = 0 s.responseBody = nil + s.logRequestToClickHouse = true // Reset ClickHouse logging control to default (enabled) } diff --git a/go/pkg/zen/session_bind_body_test.go b/go/pkg/zen/session_bind_body_test.go index a23251ae47..c5bbcce23e 100644 --- a/go/pkg/zen/session_bind_body_test.go +++ b/go/pkg/zen/session_bind_body_test.go @@ -86,12 +86,12 @@ func TestSession_BindBody(t *testing.T) { req.Header.Set("Content-Type", "application/json") // Create a session - sess := &Session{ - r: req, - } + sess := &Session{} + err := sess.init(httptest.NewRecorder(), req, 0) + require.NoError(t, err) // Call BindBody - err := sess.BindBody(tt.target) + err = sess.BindBody(tt.target) // Check error conditions if tt.wantErr { @@ -117,16 +117,11 @@ func TestSession_BindBody_ReadError(t *testing.T) { errReader := &errorReader{err: io.ErrUnexpectedEOF} req := httptest.NewRequest(http.MethodPost, "/", errReader) - // Create a session - sess := &Session{ - r: req, - } - - // Call BindBody - var target map[string]interface{} - err := sess.BindBody(&target) + // Create a session and try to init it (this should fail) + sess := &Session{} + err := sess.init(httptest.NewRecorder(), req, 0) - // Verify error + // Verify the init error require.Error(t, err) assert.Contains(t, err.Error(), "unable to read request body") } @@ -168,9 +163,9 @@ func TestSession_BindBody_LargeBody(t *testing.T) { req.Header.Set("Content-Type", "application/json") // Create a session - sess := &Session{ - r: req, - } + sess := &Session{} + err = sess.init(httptest.NewRecorder(), req, 0) + require.NoError(t, err) type LargeStruct struct { Items []Item `json:"items"` @@ -193,7 +188,7 @@ func TestSession_BindBody_Integration(t *testing.T) { w := httptest.NewRecorder() sess := &Session{} - err := sess.init(w, req) + err := sess.init(w, req, 0) require.NoError(t, err) type TestData struct { diff --git a/go/pkg/zen/session_bind_query_test.go b/go/pkg/zen/session_bind_query_test.go index 15b70e2923..8c2b654839 100644 --- a/go/pkg/zen/session_bind_query_test.go +++ b/go/pkg/zen/session_bind_query_test.go @@ -255,7 +255,7 @@ func TestSession_BindQuery_Init(t *testing.T) { // Create and initialize a session sess := &Session{} - err := sess.init(w, req) + err := sess.init(w, req, 0) require.NoError(t, err) // Bind query params diff --git a/go/pkg/zen/session_body_limit_test.go b/go/pkg/zen/session_body_limit_test.go new file mode 100644 index 0000000000..8eb7d190dc --- /dev/null +++ b/go/pkg/zen/session_body_limit_test.go @@ -0,0 +1,242 @@ +package zen + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/pkg/fault" + "github.com/unkeyed/unkey/go/pkg/otel/logging" +) + +func TestSession_BodySizeLimit(t *testing.T) { + tests := []struct { + name string + bodyContent string + maxBodySize int64 + wantErr bool + errSubstr string + }{ + { + name: "body within limit", + bodyContent: `{"name":"test"}`, + maxBodySize: 100, + wantErr: false, + }, + { + name: "body exceeds limit", + bodyContent: strings.Repeat("x", 200), + maxBodySize: 100, + wantErr: true, + errSubstr: "request body exceeds size limit of 100 bytes", + }, + { + name: "no limit enforced when maxBodySize is 0", + bodyContent: strings.Repeat("x", 1000), + maxBodySize: 0, + wantErr: false, + }, + { + name: "no limit enforced when maxBodySize is negative", + bodyContent: strings.Repeat("x", 1000), + maxBodySize: -1, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/", strings.NewReader(tt.bodyContent)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + sess := &Session{} + err := sess.init(w, req, tt.maxBodySize) + + if tt.wantErr { + require.Error(t, err) + if tt.errSubstr != "" { + require.Contains(t, err.Error(), tt.errSubstr) + } + return + } + + require.NoError(t, err) + require.Equal(t, []byte(tt.bodyContent), sess.requestBody) + }) + } +} + +func TestSession_BodySizeLimitWithBindBody(t *testing.T) { + // Test that BindBody still works correctly with body size limits + bodyContent := `{"name":"test","value":42}` + + req := httptest.NewRequest("POST", "/", strings.NewReader(bodyContent)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + sess := &Session{} + err := sess.init(w, req, 1024) // 1KB limit + require.NoError(t, err) + + type TestData struct { + Name string `json:"name"` + Value int `json:"value"` + } + + var data TestData + err = sess.BindBody(&data) + require.NoError(t, err) + require.Equal(t, "test", data.Name) + require.Equal(t, 42, data.Value) +} + +func TestSession_MaxBytesErrorMessage(t *testing.T) { + // Test that different size limits produce correct error messages + tests := []struct { + name string + bodySize int + maxBodySize int64 + wantErrMsg string + }{ + { + name: "512 byte limit", + bodySize: 1024, + maxBodySize: 512, + wantErrMsg: "request body exceeds size limit of 512 bytes", + }, + { + name: "1KB limit", + bodySize: 2048, + maxBodySize: 1024, + wantErrMsg: "request body exceeds size limit of 1024 bytes", + }, + { + name: "10KB limit", + bodySize: 20000, + maxBodySize: 10240, + wantErrMsg: "request body exceeds size limit of 10240 bytes", + }, + { + name: "1MB limit", + bodySize: 2097152, // 2MB + maxBodySize: 1048576, // 1MB + wantErrMsg: "request body exceeds size limit of 1048576 bytes", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a body larger than the limit + bodyContent := strings.Repeat("x", tt.bodySize) + req := httptest.NewRequest("POST", "/", strings.NewReader(bodyContent)) + w := httptest.NewRecorder() + + sess := &Session{} + err := sess.init(w, req, tt.maxBodySize) + + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErrMsg) + + // Also verify the user-facing message includes the limit + userMsg := fault.UserFacingMessage(err) + expectedUserMsg := fmt.Sprintf("The request body exceeds the maximum allowed size of %d bytes.", tt.maxBodySize) + require.Equal(t, expectedUserMsg, userMsg) + }) + } +} + +func TestSession_BodySizeLimitHTTPStatus(t *testing.T) { + // Test that oversized request bodies return 413 status through zen server + logger := logging.NewNoop() + + // Create server with small body size limit + srv, err := New(Config{ + Logger: logger, + MaxRequestBodySize: 100, // 100 byte limit + }) + require.NoError(t, err) + + // Flag to track if handler was invoked (should remain false) + handlerInvoked := false + + // Register a simple route that would process the body + testRoute := NewRoute("POST", "/test", func(ctx context.Context, s *Session) error { + // This should never be reached due to the body size limit + handlerInvoked = true + return s.JSON(http.StatusOK, map[string]string{"status": "ok"}) + }) + + srv.RegisterRoute( + []Middleware{ + WithErrorHandling(logger), + }, + testRoute, + ) + + // Create request with body larger than limit (200 bytes vs 100 byte limit) + bodyContent := strings.Repeat("x", 200) + req := httptest.NewRequest("POST", "/test", strings.NewReader(bodyContent)) + w := httptest.NewRecorder() + + // Call through the zen server + srv.Mux().ServeHTTP(w, req) + + // Check that the response is 413 Request Entity Too Large + require.Equal(t, http.StatusRequestEntityTooLarge, w.Code, "Should return 413 Request Entity Too Large status") + + // Parse and validate JSON response structure + require.Contains(t, w.Header().Get("Content-Type"), "application/json") + + var response map[string]interface{} + err = json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err, "Response should be valid JSON") + + // Validate that response contains error field + require.Contains(t, response, "error", "Response should contain 'error' field") + + errorObj, ok := response["error"].(map[string]interface{}) + require.True(t, ok, "Error field should be an object") + + // Validate required JSON fields in error object + require.Contains(t, errorObj, "title", "Error should contain 'title' field") + require.Contains(t, errorObj, "detail", "Error should contain 'detail' field") + require.Contains(t, errorObj, "status", "Error should contain 'status' field") + + // Validate field values + require.Equal(t, "Request Entity Too Large", errorObj["title"]) + require.Equal(t, float64(413), errorObj["status"]) // JSON unmarshals numbers as float64 + require.Contains(t, errorObj["detail"], "request body exceeds") + + // Ensure the handler was never invoked + require.False(t, handlerInvoked, "Handler should not have been invoked due to body size limit") +} + +func TestSession_ClickHouseLoggingControl(t *testing.T) { + // Test that the new ClickHouse logging control methods work correctly + req := httptest.NewRequest("POST", "/", strings.NewReader("test")) + w := httptest.NewRecorder() + + sess := &Session{} + err := sess.init(w, req, 0) + require.NoError(t, err) + + // Should default to true (logging enabled) + require.True(t, sess.ShouldLogRequestToClickHouse(), "Should default to logging enabled") + + // Disable ClickHouse logging + sess.DisableClickHouseLogging() + require.False(t, sess.ShouldLogRequestToClickHouse(), "Should be disabled after calling DisableClickHouseLogging") + + // Reset should re-enable logging + sess.reset() + require.True(t, sess.ShouldLogRequestToClickHouse(), "Should be enabled again after reset") + require.Empty(t, sess.requestBody) + require.Equal(t, 0, sess.responseStatus) + require.Empty(t, sess.responseBody) +}