diff --git a/README.md b/README.md index c74d06e..9e900b0 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ Interceptor configuration functions \(AddUnaryServerInterceptor, SetFilterFunc, - [func NewRelicInterceptor\(\) grpc.UnaryServerInterceptor](<#NewRelicInterceptor>) - [func OptionsInterceptor\(\) grpc.UnaryServerInterceptor](<#OptionsInterceptor>) - [func PanicRecoveryInterceptor\(\) grpc.UnaryServerInterceptor](<#PanicRecoveryInterceptor>) +- [func PanicRecoveryStreamInterceptor\(\) grpc.StreamServerInterceptor](<#PanicRecoveryStreamInterceptor>) - [func ProtoValidateInterceptor\(\) grpc.UnaryServerInterceptor](<#ProtoValidateInterceptor>) - [func ProtoValidateStreamInterceptor\(\) grpc.StreamServerInterceptor](<#ProtoValidateStreamInterceptor>) - [func ResponseTimeLoggingInterceptor\(ff FilterFunc\) grpc.UnaryServerInterceptor](<#ResponseTimeLoggingInterceptor>) @@ -93,7 +94,7 @@ var ( ``` -## func [AddStreamClientInterceptor]() +## func [AddStreamClientInterceptor]() ```go func AddStreamClientInterceptor(ctx context.Context, i ...grpc.StreamClientInterceptor) @@ -102,7 +103,7 @@ func AddStreamClientInterceptor(ctx context.Context, i ...grpc.StreamClientInter AddStreamClientInterceptor adds a client stream interceptor to default client stream interceptors. Must be called during initialization, before any RPCs are made. Not safe for concurrent use. -## func [AddStreamServerInterceptor]() +## func [AddStreamServerInterceptor]() ```go func AddStreamServerInterceptor(ctx context.Context, i ...grpc.StreamServerInterceptor) @@ -111,7 +112,7 @@ func AddStreamServerInterceptor(ctx context.Context, i ...grpc.StreamServerInter AddStreamServerInterceptor adds a server interceptor to default server interceptors. Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [AddUnaryClientInterceptor]() +## func [AddUnaryClientInterceptor]() ```go func AddUnaryClientInterceptor(ctx context.Context, i ...grpc.UnaryClientInterceptor) @@ -120,7 +121,7 @@ func AddUnaryClientInterceptor(ctx context.Context, i ...grpc.UnaryClientInterce AddUnaryClientInterceptor adds a client interceptor to default client interceptors. Must be called during initialization, before any RPCs are made. Not safe for concurrent use. -## func [AddUnaryServerInterceptor]() +## func [AddUnaryServerInterceptor]() ```go func AddUnaryServerInterceptor(ctx context.Context, i ...grpc.UnaryServerInterceptor) @@ -129,7 +130,7 @@ func AddUnaryServerInterceptor(ctx context.Context, i ...grpc.UnaryServerInterce AddUnaryServerInterceptor adds a server interceptor to default server interceptors. Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [DebugLogInterceptor]() +## func [DebugLogInterceptor]() ```go func DebugLogInterceptor() grpc.UnaryServerInterceptor @@ -143,7 +144,7 @@ DebugLogInterceptor enables per\-request log level override based on a proto fie Combined with ColdBrew's trace ID propagation, this allows enabling debug logging for a single request and following it across services via trace ID. -## func [DebugLoggingInterceptor]() +## func [DebugLoggingInterceptor]() ```go func DebugLoggingInterceptor() grpc.UnaryServerInterceptor @@ -152,7 +153,7 @@ func DebugLoggingInterceptor() grpc.UnaryServerInterceptor DebugLoggingInterceptor is the interceptor that logs all request/response from a handler -## func [DefaultClientInterceptor]() +## func [DefaultClientInterceptor]() ```go func DefaultClientInterceptor(defaultOpts ...any) grpc.UnaryClientInterceptor @@ -161,7 +162,7 @@ func DefaultClientInterceptor(defaultOpts ...any) grpc.UnaryClientInterceptor DefaultClientInterceptor are the set of default interceptors that should be applied to all client calls -## func [DefaultClientInterceptors]() +## func [DefaultClientInterceptors]() ```go func DefaultClientInterceptors(defaultOpts ...any) []grpc.UnaryClientInterceptor @@ -170,7 +171,7 @@ func DefaultClientInterceptors(defaultOpts ...any) []grpc.UnaryClientInterceptor DefaultClientInterceptors are the set of default interceptors that should be applied to all client calls -## func [DefaultClientStreamInterceptor]() +## func [DefaultClientStreamInterceptor]() ```go func DefaultClientStreamInterceptor(defaultOpts ...any) grpc.StreamClientInterceptor @@ -179,7 +180,7 @@ func DefaultClientStreamInterceptor(defaultOpts ...any) grpc.StreamClientInterce DefaultClientStreamInterceptor are the set of default interceptors that should be applied to all stream client calls -## func [DefaultClientStreamInterceptors]() +## func [DefaultClientStreamInterceptors]() ```go func DefaultClientStreamInterceptors(defaultOpts ...any) []grpc.StreamClientInterceptor @@ -188,7 +189,7 @@ func DefaultClientStreamInterceptors(defaultOpts ...any) []grpc.StreamClientInte DefaultClientStreamInterceptors are the set of default interceptors that should be applied to all stream client calls -## func [DefaultInterceptors]() +## func [DefaultInterceptors]() ```go func DefaultInterceptors() []grpc.UnaryServerInterceptor @@ -197,7 +198,7 @@ func DefaultInterceptors() []grpc.UnaryServerInterceptor DefaultInterceptors are the set of default interceptors that are applied to all coldbrew methods -## func [DefaultStreamInterceptors]() +## func [DefaultStreamInterceptors]() ```go func DefaultStreamInterceptors() []grpc.StreamServerInterceptor @@ -206,7 +207,7 @@ func DefaultStreamInterceptors() []grpc.StreamServerInterceptor DefaultStreamInterceptors are the set of default interceptors that should be applied to all coldbrew streams -## func [DefaultTimeoutInterceptor]() +## func [DefaultTimeoutInterceptor]() ```go func DefaultTimeoutInterceptor() grpc.UnaryServerInterceptor @@ -215,7 +216,7 @@ func DefaultTimeoutInterceptor() grpc.UnaryServerInterceptor DefaultTimeoutInterceptor returns a unary server interceptor that applies a default deadline to incoming requests that have no deadline set. If the incoming context already has a deadline \(regardless of duration\), it is left unchanged. When defaultTimeout is \<= 0, the interceptor is a no\-op pass\-through. -## func [DoHTTPtoGRPC]() +## func [DoHTTPtoGRPC]() ```go func DoHTTPtoGRPC(ctx context.Context, svr any, handler func(ctx context.Context, req any) (any, error), in any) (any, error) @@ -241,7 +242,7 @@ func (s *svc) echo(ctx context.Context, req *proto.EchoRequest) (*proto.EchoResp ``` -## func [FilterMethodsFunc]() +## func [FilterMethodsFunc]() ```go func FilterMethodsFunc(ctx context.Context, fullMethodName string) bool @@ -250,7 +251,7 @@ func FilterMethodsFunc(ctx context.Context, fullMethodName string) bool FilterMethodsFunc is the default implementation of Filter function -## func [GRPCClientInterceptor]() +## func [GRPCClientInterceptor]() ```go func GRPCClientInterceptor(_ ...any) grpc.UnaryClientInterceptor @@ -259,7 +260,7 @@ func GRPCClientInterceptor(_ ...any) grpc.UnaryClientInterceptor Deprecated: GRPCClientInterceptor is no longer needed. gRPC tracing is now handled by google.golang.org/grpc/stats/opentelemetry, configured via opentelemetry.DialOption\(\) at the client level. This function is retained for backwards compatibility but returns a no\-op interceptor. -## func [GetDebugLogHeaderName]() +## func [GetDebugLogHeaderName]() ```go func GetDebugLogHeaderName() string @@ -268,7 +269,7 @@ func GetDebugLogHeaderName() string GetDebugLogHeaderName returns the current debug log header name. -## func [HystrixClientInterceptor]() +## func [HystrixClientInterceptor]() ```go func HystrixClientInterceptor(defaultOpts ...grpc.CallOption) grpc.UnaryClientInterceptor @@ -281,7 +282,7 @@ Note: This interceptor wraps github.com/afex/hystrix\-go which has been unmainta The interceptor applies provided default and per\-call client options to configure Hystrix behavior \(for example the command name, disabled flag, excluded errors, and excluded gRPC status codes\). If Hystrix is disabled via options, the RPC is invoked directly. If the underlying RPC returns an error that matches any configured excluded error or whose gRPC status code matches any configured excluded code, Hystrix fallback is skipped and the RPC error is returned. Panics raised during the RPC invocation are captured and reported to the notifier before being converted into an error. If the RPC itself returns an error, that error is returned; otherwise any error produced by Hystrix is returned. -## func [NRHttpTracer]() +## func [NRHttpTracer]() ```go func NRHttpTracer(pattern string, h http.HandlerFunc) (string, http.HandlerFunc) @@ -290,7 +291,7 @@ func NRHttpTracer(pattern string, h http.HandlerFunc) (string, http.HandlerFunc) NRHttpTracer adds newrelic tracing to this http function -## func [NewRelicClientInterceptor]() +## func [NewRelicClientInterceptor]() ```go func NewRelicClientInterceptor() grpc.UnaryClientInterceptor @@ -299,7 +300,7 @@ func NewRelicClientInterceptor() grpc.UnaryClientInterceptor NewRelicClientInterceptor intercepts all client actions and reports them to newrelic. When NewRelic app is nil \(no license key configured\), returns a pass\-through interceptor to avoid overhead. -## func [NewRelicInterceptor]() +## func [NewRelicInterceptor]() ```go func NewRelicInterceptor() grpc.UnaryServerInterceptor @@ -308,7 +309,7 @@ func NewRelicInterceptor() grpc.UnaryServerInterceptor NewRelicInterceptor intercepts all server actions and reports them to newrelic. When NewRelic app is nil \(no license key configured\), returns a pass\-through interceptor to avoid overhead. -## func [OptionsInterceptor]() +## func [OptionsInterceptor]() ```go func OptionsInterceptor() grpc.UnaryServerInterceptor @@ -317,7 +318,7 @@ func OptionsInterceptor() grpc.UnaryServerInterceptor -## func [PanicRecoveryInterceptor]() +## func [PanicRecoveryInterceptor]() ```go func PanicRecoveryInterceptor() grpc.UnaryServerInterceptor @@ -325,8 +326,17 @@ func PanicRecoveryInterceptor() grpc.UnaryServerInterceptor + +## func [PanicRecoveryStreamInterceptor]() + +```go +func PanicRecoveryStreamInterceptor() grpc.StreamServerInterceptor +``` + +PanicRecoveryStreamInterceptor recovers from panics in stream handlers, logs the panic and stack trace, and reports it to the error notifier. + -## func [ProtoValidateInterceptor]() +## func [ProtoValidateInterceptor]() ```go func ProtoValidateInterceptor() grpc.UnaryServerInterceptor @@ -335,7 +345,7 @@ func ProtoValidateInterceptor() grpc.UnaryServerInterceptor ProtoValidateInterceptor returns a unary server interceptor that validates incoming messages using protovalidate annotations. Returns InvalidArgument on validation failure. Uses GlobalValidator by default; if custom options are set via SetProtoValidateOptions, creates a new validator with those options. -## func [ProtoValidateStreamInterceptor]() +## func [ProtoValidateStreamInterceptor]() ```go func ProtoValidateStreamInterceptor() grpc.StreamServerInterceptor @@ -344,7 +354,7 @@ func ProtoValidateStreamInterceptor() grpc.StreamServerInterceptor ProtoValidateStreamInterceptor returns a stream server interceptor that validates incoming messages using protovalidate annotations. -## func [ResponseTimeLoggingInterceptor]() +## func [ResponseTimeLoggingInterceptor]() ```go func ResponseTimeLoggingInterceptor(ff FilterFunc) grpc.UnaryServerInterceptor @@ -353,7 +363,7 @@ func ResponseTimeLoggingInterceptor(ff FilterFunc) grpc.UnaryServerInterceptor ResponseTimeLoggingInterceptor logs response time for each request on server -## func [ResponseTimeLoggingStreamInterceptor]() +## func [ResponseTimeLoggingStreamInterceptor]() ```go func ResponseTimeLoggingStreamInterceptor() grpc.StreamServerInterceptor @@ -362,7 +372,7 @@ func ResponseTimeLoggingStreamInterceptor() grpc.StreamServerInterceptor ResponseTimeLoggingStreamInterceptor logs response time for stream RPCs. -## func [ServerErrorInterceptor]() +## func [ServerErrorInterceptor]() ```go func ServerErrorInterceptor() grpc.UnaryServerInterceptor @@ -371,7 +381,7 @@ func ServerErrorInterceptor() grpc.UnaryServerInterceptor ServerErrorInterceptor intercepts all server actions and reports them to error notifier -## func [ServerErrorStreamInterceptor]() +## func [ServerErrorStreamInterceptor]() ```go func ServerErrorStreamInterceptor() grpc.StreamServerInterceptor @@ -380,16 +390,16 @@ func ServerErrorStreamInterceptor() grpc.StreamServerInterceptor ServerErrorStreamInterceptor intercepts server errors for stream RPCs and reports them to the error notifier. -## func [SetClientMetricsOptions]() +## func [SetClientMetricsOptions]() ```go func SetClientMetricsOptions(opts ...grpcprom.ClientMetricsOption) ``` -SetClientMetricsOptions appends gRPC client metrics options. Must be called during initialization, before the server starts. Not safe for concurrent use. +SetClientMetricsOptions appends gRPC client metrics options. Must be called during initialization, before any RPCs are made. Not safe for concurrent use. -## func [SetDebugLogHeaderName]() +## func [SetDebugLogHeaderName]() ```go func SetDebugLogHeaderName(name string) @@ -398,7 +408,7 @@ func SetDebugLogHeaderName(name string) SetDebugLogHeaderName sets the gRPC metadata header name that triggers per\-request log level override. Default is "x\-debug\-log\-level". The header value should be a valid log level \(e.g., "debug"\). Empty names are ignored. Must be called during initialization. -## func [SetDefaultRateLimit]() +## func [SetDefaultRateLimit]() ```go func SetDefaultRateLimit(rps float64, burst int) @@ -407,7 +417,7 @@ func SetDefaultRateLimit(rps float64, burst int) SetDefaultRateLimit configures the built\-in token bucket rate limiter. rps is requests per second, burst is the maximum burst size. This is a per\-pod in\-memory limit — with N pods, the effective cluster\-wide limit is N × rps. For distributed rate limiting, use SetRateLimiter\(\) with a custom implementation \(e.g., Redis\-backed\). Must be called during initialization. -## func [SetDefaultTimeout]() +## func [SetDefaultTimeout]() ```go func SetDefaultTimeout(d time.Duration) @@ -416,7 +426,7 @@ func SetDefaultTimeout(d time.Duration) SetDefaultTimeout sets the default timeout applied to incoming unary RPCs that arrive without a deadline. When set to \<= 0, the timeout interceptor is disabled \(pass\-through\). Default is 60s. Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [SetDisableDebugLogInterceptor]() +## func [SetDisableDebugLogInterceptor]() ```go func SetDisableDebugLogInterceptor(disable bool) @@ -425,7 +435,7 @@ func SetDisableDebugLogInterceptor(disable bool) SetDisableDebugLogInterceptor disables the DebugLogInterceptor in the default interceptor chain. Must be called during initialization, before the server starts. -## func [SetDisableProtoValidate]() +## func [SetDisableProtoValidate]() ```go func SetDisableProtoValidate(disable bool) @@ -434,7 +444,7 @@ func SetDisableProtoValidate(disable bool) SetDisableProtoValidate disables the protovalidate interceptor in the default chain. Must be called during init\(\) — not safe for concurrent use. -## func [SetDisableRateLimit]() +## func [SetDisableRateLimit]() ```go func SetDisableRateLimit(disable bool) @@ -443,7 +453,7 @@ func SetDisableRateLimit(disable bool) SetDisableRateLimit disables the rate limiting interceptor in the default interceptor chain. Must be called during initialization. -## func [SetFilterFunc]() +## func [SetFilterFunc]() ```go func SetFilterFunc(ctx context.Context, ff FilterFunc) @@ -452,7 +462,7 @@ func SetFilterFunc(ctx context.Context, ff FilterFunc) SetFilterFunc sets the default filter function to be used by interceptors. Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [SetFilterMethods]() +## func [SetFilterMethods]() ```go func SetFilterMethods(ctx context.Context, methods []string) @@ -461,7 +471,7 @@ func SetFilterMethods(ctx context.Context, methods []string) SetFilterMethods sets the list of method substrings to exclude from tracing/logging. It rebuilds the internal cache. Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [SetProtoValidateOptions]() +## func [SetProtoValidateOptions]() ```go func SetProtoValidateOptions(opts ...protovalidate.ValidatorOption) @@ -470,7 +480,7 @@ func SetProtoValidateOptions(opts ...protovalidate.ValidatorOption) SetProtoValidateOptions configures custom protovalidate options \(e.g., custom constraints\). Must be called during init\(\) — not safe for concurrent use. Follows ColdBrew's init\-only config pattern. -## func [SetRateLimiter]() +## func [SetRateLimiter]() ```go func SetRateLimiter(limiter ratelimit_middleware.Limiter) @@ -479,7 +489,7 @@ func SetRateLimiter(limiter ratelimit_middleware.Limiter) SetRateLimiter sets a custom rate limiter implementation. This overrides the built\-in token bucket limiter. Must be called during initialization. -## func [SetResponseTimeLogErrorOnly]() +## func [SetResponseTimeLogErrorOnly]() ```go func SetResponseTimeLogErrorOnly(errorOnly bool) @@ -488,7 +498,7 @@ func SetResponseTimeLogErrorOnly(errorOnly bool) SetResponseTimeLogErrorOnly when set to true, only logs response time when the request returns an error. Successful requests are not logged. Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [SetResponseTimeLogLevel]() +## func [SetResponseTimeLogLevel]() ```go func SetResponseTimeLogLevel(ctx context.Context, level loggers.Level) @@ -497,7 +507,7 @@ func SetResponseTimeLogLevel(ctx context.Context, level loggers.Level) SetResponseTimeLogLevel sets the log level for response time logging. Default is InfoLevel. Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [SetServerMetricsOptions]() +## func [SetServerMetricsOptions]() ```go func SetServerMetricsOptions(opts ...grpcprom.ServerMetricsOption) @@ -506,7 +516,7 @@ func SetServerMetricsOptions(opts ...grpcprom.ServerMetricsOption) SetServerMetricsOptions appends gRPC server metrics options \(histogram, labels, namespace, etc.\). Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [TraceIdInterceptor]() +## func [TraceIdInterceptor]() ```go func TraceIdInterceptor() grpc.UnaryServerInterceptor @@ -515,7 +525,7 @@ func TraceIdInterceptor() grpc.UnaryServerInterceptor TraceIdInterceptor allows injecting trace id from request objects -## func [UseColdBrewClientInterceptors]() +## func [UseColdBrewClientInterceptors]() ```go func UseColdBrewClientInterceptors(ctx context.Context, flag bool) @@ -524,7 +534,7 @@ func UseColdBrewClientInterceptors(ctx context.Context, flag bool) UseColdBrewClientInterceptors allows enabling/disabling coldbrew client interceptors. When set to false, the coldbrew client interceptors will not be used. Must be called during initialization, before any RPCs are made. Not safe for concurrent use. -## func [UseColdBrewServerInterceptors]() +## func [UseColdBrewServerInterceptors]() ```go func UseColdBrewServerInterceptors(ctx context.Context, flag bool) @@ -533,7 +543,7 @@ func UseColdBrewServerInterceptors(ctx context.Context, flag bool) UseColdBrewServerInterceptors allows enabling/disabling coldbrew server interceptors. When set to false, the coldbrew server interceptors will not be used. Must be called during initialization, before the server starts. Not safe for concurrent use. -## type [FilterFunc]() +## type [FilterFunc]() If it returns false, the given request will not be traced. diff --git a/chain.go b/chain.go new file mode 100644 index 0000000..5c75839 --- /dev/null +++ b/chain.go @@ -0,0 +1,52 @@ +package interceptors + +import ( + "context" + + "google.golang.org/grpc" +) + +// chainUnaryServer chains multiple unary server interceptors into one. +func chainUnaryServer(interceptors []grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + chain := handler + for i := len(interceptors) - 1; i >= 0; i-- { + interceptor := interceptors[i] + next := chain + chain = func(ctx context.Context, req any) (any, error) { + return interceptor(ctx, req, info, next) + } + } + return chain(ctx, req) + } +} + +// chainUnaryClient chains multiple unary client interceptors into one. +func chainUnaryClient(interceptors []grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + chain := invoker + for i := len(interceptors) - 1; i >= 0; i-- { + interceptor := interceptors[i] + next := chain + chain = func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + return interceptor(ctx, method, req, reply, cc, next, opts...) + } + } + return chain(ctx, method, req, reply, cc, opts...) + } +} + +// chainStreamClient chains multiple stream client interceptors into one. +func chainStreamClient(interceptors []grpc.StreamClientInterceptor) grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + chain := streamer + for i := len(interceptors) - 1; i >= 0; i-- { + interceptor := interceptors[i] + next := chain + chain = func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return interceptor(ctx, desc, cc, method, next, opts...) + } + } + return chain(ctx, desc, cc, method, opts...) + } +} diff --git a/client.go b/client.go new file mode 100644 index 0000000..6b71608 --- /dev/null +++ b/client.go @@ -0,0 +1,161 @@ +package interceptors + +import ( + "context" + stdError "errors" + "fmt" + "slices" + + "github.com/afex/hystrix-go/hystrix" + "github.com/go-coldbrew/errors" + "github.com/go-coldbrew/errors/notifier" + "github.com/go-coldbrew/log" + nrutil "github.com/go-coldbrew/tracing/newrelic" + grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/retry" + "github.com/newrelic/go-agent/v3/integrations/nrgrpc" + "google.golang.org/grpc" + "google.golang.org/grpc/status" +) + +// DefaultClientInterceptors are the set of default interceptors that should be applied to all client calls +func DefaultClientInterceptors(defaultOpts ...any) []grpc.UnaryClientInterceptor { + ints := []grpc.UnaryClientInterceptor{} + if len(defaultConfig.unaryClientInterceptors) > 0 { + ints = append(ints, defaultConfig.unaryClientInterceptors...) + } + if defaultConfig.useCBClientInterceptors { + hystrixOptions := make([]grpc.CallOption, 0) + for _, opt := range defaultOpts { + if opt == nil { + continue + } + if o, ok := opt.(grpc.CallOption); ok { + hystrixOptions = append(hystrixOptions, o) + } + } + ints = append(ints, + HystrixClientInterceptor(hystrixOptions...), + grpc_retry.UnaryClientInterceptor(), + NewRelicClientInterceptor(), + getClientMetrics().UnaryClientInterceptor(), + ) + } + return ints +} + +// DefaultClientStreamInterceptors are the set of default interceptors that should be applied to all stream client calls +func DefaultClientStreamInterceptors(defaultOpts ...any) []grpc.StreamClientInterceptor { + ints := []grpc.StreamClientInterceptor{} + if len(defaultConfig.streamClientInterceptors) > 0 { + ints = append(ints, defaultConfig.streamClientInterceptors...) + } + if defaultConfig.useCBClientInterceptors { + if nrutil.GetNewRelicApp() != nil { + ints = append(ints, nrgrpc.StreamClientInterceptor) + } + ints = append(ints, getClientMetrics().StreamClientInterceptor()) + } + return ints +} + +// DefaultClientInterceptor are the set of default interceptors that should be applied to all client calls +func DefaultClientInterceptor(defaultOpts ...any) grpc.UnaryClientInterceptor { + return chainUnaryClient(DefaultClientInterceptors(defaultOpts...)) +} + +// DefaultClientStreamInterceptor are the set of default interceptors that should be applied to all stream client calls +func DefaultClientStreamInterceptor(defaultOpts ...any) grpc.StreamClientInterceptor { + return chainStreamClient(DefaultClientStreamInterceptors(defaultOpts...)) +} + +// NewRelicClientInterceptor intercepts all client actions and reports them to newrelic. +// When NewRelic app is nil (no license key configured), returns a pass-through +// interceptor to avoid overhead. +func NewRelicClientInterceptor() grpc.UnaryClientInterceptor { + app := nrutil.GetNewRelicApp() + if app == nil { + return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + return invoker(ctx, method, req, reply, cc, opts...) + } + } + return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if defaultConfig.filterFunc(ctx, method) { + return nrgrpc.UnaryClientInterceptor(ctx, method, req, reply, cc, invoker, opts...) + } else { + return invoker(ctx, method, req, reply, cc, opts...) + } + } +} + +// Deprecated: GRPCClientInterceptor is no longer needed. gRPC tracing is now handled +// by google.golang.org/grpc/stats/opentelemetry, configured via opentelemetry.DialOption() +// at the client level. This function is retained for backwards compatibility but returns +// a no-op interceptor. +func GRPCClientInterceptor(_ ...any) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + return invoker(ctx, method, req, reply, cc, opts...) + } +} + +// HystrixClientInterceptor returns a unary client interceptor that executes the RPC inside a Hystrix command. +// +// Note: This interceptor wraps github.com/afex/hystrix-go which has been unmaintained since 2018. +// Consider migrating to github.com/failsafe-go/failsafe-go for circuit breaker functionality. +// +// The interceptor applies provided default and per-call client options to configure Hystrix behavior (for example the command name, disabled flag, excluded errors, and excluded gRPC status codes). +// If Hystrix is disabled via options, the RPC is invoked directly. If the underlying RPC returns an error that matches any configured excluded error or whose gRPC status code matches any configured excluded code, Hystrix fallback is skipped and the RPC error is returned. +// Panics raised during the RPC invocation are captured and reported to the notifier before being converted into an error. If the RPC itself returns an error, that error is returned; otherwise any error produced by Hystrix is returned. +func HystrixClientInterceptor(defaultOpts ...grpc.CallOption) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + options := clientOptions{ + hystrixName: method, + } + for _, opt := range defaultOpts { + if opt != nil { + if o, ok := opt.(clientOption); ok { + o.process(&options) + } + } + } + for _, opt := range opts { + if opt != nil { + if o, ok := opt.(clientOption); ok { + o.process(&options) + } + } + } + if options.disableHystrix { + // short circuit if hystrix is disabled + return invoker(ctx, method, req, reply, cc, opts...) + } + newCtx, cancel := context.WithCancel(ctx) + defer cancel() + + var invokerErr error + hystrixErr := hystrix.Do(options.hystrixName, func() (err error) { + defer func() { + if r := recover(); r != nil { + err = errors.Wrap(fmt.Errorf("panic inside hystrix method: %s, req: %v, reply: %v", method, req, reply), "Hystrix") + log.Error(ctx, "panic", r, "method", method, "req", req, "reply", reply) + } + }() + defer notifier.NotifyOnPanic(newCtx, method) + invokerErr = invoker(newCtx, method, req, reply, cc, opts...) + for _, excludedErr := range options.excludedErrors { + if stdError.Is(invokerErr, excludedErr) { + return nil + } + } + if st, ok := status.FromError(invokerErr); ok { + if slices.Contains(options.excludedCodes, st.Code()) { + return nil + } + } + return invokerErr + }, nil) + if invokerErr != nil { + return invokerErr + } + return hystrixErr + } +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..d01a6fc --- /dev/null +++ b/config.go @@ -0,0 +1,205 @@ +package interceptors + +import ( + "context" + "strings" + "time" + + "buf.build/go/protovalidate" + "github.com/go-coldbrew/log/loggers" + grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" + ratelimit_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/ratelimit" + "golang.org/x/time/rate" + "google.golang.org/grpc" +) + +// interceptorConfig consolidates all package-level configuration. +// All Set* functions modify fields on defaultConfig. +type interceptorConfig struct { + // Interceptor slices + unaryServerInterceptors []grpc.UnaryServerInterceptor + streamServerInterceptors []grpc.StreamServerInterceptor + unaryClientInterceptors []grpc.UnaryClientInterceptor + streamClientInterceptors []grpc.StreamClientInterceptor + + // Feature flags + useCBServerInterceptors bool + useCBClientInterceptors bool + disableProtoValidate bool + disableDebugLogInterceptor bool + disableRateLimit bool + responseTimeLogErrorOnly bool + + // Configuration values + responseTimeLogLevel loggers.Level + defaultTimeout time.Duration + debugLogHeaderName string + filterFunc FilterFunc + + // Metrics options + srvMetricsOpts []grpcprom.ServerMetricsOption + cltMetricsOpts []grpcprom.ClientMetricsOption + + // ProtoValidate options + protoValidateOpts []protovalidate.ValidatorOption + + // Rate limiting + rateLimiter ratelimit_middleware.Limiter + defaultRateLimit rate.Limit + defaultRateBurst int +} + +var defaultConfig = newDefaultConfig() + +func newDefaultConfig() interceptorConfig { + return interceptorConfig{ + useCBServerInterceptors: true, + useCBClientInterceptors: true, + responseTimeLogLevel: loggers.InfoLevel, + defaultTimeout: 60 * time.Second, + debugLogHeaderName: "x-debug-log-level", + filterFunc: FilterMethodsFunc, + defaultRateLimit: rate.Inf, + } +} + +// SetResponseTimeLogLevel sets the log level for response time logging. +// Default is InfoLevel. Must be called during initialization, before the server starts. Not safe for concurrent use. +func SetResponseTimeLogLevel(ctx context.Context, level loggers.Level) { + defaultConfig.responseTimeLogLevel = level +} + +// SetResponseTimeLogErrorOnly when set to true, only logs response time when +// the request returns an error. Successful requests are not logged. +// Must be called during initialization, before the server starts. Not safe for concurrent use. +func SetResponseTimeLogErrorOnly(errorOnly bool) { + defaultConfig.responseTimeLogErrorOnly = errorOnly +} + +// SetDefaultTimeout sets the default timeout applied to incoming unary RPCs +// that arrive without a deadline. When set to <= 0, the timeout interceptor is +// disabled (pass-through). Default is 60s. +// Must be called during initialization, before the server starts. Not safe for concurrent use. +func SetDefaultTimeout(d time.Duration) { + defaultConfig.defaultTimeout = d +} + +// SetFilterFunc sets the default filter function to be used by interceptors. +// Must be called during initialization, before the server starts. Not safe for concurrent use. +func SetFilterFunc(ctx context.Context, ff FilterFunc) { + if ff != nil { + defaultConfig.filterFunc = ff + } +} + +// AddUnaryServerInterceptor adds a server interceptor to default server interceptors. +// Must be called during initialization, before the server starts. Not safe for concurrent use. +func AddUnaryServerInterceptor(ctx context.Context, i ...grpc.UnaryServerInterceptor) { + defaultConfig.unaryServerInterceptors = append(defaultConfig.unaryServerInterceptors, i...) +} + +// AddStreamServerInterceptor adds a server interceptor to default server interceptors. +// Must be called during initialization, before the server starts. Not safe for concurrent use. +func AddStreamServerInterceptor(ctx context.Context, i ...grpc.StreamServerInterceptor) { + defaultConfig.streamServerInterceptors = append(defaultConfig.streamServerInterceptors, i...) +} + +// UseColdBrewServerInterceptors allows enabling/disabling coldbrew server interceptors. +// When set to false, the coldbrew server interceptors will not be used. +// Must be called during initialization, before the server starts. Not safe for concurrent use. +func UseColdBrewServerInterceptors(ctx context.Context, flag bool) { + defaultConfig.useCBServerInterceptors = flag +} + +// AddUnaryClientInterceptor adds a client interceptor to default client interceptors. +// Must be called during initialization, before any RPCs are made. Not safe for concurrent use. +func AddUnaryClientInterceptor(ctx context.Context, i ...grpc.UnaryClientInterceptor) { + defaultConfig.unaryClientInterceptors = append(defaultConfig.unaryClientInterceptors, i...) +} + +// AddStreamClientInterceptor adds a client stream interceptor to default client stream interceptors. +// Must be called during initialization, before any RPCs are made. Not safe for concurrent use. +func AddStreamClientInterceptor(ctx context.Context, i ...grpc.StreamClientInterceptor) { + defaultConfig.streamClientInterceptors = append(defaultConfig.streamClientInterceptors, i...) +} + +// UseColdBrewClientInterceptors allows enabling/disabling coldbrew client interceptors. +// When set to false, the coldbrew client interceptors will not be used. +// Must be called during initialization, before any RPCs are made. Not safe for concurrent use. +func UseColdBrewClientInterceptors(ctx context.Context, flag bool) { + defaultConfig.useCBClientInterceptors = flag +} + +// SetServerMetricsOptions appends gRPC server metrics options (histogram, labels, namespace, etc.). +// Must be called during initialization, before the server starts. Not safe for concurrent use. +func SetServerMetricsOptions(opts ...grpcprom.ServerMetricsOption) { + defaultConfig.srvMetricsOpts = append(defaultConfig.srvMetricsOpts, opts...) +} + +// SetClientMetricsOptions appends gRPC client metrics options. +// Must be called during initialization, before any RPCs are made. Not safe for concurrent use. +func SetClientMetricsOptions(opts ...grpcprom.ClientMetricsOption) { + defaultConfig.cltMetricsOpts = append(defaultConfig.cltMetricsOpts, opts...) +} + +// SetProtoValidateOptions configures custom protovalidate options (e.g., +// custom constraints). Must be called during init() — not safe for +// concurrent use. Follows ColdBrew's init-only config pattern. +func SetProtoValidateOptions(opts ...protovalidate.ValidatorOption) { + defaultConfig.protoValidateOpts = append(defaultConfig.protoValidateOpts, opts...) +} + +// SetDisableProtoValidate disables the protovalidate interceptor in the +// default chain. Must be called during init() — not safe for concurrent use. +func SetDisableProtoValidate(disable bool) { + defaultConfig.disableProtoValidate = disable +} + +// SetDisableDebugLogInterceptor disables the DebugLogInterceptor in the default +// interceptor chain. Must be called during initialization, before the server starts. +func SetDisableDebugLogInterceptor(disable bool) { + defaultConfig.disableDebugLogInterceptor = disable +} + +// SetDebugLogHeaderName sets the gRPC metadata header name that triggers +// per-request log level override. Default is "x-debug-log-level". The header +// value should be a valid log level (e.g., "debug"). Empty names are ignored. +// Must be called during initialization. +func SetDebugLogHeaderName(name string) { + name = strings.ToLower(strings.TrimSpace(name)) + if name == "" { + return + } + defaultConfig.debugLogHeaderName = name +} + +// GetDebugLogHeaderName returns the current debug log header name. +func GetDebugLogHeaderName() string { + return defaultConfig.debugLogHeaderName +} + +// SetDisableRateLimit disables the rate limiting interceptor in the default +// interceptor chain. Must be called during initialization. +func SetDisableRateLimit(disable bool) { + defaultConfig.disableRateLimit = disable +} + +// SetRateLimiter sets a custom rate limiter implementation. This overrides the +// built-in token bucket limiter. Must be called during initialization. +func SetRateLimiter(limiter ratelimit_middleware.Limiter) { + defaultConfig.rateLimiter = limiter +} + +// SetDefaultRateLimit configures the built-in token bucket rate limiter. +// rps is requests per second, burst is the maximum burst size. +// This is a per-pod in-memory limit — with N pods, the effective cluster-wide +// limit is N × rps. For distributed rate limiting, use SetRateLimiter() with +// a custom implementation (e.g., Redis-backed). +// Must be called during initialization. +func SetDefaultRateLimit(rps float64, burst int) { + defaultConfig.defaultRateLimit = rate.Limit(rps) + if burst < 1 { + burst = 1 + } + defaultConfig.defaultRateBurst = burst +} diff --git a/documentations.go b/documentations.go deleted file mode 100644 index 8f46784..0000000 --- a/documentations.go +++ /dev/null @@ -1 +0,0 @@ -package interceptors diff --git a/filter.go b/filter.go new file mode 100644 index 0000000..f945b04 --- /dev/null +++ b/filter.go @@ -0,0 +1,114 @@ +package interceptors + +import ( + "context" + "strings" + "sync" + "sync/atomic" + + "google.golang.org/grpc" +) + +// If it returns false, the given request will not be traced. +type FilterFunc func(ctx context.Context, fullMethodName string) bool + +var ( + // Deprecated: FilterMethods is the list of methods that are filtered by default. + // Use SetFilterMethods instead. Only some direct mutations (replacing the slice + // or changing the first element) are detected by internal change detection; + // other in-place changes may not invalidate caches correctly. + FilterMethods = []string{"healthcheck", "readycheck", "serverreflectioninfo"} +) + +// filterState holds pre-computed filter data and a per-instance cache. +// A new filterState is created whenever FilterMethods changes, which +// atomically invalidates the old cache. +type filterState struct { + methods []string // lowercased filter substrings + cache sync.Map // map[string]bool + sourceLen int // len(FilterMethods) when this state was built + sourceFirst string // FilterMethods[0] when built (fast mutation check) +} + +var currentFilter atomic.Pointer[filterState] + +func init() { + currentFilter.Store(buildFilterState()) +} + +func buildFilterState() *filterState { + lower := make([]string, len(FilterMethods)) + for i, m := range FilterMethods { + lower[i] = strings.ToLower(m) + } + s := &filterState{ + methods: lower, + sourceLen: len(FilterMethods), + } + if len(FilterMethods) > 0 { + s.sourceFirst = FilterMethods[0] + } + return s +} + +// changed reports whether the deprecated FilterMethods variable +// has been mutated since this filterState was built. +func (s *filterState) changed() bool { + if len(FilterMethods) != s.sourceLen { + return true + } + if s.sourceLen > 0 && FilterMethods[0] != s.sourceFirst { + return true + } + return false +} + +// SetFilterMethods sets the list of method substrings to exclude from tracing/logging. +// It rebuilds the internal cache. Must be called during initialization, before +// the server starts. Not safe for concurrent use. +func SetFilterMethods(ctx context.Context, methods []string) { + // Defensive copy to prevent aliasing: if the caller later mutates + // their slice, it won't silently affect filtering. + cp := make([]string, len(methods)) + copy(cp, methods) + FilterMethods = cp + currentFilter.Store(buildFilterState()) +} + +// isGRPCRequest returns true if the context is a gRPC server context. +// Uses grpc.Method(ctx) which is a single context value lookup with zero +// allocations. HTTP handlers pass plain contexts where this returns false. +// This is used to decide whether to cache filter decisions — gRPC method +// names are stable and finite, while HTTP paths can be high-cardinality. +func isGRPCRequest(ctx context.Context) bool { + _, ok := grpc.Method(ctx) + return ok +} + +// FilterMethodsFunc is the default implementation of Filter function +func FilterMethodsFunc(ctx context.Context, fullMethodName string) bool { + f := currentFilter.Load() + // Auto-detect direct mutation of the deprecated FilterMethods variable. + if f.changed() { + f = buildFilterState() + currentFilter.Store(f) + } + cacheable := isGRPCRequest(ctx) + if cacheable { + if v, ok := f.cache.Load(fullMethodName); ok { + return v.(bool) + } + } + lowerMethod := strings.ToLower(fullMethodName) + result := true + for _, name := range f.methods { + if strings.Contains(lowerMethod, name) { + result = false + break + } + } + if cacheable { + f.cache.Store(fullMethodName, result) + } + return result +} diff --git a/http.go b/http.go new file mode 100644 index 0000000..f5c93e7 --- /dev/null +++ b/http.go @@ -0,0 +1,80 @@ +package interceptors + +import ( + "context" + "net/http" + "sync" + + nrutil "github.com/go-coldbrew/tracing/newrelic" + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + newrelic "github.com/newrelic/go-agent/v3/newrelic" + "google.golang.org/grpc" +) + +var ( + httpToGRPCOnce sync.Once + httpToGRPCInterceptor grpc.UnaryServerInterceptor +) + +func getHTTPtoGRPCInterceptor() grpc.UnaryServerInterceptor { + httpToGRPCOnce.Do(func() { + httpToGRPCInterceptor = chainUnaryServer(DefaultInterceptors()) + }) + return httpToGRPCInterceptor +} + +// DoHTTPtoGRPC allows calling the interceptors when you use the RegisterHandlerServer in grpc-gateway. +// This enables in-process HTTP-to-gRPC calls with the full interceptor chain (logging, tracing, metrics, +// panic recovery) without a network hop — the fastest option for gateway performance. +// The interceptor chain is cached on first invocation. All interceptor configuration +// (AddUnaryServerInterceptor, SetFilterFunc, etc.) must be finalized before the first call. +// See example below for reference. +// +// func (s *svc) Echo(ctx context.Context, req *proto.EchoRequest) (*proto.EchoResponse, error) { +// handler := func(ctx context.Context, req interface{}) (interface{}, error) { +// return s.echo(ctx, req.(*proto.EchoRequest)) +// } +// r, err := DoHTTPtoGRPC(ctx, s, handler, req) +// if err != nil { +// return nil, err +// } +// return r.(*proto.EchoResponse), nil +// } +// +// func (s *svc) echo(ctx context.Context, req *proto.EchoRequest) (*proto.EchoResponse, error) { +// .... implementation .... +// } +func DoHTTPtoGRPC(ctx context.Context, svr any, handler func(ctx context.Context, req any) (any, error), in any) (any, error) { + method, ok := runtime.RPCMethod(ctx) + if ok { + interceptor := getHTTPtoGRPCInterceptor() + info := &grpc.UnaryServerInfo{ + Server: svr, + FullMethod: method, + } + return interceptor(ctx, in, info, handler) + } + return handler(ctx, in) +} + +// NRHttpTracer adds newrelic tracing to this http function +func NRHttpTracer(pattern string, h http.HandlerFunc) (string, http.HandlerFunc) { + app := nrutil.GetNewRelicApp() + if app == nil { + return pattern, h + } + if pattern != "" { + return newrelic.WrapHandleFunc(app, pattern, h) + } + return pattern, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // filter functions we do not need + if defaultConfig.filterFunc(context.Background(), r.URL.Path) { + txn := app.StartTransaction(r.Method + " " + r.URL.Path) + defer txn.End() + w = txn.SetWebResponse(w) + txn.SetWebRequestHTTP(r) + r = newrelic.RequestWithTransactionContext(r, txn) + } + h.ServeHTTP(w, r) + }) +} diff --git a/interceptors.go b/interceptors.go index 50d8c8a..9dfb514 100644 --- a/interceptors.go +++ b/interceptors.go @@ -6,37 +6,7 @@ package interceptors import ( - "context" - stdError "errors" - "fmt" - "net/http" - "runtime/debug" - "slices" - "strings" - "sync" - "sync/atomic" - "time" - - "buf.build/go/protovalidate" - "github.com/afex/hystrix-go/hystrix" "github.com/go-coldbrew/errors" - "github.com/go-coldbrew/errors/notifier" - "github.com/go-coldbrew/log" - "github.com/go-coldbrew/log/loggers" - "github.com/go-coldbrew/options" - nrutil "github.com/go-coldbrew/tracing/newrelic" - grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" - protovalidate_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate" - ratelimit_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/ratelimit" - grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/retry" - "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "github.com/newrelic/go-agent/v3/integrations/nrgrpc" - newrelic "github.com/newrelic/go-agent/v3/newrelic" - "github.com/prometheus/client_golang/prometheus" - "golang.org/x/time/rate" - "google.golang.org/grpc" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" ) // SupportPackageIsVersion1 is a compile-time assertion constant. @@ -47,890 +17,3 @@ const SupportPackageIsVersion1 = true // Compile-time version compatibility check. var _ = errors.SupportPackageIsVersion1 - -var ( - // Deprecated: FilterMethods is the list of methods that are filtered by default. - // Use SetFilterMethods instead. Only some direct mutations (replacing the slice - // or changing the first element) are detected by internal change detection; - // other in-place changes may not invalidate caches correctly. - FilterMethods = []string{"healthcheck", "readycheck", "serverreflectioninfo"} - defaultFilterFunc = FilterMethodsFunc - unaryServerInterceptors = []grpc.UnaryServerInterceptor{} - streamServerInterceptors = []grpc.StreamServerInterceptor{} - useCBServerInterceptors = true - unaryClientInterceptors = []grpc.UnaryClientInterceptor{} - streamClientInterceptors = []grpc.StreamClientInterceptor{} - useCBClientInterceptors = true - responseTimeLogLevel loggers.Level = loggers.InfoLevel - responseTimeLogErrorOnly bool - defaultTimeout time.Duration = 60 * time.Second // 0 disables - protoValidateOpts []protovalidate.ValidatorOption - disableProtoValidate bool - srvMetricsOpts []grpcprom.ServerMetricsOption - cltMetricsOpts []grpcprom.ClientMetricsOption - srvMetricsOnce sync.Once - srvMetrics *grpcprom.ServerMetrics - cltMetricsOnce sync.Once - cltMetrics *grpcprom.ClientMetrics - disableDebugLogInterceptor bool - debugLogHeaderName = "x-debug-log-level" - disableRateLimit bool - rateLimiter ratelimit_middleware.Limiter - rateLimiterOnce sync.Once - rateLimiterVal ratelimit_middleware.Limiter - defaultRateLimit rate.Limit = rate.Inf - defaultRateBurst int -) - -// SetResponseTimeLogLevel sets the log level for response time logging. -// Default is InfoLevel. Must be called during initialization, before the server starts. Not safe for concurrent use. -func SetResponseTimeLogLevel(ctx context.Context, level loggers.Level) { - responseTimeLogLevel = level -} - -// SetResponseTimeLogErrorOnly when set to true, only logs response time when -// the request returns an error. Successful requests are not logged. -// Must be called during initialization, before the server starts. Not safe for concurrent use. -func SetResponseTimeLogErrorOnly(errorOnly bool) { - responseTimeLogErrorOnly = errorOnly -} - -// SetDefaultTimeout sets the default timeout applied to incoming unary RPCs -// that arrive without a deadline. When set to <= 0, the timeout interceptor is -// disabled (pass-through). Default is 60s. -// Must be called during initialization, before the server starts. Not safe for concurrent use. -func SetDefaultTimeout(d time.Duration) { - defaultTimeout = d -} - -// If it returns false, the given request will not be traced. -type FilterFunc func(ctx context.Context, fullMethodName string) bool - -// filterState holds pre-computed filter data and a per-instance cache. -// A new filterState is created whenever FilterMethods changes, which -// atomically invalidates the old cache. -type filterState struct { - methods []string // lowercased filter substrings - cache sync.Map // map[string]bool - sourceLen int // len(FilterMethods) when this state was built - sourceFirst string // FilterMethods[0] when built (fast mutation check) -} - -var currentFilter atomic.Pointer[filterState] - -func init() { - currentFilter.Store(buildFilterState()) -} - -func buildFilterState() *filterState { - lower := make([]string, len(FilterMethods)) - for i, m := range FilterMethods { - lower[i] = strings.ToLower(m) - } - s := &filterState{ - methods: lower, - sourceLen: len(FilterMethods), - } - if len(FilterMethods) > 0 { - s.sourceFirst = FilterMethods[0] - } - return s -} - -// changed reports whether the deprecated FilterMethods variable -// has been mutated since this filterState was built. -func (s *filterState) changed() bool { - if len(FilterMethods) != s.sourceLen { - return true - } - if s.sourceLen > 0 && FilterMethods[0] != s.sourceFirst { - return true - } - return false -} - -// SetFilterMethods sets the list of method substrings to exclude from tracing/logging. -// It rebuilds the internal cache. Must be called during initialization, before -// the server starts. Not safe for concurrent use. -func SetFilterMethods(ctx context.Context, methods []string) { - // Defensive copy to prevent aliasing: if the caller later mutates - // their slice, it won't silently affect filtering. - cp := make([]string, len(methods)) - copy(cp, methods) - FilterMethods = cp - currentFilter.Store(buildFilterState()) -} - -// isGRPCRequest returns true if the context is a gRPC server context. -// Uses grpc.Method(ctx) which is a single context value lookup with zero -// allocations. HTTP handlers pass plain contexts where this returns false. -// This is used to decide whether to cache filter decisions — gRPC method -// names are stable and finite, while HTTP paths can be high-cardinality. -func isGRPCRequest(ctx context.Context) bool { - _, ok := grpc.Method(ctx) - return ok -} - -// FilterMethodsFunc is the default implementation of Filter function -func FilterMethodsFunc(ctx context.Context, fullMethodName string) bool { - f := currentFilter.Load() - // Auto-detect direct mutation of the deprecated FilterMethods variable. - if f.changed() { - f = buildFilterState() - currentFilter.Store(f) - } - cacheable := isGRPCRequest(ctx) - if cacheable { - if v, ok := f.cache.Load(fullMethodName); ok { - return v.(bool) - } - } - lowerMethod := strings.ToLower(fullMethodName) - result := true - for _, name := range f.methods { - if strings.Contains(lowerMethod, name) { - result = false - break - } - } - if cacheable { - f.cache.Store(fullMethodName, result) - } - return result -} - -// SetFilterFunc sets the default filter function to be used by interceptors. -// Must be called during initialization, before the server starts. Not safe for concurrent use. -func SetFilterFunc(ctx context.Context, ff FilterFunc) { - if ff != nil { - defaultFilterFunc = ff - } -} - -// AddUnaryServerInterceptor adds a server interceptor to default server interceptors. -// Must be called during initialization, before the server starts. Not safe for concurrent use. -func AddUnaryServerInterceptor(ctx context.Context, i ...grpc.UnaryServerInterceptor) { - unaryServerInterceptors = append(unaryServerInterceptors, i...) -} - -// AddStreamServerInterceptor adds a server interceptor to default server interceptors. -// Must be called during initialization, before the server starts. Not safe for concurrent use. -func AddStreamServerInterceptor(ctx context.Context, i ...grpc.StreamServerInterceptor) { - streamServerInterceptors = append(streamServerInterceptors, i...) -} - -// UseColdBrewServerInterceptors allows enabling/disabling coldbrew server interceptors. -// When set to false, the coldbrew server interceptors will not be used. -// Must be called during initialization, before the server starts. Not safe for concurrent use. -func UseColdBrewServerInterceptors(ctx context.Context, flag bool) { - useCBServerInterceptors = flag -} - -// AddUnaryClientInterceptor adds a client interceptor to default client interceptors. -// Must be called during initialization, before any RPCs are made. Not safe for concurrent use. -func AddUnaryClientInterceptor(ctx context.Context, i ...grpc.UnaryClientInterceptor) { - unaryClientInterceptors = append(unaryClientInterceptors, i...) -} - -// AddStreamClientInterceptor adds a client stream interceptor to default client stream interceptors. -// Must be called during initialization, before any RPCs are made. Not safe for concurrent use. -func AddStreamClientInterceptor(ctx context.Context, i ...grpc.StreamClientInterceptor) { - streamClientInterceptors = append(streamClientInterceptors, i...) -} - -// UseColdBrewClientInterceptors allows enabling/disabling coldbrew client interceptors. -// When set to false, the coldbrew client interceptors will not be used. -// Must be called during initialization, before any RPCs are made. Not safe for concurrent use. -func UseColdBrewClientInterceptors(ctx context.Context, flag bool) { - useCBClientInterceptors = flag -} - -// SetServerMetricsOptions appends gRPC server metrics options (histogram, labels, namespace, etc.). -// Must be called during initialization, before the server starts. Not safe for concurrent use. -func SetServerMetricsOptions(opts ...grpcprom.ServerMetricsOption) { - srvMetricsOpts = append(srvMetricsOpts, opts...) -} - -// SetClientMetricsOptions appends gRPC client metrics options. -// Must be called during initialization, before the server starts. Not safe for concurrent use. -func SetClientMetricsOptions(opts ...grpcprom.ClientMetricsOption) { - cltMetricsOpts = append(cltMetricsOpts, opts...) -} - -// SetProtoValidateOptions configures custom protovalidate options (e.g., -// custom constraints). Must be called during init() — not safe for -// concurrent use. Follows ColdBrew's init-only config pattern. -func SetProtoValidateOptions(opts ...protovalidate.ValidatorOption) { - protoValidateOpts = append(protoValidateOpts, opts...) -} - -// SetDisableProtoValidate disables the protovalidate interceptor in the -// default chain. Must be called during init() — not safe for concurrent use. -func SetDisableProtoValidate(disable bool) { - disableProtoValidate = disable -} - -// ProtoValidateInterceptor returns a unary server interceptor that validates -// incoming messages using protovalidate annotations. Returns InvalidArgument -// on validation failure. Uses GlobalValidator by default; if custom options -// are set via SetProtoValidateOptions, creates a new validator with those options. -func ProtoValidateInterceptor() grpc.UnaryServerInterceptor { - return protovalidate_middleware.UnaryServerInterceptor(getProtoValidator()) -} - -// ProtoValidateStreamInterceptor returns a stream server interceptor that -// validates incoming messages using protovalidate annotations. -func ProtoValidateStreamInterceptor() grpc.StreamServerInterceptor { - return protovalidate_middleware.StreamServerInterceptor(getProtoValidator()) -} - -var ( - protoValidatorOnce sync.Once - protoValidatorVal protovalidate.Validator -) - -// getProtoValidator returns a cached protovalidate.Validator configured with -// custom options if set, falling back to GlobalValidator. -func getProtoValidator() protovalidate.Validator { - protoValidatorOnce.Do(func() { - if len(protoValidateOpts) > 0 { - v, err := protovalidate.New(protoValidateOpts...) - if err != nil { - log.Error(context.Background(), "msg", "failed to create protovalidate validator with custom options, falling back to global", "err", err) - protoValidatorVal = protovalidate.GlobalValidator - return - } - protoValidatorVal = v - return - } - protoValidatorVal = protovalidate.GlobalValidator - }) - return protoValidatorVal -} - -func registerCollector(c prometheus.Collector) { - if err := prometheus.Register(c); err != nil { - var are prometheus.AlreadyRegisteredError - if stdError.As(err, &are) { - prometheus.Unregister(are.ExistingCollector) - if err := prometheus.Register(c); err != nil { - log.Warn(context.Background(), "msg", "failed to re-register gRPC metrics with Prometheus", "err", err) - } - return - } - log.Error(context.Background(), "msg", "gRPC Prometheus metrics registration failed. If you are using github.com/go-coldbrew/core, it may need to be updated to the latest version.", "err", err) - } -} - -func getServerMetrics() *grpcprom.ServerMetrics { - srvMetricsOnce.Do(func() { - srvMetrics = grpcprom.NewServerMetrics(srvMetricsOpts...) - registerCollector(srvMetrics) - }) - return srvMetrics -} - -func getClientMetrics() *grpcprom.ClientMetrics { - cltMetricsOnce.Do(func() { - cltMetrics = grpcprom.NewClientMetrics(cltMetricsOpts...) - registerCollector(cltMetrics) - }) - return cltMetrics -} - -// chainUnaryServer chains multiple unary server interceptors into one. -func chainUnaryServer(interceptors []grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { - return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - chain := handler - for i := len(interceptors) - 1; i >= 0; i-- { - interceptor := interceptors[i] - next := chain - chain = func(ctx context.Context, req any) (any, error) { - return interceptor(ctx, req, info, next) - } - } - return chain(ctx, req) - } -} - -// chainUnaryClient chains multiple unary client interceptors into one. -func chainUnaryClient(interceptors []grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor { - return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - chain := invoker - for i := len(interceptors) - 1; i >= 0; i-- { - interceptor := interceptors[i] - next := chain - chain = func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error { - return interceptor(ctx, method, req, reply, cc, next, opts...) - } - } - return chain(ctx, method, req, reply, cc, opts...) - } -} - -// chainStreamClient chains multiple stream client interceptors into one. -func chainStreamClient(interceptors []grpc.StreamClientInterceptor) grpc.StreamClientInterceptor { - return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { - chain := streamer - for i := len(interceptors) - 1; i >= 0; i-- { - interceptor := interceptors[i] - next := chain - chain = func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { - return interceptor(ctx, desc, cc, method, next, opts...) - } - } - return chain(ctx, desc, cc, method, opts...) - } -} - -var ( - httpToGRPCOnce sync.Once - httpToGRPCInterceptor grpc.UnaryServerInterceptor -) - -func getHTTPtoGRPCInterceptor() grpc.UnaryServerInterceptor { - httpToGRPCOnce.Do(func() { - httpToGRPCInterceptor = chainUnaryServer(DefaultInterceptors()) - }) - return httpToGRPCInterceptor -} - -// DoHTTPtoGRPC allows calling the interceptors when you use the RegisterHandlerServer in grpc-gateway. -// This enables in-process HTTP-to-gRPC calls with the full interceptor chain (logging, tracing, metrics, -// panic recovery) without a network hop — the fastest option for gateway performance. -// The interceptor chain is cached on first invocation. All interceptor configuration -// (AddUnaryServerInterceptor, SetFilterFunc, etc.) must be finalized before the first call. -// See example below for reference. -// -// func (s *svc) Echo(ctx context.Context, req *proto.EchoRequest) (*proto.EchoResponse, error) { -// handler := func(ctx context.Context, req interface{}) (interface{}, error) { -// return s.echo(ctx, req.(*proto.EchoRequest)) -// } -// r, err := DoHTTPtoGRPC(ctx, s, handler, req) -// if err != nil { -// return nil, err -// } -// return r.(*proto.EchoResponse), nil -// } -// -// func (s *svc) echo(ctx context.Context, req *proto.EchoRequest) (*proto.EchoResponse, error) { -// .... implementation .... -// } -func DoHTTPtoGRPC(ctx context.Context, svr any, handler func(ctx context.Context, req any) (any, error), in any) (any, error) { - method, ok := runtime.RPCMethod(ctx) - if ok { - interceptor := getHTTPtoGRPCInterceptor() - info := &grpc.UnaryServerInfo{ - Server: svr, - FullMethod: method, - } - return interceptor(ctx, in, info, handler) - } - return handler(ctx, in) -} - -// DefaultInterceptors are the set of default interceptors that are applied to all coldbrew methods -func DefaultInterceptors() []grpc.UnaryServerInterceptor { - ints := []grpc.UnaryServerInterceptor{} - if len(unaryServerInterceptors) > 0 { - ints = append(ints, unaryServerInterceptors...) - } - if useCBServerInterceptors { - ints = append(ints, DefaultTimeoutInterceptor()) - if !disableRateLimit { - if limiter := getRateLimiter(); limiter != nil { - ints = append(ints, ratelimit_middleware.UnaryServerInterceptor(limiter)) - } - } - ints = append(ints, - ResponseTimeLoggingInterceptor(defaultFilterFunc), - TraceIdInterceptor(), - ) - if !disableDebugLogInterceptor { - ints = append(ints, DebugLogInterceptor()) - } - if !disableProtoValidate { - ints = append(ints, ProtoValidateInterceptor()) - } - ints = append(ints, - getServerMetrics().UnaryServerInterceptor(), - ServerErrorInterceptor(), - NewRelicInterceptor(), - PanicRecoveryInterceptor(), - ) - } - return ints -} - -// DefaultClientInterceptors are the set of default interceptors that should be applied to all client calls -func DefaultClientInterceptors(defaultOpts ...any) []grpc.UnaryClientInterceptor { - ints := []grpc.UnaryClientInterceptor{} - if len(unaryClientInterceptors) > 0 { - ints = append(ints, unaryClientInterceptors...) - } - if useCBClientInterceptors { - hystrixOptions := make([]grpc.CallOption, 0) - for _, opt := range defaultOpts { - if opt == nil { - continue - } - if o, ok := opt.(grpc.CallOption); ok { - hystrixOptions = append(hystrixOptions, o) - } - } - ints = append(ints, - HystrixClientInterceptor(hystrixOptions...), - grpc_retry.UnaryClientInterceptor(), - NewRelicClientInterceptor(), - getClientMetrics().UnaryClientInterceptor(), - ) - } - return ints -} - -// DefaultClientStreamInterceptors are the set of default interceptors that should be applied to all stream client calls -func DefaultClientStreamInterceptors(defaultOpts ...any) []grpc.StreamClientInterceptor { - ints := []grpc.StreamClientInterceptor{} - if len(streamClientInterceptors) > 0 { - ints = append(ints, streamClientInterceptors...) - } - if useCBClientInterceptors { - if nrutil.GetNewRelicApp() != nil { - ints = append(ints, nrgrpc.StreamClientInterceptor) - } - ints = append(ints, getClientMetrics().StreamClientInterceptor()) - } - return ints -} - -// DefaultStreamInterceptors are the set of default interceptors that should be applied to all coldbrew streams -func DefaultStreamInterceptors() []grpc.StreamServerInterceptor { - ints := []grpc.StreamServerInterceptor{} - if len(streamServerInterceptors) > 0 { - ints = append(ints, streamServerInterceptors...) - } - if useCBServerInterceptors { - if !disableRateLimit { - if limiter := getRateLimiter(); limiter != nil { - ints = append(ints, ratelimit_middleware.StreamServerInterceptor(limiter)) - } - } - ints = append(ints, - ResponseTimeLoggingStreamInterceptor(), - ) - if !disableProtoValidate { - ints = append(ints, ProtoValidateStreamInterceptor()) - } - ints = append(ints, - getServerMetrics().StreamServerInterceptor(), - ServerErrorStreamInterceptor(), - ) - } - return ints -} - -// DefaultClientInterceptor are the set of default interceptors that should be applied to all client calls -func DefaultClientInterceptor(defaultOpts ...any) grpc.UnaryClientInterceptor { - return chainUnaryClient(DefaultClientInterceptors(defaultOpts...)) -} - -// DefaultClientStreamInterceptor are the set of default interceptors that should be applied to all stream client calls -func DefaultClientStreamInterceptor(defaultOpts ...any) grpc.StreamClientInterceptor { - return chainStreamClient(DefaultClientStreamInterceptors(defaultOpts...)) -} - -// DebugLoggingInterceptor is the interceptor that logs all request/response from a handler -func DebugLoggingInterceptor() grpc.UnaryServerInterceptor { - return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - log.Debug(ctx, "method", info.FullMethod, "request", req) - resp, err := handler(ctx, req) - log.Debug(ctx, "method", info.FullMethod, "response", resp, "err", err) - return resp, err - } -} - -// DefaultTimeoutInterceptor returns a unary server interceptor that applies a -// default deadline to incoming requests that have no deadline set. If the -// incoming context already has a deadline (regardless of duration), it is left -// unchanged. When defaultTimeout is <= 0, the interceptor is a no-op pass-through. -func DefaultTimeoutInterceptor() grpc.UnaryServerInterceptor { - return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - if defaultTimeout <= 0 { - return handler(ctx, req) - } - if _, ok := ctx.Deadline(); ok { - return handler(ctx, req) - } - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() - return handler(ctx, req) - } -} - -// ResponseTimeLoggingInterceptor logs response time for each request on server -func ResponseTimeLoggingInterceptor(ff FilterFunc) grpc.UnaryServerInterceptor { - return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { - ctx = loggers.AddToLogContext(ctx, "grpcMethod", info.FullMethod) - defer func(ctx context.Context, method string, begin time.Time) { - if ff != nil && !ff(ctx, method) { - return - } - if responseTimeLogErrorOnly && err == nil { - return - } - logArgs := make([]any, 0, 6) - logArgs = append(logArgs, "error", err, "took", time.Since(begin)) - if err != nil { - logArgs = append(logArgs, "grpcCode", status.Code(err)) - } - log.GetLogger().Log(ctx, responseTimeLogLevel, 1, logArgs...) - }(ctx, info.FullMethod, time.Now()) - resp, err = handler(ctx, req) - return resp, err - } -} - -func OptionsInterceptor() grpc.UnaryServerInterceptor { - return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - ctx = options.AddToOptions(ctx, "", "") - // loggers.AddToLogContext(ctx, "transport", "gRPC") - return handler(ctx, req) - } -} - -// NewRelicInterceptor intercepts all server actions and reports them to newrelic. -// When NewRelic app is nil (no license key configured), returns a pass-through -// interceptor to avoid overhead. -func NewRelicInterceptor() grpc.UnaryServerInterceptor { - app := nrutil.GetNewRelicApp() - if app == nil { - return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - return handler(ctx, req) - } - } - nrh := nrgrpc.UnaryServerInterceptor(app) - return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { - if defaultFilterFunc(ctx, info.FullMethod) { - return nrh(ctx, req, info, handler) - } else { - return handler(ctx, req) - } - } -} - -// ServerErrorInterceptor intercepts all server actions and reports them to error notifier -func ServerErrorInterceptor() grpc.UnaryServerInterceptor { - return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { - // set trace id if not set - ctx, _ = notifier.SetTraceIdWithValue(ctx) - start := time.Now() - resp, err = handler(ctx, req) - if err != nil && defaultFilterFunc(ctx, info.FullMethod) { - _ = notifier.NotifyAsync(err, ctx, notifier.Tags{ - "grpcMethod": info.FullMethod, - "duration": time.Since(start).Truncate(time.Millisecond).String(), - }) - } - return resp, err - } -} - -func PanicRecoveryInterceptor() grpc.UnaryServerInterceptor { - return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { - defer func(ctx context.Context) { - // panic handler - if r := recover(); r != nil { - stack := string(debug.Stack()) - log.Error(ctx, "panic", r, "method", info.FullMethod, "stack", stack) - if e, ok := r.(error); ok { - err = e - } else { - err = errors.New(fmt.Sprintf("panic: %s", r)) - } - nrutil.FinishNRTransaction(ctx, err) - _ = notifier.NotifyWithLevel(err, "critical", info.FullMethod, ctx, stack) - } - }(ctx) - - resp, err = handler(ctx, req) - return resp, err - } -} - -// NewRelicClientInterceptor intercepts all client actions and reports them to newrelic. -// When NewRelic app is nil (no license key configured), returns a pass-through -// interceptor to avoid overhead. -func NewRelicClientInterceptor() grpc.UnaryClientInterceptor { - app := nrutil.GetNewRelicApp() - if app == nil { - return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - return invoker(ctx, method, req, reply, cc, opts...) - } - } - return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - if defaultFilterFunc(ctx, method) { - return nrgrpc.UnaryClientInterceptor(ctx, method, req, reply, cc, invoker, opts...) - } else { - return invoker(ctx, method, req, reply, cc, opts...) - } - } -} - -// Deprecated: GRPCClientInterceptor is no longer needed. gRPC tracing is now handled -// by google.golang.org/grpc/stats/opentelemetry, configured via opentelemetry.DialOption() -// at the client level. This function is retained for backwards compatibility but returns -// a no-op interceptor. -func GRPCClientInterceptor(_ ...any) grpc.UnaryClientInterceptor { - return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - return invoker(ctx, method, req, reply, cc, opts...) - } -} - -// HystrixClientInterceptor returns a unary client interceptor that executes the RPC inside a Hystrix command. -// -// Note: This interceptor wraps github.com/afex/hystrix-go which has been unmaintained since 2018. -// Consider migrating to github.com/failsafe-go/failsafe-go for circuit breaker functionality. -// -// The interceptor applies provided default and per-call client options to configure Hystrix behavior (for example the command name, disabled flag, excluded errors, and excluded gRPC status codes). -// If Hystrix is disabled via options, the RPC is invoked directly. If the underlying RPC returns an error that matches any configured excluded error or whose gRPC status code matches any configured excluded code, Hystrix fallback is skipped and the RPC error is returned. -// Panics raised during the RPC invocation are captured and reported to the notifier before being converted into an error. If the RPC itself returns an error, that error is returned; otherwise any error produced by Hystrix is returned. -func HystrixClientInterceptor(defaultOpts ...grpc.CallOption) grpc.UnaryClientInterceptor { - return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - options := clientOptions{ - hystrixName: method, - } - for _, opt := range defaultOpts { - if opt != nil { - if o, ok := opt.(clientOption); ok { - o.process(&options) - } - } - } - for _, opt := range opts { - if opt != nil { - if o, ok := opt.(clientOption); ok { - o.process(&options) - } - } - } - if options.disableHystrix { - // short circuit if hystrix is disabled - return invoker(ctx, method, req, reply, cc, opts...) - } - newCtx, cancel := context.WithCancel(ctx) - defer cancel() - - var invokerErr error - hystrixErr := hystrix.Do(options.hystrixName, func() (err error) { - defer func() { - if r := recover(); r != nil { - err = errors.Wrap(fmt.Errorf("panic inside hystrix method: %s, req: %v, reply: %v", method, req, reply), "Hystrix") - log.Error(ctx, "panic", r, "method", method, "req", req, "reply", reply) - } - }() - defer notifier.NotifyOnPanic(newCtx, method) - invokerErr = invoker(newCtx, method, req, reply, cc, opts...) - for _, excludedErr := range options.excludedErrors { - if stdError.Is(invokerErr, excludedErr) { - return nil - } - } - if st, ok := status.FromError(invokerErr); ok { - if slices.Contains(options.excludedCodes, st.Code()) { - return nil - } - } - return invokerErr - }, nil) - if invokerErr != nil { - return invokerErr - } - return hystrixErr - } -} - -// ResponseTimeLoggingStreamInterceptor logs response time for stream RPCs. -func ResponseTimeLoggingStreamInterceptor() grpc.StreamServerInterceptor { - return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { - defer func(begin time.Time) { - if responseTimeLogErrorOnly && err == nil { - return - } - logArgs := make([]any, 0, 8) - logArgs = append(logArgs, "method", info.FullMethod, "error", err, "took", time.Since(begin)) - if err != nil { - logArgs = append(logArgs, "grpcCode", status.Code(err)) - } - log.GetLogger().Log(stream.Context(), responseTimeLogLevel, 1, logArgs...) - }(time.Now()) - err = handler(srv, stream) - return err - } -} - -// ServerErrorStreamInterceptor intercepts server errors for stream RPCs and -// reports them to the error notifier. -func ServerErrorStreamInterceptor() grpc.StreamServerInterceptor { - return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { - ctx := stream.Context() - ctx, _ = notifier.SetTraceIdWithValue(ctx) - start := time.Now() - err = handler(srv, stream) - if err != nil && defaultFilterFunc(ctx, info.FullMethod) { - _ = notifier.NotifyAsync(err, ctx, notifier.Tags{ - "grpcMethod": info.FullMethod, - "duration": time.Since(start).Truncate(time.Millisecond).String(), - }) - } - return err - } -} - -// NRHttpTracer adds newrelic tracing to this http function -func NRHttpTracer(pattern string, h http.HandlerFunc) (string, http.HandlerFunc) { - app := nrutil.GetNewRelicApp() - if app == nil { - return pattern, h - } - if pattern != "" { - return newrelic.WrapHandleFunc(app, pattern, h) - } - return pattern, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // filter functions we do not need - if defaultFilterFunc(context.Background(), r.URL.Path) { - txn := app.StartTransaction(r.Method + " " + r.URL.Path) - defer txn.End() - w = txn.SetWebResponse(w) - txn.SetWebRequestHTTP(r) - r = newrelic.RequestWithTransactionContext(r, txn) - } - h.ServeHTTP(w, r) - }) -} - -// TraceIdInterceptor allows injecting trace id from request objects -func TraceIdInterceptor() grpc.UnaryServerInterceptor { - return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { - if req != nil { - // fetch and update trace id from request - if r, ok := req.(interface{ GetTraceId() string }); ok { - ctx = notifier.UpdateTraceId(ctx, r.GetTraceId()) - } else if r, ok := req.(interface{ GetTraceID() string }); ok { - ctx = notifier.UpdateTraceId(ctx, r.GetTraceID()) - } - } - return handler(ctx, req) - } -} - -// SetDisableDebugLogInterceptor disables the DebugLogInterceptor in the default -// interceptor chain. Must be called during initialization, before the server starts. -func SetDisableDebugLogInterceptor(disable bool) { - disableDebugLogInterceptor = disable -} - -// SetDebugLogHeaderName sets the gRPC metadata header name that triggers -// per-request log level override. Default is "x-debug-log-level". The header -// value should be a valid log level (e.g., "debug"). Empty names are ignored. -// Must be called during initialization. -func SetDebugLogHeaderName(name string) { - name = strings.ToLower(strings.TrimSpace(name)) - if name == "" { - return - } - debugLogHeaderName = name -} - -// GetDebugLogHeaderName returns the current debug log header name. -func GetDebugLogHeaderName() string { - return debugLogHeaderName -} - -// DebugLogInterceptor enables per-request log level override based on a proto -// field or gRPC metadata header. It checks (in order): -// 1. Proto field: GetDebug() bool or GetEnableDebug() bool — always sets DebugLevel -// 2. Metadata header: configurable via SetDebugLogHeaderName (default "x-debug-log-level") -// — the header value is parsed as a log level, allowing any valid level (debug, info, warn, error) -// -// Combined with ColdBrew's trace ID propagation, this allows enabling debug -// logging for a single request and following it across services via trace ID. -func DebugLogInterceptor() grpc.UnaryServerInterceptor { - return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { - // Check proto field first - if req != nil { - if r, ok := req.(interface{ GetDebug() bool }); ok && r.GetDebug() { - ctx = log.OverrideLogLevel(ctx, loggers.DebugLevel) - return handler(ctx, req) - } - if r, ok := req.(interface{ GetEnableDebug() bool }); ok && r.GetEnableDebug() { - ctx = log.OverrideLogLevel(ctx, loggers.DebugLevel) - return handler(ctx, req) - } - } - // Check gRPC metadata header - if md, ok := metadata.FromIncomingContext(ctx); ok { - if vals := md.Get(debugLogHeaderName); len(vals) > 0 { - if level, err := loggers.ParseLevel(vals[0]); err == nil { - ctx = log.OverrideLogLevel(ctx, level) - } - } - } - return handler(ctx, req) - } -} - -// SetDisableRateLimit disables the rate limiting interceptor in the default -// interceptor chain. Must be called during initialization. -func SetDisableRateLimit(disable bool) { - disableRateLimit = disable -} - -// SetRateLimiter sets a custom rate limiter implementation. This overrides the -// built-in token bucket limiter. Must be called during initialization. -func SetRateLimiter(limiter ratelimit_middleware.Limiter) { - rateLimiter = limiter -} - -// SetDefaultRateLimit configures the built-in token bucket rate limiter. -// rps is requests per second, burst is the maximum burst size. -// This is a per-pod in-memory limit — with N pods, the effective cluster-wide -// limit is N × rps. For distributed rate limiting, use SetRateLimiter() with -// a custom implementation (e.g., Redis-backed). -// Must be called during initialization. -func SetDefaultRateLimit(rps float64, burst int) { - defaultRateLimit = rate.Limit(rps) - if burst < 1 { - burst = 1 - } - defaultRateBurst = burst -} - -// tokenBucketLimiter wraps golang.org/x/time/rate.Limiter to implement -// the ratelimit.Limiter interface. -type tokenBucketLimiter struct { - limiter *rate.Limiter -} - -func (l *tokenBucketLimiter) Limit(_ context.Context) error { - if !l.limiter.Allow() { - return fmt.Errorf("rate limit exceeded") - } - return nil -} - -func getRateLimiter() ratelimit_middleware.Limiter { - rateLimiterOnce.Do(func() { - if rateLimiter != nil { - rateLimiterVal = rateLimiter - return - } - if defaultRateLimit == rate.Inf { - rateLimiterVal = nil - return - } - rateLimiterVal = &tokenBucketLimiter{ - limiter: rate.NewLimiter(defaultRateLimit, defaultRateBurst), - } - }) - return rateLimiterVal -} diff --git a/interceptors_test.go b/interceptors_test.go index 1f093c8..84988b3 100644 --- a/interceptors_test.go +++ b/interceptors_test.go @@ -14,7 +14,6 @@ import ( "github.com/go-coldbrew/log/loggers" ratelimit_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/ratelimit" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "golang.org/x/time/rate" "google.golang.org/grpc" "google.golang.org/grpc/codes" grpcmd "google.golang.org/grpc/metadata" @@ -37,28 +36,19 @@ func grpcContext() context.Context { // resetGlobals restores package-level state so tests don't interfere with each other. func resetGlobals() { + defaultConfig = newDefaultConfig() FilterMethods = []string{"healthcheck", "readycheck", "serverreflectioninfo"} currentFilter.Store(buildFilterState()) - defaultFilterFunc = FilterMethodsFunc - unaryServerInterceptors = []grpc.UnaryServerInterceptor{} - streamServerInterceptors = []grpc.StreamServerInterceptor{} - useCBServerInterceptors = true - unaryClientInterceptors = []grpc.UnaryClientInterceptor{} - streamClientInterceptors = []grpc.StreamClientInterceptor{} - useCBClientInterceptors = true - responseTimeLogErrorOnly = false - responseTimeLogLevel = loggers.InfoLevel - defaultTimeout = 60 * time.Second httpToGRPCOnce = sync.Once{} httpToGRPCInterceptor = nil - disableDebugLogInterceptor = false - debugLogHeaderName = "x-debug-log-level" - disableRateLimit = false - rateLimiter = nil rateLimiterOnce = sync.Once{} rateLimiterVal = nil - defaultRateLimit = rate.Inf - defaultRateBurst = 0 + srvMetricsOnce = sync.Once{} + srvMetrics = nil + cltMetricsOnce = sync.Once{} + cltMetrics = nil + protoValidatorOnce = sync.Once{} + protoValidatorVal = nil } func TestFilterMethodsFunc(t *testing.T) { @@ -95,18 +85,18 @@ func TestSetFilterFunc(t *testing.T) { } SetFilterFunc(ctx, custom) - if defaultFilterFunc(ctx, "allow") != true { + if defaultConfig.filterFunc(ctx, "allow") != true { t.Error("custom filter should return true for 'allow'") } - if defaultFilterFunc(ctx, "deny") != false { + if defaultConfig.filterFunc(ctx, "deny") != false { t.Error("custom filter should return false for 'deny'") } // Setting nil should not change the filter. - prev := defaultFilterFunc + prev := defaultConfig.filterFunc SetFilterFunc(ctx, nil) // We can't compare funcs directly, so just verify behaviour is unchanged. - if defaultFilterFunc(ctx, "allow") != prev(ctx, "allow") { + if defaultConfig.filterFunc(ctx, "allow") != prev(ctx, "allow") { t.Error("SetFilterFunc(nil) should not change the filter") } } @@ -677,12 +667,12 @@ func BenchmarkResponseTimeLogging(b *testing.B) { resetGlobals() // Use debug level — the slog default logger discards debug, so we // measure interceptor + log-args-building overhead without I/O noise. - responseTimeLogLevel = loggers.DebugLevel + defaultConfig.responseTimeLogLevel = loggers.DebugLevel ctx := grpcContext() ff := FilterMethodsFunc b.Run("default/success", func(b *testing.B) { - responseTimeLogErrorOnly = false + defaultConfig.responseTimeLogErrorOnly = false interceptor := ResponseTimeLoggingInterceptor(ff) b.ResetTimer() b.ReportAllocs() @@ -692,7 +682,7 @@ func BenchmarkResponseTimeLogging(b *testing.B) { }) b.Run("default/error", func(b *testing.B) { - responseTimeLogErrorOnly = false + defaultConfig.responseTimeLogErrorOnly = false interceptor := ResponseTimeLoggingInterceptor(ff) b.ResetTimer() b.ReportAllocs() @@ -702,7 +692,7 @@ func BenchmarkResponseTimeLogging(b *testing.B) { }) b.Run("error_only/success", func(b *testing.B) { - responseTimeLogErrorOnly = true + defaultConfig.responseTimeLogErrorOnly = true interceptor := ResponseTimeLoggingInterceptor(ff) b.ResetTimer() b.ReportAllocs() @@ -712,7 +702,7 @@ func BenchmarkResponseTimeLogging(b *testing.B) { }) b.Run("error_only/error", func(b *testing.B) { - responseTimeLogErrorOnly = true + defaultConfig.responseTimeLogErrorOnly = true interceptor := ResponseTimeLoggingInterceptor(ff) b.ResetTimer() b.ReportAllocs() @@ -722,7 +712,7 @@ func BenchmarkResponseTimeLogging(b *testing.B) { }) // Restore default. - responseTimeLogErrorOnly = false + defaultConfig.responseTimeLogErrorOnly = false } func BenchmarkDefaultInterceptors(b *testing.B) { @@ -1435,3 +1425,111 @@ func TestRateLimitInterceptor_Disabled(t *testing.T) { } } } + +// mockServerStream implements grpc.ServerStream for testing stream interceptors. +type mockServerStream struct { + grpc.ServerStream + ctx context.Context +} + +func (s *mockServerStream) Context() context.Context { return s.ctx } + +func TestPanicRecoveryStreamInterceptor_NoPanic(t *testing.T) { + interceptor := PanicRecoveryStreamInterceptor() + stream := &mockServerStream{ctx: context.Background()} + info := &grpc.StreamServerInfo{FullMethod: "/test.Svc/Stream"} + + called := false + err := interceptor(nil, stream, info, func(_ any, _ grpc.ServerStream) error { + called = true + return nil + }) + if !called { + t.Fatal("handler should have been called") + } + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } +} + +func TestPanicRecoveryStreamInterceptor_Panic(t *testing.T) { + interceptor := PanicRecoveryStreamInterceptor() + stream := &mockServerStream{ctx: context.Background()} + info := &grpc.StreamServerInfo{FullMethod: "/test.Svc/Stream"} + + err := interceptor(nil, stream, info, func(_ any, _ grpc.ServerStream) error { + panic("test panic") + }) + if err == nil { + t.Fatal("expected error from panic recovery") + } + if err.Error() != "panic: test panic" { + t.Fatalf("expected 'panic: test panic', got %q", err.Error()) + } +} + +func TestPanicRecoveryStreamInterceptor_PanicWithError(t *testing.T) { + interceptor := PanicRecoveryStreamInterceptor() + stream := &mockServerStream{ctx: context.Background()} + info := &grpc.StreamServerInfo{FullMethod: "/test.Svc/Stream"} + origErr := errors.New("original error") + + err := interceptor(nil, stream, info, func(_ any, _ grpc.ServerStream) error { + panic(origErr) + }) + if err != origErr { + t.Fatalf("expected original error, got %v", err) + } +} + +func TestServerErrorStreamInterceptor_ContextWrapped(t *testing.T) { + resetGlobals() + interceptor := ServerErrorStreamInterceptor() + stream := &mockServerStream{ctx: context.Background()} + info := &grpc.StreamServerInfo{FullMethod: "/test.Svc/Stream"} + + var handlerCtx context.Context + _ = interceptor(nil, stream, info, func(_ any, s grpc.ServerStream) error { + handlerCtx = s.Context() + return nil + }) + // The handler should receive a wrapped stream with trace ID set + if handlerCtx == nil { + t.Fatal("handler context should not be nil") + } + // Context should differ from original (trace ID was added) + if handlerCtx == context.Background() { + t.Fatal("handler should receive a wrapped context, not the original") + } +} + +func TestServerErrorStreamInterceptor_Error(t *testing.T) { + resetGlobals() + interceptor := ServerErrorStreamInterceptor() + stream := &mockServerStream{ctx: context.Background()} + info := &grpc.StreamServerInfo{FullMethod: "/test.Svc/Stream"} + testErr := errors.New("stream error") + + err := interceptor(nil, stream, info, func(_ any, _ grpc.ServerStream) error { + return testErr + }) + if err != testErr { + t.Fatalf("expected test error, got %v", err) + } +} + +func TestDefaultStreamInterceptors_IncludesPanicRecovery(t *testing.T) { + resetGlobals() + ints := DefaultStreamInterceptors() + // Verify the chain handles panics by running a panicking handler + chain := ints[len(ints)-1] // PanicRecoveryStreamInterceptor is last + stream := &mockServerStream{ctx: context.Background()} + info := &grpc.StreamServerInfo{FullMethod: "/test.Svc/Stream"} + + err := chain(nil, stream, info, func(_ any, _ grpc.ServerStream) error { + panic("chain panic test") + }) + if err == nil { + t.Fatal("expected error from panic recovery in default stream chain") + } +} diff --git a/metrics.go b/metrics.go new file mode 100644 index 0000000..a3deacd --- /dev/null +++ b/metrics.go @@ -0,0 +1,48 @@ +package interceptors + +import ( + "context" + stdError "errors" + "sync" + + "github.com/go-coldbrew/log" + grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" + "github.com/prometheus/client_golang/prometheus" +) + +var ( + srvMetricsOnce sync.Once + srvMetrics *grpcprom.ServerMetrics + cltMetricsOnce sync.Once + cltMetrics *grpcprom.ClientMetrics +) + +func registerCollector(c prometheus.Collector) { + if err := prometheus.Register(c); err != nil { + var are prometheus.AlreadyRegisteredError + if stdError.As(err, &are) { + prometheus.Unregister(are.ExistingCollector) + if err := prometheus.Register(c); err != nil { + log.Warn(context.Background(), "msg", "failed to re-register gRPC metrics with Prometheus", "err", err) + } + return + } + log.Error(context.Background(), "msg", "gRPC Prometheus metrics registration failed. If you are using github.com/go-coldbrew/core, it may need to be updated to the latest version.", "err", err) + } +} + +func getServerMetrics() *grpcprom.ServerMetrics { + srvMetricsOnce.Do(func() { + srvMetrics = grpcprom.NewServerMetrics(defaultConfig.srvMetricsOpts...) + registerCollector(srvMetrics) + }) + return srvMetrics +} + +func getClientMetrics() *grpcprom.ClientMetrics { + cltMetricsOnce.Do(func() { + cltMetrics = grpcprom.NewClientMetrics(defaultConfig.cltMetricsOpts...) + registerCollector(cltMetrics) + }) + return cltMetrics +} diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 0000000..63f92c6 --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,45 @@ +package interceptors + +import ( + "context" + "fmt" + "sync" + + ratelimit_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/ratelimit" + "golang.org/x/time/rate" +) + +var ( + rateLimiterOnce sync.Once + rateLimiterVal ratelimit_middleware.Limiter +) + +// tokenBucketLimiter wraps golang.org/x/time/rate.Limiter to implement +// the ratelimit.Limiter interface. +type tokenBucketLimiter struct { + limiter *rate.Limiter +} + +func (l *tokenBucketLimiter) Limit(_ context.Context) error { + if !l.limiter.Allow() { + return fmt.Errorf("rate limit exceeded") + } + return nil +} + +func getRateLimiter() ratelimit_middleware.Limiter { + rateLimiterOnce.Do(func() { + if defaultConfig.rateLimiter != nil { + rateLimiterVal = defaultConfig.rateLimiter + return + } + if defaultConfig.defaultRateLimit == rate.Inf { + rateLimiterVal = nil + return + } + rateLimiterVal = &tokenBucketLimiter{ + limiter: rate.NewLimiter(defaultConfig.defaultRateLimit, defaultConfig.defaultRateBurst), + } + }) + return rateLimiterVal +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..6db7174 --- /dev/null +++ b/server.go @@ -0,0 +1,353 @@ +package interceptors + +import ( + "context" + "fmt" + "runtime/debug" + "sync" + "time" + + "buf.build/go/protovalidate" + "github.com/go-coldbrew/errors" + "github.com/go-coldbrew/errors/notifier" + "github.com/go-coldbrew/log" + "github.com/go-coldbrew/log/loggers" + "github.com/go-coldbrew/options" + nrutil "github.com/go-coldbrew/tracing/newrelic" + protovalidate_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate" + ratelimit_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/ratelimit" + "github.com/newrelic/go-agent/v3/integrations/nrgrpc" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +var ( + protoValidatorOnce sync.Once + protoValidatorVal protovalidate.Validator +) + +// getProtoValidator returns a cached protovalidate.Validator configured with +// custom options if set, falling back to GlobalValidator. +func getProtoValidator() protovalidate.Validator { + protoValidatorOnce.Do(func() { + if len(defaultConfig.protoValidateOpts) > 0 { + v, err := protovalidate.New(defaultConfig.protoValidateOpts...) + if err != nil { + log.Error(context.Background(), "msg", "failed to create protovalidate validator with custom options, falling back to global", "err", err) + protoValidatorVal = protovalidate.GlobalValidator + return + } + protoValidatorVal = v + return + } + protoValidatorVal = protovalidate.GlobalValidator + }) + return protoValidatorVal +} + +// ProtoValidateInterceptor returns a unary server interceptor that validates +// incoming messages using protovalidate annotations. Returns InvalidArgument +// on validation failure. Uses GlobalValidator by default; if custom options +// are set via SetProtoValidateOptions, creates a new validator with those options. +func ProtoValidateInterceptor() grpc.UnaryServerInterceptor { + return protovalidate_middleware.UnaryServerInterceptor(getProtoValidator()) +} + +// ProtoValidateStreamInterceptor returns a stream server interceptor that +// validates incoming messages using protovalidate annotations. +func ProtoValidateStreamInterceptor() grpc.StreamServerInterceptor { + return protovalidate_middleware.StreamServerInterceptor(getProtoValidator()) +} + +// DefaultInterceptors are the set of default interceptors that are applied to all coldbrew methods +func DefaultInterceptors() []grpc.UnaryServerInterceptor { + ints := []grpc.UnaryServerInterceptor{} + if len(defaultConfig.unaryServerInterceptors) > 0 { + ints = append(ints, defaultConfig.unaryServerInterceptors...) + } + if defaultConfig.useCBServerInterceptors { + ints = append(ints, DefaultTimeoutInterceptor()) + if !defaultConfig.disableRateLimit { + if limiter := getRateLimiter(); limiter != nil { + ints = append(ints, ratelimit_middleware.UnaryServerInterceptor(limiter)) + } + } + ints = append(ints, + ResponseTimeLoggingInterceptor(defaultConfig.filterFunc), + TraceIdInterceptor(), + ) + if !defaultConfig.disableDebugLogInterceptor { + ints = append(ints, DebugLogInterceptor()) + } + if !defaultConfig.disableProtoValidate { + ints = append(ints, ProtoValidateInterceptor()) + } + ints = append(ints, + getServerMetrics().UnaryServerInterceptor(), + ServerErrorInterceptor(), + NewRelicInterceptor(), + PanicRecoveryInterceptor(), + ) + } + return ints +} + +// DefaultStreamInterceptors are the set of default interceptors that should be applied to all coldbrew streams +func DefaultStreamInterceptors() []grpc.StreamServerInterceptor { + ints := []grpc.StreamServerInterceptor{} + if len(defaultConfig.streamServerInterceptors) > 0 { + ints = append(ints, defaultConfig.streamServerInterceptors...) + } + if defaultConfig.useCBServerInterceptors { + if !defaultConfig.disableRateLimit { + if limiter := getRateLimiter(); limiter != nil { + ints = append(ints, ratelimit_middleware.StreamServerInterceptor(limiter)) + } + } + ints = append(ints, + ResponseTimeLoggingStreamInterceptor(), + ) + if !defaultConfig.disableProtoValidate { + ints = append(ints, ProtoValidateStreamInterceptor()) + } + ints = append(ints, + getServerMetrics().StreamServerInterceptor(), + ServerErrorStreamInterceptor(), + PanicRecoveryStreamInterceptor(), + ) + } + return ints +} + +// DefaultTimeoutInterceptor returns a unary server interceptor that applies a +// default deadline to incoming requests that have no deadline set. If the +// incoming context already has a deadline (regardless of duration), it is left +// unchanged. When defaultTimeout is <= 0, the interceptor is a no-op pass-through. +func DefaultTimeoutInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + if defaultConfig.defaultTimeout <= 0 { + return handler(ctx, req) + } + if _, ok := ctx.Deadline(); ok { + return handler(ctx, req) + } + ctx, cancel := context.WithTimeout(ctx, defaultConfig.defaultTimeout) + defer cancel() + return handler(ctx, req) + } +} + +// ResponseTimeLoggingInterceptor logs response time for each request on server +func ResponseTimeLoggingInterceptor(ff FilterFunc) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + ctx = loggers.AddToLogContext(ctx, "grpcMethod", info.FullMethod) + defer func(ctx context.Context, method string, begin time.Time) { + if ff != nil && !ff(ctx, method) { + return + } + if defaultConfig.responseTimeLogErrorOnly && err == nil { + return + } + logArgs := make([]any, 0, 6) + logArgs = append(logArgs, "error", err, "took", time.Since(begin)) + if err != nil { + logArgs = append(logArgs, "grpcCode", status.Code(err)) + } + log.GetLogger().Log(ctx, defaultConfig.responseTimeLogLevel, 1, logArgs...) + }(ctx, info.FullMethod, time.Now()) + resp, err = handler(ctx, req) + return resp, err + } +} + +// ResponseTimeLoggingStreamInterceptor logs response time for stream RPCs. +func ResponseTimeLoggingStreamInterceptor() grpc.StreamServerInterceptor { + return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { + defer func(begin time.Time) { + if defaultConfig.responseTimeLogErrorOnly && err == nil { + return + } + logArgs := make([]any, 0, 8) + logArgs = append(logArgs, "method", info.FullMethod, "error", err, "took", time.Since(begin)) + if err != nil { + logArgs = append(logArgs, "grpcCode", status.Code(err)) + } + log.GetLogger().Log(stream.Context(), defaultConfig.responseTimeLogLevel, 1, logArgs...) + }(time.Now()) + err = handler(srv, stream) + return err + } +} + +func OptionsInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + ctx = options.AddToOptions(ctx, "", "") + return handler(ctx, req) + } +} + +// NewRelicInterceptor intercepts all server actions and reports them to newrelic. +// When NewRelic app is nil (no license key configured), returns a pass-through +// interceptor to avoid overhead. +func NewRelicInterceptor() grpc.UnaryServerInterceptor { + app := nrutil.GetNewRelicApp() + if app == nil { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + return handler(ctx, req) + } + } + nrh := nrgrpc.UnaryServerInterceptor(app) + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + if defaultConfig.filterFunc(ctx, info.FullMethod) { + return nrh(ctx, req, info, handler) + } else { + return handler(ctx, req) + } + } +} + +// ServerErrorInterceptor intercepts all server actions and reports them to error notifier +func ServerErrorInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + // set trace id if not set + ctx, _ = notifier.SetTraceIdWithValue(ctx) + start := time.Now() + resp, err = handler(ctx, req) + if err != nil && defaultConfig.filterFunc(ctx, info.FullMethod) { + _ = notifier.NotifyAsync(err, ctx, notifier.Tags{ + "grpcMethod": info.FullMethod, + "duration": time.Since(start).Truncate(time.Millisecond).String(), + }) + } + return resp, err + } +} + +// wrappedStream wraps a grpc.ServerStream to override its context. +type wrappedStream struct { + grpc.ServerStream + ctx context.Context +} + +func (w *wrappedStream) Context() context.Context { return w.ctx } + +// ServerErrorStreamInterceptor intercepts server errors for stream RPCs and +// reports them to the error notifier. +func ServerErrorStreamInterceptor() grpc.StreamServerInterceptor { + return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { + ctx := stream.Context() + ctx, _ = notifier.SetTraceIdWithValue(ctx) + start := time.Now() + err = handler(srv, &wrappedStream{ServerStream: stream, ctx: ctx}) + if err != nil && defaultConfig.filterFunc(ctx, info.FullMethod) { + _ = notifier.NotifyAsync(err, ctx, notifier.Tags{ + "grpcMethod": info.FullMethod, + "duration": time.Since(start).Truncate(time.Millisecond).String(), + }) + } + return err + } +} + +func PanicRecoveryInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + defer func(ctx context.Context) { + // panic handler + if r := recover(); r != nil { + stack := string(debug.Stack()) + log.Error(ctx, "panic", r, "method", info.FullMethod, "stack", stack) + if e, ok := r.(error); ok { + err = e + } else { + err = errors.New(fmt.Sprintf("panic: %v", r)) + } + nrutil.FinishNRTransaction(ctx, err) + _ = notifier.NotifyWithLevel(err, "critical", info.FullMethod, ctx, stack) + } + }(ctx) + + resp, err = handler(ctx, req) + return resp, err + } +} + +// PanicRecoveryStreamInterceptor recovers from panics in stream handlers, +// logs the panic and stack trace, and reports it to the error notifier. +func PanicRecoveryStreamInterceptor() grpc.StreamServerInterceptor { + return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { + defer func() { + if r := recover(); r != nil { + ctx := stream.Context() + stack := string(debug.Stack()) + log.Error(ctx, "panic", r, "method", info.FullMethod, "stack", stack) + if e, ok := r.(error); ok { + err = e + } else { + err = errors.New(fmt.Sprintf("panic: %v", r)) + } + nrutil.FinishNRTransaction(ctx, err) + _ = notifier.NotifyWithLevel(err, "critical", info.FullMethod, ctx, stack) + } + }() + return handler(srv, stream) + } +} + +// TraceIdInterceptor allows injecting trace id from request objects +func TraceIdInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + if req != nil { + // fetch and update trace id from request + if r, ok := req.(interface{ GetTraceId() string }); ok { + ctx = notifier.UpdateTraceId(ctx, r.GetTraceId()) + } else if r, ok := req.(interface{ GetTraceID() string }); ok { + ctx = notifier.UpdateTraceId(ctx, r.GetTraceID()) + } + } + return handler(ctx, req) + } +} + +// DebugLogInterceptor enables per-request log level override based on a proto +// field or gRPC metadata header. It checks (in order): +// 1. Proto field: GetDebug() bool or GetEnableDebug() bool — always sets DebugLevel +// 2. Metadata header: configurable via SetDebugLogHeaderName (default "x-debug-log-level") +// — the header value is parsed as a log level, allowing any valid level (debug, info, warn, error) +// +// Combined with ColdBrew's trace ID propagation, this allows enabling debug +// logging for a single request and following it across services via trace ID. +func DebugLogInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + // Check proto field first + if req != nil { + if r, ok := req.(interface{ GetDebug() bool }); ok && r.GetDebug() { + ctx = log.OverrideLogLevel(ctx, loggers.DebugLevel) + return handler(ctx, req) + } + if r, ok := req.(interface{ GetEnableDebug() bool }); ok && r.GetEnableDebug() { + ctx = log.OverrideLogLevel(ctx, loggers.DebugLevel) + return handler(ctx, req) + } + } + // Check gRPC metadata header + if md, ok := metadata.FromIncomingContext(ctx); ok { + if vals := md.Get(defaultConfig.debugLogHeaderName); len(vals) > 0 { + if level, err := loggers.ParseLevel(vals[0]); err == nil { + ctx = log.OverrideLogLevel(ctx, level) + } + } + } + return handler(ctx, req) + } +} + +// DebugLoggingInterceptor is the interceptor that logs all request/response from a handler +func DebugLoggingInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + log.Debug(ctx, "method", info.FullMethod, "request", req) + resp, err := handler(ctx, req) + log.Debug(ctx, "method", info.FullMethod, "response", resp, "err", err) + return resp, err + } +}