Skip to content

Commit

Permalink
cherry-pick #7557 to v1.66.x branch (#7564)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfawley authored Aug 26, 2024
1 parent 62baa5f commit 8e3596c
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 133 deletions.
10 changes: 2 additions & 8 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,11 @@ type baseCodec interface {
// with encoding.GetCodec and if it is registered wraps it with newCodecV1Bridge
// to turn it into an encoding.CodecV2. Returns nil otherwise.
func getCodec(name string) encoding.CodecV2 {
codecV2 := encoding.GetCodecV2(name)
if codecV2 != nil {
return codecV2
}

codecV1 := encoding.GetCodec(name)
if codecV1 != nil {
if codecV1 := encoding.GetCodec(name); codecV1 != nil {
return newCodecV1Bridge(codecV1)
}

return nil
return encoding.GetCodecV2(name)
}

func newCodecV0Bridge(c Codec) baseCodec {
Expand Down
2 changes: 1 addition & 1 deletion codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

func (s) TestGetCodecForProtoIsNotNil(t *testing.T) {
if encoding.GetCodec(proto.Name) == nil {
if encoding.GetCodecV2(proto.Name) == nil {
t.Fatalf("encoding.GetCodec(%q) must not be nil by default", proto.Name)
}
}
5 changes: 3 additions & 2 deletions encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ type Codec interface {
Name() string
}

var registeredCodecs = make(map[string]Codec)
var registeredCodecs = make(map[string]any)

// RegisterCodec registers the provided Codec for use with all gRPC clients and
// servers.
Expand Down Expand Up @@ -126,5 +126,6 @@ func RegisterCodec(codec Codec) {
//
// The content-subtype is expected to be lowercase.
func GetCodec(contentSubtype string) Codec {
return registeredCodecs[contentSubtype]
c, _ := registeredCodecs[contentSubtype].(Codec)
return c
}
37 changes: 19 additions & 18 deletions encoding/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

Expand Down Expand Up @@ -90,18 +91,18 @@ type errProtoCodec struct {
decodingErr error
}

func (c *errProtoCodec) Marshal(v any) ([]byte, error) {
func (c *errProtoCodec) Marshal(v any) (mem.BufferSlice, error) {
if c.encodingErr != nil {
return nil, c.encodingErr
}
return encoding.GetCodec(proto.Name).Marshal(v)
return encoding.GetCodecV2(proto.Name).Marshal(v)
}

func (c *errProtoCodec) Unmarshal(data []byte, v any) error {
func (c *errProtoCodec) Unmarshal(data mem.BufferSlice, v any) error {
if c.decodingErr != nil {
return c.decodingErr
}
return encoding.GetCodec(proto.Name).Unmarshal(data, v)
return encoding.GetCodecV2(proto.Name).Unmarshal(data, v)
}

func (c *errProtoCodec) Name() string {
Expand All @@ -118,7 +119,7 @@ func (s) TestEncodeDoesntPanicOnServer(t *testing.T) {
ec := &errProtoCodec{name: t.Name(), encodingErr: encodingErr}

// Start a server with the above codec.
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodec(ec))
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodecV2(ec))
defer backend.Stop()

// Create a channel to the above server.
Expand Down Expand Up @@ -154,7 +155,7 @@ func (s) TestDecodeDoesntPanicOnServer(t *testing.T) {
ec := &errProtoCodec{name: t.Name(), decodingErr: decodingErr}

// Start a server with the above codec.
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodec(ec))
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodecV2(ec))
defer backend.Stop()

// Create a channel to the above server. Since we do not specify any codec
Expand Down Expand Up @@ -206,15 +207,15 @@ func (s) TestEncodeDoesntPanicOnClient(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
client := testgrpc.NewTestServiceClient(cc)
_, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec))
_, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec))
if err == nil || !strings.Contains(err.Error(), encodingErr.Error()) {
t.Fatalf("RPC failed with error: %v, want: %v", err, encodingErr)
}

// Configure the codec on the client to not return errors anymore and expect
// the RPC to succeed.
ec.encodingErr = nil
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec)); err != nil {
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec)); err != nil {
t.Fatalf("RPC failed with error: %v", err)
}
}
Expand Down Expand Up @@ -242,15 +243,15 @@ func (s) TestDecodeDoesntPanicOnClient(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
client := testgrpc.NewTestServiceClient(cc)
_, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec))
_, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec))
if err == nil || !strings.Contains(err.Error(), decodingErr.Error()) {
t.Fatalf("RPC failed with error: %v, want: %v", err, decodingErr)
}

// Configure the codec on the client to not return errors anymore and expect
// the RPC to succeed.
ec.decodingErr = nil
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec)); err != nil {
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec)); err != nil {
t.Fatalf("RPC failed with error: %v", err)
}
}
Expand All @@ -265,14 +266,14 @@ type countingProtoCodec struct {
unmarshalCount int32
}

func (p *countingProtoCodec) Marshal(v any) ([]byte, error) {
func (p *countingProtoCodec) Marshal(v any) (mem.BufferSlice, error) {
atomic.AddInt32(&p.marshalCount, 1)
return encoding.GetCodec(proto.Name).Marshal(v)
return encoding.GetCodecV2(proto.Name).Marshal(v)
}

func (p *countingProtoCodec) Unmarshal(data []byte, v any) error {
func (p *countingProtoCodec) Unmarshal(data mem.BufferSlice, v any) error {
atomic.AddInt32(&p.unmarshalCount, 1)
return encoding.GetCodec(proto.Name).Unmarshal(data, v)
return encoding.GetCodecV2(proto.Name).Unmarshal(data, v)
}

func (p *countingProtoCodec) Name() string {
Expand All @@ -284,7 +285,7 @@ func (p *countingProtoCodec) Name() string {
func (s) TestForceServerCodec(t *testing.T) {
// Create an server with the counting proto codec.
codec := &countingProtoCodec{name: t.Name()}
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodec(codec))
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodecV2(codec))
defer backend.Stop()

// Create a channel to the above server.
Expand Down Expand Up @@ -317,7 +318,7 @@ func (s) TestForceServerCodec(t *testing.T) {

// renameProtoCodec wraps the proto codec and allows customizing the Name().
type renameProtoCodec struct {
encoding.Codec
encoding.CodecV2
name string
}

Expand Down Expand Up @@ -356,9 +357,9 @@ func (s) TestForceCodecName(t *testing.T) {

// Force the use of the custom codec on the client with the ForceCodec call
// option. Confirm the name is converted to lowercase before transmitting.
codec := &renameProtoCodec{Codec: encoding.GetCodec(proto.Name), name: t.Name()}
codec := &renameProtoCodec{CodecV2: encoding.GetCodecV2(proto.Name), name: t.Name()}
wantContentTypeCh <- []string{fmt.Sprintf("application/grpc+%s", strings.ToLower(t.Name()))}
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(codec)); err != nil {
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(codec)); err != nil {
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
}
}
7 changes: 3 additions & 4 deletions encoding/encoding_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ type CodecV2 interface {
Name() string
}

var registeredV2Codecs = make(map[string]CodecV2)

// RegisterCodecV2 registers the provided CodecV2 for use with all gRPC clients and
// servers.
//
Expand All @@ -70,13 +68,14 @@ func RegisterCodecV2(codec CodecV2) {
panic("cannot register CodecV2 with empty string result for Name()")
}
contentSubtype := strings.ToLower(codec.Name())
registeredV2Codecs[contentSubtype] = codec
registeredCodecs[contentSubtype] = codec
}

// GetCodecV2 gets a registered CodecV2 by content-subtype, or nil if no CodecV2 is
// registered for the content-subtype.
//
// The content-subtype is expected to be lowercase.
func GetCodecV2(contentSubtype string) CodecV2 {
return registeredV2Codecs[contentSubtype]
c, _ := registeredCodecs[contentSubtype].(CodecV2)
return c
}
44 changes: 34 additions & 10 deletions encoding/proto/proto.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright 2018 gRPC authors.
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,6 +24,7 @@ import (
"fmt"

"google.golang.org/grpc/encoding"
"google.golang.org/grpc/mem"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/protoadapt"
)
Expand All @@ -32,28 +33,51 @@ import (
const Name = "proto"

func init() {
encoding.RegisterCodec(codec{})
encoding.RegisterCodecV2(&codecV2{})
}

// codec is a Codec implementation with protobuf. It is the default codec for gRPC.
type codec struct{}
// codec is a CodecV2 implementation with protobuf. It is the default codec for
// gRPC.
type codecV2 struct{}

func (codec) Marshal(v any) ([]byte, error) {
func (c *codecV2) Marshal(v any) (data mem.BufferSlice, err error) {
vv := messageV2Of(v)
if vv == nil {
return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v)
return nil, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v)
}

return proto.Marshal(vv)
size := proto.Size(vv)
if mem.IsBelowBufferPoolingThreshold(size) {
buf, err := proto.Marshal(vv)
if err != nil {
return nil, err
}
data = append(data, mem.SliceBuffer(buf))
} else {
pool := mem.DefaultBufferPool()
buf := pool.Get(size)
if _, err := (proto.MarshalOptions{}).MarshalAppend((*buf)[:0], vv); err != nil {
pool.Put(buf)
return nil, err
}
data = append(data, mem.NewBuffer(buf, pool))
}

return data, nil
}

func (codec) Unmarshal(data []byte, v any) error {
func (c *codecV2) Unmarshal(data mem.BufferSlice, v any) (err error) {
vv := messageV2Of(v)
if vv == nil {
return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v)
}

return proto.Unmarshal(data, vv)
buf := data.MaterializeToBuffer(mem.DefaultBufferPool())
defer buf.Free()
// TODO: Upgrade proto.Unmarshal to support mem.BufferSlice. Right now, it's not
// really possible without a major overhaul of the proto package, but the
// vtprotobuf library may be able to support this.
return proto.Unmarshal(buf.ReadOnlyData(), vv)
}

func messageV2Of(v any) proto.Message {
Expand All @@ -67,6 +91,6 @@ func messageV2Of(v any) proto.Message {
return nil
}

func (codec) Name() string {
func (c *codecV2) Name() string {
return Name
}
6 changes: 3 additions & 3 deletions encoding/proto/proto_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func BenchmarkProtoCodec(b *testing.B) {
protoStructs := setupBenchmarkProtoCodecInputs(s)
name := fmt.Sprintf("MinPayloadSize:%v/SetParallelism(%v)", s, p)
b.Run(name, func(b *testing.B) {
codec := &codec{}
codec := &codecV2{}
b.SetParallelism(p)
b.RunParallel(func(pb *testing.PB) {
benchmarkProtoCodec(codec, protoStructs, pb, b)
Expand All @@ -78,7 +78,7 @@ func BenchmarkProtoCodec(b *testing.B) {
}
}

func benchmarkProtoCodec(codec *codec, protoStructs []proto.Message, pb *testing.PB, b *testing.B) {
func benchmarkProtoCodec(codec *codecV2, protoStructs []proto.Message, pb *testing.PB, b *testing.B) {
counter := 0
for pb.Next() {
counter++
Expand All @@ -87,7 +87,7 @@ func benchmarkProtoCodec(codec *codec, protoStructs []proto.Message, pb *testing
}
}

func fastMarshalAndUnmarshal(codec encoding.Codec, protoStruct proto.Message, b *testing.B) {
func fastMarshalAndUnmarshal(codec encoding.CodecV2, protoStruct proto.Message, b *testing.B) {
marshaledBytes, err := codec.Marshal(protoStruct)
if err != nil {
b.Errorf("codec.Marshal(_) returned an error")
Expand Down
13 changes: 7 additions & 6 deletions encoding/proto/proto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ import (

"google.golang.org/grpc/encoding"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/mem"
pb "google.golang.org/grpc/test/codec_perf"
)

func marshalAndUnmarshal(t *testing.T, codec encoding.Codec, expectedBody []byte) {
func marshalAndUnmarshal(t *testing.T, codec encoding.CodecV2, expectedBody []byte) {
p := &pb.Buffer{}
p.Body = expectedBody

Expand All @@ -55,7 +56,7 @@ func Test(t *testing.T) {
}

func (s) TestBasicProtoCodecMarshalAndUnmarshal(t *testing.T) {
marshalAndUnmarshal(t, codec{}, []byte{1, 2, 3})
marshalAndUnmarshal(t, &codecV2{}, []byte{1, 2, 3})
}

// Try to catch possible race conditions around use of pools
Expand All @@ -75,7 +76,7 @@ func (s) TestConcurrentUsage(t *testing.T) {
}

var wg sync.WaitGroup
codec := codec{}
codec := &codecV2{}

for i := 0; i < numGoRoutines; i++ {
wg.Add(1)
Expand All @@ -93,16 +94,16 @@ func (s) TestConcurrentUsage(t *testing.T) {
// TestStaggeredMarshalAndUnmarshalUsingSamePool tries to catch potential errors in which slices get
// stomped on during reuse of a proto.Buffer.
func (s) TestStaggeredMarshalAndUnmarshalUsingSamePool(t *testing.T) {
codec1 := codec{}
codec2 := codec{}
codec1 := &codecV2{}
codec2 := &codecV2{}

expectedBody1 := []byte{1, 2, 3}
expectedBody2 := []byte{4, 5, 6}

proto1 := pb.Buffer{Body: expectedBody1}
proto2 := pb.Buffer{Body: expectedBody2}

var m1, m2 []byte
var m1, m2 mem.BufferSlice
var err error

if m1, err = codec1.Marshal(&proto1); err != nil {
Expand Down
Loading

0 comments on commit 8e3596c

Please sign in to comment.