Skip to content

Commit

Permalink
Add error instead of bool to ratelimit interceptor
Browse files Browse the repository at this point in the history
Signed-off-by: dmaiocchi <[email protected]>
  • Loading branch information
MalloZup committed Jan 14, 2021
1 parent 08b17eb commit 6bfe427
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions interceptors/ratelimit/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (
// It does not limit any request because Limit function always returns false.
type alwaysPassLimiter struct{}

func (*alwaysPassLimiter) Limit(_ context.Context) bool {
return false
func (*alwaysPassLimiter) Limit(_ context.Context) error {
return nil
}

// Simple example of server initialization code.
Expand Down
12 changes: 6 additions & 6 deletions interceptors/ratelimit/ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ import (
)

// Limiter defines the interface to perform request rate limiting.
// If Limit function return true, the request will be rejected.
// If Limit function return an error, the request will be rejected.
// Otherwise, the request will pass.
type Limiter interface {
Limit(ctx context.Context) bool
Limit(ctx context.Context) error
}

// UnaryServerInterceptor returns a new unary server interceptors that performs request rate limiting.
func UnaryServerInterceptor(limiter Limiter) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if limiter.Limit(ctx) {
return nil, status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later.", info.FullMethod)
if err := limiter.Limit(ctx); err != nil {
return nil, status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later. %s", info.FullMethod, err)
}
return handler(ctx, req)
}
Expand All @@ -28,8 +28,8 @@ func UnaryServerInterceptor(limiter Limiter) grpc.UnaryServerInterceptor {
// StreamServerInterceptor returns a new stream server interceptor that performs rate limiting on the request.
func StreamServerInterceptor(limiter Limiter) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if limiter.Limit(stream.Context()) {
return status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later.", info.FullMethod)
if err := limiter.Limit(stream.Context()); err != nil {
return status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later. %s", info.FullMethod, err)
}
return handler(srv, stream)
}
Expand Down
10 changes: 5 additions & 5 deletions interceptors/ratelimit/ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ func (m *mockGRPCServerStream) Context() context.Context {

type mockContextBasedLimiter struct{}

func (*mockContextBasedLimiter) Limit(ctx context.Context) bool {
l, ok := ctx.Value(ctxLimitKey).(bool)
return ok && l
func (*mockContextBasedLimiter) Limit(ctx context.Context) error {
l, _ := ctx.Value(ctxLimitKey).(error)
return l
}

func TestUnaryServerInterceptor_RateLimitPass(t *testing.T) {
Expand Down Expand Up @@ -74,7 +74,7 @@ func TestUnaryServerInterceptor_RateLimitFail(t *testing.T) {
}
resp, err := interceptor(ctx, nil, info, handler)
assert.Nil(t, resp)
assert.EqualError(t, err, "rpc error: code = ResourceExhausted desc = FakeMethod is rejected by grpc_ratelimit middleware, please retry later.")
assert.EqualError(t, err, errMsgFake)
}

func TestStreamServerInterceptor_RateLimitFail(t *testing.T) {
Expand All @@ -89,5 +89,5 @@ func TestStreamServerInterceptor_RateLimitFail(t *testing.T) {
FullMethod: "FakeMethod",
}
err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, info, handler)
assert.EqualError(t, err, "rpc error: code = ResourceExhausted desc = FakeMethod is rejected by grpc_ratelimit middleware, please retry later.")
assert.EqualError(t, err, errMsgFake)
}

0 comments on commit 6bfe427

Please sign in to comment.