diff --git a/drpcmux/handle_rpc.go b/drpcmux/handle_rpc.go index 9bc6c12..a8cb2cc 100644 --- a/drpcmux/handle_rpc.go +++ b/drpcmux/handle_rpc.go @@ -4,10 +4,10 @@ package drpcmux import ( + "context" "reflect" "github.com/zeebo/errs" - "storj.io/drpc" ) @@ -30,7 +30,21 @@ func (m *Mux) HandleRPC(stream drpc.Stream, rpc string) (err error) { in = msg } - out, err := data.receiver(data.srv, stream.Context(), in, stream) + var out drpc.Message + if data.unitary && m.unaryInterceptor != nil { + out, err = m.unaryInterceptor(stream.Context(), in, rpc, + func(ctx context.Context, req interface{}) (interface{}, error) { + return data.receiver(data.srv, ctx, req, stream) + }) + } else if !data.unitary && m.streamInterceptor != nil { + out, err = m.streamInterceptor(stream.Context(), stream, rpc, + func(ctx context.Context, st drpc.Stream) (interface{}, error) { + return data.receiver(data.srv, ctx, st, stream) + }) + } else { + out, err = data.receiver(data.srv, stream.Context(), in, stream) + } + switch { case err != nil: return errs.Wrap(err) diff --git a/drpcmux/interceptor.go b/drpcmux/interceptor.go new file mode 100644 index 0000000..c649dbd --- /dev/null +++ b/drpcmux/interceptor.go @@ -0,0 +1,81 @@ +package drpcmux + +import ( + "context" + + "storj.io/drpc" +) + +// UnaryHandler defines the handler for the unary RPC. +type UnaryHandler func(ctx context.Context, in interface{}) (out interface{}, err error) + +// UnaryServerInterceptor defines the server side interceptor for unary RPC. +type UnaryServerInterceptor func( + ctx context.Context, req interface{}, rpc string, handler UnaryHandler) (out interface{}, err error) + +func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor { + switch n := len(interceptors); n { + case 0: + return nil + case 1: + return interceptors[0] + default: + return func(ctx context.Context, req interface{}, rpc string, handler UnaryHandler) ( + out interface{}, err error, + ) { + return interceptors[0]( + ctx, req, rpc, getChainedUnaryHandler(interceptors, 1, rpc, handler), + ) + } + } +} + +func getChainedUnaryHandler( + interceptors []UnaryServerInterceptor, currIdx int, rpc string, handler UnaryHandler, +) UnaryHandler { + if currIdx == len(interceptors) { + return handler + } + return func(ctx context.Context, in interface{}) (out interface{}, err error) { + return interceptors[currIdx]( + ctx, in, rpc, getChainedUnaryHandler(interceptors, currIdx+1, rpc, handler), + ) + } +} + +// StreamHandler defines the handler for the stream RPC. +type StreamHandler func(ctx context.Context, in drpc.Stream) (out interface{}, err error) + +// StreamServerInterceptor defines a server side interceptor for unary RPC. +type StreamServerInterceptor func( + ctx context.Context, stream drpc.Stream, rpc string, handler StreamHandler) (out interface{}, err error) + +func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor { + switch n := len(interceptors); n { + case 0: + return nil + case 1: + return interceptors[0] + default: + return func(ctx context.Context, stream drpc.Stream, rpc string, handler StreamHandler) ( + out interface{}, err error, + ) { + return interceptors[0]( + ctx, stream, rpc, getChainedStreamHandler(interceptors, 1, rpc, handler), + ) + } + } +} + +func getChainedStreamHandler( + interceptors []StreamServerInterceptor, currIdx int, rpc string, handler StreamHandler, +) StreamHandler { + if currIdx == len(interceptors) { + return handler + } + return func(ctx context.Context, in drpc.Stream) (out interface{}, err error) { + return interceptors[currIdx]( + ctx, in, rpc, getChainedStreamHandler(interceptors, currIdx+1, rpc, handler), + ) + } +} diff --git a/drpcmux/mux.go b/drpcmux/mux.go index 50ca443..bfca769 100644 --- a/drpcmux/mux.go +++ b/drpcmux/mux.go @@ -7,7 +7,6 @@ import ( "reflect" "github.com/zeebo/errs" - "storj.io/drpc" ) @@ -15,6 +14,20 @@ import ( // appropriate Receivers registered by Descriptions. type Mux struct { rpcs map[string]rpcData + + unaryInterceptor UnaryServerInterceptor + streamInterceptor StreamServerInterceptor +} + +// NewWithInterceptors constructs a new Mux with the provided unary and stream server interceptors. +func NewWithInterceptors( + unaryInterceptors []UnaryServerInterceptor, streamInterceptors []StreamServerInterceptor, +) *Mux { + return &Mux{ + rpcs: make(map[string]rpcData), + unaryInterceptor: chainUnaryInterceptors(unaryInterceptors), + streamInterceptor: chainStreamInterceptors(streamInterceptors), + } } // New constructs a new Mux. @@ -55,7 +68,9 @@ func (m *Mux) Register(srv interface{}, desc drpc.Description) error { } // registerOne does the work to register a single rpc. -func (m *Mux) registerOne(srv interface{}, rpc string, enc drpc.Encoding, receiver drpc.Receiver, method interface{}) error { +func (m *Mux) registerOne( + srv interface{}, rpc string, enc drpc.Encoding, receiver drpc.Receiver, method interface{}, +) error { data := rpcData{srv: srv, enc: enc, receiver: receiver} switch mt := reflect.TypeOf(method); {