diff --git a/interceptors/ratelimit/examples_test.go b/interceptors/ratelimit/examples_test.go index 9b73fabeb..370bed827 100644 --- a/interceptors/ratelimit/examples_test.go +++ b/interceptors/ratelimit/examples_test.go @@ -13,8 +13,8 @@ import ( // It does not limit any request because Limit function always returns false. type alwaysPassLimiter struct{} -func (*alwaysPassLimiter) Limit(_ context.Context) bool { - return false +func (*alwaysPassLimiter) Limit(_ context.Context) error { + return nil } // Simple example of server initialization code. diff --git a/interceptors/ratelimit/ratelimit.go b/interceptors/ratelimit/ratelimit.go index afc843106..54df8921a 100644 --- a/interceptors/ratelimit/ratelimit.go +++ b/interceptors/ratelimit/ratelimit.go @@ -9,17 +9,17 @@ import ( ) // Limiter defines the interface to perform request rate limiting. -// If Limit function return true, the request will be rejected. +// If Limit function return an error, the request will be rejected. // Otherwise, the request will pass. type Limiter interface { - Limit(ctx context.Context) bool + Limit(ctx context.Context) error } // UnaryServerInterceptor returns a new unary server interceptors that performs request rate limiting. func UnaryServerInterceptor(limiter Limiter) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - if limiter.Limit(ctx) { - return nil, status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later.", info.FullMethod) + if err := limiter.Limit(ctx); err != nil { + return nil, status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later. %s", info.FullMethod, err) } return handler(ctx, req) } @@ -28,8 +28,8 @@ func UnaryServerInterceptor(limiter Limiter) grpc.UnaryServerInterceptor { // StreamServerInterceptor returns a new stream server interceptor that performs rate limiting on the request. func StreamServerInterceptor(limiter Limiter) grpc.StreamServerInterceptor { return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - if limiter.Limit(stream.Context()) { - return status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later.", info.FullMethod) + if err := limiter.Limit(stream.Context()); err != nil { + return status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later. %s", info.FullMethod, err) } return handler(srv, stream) } diff --git a/interceptors/ratelimit/ratelimit_test.go b/interceptors/ratelimit/ratelimit_test.go index 019f0a5de..8c7c5371f 100644 --- a/interceptors/ratelimit/ratelimit_test.go +++ b/interceptors/ratelimit/ratelimit_test.go @@ -25,9 +25,9 @@ func (m *mockGRPCServerStream) Context() context.Context { type mockContextBasedLimiter struct{} -func (*mockContextBasedLimiter) Limit(ctx context.Context) bool { - l, ok := ctx.Value(ctxLimitKey).(bool) - return ok && l +func (*mockContextBasedLimiter) Limit(ctx context.Context) error { + l, _ := ctx.Value(ctxLimitKey).(error) + return l } func TestUnaryServerInterceptor_RateLimitPass(t *testing.T) { @@ -74,7 +74,7 @@ func TestUnaryServerInterceptor_RateLimitFail(t *testing.T) { } resp, err := interceptor(ctx, nil, info, handler) assert.Nil(t, resp) - assert.EqualError(t, err, "rpc error: code = ResourceExhausted desc = FakeMethod is rejected by grpc_ratelimit middleware, please retry later.") + assert.EqualError(t, err, errMsgFake) } func TestStreamServerInterceptor_RateLimitFail(t *testing.T) { @@ -89,5 +89,5 @@ func TestStreamServerInterceptor_RateLimitFail(t *testing.T) { FullMethod: "FakeMethod", } err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, info, handler) - assert.EqualError(t, err, "rpc error: code = ResourceExhausted desc = FakeMethod is rejected by grpc_ratelimit middleware, please retry later.") + assert.EqualError(t, err, errMsgFake) }