diff --git a/gateway/field_presence.go b/gateway/field_presence.go index e531cd98..c3143585 100644 --- a/gateway/field_presence.go +++ b/gateway/field_presence.go @@ -68,7 +68,7 @@ func NewPresenceAnnotator(methods ...string) func(context.Context, *http.Request if m, ok := item.node.(map[string]interface{}); ok { // if the item is an object, then enqueue all of its children for k, v := range m { - newPath := extendPath(item.path, k, v) + newPath := extendPath(item.path, k) queue = append(queue, pathItem{path: newPath, node: v}) } } @@ -87,9 +87,9 @@ func NewPresenceAnnotator(methods ...string) func(context.Context, *http.Request } } -func extendPath(parrent []string, key string, value interface{}) []string { - newPath := make([]string, len(parrent)+1) - copy(newPath, parrent) +func extendPath(parent []string, key string) []string { + newPath := make([]string, len(parent)+1) + copy(newPath, parent) newPath[len(newPath)-1] = generator.CamelCase(key) return newPath } @@ -145,9 +145,26 @@ type pathItem struct { node interface{} } +type presenceInterceptorOptionsDecorator struct { + overrideFieldMask bool +} + +type presenceInterceptorOption func(*presenceInterceptorOptionsDecorator) + +//WithOverrideFieldMask represent an option to override field mask generated by grpc-gateway +func WithOverrideFieldMask(d *presenceInterceptorOptionsDecorator) { + d.overrideFieldMask = true +} + // PresenceClientInterceptor gets the interceptor for populating a fieldmask in a // proto message from the fields given in the metadata/context -func PresenceClientInterceptor() grpc.UnaryClientInterceptor { +func PresenceClientInterceptor(options ...presenceInterceptorOption) grpc.UnaryClientInterceptor { + ops := &presenceInterceptorOptionsDecorator{} + + for _, op := range options { + op(ops) + } + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) { defer func() { err = invoker(ctx, method, req, reply, cc, opts...) @@ -175,8 +192,12 @@ func PresenceClientInterceptor() grpc.UnaryClientInterceptor { for i := 0; i < t.NumField(); i++ { f := t.Field(i) - if f.Type() == reflect.TypeOf(fieldMask) && f.IsNil() { - f.Set(reflect.ValueOf(fieldMask)) + if f.Type() == reflect.TypeOf(fieldMask) { + if ops.overrideFieldMask || f.IsNil() { + f.Set(reflect.ValueOf(fieldMask)) + } + + return } } return diff --git a/gateway/field_presence_test.go b/gateway/field_presence_test.go index 4bcb6da4..968b2544 100644 --- a/gateway/field_presence_test.go +++ b/gateway/field_presence_test.go @@ -157,3 +157,100 @@ func TestUnaryServerInterceptor(t *testing.T) { } }) } + +type RequestWithFieldMask struct { + FieldMask *field_mask.FieldMask +} + +func TestOverrideFieldMaskOption(t *testing.T) { + defaultInvoker := func(context.Context, string, interface{}, interface{}, *grpc.ClientConn, ...grpc.CallOption) error { + return nil + } + + f := func(ctx context.Context, req *RequestWithFieldMask, expected *field_mask.FieldMask, overrideEnabled bool) { + interceptor := PresenceClientInterceptor() + if overrideEnabled { + interceptor = PresenceClientInterceptor(WithOverrideFieldMask) + } + + if err := interceptor(ctx, "", req, nil, nil, defaultInvoker); err != nil { + t.Logf("Unexpected error %s\n", err) + t.Fail() + } + + result := req.FieldMask + + if !isEqualFieldMasks(expected.Paths, result.Paths) { + t.Logf("Unexpected result field mask, expect %+v, got %+v\n", expected.Paths, result.Paths) + t.Fail() + } + + } + + r1 := &RequestWithFieldMask{FieldMask: &field_mask.FieldMask{Paths: []string{"One"}}} + f(context.Background(), r1, &field_mask.FieldMask{Paths: []string{"One"}}, true) + + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD{fieldPresenceMetaKey: []string{"One"}}) + r2 := &RequestWithFieldMask{} + f(ctx, r2, &field_mask.FieldMask{Paths: []string{"One"}}, true) + + ctx = metadata.NewIncomingContext(context.Background(), metadata.MD{fieldPresenceMetaKey: []string{"Two"}}) + r3 := &RequestWithFieldMask{FieldMask: &field_mask.FieldMask{Paths: []string{"One"}}} + f(ctx, r3, &field_mask.FieldMask{Paths: []string{"Two"}}, true) + + ctx = metadata.NewIncomingContext(context.Background(), metadata.MD{fieldPresenceMetaKey: []string{"Two"}}) + r4 := &RequestWithFieldMask{FieldMask: &field_mask.FieldMask{Paths: []string{"One"}}} + f(ctx, r4, &field_mask.FieldMask{Paths: []string{"One"}}, false) +} + +type RequestWithMultiFieldMask struct { + FieldMasks []*field_mask.FieldMask +} + +func TestOverrideMultipleFieldMasksOption(t *testing.T) { + defaultInvoker := func(context.Context, string, interface{}, interface{}, *grpc.ClientConn, ...grpc.CallOption) error { + return nil + } + + f := func(ctx context.Context, req *RequestWithMultiFieldMask, expected []*field_mask.FieldMask, overrideEnabled bool) { + interceptor := PresenceClientInterceptor() + if overrideEnabled { + interceptor = PresenceClientInterceptor(WithOverrideFieldMask) + } + + if err := interceptor(ctx, "", req, nil, nil, defaultInvoker); err != nil { + t.Logf("Unexpected error %s\n", err) + t.Fail() + } + + result := req.FieldMasks + if len(result) != len(expected) { + t.Logf("Unexpected field masks expect %+v, got %+v", expected, result) + t.Fail() + return + } + + for i := 0; i < len(expected); i++ { + if !isEqualFieldMasks(expected[i].Paths, result[i].Paths) { + t.Logf("Unexpected result field mask on index %d, expect %+v, got %+v\n", i, expected[0].Paths, result[0].Paths) + t.Fail() + } + } + + } + + r1 := &RequestWithMultiFieldMask{FieldMasks: []*field_mask.FieldMask{{Paths: []string{"One"}}, {Paths: []string{"Two"}}}} + f(context.Background(), r1, []*field_mask.FieldMask{{Paths: []string{"One"}}, {Paths: []string{"Two"}}}, true) + + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD{fieldPresenceMetaKey: []string{"One", "Two"}}) + r2 := &RequestWithMultiFieldMask{} + f(ctx, r2, []*field_mask.FieldMask{{Paths: []string{"One"}}, {Paths: []string{"Two"}}}, true) + + ctx = metadata.NewIncomingContext(context.Background(), metadata.MD{fieldPresenceMetaKey: []string{"Four", "Five"}}) + r3 := &RequestWithMultiFieldMask{FieldMasks: []*field_mask.FieldMask{{Paths: []string{"One"}}, {Paths: []string{"Two"}}}} + f(ctx, r3, []*field_mask.FieldMask{{Paths: []string{"Four"}}, {Paths: []string{"Five"}}}, true) + + ctx = metadata.NewIncomingContext(context.Background(), metadata.MD{fieldPresenceMetaKey: []string{"Four", "Five"}}) + r4 := &RequestWithMultiFieldMask{FieldMasks: []*field_mask.FieldMask{{Paths: []string{"One"}}, {Paths: []string{"Two"}}}} + f(ctx, r4, []*field_mask.FieldMask{{Paths: []string{"One"}}, {Paths: []string{"Two"}}}, false) +}