Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve v2 rate-limiter #380

Merged
merged 3 commits into from
Jan 22, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions interceptors/ratelimit/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (
// It does not limit any request because Limit function always returns false.
type alwaysPassLimiter struct{}

func (*alwaysPassLimiter) Limit(_ context.Context) bool {
return false
func (*alwaysPassLimiter) Limit(_ context.Context) error {
return nil
}

// Simple example of server initialization code.
Expand Down
12 changes: 6 additions & 6 deletions interceptors/ratelimit/ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ import (
)

// Limiter defines the interface to perform request rate limiting.
// If Limit function return true, the request will be rejected.
// If Limit function return an error, the request will be rejected.
MalloZup marked this conversation as resolved.
Show resolved Hide resolved
// Otherwise, the request will pass.
type Limiter interface {
Limit(ctx context.Context) bool
Limit(ctx context.Context) error
}

// UnaryServerInterceptor returns a new unary server interceptors that performs request rate limiting.
func UnaryServerInterceptor(limiter Limiter) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if limiter.Limit(ctx) {
return nil, status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later.", info.FullMethod)
if err := limiter.Limit(ctx); err != nil {
return nil, status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later. %s", info.FullMethod, err)
}
return handler(ctx, req)
}
Expand All @@ -28,8 +28,8 @@ func UnaryServerInterceptor(limiter Limiter) grpc.UnaryServerInterceptor {
// StreamServerInterceptor returns a new stream server interceptor that performs rate limiting on the request.
func StreamServerInterceptor(limiter Limiter) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if limiter.Limit(stream.Context()) {
return status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later.", info.FullMethod)
if err := limiter.Limit(stream.Context()); err != nil {
return status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later. %s", info.FullMethod, err)
}
return handler(srv, stream)
}
Expand Down
10 changes: 5 additions & 5 deletions interceptors/ratelimit/ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ func (m *mockGRPCServerStream) Context() context.Context {

type mockContextBasedLimiter struct{}

func (*mockContextBasedLimiter) Limit(ctx context.Context) bool {
l, ok := ctx.Value(ctxLimitKey).(bool)
return ok && l
func (*mockContextBasedLimiter) Limit(ctx context.Context) error {
l, _ := ctx.Value(ctxLimitKey).(error)
return l
}

func TestUnaryServerInterceptor_RateLimitPass(t *testing.T) {
Expand Down Expand Up @@ -74,7 +74,7 @@ func TestUnaryServerInterceptor_RateLimitFail(t *testing.T) {
}
resp, err := interceptor(ctx, nil, info, handler)
assert.Nil(t, resp)
assert.EqualError(t, err, "rpc error: code = ResourceExhausted desc = FakeMethod is rejected by grpc_ratelimit middleware, please retry later.")
assert.EqualError(t, err, errMsgFake)
}

func TestStreamServerInterceptor_RateLimitFail(t *testing.T) {
Expand All @@ -89,5 +89,5 @@ func TestStreamServerInterceptor_RateLimitFail(t *testing.T) {
FullMethod: "FakeMethod",
}
err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, info, handler)
assert.EqualError(t, err, "rpc error: code = ResourceExhausted desc = FakeMethod is rejected by grpc_ratelimit middleware, please retry later.")
assert.EqualError(t, err, errMsgFake)
}