diff --git a/rpc_util.go b/rpc_util.go index c8ae0e4444c7..6db356fa56a7 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -429,9 +429,10 @@ func (o ContentSubtypeCallOption) before(c *callInfo) error { } func (o ContentSubtypeCallOption) after(c *callInfo, attempt *csAttempt) {} -// ForceCodec returns a CallOption that will set codec to be -// used for all request and response messages for a call. The result of calling -// Name() will be used as the content-subtype in a case-insensitive manner. +// ForceCodec returns a CallOption that will set codec to be used for all +// request and response messages for a call. The result of calling Name() will +// be used as the content-subtype after converting to lowercase, unless +// CallContentSubtype is also used. // // See Content-Type on // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for @@ -853,7 +854,17 @@ func toRPCErr(err error) error { // setCallInfoCodec should only be called after CallOptions have been applied. func setCallInfoCodec(c *callInfo) error { if c.codec != nil { - // codec was already set by a CallOption; use it. + // codec was already set by a CallOption; use it, but set the content + // subtype if it is not set. + if c.contentSubtype == "" { + // c.codec is a baseCodec to hide the difference between grpc.Codec and + // encoding.Codec (Name vs. String method name). We only support + // setting content subtype from encoding.Codec to avoid a behavior + // change with the deprecated version. + if ec, ok := c.codec.(encoding.Codec); ok { + c.contentSubtype = strings.ToLower(ec.Name()) + } + } return nil } diff --git a/test/end2end_test.go b/test/end2end_test.go index 861a2f29623d..eb91d09afdf0 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5293,7 +5293,7 @@ func (s) TestGRPCMethod(t *testing.T) { } defer ss.Stop() - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { @@ -5305,6 +5305,55 @@ func (s) TestGRPCMethod(t *testing.T) { } } +// renameProtoCodec is an encoding.Codec wrapper that allows customizing the +// Name() of another codec. +type renameProtoCodec struct { + encoding.Codec + name string +} + +func (r *renameProtoCodec) Name() string { return r.name } + +// TestForceCodecName confirms that the ForceCodec call option sets the subtype +// in the content-type header according to the Name() of the codec provided. +func (s) TestForceCodecName(t *testing.T) { + wantContentTypeCh := make(chan []string, 1) + defer close(wantContentTypeCh) + + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Errorf(codes.Internal, "no metadata in context") + } + if got, want := md["content-type"], <-wantContentTypeCh; !reflect.DeepEqual(got, want) { + return nil, status.Errorf(codes.Internal, "got content-type=%q; want [%q]", got, want) + } + return &testpb.Empty{}, nil + }, + } + if err := ss.Start([]grpc.ServerOption{grpc.ForceServerCodec(encoding.GetCodec("proto"))}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + codec := &renameProtoCodec{Codec: encoding.GetCodec("proto"), name: "some-test-name"} + wantContentTypeCh <- []string{"application/grpc+some-test-name"} + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(codec)); err != nil { + t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err) + } + + // Confirm the name is converted to lowercase before transmitting. + codec.name = "aNoTHeRNaME" + wantContentTypeCh <- []string{"application/grpc+anothername"} + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(codec)); err != nil { + t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err) + } +} + func (s) TestForceServerCodec(t *testing.T) { ss := &stubserver.StubServer{ EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { @@ -5317,7 +5366,7 @@ func (s) TestForceServerCodec(t *testing.T) { } defer ss.Stop() - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {