diff --git a/README.md b/README.md index df3668b..c74d06e 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ Interceptor configuration functions \(AddUnaryServerInterceptor, SetFilterFunc, - [func AddStreamServerInterceptor\(ctx context.Context, i ...grpc.StreamServerInterceptor\)](<#AddStreamServerInterceptor>) - [func AddUnaryClientInterceptor\(ctx context.Context, i ...grpc.UnaryClientInterceptor\)](<#AddUnaryClientInterceptor>) - [func AddUnaryServerInterceptor\(ctx context.Context, i ...grpc.UnaryServerInterceptor\)](<#AddUnaryServerInterceptor>) +- [func DebugLogInterceptor\(\) grpc.UnaryServerInterceptor](<#DebugLogInterceptor>) - [func DebugLoggingInterceptor\(\) grpc.UnaryServerInterceptor](<#DebugLoggingInterceptor>) - [func DefaultClientInterceptor\(defaultOpts ...any\) grpc.UnaryClientInterceptor](<#DefaultClientInterceptor>) - [func DefaultClientInterceptors\(defaultOpts ...any\) \[\]grpc.UnaryClientInterceptor](<#DefaultClientInterceptors>) @@ -36,6 +37,7 @@ Interceptor configuration functions \(AddUnaryServerInterceptor, SetFilterFunc, - [func DoHTTPtoGRPC\(ctx context.Context, svr any, handler func\(ctx context.Context, req any\) \(any, error\), in any\) \(any, error\)](<#DoHTTPtoGRPC>) - [func FilterMethodsFunc\(ctx context.Context, fullMethodName string\) bool](<#FilterMethodsFunc>) - [func GRPCClientInterceptor\(\_ ...any\) grpc.UnaryClientInterceptor](<#GRPCClientInterceptor>) +- [func GetDebugLogHeaderName\(\) string](<#GetDebugLogHeaderName>) - [func HystrixClientInterceptor\(defaultOpts ...grpc.CallOption\) grpc.UnaryClientInterceptor](<#HystrixClientInterceptor>) - [func NRHttpTracer\(pattern string, h http.HandlerFunc\) \(string, http.HandlerFunc\)](<#NRHttpTracer>) - [func NewRelicClientInterceptor\(\) grpc.UnaryClientInterceptor](<#NewRelicClientInterceptor>) @@ -49,11 +51,16 @@ Interceptor configuration functions \(AddUnaryServerInterceptor, SetFilterFunc, - [func ServerErrorInterceptor\(\) grpc.UnaryServerInterceptor](<#ServerErrorInterceptor>) - [func ServerErrorStreamInterceptor\(\) grpc.StreamServerInterceptor](<#ServerErrorStreamInterceptor>) - [func SetClientMetricsOptions\(opts ...grpcprom.ClientMetricsOption\)](<#SetClientMetricsOptions>) +- [func SetDebugLogHeaderName\(name string\)](<#SetDebugLogHeaderName>) +- [func SetDefaultRateLimit\(rps float64, burst int\)](<#SetDefaultRateLimit>) - [func SetDefaultTimeout\(d time.Duration\)](<#SetDefaultTimeout>) +- [func SetDisableDebugLogInterceptor\(disable bool\)](<#SetDisableDebugLogInterceptor>) - [func SetDisableProtoValidate\(disable bool\)](<#SetDisableProtoValidate>) +- [func SetDisableRateLimit\(disable bool\)](<#SetDisableRateLimit>) - [func SetFilterFunc\(ctx context.Context, ff FilterFunc\)](<#SetFilterFunc>) - [func SetFilterMethods\(ctx context.Context, methods \[\]string\)](<#SetFilterMethods>) - [func SetProtoValidateOptions\(opts ...protovalidate.ValidatorOption\)](<#SetProtoValidateOptions>) +- [func SetRateLimiter\(limiter ratelimit\_middleware.Limiter\)](<#SetRateLimiter>) - [func SetResponseTimeLogErrorOnly\(errorOnly bool\)](<#SetResponseTimeLogErrorOnly>) - [func SetResponseTimeLogLevel\(ctx context.Context, level loggers.Level\)](<#SetResponseTimeLogLevel>) - [func SetServerMetricsOptions\(opts ...grpcprom.ServerMetricsOption\)](<#SetServerMetricsOptions>) @@ -86,7 +93,7 @@ var ( ``` -## func [AddStreamClientInterceptor]() +## func [AddStreamClientInterceptor]() ```go func AddStreamClientInterceptor(ctx context.Context, i ...grpc.StreamClientInterceptor) @@ -95,7 +102,7 @@ func AddStreamClientInterceptor(ctx context.Context, i ...grpc.StreamClientInter AddStreamClientInterceptor adds a client stream interceptor to default client stream interceptors. Must be called during initialization, before any RPCs are made. Not safe for concurrent use. -## func [AddStreamServerInterceptor]() +## func [AddStreamServerInterceptor]() ```go func AddStreamServerInterceptor(ctx context.Context, i ...grpc.StreamServerInterceptor) @@ -104,7 +111,7 @@ func AddStreamServerInterceptor(ctx context.Context, i ...grpc.StreamServerInter AddStreamServerInterceptor adds a server interceptor to default server interceptors. Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [AddUnaryClientInterceptor]() +## func [AddUnaryClientInterceptor]() ```go func AddUnaryClientInterceptor(ctx context.Context, i ...grpc.UnaryClientInterceptor) @@ -113,7 +120,7 @@ func AddUnaryClientInterceptor(ctx context.Context, i ...grpc.UnaryClientInterce AddUnaryClientInterceptor adds a client interceptor to default client interceptors. Must be called during initialization, before any RPCs are made. Not safe for concurrent use. -## func [AddUnaryServerInterceptor]() +## func [AddUnaryServerInterceptor]() ```go func AddUnaryServerInterceptor(ctx context.Context, i ...grpc.UnaryServerInterceptor) @@ -121,8 +128,22 @@ func AddUnaryServerInterceptor(ctx context.Context, i ...grpc.UnaryServerInterce AddUnaryServerInterceptor adds a server interceptor to default server interceptors. Must be called during initialization, before the server starts. Not safe for concurrent use. + +## func [DebugLogInterceptor]() + +```go +func DebugLogInterceptor() grpc.UnaryServerInterceptor +``` + +DebugLogInterceptor enables per\-request log level override based on a proto field or gRPC metadata header. It checks \(in order\): + +1. Proto field: GetDebug\(\) bool or GetEnableDebug\(\) bool — always sets DebugLevel +2. Metadata header: configurable via SetDebugLogHeaderName \(default "x\-debug\-log\-level"\) — the header value is parsed as a log level, allowing any valid level \(debug, info, warn, error\) + +Combined with ColdBrew's trace ID propagation, this allows enabling debug logging for a single request and following it across services via trace ID. + -## func [DebugLoggingInterceptor]() +## func [DebugLoggingInterceptor]() ```go func DebugLoggingInterceptor() grpc.UnaryServerInterceptor @@ -131,7 +152,7 @@ func DebugLoggingInterceptor() grpc.UnaryServerInterceptor DebugLoggingInterceptor is the interceptor that logs all request/response from a handler -## func [DefaultClientInterceptor]() +## func [DefaultClientInterceptor]() ```go func DefaultClientInterceptor(defaultOpts ...any) grpc.UnaryClientInterceptor @@ -140,7 +161,7 @@ func DefaultClientInterceptor(defaultOpts ...any) grpc.UnaryClientInterceptor DefaultClientInterceptor are the set of default interceptors that should be applied to all client calls -## func [DefaultClientInterceptors]() +## func [DefaultClientInterceptors]() ```go func DefaultClientInterceptors(defaultOpts ...any) []grpc.UnaryClientInterceptor @@ -149,7 +170,7 @@ func DefaultClientInterceptors(defaultOpts ...any) []grpc.UnaryClientInterceptor DefaultClientInterceptors are the set of default interceptors that should be applied to all client calls -## func [DefaultClientStreamInterceptor]() +## func [DefaultClientStreamInterceptor]() ```go func DefaultClientStreamInterceptor(defaultOpts ...any) grpc.StreamClientInterceptor @@ -158,7 +179,7 @@ func DefaultClientStreamInterceptor(defaultOpts ...any) grpc.StreamClientInterce DefaultClientStreamInterceptor are the set of default interceptors that should be applied to all stream client calls -## func [DefaultClientStreamInterceptors]() +## func [DefaultClientStreamInterceptors]() ```go func DefaultClientStreamInterceptors(defaultOpts ...any) []grpc.StreamClientInterceptor @@ -167,7 +188,7 @@ func DefaultClientStreamInterceptors(defaultOpts ...any) []grpc.StreamClientInte DefaultClientStreamInterceptors are the set of default interceptors that should be applied to all stream client calls -## func [DefaultInterceptors]() +## func [DefaultInterceptors]() ```go func DefaultInterceptors() []grpc.UnaryServerInterceptor @@ -176,7 +197,7 @@ func DefaultInterceptors() []grpc.UnaryServerInterceptor DefaultInterceptors are the set of default interceptors that are applied to all coldbrew methods -## func [DefaultStreamInterceptors]() +## func [DefaultStreamInterceptors]() ```go func DefaultStreamInterceptors() []grpc.StreamServerInterceptor @@ -185,7 +206,7 @@ func DefaultStreamInterceptors() []grpc.StreamServerInterceptor DefaultStreamInterceptors are the set of default interceptors that should be applied to all coldbrew streams -## func [DefaultTimeoutInterceptor]() +## func [DefaultTimeoutInterceptor]() ```go func DefaultTimeoutInterceptor() grpc.UnaryServerInterceptor @@ -194,7 +215,7 @@ func DefaultTimeoutInterceptor() grpc.UnaryServerInterceptor DefaultTimeoutInterceptor returns a unary server interceptor that applies a default deadline to incoming requests that have no deadline set. If the incoming context already has a deadline \(regardless of duration\), it is left unchanged. When defaultTimeout is \<= 0, the interceptor is a no\-op pass\-through. -## func [DoHTTPtoGRPC]() +## func [DoHTTPtoGRPC]() ```go func DoHTTPtoGRPC(ctx context.Context, svr any, handler func(ctx context.Context, req any) (any, error), in any) (any, error) @@ -220,7 +241,7 @@ func (s *svc) echo(ctx context.Context, req *proto.EchoRequest) (*proto.EchoResp ``` -## func [FilterMethodsFunc]() +## func [FilterMethodsFunc]() ```go func FilterMethodsFunc(ctx context.Context, fullMethodName string) bool @@ -229,7 +250,7 @@ func FilterMethodsFunc(ctx context.Context, fullMethodName string) bool FilterMethodsFunc is the default implementation of Filter function -## func [GRPCClientInterceptor]() +## func [GRPCClientInterceptor]() ```go func GRPCClientInterceptor(_ ...any) grpc.UnaryClientInterceptor @@ -237,8 +258,17 @@ func GRPCClientInterceptor(_ ...any) grpc.UnaryClientInterceptor Deprecated: GRPCClientInterceptor is no longer needed. gRPC tracing is now handled by google.golang.org/grpc/stats/opentelemetry, configured via opentelemetry.DialOption\(\) at the client level. This function is retained for backwards compatibility but returns a no\-op interceptor. + +## func [GetDebugLogHeaderName]() + +```go +func GetDebugLogHeaderName() string +``` + +GetDebugLogHeaderName returns the current debug log header name. + -## func [HystrixClientInterceptor]() +## func [HystrixClientInterceptor]() ```go func HystrixClientInterceptor(defaultOpts ...grpc.CallOption) grpc.UnaryClientInterceptor @@ -251,7 +281,7 @@ Note: This interceptor wraps github.com/afex/hystrix\-go which has been unmainta The interceptor applies provided default and per\-call client options to configure Hystrix behavior \(for example the command name, disabled flag, excluded errors, and excluded gRPC status codes\). If Hystrix is disabled via options, the RPC is invoked directly. If the underlying RPC returns an error that matches any configured excluded error or whose gRPC status code matches any configured excluded code, Hystrix fallback is skipped and the RPC error is returned. Panics raised during the RPC invocation are captured and reported to the notifier before being converted into an error. If the RPC itself returns an error, that error is returned; otherwise any error produced by Hystrix is returned. -## func [NRHttpTracer]() +## func [NRHttpTracer]() ```go func NRHttpTracer(pattern string, h http.HandlerFunc) (string, http.HandlerFunc) @@ -260,7 +290,7 @@ func NRHttpTracer(pattern string, h http.HandlerFunc) (string, http.HandlerFunc) NRHttpTracer adds newrelic tracing to this http function -## func [NewRelicClientInterceptor]() +## func [NewRelicClientInterceptor]() ```go func NewRelicClientInterceptor() grpc.UnaryClientInterceptor @@ -269,7 +299,7 @@ func NewRelicClientInterceptor() grpc.UnaryClientInterceptor NewRelicClientInterceptor intercepts all client actions and reports them to newrelic. When NewRelic app is nil \(no license key configured\), returns a pass\-through interceptor to avoid overhead. -## func [NewRelicInterceptor]() +## func [NewRelicInterceptor]() ```go func NewRelicInterceptor() grpc.UnaryServerInterceptor @@ -278,7 +308,7 @@ func NewRelicInterceptor() grpc.UnaryServerInterceptor NewRelicInterceptor intercepts all server actions and reports them to newrelic. When NewRelic app is nil \(no license key configured\), returns a pass\-through interceptor to avoid overhead. -## func [OptionsInterceptor]() +## func [OptionsInterceptor]() ```go func OptionsInterceptor() grpc.UnaryServerInterceptor @@ -287,7 +317,7 @@ func OptionsInterceptor() grpc.UnaryServerInterceptor -## func [PanicRecoveryInterceptor]() +## func [PanicRecoveryInterceptor]() ```go func PanicRecoveryInterceptor() grpc.UnaryServerInterceptor @@ -296,7 +326,7 @@ func PanicRecoveryInterceptor() grpc.UnaryServerInterceptor -## func [ProtoValidateInterceptor]() +## func [ProtoValidateInterceptor]() ```go func ProtoValidateInterceptor() grpc.UnaryServerInterceptor @@ -305,7 +335,7 @@ func ProtoValidateInterceptor() grpc.UnaryServerInterceptor ProtoValidateInterceptor returns a unary server interceptor that validates incoming messages using protovalidate annotations. Returns InvalidArgument on validation failure. Uses GlobalValidator by default; if custom options are set via SetProtoValidateOptions, creates a new validator with those options. -## func [ProtoValidateStreamInterceptor]() +## func [ProtoValidateStreamInterceptor]() ```go func ProtoValidateStreamInterceptor() grpc.StreamServerInterceptor @@ -314,7 +344,7 @@ func ProtoValidateStreamInterceptor() grpc.StreamServerInterceptor ProtoValidateStreamInterceptor returns a stream server interceptor that validates incoming messages using protovalidate annotations. -## func [ResponseTimeLoggingInterceptor]() +## func [ResponseTimeLoggingInterceptor]() ```go func ResponseTimeLoggingInterceptor(ff FilterFunc) grpc.UnaryServerInterceptor @@ -323,7 +353,7 @@ func ResponseTimeLoggingInterceptor(ff FilterFunc) grpc.UnaryServerInterceptor ResponseTimeLoggingInterceptor logs response time for each request on server -## func [ResponseTimeLoggingStreamInterceptor]() +## func [ResponseTimeLoggingStreamInterceptor]() ```go func ResponseTimeLoggingStreamInterceptor() grpc.StreamServerInterceptor @@ -332,7 +362,7 @@ func ResponseTimeLoggingStreamInterceptor() grpc.StreamServerInterceptor ResponseTimeLoggingStreamInterceptor logs response time for stream RPCs. -## func [ServerErrorInterceptor]() +## func [ServerErrorInterceptor]() ```go func ServerErrorInterceptor() grpc.UnaryServerInterceptor @@ -341,7 +371,7 @@ func ServerErrorInterceptor() grpc.UnaryServerInterceptor ServerErrorInterceptor intercepts all server actions and reports them to error notifier -## func [ServerErrorStreamInterceptor]() +## func [ServerErrorStreamInterceptor]() ```go func ServerErrorStreamInterceptor() grpc.StreamServerInterceptor @@ -350,7 +380,7 @@ func ServerErrorStreamInterceptor() grpc.StreamServerInterceptor ServerErrorStreamInterceptor intercepts server errors for stream RPCs and reports them to the error notifier. -## func [SetClientMetricsOptions]() +## func [SetClientMetricsOptions]() ```go func SetClientMetricsOptions(opts ...grpcprom.ClientMetricsOption) @@ -358,8 +388,26 @@ func SetClientMetricsOptions(opts ...grpcprom.ClientMetricsOption) SetClientMetricsOptions appends gRPC client metrics options. Must be called during initialization, before the server starts. Not safe for concurrent use. + +## func [SetDebugLogHeaderName]() + +```go +func SetDebugLogHeaderName(name string) +``` + +SetDebugLogHeaderName sets the gRPC metadata header name that triggers per\-request log level override. Default is "x\-debug\-log\-level". The header value should be a valid log level \(e.g., "debug"\). Empty names are ignored. Must be called during initialization. + + +## func [SetDefaultRateLimit]() + +```go +func SetDefaultRateLimit(rps float64, burst int) +``` + +SetDefaultRateLimit configures the built\-in token bucket rate limiter. rps is requests per second, burst is the maximum burst size. This is a per\-pod in\-memory limit — with N pods, the effective cluster\-wide limit is N × rps. For distributed rate limiting, use SetRateLimiter\(\) with a custom implementation \(e.g., Redis\-backed\). Must be called during initialization. + -## func [SetDefaultTimeout]() +## func [SetDefaultTimeout]() ```go func SetDefaultTimeout(d time.Duration) @@ -367,8 +415,17 @@ func SetDefaultTimeout(d time.Duration) SetDefaultTimeout sets the default timeout applied to incoming unary RPCs that arrive without a deadline. When set to \<= 0, the timeout interceptor is disabled \(pass\-through\). Default is 60s. Must be called during initialization, before the server starts. Not safe for concurrent use. + +## func [SetDisableDebugLogInterceptor]() + +```go +func SetDisableDebugLogInterceptor(disable bool) +``` + +SetDisableDebugLogInterceptor disables the DebugLogInterceptor in the default interceptor chain. Must be called during initialization, before the server starts. + -## func [SetDisableProtoValidate]() +## func [SetDisableProtoValidate]() ```go func SetDisableProtoValidate(disable bool) @@ -376,8 +433,17 @@ func SetDisableProtoValidate(disable bool) SetDisableProtoValidate disables the protovalidate interceptor in the default chain. Must be called during init\(\) — not safe for concurrent use. + +## func [SetDisableRateLimit]() + +```go +func SetDisableRateLimit(disable bool) +``` + +SetDisableRateLimit disables the rate limiting interceptor in the default interceptor chain. Must be called during initialization. + -## func [SetFilterFunc]() +## func [SetFilterFunc]() ```go func SetFilterFunc(ctx context.Context, ff FilterFunc) @@ -386,7 +452,7 @@ func SetFilterFunc(ctx context.Context, ff FilterFunc) SetFilterFunc sets the default filter function to be used by interceptors. Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [SetFilterMethods]() +## func [SetFilterMethods]() ```go func SetFilterMethods(ctx context.Context, methods []string) @@ -395,7 +461,7 @@ func SetFilterMethods(ctx context.Context, methods []string) SetFilterMethods sets the list of method substrings to exclude from tracing/logging. It rebuilds the internal cache. Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [SetProtoValidateOptions]() +## func [SetProtoValidateOptions]() ```go func SetProtoValidateOptions(opts ...protovalidate.ValidatorOption) @@ -403,8 +469,17 @@ func SetProtoValidateOptions(opts ...protovalidate.ValidatorOption) SetProtoValidateOptions configures custom protovalidate options \(e.g., custom constraints\). Must be called during init\(\) — not safe for concurrent use. Follows ColdBrew's init\-only config pattern. + +## func [SetRateLimiter]() + +```go +func SetRateLimiter(limiter ratelimit_middleware.Limiter) +``` + +SetRateLimiter sets a custom rate limiter implementation. This overrides the built\-in token bucket limiter. Must be called during initialization. + -## func [SetResponseTimeLogErrorOnly]() +## func [SetResponseTimeLogErrorOnly]() ```go func SetResponseTimeLogErrorOnly(errorOnly bool) @@ -413,7 +488,7 @@ func SetResponseTimeLogErrorOnly(errorOnly bool) SetResponseTimeLogErrorOnly when set to true, only logs response time when the request returns an error. Successful requests are not logged. Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [SetResponseTimeLogLevel]() +## func [SetResponseTimeLogLevel]() ```go func SetResponseTimeLogLevel(ctx context.Context, level loggers.Level) @@ -422,7 +497,7 @@ func SetResponseTimeLogLevel(ctx context.Context, level loggers.Level) SetResponseTimeLogLevel sets the log level for response time logging. Default is InfoLevel. Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [SetServerMetricsOptions]() +## func [SetServerMetricsOptions]() ```go func SetServerMetricsOptions(opts ...grpcprom.ServerMetricsOption) @@ -431,7 +506,7 @@ func SetServerMetricsOptions(opts ...grpcprom.ServerMetricsOption) SetServerMetricsOptions appends gRPC server metrics options \(histogram, labels, namespace, etc.\). Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [TraceIdInterceptor]() +## func [TraceIdInterceptor]() ```go func TraceIdInterceptor() grpc.UnaryServerInterceptor @@ -440,7 +515,7 @@ func TraceIdInterceptor() grpc.UnaryServerInterceptor TraceIdInterceptor allows injecting trace id from request objects -## func [UseColdBrewClientInterceptors]() +## func [UseColdBrewClientInterceptors]() ```go func UseColdBrewClientInterceptors(ctx context.Context, flag bool) @@ -449,7 +524,7 @@ func UseColdBrewClientInterceptors(ctx context.Context, flag bool) UseColdBrewClientInterceptors allows enabling/disabling coldbrew client interceptors. When set to false, the coldbrew client interceptors will not be used. Must be called during initialization, before any RPCs are made. Not safe for concurrent use. -## func [UseColdBrewServerInterceptors]() +## func [UseColdBrewServerInterceptors]() ```go func UseColdBrewServerInterceptors(ctx context.Context, flag bool) @@ -458,7 +533,7 @@ func UseColdBrewServerInterceptors(ctx context.Context, flag bool) UseColdBrewServerInterceptors allows enabling/disabling coldbrew server interceptors. When set to false, the coldbrew server interceptors will not be used. Must be called during initialization, before the server starts. Not safe for concurrent use. -## type [FilterFunc]() +## type [FilterFunc]() If it returns false, the given request will not be traced. diff --git a/go.mod b/go.mod index 2ba34de..0a7a659 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/newrelic/go-agent/v3/integrations/nrgrpc v1.4.7 github.com/prometheus/client_golang v1.23.2 go.uber.org/goleak v1.3.0 + golang.org/x/time v0.15.0 google.golang.org/grpc v1.79.3 ) diff --git a/go.sum b/go.sum index 8a1ecf3..aefca08 100644 --- a/go.sum +++ b/go.sum @@ -752,6 +752,8 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200329025819-fd4102a86c65/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8= diff --git a/interceptors.go b/interceptors.go index e70609e..50d8c8a 100644 --- a/interceptors.go +++ b/interceptors.go @@ -27,11 +27,13 @@ import ( nrutil "github.com/go-coldbrew/tracing/newrelic" grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" protovalidate_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate" + ratelimit_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/ratelimit" grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/retry" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/newrelic/go-agent/v3/integrations/nrgrpc" newrelic "github.com/newrelic/go-agent/v3/newrelic" "github.com/prometheus/client_golang/prometheus" + "golang.org/x/time/rate" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -51,27 +53,33 @@ var ( // Use SetFilterMethods instead. Only some direct mutations (replacing the slice // or changing the first element) are detected by internal change detection; // other in-place changes may not invalidate caches correctly. - FilterMethods = []string{"healthcheck", "readycheck", "serverreflectioninfo"} - defaultFilterFunc = FilterMethodsFunc - unaryServerInterceptors = []grpc.UnaryServerInterceptor{} - streamServerInterceptors = []grpc.StreamServerInterceptor{} - useCBServerInterceptors = true - unaryClientInterceptors = []grpc.UnaryClientInterceptor{} - streamClientInterceptors = []grpc.StreamClientInterceptor{} - useCBClientInterceptors = true - responseTimeLogLevel loggers.Level = loggers.InfoLevel - responseTimeLogErrorOnly bool - defaultTimeout time.Duration = 60 * time.Second // 0 disables - protoValidateOpts []protovalidate.ValidatorOption - disableProtoValidate bool - srvMetricsOpts []grpcprom.ServerMetricsOption - cltMetricsOpts []grpcprom.ClientMetricsOption - srvMetricsOnce sync.Once - srvMetrics *grpcprom.ServerMetrics - cltMetricsOnce sync.Once - cltMetrics *grpcprom.ClientMetrics + FilterMethods = []string{"healthcheck", "readycheck", "serverreflectioninfo"} + defaultFilterFunc = FilterMethodsFunc + unaryServerInterceptors = []grpc.UnaryServerInterceptor{} + streamServerInterceptors = []grpc.StreamServerInterceptor{} + useCBServerInterceptors = true + unaryClientInterceptors = []grpc.UnaryClientInterceptor{} + streamClientInterceptors = []grpc.StreamClientInterceptor{} + useCBClientInterceptors = true + responseTimeLogLevel loggers.Level = loggers.InfoLevel + responseTimeLogErrorOnly bool + defaultTimeout time.Duration = 60 * time.Second // 0 disables + protoValidateOpts []protovalidate.ValidatorOption + disableProtoValidate bool + srvMetricsOpts []grpcprom.ServerMetricsOption + cltMetricsOpts []grpcprom.ClientMetricsOption + srvMetricsOnce sync.Once + srvMetrics *grpcprom.ServerMetrics + cltMetricsOnce sync.Once + cltMetrics *grpcprom.ClientMetrics disableDebugLogInterceptor bool debugLogHeaderName = "x-debug-log-level" + disableRateLimit bool + rateLimiter ratelimit_middleware.Limiter + rateLimiterOnce sync.Once + rateLimiterVal ratelimit_middleware.Limiter + defaultRateLimit rate.Limit = rate.Inf + defaultRateBurst int ) // SetResponseTimeLogLevel sets the log level for response time logging. @@ -428,8 +436,13 @@ func DefaultInterceptors() []grpc.UnaryServerInterceptor { ints = append(ints, unaryServerInterceptors...) } if useCBServerInterceptors { + ints = append(ints, DefaultTimeoutInterceptor()) + if !disableRateLimit { + if limiter := getRateLimiter(); limiter != nil { + ints = append(ints, ratelimit_middleware.UnaryServerInterceptor(limiter)) + } + } ints = append(ints, - DefaultTimeoutInterceptor(), ResponseTimeLoggingInterceptor(defaultFilterFunc), TraceIdInterceptor(), ) @@ -497,6 +510,11 @@ func DefaultStreamInterceptors() []grpc.StreamServerInterceptor { ints = append(ints, streamServerInterceptors...) } if useCBServerInterceptors { + if !disableRateLimit { + if limiter := getRateLimiter(); limiter != nil { + ints = append(ints, ratelimit_middleware.StreamServerInterceptor(limiter)) + } + } ints = append(ints, ResponseTimeLoggingStreamInterceptor(), ) @@ -860,3 +878,59 @@ func DebugLogInterceptor() grpc.UnaryServerInterceptor { return handler(ctx, req) } } + +// SetDisableRateLimit disables the rate limiting interceptor in the default +// interceptor chain. Must be called during initialization. +func SetDisableRateLimit(disable bool) { + disableRateLimit = disable +} + +// SetRateLimiter sets a custom rate limiter implementation. This overrides the +// built-in token bucket limiter. Must be called during initialization. +func SetRateLimiter(limiter ratelimit_middleware.Limiter) { + rateLimiter = limiter +} + +// SetDefaultRateLimit configures the built-in token bucket rate limiter. +// rps is requests per second, burst is the maximum burst size. +// This is a per-pod in-memory limit — with N pods, the effective cluster-wide +// limit is N × rps. For distributed rate limiting, use SetRateLimiter() with +// a custom implementation (e.g., Redis-backed). +// Must be called during initialization. +func SetDefaultRateLimit(rps float64, burst int) { + defaultRateLimit = rate.Limit(rps) + if burst < 1 { + burst = 1 + } + defaultRateBurst = burst +} + +// tokenBucketLimiter wraps golang.org/x/time/rate.Limiter to implement +// the ratelimit.Limiter interface. +type tokenBucketLimiter struct { + limiter *rate.Limiter +} + +func (l *tokenBucketLimiter) Limit(_ context.Context) error { + if !l.limiter.Allow() { + return fmt.Errorf("rate limit exceeded") + } + return nil +} + +func getRateLimiter() ratelimit_middleware.Limiter { + rateLimiterOnce.Do(func() { + if rateLimiter != nil { + rateLimiterVal = rateLimiter + return + } + if defaultRateLimit == rate.Inf { + rateLimiterVal = nil + return + } + rateLimiterVal = &tokenBucketLimiter{ + limiter: rate.NewLimiter(defaultRateLimit, defaultRateBurst), + } + }) + return rateLimiterVal +} diff --git a/interceptors_test.go b/interceptors_test.go index 9ed6d3b..1f093c8 100644 --- a/interceptors_test.go +++ b/interceptors_test.go @@ -12,9 +12,13 @@ import ( "github.com/go-coldbrew/log" "github.com/go-coldbrew/log/loggers" + ratelimit_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/ratelimit" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "golang.org/x/time/rate" "google.golang.org/grpc" + "google.golang.org/grpc/codes" grpcmd "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) // mockStream implements grpc.ServerTransportStream for testing. @@ -49,6 +53,12 @@ func resetGlobals() { httpToGRPCInterceptor = nil disableDebugLogInterceptor = false debugLogHeaderName = "x-debug-log-level" + disableRateLimit = false + rateLimiter = nil + rateLimiterOnce = sync.Once{} + rateLimiterVal = nil + defaultRateLimit = rate.Inf + defaultRateBurst = 0 } func TestFilterMethodsFunc(t *testing.T) { @@ -1306,3 +1316,122 @@ func TestDebugLogInterceptor_Disabled(t *testing.T) { } } } + +// --- RateLimit interceptor tests --- + +type alwaysRejectLimiter struct{} + +func (l *alwaysRejectLimiter) Limit(_ context.Context) error { + return fmt.Errorf("always rejected") +} + +func TestRateLimitInterceptor_DefaultInf(t *testing.T) { + resetGlobals() + // Default is rate.Inf — no rate limiting, getRateLimiter returns nil + limiter := getRateLimiter() + if limiter != nil { + t.Error("expected nil limiter with default rate.Inf") + } +} + +func TestRateLimitInterceptor_Allowed(t *testing.T) { + resetGlobals() + SetDefaultRateLimit(1000, 100) + info := &grpc.UnaryServerInfo{FullMethod: "/test/RateLimit"} + + handler := func(ctx context.Context, req any) (any, error) { + return "ok", nil + } + + limiter := getRateLimiter() + if limiter == nil { + t.Fatal("expected non-nil limiter after SetDefaultRateLimit") + } + + interceptor := ratelimit_middleware.UnaryServerInterceptor(limiter) + resp, err := interceptor(context.Background(), nil, info, handler) + if err != nil { + t.Fatalf("expected request to pass, got: %v", err) + } + if resp != "ok" { + t.Errorf("expected 'ok', got %v", resp) + } +} + +func TestRateLimitInterceptor_Exceeded(t *testing.T) { + resetGlobals() + SetDefaultRateLimit(1, 1) // 1 rps, burst 1 + info := &grpc.UnaryServerInfo{FullMethod: "/test/RateLimit"} + + handler := func(ctx context.Context, req any) (any, error) { + return "ok", nil + } + + limiter := getRateLimiter() + interceptor := ratelimit_middleware.UnaryServerInterceptor(limiter) + + // First request should pass + _, err := interceptor(context.Background(), nil, info, handler) + if err != nil { + t.Fatalf("first request should pass, got: %v", err) + } + + // Second request should be rate limited + _, err = interceptor(context.Background(), nil, info, handler) + if err == nil { + t.Fatal("second request should be rate limited") + } + st, ok := status.FromError(err) + if !ok || st.Code() != codes.ResourceExhausted { + t.Errorf("expected ResourceExhausted, got: %v", err) + } +} + +func TestRateLimitInterceptor_CustomLimiter(t *testing.T) { + resetGlobals() + SetRateLimiter(&alwaysRejectLimiter{}) + info := &grpc.UnaryServerInfo{FullMethod: "/test/CustomLimit"} + + handler := func(ctx context.Context, req any) (any, error) { + return "ok", nil + } + + limiter := getRateLimiter() + interceptor := ratelimit_middleware.UnaryServerInterceptor(limiter) + + _, err := interceptor(context.Background(), nil, info, handler) + if err == nil { + t.Fatal("expected rejection from custom limiter") + } + st, ok := status.FromError(err) + if !ok || st.Code() != codes.ResourceExhausted { + t.Errorf("expected ResourceExhausted, got: %v", err) + } +} + +func TestRateLimitInterceptor_Disabled(t *testing.T) { + resetGlobals() + SetDefaultRateLimit(1, 1) + SetDisableRateLimit(true) + + ints := DefaultInterceptors() + // Verify no ratelimit interceptor in chain by running all interceptors + // with a handler that should always succeed + info := &grpc.UnaryServerInfo{FullMethod: "/test/Disabled"} + handler := func(ctx context.Context, req any) (any, error) { + return "ok", nil + } + + // Fire multiple requests through the chain — none should be rate limited + for i := 0; i < 5; i++ { + for _, interceptor := range ints { + _, err := interceptor(context.Background(), nil, info, handler) + if err != nil { + st, ok := status.FromError(err) + if ok && st.Code() == codes.ResourceExhausted { + t.Fatal("rate limiting should be disabled") + } + } + } + } +}