From b2025bd42a3d4c8f24d991cd8f491ddfd0087a1a Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 20 May 2026 00:47:26 -0700 Subject: [PATCH 1/2] azcosmos: single-pending-I/O pkrange cache and mid-pagination retry Rewrite partitionKeyRangeCache so concurrent callers share a single in-flight refresh per container, and the cached routing map remains readable while a refresh is in flight. The refresh runs on a detached context.Background() so a single caller's context cancellation no longer aborts the shared fetch for other waiters; each caller still honors its own ctx.Done(). forceRefresh now takes the routing map pointer the caller observed. When the entry already holds a fresher map than the caller's view, the refresh is suppressed and the fresher map is returned immediately (pointer-identity dedup of stale-view triggers). Make change-feed pagination resilient to mid-loop transient failures: each page is retried up to changeFeedPageMaxAttempts times with linear backoff on 5xx / 408 / 429 / network errors, preserving the pages already accumulated. Non-transient errors (other 4xx, ctx errors) fail fast. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk/data/azcosmos/CHANGELOG.md | 3 + .../azcosmos/cosmos_container_read_many.go | 2 +- .../cosmos_partition_key_range_cache.go | 295 +++++++++--- .../cosmos_partition_key_range_cache_test.go | 441 +++++++++++++++++- 4 files changed, 676 insertions(+), 65 deletions(-) diff --git a/sdk/data/azcosmos/CHANGELOG.md b/sdk/data/azcosmos/CHANGELOG.md index 2e1208900046..49bcfd2eb33d 100644 --- a/sdk/data/azcosmos/CHANGELOG.md +++ b/sdk/data/azcosmos/CHANGELOG.md @@ -4,6 +4,9 @@ ### Features Added +* Partition key range cache now serves concurrent callers from a single in-flight refresh per container, and the cached routing map remains readable while a refresh is in progress. The refresh runs on a detached background context so one caller's cancellation no longer aborts the shared fetch for other waiters; each caller continues to honor its own context deadline. +* Partition key range cache change-feed pagination is now resilient to transient mid-pagination failures (5xx, 408, 429, network errors). The failing page is retried with linear backoff while preserving the pages already accumulated, instead of restarting the entire drain from page 1 on the next refresh. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/data/azcosmos/cosmos_container_read_many.go b/sdk/data/azcosmos/cosmos_container_read_many.go index 84421b8f0027..bba68af9659c 100644 --- a/sdk/data/azcosmos/cosmos_container_read_many.go +++ b/sdk/data/azcosmos/cosmos_container_read_many.go @@ -381,7 +381,7 @@ func (c *ContainerClient) refreshPKRangeCache(ctx context.Context) error { if err != nil { return err } - _, err = c.database.client.getPKRangeCache().forceRefresh(ctx, containerRID, c.link, c.database.client) + _, err = c.database.client.getPKRangeCache().forceRefresh(ctx, containerRID, c.link, c.database.client, nil) if err != nil { return err } diff --git a/sdk/data/azcosmos/cosmos_partition_key_range_cache.go b/sdk/data/azcosmos/cosmos_partition_key_range_cache.go index a4183e15c7bd..32610ea33dcb 100644 --- a/sdk/data/azcosmos/cosmos_partition_key_range_cache.go +++ b/sdk/data/azcosmos/cosmos_partition_key_range_cache.go @@ -6,10 +6,13 @@ package azcosmos import ( "context" "encoding/json" + "errors" "fmt" "net/http" "sync" + "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" azlog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" @@ -21,14 +24,41 @@ import ( // invalidation (no TTL). Refreshes are incremental using the change-feed ETag. // Keying by RID (rather than name-based link) ensures the cache survives // container renames and matches the service's partition key range addressing. +// +// Concurrency model (single-pending-I/O per container): +// +// - At most one refresh runs per container at any time. Concurrent callers +// that arrive while a refresh is in flight share its result. +// - The cached routing map remains readable while a refresh is in flight; +// getRoutingMap returns immediately whenever the entry already has a +// non-nil routingMap and does not wait for the in-flight refresh. +// - The refresh goroutine runs on context.Background() so a single caller's +// ctx cancellation does not abort the shared fetch for other waiters. +// Each waiter still honors its own ctx and returns ctx.Err() if it fires +// before the shared refresh completes. +// - forceRefresh accepts the routing map pointer the caller observed when +// it decided to refresh ("previous"). If the entry already holds a +// different (fresher) routing map, the caller is served that map +// immediately without starting a new refresh — i.e. a refresh triggered +// by a stale-view caller is suppressed (pointer-identity dedup). type partitionKeyRangeCache struct { mu sync.RWMutex entries map[string]*pkRangeCacheEntry // keyed by container ResourceID } +// refreshOp represents an in-flight partition-key-range refresh for one +// container. Awaiters receive the (rm, err) pair by reading the fields after +// done is closed. +type refreshOp struct { + done chan struct{} + rm *collectionRoutingMap + err error +} + type pkRangeCacheEntry struct { - mu sync.Mutex // single-flights refresh for this container + mu sync.Mutex // protects routingMap and inFlight routingMap *collectionRoutingMap + inFlight *refreshOp } func newPartitionKeyRangeCache() *partitionKeyRangeCache { @@ -37,42 +67,39 @@ func newPartitionKeyRangeCache() *partitionKeyRangeCache { } } -// getRoutingMap returns the cached routing map for the given container RID. -// If the cache is empty for this container, it fetches from the service. -// containerLink is the name-based path used for the HTTP request. -func (c *partitionKeyRangeCache) getRoutingMap( - ctx context.Context, - containerRID string, - containerLink string, - client *Client, -) (*collectionRoutingMap, error) { - // Fast path: read lock check +// getOrCreateEntry returns the entry for the given container RID, creating it +// under the write lock if necessary. The returned entry is safe to use after +// the cache-level lock is released; each entry has its own per-entry mutex +// guarding its routingMap/inFlight fields. +func (c *partitionKeyRangeCache) getOrCreateEntry(containerRID string) *pkRangeCacheEntry { c.mu.RLock() entry, exists := c.entries[containerRID] c.mu.RUnlock() - if exists { - entry.mu.Lock() - if entry.routingMap != nil { - rm := entry.routingMap - entry.mu.Unlock() - return rm, nil - } - // Cache entry exists but routing map is nil (invalidated) — refresh under lock - rm, err := c.refreshEntry(ctx, containerLink, entry, client) - entry.mu.Unlock() - return rm, err + return entry } - // Slow path: create entry c.mu.Lock() - // Double check after acquiring write lock - entry, exists = c.entries[containerRID] - if !exists { - entry = &pkRangeCacheEntry{} - c.entries[containerRID] = entry + defer c.mu.Unlock() + if entry, exists = c.entries[containerRID]; exists { + return entry } - c.mu.Unlock() + entry = &pkRangeCacheEntry{} + c.entries[containerRID] = entry + return entry +} + +// getRoutingMap returns the cached routing map for the given container RID. +// If a routing map is already cached it is returned immediately, even when a +// refresh is in flight. Otherwise the caller joins or starts a shared +// in-flight refresh and waits on its own context. +func (c *partitionKeyRangeCache) getRoutingMap( + ctx context.Context, + containerRID string, + containerLink string, + client *Client, +) (*collectionRoutingMap, error) { + entry := c.getOrCreateEntry(containerRID) entry.mu.Lock() if entry.routingMap != nil { @@ -80,36 +107,96 @@ func (c *partitionKeyRangeCache) getRoutingMap( entry.mu.Unlock() return rm, nil } - rm, err := c.refreshEntry(ctx, containerLink, entry, client) + op := c.ensureInFlightLocked(entry, containerLink, client) entry.mu.Unlock() - return rm, err + + return awaitRefresh(ctx, op) } -// forceRefresh triggers an incremental refresh of the routing map for the given -// container. If the incremental merge fails (incomplete covering), it falls back -// to a full refresh. containerRID is the cache key; containerLink is used for HTTP requests. +// forceRefresh starts (or joins) a refresh for the given container and +// returns the resulting routing map. The previous parameter is the routing +// map pointer the caller observed when it decided to refresh: when non-nil +// and the entry already holds a different routing map, the caller is served +// that fresher map immediately without starting a new refresh. Pass nil to +// always start or join a refresh (e.g. when the caller has no prior view). func (c *partitionKeyRangeCache) forceRefresh( ctx context.Context, containerRID string, containerLink string, client *Client, + previous *collectionRoutingMap, ) (*collectionRoutingMap, error) { - c.mu.RLock() - entry, exists := c.entries[containerRID] - c.mu.RUnlock() + entry := c.getOrCreateEntry(containerRID) + + entry.mu.Lock() + // Suppress refresh if another caller already installed a fresher map past + // our view. Dedup of stale-view triggers via pointer identity. + if previous != nil && entry.routingMap != nil && entry.routingMap != previous { + rm := entry.routingMap + entry.mu.Unlock() + return rm, nil + } + op := c.ensureInFlightLocked(entry, containerLink, client) + entry.mu.Unlock() - if !exists { - // No entry yet — just do a normal get which will create and populate - return c.getRoutingMap(ctx, containerRID, containerLink, client) + return awaitRefresh(ctx, op) +} + +// ensureInFlightLocked returns the entry's in-flight refresh op, creating +// (and spawning) one if none is already running. Caller MUST hold entry.mu. +func (c *partitionKeyRangeCache) ensureInFlightLocked( + entry *pkRangeCacheEntry, + containerLink string, + client *Client, +) *refreshOp { + if entry.inFlight != nil { + return entry.inFlight } + op := &refreshOp{done: make(chan struct{})} + entry.inFlight = op + go c.runRefresh(entry, containerLink, client, op) + return op +} + +// runRefresh executes the change-feed refresh on a detached context.Background() +// so caller cancellations do not abort the shared fetch. On completion it +// updates the entry under entry.mu, clears the in-flight slot, and signals +// awaiters by closing op.done. +func (c *partitionKeyRangeCache) runRefresh( + entry *pkRangeCacheEntry, + containerLink string, + client *Client, + op *refreshOp, +) { + rm, err := c.refreshEntryDetached(containerLink, entry, client) entry.mu.Lock() - defer entry.mu.Unlock() - return c.refreshEntry(ctx, containerLink, entry, client) + if err == nil && rm != nil { + entry.routingMap = rm + } + op.rm = rm + op.err = err + entry.inFlight = nil + entry.mu.Unlock() + + close(op.done) } -// invalidate removes the cached routing map for the given container RID, -// forcing the next access to fetch fresh data. +// awaitRefresh blocks the caller until either the refresh completes or the +// caller's context is cancelled. The refresh continues running in the +// background even when individual awaiters return early via ctx.Err(). +func awaitRefresh(ctx context.Context, op *refreshOp) (*collectionRoutingMap, error) { + select { + case <-op.done: + return op.rm, op.err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// invalidate clears the cached routing map for the given container RID, +// forcing the next access to fetch fresh data. An in-flight refresh continues +// and its result will replace the cleared map. func (c *partitionKeyRangeCache) invalidate(containerRID string) { c.mu.RLock() entry, exists := c.entries[containerRID] @@ -123,9 +210,20 @@ func (c *partitionKeyRangeCache) invalidate(containerRID string) { } // maxChangeFeedIterations caps the number of change-feed fetch loops -// to prevent runaway requests during large-scale splits. Aligned with the Rust SDK. +// to prevent runaway requests during large-scale splits. const maxChangeFeedIterations = 1000 +// changeFeedPageMaxAttempts bounds how many times we retry a single +// change-feed page after a transient failure. The pipeline already does +// per-request retry; this is an additional safety net that preserves +// already-accumulated pages instead of restarting the entire change-feed +// drain from page 1 on a single bad page. +const changeFeedPageMaxAttempts = 3 + +// changeFeedPageRetryBaseDelay is the base sleep before retrying a failed +// change-feed page. Backoff is linear: base, 2*base, 3*base, ... +const changeFeedPageRetryBaseDelay = 100 * time.Millisecond + // changeFeedResult holds the result of draining all change-feed pages. type changeFeedResult struct { ranges []partitionKeyRange @@ -136,6 +234,11 @@ type changeFeedResult struct { // fetchAllChangeFeedPages fetches change-feed pages starting from startETag // until 304 Not Modified or the iteration cap. Returns the accumulated ranges, // the final ETag, and whether the loop completed cleanly (304 received). +// +// Each individual page is retried up to changeFeedPageMaxAttempts times on +// transient errors (5xx, 408, 429, network errors) with linear backoff so +// a single bad page doesn't discard the pages already accumulated. Non- +// transient errors (4xx other than 408/429, context errors) fail fast. func fetchAllChangeFeedPages( ctx context.Context, containerLink string, @@ -148,8 +251,12 @@ func fetchAllChangeFeedPages( if err := ctx.Err(); err != nil { return changeFeedResult{}, err } - result, err := fetchPartitionKeyRanges(ctx, containerLink, currentETag, client) + result, err := fetchOneChangeFeedPageWithRetry(ctx, containerLink, currentETag, client) if err != nil { + // Even though we're surfacing the error to the caller, log how + // far we got so operators can correlate partial-drain failures + // with the next refresh re-starting from scratch. + log.Writef(azlog.EventResponse, "partition key range change-feed page failed for container %s after %d successful pages (%d ranges accumulated): %v", containerLink, i, len(allRanges), err) return changeFeedResult{}, err } @@ -170,18 +277,91 @@ func fetchAllChangeFeedPages( return changeFeedResult{ranges: allRanges, finalETag: currentETag, completed: false}, nil } -// refreshEntry fetches PK ranges from the service and populates the entry. -// It attempts an incremental refresh if a previous routing map with an ETag exists, -// accumulating all change-feed pages before merging. Falls back to a full -// change-feed refresh if the incremental merge is incomplete. -// Caller must hold entry.mu. -func (c *partitionKeyRangeCache) refreshEntry( +// fetchOneChangeFeedPageWithRetry fetches a single change-feed page, +// retrying on transient errors so a transient hiccup mid-pagination +// doesn't discard the already-accumulated pages in the caller. Returns +// the last error if all attempts fail or the caller's context fires. +func fetchOneChangeFeedPageWithRetry( ctx context.Context, + containerLink string, + currentETag string, + client *Client, +) (fetchPartitionKeyRangesResult, error) { + var lastErr error + for attempt := 0; attempt < changeFeedPageMaxAttempts; attempt++ { + if err := ctx.Err(); err != nil { + return fetchPartitionKeyRangesResult{}, err + } + result, err := fetchPartitionKeyRanges(ctx, containerLink, currentETag, client) + if err == nil { + return result, nil + } + lastErr = err + if !isTransientPKRangeFetchError(err) { + return fetchPartitionKeyRangesResult{}, err + } + // Last attempt — don't sleep, just return. + if attempt == changeFeedPageMaxAttempts-1 { + break + } + // Linear backoff: 1×base, 2×base, ... + delay := time.Duration(attempt+1) * changeFeedPageRetryBaseDelay + log.Writef(azlog.EventResponse, "partition key range change-feed page transient failure for container %s (attempt %d/%d, retrying in %s): %v", containerLink, attempt+1, changeFeedPageMaxAttempts, delay, err) + timer := time.NewTimer(delay) + select { + case <-timer.C: + case <-ctx.Done(): + timer.Stop() + return fetchPartitionKeyRangesResult{}, ctx.Err() + } + } + return fetchPartitionKeyRangesResult{}, lastErr +} + +// isTransientPKRangeFetchError reports whether a /pkranges fetch error is +// worth retrying mid-pagination. Returns true for 5xx, 408, 429, and +// network-class errors (any error without an HTTP response). Returns false +// for 4xx (other than 408/429) and for context cancellation. +func isTransientPKRangeFetchError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + var respErr *azcore.ResponseError + if !errors.As(err, &respErr) { + // Network / transport error — treat as transient. + return true + } + switch { + case respErr.StatusCode >= 500: + return true + case respErr.StatusCode == http.StatusRequestTimeout: // 408 + return true + case respErr.StatusCode == http.StatusTooManyRequests: // 429 + return true + } + return false +} + +// refreshEntryDetached fetches PK ranges from the service and returns a fresh +// routing map. It snapshots the entry's previous routing map under entry.mu +// (briefly), then performs all network I/O without holding any lock. The +// caller is responsible for installing the returned map onto the entry. +// +// The function uses context.Background() internally via the runRefresh +// goroutine; callers must not pass a caller-scoped context. +func (c *partitionKeyRangeCache) refreshEntryDetached( containerLink string, entry *pkRangeCacheEntry, client *Client, ) (*collectionRoutingMap, error) { + entry.mu.Lock() previousMap := entry.routingMap + entry.mu.Unlock() + + ctx := context.Background() if previousMap != nil && previousMap.changeFeedETag != "" { // Incremental refresh: accumulate ALL change-feed pages first, then @@ -196,21 +376,21 @@ func (c *partitionKeyRangeCache) refreshEntry( if result.completed { if len(result.ranges) == 0 { - // No changes since last refresh — update ETag if changed + // No changes since last refresh — surface the ETag bump if any, + // otherwise return the previous map unchanged. if result.finalETag != previousMap.changeFeedETag { - entry.routingMap = &collectionRoutingMap{ + return &collectionRoutingMap{ orderedRanges: previousMap.orderedRanges, rangeByID: previousMap.rangeByID, goneRanges: previousMap.goneRanges, changeFeedETag: result.finalETag, - } + }, nil } - return entry.routingMap, nil + return previousMap, nil } merged := previousMap.tryCombine(result.ranges, result.finalETag) if merged != nil { - entry.routingMap = merged return merged, nil } } @@ -235,7 +415,6 @@ func (c *partitionKeyRangeCache) refreshEntry( return nil, fmt.Errorf("partition key range cache refresh failed: service returned an incomplete set of ranges for container %s (raw ranges=%d, final ranges=%d, issue: %s). This may indicate a transient issue during a partition split", containerLink, len(result.ranges), len(newMap.orderedRanges), issue) } - entry.routingMap = newMap return newMap, nil } diff --git a/sdk/data/azcosmos/cosmos_partition_key_range_cache_test.go b/sdk/data/azcosmos/cosmos_partition_key_range_cache_test.go index 0f6afdaac8a3..cebd9360e44a 100644 --- a/sdk/data/azcosmos/cosmos_partition_key_range_cache_test.go +++ b/sdk/data/azcosmos/cosmos_partition_key_range_cache_test.go @@ -247,7 +247,7 @@ func Test_partitionKeyRangeCache_incrementalRefresh_success(t *testing.T) { client.caches.pkRangeCache.mu.Unlock() // forceRefresh should do incremental refresh - rm, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client) + rm, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client, nil) require.NoError(t, err) require.Equal(t, 2, len(rm.orderedRanges)) require.Equal(t, "", rm.orderedRanges[0].MinInclusive) @@ -282,7 +282,7 @@ func Test_partitionKeyRangeCache_incrementalRefresh_304_immediate(t *testing.T) client.caches.pkRangeCache.entries["testRID"] = entry client.caches.pkRangeCache.mu.Unlock() - rm, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client) + rm, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client, nil) require.NoError(t, err) // Ranges should be preserved require.Equal(t, 2, len(rm.orderedRanges)) @@ -347,7 +347,7 @@ func Test_partitionKeyRangeCache_incrementalRefresh_mergeFailure_fullRefresh(t * client.caches.pkRangeCache.entries["testRID"] = entry client.caches.pkRangeCache.mu.Unlock() - rm, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client) + rm, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client, nil) require.NoError(t, err) // Should have the 3 ranges from full refresh require.Equal(t, 3, len(rm.orderedRanges)) @@ -383,7 +383,7 @@ func Test_partitionKeyRangeCache_incrementalRefresh_contextCancelled(t *testing. ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err := client.caches.pkRangeCache.forceRefresh(ctx, "testRID", "dbs/db1/colls/col1", client) + _, err := client.caches.pkRangeCache.forceRefresh(ctx, "testRID", "dbs/db1/colls/col1", client, nil) require.Error(t, err) require.ErrorIs(t, err, context.Canceled) @@ -689,7 +689,7 @@ func Test_partitionKeyRangeCache_incrementalRefresh_emptyPagesBeforeData(t *test require.Equal(t, 1, len(rm.orderedRanges)) // Trigger incremental refresh - rm, err = client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client) + rm, err = client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client, nil) require.NoError(t, err) // Should have 2 ranges after split (parent "0" filtered) @@ -766,7 +766,7 @@ func Test_partitionKeyRangeCache_incrementalRefresh_cascadingSplitAcrossPages(t require.Equal(t, "0", rm.orderedRanges[0].ID) // Trigger incremental refresh via forceRefresh - rm, err = client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client) + rm, err = client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client, nil) require.NoError(t, err) // Final state should have 2 ranges: "1" and "2" (parent "0" filtered) @@ -782,3 +782,432 @@ func Test_partitionKeyRangeCache_incrementalRefresh_cascadingSplitAcrossPages(t // Verify parent is marked as gone require.True(t, rm.isGone("0")) } + +// gatePolicy blocks each request on a "started" signal until the test +// closes the "release" channel. It also counts requests so tests can assert +// single-flight (factory invoked exactly once). +type gatePolicy struct { + started chan struct{} + release chan struct{} + count atomic.Int32 +} + +func newGatePolicy() *gatePolicy { + return &gatePolicy{ + started: make(chan struct{}, 16), + release: make(chan struct{}), + } +} + +func (g *gatePolicy) Do(req *policy.Request) (*http.Response, error) { + g.count.Add(1) + // non-blocking signal so tests can count "arrivals" without deadlock + select { + case g.started <- struct{}{}: + default: + } + select { + case <-g.release: + case <-req.Raw().Context().Done(): + return nil, req.Raw().Context().Err() + } + return req.Next() +} + +func createGatedClientForPKRangeCache(srv *mock.Server, gate *gatePolicy) *Client { + defaultEndpoint, _ := url.Parse(srv.URL()) + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", + azruntime.PipelineOptions{PerCall: []policy.Policy{gate}}, + &policy.ClientOptions{Transport: srv}) + gem := &globalEndpointManager{preferredLocations: []string{}} + return &Client{ + endpoint: srv.URL(), + endpointUrl: defaultEndpoint, + internal: internalClient, + gem: gem, + caches: &sharedCacheSet{ + pkRangeCache: newPartitionKeyRangeCache(), + containerCache: newContainerPropertiesCache(), + }, + } +} + +// appendPKRangesAndTerminator queues a 200 with the given ranges + ETag and +// a follow-up 304 to terminate the change-feed loop. +func appendPKRangesAndTerminator(srv *mock.Server, body []byte, etag string) { + srv.AppendResponse( + mock.WithBody(body), + mock.WithHeader(cosmosHeaderEtag, etag), + mock.WithStatusCode(200), + ) + srv.AppendResponse( + mock.WithStatusCode(304), + mock.WithHeader(cosmosHeaderEtag, etag), + ) +} + +func Test_partitionKeyRangeCache_getRoutingMap_returnsCachedDuringRefresh(t *testing.T) { + // While a forced refresh is in flight, a concurrent getRoutingMap should + // return the cached map immediately rather than waiting for the refresh. + srv, closeSrv := mock.NewTLSServer() + defer closeSrv() + + body := []byte(`{"_rid":"testRID","PartitionKeyRanges":[{"_rid":"r0","id":"0","minInclusive":"","maxExclusive":"FF","parents":[]}],"_count":1}`) + appendPKRangesAndTerminator(srv, body, "etag2") + + gate := newGatePolicy() + client := createGatedClientForPKRangeCache(srv, gate) + + // Seed the entry with a routing map so getRoutingMap should fast-path. + existing := newCollectionRoutingMap([]partitionKeyRange{{ID: "0", MinInclusive: "", MaxExclusive: "FF"}}, "etag1") + entry := &pkRangeCacheEntry{routingMap: existing} + client.caches.pkRangeCache.mu.Lock() + client.caches.pkRangeCache.entries["testRID"] = entry + client.caches.pkRangeCache.mu.Unlock() + + // Start a forced refresh in the background; it will block in the gate. + refreshDone := make(chan struct{}) + go func() { + defer close(refreshDone) + _, _ = client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client, existing) + }() + + // Wait for the refresh goroutine to actually issue its HTTP request. + <-gate.started + + // Concurrent getRoutingMap MUST return immediately with the existing map. + got, err := client.caches.pkRangeCache.getRoutingMap(context.Background(), "testRID", "dbs/db1/colls/col1", client) + require.NoError(t, err) + require.Same(t, existing, got) + + close(gate.release) + <-refreshDone +} + +func Test_partitionKeyRangeCache_forceRefresh_concurrentCallersShareOneFetch(t *testing.T) { + // Multiple concurrent forceRefresh callers must trigger only one network + // fetch and all observe the same resulting routing map. + srv, closeSrv := mock.NewTLSServer() + defer closeSrv() + + body := []byte(`{"_rid":"testRID","PartitionKeyRanges":[{"_rid":"r0","id":"0","minInclusive":"","maxExclusive":"FF","parents":[]}],"_count":1}`) + appendPKRangesAndTerminator(srv, body, "etag1") + + gate := newGatePolicy() + client := createGatedClientForPKRangeCache(srv, gate) + + // Seed entry without a routing map so concurrent forceRefresh callers all + // fall into the in-flight path. + entry := &pkRangeCacheEntry{} + client.caches.pkRangeCache.mu.Lock() + client.caches.pkRangeCache.entries["testRID"] = entry + client.caches.pkRangeCache.mu.Unlock() + + const N = 8 + var wg sync.WaitGroup + results := make([]*collectionRoutingMap, N) + errs := make([]error, N) + wg.Add(N) + for i := 0; i < N; i++ { + i := i + go func() { + defer wg.Done() + results[i], errs[i] = client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client, nil) + }() + } + + // Let the (single) fetch arrive then release it. + <-gate.started + close(gate.release) + wg.Wait() + + for i := 0; i < N; i++ { + require.NoError(t, errs[i]) + require.NotNil(t, results[i]) + require.Same(t, results[0], results[i], "all callers must share the same routing map pointer") + } + // Exactly one full-refresh sequence (200 + 304) = 2 HTTP requests. + require.Equal(t, int32(2), gate.count.Load()) +} + +func Test_partitionKeyRangeCache_forceRefresh_staleViewSuppressesNewFetch(t *testing.T) { + // When the entry already holds a fresher routing map than the caller's + // `previous` pointer, forceRefresh must return the fresher map without + // issuing a new fetch (pointer-identity dedup). + srv, closeSrv := mock.NewTLSServer() + defer closeSrv() + + gate := newGatePolicy() + close(gate.release) // never block + client := createGatedClientForPKRangeCache(srv, gate) + + stale := newCollectionRoutingMap([]partitionKeyRange{{ID: "0", MinInclusive: "", MaxExclusive: "FF"}}, "etagOld") + fresh := newCollectionRoutingMap([]partitionKeyRange{{ID: "0", MinInclusive: "", MaxExclusive: "FF"}}, "etagNew") + + entry := &pkRangeCacheEntry{routingMap: fresh} + client.caches.pkRangeCache.mu.Lock() + client.caches.pkRangeCache.entries["testRID"] = entry + client.caches.pkRangeCache.mu.Unlock() + + got, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client, stale) + require.NoError(t, err) + require.Same(t, fresh, got) + // No HTTP request should have been issued. + require.Equal(t, int32(0), gate.count.Load()) +} + +func Test_partitionKeyRangeCache_callerCancelDoesNotAbortSharedFetch(t *testing.T) { + // A caller whose context is cancelled while awaiting the in-flight refresh + // must return ctx.Err(); the refresh must continue and serve other waiters. + srv, closeSrv := mock.NewTLSServer() + defer closeSrv() + + body := []byte(`{"_rid":"testRID","PartitionKeyRanges":[{"_rid":"r0","id":"0","minInclusive":"","maxExclusive":"FF","parents":[]}],"_count":1}`) + appendPKRangesAndTerminator(srv, body, "etag1") + + gate := newGatePolicy() + client := createGatedClientForPKRangeCache(srv, gate) + + entry := &pkRangeCacheEntry{} + client.caches.pkRangeCache.mu.Lock() + client.caches.pkRangeCache.entries["testRID"] = entry + client.caches.pkRangeCache.mu.Unlock() + + // Caller A cancels mid-await. + ctxA, cancelA := context.WithCancel(context.Background()) + errA := make(chan error, 1) + go func() { + _, err := client.caches.pkRangeCache.getRoutingMap(ctxA, "testRID", "dbs/db1/colls/col1", client) + errA <- err + }() + + // Caller B uses background and should succeed. + errB := make(chan error, 1) + rmB := make(chan *collectionRoutingMap, 1) + go func() { + rm, err := client.caches.pkRangeCache.getRoutingMap(context.Background(), "testRID", "dbs/db1/colls/col1", client) + rmB <- rm + errB <- err + }() + + <-gate.started + cancelA() + require.ErrorIs(t, <-errA, context.Canceled) + + // Release the refresh and B should observe the populated map. + close(gate.release) + require.NoError(t, <-errB) + require.NotNil(t, <-rmB) +} + +func Test_partitionKeyRangeCache_invalidateDuringRefresh_keepsNewResult(t *testing.T) { + // invalidate() during an in-flight refresh must NOT abort the refresh; + // when the refresh completes, its result becomes the cached map. + srv, closeSrv := mock.NewTLSServer() + defer closeSrv() + + body := []byte(`{"_rid":"testRID","PartitionKeyRanges":[{"_rid":"r0","id":"0","minInclusive":"","maxExclusive":"FF","parents":[]}],"_count":1}`) + appendPKRangesAndTerminator(srv, body, "etagNew") + + gate := newGatePolicy() + client := createGatedClientForPKRangeCache(srv, gate) + + entry := &pkRangeCacheEntry{} + client.caches.pkRangeCache.mu.Lock() + client.caches.pkRangeCache.entries["testRID"] = entry + client.caches.pkRangeCache.mu.Unlock() + + done := make(chan struct { + rm *collectionRoutingMap + err error + }, 1) + go func() { + rm, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client, nil) + done <- struct { + rm *collectionRoutingMap + err error + }{rm, err} + }() + + <-gate.started + client.caches.pkRangeCache.invalidate("testRID") + close(gate.release) + + res := <-done + require.NoError(t, res.err) + require.NotNil(t, res.rm) + require.Equal(t, "etagNew", res.rm.changeFeedETag) + + // The cached map is the refresh result (not nil from the invalidate). + entry.mu.Lock() + require.Same(t, res.rm, entry.routingMap) + entry.mu.Unlock() +} + +func Test_partitionKeyRangeCache_refreshError_clearsInFlightForRetry(t *testing.T) { + // When a refresh fails, current waiters observe the error and the + // in-flight slot is cleared so the next caller starts a fresh op. + srv, closeSrv := mock.NewTLSServer() + defer closeSrv() + + // First refresh fails with 401 (non-transient, won't be retried by either + // the pipeline or our per-page retry). Second refresh succeeds. + srv.AppendResponse(mock.WithBody([]byte(`{"code":"Unauthorized"}`)), mock.WithStatusCode(401)) + body := []byte(`{"_rid":"testRID","PartitionKeyRanges":[{"_rid":"r0","id":"0","minInclusive":"","maxExclusive":"FF","parents":[]}],"_count":1}`) + appendPKRangesAndTerminator(srv, body, "etag1") + + client := createMockClientForPKRangeCache(srv) + entry := &pkRangeCacheEntry{} + client.caches.pkRangeCache.mu.Lock() + client.caches.pkRangeCache.entries["testRID"] = entry + client.caches.pkRangeCache.mu.Unlock() + + _, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client, nil) + require.Error(t, err) + + // in-flight slot cleared so next call kicks off another op. + entry.mu.Lock() + require.Nil(t, entry.inFlight) + entry.mu.Unlock() + + rm, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client, nil) + require.NoError(t, err) + require.NotNil(t, rm) + require.Equal(t, "etag1", rm.changeFeedETag) +} + +// createMockClientForPKRangeCacheNoRetry is like createMockClientForPKRangeCache +// but disables the pipeline's built-in retry policy so tests can exercise the +// change-feed loop's own per-page retry behaviour without the pipeline +// silently retrying transient failures first. +func createMockClientForPKRangeCacheNoRetry(srv *mock.Server) *Client { + defaultEndpoint, _ := url.Parse(srv.URL()) + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{ + Transport: srv, + Retry: policy.RetryOptions{MaxRetries: -1}, + }) + gem := &globalEndpointManager{preferredLocations: []string{}} + return &Client{ + endpoint: srv.URL(), + endpointUrl: defaultEndpoint, + internal: internalClient, + gem: gem, + caches: &sharedCacheSet{ + pkRangeCache: newPartitionKeyRangeCache(), + containerCache: newContainerPropertiesCache(), + }, + } +} + +func Test_partitionKeyRangeCache_midPagination_transientRetrySucceeds(t *testing.T) { + // The change-feed loop must survive a transient 503 between pages by + // retrying the failing page, preserving the pages already accumulated. + srv, closeSrv := mock.NewTLSServer() + defer closeSrv() + + // Page 1: 200 with one range, etag "p1" + srv.AppendResponse( + mock.WithBody([]byte(`{"_rid":"testRID","PartitionKeyRanges":[{"_rid":"r0","id":"0","minInclusive":"","maxExclusive":"80","parents":[]}],"_count":1}`)), + mock.WithHeader(cosmosHeaderEtag, "p1"), + mock.WithStatusCode(200), + ) + // Page 2: 503 transient — must be retried + srv.AppendResponse( + mock.WithBody([]byte(`{"code":"ServiceUnavailable"}`)), + mock.WithStatusCode(503), + ) + // Page 2 retry: 200 with second range, etag "p2" + srv.AppendResponse( + mock.WithBody([]byte(`{"_rid":"testRID","PartitionKeyRanges":[{"_rid":"r1","id":"1","minInclusive":"80","maxExclusive":"FF","parents":[]}],"_count":1}`)), + mock.WithHeader(cosmosHeaderEtag, "p2"), + mock.WithStatusCode(200), + ) + // Terminator + srv.AppendResponse(mock.WithStatusCode(304), mock.WithHeader(cosmosHeaderEtag, "p2")) + + client := createMockClientForPKRangeCacheNoRetry(srv) + entry := &pkRangeCacheEntry{} + client.caches.pkRangeCache.mu.Lock() + client.caches.pkRangeCache.entries["testRID"] = entry + client.caches.pkRangeCache.mu.Unlock() + + rm, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client, nil) + require.NoError(t, err) + require.NotNil(t, rm) + require.Equal(t, 2, len(rm.orderedRanges), "both pages must be in the final routing map") + require.Equal(t, "0", rm.orderedRanges[0].ID) + require.Equal(t, "1", rm.orderedRanges[1].ID) + require.Equal(t, "p2", rm.changeFeedETag) +} + +func Test_partitionKeyRangeCache_midPagination_exhaustedRetriesFails(t *testing.T) { + // If all retry attempts of a single page fail, the refresh fails as a + // whole (cache stays in its prior state). Verify the configured + // attempt cap is honored. + srv, closeSrv := mock.NewTLSServer() + defer closeSrv() + + // Page 1: success + srv.AppendResponse( + mock.WithBody([]byte(`{"_rid":"testRID","PartitionKeyRanges":[{"_rid":"r0","id":"0","minInclusive":"","maxExclusive":"80","parents":[]}],"_count":1}`)), + mock.WithHeader(cosmosHeaderEtag, "p1"), + mock.WithStatusCode(200), + ) + // Page 2: 503 on every attempt. Enqueue exactly changeFeedPageMaxAttempts 503s. + for i := 0; i < changeFeedPageMaxAttempts; i++ { + srv.AppendResponse( + mock.WithBody([]byte(`{"code":"ServiceUnavailable"}`)), + mock.WithStatusCode(503), + ) + } + + client := createMockClientForPKRangeCacheNoRetry(srv) + entry := &pkRangeCacheEntry{} + client.caches.pkRangeCache.mu.Lock() + client.caches.pkRangeCache.entries["testRID"] = entry + client.caches.pkRangeCache.mu.Unlock() + + _, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client, nil) + require.Error(t, err) + var respErr *azcore.ResponseError + require.ErrorAs(t, err, &respErr) + require.Equal(t, http.StatusServiceUnavailable, respErr.StatusCode) +} + +func Test_partitionKeyRangeCache_midPagination_nonTransientFailsFast(t *testing.T) { + // A non-transient error (e.g. 401) on a mid-loop page must NOT be + // retried; it should be surfaced immediately. + srv, closeSrv := mock.NewTLSServer() + defer closeSrv() + + // Page 1: success + srv.AppendResponse( + mock.WithBody([]byte(`{"_rid":"testRID","PartitionKeyRanges":[{"_rid":"r0","id":"0","minInclusive":"","maxExclusive":"80","parents":[]}],"_count":1}`)), + mock.WithHeader(cosmosHeaderEtag, "p1"), + mock.WithStatusCode(200), + ) + // Page 2: 401 — fail fast, no retries + srv.AppendResponse( + mock.WithBody([]byte(`{"code":"Unauthorized"}`)), + mock.WithStatusCode(401), + ) + // Sentinel: if we DID retry, the next response would be a 200, which + // would mask the bug. Make it a 503 instead so an unintended retry + // also fails (test would still fail with the wrong status code). + srv.AppendResponse( + mock.WithBody([]byte(`{"code":"ServiceUnavailable"}`)), + mock.WithStatusCode(503), + ) + + client := createMockClientForPKRangeCacheNoRetry(srv) + entry := &pkRangeCacheEntry{} + client.caches.pkRangeCache.mu.Lock() + client.caches.pkRangeCache.entries["testRID"] = entry + client.caches.pkRangeCache.mu.Unlock() + + _, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client, nil) + require.Error(t, err) + var respErr *azcore.ResponseError + require.ErrorAs(t, err, &respErr) + require.Equal(t, http.StatusUnauthorized, respErr.StatusCode, "non-transient 401 must fail fast without retry") +} From 1aa5d57c237a9661e21336cf9e2e68f1d94df54a Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 20 May 2026 00:48:07 -0700 Subject: [PATCH 2/2] azcosmos: add PR links to CHANGELOG entries Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk/data/azcosmos/CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/data/azcosmos/CHANGELOG.md b/sdk/data/azcosmos/CHANGELOG.md index 49bcfd2eb33d..3881cf7e39b0 100644 --- a/sdk/data/azcosmos/CHANGELOG.md +++ b/sdk/data/azcosmos/CHANGELOG.md @@ -4,8 +4,8 @@ ### Features Added -* Partition key range cache now serves concurrent callers from a single in-flight refresh per container, and the cached routing map remains readable while a refresh is in progress. The refresh runs on a detached background context so one caller's cancellation no longer aborts the shared fetch for other waiters; each caller continues to honor its own context deadline. -* Partition key range cache change-feed pagination is now resilient to transient mid-pagination failures (5xx, 408, 429, network errors). The failing page is retried with linear backoff while preserving the pages already accumulated, instead of restarting the entire drain from page 1 on the next refresh. +* Partition key range cache now serves concurrent callers from a single in-flight refresh per container, and the cached routing map remains readable while a refresh is in progress. The refresh runs on a detached background context so one caller's cancellation no longer aborts the shared fetch for other waiters; each caller continues to honor its own context deadline. See [PR 1](https://github.com/tvaron3/azure-sdk-for-go/pull/1). +* Partition key range cache change-feed pagination is now resilient to transient mid-pagination failures (5xx, 408, 429, network errors). The failing page is retried with linear backoff while preserving the pages already accumulated, instead of restarting the entire drain from page 1 on the next refresh. See [PR 1](https://github.com/tvaron3/azure-sdk-for-go/pull/1). ### Breaking Changes