diff --git a/sdk/data/azcosmos/CHANGELOG.md b/sdk/data/azcosmos/CHANGELOG.md index 905d7fac617e..bf7a37c36e9b 100644 --- a/sdk/data/azcosmos/CHANGELOG.md +++ b/sdk/data/azcosmos/CHANGELOG.md @@ -6,6 +6,8 @@ ### Features Added +* Added retry policy for transient `500`, `502`, and `504` server errors on read requests. The request is retried once in the current region and, if applicable, once against the next preferred region. Writes are not retried. This matches the behavior of the .NET, Java, and Python Cosmos SDKs. See [PR 26821](https://github.com/Azure/azure-sdk-for-go/pull/26821). + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/data/azcosmos/cosmos_client_retry_policy.go b/sdk/data/azcosmos/cosmos_client_retry_policy.go index 1f78a708eda5..53b2bc1fc273 100644 --- a/sdk/data/azcosmos/cosmos_client_retry_policy.go +++ b/sdk/data/azcosmos/cosmos_client_retry_policy.go @@ -114,6 +114,10 @@ type retryContext struct { retryCount int sessionRetryCount int preferredLocationIndex int + // serverErrorRetryCount tracks the number of retries attempted for a + // transient 5xx server error (500/502/504). Only reads are retried; + // the budget is one in-region retry followed by one cross-region retry. + serverErrorRetryCount int // sameRegionRetryCount tracks the number of consecutive retries we have // attempted against the currently-resolved endpoint for a connection // error chain. It resets to 0 whenever we fail over to another region @@ -138,6 +142,11 @@ type retryContext struct { const maxRetryCount = 120 const defaultBackoff = 1 +// maxServerErrorRetryCount is the total number of retries attempted for a +// transient 5xx server error: one in-region retry followed by one +// cross-region retry. +const maxServerErrorRetryCount = 2 + // sleepWithContext sleeps for d, but returns early with the context's error // if ctx is cancelled or its deadline expires. Use this in retry paths so // the policy honors caller-set context deadlines instead of consuming the @@ -234,6 +243,11 @@ func (p *clientRetryPolicy) Do(req *policy.Request) (*http.Response, error) { subStatus := response.Header.Get(cosmosHeaderSubstatus) if p.shouldRetryStatus(response.StatusCode, subStatus) { retryContext.useWriteEndpoint = false + // advanceLocation gates whether the post-switch logic advances + // retryCount (which moves the resolved endpoint to the next + // region). An in-region 5xx retry leaves it true=>false so the + // retry targets the same endpoint. + advanceLocation := true switch response.StatusCode { case http.StatusForbidden: shouldRetry, err := p.attemptRetryOnEndpointFailure(req, o.isWriteOperation, &retryContext) @@ -259,12 +273,25 @@ func (p *clientRetryPolicy) Do(req *policy.Request) (*http.Response, error) { if !shouldRetry { return nil, errorinfo.NonRetriableError(azruntime.NewResponseErrorWithErrorCode(response, response.Status)) } + case http.StatusInternalServerError, http.StatusBadGateway, http.StatusGatewayTimeout: + shouldRetry, inRegion := p.attemptRetryOnServerError(o.isWriteOperation, &retryContext) + if !shouldRetry { + return nil, errorinfo.NonRetriableError(azruntime.NewResponseErrorWithErrorCode(response, response.Status)) + } + // The in-region retry targets the same endpoint, so do not + // advance retryCount. The cross-region retry advances the + // location via preferredLocationIndex and retryCount. + if inRegion { + advanceLocation = false + } } err = req.RewindBody() if err != nil { return response, err } - retryContext.retryCount += 1 + if advanceLocation { + retryContext.retryCount += 1 + } // HTTP-status retries can change the endpoint (via retryCount // or preferredLocationIndex). Reset the connection-error // same-region budget so a fresh chain of connection errors @@ -283,7 +310,10 @@ func (p *clientRetryPolicy) shouldRetryStatus(status int, subStatus string) (sho if (status == http.StatusForbidden && (subStatus == subStatusWriteForbidden || subStatus == subStatusDatabaseAccountNotFound)) || (status == http.StatusNotFound && subStatus == subStatusReadSessionNotAvailable) || (status == http.StatusServiceUnavailable) || - (status == http.StatusRequestTimeout) { + (status == http.StatusRequestTimeout) || + (status == http.StatusInternalServerError) || + (status == http.StatusBadGateway) || + (status == http.StatusGatewayTimeout) { return true } return false @@ -510,6 +540,39 @@ func (p *clientRetryPolicy) attemptRetryOnServiceUnavailable(isWriteOperation bo return true } +// attemptRetryOnServerError applies the 5xx retry policy for transient server +// errors (500 Internal Server Error, 502 Bad Gateway, 504 Gateway Timeout). +// Consistent with the other Cosmos SDKs (.NET, Python, Java), only read +// operations are retried. The retry budget is one in-region retry followed by +// one cross-region retry, after which the error is surfaced to the caller. The +// cross-region retry is only attempted when cross-region retries are enabled, a +// preferred location is available to fail over to, and the location cache has +// resolved more than one read endpoint -- otherwise the "cross-region" retry +// would just hit the same endpoint as the in-region retry. The returned +// inRegion flag tells the caller whether to keep targeting the current endpoint +// (true) or to advance to the next preferred region (false). +func (p *clientRetryPolicy) attemptRetryOnServerError(isWriteOperation bool, retryContext *retryContext) (shouldRetry bool, inRegion bool) { + if isWriteOperation { + return false, false + } + if retryContext.serverErrorRetryCount >= maxServerErrorRetryCount { + return false, false + } + if retryContext.serverErrorRetryCount == 0 { + retryContext.serverErrorRetryCount += 1 + return true, true + } + if !p.gem.locationCache.enableCrossRegionRetries || retryContext.preferredLocationIndex >= len(p.gem.preferredLocations) { + return false, false + } + if p.gem.locationCache.readEndpointCount() <= 1 { + return false, false + } + retryContext.serverErrorRetryCount += 1 + retryContext.preferredLocationIndex += 1 + return true, false +} + // attemptRetryOnRequestTimeout handles an HTTP 408 from the service. A // 408 is ambiguous from a write-safety standpoint (the request may or // may not have been processed before the server timed out), so only diff --git a/sdk/data/azcosmos/cosmos_client_retry_policy_test.go b/sdk/data/azcosmos/cosmos_client_retry_policy_test.go index 2acdde358d22..342d9e6ec5c0 100644 --- a/sdk/data/azcosmos/cosmos_client_retry_policy_test.go +++ b/sdk/data/azcosmos/cosmos_client_retry_policy_test.go @@ -1348,3 +1348,536 @@ func (p *clientRetryPolicyVerifier) Do(req *policy.Request) (*http.Response, err p.requests = append(p.requests, pr) return resp, err } + +// multiReadEndpointLC builds a mock location cache whose readEndpoints slice has +// more than one entry, all pointing at the single mock server (defaultEndpoint). +// The server-error cross-region retry path requires len(readEndpoints) > 1; using +// identical URLs keeps every retry routed back to the same mock server so the +// queued responses are consumed in order. +func multiReadEndpointLC(defaultEndpoint url.URL, isMultiMaster bool) *locationCache { + lc := CreateMockLC(defaultEndpoint, isMultiMaster) + lc.locationInfo.readEndpoints = []url.URL{defaultEndpoint, defaultEndpoint, defaultEndpoint} + return lc +} + +// hostRoutingTransport dispatches each request to a backing transport selected +// by the request's URL host, and records the host of every attempt in order. +// It lets a test prove that a cross-region retry actually changes the request +// target endpoint (instead of merely advancing retry counters). +type hostRoutingTransport struct { + routes map[string]policy.Transporter + seenHosts []string +} + +func (t *hostRoutingTransport) Do(req *http.Request) (*http.Response, error) { + t.seenHosts = append(t.seenHosts, req.URL.Host) + backing, ok := t.routes[req.URL.Host] + if !ok { + return nil, fmt.Errorf("no backing transport registered for host %q", req.URL.Host) + } + return backing.Do(req) +} + +// TestReadServerError_CrossRegionRoutesToDifferentEndpoint proves that the +// cross-region 5xx retry targets a genuinely different endpoint than the +// in-region attempts. The in-region read endpoint and the next preferred +// region are backed by two distinct mock servers, and a host-routing transport +// records which endpoint each attempt actually hit. The in-region server fails +// both the initial request and the in-region retry; the request only succeeds +// once the cross-region retry routes to the second server. +func TestReadServerError_CrossRegionRoutesToDifferentEndpoint(t *testing.T) { + for _, statusCode := range []int{ + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusGatewayTimeout, + } { + t.Run(http.StatusText(statusCode), func(t *testing.T) { + // In-region endpoint: fails the initial request and the in-region retry. + srvA, closeA := mock.NewTLSServer() + defer closeA() + srvA.AppendResponse(mock.WithStatusCode(statusCode)) + srvA.AppendResponse(mock.WithStatusCode(statusCode)) + + // Cross-region endpoint: succeeds once the failover routes here. + srvB, closeB := mock.NewTLSServer() + defer closeB() + srvB.AppendResponse(mock.WithStatusCode(200)) + + endpointA, err := url.Parse(srvA.URL()) + assert.NoError(t, err) + endpointB, err := url.Parse(srvB.URL()) + assert.NoError(t, err) + // Sanity: the two regions must be genuinely distinct endpoints, + // otherwise the routing assertion below would be meaningless. + assert.NotEqual(t, endpointA.Host, endpointB.Host) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + // Resolve read index 0 -> endpointA (in-region) and index 1 -> + // endpointB (next preferred region). + lc := CreateMockLC(*endpointA, false) + lc.locationInfo.readEndpoints = []url.URL{*endpointA, *endpointB} + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{"East US", "Central US"}, + locationCache: lc, + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + verifier := clientRetryPolicyVerifier{} + + transport := &hostRoutingTransport{ + routes: map[string]policy.Transporter{ + endpointA.Host: srvA, + endpointB.Host: srvB, + }, + } + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{&verifier, retryPolicy}}, &policy.ClientOptions{Transport: transport}) + + client := &Client{endpoint: srvA.URL(), endpointUrl: endpointA, internal: internalClient, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + assert.NoError(t, err) + + // Counter behavior: one in-region retry then one cross-region retry. + assert.Equal(t, 2, verifier.requests[0].retryContext.serverErrorRetryCount) + assert.Equal(t, 1, verifier.requests[0].retryContext.retryCount) + assert.Equal(t, 1, verifier.requests[0].retryContext.preferredLocationIndex) + + // Routing behavior: the first two attempts hit the in-region + // endpoint and only the third (cross-region) attempt hit the + // second, distinct endpoint. This proves the failover changes the + // request target rather than only mutating counters. + assert.Equal(t, []string{endpointA.Host, endpointA.Host, endpointB.Host}, transport.seenHosts) + }) + } +} + +func TestReadServerError(t *testing.T) { + for _, statusCode := range []int{ + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusGatewayTimeout, + } { + t.Run(http.StatusText(statusCode), func(t *testing.T) { + srv, closeFunc := mock.NewTLSServer() + defer closeFunc() + + defaultEndpoint, err := url.Parse(srv.URL()) + assert.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{"East US", "Central US"}, + locationCache: multiReadEndpointLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + verifier := clientRetryPolicyVerifier{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{&verifier, retryPolicy}}, &policy.ClientOptions{Transport: srv}) + + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + + // Setting up responses for in-region retry + cross-region retry then succeeding. + srv.AppendResponse(mock.WithStatusCode(statusCode)) + srv.AppendResponse(mock.WithStatusCode(statusCode)) + srv.AppendResponse(mock.WithStatusCode(200)) + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + // Request should retry in-region then cross-region and succeed on the third attempt. + assert.NoError(t, err) + assert.True(t, verifier.requests[0].retryContext.serverErrorRetryCount == 2) + // retryCount advances only on the cross-region retry, not the in-region one. + assert.True(t, verifier.requests[0].retryContext.retryCount == 1) + assert.True(t, verifier.requests[0].retryContext.preferredLocationIndex == 1) + + // Setting up responses for both retries failing. + srv.AppendResponse(mock.WithStatusCode(statusCode)) + srv.AppendResponse(mock.WithStatusCode(statusCode)) + srv.AppendResponse(mock.WithStatusCode(statusCode)) + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + // Request should retry once in-region and once cross-region and then fail. + assert.Error(t, err) + assert.True(t, verifier.requests[1].retryContext.serverErrorRetryCount == 2) + assert.True(t, verifier.requests[1].retryContext.retryCount == 1) + + // Without preferred locations, only the in-region retry should occur. + gem.preferredLocations = []string{} + srv.AppendResponse(mock.WithStatusCode(statusCode)) + srv.AppendResponse(mock.WithStatusCode(statusCode)) + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + assert.Error(t, err) + assert.True(t, verifier.requests[2].retryContext.serverErrorRetryCount == 1) + assert.True(t, verifier.requests[2].retryContext.retryCount == 0) + }) + } +} + +func TestWriteServerErrorNotRetried(t *testing.T) { + for _, statusCode := range []int{ + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusGatewayTimeout, + } { + t.Run(http.StatusText(statusCode), func(t *testing.T) { + srv, closeFunc := mock.NewTLSServer() + defer closeFunc() + + defaultEndpoint, err := url.Parse(srv.URL()) + assert.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{"East US", "Central US"}, + locationCache: CreateMockLC(*defaultEndpoint, true), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + verifier := clientRetryPolicyVerifier{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{&verifier, retryPolicy}}, &policy.ClientOptions{Transport: srv}) + + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + + item := map[string]interface{}{ + "id": "1", + "value": "2", + } + marshalled, err := json.Marshal(item) + if err != nil { + t.Fatal(err) + } + + // Even with preferred locations and multi-master configured, writes must not retry 5xx. + srv.AppendResponse(mock.WithStatusCode(statusCode)) + srv.AppendResponse(mock.WithStatusCode(200)) + _, err = container.CreateItem(context.TODO(), NewPartitionKeyString("1"), marshalled, nil) + assert.Error(t, err) + assert.True(t, verifier.requests[0].retryContext.retryCount == 0) + assert.True(t, verifier.requests[0].retryContext.serverErrorRetryCount == 0) + }) + } +} + +func TestReadServerError_InRegionRetrySucceeds(t *testing.T) { + for _, statusCode := range []int{ + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusGatewayTimeout, + } { + t.Run(http.StatusText(statusCode), func(t *testing.T) { + srv, closeFunc := mock.NewTLSServer() + defer closeFunc() + + defaultEndpoint, err := url.Parse(srv.URL()) + assert.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{"East US", "Central US"}, + locationCache: CreateMockLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + verifier := clientRetryPolicyVerifier{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{&verifier, retryPolicy}}, &policy.ClientOptions{Transport: srv}) + + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + + // In-region retry should be enough; no cross-region failover required. + srv.AppendResponse(mock.WithStatusCode(statusCode)) + srv.AppendResponse(mock.WithStatusCode(200)) + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + assert.NoError(t, err) + assert.True(t, verifier.requests[0].retryContext.serverErrorRetryCount == 1) + assert.True(t, verifier.requests[0].retryContext.retryCount == 0) + assert.True(t, verifier.requests[0].retryContext.preferredLocationIndex == 0) + }) + } +} + +func TestReadServerError_CrossRegionRetriesDisabled(t *testing.T) { + for _, statusCode := range []int{ + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusGatewayTimeout, + } { + t.Run(http.StatusText(statusCode), func(t *testing.T) { + srv, closeFunc := mock.NewTLSServer() + defer closeFunc() + + defaultEndpoint, err := url.Parse(srv.URL()) + assert.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + lc := CreateMockLC(*defaultEndpoint, false) + lc.enableCrossRegionRetries = false + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{"East US", "Central US"}, + locationCache: lc, + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + verifier := clientRetryPolicyVerifier{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{&verifier, retryPolicy}}, &policy.ClientOptions{Transport: srv}) + + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + + // With cross-region disabled only the in-region retry should be attempted. + srv.AppendResponse(mock.WithStatusCode(statusCode)) + srv.AppendResponse(mock.WithStatusCode(statusCode)) + srv.AppendResponse(mock.WithStatusCode(200)) + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + assert.Error(t, err) + assert.True(t, verifier.requests[0].retryContext.serverErrorRetryCount == 1) + assert.True(t, verifier.requests[0].retryContext.retryCount == 0) + assert.True(t, verifier.requests[0].retryContext.preferredLocationIndex == 0) + }) + } +} + +func TestReadServerError_ErrorIsResponseError(t *testing.T) { + for _, statusCode := range []int{ + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusGatewayTimeout, + } { + t.Run(http.StatusText(statusCode), func(t *testing.T) { + srv, closeFunc := mock.NewTLSServer() + defer closeFunc() + + defaultEndpoint, err := url.Parse(srv.URL()) + assert.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{"East US", "Central US"}, + locationCache: multiReadEndpointLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + verifier := clientRetryPolicyVerifier{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{&verifier, retryPolicy}}, &policy.ClientOptions{Transport: srv}) + + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + + // Exhaust both retries and verify the surfaced error preserves the response. + srv.AppendResponse(mock.WithStatusCode(statusCode)) + srv.AppendResponse(mock.WithStatusCode(statusCode)) + srv.AppendResponse(mock.WithStatusCode(statusCode)) + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + assert.Error(t, err) + var responseErr *azcore.ResponseError + assert.True(t, errors.As(err, &responseErr)) + assert.Equal(t, statusCode, responseErr.StatusCode) + }) + } +} + +func TestReadServerError_NonRetriable5xxNotRetried(t *testing.T) { + // 501 Not Implemented is a 5xx status that should NOT be retried. + srv, closeFunc := mock.NewTLSServer() + defer closeFunc() + + defaultEndpoint, err := url.Parse(srv.URL()) + assert.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{"East US", "Central US"}, + locationCache: CreateMockLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + verifier := clientRetryPolicyVerifier{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{&verifier, retryPolicy}}, &policy.ClientOptions{Transport: srv}) + + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + + srv.AppendResponse(mock.WithStatusCode(http.StatusNotImplemented)) + srv.AppendResponse(mock.WithStatusCode(200)) + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + assert.Error(t, err) + assert.True(t, verifier.requests[0].retryContext.serverErrorRetryCount == 0) + assert.True(t, verifier.requests[0].retryContext.retryCount == 0) +} + +func TestReadServerError_MixedWith503(t *testing.T) { + srv, closeFunc := mock.NewTLSServer() + defer closeFunc() + + defaultEndpoint, err := url.Parse(srv.URL()) + assert.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{"East US", "Central US"}, + locationCache: multiReadEndpointLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + verifier := clientRetryPolicyVerifier{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{&verifier, retryPolicy}}, &policy.ClientOptions{Transport: srv}) + + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + + // Sequence: 500 (in-region retry) -> 503 (cross-region via preferredLocationIndex) -> + // 500 (cross-region 5xx retry; consumes the last preferred location) -> 200. + srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError)) + srv.AppendResponse(mock.WithStatusCode(http.StatusServiceUnavailable)) + srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError)) + srv.AppendResponse(mock.WithStatusCode(200)) + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + assert.NoError(t, err) + assert.True(t, verifier.requests[0].retryContext.serverErrorRetryCount == 2) + // 503 advances preferredLocationIndex by 1 and retryCount by 1; + // the cross-region 5xx retry advances both by 1 again. + assert.True(t, verifier.requests[0].retryContext.preferredLocationIndex == 2) + assert.True(t, verifier.requests[0].retryContext.retryCount == 2) +} + +// TestReadServerError_SingleReadEndpoint verifies that when the location cache has +// only one resolved read endpoint, the cross-region retry is skipped because failing +// over would just hit the same endpoint as the in-region retry. This covers +// single-region accounts and the case where preferred locations resolve to only one +// available read endpoint. +func TestReadServerError_SingleReadEndpoint(t *testing.T) { + srv, closeFunc := mock.NewTLSServer() + defer closeFunc() + + defaultEndpoint, err := url.Parse(srv.URL()) + assert.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + lc := CreateMockLC(*defaultEndpoint, false) + // Simulate a single-region account: only one resolved read endpoint, even though + // the caller provided multiple preferred locations. + lc.locationInfo.readEndpoints = []url.URL{*defaultEndpoint} + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{"East US", "Central US"}, + locationCache: lc, + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + verifier := clientRetryPolicyVerifier{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{&verifier, retryPolicy}}, &policy.ClientOptions{Transport: srv}) + + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + + // Two 5xx responses queued: the in-region retry consumes the second one and + // fails. The cross-region retry must NOT fire because readEndpoints has length 1. + srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError)) + srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError)) + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + assert.Error(t, err) + assert.Equal(t, 1, verifier.requests[0].retryContext.serverErrorRetryCount) + assert.Equal(t, 0, verifier.requests[0].retryContext.retryCount) + assert.Equal(t, 0, verifier.requests[0].retryContext.preferredLocationIndex) +} diff --git a/sdk/data/azcosmos/cosmos_location_cache.go b/sdk/data/azcosmos/cosmos_location_cache.go index be68f7491baf..981c5ea903e8 100644 --- a/sdk/data/azcosmos/cosmos_location_cache.go +++ b/sdk/data/azcosmos/cosmos_location_cache.go @@ -269,6 +269,17 @@ func (lc *locationCache) sessionRetrySnapshot() (multiWrite bool, readN, writeN len(lc.locationInfo.availWriteLocations) } +// readEndpointCount returns the number of resolved preferred read endpoints +// under RLock. The server-error retry path uses it to decide whether a +// cross-region failover would actually target a different endpoint. Reading +// the slice length under the lock prevents a torn read against a concurrent +// locationCache.update (e.g. from an async GEM refresh). +func (lc *locationCache) readEndpointCount() int { + lc.mapMutex.RLock() + defer lc.mapMutex.RUnlock() + return len(lc.locationInfo.readEndpoints) +} + func (lc *locationCache) markEndpointUnavailableForRead(endpoint url.URL) (wasAlreadyUnavailable bool, err error) { return lc.markEndpointUnavailable(endpointKey(endpoint), read) }