diff --git a/README.md b/README.md index e07764b..049cd65 100644 --- a/README.md +++ b/README.md @@ -133,7 +133,7 @@ func AddUnaryServerInterceptor(ctx context.Context, i ...grpc.UnaryServerInterce AddUnaryServerInterceptor adds a server interceptor to default server interceptors. Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [DebugLogInterceptor]() +## func [DebugLogInterceptor]() ```go func DebugLogInterceptor() grpc.UnaryServerInterceptor @@ -147,7 +147,7 @@ DebugLogInterceptor enables per\-request log level override based on a proto fie Combined with ColdBrew's trace ID propagation, this allows enabling debug logging for a single request and following it across services via trace ID. -## func [DebugLoggingInterceptor]() +## func [DebugLoggingInterceptor]() ```go func DebugLoggingInterceptor() grpc.UnaryServerInterceptor @@ -192,25 +192,25 @@ func DefaultClientStreamInterceptors(defaultOpts ...any) []grpc.StreamClientInte DefaultClientStreamInterceptors are the set of default interceptors that should be applied to all stream client calls -## func [DefaultInterceptors]() +## func [DefaultInterceptors]() ```go func DefaultInterceptors() []grpc.UnaryServerInterceptor ``` -DefaultInterceptors are the set of default interceptors that are applied to all coldbrew methods +DefaultInterceptors returns the default unary server interceptor chain. The ordering is defined by the unaryPos\* constants above; this function assigns each interceptor to its named slot and drops any slot that is disabled via configuration. See the ordering contract above for semantics. -## func [DefaultStreamInterceptors]() +## func [DefaultStreamInterceptors]() ```go func DefaultStreamInterceptors() []grpc.StreamServerInterceptor ``` -DefaultStreamInterceptors are the set of default interceptors that should be applied to all coldbrew streams +DefaultStreamInterceptors returns the default stream server interceptor chain. The ordering is defined by the streamPos\* constants above; this function assigns each interceptor to its named slot and drops any slot that is disabled via configuration. See the ordering contract above for semantics. -## func [DefaultTimeoutInterceptor]() +## func [DefaultTimeoutInterceptor]() ```go func DefaultTimeoutInterceptor() grpc.UnaryServerInterceptor @@ -314,7 +314,7 @@ func NewRelicClientInterceptor() grpc.UnaryClientInterceptor NewRelicClientInterceptor intercepts all client actions and reports them to newrelic. When NewRelic app is nil \(no license key configured\), returns a pass\-through interceptor to avoid overhead. -## func [NewRelicInterceptor]() +## func [NewRelicInterceptor]() ```go func NewRelicInterceptor() grpc.UnaryServerInterceptor @@ -323,7 +323,7 @@ func NewRelicInterceptor() grpc.UnaryServerInterceptor NewRelicInterceptor intercepts all server actions and reports them to newrelic. When NewRelic app is nil \(no license key configured\), returns a pass\-through interceptor to avoid overhead. -## func [OptionsInterceptor]() +## func [OptionsInterceptor]() ```go func OptionsInterceptor() grpc.UnaryServerInterceptor @@ -332,7 +332,7 @@ func OptionsInterceptor() grpc.UnaryServerInterceptor -## func [PanicRecoveryInterceptor]() +## func [PanicRecoveryInterceptor]() ```go func PanicRecoveryInterceptor() grpc.UnaryServerInterceptor @@ -341,7 +341,7 @@ func PanicRecoveryInterceptor() grpc.UnaryServerInterceptor -## func [PanicRecoveryStreamInterceptor]() +## func [PanicRecoveryStreamInterceptor]() ```go func PanicRecoveryStreamInterceptor() grpc.StreamServerInterceptor @@ -368,7 +368,7 @@ func ProtoValidateStreamInterceptor() grpc.StreamServerInterceptor ProtoValidateStreamInterceptor returns a stream server interceptor that validates incoming messages using protovalidate annotations. -## func [ResponseTimeLoggingInterceptor]() +## func [ResponseTimeLoggingInterceptor]() ```go func ResponseTimeLoggingInterceptor(ff FilterFunc) grpc.UnaryServerInterceptor @@ -377,7 +377,7 @@ func ResponseTimeLoggingInterceptor(ff FilterFunc) grpc.UnaryServerInterceptor ResponseTimeLoggingInterceptor logs response time for each request on server -## func [ResponseTimeLoggingStreamInterceptor]() +## func [ResponseTimeLoggingStreamInterceptor]() ```go func ResponseTimeLoggingStreamInterceptor() grpc.StreamServerInterceptor @@ -386,7 +386,7 @@ func ResponseTimeLoggingStreamInterceptor() grpc.StreamServerInterceptor ResponseTimeLoggingStreamInterceptor logs response time for stream RPCs. -## func [ServerErrorInterceptor]() +## func [ServerErrorInterceptor]() ```go func ServerErrorInterceptor() grpc.UnaryServerInterceptor @@ -395,7 +395,7 @@ func ServerErrorInterceptor() grpc.UnaryServerInterceptor ServerErrorInterceptor intercepts all server actions and reports them to error notifier -## func [ServerErrorStreamInterceptor]() +## func [ServerErrorStreamInterceptor]() ```go func ServerErrorStreamInterceptor() grpc.StreamServerInterceptor @@ -539,7 +539,7 @@ func SetServerMetricsOptions(opts ...grpcprom.ServerMetricsOption) SetServerMetricsOptions appends gRPC server metrics options \(histogram, labels, namespace, etc.\). Must be called during initialization, before the server starts. Not safe for concurrent use. -## func [TraceIdInterceptor]() +## func [TraceIdInterceptor]() ```go func TraceIdInterceptor() grpc.UnaryServerInterceptor diff --git a/interceptors_test.go b/interceptors_test.go index 8c94917..59d1972 100644 --- a/interceptors_test.go +++ b/interceptors_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/go-coldbrew/errors/notifier" "github.com/go-coldbrew/log" "github.com/go-coldbrew/log/loggers" ratelimit_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/ratelimit" @@ -1823,3 +1824,376 @@ func TestDefaultStreamInterceptors_IncludesPanicRecovery(t *testing.T) { t.Fatal("expected error from panic recovery in default stream chain") } } + +// --- Interceptor ordering contract tests --- +// +// These tests guard the layering described on DefaultInterceptors / +// DefaultStreamInterceptors in server.go. They fail if a position constant is +// reordered or a slot is wired to the wrong interceptor — any such change +// silently alters observable server semantics (panic recovery coverage, +// validation-error reporting, deadline application). + +type traceIDRequest struct{ id string } + +func (r *traceIDRequest) GetTraceId() string { return r.id } + +func TestInterceptorPositionConstants(t *testing.T) { + // Unary: required relative ordering encoding the contract. + // See ordering contract godoc in server.go. + if unaryPosTimeout != 0 { + t.Errorf("unaryPosTimeout must be outermost (0); got %d", unaryPosTimeout) + } + if unaryPosPanicRecovery != unaryPosCount-1 { + t.Errorf("unaryPosPanicRecovery must be innermost (unaryPosCount-1=%d); got %d", + unaryPosCount-1, unaryPosPanicRecovery) + } + unaryOrder := []struct { + name string + pos int + }{ + {"Timeout", unaryPosTimeout}, + {"RateLimit", unaryPosRateLimit}, + {"ResponseTimeLog", unaryPosResponseTimeLog}, + {"TraceID", unaryPosTraceID}, + {"DebugLog", unaryPosDebugLog}, + {"ProtoValidate", unaryPosProtoValidate}, + {"Metrics", unaryPosMetrics}, + {"ServerError", unaryPosServerError}, + {"NewRelic", unaryPosNewRelic}, + {"PanicRecovery", unaryPosPanicRecovery}, + } + for i := 1; i < len(unaryOrder); i++ { + if unaryOrder[i-1].pos >= unaryOrder[i].pos { + t.Errorf("unary position order violation: %s (%d) must precede %s (%d)", + unaryOrder[i-1].name, unaryOrder[i-1].pos, + unaryOrder[i].name, unaryOrder[i].pos) + } + } + + // Critical semantic invariants: panic recovery must be INNER to + // metrics, error-reporting, and tracing so those layers observe the + // synthesized error from a recovered handler panic. + if unaryPosPanicRecovery <= unaryPosServerError { + t.Error("panic recovery must be INNER to ServerErrorInterceptor") + } + if unaryPosPanicRecovery <= unaryPosMetrics { + t.Error("panic recovery must be INNER to metrics") + } + if unaryPosPanicRecovery <= unaryPosNewRelic { + t.Error("panic recovery must be INNER to NewRelic") + } + // Protovalidate sits OUTER to metrics / error-reporting so that + // obviously-invalid requests short-circuit with InvalidArgument before + // any metrics or error-reporting work runs. + if unaryPosProtoValidate >= unaryPosMetrics { + t.Error("protovalidate must be OUTER to metrics (short-circuits before metrics work runs)") + } + if unaryPosProtoValidate >= unaryPosServerError { + t.Error("protovalidate must be OUTER to ServerErrorInterceptor (short-circuits before error-reporting runs)") + } + + // Stream variants. + if streamPosRateLimit != 0 { + t.Errorf("streamPosRateLimit must be outermost (0); got %d", streamPosRateLimit) + } + if streamPosPanicRecovery != streamPosCount-1 { + t.Errorf("streamPosPanicRecovery must be innermost (streamPosCount-1=%d); got %d", + streamPosCount-1, streamPosPanicRecovery) + } + streamOrder := []struct { + name string + pos int + }{ + {"RateLimit", streamPosRateLimit}, + {"ResponseTimeLog", streamPosResponseTimeLog}, + {"ProtoValidate", streamPosProtoValidate}, + {"Metrics", streamPosMetrics}, + {"ServerError", streamPosServerError}, + {"PanicRecovery", streamPosPanicRecovery}, + } + for i := 1; i < len(streamOrder); i++ { + if streamOrder[i-1].pos >= streamOrder[i].pos { + t.Errorf("stream position order violation: %s (%d) must precede %s (%d)", + streamOrder[i-1].name, streamOrder[i-1].pos, + streamOrder[i].name, streamOrder[i].pos) + } + } + if streamPosPanicRecovery <= streamPosServerError { + t.Error("stream panic recovery must be INNER to ServerErrorStreamInterceptor") + } + if streamPosProtoValidate >= streamPosMetrics { + t.Error("stream protovalidate must be OUTER to metrics (short-circuits before metrics work runs)") + } +} + +func TestDefaultInterceptors_AllSlotsPopulated(t *testing.T) { + resetGlobals() + defer resetGlobals() + + // Enable the rate limit slot (default config leaves the limiter nil). + SetDefaultRateLimit(1000, 1000) + + ints := DefaultInterceptors() + if len(ints) != unaryPosCount { + t.Fatalf("expected %d interceptors, got %d", unaryPosCount, len(ints)) + } +} + +func TestDefaultStreamInterceptors_AllSlotsPopulated(t *testing.T) { + resetGlobals() + defer resetGlobals() + SetDefaultRateLimit(1000, 1000) + + ints := DefaultStreamInterceptors() + if len(ints) != streamPosCount { + t.Fatalf("expected %d stream interceptors, got %d", streamPosCount, len(ints)) + } +} + +// TestDefaultInterceptors_SlotWiring verifies that each named position holds +// the expected interceptor by probing its characteristic side effect. Combined +// with TestInterceptorPositionConstants, this enforces both the positional +// layering and the slot-to-interceptor mapping. +func TestDefaultInterceptors_SlotWiring(t *testing.T) { + resetGlobals() + defer resetGlobals() + SetDefaultRateLimit(1000, 1000) + + ints := DefaultInterceptors() + if len(ints) != unaryPosCount { + t.Fatalf("expected %d interceptors, got %d", unaryPosCount, len(ints)) + } + info := &grpc.UnaryServerInfo{FullMethod: "/test.Svc/Method"} + + t.Run("Timeout", func(t *testing.T) { + gotDeadline := false + _, err := ints[unaryPosTimeout](context.Background(), nil, info, + func(ctx context.Context, _ any) (any, error) { + _, gotDeadline = ctx.Deadline() + return "ok", nil + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !gotDeadline { + t.Errorf("slot unaryPosTimeout (%d) should set a deadline", unaryPosTimeout) + } + }) + + t.Run("ResponseTimeLog", func(t *testing.T) { + var grpcMethod any + var found bool + _, err := ints[unaryPosResponseTimeLog](context.Background(), nil, info, + func(ctx context.Context, _ any) (any, error) { + grpcMethod, found = loggers.FromContext(ctx).Load("grpcMethod") + return "ok", nil + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !found || grpcMethod != "/test.Svc/Method" { + t.Errorf("slot unaryPosResponseTimeLog (%d) should add grpcMethod to log context; found=%v val=%v", + unaryPosResponseTimeLog, found, grpcMethod) + } + }) + + t.Run("TraceID", func(t *testing.T) { + var observed string + _, err := ints[unaryPosTraceID](context.Background(), &traceIDRequest{id: "trace-xyz"}, info, + func(ctx context.Context, _ any) (any, error) { + observed = notifier.GetTraceId(ctx) + return "ok", nil + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if observed != "trace-xyz" { + t.Errorf("slot unaryPosTraceID (%d) should extract trace id from request; got %q", + unaryPosTraceID, observed) + } + }) + + t.Run("DebugLog", func(t *testing.T) { + var level loggers.Level + var found bool + _, err := ints[unaryPosDebugLog](context.Background(), &debugRequest{debug: true}, info, + func(ctx context.Context, _ any) (any, error) { + level, found = log.GetOverridenLogLevel(ctx) + return "ok", nil + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !found || level != loggers.DebugLevel { + t.Errorf("slot unaryPosDebugLog (%d) should override log level on GetDebug()=true; found=%v level=%v", + unaryPosDebugLog, found, level) + } + }) + + t.Run("PanicRecovery", func(t *testing.T) { + _, err := ints[unaryPosPanicRecovery](context.Background(), nil, info, + func(_ context.Context, _ any) (any, error) { + panic("slot panic") + }) + if err == nil { + t.Errorf("slot unaryPosPanicRecovery (%d) should convert handler panic to error", + unaryPosPanicRecovery) + } + }) + + t.Run("ServerError_PanicPropagates", func(t *testing.T) { + // ServerErrorInterceptor must NOT recover panics — if it did, a + // misplaced recovery would silently swallow errors and break the + // contract that only unaryPosPanicRecovery recovers. + defer func() { + if r := recover(); r == nil { + t.Errorf("slot unaryPosServerError (%d) must NOT recover panics", unaryPosServerError) + } + }() + _, _ = ints[unaryPosServerError](context.Background(), nil, info, + func(_ context.Context, _ any) (any, error) { + panic("should propagate past ServerErrorInterceptor") + }) + }) +} + +// TestDefaultInterceptors_PanicThroughFullChain runs the full default unary +// chain with a panicking handler and asserts (a) the chain recovers and +// returns an error, and (b) outer layers (user interceptor registered via +// AddUnaryServerInterceptor) observe that error. This validates end-to-end +// that panic-recovery is wired into the chain and is innermost relative to +// user interceptors. +func TestDefaultInterceptors_PanicThroughFullChain(t *testing.T) { + resetGlobals() + defer resetGlobals() + + // protovalidate sits outer to panic recovery and rejects nil requests + // with InvalidArgument; without disabling it, the handler never runs + // and the test would pass on the validation error instead of the + // recovered panic. + SetDisableProtoValidate(true) + + var userSawErr error + AddUnaryServerInterceptor(context.Background(), + func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + resp, err := handler(ctx, req) + userSawErr = err + return resp, err + }) + + chain := chainUnaryServer(DefaultInterceptors()) + info := &grpc.UnaryServerInfo{FullMethod: "/test.Svc/Panic"} + + resp, err := chain(context.Background(), nil, info, + func(_ context.Context, _ any) (any, error) { + panic("boom") + }) + if err == nil || !strings.Contains(err.Error(), "boom") { + t.Fatalf("chain should surface the recovered panic (want error containing 'boom'); got %v", err) + } + if resp != nil { + t.Errorf("chain should return nil resp on panic, got %v", resp) + } + if userSawErr == nil || !strings.Contains(userSawErr.Error(), "boom") { + t.Errorf("user interceptor (outermost) should observe the recovered panic error; got %v", userSawErr) + } +} + +// TestDefaultInterceptors_UserInterceptorsOutermost verifies that interceptors +// registered via AddUnaryServerInterceptor run BEFORE the CB set — i.e., they +// see a context without the default timeout deadline applied. +func TestDefaultInterceptors_UserInterceptorsOutermost(t *testing.T) { + resetGlobals() + defer resetGlobals() + + // Disable protovalidate; it rejects nil requests which would short-circuit + // the chain before it reaches the handler. We only care about ordering here. + SetDisableProtoValidate(true) + + var userSawDeadline bool + AddUnaryServerInterceptor(context.Background(), + func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + _, userSawDeadline = ctx.Deadline() + return handler(ctx, req) + }) + + chain := chainUnaryServer(DefaultInterceptors()) + info := &grpc.UnaryServerInfo{FullMethod: "/test.Svc/Method"} + + _, err := chain(context.Background(), nil, info, + func(_ context.Context, _ any) (any, error) { return "ok", nil }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if userSawDeadline { + t.Error("user interceptor should run before DefaultTimeoutInterceptor and see no deadline") + } +} + +func TestDefaultStreamInterceptors_SlotWiring(t *testing.T) { + resetGlobals() + defer resetGlobals() + SetDefaultRateLimit(1000, 1000) + + ints := DefaultStreamInterceptors() + if len(ints) != streamPosCount { + t.Fatalf("expected %d stream interceptors, got %d", streamPosCount, len(ints)) + } + info := &grpc.StreamServerInfo{FullMethod: "/test.Svc/Stream"} + stream := &mockServerStream{ctx: context.Background()} + + t.Run("PanicRecovery", func(t *testing.T) { + err := ints[streamPosPanicRecovery](nil, stream, info, + func(_ any, _ grpc.ServerStream) error { + panic("stream slot panic") + }) + if err == nil { + t.Errorf("slot streamPosPanicRecovery (%d) should convert handler panic to error", + streamPosPanicRecovery) + } + }) + + t.Run("ServerError_PanicPropagates", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("slot streamPosServerError (%d) must NOT recover panics", streamPosServerError) + } + }() + _ = ints[streamPosServerError](nil, stream, info, + func(_ any, _ grpc.ServerStream) error { + panic("should propagate past ServerErrorStreamInterceptor") + }) + }) +} + +// TestDefaultStreamInterceptors_UserInterceptorsOutermost — user stream +// interceptors registered via AddStreamServerInterceptor run before the CB set. +func TestDefaultStreamInterceptors_UserInterceptorsOutermost(t *testing.T) { + resetGlobals() + defer resetGlobals() + + var order []string + AddStreamServerInterceptor(context.Background(), + func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + order = append(order, "user") + return handler(srv, ss) + }) + + ints := DefaultStreamInterceptors() + // The user interceptor must be the first element of the slice. + if len(ints) == 0 { + t.Fatal("no interceptors returned") + } + info := &grpc.StreamServerInfo{FullMethod: "/test.Svc/Stream"} + stream := &mockServerStream{ctx: context.Background()} + + // Run the outermost interceptor directly and verify the user interceptor + // is hit first. + _ = ints[0](nil, stream, info, func(_ any, _ grpc.ServerStream) error { + order = append(order, "handler") + return nil + }) + if len(order) == 0 || order[0] != "user" { + t.Errorf("user stream interceptor should be outermost (first); order=%v", order) + } +} diff --git a/server.go b/server.go index 6db7174..30905fa 100644 --- a/server.go +++ b/server.go @@ -60,62 +60,138 @@ func ProtoValidateStreamInterceptor() grpc.StreamServerInterceptor { return protovalidate_middleware.StreamServerInterceptor(getProtoValidator()) } -// DefaultInterceptors are the set of default interceptors that are applied to all coldbrew methods +// Interceptor ordering contract (read before reordering). +// +// The gRPC server chain helper (chainUnaryServer in chain.go) wraps +// interceptors LAST-FIRST, so the LAST element of the slice returned by +// DefaultInterceptors / DefaultStreamInterceptors is the INNERMOST (runs +// closest to the handler) and the FIRST element is the OUTERMOST (runs +// first on an incoming request). +// +// The unaryPos* / streamPos* constants below encode the required layering; +// changing a position changes observable server semantics: +// +// - Timeout / rate-limit are OUTERMOST. They short-circuit or cap work +// before any other interceptor runs. +// - Response-time logging, trace-id propagation, and the debug-log override +// run next. They set up context fields that downstream interceptors and +// the handler rely on. +// - Protovalidate runs BEFORE (outer to) metrics / error reporting / +// tracing. A validation failure short-circuits the chain with +// InvalidArgument so no metrics or error-reporting work is done for +// obviously bad requests; the trade-off is that inner layers do not +// observe validation rejections. +// - Metrics, ServerErrorInterceptor, and New Relic wrap the handler from +// the OUTSIDE of the inner stack. They observe the final error/response +// that propagates back outward — including errors synthesized by the +// panic-recovery layer — but not validation rejections short-circuited +// by the outer protovalidate layer. +// - Panic recovery is INNERMOST. Handler panics are recovered and converted +// to errors, which then propagate outward through error reporting, +// metrics, and tracing so those layers record the call as a failure +// rather than a success. +// +// User-supplied interceptors registered via AddUnaryServerInterceptor / +// AddStreamServerInterceptor are prepended OUTERMOST, before the ColdBrew +// (CB) set. +// +// Tests in interceptors_test.go (TestInterceptorPositionConstants, +// TestDefaultInterceptors_SlotWiring, TestDefaultInterceptors_PanicThroughFullChain, +// TestDefaultInterceptors_UserInterceptorsOutermost, and their stream +// variants) guard this contract. +const ( + unaryPosTimeout = iota // outermost + unaryPosRateLimit + unaryPosResponseTimeLog + unaryPosTraceID + unaryPosDebugLog + unaryPosProtoValidate + unaryPosMetrics + unaryPosServerError + unaryPosNewRelic + unaryPosPanicRecovery // innermost + unaryPosCount +) + +const ( + streamPosRateLimit = iota // outermost + streamPosResponseTimeLog + streamPosProtoValidate + streamPosMetrics + streamPosServerError + streamPosPanicRecovery // innermost + streamPosCount +) + +// DefaultInterceptors returns the default unary server interceptor chain. +// The ordering is defined by the unaryPos* constants above; this function +// assigns each interceptor to its named slot and drops any slot that is +// disabled via configuration. See the ordering contract above for semantics. func DefaultInterceptors() []grpc.UnaryServerInterceptor { - ints := []grpc.UnaryServerInterceptor{} - if len(defaultConfig.unaryServerInterceptors) > 0 { - ints = append(ints, defaultConfig.unaryServerInterceptors...) + ints := make([]grpc.UnaryServerInterceptor, 0, len(defaultConfig.unaryServerInterceptors)+unaryPosCount) + ints = append(ints, defaultConfig.unaryServerInterceptors...) + if !defaultConfig.useCBServerInterceptors { + return ints } - if defaultConfig.useCBServerInterceptors { - ints = append(ints, DefaultTimeoutInterceptor()) - if !defaultConfig.disableRateLimit { - if limiter := getRateLimiter(); limiter != nil { - ints = append(ints, ratelimit_middleware.UnaryServerInterceptor(limiter)) - } - } - ints = append(ints, - ResponseTimeLoggingInterceptor(defaultConfig.filterFunc), - TraceIdInterceptor(), - ) - if !defaultConfig.disableDebugLogInterceptor { - ints = append(ints, DebugLogInterceptor()) + + cb := make([]grpc.UnaryServerInterceptor, unaryPosCount) + cb[unaryPosTimeout] = DefaultTimeoutInterceptor() + if !defaultConfig.disableRateLimit { + if limiter := getRateLimiter(); limiter != nil { + cb[unaryPosRateLimit] = ratelimit_middleware.UnaryServerInterceptor(limiter) } - if !defaultConfig.disableProtoValidate { - ints = append(ints, ProtoValidateInterceptor()) + } + cb[unaryPosResponseTimeLog] = ResponseTimeLoggingInterceptor(defaultConfig.filterFunc) + cb[unaryPosTraceID] = TraceIdInterceptor() + if !defaultConfig.disableDebugLogInterceptor { + cb[unaryPosDebugLog] = DebugLogInterceptor() + } + if !defaultConfig.disableProtoValidate { + cb[unaryPosProtoValidate] = ProtoValidateInterceptor() + } + cb[unaryPosMetrics] = getServerMetrics().UnaryServerInterceptor() + cb[unaryPosServerError] = ServerErrorInterceptor() + cb[unaryPosNewRelic] = NewRelicInterceptor() + cb[unaryPosPanicRecovery] = PanicRecoveryInterceptor() + + for _, i := range cb { + if i != nil { + ints = append(ints, i) } - ints = append(ints, - getServerMetrics().UnaryServerInterceptor(), - ServerErrorInterceptor(), - NewRelicInterceptor(), - PanicRecoveryInterceptor(), - ) } return ints } -// DefaultStreamInterceptors are the set of default interceptors that should be applied to all coldbrew streams +// DefaultStreamInterceptors returns the default stream server interceptor +// chain. The ordering is defined by the streamPos* constants above; this +// function assigns each interceptor to its named slot and drops any slot +// that is disabled via configuration. See the ordering contract above for +// semantics. func DefaultStreamInterceptors() []grpc.StreamServerInterceptor { - ints := []grpc.StreamServerInterceptor{} - if len(defaultConfig.streamServerInterceptors) > 0 { - ints = append(ints, defaultConfig.streamServerInterceptors...) + ints := make([]grpc.StreamServerInterceptor, 0, len(defaultConfig.streamServerInterceptors)+streamPosCount) + ints = append(ints, defaultConfig.streamServerInterceptors...) + if !defaultConfig.useCBServerInterceptors { + return ints } - if defaultConfig.useCBServerInterceptors { - if !defaultConfig.disableRateLimit { - if limiter := getRateLimiter(); limiter != nil { - ints = append(ints, ratelimit_middleware.StreamServerInterceptor(limiter)) - } + + cb := make([]grpc.StreamServerInterceptor, streamPosCount) + if !defaultConfig.disableRateLimit { + if limiter := getRateLimiter(); limiter != nil { + cb[streamPosRateLimit] = ratelimit_middleware.StreamServerInterceptor(limiter) } - ints = append(ints, - ResponseTimeLoggingStreamInterceptor(), - ) - if !defaultConfig.disableProtoValidate { - ints = append(ints, ProtoValidateStreamInterceptor()) + } + cb[streamPosResponseTimeLog] = ResponseTimeLoggingStreamInterceptor() + if !defaultConfig.disableProtoValidate { + cb[streamPosProtoValidate] = ProtoValidateStreamInterceptor() + } + cb[streamPosMetrics] = getServerMetrics().StreamServerInterceptor() + cb[streamPosServerError] = ServerErrorStreamInterceptor() + cb[streamPosPanicRecovery] = PanicRecoveryStreamInterceptor() + + for _, i := range cb { + if i != nil { + ints = append(ints, i) } - ints = append(ints, - getServerMetrics().StreamServerInterceptor(), - ServerErrorStreamInterceptor(), - PanicRecoveryStreamInterceptor(), - ) } return ints }