diff --git a/sdk/data/azcosmos/CHANGELOG.md b/sdk/data/azcosmos/CHANGELOG.md index 2e1208900046..4c3042b6ade6 100644 --- a/sdk/data/azcosmos/CHANGELOG.md +++ b/sdk/data/azcosmos/CHANGELOG.md @@ -4,12 +4,16 @@ ### Features Added +* Added a dedicated 429 (Too Many Requests) throttling retry policy that honors the `x-ms-retry-after-ms` response header and is configurable via `ClientOptions.ThrottlingRetryOptions` (`MaxRetryAttempts`, `MaxRetryWaitTime`). This brings parity with the throttling retry behavior in the .NET, Java, and Python Cosmos SDKs. When `ClientOptions.Retry.StatusCodes` and `ClientOptions.Retry.ShouldRetry` are both unset, 429 is no longer in the azcore retry policy's default status codes (it is now handled exclusively by the throttling retry policy); the other transient status codes (408, 500, 502, 503, 504) remain. + ### Breaking Changes ### Bugs Fixed ### Other Changes +* Throttling retry policy: an explicit `x-ms-retry-after-ms: 0` header is now honored as "retry immediately" instead of being treated as a missing header (which would have applied the default delay). NaN/Inf values for the header are now rejected as invalid. The request body is rewound before the 429 response body is drained so a rewind failure surfaces a usable 429 response to the caller. + ## 1.5.0-beta.6 (2026-05-15) ### Features Added diff --git a/sdk/data/azcosmos/cosmos_client.go b/sdk/data/azcosmos/cosmos_client.go index a160924e08c5..5653359d50d6 100644 --- a/sdk/data/azcosmos/cosmos_client.go +++ b/sdk/data/azcosmos/cosmos_client.go @@ -184,6 +184,16 @@ func newClient(authPolicy policy.Policy, gem *globalEndpointManager, options *Cl if options == nil { options = &ClientOptions{} } + // Copy the embedded azcore.ClientOptions so adjustments to retry defaults + // don't mutate the caller-supplied struct. The throttleRetryPolicy below + // owns 429 handling with Cosmos-specific semantics (x-ms-retry-after-ms + // header, cumulative wait budget), so when the caller hasn't customized + // the retry status codes or supplied a ShouldRetry callback, exclude 429 + // from azcore's default retry list to avoid double-retry. + clientOpts := options.ClientOptions + if clientOpts.Retry.StatusCodes == nil && clientOpts.Retry.ShouldRetry == nil { + clientOpts.Retry.StatusCodes = defaultAzcoreRetryStatusCodesWithout429() + } return azcore.NewClient(moduleName, serviceLibVersion, azruntime.PipelineOptions{ AllowedHeaders: getAllowedHeaders(), @@ -197,13 +207,29 @@ func newClient(authPolicy policy.Policy, gem *globalEndpointManager, options *Cl }, PerRetry: []policy.Policy{ authPolicy, + newThrottleRetryPolicy(&options.ThrottlingRetryOptions), &clientRetryPolicy{gem: gem}, }, Tracing: azruntime.TracingOptions{ Namespace: "Microsoft.DocumentDB", }, }, - &options.ClientOptions) + &clientOpts) +} + +// defaultAzcoreRetryStatusCodesWithout429 returns azcore's default retryable +// HTTP status codes with 429 removed. The Cosmos throttleRetryPolicy already +// retries 429 with x-ms-retry-after-ms semantics and a cumulative wait budget, +// so layering azcore's default 429 retry on top would result in compounded +// retry attempts. +func defaultAzcoreRetryStatusCodesWithout429() []int { + return []int{ + http.StatusRequestTimeout, // 408 + http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout, // 504 + } } func newInternalPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pipeline { diff --git a/sdk/data/azcosmos/cosmos_client_options.go b/sdk/data/azcosmos/cosmos_client_options.go index 33b88bbc26f3..364a41647cbc 100644 --- a/sdk/data/azcosmos/cosmos_client_options.go +++ b/sdk/data/azcosmos/cosmos_client_options.go @@ -4,6 +4,8 @@ package azcosmos import ( + "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" ) @@ -25,4 +27,24 @@ type ClientOptions struct { // The valid range is 1 to 5 (inclusive). // Can be overridden per-request via the operation options. ThroughputBucket *int32 + // ThrottlingRetryOptions configures how the client retries requests that fail with + // HTTP 429 (Too Many Requests). When unset, defaults consistent with the other + // Cosmos SDKs are used (9 attempts, 30s cumulative wait). + ThrottlingRetryOptions ThrottlingRetryOptions +} + +// ThrottlingRetryOptions configures the retry behavior for HTTP 429 +// (Too Many Requests) responses. The Cosmos service indicates the recommended +// retry delay via the x-ms-retry-after-ms response header; the client respects +// that value subject to the limits in this struct. +type ThrottlingRetryOptions struct { + // MaxRetryAttempts is the maximum number of times the client will retry a + // throttled request. The default is 9. Set to a negative value to disable + // throttling retries. + MaxRetryAttempts int + // MaxRetryWaitTime is the maximum cumulative time the client will spend + // waiting between throttled retries for a single request. Once this budget + // is exhausted, the most recent 429 response is returned to the caller. + // The default is 30 seconds. + MaxRetryWaitTime time.Duration } diff --git a/sdk/data/azcosmos/cosmos_http_constants.go b/sdk/data/azcosmos/cosmos_http_constants.go index c9d29c6cc5c2..aba7ef026c37 100644 --- a/sdk/data/azcosmos/cosmos_http_constants.go +++ b/sdk/data/azcosmos/cosmos_http_constants.go @@ -91,6 +91,7 @@ const ( headerDedicatedGatewayBypassCache string = "x-ms-dedicatedgateway-bypass-cache" cosmosHeaderPriorityLevel string = "x-ms-cosmos-priority-level" cosmosHeaderThroughputBucket string = "x-ms-cosmos-throughput-bucket" + cosmosHeaderRetryAfterMs string = "x-ms-retry-after-ms" ) const ( diff --git a/sdk/data/azcosmos/cosmos_throttle_retry_policy.go b/sdk/data/azcosmos/cosmos_throttle_retry_policy.go new file mode 100644 index 000000000000..bd9a01ace379 --- /dev/null +++ b/sdk/data/azcosmos/cosmos_throttle_retry_policy.go @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +import ( + "math" + "net/http" + "strconv" + "time" + + azlog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +const ( + defaultMaxThrottleRetryAttempts = 9 + defaultMaxThrottleRetryWaitTime = 30 * time.Second + defaultThrottleRetryDelay = 5 * time.Second +) + +// throttleRetryPolicy retries requests that fail with HTTP 429 (Too Many Requests). +// It honors the Cosmos-specific x-ms-retry-after-ms header to determine the +// delay between attempts and caps the number of attempts and total cumulative +// retry delay. This matches the throttling retry behavior of the other Cosmos +// SDKs (.NET, Java, Python). +type throttleRetryPolicy struct { + maxRetryAttempts int + maxRetryWaitTime time.Duration + // defaultDelay is used when a 429 response is missing the + // x-ms-retry-after-ms header. Defaults to defaultThrottleRetryDelay. + defaultDelay time.Duration +} + +// newThrottleRetryPolicy constructs a throttleRetryPolicy. For MaxRetryAttempts, +// a positive value is used as the cap, zero falls back to the default +// (defaultMaxThrottleRetryAttempts), and a negative value disables throttling +// retries entirely. For MaxRetryWaitTime, a non-positive value falls back to +// the default (defaultMaxThrottleRetryWaitTime). +func newThrottleRetryPolicy(o *ThrottlingRetryOptions) *throttleRetryPolicy { + p := &throttleRetryPolicy{ + maxRetryAttempts: defaultMaxThrottleRetryAttempts, + maxRetryWaitTime: defaultMaxThrottleRetryWaitTime, + defaultDelay: defaultThrottleRetryDelay, + } + if o != nil { + if o.MaxRetryAttempts > 0 { + p.maxRetryAttempts = o.MaxRetryAttempts + } else if o.MaxRetryAttempts < 0 { + // negative values disable throttling retries entirely + p.maxRetryAttempts = 0 + } + if o.MaxRetryWaitTime > 0 { + p.maxRetryWaitTime = o.MaxRetryWaitTime + } + } + return p +} + +func (p *throttleRetryPolicy) Do(req *policy.Request) (*http.Response, error) { + attemptCount := 0 + cumulativeDelay := time.Duration(0) + for { + response, err := req.Next() + // Transport / non-HTTP errors are not throttling; let other policies decide. + if err != nil || response == nil || response.StatusCode != http.StatusTooManyRequests { + return response, err + } + + if attemptCount >= p.maxRetryAttempts { + log.Writef(azlog.EventRetryPolicy, "Cosmos throttle retry exhausted attempts (%d); returning 429 to caller", p.maxRetryAttempts) + return response, nil + } + + delay, ok := readRetryAfterMs(response) + if !ok { + // header missing or unparseable; fall back to the default delay. + // an explicit "0" header is honored (retry immediately). + delay = p.defaultDelay + } + + if cumulativeDelay+delay > p.maxRetryWaitTime { + log.Writef(azlog.EventRetryPolicy, "Cosmos throttle retry exceeded cumulative wait time (%s); returning 429 to caller", p.maxRetryWaitTime) + return response, nil + } + + cumulativeDelay += delay + attemptCount++ + + // Rewind the request body before discarding the response so that, if + // the body isn't seekable, the caller still receives a usable 429 + // response for diagnostics. + if err := req.RewindBody(); err != nil { + return response, err + } + + // drain and close the response body so the connection can be reused + azruntime.Drain(response) + + log.Writef(azlog.EventRetryPolicy, "Cosmos throttle retry attempt %d after %s (cumulative %s)", attemptCount, delay, cumulativeDelay) + + timer := time.NewTimer(delay) + select { + case <-timer.C: + case <-req.Raw().Context().Done(): + timer.Stop() + return response, req.Raw().Context().Err() + } + } +} + +// readRetryAfterMs parses the Cosmos x-ms-retry-after-ms header (milliseconds). +// Returns (delay, true) on a successful parse of a non-negative finite value +// (including an explicit "0", which means "retry immediately"). Returns +// (0, false) when the header is missing, unparseable, NaN, infinite, or +// negative so that the caller can apply a default delay only in that case. +func readRetryAfterMs(resp *http.Response) (time.Duration, bool) { + if resp == nil { + return 0, false + } + v := resp.Header.Get(cosmosHeaderRetryAfterMs) + if v == "" { + return 0, false + } + ms, err := strconv.ParseFloat(v, 64) + if err != nil || math.IsNaN(ms) || math.IsInf(ms, 0) || ms < 0 { + return 0, false + } + return time.Duration(ms * float64(time.Millisecond)), true +} diff --git a/sdk/data/azcosmos/cosmos_throttle_retry_policy_test.go b/sdk/data/azcosmos/cosmos_throttle_retry_policy_test.go new file mode 100644 index 000000000000..154be45f1367 --- /dev/null +++ b/sdk/data/azcosmos/cosmos_throttle_retry_policy_test.go @@ -0,0 +1,482 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +import ( + "bytes" + "context" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" + "github.com/stretchr/testify/require" +) + +type attemptCounter struct { + attempts int +} + +func (a *attemptCounter) Do(req *policy.Request) (*http.Response, error) { + a.attempts++ + return req.Next() +} + +// throttleTestPipeline builds an azcore client wired with a specific +// throttleRetryPolicy and an attempt counter for inspection in tests. +// azcore's built-in retry policy is disabled so the throttleRetryPolicy is the +// only thing retrying. The counter is placed *after* the throttle policy so it +// gets invoked on every retry the throttle policy issues. +func throttleTestPipeline(t *testing.T, srv *mock.Server, p *throttleRetryPolicy) (*azcore.Client, *attemptCounter) { + t.Helper() + counter := &attemptCounter{} + internal, err := azcore.NewClient("azcosmosthrottletest", "v1.0.0", + azruntime.PipelineOptions{ + PerRetry: []policy.Policy{p, counter}, + }, + &policy.ClientOptions{ + Transport: srv, + Retry: policy.RetryOptions{MaxRetries: -1}, + }) + require.NoError(t, err) + return internal, counter +} + +func doThrottleRequest(t *testing.T, c *azcore.Client, url string) (*http.Response, error) { + t.Helper() + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, url) + require.NoError(t, err) + return c.Pipeline().Do(req) +} + +func TestThrottleRetry_SucceedsAfterRetries(t *testing.T) { + srv, closeFn := mock.NewTLSServer() + defer closeFn() + + srv.AppendResponse(mock.WithStatusCode(429), mock.WithHeader(cosmosHeaderRetryAfterMs, "1")) + srv.AppendResponse(mock.WithStatusCode(429), mock.WithHeader(cosmosHeaderRetryAfterMs, "1")) + srv.AppendResponse(mock.WithStatusCode(200)) + + client, counter := throttleTestPipeline(t, srv, &throttleRetryPolicy{ + maxRetryAttempts: 5, + maxRetryWaitTime: 5 * time.Second, + defaultDelay: time.Millisecond, + }) + + resp, err := doThrottleRequest(t, client, srv.URL()) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, 3, counter.attempts) +} + +func TestThrottleRetry_ExhaustsAttempts(t *testing.T) { + srv, closeFn := mock.NewTLSServer() + defer closeFn() + + srv.SetResponse(mock.WithStatusCode(429), mock.WithHeader(cosmosHeaderRetryAfterMs, "1")) + + client, counter := throttleTestPipeline(t, srv, &throttleRetryPolicy{ + maxRetryAttempts: 3, + maxRetryWaitTime: 10 * time.Second, + defaultDelay: time.Millisecond, + }) + + resp, err := doThrottleRequest(t, client, srv.URL()) + require.NoError(t, err) + require.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + // 1 initial attempt + 3 retries + require.Equal(t, 4, counter.attempts) +} + +func TestThrottleRetry_ExhaustsCumulativeWaitTime(t *testing.T) { + srv, closeFn := mock.NewTLSServer() + defer closeFn() + + // Each 429 asks for a 60ms delay; budget is 100ms so only one retry fits. + srv.SetResponse(mock.WithStatusCode(429), mock.WithHeader(cosmosHeaderRetryAfterMs, "60")) + + client, counter := throttleTestPipeline(t, srv, &throttleRetryPolicy{ + maxRetryAttempts: 100, + maxRetryWaitTime: 100 * time.Millisecond, + defaultDelay: time.Millisecond, + }) + + resp, err := doThrottleRequest(t, client, srv.URL()) + require.NoError(t, err) + require.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + // 1 initial + 1 retry (cumulative 60ms used; second retry would push past 100ms) + require.Equal(t, 2, counter.attempts) +} + +func TestThrottleRetry_MissingHeaderUsesDefault(t *testing.T) { + srv, closeFn := mock.NewTLSServer() + defer closeFn() + + // 429 with no retry-after header followed by success. The policy should fall + // back to its defaultDelay value. + srv.AppendResponse(mock.WithStatusCode(429)) + srv.AppendResponse(mock.WithStatusCode(200)) + + client, counter := throttleTestPipeline(t, srv, &throttleRetryPolicy{ + maxRetryAttempts: 5, + maxRetryWaitTime: time.Second, + defaultDelay: 5 * time.Millisecond, + }) + + resp, err := doThrottleRequest(t, client, srv.URL()) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, 2, counter.attempts) +} + +func TestThrottleRetry_ExplicitZeroRetryAfterIsHonored(t *testing.T) { + srv, closeFn := mock.NewTLSServer() + defer closeFn() + + // Several 429s explicitly asking for "retry immediately" (header value "0"), + // followed by success. If the policy treated an explicit 0 as "missing" and + // fell back to defaultDelay (5s here) it would either exceed the 100ms + // cumulative budget (so the first retry would be skipped and the test would + // receive a 429) or take much longer than the assertion below allows. + for i := 0; i < 4; i++ { + srv.AppendResponse(mock.WithStatusCode(429), mock.WithHeader(cosmosHeaderRetryAfterMs, "0")) + } + srv.AppendResponse(mock.WithStatusCode(200)) + + client, counter := throttleTestPipeline(t, srv, &throttleRetryPolicy{ + maxRetryAttempts: 10, + maxRetryWaitTime: 100 * time.Millisecond, + defaultDelay: 5 * time.Second, + }) + + start := time.Now() + resp, err := doThrottleRequest(t, client, srv.URL()) + elapsed := time.Since(start) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, 5, counter.attempts, "expected 4 retries + 1 success") + require.Less(t, elapsed, time.Second, "explicit zero retry-after should not have waited the default delay") +} + +func TestThrottleRetry_Non429PassesThrough(t *testing.T) { + srv, closeFn := mock.NewTLSServer() + defer closeFn() + + srv.SetResponse(mock.WithStatusCode(503)) + + client, counter := throttleTestPipeline(t, srv, newThrottleRetryPolicy(&ThrottlingRetryOptions{ + MaxRetryAttempts: 5, + MaxRetryWaitTime: 5 * time.Second, + })) + + resp, err := doThrottleRequest(t, client, srv.URL()) + require.NoError(t, err) + require.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + require.Equal(t, 1, counter.attempts) +} + +func TestThrottleRetry_ContextCancellationAbortsRetry(t *testing.T) { + srv, closeFn := mock.NewTLSServer() + defer closeFn() + + // Ask for a long retry-after so the policy is asleep when the context is cancelled. + srv.SetResponse(mock.WithStatusCode(429), mock.WithHeader(cosmosHeaderRetryAfterMs, "5000")) + + client, counter := throttleTestPipeline(t, srv, &throttleRetryPolicy{ + maxRetryAttempts: 10, + maxRetryWaitTime: time.Minute, + defaultDelay: time.Second, + }) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + req, err := azruntime.NewRequest(ctx, http.MethodGet, srv.URL()) + require.NoError(t, err) + _, err = client.Pipeline().Do(req) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + // Exactly one transport attempt: the retry was aborted while sleeping. + require.Equal(t, 1, counter.attempts) +} + +func TestThrottleRetry_DefaultsAppliedWhenOptionsNil(t *testing.T) { + p := newThrottleRetryPolicy(nil) + require.Equal(t, defaultMaxThrottleRetryAttempts, p.maxRetryAttempts) + require.Equal(t, defaultMaxThrottleRetryWaitTime, p.maxRetryWaitTime) + require.Equal(t, defaultThrottleRetryDelay, p.defaultDelay) +} + +func TestThrottleRetry_DefaultsAppliedWhenOptionsZero(t *testing.T) { + p := newThrottleRetryPolicy(&ThrottlingRetryOptions{}) + require.Equal(t, defaultMaxThrottleRetryAttempts, p.maxRetryAttempts) + require.Equal(t, defaultMaxThrottleRetryWaitTime, p.maxRetryWaitTime) +} + +func TestThrottleRetry_NegativeAttemptsDisablesRetry(t *testing.T) { + srv, closeFn := mock.NewTLSServer() + defer closeFn() + + srv.SetResponse(mock.WithStatusCode(429), mock.WithHeader(cosmosHeaderRetryAfterMs, "1")) + + client, counter := throttleTestPipeline(t, srv, + newThrottleRetryPolicy(&ThrottlingRetryOptions{MaxRetryAttempts: -1})) + + resp, err := doThrottleRequest(t, client, srv.URL()) + require.NoError(t, err) + require.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + require.Equal(t, 1, counter.attempts) +} + +func TestReadRetryAfterMs(t *testing.T) { + tests := []struct { + name string + header string + want time.Duration + wantOK bool + }{ + {"missing", "", 0, false}, + {"integer", "1500", 1500 * time.Millisecond, true}, + {"float", "12.5", 12500 * time.Microsecond, true}, + {"explicit-zero", "0", 0, true}, + {"invalid", "not-a-number", 0, false}, + {"negative", "-10", 0, false}, + {"nan", "NaN", 0, false}, + {"positive-inf", "Inf", 0, false}, + {"negative-inf", "-Inf", 0, false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + resp := &http.Response{Header: http.Header{}} + if tc.header != "" { + resp.Header.Set(cosmosHeaderRetryAfterMs, tc.header) + } + got, ok := readRetryAfterMs(resp) + require.Equal(t, tc.want, got) + require.Equal(t, tc.wantOK, ok) + }) + } + got, ok := readRetryAfterMs(nil) + require.Equal(t, time.Duration(0), got) + require.False(t, ok) +} + +// trackingBody is a strings.Reader-backed body that records how many times Seek(0,0) +// was called so tests can assert that the throttle policy rewinds the body across retries. +type trackingBody struct { + *strings.Reader + rewinds int + closes int +} + +func newTrackingBody(s string) *trackingBody { + return &trackingBody{Reader: strings.NewReader(s)} +} + +func (b *trackingBody) Seek(offset int64, whence int) (int64, error) { + if offset == 0 && whence == io.SeekStart { + b.rewinds++ + } + return b.Reader.Seek(offset, whence) +} + +func (b *trackingBody) Close() error { + b.closes++ + return nil +} + +// bodyEchoCounter is a per-retry policy that drains the request body and records what it +// saw on each transport attempt. It's used to prove that retries see the full body again. +type bodyEchoCounter struct { + bodies [][]byte +} + +func (b *bodyEchoCounter) Do(req *policy.Request) (*http.Response, error) { + if raw := req.Raw(); raw != nil && raw.Body != nil { + data, err := io.ReadAll(raw.Body) + if err != nil { + return nil, err + } + b.bodies = append(b.bodies, data) + raw.Body = io.NopCloser(bytes.NewReader(data)) + } else { + b.bodies = append(b.bodies, nil) + } + return req.Next() +} + +func TestThrottleRetry_RewindsRequestBodyAcrossRetries(t *testing.T) { + srv, closeFn := mock.NewTLSServer() + defer closeFn() + + srv.AppendResponse(mock.WithStatusCode(429), mock.WithHeader(cosmosHeaderRetryAfterMs, "1")) + srv.AppendResponse(mock.WithStatusCode(429), mock.WithHeader(cosmosHeaderRetryAfterMs, "1")) + srv.AppendResponse(mock.WithStatusCode(200)) + + echo := &bodyEchoCounter{} + internal, err := azcore.NewClient("azcosmosthrottletest", "v1.0.0", + azruntime.PipelineOptions{ + PerRetry: []policy.Policy{ + &throttleRetryPolicy{maxRetryAttempts: 5, maxRetryWaitTime: 5 * time.Second, defaultDelay: time.Millisecond}, + echo, + }, + }, + &policy.ClientOptions{ + Transport: srv, + Retry: policy.RetryOptions{MaxRetries: -1}, + }) + require.NoError(t, err) + + body := newTrackingBody(`{"id":"42"}`) + req, err := azruntime.NewRequest(context.Background(), http.MethodPost, srv.URL()) + require.NoError(t, err) + require.NoError(t, req.SetBody(body, "application/json")) + + resp, err := internal.Pipeline().Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // All three transport attempts should have observed the full request body. + require.Len(t, echo.bodies, 3) + for i, b := range echo.bodies { + require.Equal(t, `{"id":"42"}`, string(b), "transport attempt %d saw a truncated body", i) + } + // Two retries means the throttle policy rewound the body at least twice. + require.GreaterOrEqual(t, body.rewinds, 2) +} + +// fullPipelineClient wires up the same pipeline newClient builds (minus the +// globalEndpointManager bits) so we can verify the cosmos-level retry config +// in concert with azcore's retry policy. +func fullPipelineClient(t *testing.T, srv *mock.Server, opts *ClientOptions) (*azcore.Client, *attemptCounter) { + t.Helper() + if opts == nil { + opts = &ClientOptions{} + } + clientOpts := opts.ClientOptions + clientOpts.Transport = srv + if clientOpts.Retry.RetryDelay == 0 { + // Keep azcore's exponential backoff from making tests slow. + clientOpts.Retry.RetryDelay = time.Millisecond + } + if clientOpts.Retry.StatusCodes == nil && clientOpts.Retry.ShouldRetry == nil { + clientOpts.Retry.StatusCodes = defaultAzcoreRetryStatusCodesWithout429() + } + counter := &attemptCounter{} + internal, err := azcore.NewClient("azcosmosthrottletest", "v1.0.0", + azruntime.PipelineOptions{ + PerRetry: []policy.Policy{ + newThrottleRetryPolicy(&opts.ThrottlingRetryOptions), + counter, + }, + }, + &clientOpts) + require.NoError(t, err) + return internal, counter +} + +// TestThrottleRetry_NoDoubleRetryWith429 verifies that, given the cosmos pipeline's +// default retry configuration, azcore's retry policy does not also retry 429s after +// the throttleRetryPolicy has exhausted its attempts. Without the StatusCodes override +// in newClient, azcore's default StatusCodes (which include 429) would retry the +// whole pipeline three additional times, multiplying the attempt count by 4. +func TestThrottleRetry_NoDoubleRetryWith429(t *testing.T) { + srv, closeFn := mock.NewTLSServer() + defer closeFn() + + srv.SetResponse(mock.WithStatusCode(429), mock.WithHeader(cosmosHeaderRetryAfterMs, "1")) + + client, counter := fullPipelineClient(t, srv, &ClientOptions{ + ThrottlingRetryOptions: ThrottlingRetryOptions{ + MaxRetryAttempts: 2, + MaxRetryWaitTime: 10 * time.Second, + }, + }) + + resp, err := doThrottleRequest(t, client, srv.URL()) + require.NoError(t, err) + require.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + // Strictly 1 initial + 2 throttle retries. If azcore also retried 429s + // the count would be 12 (3 * 4). + require.Equal(t, 3, counter.attempts) +} + +// TestThrottleRetry_AzcoreStillRetriesOther5xx ensures that excluding 429 from +// azcore's default retry StatusCodes leaves the other transient codes (e.g. 503) +// intact, so non-429 retries still happen. +func TestThrottleRetry_AzcoreStillRetriesOther5xx(t *testing.T) { + srv, closeFn := mock.NewTLSServer() + defer closeFn() + + srv.AppendResponse(mock.WithStatusCode(503)) + srv.AppendResponse(mock.WithStatusCode(503)) + srv.AppendResponse(mock.WithStatusCode(200)) + + client, counter := fullPipelineClient(t, srv, nil) + + resp, err := doThrottleRequest(t, client, srv.URL()) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + // azcore retried twice and then the third attempt succeeded. + require.Equal(t, 3, counter.attempts) +} + +// TestThrottleRetry_CallerStatusCodesPreserved verifies that we only override +// azcore's StatusCodes when the caller hasn't supplied their own. If they +// explicitly opt in to retrying 429 at the azcore layer, we respect that +// (even though it stacks with the throttle policy). +func TestThrottleRetry_CallerStatusCodesPreserved(t *testing.T) { + srv, closeFn := mock.NewTLSServer() + defer closeFn() + + srv.SetResponse(mock.WithStatusCode(429), mock.WithHeader(cosmosHeaderRetryAfterMs, "1")) + + caller429 := []int{http.StatusTooManyRequests} + client, counter := fullPipelineClient(t, srv, &ClientOptions{ + ClientOptions: azcore.ClientOptions{ + Retry: policy.RetryOptions{ + MaxRetries: 1, + StatusCodes: caller429, + RetryDelay: time.Millisecond, + }, + }, + ThrottlingRetryOptions: ThrottlingRetryOptions{ + MaxRetryAttempts: 1, + MaxRetryWaitTime: 10 * time.Second, + }, + }) + + resp, err := doThrottleRequest(t, client, srv.URL()) + require.NoError(t, err) + require.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + // Throttle: 1 initial + 1 retry = 2 per azcore try. azcore: 1 initial + 1 retry = 2 tries. + // Total transport attempts = 2 * 2 = 4. + require.Equal(t, 4, counter.attempts) +} + +func TestDefaultAzcoreRetryStatusCodesWithout429(t *testing.T) { + codes := defaultAzcoreRetryStatusCodesWithout429() + for _, c := range codes { + require.NotEqual(t, http.StatusTooManyRequests, c) + } + // Sanity check: the other transient HTTP failures azcore retries by default + // are still in the list. + for _, want := range []int{ + http.StatusRequestTimeout, + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusServiceUnavailable, + http.StatusGatewayTimeout, + } { + require.Contains(t, codes, want) + } +}