Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,10 @@ func (o ContentSubtypeCallOption) before(c *callInfo) error {
}
func (o ContentSubtypeCallOption) after(c *callInfo, attempt *csAttempt) {}

// ForceCodec returns a CallOption that will set codec to be
// used for all request and response messages for a call. The result of calling
// Name() will be used as the content-subtype in a case-insensitive manner.
// ForceCodec returns a CallOption that will set codec to be used for all
// request and response messages for a call. The result of calling Name() will
// be used as the content-subtype after converting to lowercase, unless
// CallContentSubtype is also used.
//
// See Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
Expand Down Expand Up @@ -853,7 +854,17 @@ func toRPCErr(err error) error {
// setCallInfoCodec should only be called after CallOptions have been applied.
func setCallInfoCodec(c *callInfo) error {
if c.codec != nil {
// codec was already set by a CallOption; use it.
// codec was already set by a CallOption; use it, but set the content
// subtype if it is not set.
if c.contentSubtype == "" {
// c.codec is a baseCodec to hide the difference between grpc.Codec and
// encoding.Codec (Name vs. String method name). We only support
// setting content subtype from encoding.Codec to avoid a behavior
// change with the deprecated version.
if ec, ok := c.codec.(encoding.Codec); ok {
c.contentSubtype = strings.ToLower(ec.Name())
}
}
return nil
}

Expand Down
53 changes: 51 additions & 2 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5293,7 +5293,7 @@ func (s) TestGRPCMethod(t *testing.T) {
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
Expand All @@ -5305,6 +5305,55 @@ func (s) TestGRPCMethod(t *testing.T) {
}
}

// renameProtoCodec is an encoding.Codec wrapper that allows customizing the
// Name() of another codec.
type renameProtoCodec struct {
encoding.Codec
name string
}

func (r *renameProtoCodec) Name() string { return r.name }

// TestForceCodecName confirms that the ForceCodec call option sets the subtype
// in the content-type header according to the Name() of the codec provided.
func (s) TestForceCodecName(t *testing.T) {
wantContentTypeCh := make(chan []string, 1)
defer close(wantContentTypeCh)

ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Internal, "no metadata in context")
}
if got, want := md["content-type"], <-wantContentTypeCh; !reflect.DeepEqual(got, want) {
return nil, status.Errorf(codes.Internal, "got content-type=%q; want [%q]", got, want)
}
return &testpb.Empty{}, nil
},
}
if err := ss.Start([]grpc.ServerOption{grpc.ForceServerCodec(encoding.GetCodec("proto"))}); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

codec := &renameProtoCodec{Codec: encoding.GetCodec("proto"), name: "some-test-name"}
wantContentTypeCh <- []string{"application/grpc+some-test-name"}
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(codec)); err != nil {
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
}

// Confirm the name is converted to lowercase before transmitting.
codec.name = "aNoTHeRNaME"
wantContentTypeCh <- []string{"application/grpc+anothername"}
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(codec)); err != nil {
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
}
}

func (s) TestForceServerCodec(t *testing.T) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
Expand All @@ -5317,7 +5366,7 @@ func (s) TestForceServerCodec(t *testing.T) {
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
Expand Down