Skip to content

Commit 776edd3

Browse files
WeiranFangmenghanl
authored andcommitted
interceptor: new APIs for chaining client interceptors. (#2696)
1 parent a9de79b commit 776edd3

File tree

3 files changed

+296
-3
lines changed

3 files changed

+296
-3
lines changed

Diff for: call_test.go

+203-1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ type server struct {
123123
conns map[transport.ServerTransport]bool
124124
}
125125

126+
type ctxKey string
127+
126128
func newTestServer() *server {
127129
return &server{startedErr: make(chan error, 1)}
128130
}
@@ -202,17 +204,217 @@ func (s *server) stop() {
202204
}
203205

204206
func setUp(t *testing.T, port int, maxStreams uint32) (*server, *ClientConn) {
207+
return setUpWithOptions(t, port, maxStreams)
208+
}
209+
210+
func setUpWithOptions(t *testing.T, port int, maxStreams uint32, dopts ...DialOption) (*server, *ClientConn) {
205211
server := newTestServer()
206212
go server.start(t, port, maxStreams)
207213
server.wait(t, 2*time.Second)
208214
addr := "localhost:" + server.port
209-
cc, err := Dial(addr, WithBlock(), WithInsecure(), WithCodec(testCodec{}))
215+
dopts = append(dopts, WithBlock(), WithInsecure(), WithCodec(testCodec{}))
216+
cc, err := Dial(addr, dopts...)
210217
if err != nil {
211218
t.Fatalf("Failed to create ClientConn: %v", err)
212219
}
213220
return server, cc
214221
}
215222

223+
func (s) TestUnaryClientInterceptor(t *testing.T) {
224+
parentKey := ctxKey("parentKey")
225+
226+
interceptor := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
227+
if ctx.Value(parentKey) == nil {
228+
t.Fatalf("interceptor should have %v in context", parentKey)
229+
}
230+
return invoker(ctx, method, req, reply, cc, opts...)
231+
}
232+
233+
server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(interceptor))
234+
defer func() {
235+
cc.Close()
236+
server.stop()
237+
}()
238+
239+
var reply string
240+
ctx := context.Background()
241+
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
242+
if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
243+
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
244+
}
245+
}
246+
247+
func (s) TestChainUnaryClientInterceptor(t *testing.T) {
248+
var (
249+
parentKey = ctxKey("parentKey")
250+
firstIntKey = ctxKey("firstIntKey")
251+
secondIntKey = ctxKey("secondIntKey")
252+
)
253+
254+
firstInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
255+
if ctx.Value(parentKey) == nil {
256+
t.Fatalf("first interceptor should have %v in context", parentKey)
257+
}
258+
if ctx.Value(firstIntKey) != nil {
259+
t.Fatalf("first interceptor should not have %v in context", firstIntKey)
260+
}
261+
if ctx.Value(secondIntKey) != nil {
262+
t.Fatalf("first interceptor should not have %v in context", secondIntKey)
263+
}
264+
firstCtx := context.WithValue(ctx, firstIntKey, 1)
265+
err := invoker(firstCtx, method, req, reply, cc, opts...)
266+
*(reply.(*string)) += "1"
267+
return err
268+
}
269+
270+
secondInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
271+
if ctx.Value(parentKey) == nil {
272+
t.Fatalf("second interceptor should have %v in context", parentKey)
273+
}
274+
if ctx.Value(firstIntKey) == nil {
275+
t.Fatalf("second interceptor should have %v in context", firstIntKey)
276+
}
277+
if ctx.Value(secondIntKey) != nil {
278+
t.Fatalf("second interceptor should not have %v in context", secondIntKey)
279+
}
280+
secondCtx := context.WithValue(ctx, secondIntKey, 2)
281+
err := invoker(secondCtx, method, req, reply, cc, opts...)
282+
*(reply.(*string)) += "2"
283+
return err
284+
}
285+
286+
lastInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
287+
if ctx.Value(parentKey) == nil {
288+
t.Fatalf("last interceptor should have %v in context", parentKey)
289+
}
290+
if ctx.Value(firstIntKey) == nil {
291+
t.Fatalf("last interceptor should have %v in context", firstIntKey)
292+
}
293+
if ctx.Value(secondIntKey) == nil {
294+
t.Fatalf("last interceptor should have %v in context", secondIntKey)
295+
}
296+
err := invoker(ctx, method, req, reply, cc, opts...)
297+
*(reply.(*string)) += "3"
298+
return err
299+
}
300+
301+
server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainUnaryInterceptor(firstInt, secondInt, lastInt))
302+
defer func() {
303+
cc.Close()
304+
server.stop()
305+
}()
306+
307+
var reply string
308+
ctx := context.Background()
309+
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
310+
if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse+"321" {
311+
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
312+
}
313+
}
314+
315+
func (s) TestChainOnBaseUnaryClientInterceptor(t *testing.T) {
316+
var (
317+
parentKey = ctxKey("parentKey")
318+
baseIntKey = ctxKey("baseIntKey")
319+
)
320+
321+
baseInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
322+
if ctx.Value(parentKey) == nil {
323+
t.Fatalf("base interceptor should have %v in context", parentKey)
324+
}
325+
if ctx.Value(baseIntKey) != nil {
326+
t.Fatalf("base interceptor should not have %v in context", baseIntKey)
327+
}
328+
baseCtx := context.WithValue(ctx, baseIntKey, 1)
329+
return invoker(baseCtx, method, req, reply, cc, opts...)
330+
}
331+
332+
chainInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
333+
if ctx.Value(parentKey) == nil {
334+
t.Fatalf("chain interceptor should have %v in context", parentKey)
335+
}
336+
if ctx.Value(baseIntKey) == nil {
337+
t.Fatalf("chain interceptor should have %v in context", baseIntKey)
338+
}
339+
return invoker(ctx, method, req, reply, cc, opts...)
340+
}
341+
342+
server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(baseInt), WithChainUnaryInterceptor(chainInt))
343+
defer func() {
344+
cc.Close()
345+
server.stop()
346+
}()
347+
348+
var reply string
349+
ctx := context.Background()
350+
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
351+
if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
352+
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
353+
}
354+
}
355+
356+
func (s) TestChainStreamClientInterceptor(t *testing.T) {
357+
var (
358+
parentKey = ctxKey("parentKey")
359+
firstIntKey = ctxKey("firstIntKey")
360+
secondIntKey = ctxKey("secondIntKey")
361+
)
362+
363+
firstInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
364+
if ctx.Value(parentKey) == nil {
365+
t.Fatalf("first interceptor should have %v in context", parentKey)
366+
}
367+
if ctx.Value(firstIntKey) != nil {
368+
t.Fatalf("first interceptor should not have %v in context", firstIntKey)
369+
}
370+
if ctx.Value(secondIntKey) != nil {
371+
t.Fatalf("first interceptor should not have %v in context", secondIntKey)
372+
}
373+
firstCtx := context.WithValue(ctx, firstIntKey, 1)
374+
return streamer(firstCtx, desc, cc, method, opts...)
375+
}
376+
377+
secondInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
378+
if ctx.Value(parentKey) == nil {
379+
t.Fatalf("second interceptor should have %v in context", parentKey)
380+
}
381+
if ctx.Value(firstIntKey) == nil {
382+
t.Fatalf("second interceptor should have %v in context", firstIntKey)
383+
}
384+
if ctx.Value(secondIntKey) != nil {
385+
t.Fatalf("second interceptor should not have %v in context", secondIntKey)
386+
}
387+
secondCtx := context.WithValue(ctx, secondIntKey, 2)
388+
return streamer(secondCtx, desc, cc, method, opts...)
389+
}
390+
391+
lastInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
392+
if ctx.Value(parentKey) == nil {
393+
t.Fatalf("last interceptor should have %v in context", parentKey)
394+
}
395+
if ctx.Value(firstIntKey) == nil {
396+
t.Fatalf("last interceptor should have %v in context", firstIntKey)
397+
}
398+
if ctx.Value(secondIntKey) == nil {
399+
t.Fatalf("last interceptor should have %v in context", secondIntKey)
400+
}
401+
return streamer(ctx, desc, cc, method, opts...)
402+
}
403+
404+
server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainStreamInterceptor(firstInt, secondInt, lastInt))
405+
defer func() {
406+
cc.Close()
407+
server.stop()
408+
}()
409+
410+
ctx := context.Background()
411+
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
412+
_, err := cc.NewStream(parentCtx, &StreamDesc{}, "/foo/bar")
413+
if err != nil {
414+
t.Fatalf("grpc.NewStream(_, _, _) = %v, want <nil>", err)
415+
}
416+
}
417+
216418
func (s) TestInvoke(t *testing.T) {
217419
server, cc := setUp(t, 0, math.MaxUint32)
218420
var reply string

Diff for: clientconn.go

+65
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
137137
opt.apply(&cc.dopts)
138138
}
139139

140+
chainUnaryClientInterceptors(cc)
141+
chainStreamClientInterceptors(cc)
142+
140143
defer func() {
141144
if err != nil {
142145
cc.Close()
@@ -327,6 +330,68 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
327330
return cc, nil
328331
}
329332

333+
// chainUnaryClientInterceptors chains all unary client interceptors into one.
334+
func chainUnaryClientInterceptors(cc *ClientConn) {
335+
interceptors := cc.dopts.chainUnaryInts
336+
// Prepend dopts.unaryInt to the chaining interceptors if it exists, since unaryInt will
337+
// be executed before any other chained interceptors.
338+
if cc.dopts.unaryInt != nil {
339+
interceptors = append([]UnaryClientInterceptor{cc.dopts.unaryInt}, interceptors...)
340+
}
341+
var chainedInt UnaryClientInterceptor
342+
if len(interceptors) == 0 {
343+
chainedInt = nil
344+
} else if len(interceptors) == 1 {
345+
chainedInt = interceptors[0]
346+
} else {
347+
chainedInt = func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
348+
return interceptors[0](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, 0, invoker), opts...)
349+
}
350+
}
351+
cc.dopts.unaryInt = chainedInt
352+
}
353+
354+
// getChainUnaryInvoker recursively generate the chained unary invoker.
355+
func getChainUnaryInvoker(interceptors []UnaryClientInterceptor, curr int, finalInvoker UnaryInvoker) UnaryInvoker {
356+
if curr == len(interceptors)-1 {
357+
return finalInvoker
358+
}
359+
return func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error {
360+
return interceptors[curr+1](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, curr+1, finalInvoker), opts...)
361+
}
362+
}
363+
364+
// chainStreamClientInterceptors chains all stream client interceptors into one.
365+
func chainStreamClientInterceptors(cc *ClientConn) {
366+
interceptors := cc.dopts.chainStreamInts
367+
// Prepend dopts.streamInt to the chaining interceptors if it exists, since streamInt will
368+
// be executed before any other chained interceptors.
369+
if cc.dopts.streamInt != nil {
370+
interceptors = append([]StreamClientInterceptor{cc.dopts.streamInt}, interceptors...)
371+
}
372+
var chainedInt StreamClientInterceptor
373+
if len(interceptors) == 0 {
374+
chainedInt = nil
375+
} else if len(interceptors) == 1 {
376+
chainedInt = interceptors[0]
377+
} else {
378+
chainedInt = func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
379+
return interceptors[0](ctx, desc, cc, method, getChainStreamer(interceptors, 0, streamer), opts...)
380+
}
381+
}
382+
cc.dopts.streamInt = chainedInt
383+
}
384+
385+
// getChainStreamer recursively generate the chained client stream constructor.
386+
func getChainStreamer(interceptors []StreamClientInterceptor, curr int, finalStreamer Streamer) Streamer {
387+
if curr == len(interceptors)-1 {
388+
return finalStreamer
389+
}
390+
return func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
391+
return interceptors[curr+1](ctx, desc, cc, method, getChainStreamer(interceptors, curr+1, finalStreamer), opts...)
392+
}
393+
}
394+
330395
// connectivityStateManager keeps the connectivity.State of ClientConn.
331396
// This struct will eventually be exported so the balancers can access it.
332397
type connectivityStateManager struct {

Diff for: dialoptions.go

+28-2
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,12 @@ import (
3939
// dialOptions configure a Dial call. dialOptions are set by the DialOption
4040
// values passed to Dial.
4141
type dialOptions struct {
42-
unaryInt UnaryClientInterceptor
43-
streamInt StreamClientInterceptor
42+
unaryInt UnaryClientInterceptor
43+
streamInt StreamClientInterceptor
44+
45+
chainUnaryInts []UnaryClientInterceptor
46+
chainStreamInts []StreamClientInterceptor
47+
4448
cp Compressor
4549
dc Decompressor
4650
bs backoff.Strategy
@@ -414,6 +418,17 @@ func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
414418
})
415419
}
416420

421+
// WithChainUnaryInterceptor returns a DialOption that specifies the chained
422+
// interceptor for unary RPCs. The first interceptor will be the outer most,
423+
// while the last interceptor will be the inner most wrapper around the real call.
424+
// All interceptors added by this method will be chained, and the interceptor
425+
// defined by WithUnaryInterceptor will always be prepended to the chain.
426+
func WithChainUnaryInterceptor(interceptors ...UnaryClientInterceptor) DialOption {
427+
return newFuncDialOption(func(o *dialOptions) {
428+
o.chainUnaryInts = append(o.chainUnaryInts, interceptors...)
429+
})
430+
}
431+
417432
// WithStreamInterceptor returns a DialOption that specifies the interceptor for
418433
// streaming RPCs.
419434
func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
@@ -422,6 +437,17 @@ func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
422437
})
423438
}
424439

440+
// WithChainStreamInterceptor returns a DialOption that specifies the chained
441+
// interceptor for unary RPCs. The first interceptor will be the outer most,
442+
// while the last interceptor will be the inner most wrapper around the real call.
443+
// All interceptors added by this method will be chained, and the interceptor
444+
// defined by WithStreamInterceptor will always be prepended to the chain.
445+
func WithChainStreamInterceptor(interceptors ...StreamClientInterceptor) DialOption {
446+
return newFuncDialOption(func(o *dialOptions) {
447+
o.chainStreamInts = append(o.chainStreamInts, interceptors...)
448+
})
449+
}
450+
425451
// WithAuthority returns a DialOption that specifies the value to be used as the
426452
// :authority pseudo-header. This value only works with WithInsecure and has no
427453
// effect if TransportCredentials are present.

0 commit comments

Comments
 (0)