Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion v2/invoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,26 @@ package gax

import (
"context"
"strconv"
"strings"
"time"

"github.com/googleapis/gax-go/v2/apierror"
"google.golang.org/grpc/metadata"
)

// APICall is a user defined call stub.
type APICall func(context.Context, CallSettings) error

// withRetryCount returns a new context with the retry count appended to
// gRPC metadata. The retry count is the number of retries that have been
// attempted. On the initial request, retry count is 0.
// On a second request (the first retry), retry count is 1.
func withRetryCount(ctx context.Context, retryCount int) context.Context {
// Add to gRPC metadata so it's visible to StatsHandlers
return metadata.AppendToOutgoingContext(ctx, "gcp.grpc.resend_count", strconv.Itoa(retryCount))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super clear on the transport agnosticism here, as the key and comments indicate this is gRPC only. If we're going to have a different function per transport, it might be worth naming them based on the transport type. Then again this is not an exported func so this is a decision that can be punted to future PRs.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're using gRPC metadata for both transports. I can clarify that here if you like. It's also discussed somewhat in the design doc.

}

// Invoke calls the given APICall, performing retries as specified by opts, if
// any.
func Invoke(ctx context.Context, call APICall, opts ...CallOption) error {
Expand Down Expand Up @@ -78,8 +89,15 @@ func invoke(ctx context.Context, call APICall, settings CallSettings, sp sleeper
ctx = c
}

retryCount := 0
// Feature gate: GOOGLE_SDK_GO_EXPERIMENTAL_TRACING=true
tracingEnabled := IsFeatureEnabled("TRACING")
for {
err := call(ctx, settings)
ctxToUse := ctx
if tracingEnabled {
ctxToUse = withRetryCount(ctx, retryCount)
}
err := call(ctxToUse, settings)
if err == nil {
return nil
}
Expand Down Expand Up @@ -110,5 +128,6 @@ func invoke(ctx context.Context, call APICall, settings CallSettings, sp sleeper
} else if err = sp(ctx, d); err != nil {
return err
}
retryCount++
}
}
47 changes: 47 additions & 0 deletions v2/invoke_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ package gax
import (
"context"
"errors"
"fmt"
"strconv"
"testing"
"time"

Expand All @@ -40,6 +42,7 @@ import (
"github.com/googleapis/gax-go/v2/apierror"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

Expand Down Expand Up @@ -264,3 +267,47 @@ func TestInvokeWithTimeout(t *testing.T) {
})
}
}

func TestInvokeRetryCount(t *testing.T) {
for _, tracingEnabled := range []bool{true, false} {
t.Run(fmt.Sprintf("tracingEnabled=%v", tracingEnabled), func(t *testing.T) {
TestOnlyResetIsFeatureEnabled()
defer TestOnlyResetIsFeatureEnabled()

if tracingEnabled {
t.Setenv("GOOGLE_SDK_GO_EXPERIMENTAL_TRACING", "true")
} else {
t.Setenv("GOOGLE_SDK_GO_EXPERIMENTAL_TRACING", "false")
}

const target = 3
var retryCounts []int
calls := 0
apiCall := func(ctx context.Context, _ CallSettings) error {
calls++
md, _ := metadata.FromOutgoingContext(ctx)
if vals := md["gcp.grpc.resend_count"]; len(vals) > 0 {
if count, err := strconv.Atoi(vals[0]); err == nil {
retryCounts = append(retryCounts, count)
}
}
if calls < target {
return errors.New("retry")
}
return nil
}
var settings CallSettings
WithRetry(func() Retryer { return boolRetryer(true) }).Resolve(&settings)
var sp recordSleeper
invoke(context.Background(), apiCall, settings, sp.sleep)

var want []int
if tracingEnabled {
want = []int{0, 1, 2}
}
if diff := cmp.Diff(want, retryCounts); diff != "" {
t.Errorf("retry count mismatch (-want +got):\n%s", diff)
}
})
}
}