diff --git a/interceptors/selector/doc.go b/interceptors/selector/doc.go new file mode 100644 index 000000000..496773ccb --- /dev/null +++ b/interceptors/selector/doc.go @@ -0,0 +1,15 @@ +// Copyright (c) The go-grpc-middleware Authors. +// Licensed under the Apache License 2.0. + +/* +Package selector + +`selector` a generic server-side selector middleware for gRPC. + +# Server Side Selector Middleware +It allows to set check rules to allowlist or blocklist middleware such as Auth +interceptors to toggle behavior on or off based on the request path. + +Please see examples for simple examples of use. +*/ +package selector diff --git a/interceptors/selector/selector.go b/interceptors/selector/selector.go new file mode 100644 index 000000000..5a5bd6a83 --- /dev/null +++ b/interceptors/selector/selector.go @@ -0,0 +1,34 @@ +// Copyright (c) The go-grpc-middleware Authors. +// Licensed under the Apache License 2.0. + +package selector + +import ( + "context" + + "google.golang.org/grpc" +) + +type MatchFunc func(ctx context.Context, fullMethod string) bool + +// UnaryServerInterceptor returns a new unary server interceptor that will decide whether to call +// the interceptor on the behavior of the MatchFunc. +func UnaryServerInterceptor(interceptors grpc.UnaryServerInterceptor, match MatchFunc) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + if match(ctx, info.FullMethod) { + return interceptors(ctx, req, info, handler) + } + return handler(ctx, req) + } +} + +// StreamServerInterceptor returns a new stream server interceptor that will decide whether to call +// the interceptor on the behavior of the MatchFunc. +func StreamServerInterceptor(interceptors grpc.StreamServerInterceptor, match MatchFunc) grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if match(ss.Context(), info.FullMethod) { + return interceptors(srv, ss, info, handler) + } + return handler(srv, ss) + } +} diff --git a/interceptors/selector/selector_example_test.go b/interceptors/selector/selector_example_test.go new file mode 100644 index 000000000..a5c95f754 --- /dev/null +++ b/interceptors/selector/selector_example_test.go @@ -0,0 +1,84 @@ +// Copyright (c) The go-grpc-middleware Authors. +// Licensed under the Apache License 2.0. + +package selector_test + +import ( + "context" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/ratelimit" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector" +) + +// alwaysPassLimiter is an example limiter which implements Limiter interface. +// It does not limit any request because Limit function always returns false. +type alwaysPassLimiter struct{} + +func (*alwaysPassLimiter) Limit(_ context.Context) error { + return nil +} + +func healthSkip(ctx context.Context, fullMethod string) bool { + return fullMethod != "/ping.v1.PingService/Health" +} + +func Example_ratelimit() { + limiter := &alwaysPassLimiter{} + _ = grpc.NewServer( + grpc.ChainUnaryInterceptor( + selector.UnaryServerInterceptor(ratelimit.UnaryServerInterceptor(limiter), healthSkip), + ), + grpc.ChainStreamInterceptor( + selector.StreamServerInterceptor(ratelimit.StreamServerInterceptor(limiter), healthSkip), + ), + ) +} + +var tokenInfoKey struct{} + +func parseToken(token string) (struct{}, error) { + return struct{}{}, nil +} + +func userClaimFromToken(struct{}) string { + return "foobar" +} + +// exampleAuthFunc is used by a middleware to authenticate requests +func exampleAuthFunc(ctx context.Context) (context.Context, error) { + token, err := auth.AuthFromMD(ctx, "bearer") + if err != nil { + return nil, err + } + + tokenInfo, err := parseToken(token) + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, "invalid auth token: %v", err) + } + + ctx = logging.InjectFields(ctx, logging.Fields{"auth.sub", userClaimFromToken(tokenInfo)}) + + // WARNING: In production define your own type to avoid context collisions. + return context.WithValue(ctx, tokenInfoKey, tokenInfo), nil +} + +func loginSkip(ctx context.Context, fullMethod string) bool { + return fullMethod != "/auth.v1.AuthService/Login" +} + +func Example_login() { + _ = grpc.NewServer( + grpc.ChainUnaryInterceptor( + selector.UnaryServerInterceptor(auth.UnaryServerInterceptor(exampleAuthFunc), loginSkip), + ), + grpc.ChainStreamInterceptor( + selector.StreamServerInterceptor(auth.StreamServerInterceptor(exampleAuthFunc), loginSkip), + ), + ) +} diff --git a/interceptors/selector/selector_test.go b/interceptors/selector/selector_test.go new file mode 100644 index 000000000..1bb15aa4a --- /dev/null +++ b/interceptors/selector/selector_test.go @@ -0,0 +1,157 @@ +// Copyright (c) The go-grpc-middleware Authors. +// Licensed under the Apache License 2.0. + +package selector + +import ( + "context" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth" +) + +var blockList = []string{"/auth.v1beta1.AuthService/Login"} + +const errMsgFake = "fake error" + +var ctxKey = struct{}{} + +// allow After the method is matched, the interceptor is run +func allow(methods []string) MatchFunc { + return func(ctx context.Context, fullMethod string) bool { + for _, s := range methods { + if s == fullMethod { + return true + } + } + return false + } +} + +// Block the interceptor will not run after the method matches +func block(methods []string) MatchFunc { + allow := allow(methods) + return func(ctx context.Context, fullMethod string) bool { + return !allow(ctx, fullMethod) + } +} + +type mockGRPCServerStream struct { + grpc.ServerStream + + ctx context.Context +} + +func (m *mockGRPCServerStream) Context() context.Context { + return m.ctx +} + +func TestUnaryServerInterceptor(t *testing.T) { + ctx := context.Background() + interceptor := UnaryServerInterceptor(auth.UnaryServerInterceptor( + func(ctx context.Context) (context.Context, error) { + newCtx := context.WithValue(ctx, ctxKey, true) + return newCtx, nil + }, + ), block(blockList)) + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + val := ctx.Value(ctxKey) + if b, ok := val.(bool); ok && b { + return "good", nil + } + return nil, errors.New(errMsgFake) + } + + t.Run("nextStep", func(t *testing.T) { + info := &grpc.UnaryServerInfo{ + FullMethod: "FakeMethod", + } + resp, err := interceptor(ctx, nil, info, handler) + assert.Nil(t, err) + assert.Equal(t, resp, "good") + }) + + t.Run("skipped", func(t *testing.T) { + info := &grpc.UnaryServerInfo{ + FullMethod: "/auth.v1beta1.AuthService/Login", + } + resp, err := interceptor(ctx, nil, info, handler) + assert.Nil(t, resp) + assert.EqualError(t, err, errMsgFake) + }) +} + +func TestStreamServerInterceptor(t *testing.T) { + ctx := context.Background() + interceptor := StreamServerInterceptor(auth.StreamServerInterceptor( + func(ctx context.Context) (context.Context, error) { + newCtx := context.WithValue(ctx, ctxKey, true) + return newCtx, nil + }, + ), block(blockList)) + + handler := func(srv interface{}, stream grpc.ServerStream) error { + ctx := stream.Context() + val := ctx.Value(ctxKey) + if b, ok := val.(bool); ok && b { + return nil + } + return errors.New(errMsgFake) + } + + t.Run("nextStep", func(t *testing.T) { + info := &grpc.StreamServerInfo{ + FullMethod: "FakeMethod", + } + err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, info, handler) + assert.Nil(t, err) + }) + + t.Run("skipped", func(t *testing.T) { + info := &grpc.StreamServerInfo{ + FullMethod: "/auth.v1beta1.AuthService/Login", + } + err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, info, handler) + assert.EqualError(t, err, errMsgFake) + }) +} + +func TestAllow(t *testing.T) { + type args struct { + methods []string + } + tests := []struct { + name string + args args + method string + want bool + }{ + { + name: "false", + args: args{ + methods: []string{"/auth.v1beta1.AuthService/Login"}, + }, + method: "/testing.testpb.v1.TestService/PingList", + want: false, + }, + { + name: "true", + args: args{ + methods: []string{"/auth.v1beta1.AuthService/Login"}, + }, + method: "/auth.v1beta1.AuthService/Login", + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + allow := allow(tt.args.methods) + want := allow(context.Background(), tt.method) + assert.Equalf(t, tt.want, want, "Allow(%v)(ctx, %v)", tt.args.methods, tt.method) + }) + } +}