diff --git a/sdk/data/azcosmos/CHANGELOG.md b/sdk/data/azcosmos/CHANGELOG.md index 91374edeff3d..66aeec26e72c 100644 --- a/sdk/data/azcosmos/CHANGELOG.md +++ b/sdk/data/azcosmos/CHANGELOG.md @@ -1,5 +1,7 @@ # Release History + + ## 1.5.0-beta.7 (Unreleased) ### Features Added @@ -8,6 +10,7 @@ ### Bugs Fixed +* 403/`WriteForbidden` retries refresh the global endpoint manager fire-and-forget (CAS-gated) instead of blocking on a synchronous `gem.Update`. See [PR 26889](https://github.com/Azure/azure-sdk-for-go/pull/26889). * Connection-error retry policy now attempts up to 3 retries against the current region before failing over, and performs at most one cross-region failover per call. Cross-region failover for writes only occurs when the error proves the request never reached the service (DNS, dial, TLS handshake, `ECONNREFUSED`, etc.); writes on ambiguous transport failures (e.g. `ECONNRESET`, `EOF`, transport-level timeouts) no longer fail over to another region, avoiding potential duplicate writes. Reads still fail over for any transport error. Caller-set context deadlines or cancellations short-circuit the policy without consuming the caller's budget with retries. See [PR 26858](https://github.com/Azure/azure-sdk-for-go/pull/26858). * HTTP `408 Request Timeout` responses are now handled by the Cosmos client retry policy: reads are retried exactly once against another region, and writes are returned to the caller immediately to avoid potential duplicates. See [PR 26858](https://github.com/Azure/azure-sdk-for-go/pull/26858). * Fixed excessive `GetDatabaseAccount` HTTP calls when using preferred regions, and stopped data-plane retries from trailing into the customer-supplied (default) endpoint once account topology is populated. See [PR 26815](https://github.com/Azure/azure-sdk-for-go/pull/26815). diff --git a/sdk/data/azcosmos/cosmos_client_retry_policy.go b/sdk/data/azcosmos/cosmos_client_retry_policy.go index 132e4bd1a7f1..1f78a708eda5 100644 --- a/sdk/data/azcosmos/cosmos_client_retry_policy.go +++ b/sdk/data/azcosmos/cosmos_client_retry_policy.go @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// cSpell:ignore azlog Writef Retriable unrecovered + package azcosmos import ( @@ -11,16 +13,99 @@ import ( "io" "net" "net/http" + "runtime/debug" + "sync/atomic" "syscall" "time" + azlog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" ) type clientRetryPolicy struct { gem *globalEndpointManager + // asyncRefreshState tracks the in-flight goroutine spawned by + // asyncForceRefreshGEM (Idle/Pending/Failed). See its doc comment. + asyncRefreshState atomic.Int32 + // lastForcedRefreshUnixNano is the completion time of the most + // recent asyncForceRefreshGEM. Read by staleForcedRefresh to + // rate-limit repeat refreshes against the same endpoint. + lastForcedRefreshUnixNano atomic.Int64 +} + +const ( + asyncRefreshIdle int32 = 0 + asyncRefreshPending int32 = 1 + asyncRefreshFailed int32 = 2 + + // forcedRefreshMinInterval rate-limits repeat forced refreshes + // against an already-unavailable endpoint. Must be >= + // defaultBackoff*time.Second so a tight 403 loop cannot bypass it. + forcedRefreshMinInterval = 2 * time.Second +) + +// asyncForceRefreshGEM kicks off a forced GEM topology refresh in a +// detached goroutine. The refresh must never block a data-plane retry: +// during a regional outage the global FQDN often resolves to the same +// regional FE pool we just marked unavailable, so a synchronous Update +// can stall and prevent failover. +// +// asyncRefreshState (CAS-gated) caps in-flight refreshes at one per +// policy. We run on context.Background() so a near-expired caller +// deadline cannot abort the refresh. Panics from gem.Update are +// recovered + logged but NOT re-panicked (an unrecovered panic in a +// detached goroutine terminates the process). +// +// Returns true if a refresh was actually spawned. +func (p *clientRetryPolicy) asyncForceRefreshGEM() bool { + for { + state := p.asyncRefreshState.Load() + if state == asyncRefreshPending { + return false + } + if p.asyncRefreshState.CompareAndSwap(state, asyncRefreshPending) { + break + } + } + go func() { + err := error(nil) + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic in azcosmos retry-policy async GEM refresh: %v", r) + log.Writef(azlog.EventResponse, "%v\n%s", err, debug.Stack()) + } + // Record completion time BEFORE flipping state so callers + // that observe Idle also see the freshly-updated timestamp. + p.lastForcedRefreshUnixNano.Store(time.Now().UnixNano()) + if err != nil { + p.asyncRefreshState.Store(asyncRefreshFailed) + } else { + p.asyncRefreshState.Store(asyncRefreshIdle) + } + }() + err = p.gem.Update(context.Background(), true) + if err != nil { + log.Writef(azlog.EventResponse, + "azcosmos retry-policy async GEM refresh failed: %v", err) + } + }() + return true +} + +// staleForcedRefresh reports whether the rate-limit window has +// elapsed since the last completed asyncForceRefreshGEM (or no refresh +// has run yet). Used to permit follow-up refreshes for repeat 403s +// against an already-unavailable endpoint -- critical for single-master +// writes, which cannot reroute locally. +func (p *clientRetryPolicy) staleForcedRefresh() bool { + last := p.lastForcedRefreshUnixNano.Load() + if last == 0 { + return true + } + return time.Since(time.Unix(0, last)) >= forcedRefreshMinInterval } // Retry context for the request @@ -43,6 +128,11 @@ type retryContext struct { // single cross-region retry for an HTTP 408. Only reads are retried // on 408; writes are returned to the caller immediately. requestTimeoutRetryDone bool + // resolveFromHead is a one-shot signal to the outer Do loop to use + // locationIndex 0 instead of retryCount. Set by retry paths that + // demote-to-tail (MarkEndpointUnavailable* moves the bad endpoint + // to the tail of the route list). + resolveFromHead bool } const maxRetryCount = 120 @@ -97,7 +187,13 @@ func (p *clientRetryPolicy) Do(req *policy.Request) (*http.Response, error) { for { // Update the retry context with the latest retry values req.SetOperationValue(retryContext) - resolvedEndpoint := p.gem.ResolveServiceEndpoint(retryContext.retryCount, o.resourceType, o.isWriteOperation, retryContext.useWriteEndpoint) + // Consume the one-shot resolveFromHead override. + locationIndex := retryContext.retryCount + if retryContext.resolveFromHead { + locationIndex = 0 + retryContext.resolveFromHead = false + } + resolvedEndpoint := p.gem.ResolveServiceEndpoint(locationIndex, o.resourceType, o.isWriteOperation, retryContext.useWriteEndpoint) regionName := p.gem.GetEndpointLocation(resolvedEndpoint) req.Raw().Host = resolvedEndpoint.Host req.Raw().URL.Host = resolvedEndpoint.Host @@ -264,50 +360,43 @@ func (p *clientRetryPolicy) attemptRetryOnNetworkError(req *policy.Request, kind // without producing a different endpoint. canCrossRegionWrite := !isWriteOperation || p.gem.CanUseMultipleWriteLocations() if isWriteOperation && (kind != connectionErrorNotSent || !canCrossRegionWrite) { - // Ambiguous failure, or single-master write: we cannot safely - // retry on another region. Mark the endpoint unavailable for - // reads so concurrent requests learn about the outage, but do - // not mark it unavailable for writes on a single-master - // account (we have nowhere else to send writes). - // - // Intentionally no gem.Update(ctx, true) here: as of PR #26815 - // MarkEndpointUnavailable* invalidates the GEM cache once per - // newly-unavailable endpoint, so the *next* caller's - // Update(false) will issue a refresh on its own. We skip the - // synchronous refresh because connection errors do not - // indicate that account topology has changed — they just say - // "this region is unhealthy right now." Forcing a refresh on - // every give-up under a regional outage would amplify the - // outage by piling GetDatabaseAccount calls on the metadata - // endpoint precisely when we want to be most responsive. - if err := p.gem.MarkEndpointUnavailableForRead(*req.Raw().URL); err != nil { + // Ambiguous failure or single-master write: cannot safely retry + // on another region. Mark unavailable for reads (concurrent + // readers learn), but not for writes on single-master (nowhere + // else to send writes). No forced gem.Update: the invalidate() + // inside MarkEndpointUnavailable* will arm the next non-force + // Update on its own, and a connection error is not a topology + // change. + if _, err := p.gem.MarkEndpointUnavailableForRead(*req.Raw().URL); err != nil { return false, err } if canCrossRegionWrite { - if err := p.gem.MarkEndpointUnavailableForWrite(*req.Raw().URL); err != nil { + if _, err := p.gem.MarkEndpointUnavailableForWrite(*req.Raw().URL); err != nil { return false, err } } return false, nil } - if err := p.gem.MarkEndpointUnavailableForRead(*req.Raw().URL); err != nil { + if _, err := p.gem.MarkEndpointUnavailableForRead(*req.Raw().URL); err != nil { return false, err } if isWriteOperation { - if err := p.gem.MarkEndpointUnavailableForWrite(*req.Raw().URL); err != nil { + if _, err := p.gem.MarkEndpointUnavailableForWrite(*req.Raw().URL); err != nil { return false, err } } - // Force a refresh so the new unavailability is reflected in - // readEndpoints / writeEndpoints for both this request and any - // concurrent requests racing through resolveServiceEndpoint. - if err := p.gem.Update(req.Raw().Context(), true); err != nil { - return false, err - } + // No forced gem.Update: gating the failover on it would surface a + // metadata-endpoint timeout (the global FQDN often resolves to the + // same regional FE pool we just marked unavailable) and skip the + // cross-region retry. invalidate() inside MarkEndpointUnavailable* + // arms the next non-force Update for real topology changes. retryContext.sameRegionRetryCount = 0 - retryContext.retryCount += 1 + // Demote-to-tail leaves the bad endpoint at index 1+; force the + // next resolve to use index 0 instead of the (possibly bumped) + // retryCount, otherwise we'd route right back to the demoted slot. + retryContext.resolveFromHead = true retryContext.crossRegionFailoverDone = true if sleepErr := sleepWithContext(req.Raw().Context(), defaultBackoff*time.Second); sleepErr != nil { return false, fmt.Errorf("%w: underlying transport error: %v", sleepErr, transportErr) @@ -319,34 +408,85 @@ func (p *clientRetryPolicy) attemptRetryOnEndpointFailure(req *policy.Request, i if (retryContext.retryCount > maxRetryCount) || !p.gem.locationCache.enableCrossRegionRetries { return false, nil } + var wasAlreadyUnavailable bool + var err error if isWriteOperation { - err := p.gem.MarkEndpointUnavailableForWrite(*req.Raw().URL) + wasAlreadyUnavailable, err = p.gem.MarkEndpointUnavailableForWrite(*req.Raw().URL) if err != nil { return false, err } } else { - err := p.gem.MarkEndpointUnavailableForRead(*req.Raw().URL) + wasAlreadyUnavailable, err = p.gem.MarkEndpointUnavailableForRead(*req.Raw().URL) if err != nil { return false, err } } - err := p.gem.Update(req.Raw().Context(), true) - if err != nil { - return false, err - } + // Kick off a forced async refresh on: + // - a NEW unavailability event for this endpoint (first + // transition; we always want fresh topology after a brand-new + // mark), OR + // - a repeat mark when no refresh is currently in flight AND + // the last completed forced refresh is older than + // forcedRefreshMinInterval. This single condition covers + // both recovery from a successful-but-stale prior refresh + // (single-master writes can't reroute locally) and recovery + // from a failed prior refresh (metadata endpoint was + // transiently unhealthy) without storming GetDatabaseAccount + // when the metadata endpoint is sustained-unhealthy. + // + // MarkEndpointUnavailable* already calls invalidate() on the first + // transition, so the next non-force Update will refresh anyway -- + // but for single-master writes the local route list cannot reroute + // around the bad write endpoint, so without these additional + // forced refreshes the client could be stuck on the failed write + // region for refreshTimeInterval (default 5 min). + // + // Fire-and-forget: we do NOT block the retry on its outcome. + // MarkEndpointUnavailable* has already invalidated the GEM cache + // and demoted the bad endpoint locally, so the next + // ResolveServiceEndpoint will pick the failover region (in + // multi-region scenarios) whether or not the metadata refresh + // succeeds. Blocking here would surface a transient metadata + // failure to the caller and skip the very cross-region retry this + // function is supposed to perform. + state := p.asyncRefreshState.Load() + shouldForceRefresh := !wasAlreadyUnavailable || + (state != asyncRefreshPending && p.staleForcedRefresh()) + if shouldForceRefresh { + p.asyncForceRefreshGEM() + } + + // Force the next resolve to use locationIndex 0. Without this, the + // outer Do() loop bumps retryCount += 1 after we return true, which + // for a two-region account turns readEndpoints[1 % 2] back into the + // just-marked unhealthy endpoint that MarkEndpointUnavailable* + // demoted to the tail. resolveFromHead is a one-shot consumed by + // the outer loop's ResolveServiceEndpoint call. + retryContext.resolveFromHead = true - time.Sleep(defaultBackoff * time.Second) + if sleepErr := sleepWithContext(req.Raw().Context(), defaultBackoff*time.Second); sleepErr != nil { + return false, sleepErr + } return true, nil } func (p *clientRetryPolicy) attemptRetryOnSessionUnavailable(isWriteOperation bool, retryContext *retryContext) bool { - if p.gem.CanUseMultipleWriteLocations() { - endpoints := p.gem.locationCache.locationInfo.availReadLocations + // Snapshot multi-write capability AND the relevant slice length + // under a single RLock. The async refresh paths (in this file and + // in globalEndpointManagerPolicy) can call locationCache.update + // concurrently, which rewrites enableMultipleWriteLocations and + // availRead/WriteLocations under mapMutex.Lock(). Sampling these + // across two separate lock acquisitions can yield a mixed snapshot + // (multi-write decision from before a refresh + slice length from + // after it, or vice versa), causing the wrong branch to be taken. + multiWrite, readN, writeN := p.gem.locationCache.sessionRetrySnapshot() + if multiWrite { + n := readN if isWriteOperation { - endpoints = p.gem.locationCache.locationInfo.availWriteLocations + n = writeN } - if retryContext.sessionRetryCount >= len(endpoints) { + if retryContext.sessionRetryCount >= n { return false } } else { @@ -387,14 +527,6 @@ func (p *clientRetryPolicy) attemptRetryOnRequestTimeout(req *policy.Request, is return false, nil } - if err := p.gem.MarkEndpointUnavailableForRead(*req.Raw().URL); err != nil { - return false, err - } - // Force a refresh so the unavailability is reflected in - // readEndpoints for this and concurrent requests. - if err := p.gem.Update(req.Raw().Context(), true); err != nil { - return false, err - } retryContext.requestTimeoutRetryDone = true // Preserve the caller's cancellation cause if their context fires // during the backoff so errors.Is(returned, context.DeadlineExceeded) diff --git a/sdk/data/azcosmos/cosmos_client_retry_policy_test.go b/sdk/data/azcosmos/cosmos_client_retry_policy_test.go index 8229f3abbb20..2acdde358d22 100644 --- a/sdk/data/azcosmos/cosmos_client_retry_policy_test.go +++ b/sdk/data/azcosmos/cosmos_client_retry_policy_test.go @@ -1,12 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// cSpell:ignore azcosmosgemtest azcosmostest retriable + package azcosmos import ( "context" "encoding/json" "errors" + "fmt" "io" "net" "net/http" @@ -598,7 +601,147 @@ func TestConnectionErrorReadFailsOverAfterThreeSameRegionAttempts(t *testing.T) // succeeded. After failover sameRegionRetryCount is reset and // retryCount is incremented to pick a different endpoint. assert.Equal(t, 0, rc.sameRegionRetryCount) - assert.Equal(t, 1, rc.retryCount) + assert.Equal(t, 0, rc.retryCount) // post-fix: retryCount not incremented on connection-error failover; demote-in-cache handles routing +} + +// TestConnectionErrorReadFailsOverWhenGlobalEndpointIsUnreachable simulates +// a regional gateway outage where the global account endpoint also resolves +// to the same regional FE pool that has been blocked (the common case for +// single-region writes — global FQDN points at the write region's FE). +// +// Before the fix, attemptRetryOnNetworkError had three interlocking +// problems that prevented the cross-region failover from ever taking effect: +// 1. It forced a synchronous gem.Update(ctx, true) after +// MarkEndpointUnavailable*. With the global endpoint unreachable, the +// refresh failed (in production typically a connect timeout when the +// global FQDN resolves to a blocked regional FE pool; the test injects +// a net.DNSError as a deterministic stand-in for any gem.Update +// failure) — causing the policy to surface the original connection +// failure without ever attempting the cross-region retry. +// 2. It incremented retryCount after the mark. MarkEndpointUnavailable* +// demotes the bad endpoint to the TAIL of readEndpoints rather than +// removing it, so readEndpoints becomes [good, bad]. With retryCount +// bumped to 1, ResolveServiceEndpoint(1 % 2) returns the still-bad +// endpoint at the tail — the failover attempt would hit the same +// dead region again. +// 3. MarkEndpointUnavailable* was called with the full request URL +// (path, query, etc. included) but the unavailability map and the +// cache's per-region endpoint lookup were keyed by base URLs +// (scheme+host). The marks were therefore written under keys nothing +// else looked up, so isEndpointUnavailableLocked always returned false +// and the demote silently did nothing. +// +// The fix drops the forced refresh, leaves retryCount at 0 so the next +// ResolveServiceEndpoint returns readEndpoints[0] (the just-promoted +// preferred region), and normalizes URLs to scheme+host on both write +// and read sides of the unavailability map. +// +// To actually exercise the routing this test wires up TWO distinct mock +// servers (badSrv = original/unhealthy region, goodSrv = failover region) +// and points the location cache's read endpoints at both. badSrv only +// serves DNS errors; goodSrv serves the 200 the request needs. If the +// resolver returns badSrv after failover (because of any of the three +// pre-fix conditions) the test fails. +func TestConnectionErrorReadFailsOverWhenGlobalEndpointIsUnreachable(t *testing.T) { + badSrv, badClose := mock.NewTLSServer() + defer badClose() + goodSrv, goodClose := mock.NewTLSServer() + defer goodClose() + + badURL, err := url.Parse(badSrv.URL()) + require.NoError(t, err) + goodURL, err := url.Parse(goodSrv.URL()) + require.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + // Simulate the global endpoint being unreachable for the duration of + // the regional outage. In production this typically manifests as a + // connect timeout (global FQDN resolves to a blocked regional FE + // pool); a net.DNSError gives us the same gem.Update(ctx,true) + // failure deterministically and without test-time sleeps. + gemServer.SetError(&net.DNSError{}) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + // Build a location cache with TWO distinct regional endpoints so the + // routing decision after failover is observable. "East US" (badSrv) + // is the user's application region (index 0); "Central US" (goodSrv) + // is the next preferred. + lc := newLocationCache([]string{"East US", "Central US"}, *badURL, true /*enableCrossRegionRetries*/) + require.NoError(t, lc.update( + []accountRegion{{Name: "East US", Endpoint: badSrv.URL()}}, + []accountRegion{ + {Name: "East US", Endpoint: badSrv.URL()}, + {Name: "Central US", Endpoint: goodSrv.URL()}, + }, + []string{"East US", "Central US"}, + nil, + )) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{"East US", "Central US"}, + locationCache: lc, + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + // azcore needs to dispatch to whichever URL the policy resolves to, + // not a fixed Transport. routingMockTransport keys by host so a single + // client sees distinct backing servers per region. + routingTransport := routingMockTransport{ + byHost: map[string]*mock.Server{ + badURL.Host: badSrv, + goodURL.Host: goodSrv, + }, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + verifier := &clientRetryPolicyVerifier{} + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{verifier, retryPolicy}}, &policy.ClientOptions{Transport: &routingTransport}) + client := &Client{endpoint: badSrv.URL(), endpointUrl: badURL, internal: internalClient, gem: gem} + + dnsErr := &net.DNSError{} + // 1 initial + 3 same-region retries on the bad region. + for i := 0; i < 4; i++ { + badSrv.AppendError(dnsErr) + } + // Cross-region failover should hit the good region. + goodSrv.AppendResponse(mock.WithStatusCode(200)) + + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + + require.NoError(t, err, "cross-region failover should reach the good region") + rc := verifier.requests[0].retryContext + assert.True(t, rc.crossRegionFailoverDone, "expected one cross-region failover") + // retryCount stays at 0: MarkEndpointUnavailable* demoted the bad + // endpoint so ResolveServiceEndpoint(0) now returns the good region. + // Bumping retryCount to 1 would index back to the demoted-tail slot. + assert.Equal(t, 0, rc.retryCount) + assert.Equal(t, 0, rc.sameRegionRetryCount) + // 1 initial + 3 same-region retries against badSrv. + assert.Equal(t, 4, badSrv.Requests()) + // Exactly one request against the good region (the failover). + assert.Equal(t, 1, goodSrv.Requests()) +} + +// routingMockTransport routes each request to the mock server matching +// the request URL's host. This lets a single client see distinct backing +// servers per region without azcore short-circuiting to a fixed mock. +type routingMockTransport struct { + byHost map[string]*mock.Server +} + +func (r *routingMockTransport) Do(req *http.Request) (*http.Response, error) { + srv, ok := r.byHost[req.URL.Host] + if !ok { + return nil, fmt.Errorf("no mock server registered for host %q", req.URL.Host) + } + return srv.Do(req) } func TestNotSentConnectionErrorWriteFailsOver(t *testing.T) { @@ -623,7 +766,7 @@ func TestNotSentConnectionErrorWriteFailsOver(t *testing.T) { assert.NoError(t, err) rc := verifier.requests[0].retryContext assert.Equal(t, 0, rc.sameRegionRetryCount) - assert.Equal(t, 1, rc.retryCount) + assert.Equal(t, 0, rc.retryCount) // post-fix: retryCount not incremented on connection-error failover; demote-in-cache handles routing } // fakeAmbiguousNetError satisfies net.Error and wraps syscall.ECONNRESET @@ -686,7 +829,7 @@ func TestAmbiguousConnectionErrorReadFailsOver(t *testing.T) { assert.NoError(t, err) rc := verifier.requests[0].retryContext assert.Equal(t, 0, rc.sameRegionRetryCount) - assert.Equal(t, 1, rc.retryCount) + assert.Equal(t, 0, rc.retryCount) // post-fix: retryCount not incremented on connection-error failover; demote-in-cache handles routing } func TestCallerDeadlineExceededDoesNotRetry(t *testing.T) { @@ -732,7 +875,7 @@ func TestNotSentConnectionErrorMultiMasterWriteFailsOver(t *testing.T) { assert.NoError(t, err) rc := verifier.requests[0].retryContext assert.True(t, rc.crossRegionFailoverDone, "expected one cross-region failover") - assert.Equal(t, 1, rc.retryCount) + assert.Equal(t, 0, rc.retryCount) // post-fix: retryCount not incremented on connection-error failover; demote-in-cache handles routing assert.Equal(t, 0, rc.sameRegionRetryCount) } @@ -756,7 +899,7 @@ func TestConnectionErrorGivesUpAfterSingleCrossRegionFailover(t *testing.T) { rc := verifier.requests[0].retryContext // One cross-region failover happened and then we gave up. assert.True(t, rc.crossRegionFailoverDone) - assert.Equal(t, 1, rc.retryCount) + assert.Equal(t, 0, rc.retryCount) // post-fix: retryCount not incremented on connection-error failover; demote-in-cache handles routing // Mock server should have served exactly 5 requests: // 1 initial + 3 same-region retries + 1 cross-region failover. assert.Equal(t, 5, srv.Requests()) @@ -975,6 +1118,181 @@ func TestClassifyNetworkError(t *testing.T) { } } +// TestWriteForbiddenFailsOverToHealthyRegion is the routing-level +// regression for the 403/WriteForbidden path. It mirrors +// TestConnectionErrorReadFailsOverWhenGlobalEndpointIsUnreachable for +// the network-error path: two distinct backend mock servers wired +// through a host-routing transport, the first returns +// 403/WriteForbidden, and the failover must reach the second. +// +// Before this PR also fixed the 403 path, MarkEndpointUnavailable* +// demoted the bad write endpoint to the tail of writeEndpoints, then +// the outer Do() loop bumped retryContext.retryCount += 1, and the +// next ResolveServiceEndpoint(1 % 2) routed right back to the demoted +// bad endpoint. The fix sets retryContext.resolveFromHead = true so +// the next resolve uses locationIndex 0. +func TestWriteForbiddenFailsOverToHealthyRegion(t *testing.T) { + badSrv, badClose := mock.NewTLSServer() + defer badClose() + goodSrv, goodClose := mock.NewTLSServer() + defer goodClose() + + badURL, err := url.Parse(badSrv.URL()) + require.NoError(t, err) + goodURL, err := url.Parse(goodSrv.URL()) + require.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetError(&net.DNSError{}) + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + lc := newLocationCache([]string{"East US", "Central US"}, *badURL, true /*enableCrossRegionRetries*/) + require.NoError(t, lc.update( + []accountRegion{ + {Name: "East US", Endpoint: badSrv.URL()}, + {Name: "Central US", Endpoint: goodSrv.URL()}, + }, + []accountRegion{ + {Name: "East US", Endpoint: badSrv.URL()}, + {Name: "Central US", Endpoint: goodSrv.URL()}, + }, + []string{"East US", "Central US"}, + boolPtr(true), // enable multi-master so writes can fail over + )) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{"East US", "Central US"}, + locationCache: lc, + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + routingTransport := routingMockTransport{ + byHost: map[string]*mock.Server{ + badURL.Host: badSrv, + goodURL.Host: goodSrv, + }, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + verifier := &clientRetryPolicyVerifier{} + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{verifier, retryPolicy}}, &policy.ClientOptions{Transport: &routingTransport}) + client := &Client{endpoint: badSrv.URL(), endpointUrl: badURL, internal: internalClient, gem: gem} + + // 1 initial 403/WriteForbidden on the bad region. + badSrv.AppendResponse( + mock.WithHeader("x-ms-substatus", subStatusWriteForbidden), + mock.WithStatusCode(http.StatusForbidden)) + // Cross-region failover should hit the good region. + goodSrv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + item := map[string]interface{}{"id": "1", "value": "2"} + marshalled, mErr := json.Marshal(item) + require.NoError(t, mErr) + _, err = container.CreateItem(context.TODO(), NewPartitionKeyString("1"), marshalled, nil) + require.NoError(t, err, "403/WriteForbidden must fail over to the healthy region") + + // Exactly one request on each: 1 initial 403 against badSrv, 1 + // failover success against goodSrv. A regression that re-routed to + // the demoted endpoint would show 2 requests on badSrv. + assert.Equal(t, 1, badSrv.Requests(), "no further requests should hit the demoted write endpoint") + assert.Equal(t, 1, goodSrv.Requests(), "the failover request must reach the healthy write endpoint") +} + +// TestConnectionErrorFailoverResetsNonZeroRetryCount covers the mixed +// failure sequence: a prior HTTP-status retry (e.g. 408) bumps +// retryCount, then a connection error triggers the cross-region +// failover. The failover must still land on the healthy region; if the +// connection-error path merely "does not increment" retryCount instead +// of forcing the next resolve to head, the inherited non-zero index +// indexes back to the demoted-tail bad endpoint. +func TestConnectionErrorFailoverResetsNonZeroRetryCount(t *testing.T) { + badSrv, badClose := mock.NewTLSServer() + defer badClose() + goodSrv, goodClose := mock.NewTLSServer() + defer goodClose() + + badURL, err := url.Parse(badSrv.URL()) + require.NoError(t, err) + goodURL, err := url.Parse(goodSrv.URL()) + require.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetError(&net.DNSError{}) + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + lc := newLocationCache([]string{"East US", "Central US"}, *badURL, true) + require.NoError(t, lc.update( + []accountRegion{{Name: "East US", Endpoint: badSrv.URL()}}, + []accountRegion{ + {Name: "East US", Endpoint: badSrv.URL()}, + {Name: "Central US", Endpoint: goodSrv.URL()}, + }, + []string{"East US", "Central US"}, + nil, + )) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{"East US", "Central US"}, + locationCache: lc, + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + routingTransport := routingMockTransport{ + byHost: map[string]*mock.Server{ + badURL.Host: badSrv, + goodURL.Host: goodSrv, + }, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + verifier := &clientRetryPolicyVerifier{} + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{verifier, retryPolicy}}, &policy.ClientOptions{Transport: &routingTransport}) + client := &Client{endpoint: badSrv.URL(), endpointUrl: badURL, internal: internalClient, gem: gem} + + // Sequence on badSrv: + // 1) 408 (read) -> outer loop bumps retryCount to 1, picks + // readEndpoints[1] = Central US for the next attempt. + // Sequence on goodSrv (now selected): + // 2) initial attempt: 4x DNSError (initial + 3 same-region) -> + // triggers cross-region failover via attemptRetryOnNetworkError. + // After the failover the inherited retryCount is 1 (or higher). + // resolveFromHead must force the next resolve to index 0, which + // after demote-to-tail of Central US is East US (badSrv). + // 3) failover hits badSrv again: serve a 200. + badSrv.AppendResponse(mock.WithStatusCode(http.StatusRequestTimeout)) + badSrv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + + dnsErr := &net.DNSError{} + for i := 0; i < 4; i++ { + goodSrv.AppendError(dnsErr) + } + + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + require.NoError(t, err, "mixed 408+connection-error sequence must still fail over to a healthy host") + + rc := verifier.requests[0].retryContext + assert.True(t, rc.requestTimeoutRetryDone, "the 408 retry should have run") + assert.True(t, rc.crossRegionFailoverDone, "the connection-error failover should have run") + // 1 initial 408 + 1 final success against badSrv. + assert.Equal(t, 2, badSrv.Requests(), "expected initial 408 + post-failover success on the head-of-list endpoint") + // 1 initial + 3 same-region retries against goodSrv (the 408 routed us here, then DNS killed it). + assert.Equal(t, 4, goodSrv.Requests(), "expected initial + 3 same-region attempts before failover gave up on the bad region") +} + +func boolPtr(b bool) *bool { return &b } + func CreateMockLC(defaultEndpoint url.URL, isMultiMaster bool) *locationCache { availableWriteLocs := []string{"East US"} if isMultiMaster { diff --git a/sdk/data/azcosmos/cosmos_dbaccount_refresh_test.go b/sdk/data/azcosmos/cosmos_dbaccount_refresh_test.go index 43a84f818222..1ea3ca4399d6 100644 --- a/sdk/data/azcosmos/cosmos_dbaccount_refresh_test.go +++ b/sdk/data/azcosmos/cosmos_dbaccount_refresh_test.go @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// cSpell:ignore azcosmosgemtest azcosmostest retriable + // Regression tests for excess GetDatabaseAccount calls observed with // preferred regions configured. // @@ -153,10 +155,17 @@ func TestFix2_ConcurrentUpdateCallersCoalesce(t *testing.T) { } // ---------------------------------------------------------------------------- -// F3: write-retry on 403/WriteForbidden force-refreshes the GEM on every -// retry attempt so the client picks up topology changes immediately. +// F3: write-retry on 403/WriteForbidden kicks off an opportunistic GEM +// refresh on the FIRST mark for each endpoint and on subsequent marks +// at most once per forcedRefreshMinInterval (rate-limited). It must +// NOT issue one refresh per retry (that would storm +// GetDatabaseAccount during a sustained 403 flap) and it must NOT +// issue zero refreshes after the first (single-master writes need +// recovery when the first refresh returns stale topology). The +// refresh is fire-and-forget so a stalled metadata endpoint cannot +// block the retry. // ---------------------------------------------------------------------------- -func TestFix3_WriteRetryForceRefreshesGEM(t *testing.T) { +func TestFix3_WriteRetryKicksOffFireAndForgetRefresh(t *testing.T) { defaultEndpoint, err := url.Parse("https://fake.documents.azure.com:443/") require.NoError(t, err) @@ -184,14 +193,717 @@ func TestFix3_WriteRetryForceRefreshesGEM(t *testing.T) { const writeRetries = 5 rc := &retryContext{} + start := time.Now() for i := 0; i < writeRetries; i++ { shouldRetry, err := retry.attemptRetryOnEndpointFailure(req, true, rc) - require.NoError(t, err) - require.True(t, shouldRetry) + require.NoError(t, err, "retry must not surface metadata-refresh errors") + require.True(t, shouldRetry, "write retries on 403/WriteForbidden must continue") rc.retryCount++ } - require.Equal(t, int64(writeRetries), transport.count.Load(), - "write retries on 403/WriteForbidden must force-refresh the GEM on every attempt") + elapsed := time.Since(start) + + // At least one refresh must run (the first 403 always forces). + require.Eventually(t, func() bool { + return transport.count.Load() >= 1 + }, 5*time.Second, 10*time.Millisecond, + "first 403 must kick off a GEM refresh") + // Give any racing follow-up refresh a chance to land. + time.Sleep(200 * time.Millisecond) + + // Upper bound: rate-limited to at most one refresh per + // forcedRefreshMinInterval. With 5 retries each sleeping + // defaultBackoff*time.Second between attempts, elapsed ~= 5s. + // Expected refreshes: 1 (initial) + floor(elapsed / + // forcedRefreshMinInterval) at most. Allow +1 for boundary + // scheduling slack. + maxExpected := int64(1+elapsed/forcedRefreshMinInterval) + 1 + require.LessOrEqual(t, transport.count.Load(), maxExpected, + "sustained 403s against the same endpoint must be rate-limited (elapsed=%v got=%d max=%d)", + elapsed, transport.count.Load(), maxExpected) +} + +// ---------------------------------------------------------------------------- +// F3a: 403/WriteForbidden retry must complete successfully even when the +// asynchronous GEM refresh hits a hard failure (5xx, dial error, etc.). +// Before the async-refresh change the policy did +// +// if err := gem.Update(ctx, true); err != nil { return false, err } +// +// so any GEM-refresh failure short-circuited the retry and surfaced the +// metadata error to the caller -- exactly the regression we're guarding +// against here. +// ---------------------------------------------------------------------------- +func TestFix3a_WriteRetrySucceedsWhenGEMRefreshFails(t *testing.T) { + defaultEndpoint, err := url.Parse("https://fake.documents.azure.com:443/") + require.NoError(t, err) + + // Transport that always returns 500 -- gem.Update will fail every + // time. The retry must still return (true, nil). + transport := &countingTransport{status: http.StatusInternalServerError} + gemPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", + azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: transport}) + + gem := &globalEndpointManager{ + clientEndpoint: defaultEndpoint.String(), + pipeline: gemPipeline, + preferredLocations: []string{"West US"}, + locationCache: CreateMockLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Now(), + } + retry := &clientRetryPolicy{gem: gem} + + req, err := azruntime.NewRequest(context.Background(), http.MethodPost, defaultEndpoint.String()) + require.NoError(t, err) + + rc := &retryContext{} + shouldRetry, err := retry.attemptRetryOnEndpointFailure(req, true, rc) + require.NoError(t, err, "GEM refresh failures must not surface to the retry caller") + require.True(t, shouldRetry, "retry must proceed regardless of GEM refresh outcome") + + // Sanity check: the async refresh actually ran (and failed). Without + // this, a regression that made asyncForceRefreshGEM a no-op would + // still pass the (true, nil) assertions above. + require.Eventually(t, func() bool { + return transport.count.Load() >= 1 + }, 5*time.Second, 10*time.Millisecond, + "async GEM refresh must run even though it fails") +} + +// ---------------------------------------------------------------------------- +// F3b: 408 read retry must NOT mark the request's endpoint unavailable. +// A 408 is a per-request signal (the gateway accepted the request but +// did not produce a response in time); demoting the whole region in the +// location cache for unavailableLocationExpirationTime based on one +// slow request would penalize concurrent reads against a region that +// may be perfectly healthy. +// ---------------------------------------------------------------------------- +func TestFix3b_RequestTimeoutDoesNotMarkEndpointUnavailable(t *testing.T) { + defaultEndpoint, err := url.Parse("https://fake.documents.azure.com:443/") + require.NoError(t, err) + + transport := &countingTransport{status: http.StatusOK, body: []byte("{}")} + gemPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", + azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: transport}) + + gem := &globalEndpointManager{ + clientEndpoint: defaultEndpoint.String(), + pipeline: gemPipeline, + preferredLocations: []string{"West US"}, + locationCache: CreateMockLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Now(), + } + retry := &clientRetryPolicy{gem: gem} + + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, defaultEndpoint.String()) + require.NoError(t, err) + + rc := &retryContext{} + shouldRetry, err := retry.attemptRetryOnRequestTimeout(req, false /*isWriteOperation*/, rc) + require.NoError(t, err) + require.True(t, shouldRetry) + require.True(t, rc.requestTimeoutRetryDone) + + require.Empty(t, gem.locationCache.locationUnavailabilityInfoMap, + "408 must not record any endpoint as unavailable") + require.Equal(t, int64(0), transport.count.Load(), + "408 retry must not synchronously hit the GEM endpoint") +} + +// ---------------------------------------------------------------------------- +// F3c: 408 read retry is non-blocking -- even a permanently-stalled GEM +// endpoint cannot delay or fail the retry, because the 408 path no +// longer calls gem.Update at all. +// ---------------------------------------------------------------------------- +func TestFix3c_RequestTimeoutDoesNotBlockOnStalledGEM(t *testing.T) { + defaultEndpoint, err := url.Parse("https://fake.documents.azure.com:443/") + require.NoError(t, err) + + // Hanging transport: any call would block effectively forever. + transport := &countingTransport{status: http.StatusOK, body: []byte("{}"), delay: 10 * time.Minute} + gemPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", + azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: transport}) + + gem := &globalEndpointManager{ + clientEndpoint: defaultEndpoint.String(), + pipeline: gemPipeline, + preferredLocations: []string{"West US"}, + locationCache: CreateMockLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Now(), + } + retry := &clientRetryPolicy{gem: gem} + + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, defaultEndpoint.String()) + require.NoError(t, err) + + // defaultBackoff*time.Second is 1s, so give a comfortable upper bound. + done := make(chan struct{}) + var shouldRetry bool + var retryErr error + go func() { + shouldRetry, retryErr = retry.attemptRetryOnRequestTimeout(req, false, &retryContext{}) + close(done) + }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("attemptRetryOnRequestTimeout blocked on the stalled GEM endpoint") + } + require.NoError(t, retryErr) + require.True(t, shouldRetry) +} + +// ---------------------------------------------------------------------------- +// F3d: asyncForceRefreshGEM's CAS gate (asyncRefreshPending) must coalesce +// a retry storm. With N concurrent retries hitting a slow GEM endpoint, +// only one refresh goroutine should reach gem.Update at a time; the GEM's +// own single-in-flight pattern further coalesces to a single HTTP call. +// Without the gate every retry would queue its own goroutine inside +// gem.Update, wasting goroutine + channel overhead for no benefit. +// ---------------------------------------------------------------------------- +func TestFix3d_AsyncRefreshCASGateCoalescesRetryStorm(t *testing.T) { + defaultEndpoint, err := url.Parse("https://fake.documents.azure.com:443/") + require.NoError(t, err) + + westRegion := accountRegion{Name: "West US", Endpoint: defaultEndpoint.String()} + body, _ := json.Marshal(accountProperties{ + ReadRegions: []accountRegion{westRegion}, + WriteRegions: []accountRegion{westRegion}, + }) + // Slow-but-successful GEM transport so refreshes overlap in time. + transport := &countingTransport{status: http.StatusOK, body: body, delay: 200 * time.Millisecond} + gemPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", + azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: transport}) + + gem := &globalEndpointManager{ + clientEndpoint: defaultEndpoint.String(), + pipeline: gemPipeline, + preferredLocations: []string{"West US"}, + locationCache: CreateMockLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Now(), + } + retry := &clientRetryPolicy{gem: gem} + + const concurrency = 50 + var wg sync.WaitGroup + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + retry.asyncForceRefreshGEM() + }() + } + wg.Wait() + + // All asyncForceRefreshGEM callers have returned (the CAS gate is + // non-blocking). Wait for whichever refresh goroutine did get spawned + // to complete its HTTP call. + require.Eventually(t, func() bool { + return transport.count.Load() >= 1 + }, 5*time.Second, 10*time.Millisecond, + "at least one refresh goroutine must reach gem.Update") + + // Give any racing follow-up refresh a chance to land before we assert + // the upper bound. asyncRefreshPending clears in the spawned + // goroutine's defer, which runs AFTER gem.Update returns, so a second + // caller arriving in that tiny window could legitimately spawn a + // second refresh. The bound is "no goroutine-per-retry storm", not + // strictly one. + time.Sleep(100 * time.Millisecond) + require.Less(t, transport.count.Load(), int64(concurrency/2), + "CAS gate must coalesce concurrent retries; got %d HTTP calls for %d retries", + transport.count.Load(), concurrency) +} + +// ---------------------------------------------------------------------------- +// F3e: asyncForceRefreshGEM must use a background context internally, not +// the caller's context. If we threaded the caller's context (or any +// derivative that inherits its deadline/cancellation) into gem.Update, +// an already-cancelled or near-expired caller context would abort the +// refresh immediately and defeat its purpose. +// ---------------------------------------------------------------------------- +func TestFix3e_AsyncRefreshIgnoresCallerContextCancellation(t *testing.T) { + defaultEndpoint, err := url.Parse("https://fake.documents.azure.com:443/") + require.NoError(t, err) + + westRegion := accountRegion{Name: "West US", Endpoint: defaultEndpoint.String()} + body, _ := json.Marshal(accountProperties{ + ReadRegions: []accountRegion{westRegion}, + WriteRegions: []accountRegion{westRegion}, + }) + transport := &countingTransport{status: http.StatusOK, body: body} + gemPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", + azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: transport}) + + gem := &globalEndpointManager{ + clientEndpoint: defaultEndpoint.String(), + pipeline: gemPipeline, + preferredLocations: []string{"West US"}, + locationCache: CreateMockLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Now(), + } + retry := &clientRetryPolicy{gem: gem} + + // Call asyncForceRefreshGEM directly. We deliberately bypass + // attemptRetryOnEndpointFailure here because that path now uses + // sleepWithContext for the backoff (so caller-cancellation correctly + // short-circuits the retry budget). The contract we're verifying is + // narrower: asyncForceRefreshGEM itself must not inherit any caller + // context -- it runs on context.Background() so an already-cancelled + // or near-expired caller deadline cannot abort the background HTTP + // call. + retry.asyncForceRefreshGEM() + + require.Eventually(t, func() bool { + return transport.count.Load() >= 1 + }, 5*time.Second, 10*time.Millisecond, + "asyncForceRefreshGEM must run on context.Background() and complete its HTTP call") +} + +// ---------------------------------------------------------------------------- +// F3f: when the first forced async refresh fails, a subsequent 403 on +// the SAME endpoint must be allowed to spawn another forced refresh. +// Before the fix the wasAlreadyUnavailable guard unconditionally +// suppressed every subsequent forced refresh, which stranded +// single-master writes on the failed write endpoint for the GEM +// throttle window (refreshTimeInterval, default 5 min) after the +// metadata endpoint recovered. +// +// We exercise this at the policy unit level by driving +// asyncForceRefreshGEM's state machine directly: the first call must +// land in asyncRefreshFailed, and the retry-policy's gating logic +// (mirrored here) must permit a new refresh when state == Failed. +// ---------------------------------------------------------------------------- +func TestFix3f_FailedAsyncRefreshIsRetriedOnNextSameEndpoint403(t *testing.T) { + defaultEndpoint, err := url.Parse("https://fake.documents.azure.com:443/") + require.NoError(t, err) + + // A transport that returns an error on demand; we toggle errOn to + // flip between failing and succeeding refreshes. + var errOn atomic.Bool + errOn.Store(true) + body, _ := json.Marshal(accountProperties{ + ReadRegions: []accountRegion{{Name: "West US", Endpoint: defaultEndpoint.String()}}, + WriteRegions: []accountRegion{{Name: "West US", Endpoint: defaultEndpoint.String()}}, + }) + transport := &countingTransport{ + respFunc: func() (int, []byte) { + if errOn.Load() { + // Return 500 with a body that will not parse; pipeline + // surfaces this as an error and azcore retry does NOT + // retry a 500 with no Retry-After... wait, it does. + // Instead use a non-retriable status the GEM treats as + // an error: 400. + return http.StatusBadRequest, []byte(`{"code":"BadRequest"}`) + } + return http.StatusOK, body + }, + } + gemPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", + azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: transport}) + gem := &globalEndpointManager{ + clientEndpoint: defaultEndpoint.String(), + pipeline: gemPipeline, + preferredLocations: []string{"West US"}, + locationCache: CreateMockLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Now(), + } + retry := &clientRetryPolicy{gem: gem} + + // First forced refresh: spawns goroutine, transport returns 400 -> + // gem.Update returns an error -> state should land at Failed. + require.True(t, retry.asyncForceRefreshGEM(), "first call must spawn the refresh goroutine") + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if retry.asyncRefreshState.Load() == asyncRefreshFailed { + break + } + time.Sleep(10 * time.Millisecond) + } + require.Equal(t, asyncRefreshFailed, retry.asyncRefreshState.Load(), + "first forced refresh must record Failed (count=%d)", transport.count.Load()) + firstCount := transport.count.Load() + require.GreaterOrEqual(t, firstCount, int64(1), "first refresh must hit the transport") + + // Now make the transport succeed and call asyncForceRefreshGEM + // again. With state == Failed it MUST spawn a new goroutine even + // though no fresh invalidation happened. + errOn.Store(false) + require.True(t, retry.asyncForceRefreshGEM(), + "second call must be allowed because previous refresh failed") + deadline = time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if retry.asyncRefreshState.Load() == asyncRefreshIdle && transport.count.Load() > firstCount { + break + } + time.Sleep(10 * time.Millisecond) + } + require.Equal(t, asyncRefreshIdle, retry.asyncRefreshState.Load(), + "second forced refresh must succeed (count=%d)", transport.count.Load()) + require.Greater(t, transport.count.Load(), firstCount, + "second refresh must hit the transport (first=%d total=%d)", firstCount, transport.count.Load()) + + // And once Idle, a third call (no failure, no invalidation) is + // allowed too (Idle state always permits a new refresh). + require.True(t, retry.asyncForceRefreshGEM(), + "third call from Idle state must be allowed") +} + +// ---------------------------------------------------------------------------- +// F1c: a forced refresh request that arrives while a stale refresh is +// in flight, where invalidate() fires WHILE the waiter is blocked on +// the in-flight, must still trigger a fresh post-invalidation refresh. +// Before the fix the waiter sampled invalidationGen before the wait, +// so any invalidation during the wait was lost and the waiter returned +// the stale (pre-invalidation) flight result. +// ---------------------------------------------------------------------------- +func TestFix1c_ForceRefreshWaiterReReadsInvalidationGenAfterWait(t *testing.T) { + body, _ := json.Marshal(accountProperties{ + ReadRegions: []accountRegion{{Name: "West US", Endpoint: "https://fake.documents.azure.com:443/"}}, + WriteRegions: []accountRegion{{Name: "West US", Endpoint: "https://fake.documents.azure.com:443/"}}, + }) + transport := &countingTransport{status: http.StatusOK, body: body, delay: 250 * time.Millisecond} + gem := newGEMWithTransport(t, []string{"West US"}, transport, 5*time.Minute) + + // 1) Kick off a refresh (becomes the leader/inflight). + leaderDone := make(chan struct{}) + go func() { + defer close(leaderDone) + _ = gem.Update(context.Background(), false) + }() + require.Eventually(t, gem.hasInflight, 2*time.Second, 5*time.Millisecond, + "leader must enter the in-flight slot") + + // 2) Start a forceRefresh waiter while the leader is still in flight. + waiterStart := make(chan struct{}) + waiterDone := make(chan struct{}) + go func() { + close(waiterStart) + _ = gem.Update(context.Background(), true) + close(waiterDone) + }() + <-waiterStart + + // 3) While the waiter is blocked on <-leaderFlight.done, invalidate. + // The leader will finish soon (delay=250ms); the waiter must + // observe the post-invalidate genAtStart on re-read and loop. + time.Sleep(50 * time.Millisecond) + mark, _ := url.Parse("https://fake.documents.azure.com:443/") + _, markErr := gem.MarkEndpointUnavailableForWrite(*mark) + require.NoError(t, markErr) + + // 4) Both goroutines must complete. + select { + case <-leaderDone: + case <-time.After(5 * time.Second): + t.Fatal("leader refresh did not complete") + } + select { + case <-waiterDone: + case <-time.After(5 * time.Second): + t.Fatal("waiter did not complete; likely stuck looping or never loops") + } + + // 5) Expect exactly 2 HTTP calls: the leader's, plus the + // post-invalidation refresh the waiter should have led after + // looping. Without the fix this would be 1 (waiter returned the + // stale flight without looping). + require.Equal(t, int64(2), transport.count.Load(), + "waiter must loop and lead a fresh refresh after invalidation during wait") +} + +// ---------------------------------------------------------------------------- +// F1d: a forceRefresh LEADER (not waiter) whose flight gets invalidated +// mid-flight must loop and lead a fresh refresh. Before the fix, only +// waiters re-read invalidationGen after their wait; a leader whose own +// genAtStart predated a mid-flight invalidate() would simply return, +// leaving asyncRefreshState=Idle in the retry policy and silently +// skipping the post-invalidation refresh the caller actually needed. +// ---------------------------------------------------------------------------- +func TestFix1d_ForceRefreshLeaderLoopsOnMidFlightInvalidation(t *testing.T) { + body, _ := json.Marshal(accountProperties{ + ReadRegions: []accountRegion{{Name: "West US", Endpoint: "https://fake.documents.azure.com:443/"}}, + WriteRegions: []accountRegion{{Name: "West US", Endpoint: "https://fake.documents.azure.com:443/"}}, + }) + transport := &countingTransport{status: http.StatusOK, body: body, delay: 250 * time.Millisecond} + gem := newGEMWithTransport(t, []string{"West US"}, transport, 5*time.Minute) + + // Kick off a forceRefresh as leader. + leaderDone := make(chan error, 1) + go func() { + leaderDone <- gem.Update(context.Background(), true /*forceRefresh*/) + }() + require.Eventually(t, gem.hasInflight, 2*time.Second, 5*time.Millisecond, + "leader must enter the in-flight slot") + + // Fire invalidate() during the leader's flight. + time.Sleep(50 * time.Millisecond) + mark, _ := url.Parse("https://fake.documents.azure.com:443/") + _, err := gem.MarkEndpointUnavailableForWrite(*mark) + require.NoError(t, err) + + // The leader's first flight will complete; the outer loop must + // detect latestGen > genAtStart and lead a second flight. Wait for + // the entire Update call to return. + select { + case updateErr := <-leaderDone: + require.NoError(t, updateErr) + case <-time.After(5 * time.Second): + t.Fatal("leader Update did not return; likely never loops") + } + + // Expect exactly 2 HTTP calls: the original flight that predated + // the invalidation + the post-invalidation refresh the leader led + // after looping. Without the fix this would be 1. + require.Equal(t, int64(2), transport.count.Load(), + "leader must loop and lead a fresh refresh after mid-flight invalidation") +} + +// ---------------------------------------------------------------------------- +// F3g: attemptRetryOnSessionUnavailable must snapshot +// (enableMultipleWriteLocations, availReadLocations, availWriteLocations) +// atomically. Otherwise a concurrent locationCache.update can flip the +// multi-write decision between the CanUseMultipleWriteLocations() check +// and the subsequent slice-length sampling, producing a routing decision +// that mixes pre- and post-refresh topology state. +// ---------------------------------------------------------------------------- +func TestFix3g_SessionUnavailableSnapshotIsAtomic(t *testing.T) { + defaultEndpoint, err := url.Parse("https://fake.documents.azure.com:443/") + require.NoError(t, err) + lc := CreateMockLC(*defaultEndpoint, true /*multiMaster*/) + multiWrite, readN, writeN := lc.sessionRetrySnapshot() + require.True(t, multiWrite, "multi-write must be reported") + require.Greater(t, readN, 0, "read count must be populated") + require.Greater(t, writeN, 0, "write count must be populated") + + // Concurrent flip+snapshot race: a hostile updater toggles + // enableMultipleWriteLocations and rewrites the slices repeatedly + // while a reader takes snapshots. Each snapshot must be internally + // consistent: multiWrite == true => write/read slices are the + // multi-master shape that was current at the lock acquisition. + stop := make(chan struct{}) + go func() { + toggle := true + for { + select { + case <-stop: + return + default: + } + toggle = !toggle + _ = lc.update(nil, nil, nil, &toggle) + } + }() + defer close(stop) + + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) { + mw, rN, wN := lc.sessionRetrySnapshot() + // The snapshot must come from a single locked read. We can't + // directly verify that without instrumentation, but at minimum + // the returned tuple must be a value (not a panic / negative). + require.GreaterOrEqual(t, rN, 0) + require.GreaterOrEqual(t, wN, 0) + _ = mw + } +} + +// ---------------------------------------------------------------------------- +// F3h: when the first forced refresh returns successfully but the +// metadata still reflects pre-failover topology (a common race during +// single-master account failovers), the retry policy MUST be able to +// force another refresh against the same already-unavailable endpoint +// after forcedRefreshMinInterval. Otherwise the policy's wasAlreadyUnavailable +// gate plus asyncRefreshState=Idle on success would suppress every +// subsequent forced refresh and leave the client stuck on the bad +// write endpoint until the GEM throttle expires (default 5 minutes). +// +// We verify the gating function staleForcedRefresh() and the spawn +// decision directly rather than driving the full attemptRetryOnEndpointFailure +// (which sleeps defaultBackoff between calls and would couple this +// test to that timing). +// ---------------------------------------------------------------------------- +func TestFix3h_RepeatWriteForbiddenForcesRefreshAfterRateWindow(t *testing.T) { + defaultEndpoint, err := url.Parse("https://fake.documents.azure.com:443/") + require.NoError(t, err) + body, _ := json.Marshal(accountProperties{ + ReadRegions: []accountRegion{{Name: "West US", Endpoint: defaultEndpoint.String()}}, + WriteRegions: []accountRegion{{Name: "West US", Endpoint: defaultEndpoint.String()}}, + }) + transport := &countingTransport{status: http.StatusOK, body: body} + gemPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", + azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: transport}) + gem := &globalEndpointManager{ + clientEndpoint: defaultEndpoint.String(), + pipeline: gemPipeline, + preferredLocations: []string{"West US"}, + locationCache: CreateMockLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Now(), + } + retry := &clientRetryPolicy{gem: gem} + + // Before any refresh has run, staleForcedRefresh() must return true + // so the very first 403 unconditionally triggers a refresh. + require.True(t, retry.staleForcedRefresh(), + "with no recorded prior refresh, staleForcedRefresh must allow a spawn") + + // Spawn a real refresh and wait for it to complete -- this populates + // lastForcedRefreshUnixNano with "now". + require.True(t, retry.asyncForceRefreshGEM()) + require.Eventually(t, func() bool { + return retry.asyncRefreshState.Load() == asyncRefreshIdle && transport.count.Load() >= 1 + }, 5*time.Second, 10*time.Millisecond, "first refresh must complete") + require.NotZero(t, retry.lastForcedRefreshUnixNano.Load(), + "the goroutine's defer must record completion time") + + // Immediately after the first refresh completes, staleForcedRefresh() + // must return false so a follow-up 403 on the same endpoint does + // NOT spawn another refresh -- the rate-limit window has not + // elapsed and a tight 403 loop must not storm GetDatabaseAccount. + require.False(t, retry.staleForcedRefresh(), + "within forcedRefreshMinInterval of a completed refresh, follow-up 403s must be rate-limited") + + // Simulate the rate window elapsing by rewinding the recorded + // timestamp by forcedRefreshMinInterval + a small slack. Now the + // gate must allow a new refresh. + retry.lastForcedRefreshUnixNano.Store(time.Now().Add(-forcedRefreshMinInterval - 50*time.Millisecond).UnixNano()) + require.True(t, retry.staleForcedRefresh(), + "after forcedRefreshMinInterval has elapsed, repeat 403 must be allowed to spawn a new refresh") +} + +// ---------------------------------------------------------------------------- +// F1e: the forceRefresh leader's mid-flight-invalidation loop must be +// bounded. Under sustained invalidations the loop could otherwise spin +// indefinitely inside a single Update call, monopolizing the inflight +// slot. After maxForceRefreshRetries iterations the call returns; the +// retry policy state machine then sees the outcome and the NEXT caller +// can lead a fresh attempt. +// ---------------------------------------------------------------------------- +func TestFix1e_ForceRefreshLeaderLoopIsBounded(t *testing.T) { + body, _ := json.Marshal(accountProperties{ + ReadRegions: []accountRegion{{Name: "West US", Endpoint: "https://fake.documents.azure.com:443/"}}, + WriteRegions: []accountRegion{{Name: "West US", Endpoint: "https://fake.documents.azure.com:443/"}}, + }) + // Slow transport so we have time to invalidate during each flight. + transport := &countingTransport{status: http.StatusOK, body: body, delay: 80 * time.Millisecond} + gem := newGEMWithTransport(t, []string{"West US"}, transport, 5*time.Minute) + + // Hostile concurrent invalidator: keeps firing invalidate() so + // every leader iteration sees latestGen > genAtStart. + stop := make(chan struct{}) + go func() { + mark, _ := url.Parse("https://fake.documents.azure.com:443/") + for { + select { + case <-stop: + return + case <-time.After(5 * time.Millisecond): + _, _ = gem.MarkEndpointUnavailableForWrite(*mark) + } + } + }() + defer close(stop) + + start := time.Now() + err := gem.Update(context.Background(), true /*forceRefresh*/) + elapsed := time.Since(start) + require.NoError(t, err) + + // 1 initial + maxForceRefreshRetries loop iterations = up to + // maxForceRefreshRetries+1 flights. Each takes ~80ms; total must + // be far less than runaway (multi-second). 80ms * (3+1) = 320ms + // nominal; allow up to 2s of scheduler slack. + require.LessOrEqual(t, transport.count.Load(), int64(maxForceRefreshRetries+1), + "leader loop must be bounded by maxForceRefreshRetries (count=%d)", transport.count.Load()) + require.Less(t, elapsed, 2*time.Second, + "leader loop must return promptly under sustained invalidations (elapsed=%v count=%d)", elapsed, transport.count.Load()) +} + +// ---------------------------------------------------------------------------- +// F1f: a forceRefresh WAITER (one that joins an in-flight refresh +// rather than leading) must also be bounded by maxForceRefreshRetries. +// Before this fix, the waiter "continue" path did not increment the +// retry counter, so a sustained leadership-churn pattern (other +// goroutines repeatedly winning each subsequent flight while +// invalidate() keeps firing) could keep one waiter joining stale +// flights indefinitely. Now the leader and waiter paths share a +// single budget. +// ---------------------------------------------------------------------------- +func TestFix1f_ForceRefreshWaiterLoopIsBounded(t *testing.T) { + body, _ := json.Marshal(accountProperties{ + ReadRegions: []accountRegion{{Name: "West US", Endpoint: "https://fake.documents.azure.com:443/"}}, + WriteRegions: []accountRegion{{Name: "West US", Endpoint: "https://fake.documents.azure.com:443/"}}, + }) + // Each flight takes ~80ms. + transport := &countingTransport{status: http.StatusOK, body: body, delay: 80 * time.Millisecond} + gem := newGEMWithTransport(t, []string{"West US"}, transport, 5*time.Minute) + + // 1) Kick off a long-running leader so the waiter under test will + // arrive while inflight != nil. + leaderDone := make(chan struct{}) + go func() { + defer close(leaderDone) + _ = gem.Update(context.Background(), false /*not force*/) + }() + require.Eventually(t, gem.hasInflight, 2*time.Second, 5*time.Millisecond, + "leader must enter the in-flight slot first") + + // 2) Hostile background pattern: keep firing invalidate() so every + // flight the waiter observes pre-dates a later invalidation, + // AND keep starting new leaders the moment the previous + // flight clears so the waiter never gets to lead itself. + stop := make(chan struct{}) + defer close(stop) + mark, _ := url.Parse("https://fake.documents.azure.com:443/") + go func() { + for { + select { + case <-stop: + return + case <-time.After(5 * time.Millisecond): + _, _ = gem.MarkEndpointUnavailableForWrite(*mark) + } + } + }() + go func() { + for { + select { + case <-stop: + return + default: + _ = gem.Update(context.Background(), false) + } + } + }() + + // 3) Now start a forceRefresh waiter and time how long it takes + // to return. It should bail out after the shared retry budget + // is exhausted, NOT spin indefinitely. + start := time.Now() + waiterDone := make(chan error, 1) + go func() { + waiterDone <- gem.Update(context.Background(), true /*forceRefresh*/) + }() + select { + case waiterErr := <-waiterDone: + require.NoError(t, waiterErr) + case <-time.After(5 * time.Second): + t.Fatalf("forceRefresh waiter did not return; likely unbounded") + } + elapsed := time.Since(start) + // 4 flights * 80ms each (leader's flight + 3 retry-budget loops) = + // ~320ms nominal upper bound. Allow scheduler slack but require + // significantly less than runaway (multi-second). + require.Less(t, elapsed, 2*time.Second, + "forceRefresh waiter must return promptly under leadership churn (elapsed=%v)", elapsed) + + <-leaderDone } // ---------------------------------------------------------------------------- @@ -266,7 +978,7 @@ func TestFix3c_ConcurrentSameEndpointMarksAreBounded(t *testing.T) { for i := 0; i < concurrency; i++ { go func() { defer wg.Done() - _ = gem.MarkEndpointUnavailableForWrite(*defaultEndpoint) + _, _ = gem.MarkEndpointUnavailableForWrite(*defaultEndpoint) }() go func() { defer wg.Done() @@ -710,7 +1422,8 @@ func TestFix7_InvalidateThenRefreshFailureDoesNotStallDataPlane(t *testing.T) { // 2. Simulate a regional 403 -> MarkEndpointUnavailableForWrite -> invalidate. west, _ := url.Parse("https://west-us.documents.azure.com:443/") - require.NoError(t, gem.MarkEndpointUnavailableForWrite(*west)) + _, err = gem.MarkEndpointUnavailableForWrite(*west) + require.NoError(t, err) // 3. populated() must remain true even though lastUpdateTime is zero. require.True(t, gem.populated(), diff --git a/sdk/data/azcosmos/cosmos_global_endpoint_manager.go b/sdk/data/azcosmos/cosmos_global_endpoint_manager.go index cbb7798b4e82..3d61c44a0379 100644 --- a/sdk/data/azcosmos/cosmos_global_endpoint_manager.go +++ b/sdk/data/azcosmos/cosmos_global_endpoint_manager.go @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// cSpell:ignore azlog ctxt + package azcosmos import ( @@ -61,6 +63,11 @@ type globalEndpointManager struct { type updateFlight struct { done chan struct{} err error + // genAtStart is the invalidationGen the leader observed when it + // claimed the flight. Waiters use it to decide whether the flight + // pre-dates a later invalidate() (in which case forceRefresh callers + // must loop and start a fresh flight after this one completes). + genAtStart uint64 } func newGlobalEndpointManager(clientEndpoint string, pipeline azruntime.Pipeline, preferredLocations []string, refreshTimeInterval time.Duration, enableCrossRegionRetries bool) (*globalEndpointManager, error) { @@ -93,7 +100,7 @@ func (gem *globalEndpointManager) GetReadEndpoints() ([]url.URL, error) { return gem.locationCache.readEndpoints() } -func (gem *globalEndpointManager) MarkEndpointUnavailableForWrite(endpoint url.URL) error { +func (gem *globalEndpointManager) MarkEndpointUnavailableForWrite(endpoint url.URL) (wasAlreadyUnavailable bool, err error) { // markEndpointUnavailableForWrite atomically reports whether the // endpoint was already unavailable for write in the same critical // section that performs the mark. This eliminates the check-then-act @@ -104,25 +111,29 @@ func (gem *globalEndpointManager) MarkEndpointUnavailableForWrite(endpoint url.U // may indicate a failover and we want to learn about new write // regions promptly. Subsequent retries within the unavailability // window do not invalidate. - wasAlreadyUnavailable, err := gem.locationCache.markEndpointUnavailableForWrite(endpoint) + // + // The caller receives wasAlreadyUnavailable so it can decide whether + // to additionally force a fresh refresh; the GEM-internal invalidate + // already arms the next non-force Update either way. + wasAlreadyUnavailable, err = gem.locationCache.markEndpointUnavailableForWrite(endpoint) if err != nil { - return err + return false, err } if !wasAlreadyUnavailable { gem.invalidate() } - return nil + return wasAlreadyUnavailable, nil } -func (gem *globalEndpointManager) MarkEndpointUnavailableForRead(endpoint url.URL) error { - wasAlreadyUnavailable, err := gem.locationCache.markEndpointUnavailableForRead(endpoint) +func (gem *globalEndpointManager) MarkEndpointUnavailableForRead(endpoint url.URL) (wasAlreadyUnavailable bool, err error) { + wasAlreadyUnavailable, err = gem.locationCache.markEndpointUnavailableForRead(endpoint) if err != nil { - return err + return false, err } if !wasAlreadyUnavailable { gem.invalidate() } - return nil + return wasAlreadyUnavailable, nil } // invalidate forces the next non-force Update to actually issue a refresh by @@ -216,52 +227,132 @@ func (gem *globalEndpointManager) ResolveServiceEndpoint(locationIndex int, reso // (GetAccountProperties applies its own 60s timeout). Waiters select between // flight completion and their own ctx.Done() so a caller-side timeout cannot // be exceeded by an unrelated stuck refresh. +// maxForceRefreshRetries bounds the total number of times a +// forceRefresh caller may re-enter the leadership-or-wait loop after +// observing that the last flight predates a later invalidation. The +// cap covers BOTH paths: a leader whose own flight saw a mid-flight +// invalidate(), and a waiter that joined a flight which turned out +// to predate a later invalidate(). Without this combined cap, a +// hostile invalidation+leadership-churn pattern could keep a single +// goroutine bouncing between waiter and leader roles indefinitely. +// After the cap is reached the call returns its last result; the +// retry-policy state machine then sees the outcome (Idle on success, +// Failed on error) and the NEXT request that needs a fresh refresh +// will spawn a new goroutine, naturally rate-limiting the topology +// poll while still making progress. +const maxForceRefreshRetries = 3 + func (gem *globalEndpointManager) Update(ctx context.Context, forceRefresh bool) error { - gem.gemMutex.Lock() - if !gem.shouldRefresh() && !forceRefresh { - // Throttled. Surface the cached error only if we have NEVER - // successfully populated the GEM -- otherwise the data plane has - // a valid cached topology and should continue working until the - // next refresh attempt succeeds. The cached error is shared across - // force=true and force=false callers: both want to surface - // "bootstrap is broken" and there's no caller-visible distinction. - var cached error - if !gem.everPopulated.Load() { - cached = gem.lastUpdateErr + retries := 0 + for { + var flight *updateFlight + var genAtStart uint64 + // Acquire leadership or wait on an in-flight refresh. + leader: + for { + gem.gemMutex.Lock() + if !gem.shouldRefresh() && !forceRefresh { + // Throttled. Surface the cached error only if we have NEVER + // successfully populated the GEM -- otherwise the data plane has + // a valid cached topology and should continue working until the + // next refresh attempt succeeds. The cached error is shared across + // force=true and force=false callers: both want to surface + // "bootstrap is broken" and there's no caller-visible distinction. + var cached error + if !gem.everPopulated.Load() { + cached = gem.lastUpdateErr + } + gem.gemMutex.Unlock() + return cached + } + if gem.inflight != nil { + // Another goroutine is performing a refresh. Wait for it and share + // its result rather than spawning a duplicate HTTP call. The result + // lives on the per-flight struct so subsequent flights cannot + // overwrite it. Honour the waiter's ctx so a caller-side timeout + // is not extended by the leader's HTTP call duration. + waitFlight := gem.inflight + gem.gemMutex.Unlock() + select { + case <-waitFlight.done: + if forceRefresh { + // Re-read invalidationGen UNDER THE LOCK after the + // flight completes. Sampling it before the wait would + // miss any invalidate() that fires while we are + // blocked on <-waitFlight.done -- in which case the + // flight we waited on would still pre-date the + // latest invalidation but we would not know it, and + // we would return the stale result instead of + // looping to lead a fresh post-invalidation refresh. + gem.gemMutex.Lock() + latestGen := gem.invalidationGen + gem.gemMutex.Unlock() + // Consume one slot of the shared force-refresh + // retry budget so a waiter that keeps joining + // stale flights cannot loop forever. + if waitFlight.genAtStart < latestGen && retries < maxForceRefreshRetries { + retries++ + continue + } + } + return waitFlight.err + case <-ctx.Done(): + return ctx.Err() + } + } + // We are the leader. Publish the inflight flight and snapshot the + // invalidation generation BEFORE releasing the lock so the HTTP + // call below does not block ShouldRefresh and other non-Update + // callers on a network round-trip. + genAtStart = gem.invalidationGen + flight = &updateFlight{done: make(chan struct{}), genAtStart: genAtStart} + gem.inflight = flight + gem.gemMutex.Unlock() + break leader } - gem.gemMutex.Unlock() - return cached - } - if gem.inflight != nil { - // Another goroutine is performing a refresh. Wait for it and share - // its result rather than spawning a duplicate HTTP call. The result - // lives on the per-flight struct so subsequent flights cannot - // overwrite it. Honour the waiter's ctx so a caller-side timeout - // is not extended by the leader's HTTP call duration. - flight := gem.inflight - gem.gemMutex.Unlock() - select { - case <-flight.done: - return flight.err - case <-ctx.Done(): - return ctx.Err() + + err := gem.runLeaderFlight(ctx, flight, genAtStart) + + // If we are a forceRefresh caller and an invalidate() fired during + // our flight, our flight's genAtStart predates the latest + // invalidation -- the topology we just learned does not reflect + // the most recent unavailability mark. Loop and either lead a + // fresh flight that covers it or coalesce into one another + // goroutine has already started. Without this, a 403 (or + // connection error) whose MarkEndpointUnavailable* invalidate() + // happened mid-flight would be silently coalesced into the + // in-progress refresh, leave asyncRefreshState=Idle in the retry + // policy, and not trigger another forced refresh. + // + // The same shared retries budget bounds this leader-side + // re-entry so a sustained invalidation storm cannot keep one + // goroutine spinning here forever. + if forceRefresh { + gem.gemMutex.Lock() + latestGen := gem.invalidationGen + gem.gemMutex.Unlock() + if latestGen > genAtStart && retries < maxForceRefreshRetries { + retries++ + continue + } } + return err } - // We are the leader. Publish the inflight flight and snapshot the - // invalidation generation, then release the lock while we perform the - // HTTP call so ShouldRefresh and other non-Update paths don't block on - // a network round-trip. - flight := &updateFlight{done: make(chan struct{})} - gem.inflight = flight - genAtStart := gem.invalidationGen - gem.gemMutex.Unlock() +} +// runLeaderFlight performs the actual GetAccountProperties HTTP call as +// the in-flight leader and runs the panic-safe defer that commits the +// flight result and timestamps. It MUST be called only after the caller +// has claimed leadership (set gem.inflight = flight). Returns the +// flight's error (nil on success). +func (gem *globalEndpointManager) runLeaderFlight(ctx context.Context, flight *updateFlight, genAtStart uint64) (err error) { // Panic-safe cleanup: if refreshOnce (or anything it transitively calls // -- the pipeline, JSON unmarshal, locationCache.update) panics, we // MUST still clear gem.inflight and close flight.done, otherwise every // subsequent Update caller blocks forever on <-flight.done. We capture - // any panic, record it as the flight error, and re-panic after cleanup. - var err error + // any panic, record it as the flight error, and re-panic after + // cleanup so the original stack trace is preserved for the host's + // panic handler. defer func() { r := recover() gem.gemMutex.Lock() diff --git a/sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go b/sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go index b6c997970d80..37ebb178747b 100644 --- a/sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go +++ b/sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// cSpell:ignore azcosmosgemtest azcosmostest westus + package azcosmos import ( @@ -88,7 +90,7 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForRead(t *testing.T) { gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute, true) assert.NoError(t, err) - err = gem.MarkEndpointUnavailableForRead(*endpoint) + _, err = gem.MarkEndpointUnavailableForRead(*endpoint) assert.NoError(t, err) unavailable := gem.IsEndpointUnavailable(*endpoint, 1) @@ -108,7 +110,7 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForWrite(t *testing.T) { gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute, true) assert.NoError(t, err) - err = gem.MarkEndpointUnavailableForWrite(*endpoint) + _, err = gem.MarkEndpointUnavailableForWrite(*endpoint) assert.NoError(t, err) unavailable := gem.IsEndpointUnavailable(*endpoint, 2) diff --git a/sdk/data/azcosmos/cosmos_location_cache.go b/sdk/data/azcosmos/cosmos_location_cache.go index 083219b2d1cc..be68f7491baf 100644 --- a/sdk/data/azcosmos/cosmos_location_cache.go +++ b/sdk/data/azcosmos/cosmos_location_cache.go @@ -145,6 +145,14 @@ func (lc *locationCache) updateLocked(writeLocations []accountRegion, readLocati } func (lc *locationCache) resolveServiceEndpoint(locationIndex int, resourceType resourceType, isWriteOperation, useWriteEndpoint bool) url.URL { + // Take a read lock for the duration of endpoint resolution. The + // fields read here (locationInfo, enableMultipleWriteLocations) are + // rewritten atomically under mapMutex.Lock() by update/updateLocked, + // and a concurrent forced refresh (e.g. from the retry policy's + // asyncForceRefreshGEM or the GEM policy's background refresh) can + // race with us without this lock. + lc.mapMutex.RLock() + defer lc.mapMutex.RUnlock() if (isWriteOperation || useWriteEndpoint) && !lc.canUseMultipleWriteLocsToRoute(resourceType) { if lc.enableCrossRegionRetries && len(lc.locationInfo.availWriteLocations) > 0 { locationIndex = min(locationIndex%2, len(lc.locationInfo.availWriteLocations)-1) @@ -201,6 +209,14 @@ func (lc *locationCache) writeEndpoints() ([]url.URL, error) { } func (lc *locationCache) getLocation(endpoint url.URL) string { + // Take a read lock for the duration of the lookup. The reads of + // locationInfo.availWriteEndpointsByLocation / + // availReadEndpointsByLocation and enableMultipleWriteLocations race + // the writes in update / updateLocked, especially now that the retry + // policy's asyncForceRefreshGEM can trigger a refresh concurrently + // with the data-plane lookup that calls into here. + lc.mapMutex.RLock() + defer lc.mapMutex.RUnlock() firstLoc := "" for location, uri := range lc.locationInfo.availWriteEndpointsByLocation { if uri == endpoint { @@ -238,12 +254,39 @@ func (lc *locationCache) CanUseMultipleWriteLocs() bool { return lc.enableMultipleWriteLocations } +// sessionRetrySnapshot returns a coherent snapshot of the fields the +// session-unavailable retry path needs to make a routing decision: +// (canUseMultipleWriteLocs, availReadLocationCount, availWriteLocationCount). +// Taking these reads under a single RLock prevents a concurrent +// locationCache.update (e.g. from an async GEM refresh) from rewriting +// enableMultipleWriteLocations between the multi-write branch decision +// and the slice-length sampling that follows it. +func (lc *locationCache) sessionRetrySnapshot() (multiWrite bool, readN, writeN int) { + lc.mapMutex.RLock() + defer lc.mapMutex.RUnlock() + return lc.enableMultipleWriteLocations, + len(lc.locationInfo.availReadLocations), + len(lc.locationInfo.availWriteLocations) +} + func (lc *locationCache) markEndpointUnavailableForRead(endpoint url.URL) (wasAlreadyUnavailable bool, err error) { - return lc.markEndpointUnavailable(endpoint, read) + return lc.markEndpointUnavailable(endpointKey(endpoint), read) } func (lc *locationCache) markEndpointUnavailableForWrite(endpoint url.URL) (wasAlreadyUnavailable bool, err error) { - return lc.markEndpointUnavailable(endpoint, write) + return lc.markEndpointUnavailable(endpointKey(endpoint), write) +} + +// endpointKey normalizes a url.URL to the form used as a key in +// locationUnavailabilityInfoMap and stored in availReadEndpointsByLocation +// / availWriteEndpointsByLocation: scheme + host only. Callers of +// MarkEndpointUnavailable* commonly pass the full request URL (including +// path, query, fragment, RawPath, etc.), which would never struct-equal +// the base URLs the cache uses; without normalization the marks are +// recorded under keys nothing else ever looks up and the demote step in +// getPrefAvailableEndpointsLocked silently does nothing. +func endpointKey(u url.URL) url.URL { + return url.URL{Scheme: u.Scheme, Host: u.Host} } // markEndpointUnavailable atomically samples whether the endpoint was already @@ -302,7 +345,7 @@ func (lc *locationCache) isEndpointUnavailable(endpoint url.URL, ops requestedOp } func (lc *locationCache) isEndpointUnavailableLocked(endpoint url.URL, ops requestedOperations) bool { - info, ok := lc.locationUnavailabilityInfoMap[endpoint] + info, ok := lc.locationUnavailabilityInfoMap[endpointKey(endpoint)] if ops == none || !ok || ops&info.unavailableOps != ops { return false }