Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ require (
github.com/newrelic/go-agent/v3/integrations/nrgrpc v1.4.7
github.com/prometheus/client_golang v1.23.2
go.uber.org/goleak v1.3.0
golang.org/x/time v0.15.0
google.golang.org/grpc v1.79.3
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,8 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200329025819-fd4102a86c65/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8=
Expand Down
114 changes: 94 additions & 20 deletions interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ import (
nrutil "github.com/go-coldbrew/tracing/newrelic"
grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus"
protovalidate_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate"
ratelimit_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/ratelimit"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/retry"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/newrelic/go-agent/v3/integrations/nrgrpc"
newrelic "github.com/newrelic/go-agent/v3/newrelic"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/time/rate"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
Expand All @@ -51,27 +53,33 @@ var (
// Use SetFilterMethods instead. Only some direct mutations (replacing the slice
// or changing the first element) are detected by internal change detection;
// other in-place changes may not invalidate caches correctly.
FilterMethods = []string{"healthcheck", "readycheck", "serverreflectioninfo"}
defaultFilterFunc = FilterMethodsFunc
unaryServerInterceptors = []grpc.UnaryServerInterceptor{}
streamServerInterceptors = []grpc.StreamServerInterceptor{}
useCBServerInterceptors = true
unaryClientInterceptors = []grpc.UnaryClientInterceptor{}
streamClientInterceptors = []grpc.StreamClientInterceptor{}
useCBClientInterceptors = true
responseTimeLogLevel loggers.Level = loggers.InfoLevel
responseTimeLogErrorOnly bool
defaultTimeout time.Duration = 60 * time.Second // 0 disables
protoValidateOpts []protovalidate.ValidatorOption
disableProtoValidate bool
srvMetricsOpts []grpcprom.ServerMetricsOption
cltMetricsOpts []grpcprom.ClientMetricsOption
srvMetricsOnce sync.Once
srvMetrics *grpcprom.ServerMetrics
cltMetricsOnce sync.Once
cltMetrics *grpcprom.ClientMetrics
FilterMethods = []string{"healthcheck", "readycheck", "serverreflectioninfo"}
defaultFilterFunc = FilterMethodsFunc
unaryServerInterceptors = []grpc.UnaryServerInterceptor{}
streamServerInterceptors = []grpc.StreamServerInterceptor{}
useCBServerInterceptors = true
unaryClientInterceptors = []grpc.UnaryClientInterceptor{}
streamClientInterceptors = []grpc.StreamClientInterceptor{}
useCBClientInterceptors = true
responseTimeLogLevel loggers.Level = loggers.InfoLevel
responseTimeLogErrorOnly bool
defaultTimeout time.Duration = 60 * time.Second // 0 disables
protoValidateOpts []protovalidate.ValidatorOption
disableProtoValidate bool
srvMetricsOpts []grpcprom.ServerMetricsOption
cltMetricsOpts []grpcprom.ClientMetricsOption
srvMetricsOnce sync.Once
srvMetrics *grpcprom.ServerMetrics
cltMetricsOnce sync.Once
cltMetrics *grpcprom.ClientMetrics
disableDebugLogInterceptor bool
debugLogHeaderName = "x-debug-log-level"
disableRateLimit bool
rateLimiter ratelimit_middleware.Limiter
rateLimiterOnce sync.Once
rateLimiterVal ratelimit_middleware.Limiter
defaultRateLimit rate.Limit = rate.Inf
defaultRateBurst int
)

// SetResponseTimeLogLevel sets the log level for response time logging.
Expand Down Expand Up @@ -428,8 +436,13 @@ func DefaultInterceptors() []grpc.UnaryServerInterceptor {
ints = append(ints, unaryServerInterceptors...)
}
if useCBServerInterceptors {
ints = append(ints, DefaultTimeoutInterceptor())
if !disableRateLimit {
if limiter := getRateLimiter(); limiter != nil {
ints = append(ints, ratelimit_middleware.UnaryServerInterceptor(limiter))
}
}
ints = append(ints,
DefaultTimeoutInterceptor(),
ResponseTimeLoggingInterceptor(defaultFilterFunc),
TraceIdInterceptor(),
)
Expand Down Expand Up @@ -497,6 +510,11 @@ func DefaultStreamInterceptors() []grpc.StreamServerInterceptor {
ints = append(ints, streamServerInterceptors...)
}
if useCBServerInterceptors {
if !disableRateLimit {
if limiter := getRateLimiter(); limiter != nil {
ints = append(ints, ratelimit_middleware.StreamServerInterceptor(limiter))
}
}
Comment thread
ankurs marked this conversation as resolved.
ints = append(ints,
ResponseTimeLoggingStreamInterceptor(),
)
Expand Down Expand Up @@ -860,3 +878,59 @@ func DebugLogInterceptor() grpc.UnaryServerInterceptor {
return handler(ctx, req)
}
}

// SetDisableRateLimit disables the rate limiting interceptor in the default
// interceptor chain. Must be called during initialization.
func SetDisableRateLimit(disable bool) {
disableRateLimit = disable
}

// SetRateLimiter sets a custom rate limiter implementation. This overrides the
// built-in token bucket limiter. Must be called during initialization.
func SetRateLimiter(limiter ratelimit_middleware.Limiter) {
rateLimiter = limiter
}

// SetDefaultRateLimit configures the built-in token bucket rate limiter.
// rps is requests per second, burst is the maximum burst size.
// This is a per-pod in-memory limit — with N pods, the effective cluster-wide
// limit is N × rps. For distributed rate limiting, use SetRateLimiter() with
// a custom implementation (e.g., Redis-backed).
// Must be called during initialization.
func SetDefaultRateLimit(rps float64, burst int) {
defaultRateLimit = rate.Limit(rps)
if burst < 1 {
burst = 1
}
defaultRateBurst = burst
Comment thread
ankurs marked this conversation as resolved.
}
Comment thread
ankurs marked this conversation as resolved.
Comment thread
ankurs marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// tokenBucketLimiter wraps golang.org/x/time/rate.Limiter to implement
// the ratelimit.Limiter interface.
type tokenBucketLimiter struct {
limiter *rate.Limiter
}

func (l *tokenBucketLimiter) Limit(_ context.Context) error {
if !l.limiter.Allow() {
return fmt.Errorf("rate limit exceeded")
}
return nil
}

func getRateLimiter() ratelimit_middleware.Limiter {
rateLimiterOnce.Do(func() {
if rateLimiter != nil {
rateLimiterVal = rateLimiter
return
}
if defaultRateLimit == rate.Inf {
rateLimiterVal = nil
return
}
rateLimiterVal = &tokenBucketLimiter{
limiter: rate.NewLimiter(defaultRateLimit, defaultRateBurst),
}
})
return rateLimiterVal
}
129 changes: 129 additions & 0 deletions interceptors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@ import (

"github.com/go-coldbrew/log"
"github.com/go-coldbrew/log/loggers"
ratelimit_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/ratelimit"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"golang.org/x/time/rate"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
grpcmd "google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

// mockStream implements grpc.ServerTransportStream for testing.
Expand Down Expand Up @@ -49,6 +53,12 @@ func resetGlobals() {
httpToGRPCInterceptor = nil
disableDebugLogInterceptor = false
debugLogHeaderName = "x-debug-log-level"
disableRateLimit = false
rateLimiter = nil
rateLimiterOnce = sync.Once{}
rateLimiterVal = nil
defaultRateLimit = rate.Inf
defaultRateBurst = 0
}

func TestFilterMethodsFunc(t *testing.T) {
Expand Down Expand Up @@ -1306,3 +1316,122 @@ func TestDebugLogInterceptor_Disabled(t *testing.T) {
}
}
}

// --- RateLimit interceptor tests ---

type alwaysRejectLimiter struct{}

func (l *alwaysRejectLimiter) Limit(_ context.Context) error {
return fmt.Errorf("always rejected")
}

func TestRateLimitInterceptor_DefaultInf(t *testing.T) {
resetGlobals()
// Default is rate.Inf — no rate limiting, getRateLimiter returns nil
limiter := getRateLimiter()
if limiter != nil {
t.Error("expected nil limiter with default rate.Inf")
}
}

func TestRateLimitInterceptor_Allowed(t *testing.T) {
resetGlobals()
SetDefaultRateLimit(1000, 100)
info := &grpc.UnaryServerInfo{FullMethod: "/test/RateLimit"}

handler := func(ctx context.Context, req any) (any, error) {
return "ok", nil
}

limiter := getRateLimiter()
if limiter == nil {
t.Fatal("expected non-nil limiter after SetDefaultRateLimit")
}

interceptor := ratelimit_middleware.UnaryServerInterceptor(limiter)
resp, err := interceptor(context.Background(), nil, info, handler)
if err != nil {
t.Fatalf("expected request to pass, got: %v", err)
}
if resp != "ok" {
t.Errorf("expected 'ok', got %v", resp)
}
}

func TestRateLimitInterceptor_Exceeded(t *testing.T) {
resetGlobals()
SetDefaultRateLimit(1, 1) // 1 rps, burst 1
info := &grpc.UnaryServerInfo{FullMethod: "/test/RateLimit"}

handler := func(ctx context.Context, req any) (any, error) {
return "ok", nil
}

limiter := getRateLimiter()
interceptor := ratelimit_middleware.UnaryServerInterceptor(limiter)

// First request should pass
_, err := interceptor(context.Background(), nil, info, handler)
if err != nil {
t.Fatalf("first request should pass, got: %v", err)
}

// Second request should be rate limited
_, err = interceptor(context.Background(), nil, info, handler)
if err == nil {
t.Fatal("second request should be rate limited")
}
st, ok := status.FromError(err)
if !ok || st.Code() != codes.ResourceExhausted {
t.Errorf("expected ResourceExhausted, got: %v", err)
}
}

func TestRateLimitInterceptor_CustomLimiter(t *testing.T) {
resetGlobals()
SetRateLimiter(&alwaysRejectLimiter{})
info := &grpc.UnaryServerInfo{FullMethod: "/test/CustomLimit"}

handler := func(ctx context.Context, req any) (any, error) {
return "ok", nil
}

limiter := getRateLimiter()
interceptor := ratelimit_middleware.UnaryServerInterceptor(limiter)

_, err := interceptor(context.Background(), nil, info, handler)
if err == nil {
t.Fatal("expected rejection from custom limiter")
}
st, ok := status.FromError(err)
if !ok || st.Code() != codes.ResourceExhausted {
t.Errorf("expected ResourceExhausted, got: %v", err)
}
}

func TestRateLimitInterceptor_Disabled(t *testing.T) {
resetGlobals()
SetDefaultRateLimit(1, 1)
SetDisableRateLimit(true)

ints := DefaultInterceptors()
// Verify no ratelimit interceptor in chain by running all interceptors
// with a handler that should always succeed
info := &grpc.UnaryServerInfo{FullMethod: "/test/Disabled"}
handler := func(ctx context.Context, req any) (any, error) {
return "ok", nil
}

// Fire multiple requests through the chain — none should be rate limited
for i := 0; i < 5; i++ {
for _, interceptor := range ints {
_, err := interceptor(context.Background(), nil, info, handler)
if err != nil {
st, ok := status.FromError(err)
if ok && st.Code() == codes.ResourceExhausted {
t.Fatal("rate limiting should be disabled")
}
}
}
}
}
Loading