diff --git a/v2/invoke.go b/v2/invoke.go index 721d1af5..af1eebe3 100644 --- a/v2/invoke.go +++ b/v2/invoke.go @@ -31,15 +31,26 @@ package gax import ( "context" + "strconv" "strings" "time" "github.com/googleapis/gax-go/v2/apierror" + "google.golang.org/grpc/metadata" ) // APICall is a user defined call stub. type APICall func(context.Context, CallSettings) error +// withRetryCount returns a new context with the retry count appended to +// gRPC metadata. The retry count is the number of retries that have been +// attempted. On the initial request, retry count is 0. +// On a second request (the first retry), retry count is 1. +func withRetryCount(ctx context.Context, retryCount int) context.Context { + // Add to gRPC metadata so it's visible to StatsHandlers + return metadata.AppendToOutgoingContext(ctx, "gcp.grpc.resend_count", strconv.Itoa(retryCount)) +} + // Invoke calls the given APICall, performing retries as specified by opts, if // any. func Invoke(ctx context.Context, call APICall, opts ...CallOption) error { @@ -78,8 +89,15 @@ func invoke(ctx context.Context, call APICall, settings CallSettings, sp sleeper ctx = c } + retryCount := 0 + // Feature gate: GOOGLE_SDK_GO_EXPERIMENTAL_TRACING=true + tracingEnabled := IsFeatureEnabled("TRACING") for { - err := call(ctx, settings) + ctxToUse := ctx + if tracingEnabled { + ctxToUse = withRetryCount(ctx, retryCount) + } + err := call(ctxToUse, settings) if err == nil { return nil } @@ -110,5 +128,6 @@ func invoke(ctx context.Context, call APICall, settings CallSettings, sp sleeper } else if err = sp(ctx, d); err != nil { return err } + retryCount++ } } diff --git a/v2/invoke_test.go b/v2/invoke_test.go index 8ae6d280..2a2dd4b2 100644 --- a/v2/invoke_test.go +++ b/v2/invoke_test.go @@ -32,6 +32,8 @@ package gax import ( "context" "errors" + "fmt" + "strconv" "testing" "time" @@ -40,6 +42,7 @@ import ( "github.com/googleapis/gax-go/v2/apierror" "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -264,3 +267,47 @@ func TestInvokeWithTimeout(t *testing.T) { }) } } + +func TestInvokeRetryCount(t *testing.T) { + for _, tracingEnabled := range []bool{true, false} { + t.Run(fmt.Sprintf("tracingEnabled=%v", tracingEnabled), func(t *testing.T) { + TestOnlyResetIsFeatureEnabled() + defer TestOnlyResetIsFeatureEnabled() + + if tracingEnabled { + t.Setenv("GOOGLE_SDK_GO_EXPERIMENTAL_TRACING", "true") + } else { + t.Setenv("GOOGLE_SDK_GO_EXPERIMENTAL_TRACING", "false") + } + + const target = 3 + var retryCounts []int + calls := 0 + apiCall := func(ctx context.Context, _ CallSettings) error { + calls++ + md, _ := metadata.FromOutgoingContext(ctx) + if vals := md["gcp.grpc.resend_count"]; len(vals) > 0 { + if count, err := strconv.Atoi(vals[0]); err == nil { + retryCounts = append(retryCounts, count) + } + } + if calls < target { + return errors.New("retry") + } + return nil + } + var settings CallSettings + WithRetry(func() Retryer { return boolRetryer(true) }).Resolve(&settings) + var sp recordSleeper + invoke(context.Background(), apiCall, settings, sp.sleep) + + var want []int + if tracingEnabled { + want = []int{0, 1, 2} + } + if diff := cmp.Diff(want, retryCounts); diff != "" { + t.Errorf("retry count mismatch (-want +got):\n%s", diff) + } + }) + } +}