Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions drpcmux/handle_rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
package drpcmux

import (
"context"
"reflect"

"github.com/zeebo/errs"

"storj.io/drpc"
)

Expand All @@ -30,7 +30,21 @@ func (m *Mux) HandleRPC(stream drpc.Stream, rpc string) (err error) {
in = msg
}

out, err := data.receiver(data.srv, stream.Context(), in, stream)
var out drpc.Message
if data.unitary && m.unaryInterceptor != nil {
out, err = m.unaryInterceptor(stream.Context(), in, rpc,
func(ctx context.Context, req interface{}) (interface{}, error) {
return data.receiver(data.srv, ctx, req, stream)
})
} else if !data.unitary && m.streamInterceptor != nil {
out, err = m.streamInterceptor(stream.Context(), stream, rpc,
func(ctx context.Context, st drpc.Stream) (interface{}, error) {
return data.receiver(data.srv, ctx, st, stream)
})
} else {
out, err = data.receiver(data.srv, stream.Context(), in, stream)
}

switch {
case err != nil:
return errs.Wrap(err)
Expand Down
81 changes: 81 additions & 0 deletions drpcmux/interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package drpcmux

import (
"context"

"storj.io/drpc"
)

// UnaryHandler defines the handler for the unary RPC.
type UnaryHandler func(ctx context.Context, in interface{}) (out interface{}, err error)

// UnaryServerInterceptor defines the server side interceptor for unary RPC.
type UnaryServerInterceptor func(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments and contract for the interceptors.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ctx context.Context, req interface{}, rpc string, handler UnaryHandler) (out interface{}, err error)

func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor {
switch n := len(interceptors); n {
case 0:
return nil
case 1:
return interceptors[0]
default:
return func(ctx context.Context, req interface{}, rpc string, handler UnaryHandler) (
out interface{}, err error,
) {
return interceptors[0](
ctx, req, rpc, getChainedUnaryHandler(interceptors, 1, rpc, handler),
)
}
}
}

func getChainedUnaryHandler(
interceptors []UnaryServerInterceptor, currIdx int, rpc string, handler UnaryHandler,
) UnaryHandler {
if currIdx == len(interceptors) {
return handler
}
return func(ctx context.Context, in interface{}) (out interface{}, err error) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about extracting this into asUnaryHandler(i UnaryInterceptor) UnaryHandler function? I find it more readable and easy to understand.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a bit more complicated. We're not just adapting the interceptor as a handler; we're also capturing arguments from getChainedUnaryHandler. This means if a method is declared as asUnaryHandler, it will still accept those arguments, making the cleanup pointless.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you are right. I would suggest adding a comment on what that code is meant to do for future reference.

return interceptors[currIdx](
ctx, in, rpc, getChainedUnaryHandler(interceptors, currIdx+1, rpc, handler),
)
}
}

// StreamHandler defines the handler for the stream RPC.
type StreamHandler func(ctx context.Context, in drpc.Stream) (out interface{}, err error)

// StreamServerInterceptor defines a server side interceptor for unary RPC.
type StreamServerInterceptor func(
ctx context.Context, stream drpc.Stream, rpc string, handler StreamHandler) (out interface{}, err error)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like we discussed before, rpc string can be replaced with StreamServerInfo similar to gRPC. Since stream interceptors apply to all forms of stream RPCs (unary request + stream response, stream request + unary response, stream request + stream response), this metadata will be useful for writing an interceptor that only applies to one form of stream interceptor.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm in favor of adding this, but I'm waiting to come across any use case where we need more than an RPC string.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not necessarily suggesting to block on this but IMO interfaces should also help with future needs. I'm fine with the current approach and revisit it if needed.


func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor {
switch n := len(interceptors); n {
case 0:
return nil
case 1:
return interceptors[0]
default:
return func(ctx context.Context, stream drpc.Stream, rpc string, handler StreamHandler) (
out interface{}, err error,
) {
return interceptors[0](
ctx, stream, rpc, getChainedStreamHandler(interceptors, 1, rpc, handler),
)
}
}
}

func getChainedStreamHandler(
interceptors []StreamServerInterceptor, currIdx int, rpc string, handler StreamHandler,
) StreamHandler {
if currIdx == len(interceptors) {
return handler
}
return func(ctx context.Context, in drpc.Stream) (out interface{}, err error) {
return interceptors[currIdx](
ctx, in, rpc, getChainedStreamHandler(interceptors, currIdx+1, rpc, handler),
)
}
}
19 changes: 17 additions & 2 deletions drpcmux/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,27 @@ import (
"reflect"

"github.com/zeebo/errs"

"storj.io/drpc"
)

// Mux is an implementation of Handler to serve drpc connections to the
// appropriate Receivers registered by Descriptions.
type Mux struct {
rpcs map[string]rpcData

unaryInterceptor UnaryServerInterceptor
streamInterceptor StreamServerInterceptor
}

// NewWithInterceptors constructs a new Mux with the provided unary and stream server interceptors.
func NewWithInterceptors(
unaryInterceptors []UnaryServerInterceptor, streamInterceptors []StreamServerInterceptor,
) *Mux {
return &Mux{
rpcs: make(map[string]rpcData),
unaryInterceptor: chainUnaryInterceptors(unaryInterceptors),
streamInterceptor: chainStreamInterceptors(streamInterceptors),
}
}

// New constructs a new Mux.
Expand Down Expand Up @@ -55,7 +68,9 @@ func (m *Mux) Register(srv interface{}, desc drpc.Description) error {
}

// registerOne does the work to register a single rpc.
func (m *Mux) registerOne(srv interface{}, rpc string, enc drpc.Encoding, receiver drpc.Receiver, method interface{}) error {
func (m *Mux) registerOne(
srv interface{}, rpc string, enc drpc.Encoding, receiver drpc.Receiver, method interface{},
) error {
data := rpcData{srv: srv, enc: enc, receiver: receiver}

switch mt := reflect.TypeOf(method); {
Expand Down