Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions experimental/experimental.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,12 @@ func WithBufferPool(bufferPool mem.BufferPool) grpc.DialOption {
func BufferPool(bufferPool mem.BufferPool) grpc.ServerOption {
return internal.BufferPool.(func(mem.BufferPool) grpc.ServerOption)(bufferPool)
}

// AcceptCompressors returns a CallOption that limits the values
// advertised in the grpc-accept-encoding header for the provided RPC. The
// supplied names must correspond to compressors registered via
// encoding.RegisterCompressor. Passing no names advertises "identity" (no
// compression) only.
func AcceptCompressors(names ...string) grpc.CallOption {
return internal.AcceptCompressors.(func(...string) grpc.CallOption)(names...)
}
4 changes: 4 additions & 0 deletions internal/experimental.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,8 @@ var (
// BufferPool is implemented by the grpc package and returns a server
// option to configure a shared buffer pool for a grpc.Server.
BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption

// AcceptCompressors is implemented by the grpc package and returns
// a call option that restricts the grpc-accept-encoding header for a call.
AcceptCompressors any // func(...string) grpc.CallOption
)
3 changes: 3 additions & 0 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,9 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te
hfLen += len(authData) + len(callAuthData)
registeredCompressors := t.registeredCompressors
if callHdr.AcceptedCompressors != nil {
registeredCompressors = *callHdr.AcceptedCompressors
}
if callHdr.PreviousAttempts > 0 {
hfLen++
}
Expand Down
6 changes: 6 additions & 0 deletions internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,12 @@ type CallHdr struct {
// outbound message.
SendCompress string

// AcceptedCompressors overrides the grpc-accept-encoding header for this
// call. When nil, the transport advertises the default set of registered
// compressors. A non-nil pointer overrides that value (including the empty
// string to advertise none).
AcceptedCompressors *string

// Creds specifies credentials.PerRPCCredentials for a call.
Creds credentials.PerRPCCredentials

Expand Down
93 changes: 81 additions & 12 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
Expand All @@ -41,6 +43,10 @@ import (
"google.golang.org/grpc/status"
)

func init() {
internal.AcceptCompressors = AcceptCompressors
}

// Compressor defines the interface gRPC uses to compress a message.
//
// Deprecated: use package encoding.
Expand Down Expand Up @@ -151,16 +157,32 @@ func (d *gzipDecompressor) Type() string {

// callInfo contains all related configuration and information about an RPC.
type callInfo struct {
compressorName string
failFast bool
maxReceiveMessageSize *int
maxSendMessageSize *int
creds credentials.PerRPCCredentials
contentSubtype string
codec baseCodec
maxRetryRPCBufferSize int
onFinish []func(err error)
authority string
compressorName string
failFast bool
maxReceiveMessageSize *int
maxSendMessageSize *int
creds credentials.PerRPCCredentials
contentSubtype string
codec baseCodec
maxRetryRPCBufferSize int
onFinish []func(err error)
authority string
acceptedResponseCompressors []string
}

func acceptedCompressorAllows(allowed []string, name string) bool {
if allowed == nil {
return true
}
if name == "" || name == encoding.Identity {
return true
}
for _, a := range allowed {
if a == name {
return true
}
}
return false
}

func defaultCallInfo() *callInfo {
Expand All @@ -170,6 +192,29 @@ func defaultCallInfo() *callInfo {
}
}

func newAcceptedCompressionConfig(names []string) ([]string, error) {
if len(names) == 0 {
return nil, nil
}
var allowed []string
seen := make(map[string]struct{}, len(names))
for _, name := range names {
name = strings.TrimSpace(name)
if name == "" || name == encoding.Identity {
continue
}
if !grpcutil.IsCompressorNameRegistered(name) {
return nil, status.Errorf(codes.InvalidArgument, "grpc: compressor %q is not registered", name)
}
if _, dup := seen[name]; dup {
continue
}
seen[name] = struct{}{}
allowed = append(allowed, name)
}
return allowed, nil
}

// CallOption configures a Call before it starts or extracts information from
// a Call after it completes.
type CallOption interface {
Expand Down Expand Up @@ -471,6 +516,31 @@ func (o CompressorCallOption) before(c *callInfo) error {
}
func (o CompressorCallOption) after(*callInfo, *csAttempt) {}

// AcceptCompressors returns a CallOption that limits the compression algorithms
// advertised in the grpc-accept-encoding header for response messages.
// Compression algorithms not in the provided list will not be advertised, and
// responses compressed with non-listed algorithms will be rejected.
func AcceptCompressors(names ...string) CallOption {
cp := append([]string(nil), names...)
return AcceptCompressorsCallOption{names: cp}
}

// AcceptCompressorsCallOption is a CallOption that limits response compression.
type AcceptCompressorsCallOption struct {
names []string
}

func (o AcceptCompressorsCallOption) before(c *callInfo) error {
allowed, err := newAcceptedCompressionConfig(o.names)
if err != nil {
return err
}
c.acceptedResponseCompressors = allowed
return nil
}

func (AcceptCompressorsCallOption) after(*callInfo, *csAttempt) {}

// CallContentSubtype returns a CallOption that will set the content-subtype
// for a call. For example, if content-subtype is "json", the Content-Type over
// the wire will be "application/grpc+json". The content-subtype is converted
Expand Down Expand Up @@ -857,8 +927,7 @@ func (p *payloadInfo) free() {
// the buffer is no longer needed.
// TODO: Refactor this function to reduce the number of arguments.
// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
) (out mem.BufferSlice, err error) {
func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) (out mem.BufferSlice, err error) {
pf, compressed, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
return nil, err
Expand Down
112 changes: 112 additions & 0 deletions rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,118 @@ const (
decompressionErrorMsg = "invalid compression format"
)

type testCompressorForRegistry struct {
name string
}

func (c *testCompressorForRegistry) Compress(w io.Writer) (io.WriteCloser, error) {
return &testWriteCloser{w}, nil
}

func (c *testCompressorForRegistry) Decompress(r io.Reader) (io.Reader, error) {
return r, nil
}

func (c *testCompressorForRegistry) Name() string {
return c.name
}

type testWriteCloser struct {
io.Writer
}

func (w *testWriteCloser) Close() error {
return nil
}

func (s) TestNewAcceptedCompressionConfig(t *testing.T) {
// Register a test compressor for multi-compressor tests
testCompressor := &testCompressorForRegistry{name: "test-compressor"}
encoding.RegisterCompressor(testCompressor)
defer func() {
// Unregister the test compressor
encoding.RegisterCompressor(&testCompressorForRegistry{name: "test-compressor"})
}()

tests := []struct {
name string
input []string
wantAllowed []string
wantErr bool
}{
{
name: "identity-only",
input: nil,
wantAllowed: nil,
},
{
name: "single valid",
input: []string{"gzip"},
wantAllowed: []string{"gzip"},
},
{
name: "dedupe and trim",
input: []string{" gzip ", "gzip"},
wantAllowed: []string{"gzip"},
},
{
name: "ignores identity",
input: []string{"identity", "gzip"},
wantAllowed: []string{"gzip"},
},
{
name: "explicit identity only",
input: []string{"identity"},
wantAllowed: nil,
},
{
name: "invalid compressor",
input: []string{"does-not-exist"},
wantErr: true,
},
{
name: "only whitespace",
input: []string{" ", "\t"},
wantAllowed: nil,
},
{
name: "multiple valid compressors",
input: []string{"gzip", "test-compressor"},
wantAllowed: []string{"gzip", "test-compressor"},
},
{
name: "multiple with identity and whitespace",
input: []string{"gzip", "identity", " test-compressor ", " "},
wantAllowed: []string{"gzip", "test-compressor"},
},
{
name: "empty string in list",
input: []string{"gzip", "", "test-compressor"},
wantAllowed: []string{"gzip", "test-compressor"},
},
{
name: "mixed valid and invalid",
input: []string{"gzip", "invalid-comp"},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
allowed, err := newAcceptedCompressionConfig(tt.input)
if (err != nil) != tt.wantErr {
t.Fatalf("newAcceptedCompressionConfig(%v) error = %v, wantErr %v", tt.input, err, tt.wantErr)
}
if tt.wantErr {
return
}
if diff := cmp.Diff(tt.wantAllowed, allowed); diff != "" {
t.Fatalf("allowed diff (-want +got): %v", diff)
}
})
}
}

type fullReader struct {
data []byte
}
Expand Down
13 changes: 13 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"math"
rand "math/rand/v2"
"strconv"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -301,6 +302,10 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client
DoneFunc: doneFunc,
Authority: callInfo.authority,
}
if allowed := callInfo.acceptedResponseCompressors; len(allowed) > 0 {
headerValue := strings.Join(allowed, ",")
callHdr.AcceptedCompressors = &headerValue
}

// Set our outgoing compression according to the UseCompressor CallOption, if
// set. In that case, also find the compressor from the encoding package.
Expand Down Expand Up @@ -1134,6 +1139,10 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
a.decompressorV0 = nil
a.decompressorV1 = encoding.GetCompressor(ct)
}
// Validate that the compression method is acceptable for this call.
if !acceptedCompressorAllows(cs.callInfo.acceptedResponseCompressors, ct) {
return status.Errorf(codes.Internal, "grpc: peer compressed the response with %q which is not allowed by AcceptCompressors", ct)
}
} else {
// No compression is used; disable our decompressor.
a.decompressorV0 = nil
Expand Down Expand Up @@ -1479,6 +1488,10 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
as.decompressorV0 = nil
as.decompressorV1 = encoding.GetCompressor(ct)
}
// Validate that the compression method is acceptable for this call.
if !acceptedCompressorAllows(as.callInfo.acceptedResponseCompressors, ct) {
return status.Errorf(codes.Internal, "grpc: peer compressed the response with %q which is not allowed by AcceptCompressors", ct)
}
} else {
// No compression is used; disable our decompressor.
as.decompressorV0 = nil
Expand Down
Loading