diff --git a/sdk/data/azcosmos/CHANGELOG.md b/sdk/data/azcosmos/CHANGELOG.md index 3793a39bf71e..724526725eb6 100644 --- a/sdk/data/azcosmos/CHANGELOG.md +++ b/sdk/data/azcosmos/CHANGELOG.md @@ -4,12 +4,16 @@ ### Features Added +* Added client-level partition key range cache and container properties cache, reducing redundant metadata round-trips for ReadMany and query operations. See [PR 26723](https://github.com/Azure/azure-sdk-for-go/pull/26723) * Added operation diagnostics on responses and `DiagnosticsFromError` for retrieving diagnostics from failed operations. See [PR 26548](https://github.com/Azure/azure-sdk-for-go/pull/26548) ### Breaking Changes ### Bugs Fixed +* Fixed V2 partition key routing: the top 2 bits of the first EPK byte are now masked to stay within the partition key range space [0x00, 0x3F]. Previously, items whose V2 hash started with a byte >= 0x40 could fail routing in ReadMany because the EPK lexicographically exceeded the "FF" range sentinel. See [PR 26723](https://github.com/Azure/azure-sdk-for-go/pull/26723) +* Fixed error handling for partition key range calls which would previously cause panics on any error. See [PR 26723](https://github.com/Azure/azure-sdk-for-go/pull/26723) + ### Other Changes ## 1.5.0-beta.5 (2026-03-09) diff --git a/sdk/data/azcosmos/cosmos_client.go b/sdk/data/azcosmos/cosmos_client.go index f06baf88f515..de94f76080a0 100644 --- a/sdk/data/azcosmos/cosmos_client.go +++ b/sdk/data/azcosmos/cosmos_client.go @@ -11,6 +11,7 @@ import ( "net/http" "net/url" "strings" + "sync" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" @@ -33,6 +34,24 @@ type Client struct { internal *azcore.Client gem *globalEndpointManager endpointUrl *url.URL + caches *sharedCacheSet + closeOnce sync.Once +} + +// getContainerCache returns the container properties cache for this client. +func (c *Client) getContainerCache() *containerPropertiesCache { + if c.caches == nil { + return nil + } + return c.caches.containerCache +} + +// getPKRangeCache returns the partition key range cache for this client. +func (c *Client) getPKRangeCache() *partitionKeyRangeCache { + if c.caches == nil { + return nil + } + return c.caches.pkRangeCache } // Endpoint used to create the client. @@ -40,6 +59,18 @@ func (c *Client) Endpoint() string { return c.endpoint } +// Close releases the shared cache reference for this client. The underlying +// caches are removed from the global registry once all clients to the same +// account endpoint have been closed. After Close, the client should not be used. +// Close is idempotent; calling it multiple times is safe. +func (c *Client) Close() { + c.closeOnce.Do(func() { + if c.endpoint != "" { + releaseCaches(c.endpoint) + } + }) +} + // NewClientWithKey creates a new instance of Cosmos client with shared key authentication. It uses the default pipeline configuration. // endpoint - The cosmos service endpoint to use. // cred - The credential used to authenticate with the cosmos service. @@ -64,7 +95,7 @@ func NewClientWithKey(endpoint string, cred KeyCredential, o *ClientOptions) (*C if err != nil { return nil, err } - return &Client{endpoint: endpoint, endpointUrl: endpointUrl, internal: internalClient, gem: gem}, nil + return &Client{endpoint: endpoint, endpointUrl: endpointUrl, internal: internalClient, gem: gem, caches: acquireCaches(endpoint)}, nil } // NewClient creates a new instance of Cosmos client with Azure AD access token authentication. It uses the default pipeline configuration. @@ -110,7 +141,7 @@ func NewClient(endpoint string, cred azcore.TokenCredential, o *ClientOptions) ( if err != nil { return nil, err } - return &Client{endpoint: endpoint, endpointUrl: endpointUrl, internal: internalClient, gem: gem}, nil + return &Client{endpoint: endpoint, endpointUrl: endpointUrl, internal: internalClient, gem: gem, caches: acquireCaches(endpoint)}, nil } // NewClientFromConnectionString creates a new instance of Cosmos client from connection string. It uses the default pipeline configuration. diff --git a/sdk/data/azcosmos/cosmos_collection_routing_map.go b/sdk/data/azcosmos/cosmos_collection_routing_map.go new file mode 100644 index 000000000000..834526549c5b --- /dev/null +++ b/sdk/data/azcosmos/cosmos_collection_routing_map.go @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +import ( + "sort" + + "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos/internal/epk" +) + +// collectionRoutingMap holds an immutable snapshot of partition key ranges for a +// container, sorted for efficient EPK lookups. It supports incremental merging +// when partition splits or merges occur. +type collectionRoutingMap struct { + // orderedRanges are the partition key ranges sorted by MinInclusive ascending. + orderedRanges []partitionKeyRange + // rangeByID provides O(1) lookups of ranges by their ID. + rangeByID map[string]partitionKeyRange + // goneRanges tracks parent range IDs that have been replaced by children after splits. + goneRanges map[string]bool + // changeFeedETag is the ETag for incremental change-feed refreshes. + changeFeedETag string +} + +// newCollectionRoutingMap creates a new collectionRoutingMap from a set of ranges. +// It filters out "gone" parent ranges (identified via the Parents field on child ranges) +// and sorts the remaining ranges by MinInclusive. +func newCollectionRoutingMap(ranges []partitionKeyRange, changeFeedETag string) *collectionRoutingMap { + goneRanges := make(map[string]bool) + for _, r := range ranges { + for _, parent := range r.Parents { + goneRanges[parent] = true + } + } + + // Filter out gone ranges + filtered := make([]partitionKeyRange, 0, len(ranges)) + for _, r := range ranges { + if !goneRanges[r.ID] { + filtered = append(filtered, r) + } + } + + // Sort by MinInclusive using length-aware comparison for HPK boundaries + sort.Slice(filtered, func(i, j int) bool { + return epk.CompareEPK(filtered[i].MinInclusive, filtered[j].MinInclusive) < 0 + }) + + rangeByID := make(map[string]partitionKeyRange, len(filtered)) + for _, r := range filtered { + rangeByID[r.ID] = r + } + + return &collectionRoutingMap{ + orderedRanges: filtered, + rangeByID: rangeByID, + goneRanges: goneRanges, + changeFeedETag: changeFeedETag, + } +} + +// tryCombine merges new ranges (from an incremental change-feed refresh) into +// the existing routing map. Returns a new collectionRoutingMap if the merge +// succeeds (produces a complete covering), or nil if the result is incomplete +// (indicating a full refresh is needed). +func (crm *collectionRoutingMap) tryCombine(newRanges []partitionKeyRange, newETag string) *collectionRoutingMap { + // Accumulate gone ranges from both existing and new ranges + combinedGone := make(map[string]bool, len(crm.goneRanges)) + for id := range crm.goneRanges { + combinedGone[id] = true + } + for _, r := range newRanges { + for _, parent := range r.Parents { + combinedGone[parent] = true + } + } + + // Build a combined set: existing ranges (minus gone) plus new ranges (minus gone) + combinedByID := make(map[string]partitionKeyRange, len(crm.rangeByID)+len(newRanges)) + for id, r := range crm.rangeByID { + if !combinedGone[id] { + combinedByID[id] = r + } + } + for _, r := range newRanges { + if !combinedGone[r.ID] { + combinedByID[r.ID] = r + } + } + + // Build sorted slice + combined := make([]partitionKeyRange, 0, len(combinedByID)) + for _, r := range combinedByID { + combined = append(combined, r) + } + sort.Slice(combined, func(i, j int) bool { + return epk.CompareEPK(combined[i].MinInclusive, combined[j].MinInclusive) < 0 + }) + + // Validate completeness: ranges must form a contiguous covering + if !isCompleteSetOfRanges(combined) { + return nil + } + + return &collectionRoutingMap{ + orderedRanges: combined, + rangeByID: combinedByID, + goneRanges: combinedGone, + changeFeedETag: newETag, + } +} + +// isGone returns true if the given range ID has been replaced (by a split/merge). +func (crm *collectionRoutingMap) isGone(rangeID string) bool { + return crm.goneRanges[rangeID] +} + +// getOverlappingRanges returns all partition key ranges that overlap with the +// given EPK range [minInclusive, maxExclusive). Uses binary search for O(log n) +// lookups. The ranges must be sorted and contiguous (guaranteed by construction). +func (crm *collectionRoutingMap) getOverlappingRanges(minInclusive, maxExclusive string) []partitionKeyRange { + if len(crm.orderedRanges) == 0 { + return nil + } + + // Start: rightmost range whose MinInclusive <= minInclusive. + // Same logic as findPhysicalRangeForEPK. + startIdx := sort.Search(len(crm.orderedRanges), func(i int) bool { + return epk.CompareEPK(crm.orderedRanges[i].MinInclusive, minInclusive) > 0 + }) - 1 + if startIdx < 0 { + startIdx = 0 + } + + // End: first range whose MinInclusive >= maxExclusive. + // All ranges from startIdx up to (but not including) endIdx overlap. + endIdx := startIdx + sort.Search(len(crm.orderedRanges)-startIdx, func(i int) bool { + return epk.CompareEPK(crm.orderedRanges[startIdx+i].MinInclusive, maxExclusive) >= 0 + }) + + if endIdx <= startIdx { + // At minimum, include the range containing minInclusive + endIdx = startIdx + 1 + } + if endIdx > len(crm.orderedRanges) { + endIdx = len(crm.orderedRanges) + } + + result := make([]partitionKeyRange, endIdx-startIdx) + copy(result, crm.orderedRanges[startIdx:endIdx]) + return result +} + +// isCompleteSetOfRanges validates that the sorted ranges form a contiguous +// partition covering with no gaps or overlaps. The first range should start +// at "" and each subsequent range should start where the previous one ends. +func isCompleteSetOfRanges(ranges []partitionKeyRange) bool { + if len(ranges) == 0 { + return false + } + + // First range must start at "" + if ranges[0].MinInclusive != "" { + return false + } + + // Each range's MinInclusive must equal the previous range's MaxExclusive. + // Use CompareEPK for length-aware comparison — HPK containers can return + // mixed-length boundaries that are semantically equal (zero-padded). + for i := 1; i < len(ranges); i++ { + if epk.CompareEPK(ranges[i].MinInclusive, ranges[i-1].MaxExclusive) != 0 { + return false + } + } + + // Last range must end at "FF" (the maximum EPK boundary) or be unbounded ("") + lastMax := ranges[len(ranges)-1].MaxExclusive + if lastMax != "FF" && lastMax != "" { + return false + } + + return true +} diff --git a/sdk/data/azcosmos/cosmos_collection_routing_map_test.go b/sdk/data/azcosmos/cosmos_collection_routing_map_test.go new file mode 100644 index 000000000000..42aee7640a81 --- /dev/null +++ b/sdk/data/azcosmos/cosmos_collection_routing_map_test.go @@ -0,0 +1,226 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_newCollectionRoutingMap_basic(t *testing.T) { + ranges := []partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "05C1CFFFFFFFF8"}, + {ID: "1", MinInclusive: "05C1CFFFFFFFF8", MaxExclusive: "FF"}, + } + + rm := newCollectionRoutingMap(ranges, "etag1") + require.NotNil(t, rm) + require.Equal(t, 2, len(rm.orderedRanges)) + require.Equal(t, "0", rm.orderedRanges[0].ID) + require.Equal(t, "1", rm.orderedRanges[1].ID) + require.Equal(t, "etag1", rm.changeFeedETag) + require.False(t, rm.isGone("0")) + require.False(t, rm.isGone("1")) +} + +func Test_newCollectionRoutingMap_sortsRanges(t *testing.T) { + // Provide ranges in reverse order + ranges := []partitionKeyRange{ + {ID: "1", MinInclusive: "05C1CFFFFFFFF8", MaxExclusive: "FF"}, + {ID: "0", MinInclusive: "", MaxExclusive: "05C1CFFFFFFFF8"}, + } + + rm := newCollectionRoutingMap(ranges, "") + require.Equal(t, "0", rm.orderedRanges[0].ID) + require.Equal(t, "1", rm.orderedRanges[1].ID) +} + +func Test_newCollectionRoutingMap_filtersGoneParents(t *testing.T) { + // Simulate a split: range "0" split into "2" and "3" + ranges := []partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "05C1CFFFFFFFF8"}, + {ID: "1", MinInclusive: "05C1CFFFFFFFF8", MaxExclusive: "FF"}, + {ID: "2", MinInclusive: "", MaxExclusive: "02E0", Parents: []string{"0"}}, + {ID: "3", MinInclusive: "02E0", MaxExclusive: "05C1CFFFFFFFF8", Parents: []string{"0"}}, + } + + rm := newCollectionRoutingMap(ranges, "etag2") + require.Equal(t, 3, len(rm.orderedRanges)) + require.True(t, rm.isGone("0")) + require.False(t, rm.isGone("1")) + require.False(t, rm.isGone("2")) + + // Verify order + require.Equal(t, "2", rm.orderedRanges[0].ID) + require.Equal(t, "3", rm.orderedRanges[1].ID) + require.Equal(t, "1", rm.orderedRanges[2].ID) +} + +func Test_newCollectionRoutingMap_rangeByID(t *testing.T) { + ranges := []partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "05C1CFFFFFFFF8"}, + {ID: "1", MinInclusive: "05C1CFFFFFFFF8", MaxExclusive: "FF"}, + } + + rm := newCollectionRoutingMap(ranges, "") + r, ok := rm.rangeByID["0"] + require.True(t, ok) + require.Equal(t, "", r.MinInclusive) + require.Equal(t, "05C1CFFFFFFFF8", r.MaxExclusive) +} + +func Test_tryCombine_successfulSplit(t *testing.T) { + // Initial state: two ranges + initial := newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "05C1CFFFFFFFF8"}, + {ID: "1", MinInclusive: "05C1CFFFFFFFF8", MaxExclusive: "FF"}, + }, "etag1") + + // Split: range "0" splits into "2" and "3" + newRanges := []partitionKeyRange{ + {ID: "2", MinInclusive: "", MaxExclusive: "02E0", Parents: []string{"0"}}, + {ID: "3", MinInclusive: "02E0", MaxExclusive: "05C1CFFFFFFFF8", Parents: []string{"0"}}, + } + + merged := initial.tryCombine(newRanges, "etag2") + require.NotNil(t, merged) + require.Equal(t, 3, len(merged.orderedRanges)) + require.Equal(t, "etag2", merged.changeFeedETag) + require.True(t, merged.isGone("0")) + + // Verify ranges are sorted correctly + require.Equal(t, "2", merged.orderedRanges[0].ID) + require.Equal(t, "3", merged.orderedRanges[1].ID) + require.Equal(t, "1", merged.orderedRanges[2].ID) +} + +func Test_tryCombine_incompleteCovering(t *testing.T) { + initial := newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "05C1CFFFFFFFF8"}, + {ID: "1", MinInclusive: "05C1CFFFFFFFF8", MaxExclusive: "FF"}, + }, "etag1") + + // Provide only one child — covering is incomplete + newRanges := []partitionKeyRange{ + {ID: "2", MinInclusive: "", MaxExclusive: "02E0", Parents: []string{"0"}}, + } + + merged := initial.tryCombine(newRanges, "etag2") + require.Nil(t, merged, "tryCombine should return nil for incomplete covering") +} + +func Test_isCompleteSetOfRanges_valid(t *testing.T) { + ranges := []partitionKeyRange{ + {MinInclusive: "", MaxExclusive: "05C1CFFFFFFFF8"}, + {MinInclusive: "05C1CFFFFFFFF8", MaxExclusive: "FF"}, + } + require.True(t, isCompleteSetOfRanges(ranges)) +} + +func Test_isCompleteSetOfRanges_empty(t *testing.T) { + require.False(t, isCompleteSetOfRanges(nil)) + require.False(t, isCompleteSetOfRanges([]partitionKeyRange{})) +} + +func Test_isCompleteSetOfRanges_doesNotStartAtEmpty(t *testing.T) { + ranges := []partitionKeyRange{ + {MinInclusive: "05C1CFFFFFFFF8", MaxExclusive: "FF"}, + } + require.False(t, isCompleteSetOfRanges(ranges)) +} + +func Test_isCompleteSetOfRanges_gap(t *testing.T) { + ranges := []partitionKeyRange{ + {MinInclusive: "", MaxExclusive: "03"}, + {MinInclusive: "05", MaxExclusive: "FF"}, // gap between 03 and 05 + } + require.False(t, isCompleteSetOfRanges(ranges)) +} + +func Test_isCompleteSetOfRanges_doesNotEndAtFF(t *testing.T) { + ranges := []partitionKeyRange{ + {MinInclusive: "", MaxExclusive: "05C1CFFFFFFFF8"}, + } + require.False(t, isCompleteSetOfRanges(ranges)) +} + +func Test_isCompleteSetOfRanges_singleRange(t *testing.T) { + ranges := []partitionKeyRange{ + {MinInclusive: "", MaxExclusive: "FF"}, + } + require.True(t, isCompleteSetOfRanges(ranges)) +} + +func Test_isCompleteSetOfRanges_emptyMaxExclusive(t *testing.T) { + // Some implementations use "" as unbounded end + ranges := []partitionKeyRange{ + {MinInclusive: "", MaxExclusive: ""}, + } + require.True(t, isCompleteSetOfRanges(ranges)) +} + +func Test_getOverlappingRanges_singleRange(t *testing.T) { + rm := newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "FF"}, + }, "") + + result := rm.getOverlappingRanges("0000", "3FFF") + require.Len(t, result, 1) + require.Equal(t, "0", result[0].ID) +} + +func Test_getOverlappingRanges_multipleRanges(t *testing.T) { + rm := newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "10"}, + {ID: "1", MinInclusive: "10", MaxExclusive: "20"}, + {ID: "2", MinInclusive: "20", MaxExclusive: "30"}, + {ID: "3", MinInclusive: "30", MaxExclusive: "FF"}, + }, "") + + // Range spanning partitions 1 and 2 + result := rm.getOverlappingRanges("10", "30") + require.Len(t, result, 2) + require.Equal(t, "1", result[0].ID) + require.Equal(t, "2", result[1].ID) +} + +func Test_getOverlappingRanges_allRanges(t *testing.T) { + rm := newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "10"}, + {ID: "1", MinInclusive: "10", MaxExclusive: "20"}, + {ID: "2", MinInclusive: "20", MaxExclusive: "FF"}, + }, "") + + result := rm.getOverlappingRanges("", "FF") + require.Len(t, result, 3) +} + +func Test_getOverlappingRanges_pointInMiddle(t *testing.T) { + rm := newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "10"}, + {ID: "1", MinInclusive: "10", MaxExclusive: "20"}, + {ID: "2", MinInclusive: "20", MaxExclusive: "FF"}, + }, "") + + // EPK range that starts and ends within range "1" + result := rm.getOverlappingRanges("15", "18") + require.Len(t, result, 1) + require.Equal(t, "1", result[0].ID) +} + +func Test_getOverlappingRanges_mixedLengthBoundaries(t *testing.T) { + // Simulate HPK container with mixed-length EPK boundaries + partial := "06AB34CFE4E482236BCACBBF50E234AB" + fullZero := partial + "00000000000000000000000000000000" + + rm := newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: partial}, + {ID: "1", MinInclusive: fullZero, MaxExclusive: "FF"}, + }, "") + + // A query range spanning both should find both + result := rm.getOverlappingRanges("", "FF") + require.Len(t, result, 2) +} diff --git a/sdk/data/azcosmos/cosmos_container.go b/sdk/data/azcosmos/cosmos_container.go index d6958815bd9a..0627e73e8855 100644 --- a/sdk/data/azcosmos/cosmos_container.go +++ b/sdk/data/azcosmos/cosmos_container.go @@ -65,6 +65,22 @@ func (c *ContainerClient) Read( } ctx, endSpan := startSpan(ctx, spanName.name, c.database.client.internal.Tracer(), &spanName.options) defer func() { endSpan(err) }() + + response, err := c.readContainerRaw(ctx, o) + if err == nil && c.database.client.getContainerCache() != nil && response.ContainerProperties != nil { + // Populate the container properties cache on successful Read + c.database.client.getContainerCache().set(c.link, response.ContainerProperties) + } + return response, err +} + +// readContainerRaw performs the HTTP call to read container properties. +// It is the shared implementation used by both Read() and the container +// properties cache refresh, ensuring consistent request construction. +func (c *ContainerClient) readContainerRaw( + ctx context.Context, + o *ReadContainerOptions, +) (ContainerResponse, error) { if o == nil { o = &ReadContainerOptions{} } @@ -89,8 +105,7 @@ func (c *ContainerClient) Read( return ContainerResponse{}, err } - response, err := newContainerResponse(azResponse) - return response, err + return newContainerResponse(azResponse) } // Replace a Cosmos container. @@ -134,6 +149,9 @@ func (c *ContainerClient) Replace( } response, err := newContainerResponse(azResponse) + if err == nil && c.database.client.getContainerCache() != nil && response.ContainerProperties != nil { + c.database.client.getContainerCache().set(c.link, response.ContainerProperties) + } return response, err } @@ -880,6 +898,19 @@ func (c *ContainerClient) getRID(ctx context.Context) (string, error) { return containerResponse.ContainerProperties.ResourceID, nil } +// getContainerRID resolves the container's ResourceID, using the container +// properties cache if available, otherwise falling back to a direct Read. +func (c *ContainerClient) getContainerRID(ctx context.Context) (string, error) { + if c.database.client.getContainerCache() != nil { + props, err := c.database.client.getContainerCache().getProperties(ctx, c) + if err != nil { + return "", err + } + return props.ResourceID, nil + } + return c.getRID(ctx) +} + func (c *ContainerClient) getSpanForContainer(operationType operationType, resourceType resourceType, id string) (span, error) { return getSpanNameForContainers(c.database.client.accountEndpointUrl(), operationType, resourceType, c.database.id, id) } @@ -889,6 +920,7 @@ func (c *ContainerClient) getSpanForItems(operationType operationType) (span, er } func (c *ContainerClient) getPartitionKeyRanges(ctx context.Context, o *partitionKeyRangeOptions) (partitionKeyRangeResponse, error) { + var err error spanName, err := c.getSpanForContainer(operationTypeRead, resourceTypePartitionKeyRange, c.id) if err != nil { return partitionKeyRangeResponse{}, err @@ -896,6 +928,33 @@ func (c *ContainerClient) getPartitionKeyRanges(ctx context.Context, o *partitio ctx, endSpan := startSpan(ctx, spanName.name, c.database.client.internal.Tracer(), &spanName.options) defer func() { endSpan(err) }() + // Use the cache if available, otherwise fall back to direct fetch + if c.database.client.getPKRangeCache() != nil { + var containerRID string + containerRID, err = c.getContainerRID(ctx) + if err != nil { + return partitionKeyRangeResponse{}, err + } + + var routingMap *collectionRoutingMap + routingMap, err = c.database.client.getPKRangeCache().getRoutingMap(ctx, containerRID, c.link, c.database.client) + if err != nil { + return partitionKeyRangeResponse{}, err + } + + return partitionKeyRangeResponse{ + PartitionKeyRanges: routingMap.orderedRanges, + Count: len(routingMap.orderedRanges), + }, nil + } + + // Fallback: direct fetch without caching + return c.fetchPartitionKeyRangesDirect(ctx, o) +} + +// fetchPartitionKeyRangesDirect fetches partition key ranges directly from the service +// without using the cache. +func (c *ContainerClient) fetchPartitionKeyRangesDirect(ctx context.Context, o *partitionKeyRangeOptions) (partitionKeyRangeResponse, error) { operationContext := pipelineRequestOptions{ resourceType: resourceTypePartitionKeyRange, resourceAddress: c.link, @@ -916,6 +975,9 @@ func (c *ContainerClient) getPartitionKeyRanges(ctx context.Context, o *partitio operationContext, o, nil) + if err != nil { + return partitionKeyRangeResponse{}, err + } response, err := newPartitionKeyRangeResponse(azResponse) if err != nil { diff --git a/sdk/data/azcosmos/cosmos_container_properties_cache.go b/sdk/data/azcosmos/cosmos_container_properties_cache.go new file mode 100644 index 000000000000..0d3e17df8aa4 --- /dev/null +++ b/sdk/data/azcosmos/cosmos_container_properties_cache.go @@ -0,0 +1,199 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +import ( + "context" + "errors" + "sync" +) + +// containerPropertiesCache provides a client-level cache of container properties +// (specifically PartitionKeyDefinition). It maintains a dual index — by container +// link (name-based path) and by ResourceID (RID) — so that lookups succeed +// regardless of which identifier the caller has. When a reference is fetched or +// inserted, both indices are cross-populated. +type containerPropertiesCache struct { + mu sync.RWMutex + entries map[string]*containerPropsCacheEntry // keyed by container link + entriesByID map[string]*containerPropsCacheEntry // keyed by ResourceID +} + +type containerPropsCacheEntry struct { + mu sync.Mutex // single-flights refresh for this container + props *ContainerProperties +} + +func newContainerPropertiesCache() *containerPropertiesCache { + return &containerPropertiesCache{ + entries: make(map[string]*containerPropsCacheEntry), + entriesByID: make(map[string]*containerPropsCacheEntry), + } +} + +// getProperties returns the cached container properties for the given container. +// If the cache is empty, it fetches from the service using the provided ContainerClient. +func (c *containerPropertiesCache) getProperties( + ctx context.Context, + container *ContainerClient, +) (*ContainerProperties, error) { + containerLink := container.link + + // Fast path: read lock check + c.mu.RLock() + entry, exists := c.entries[containerLink] + c.mu.RUnlock() + + if exists { + entry.mu.Lock() + if entry.props != nil { + props := entry.props + entry.mu.Unlock() + return props, nil + } + // Entry exists but props is nil (invalidated) — refresh under lock + props, err := c.refreshEntry(ctx, container, entry) + entry.mu.Unlock() + if err == nil { + c.updateRIDIndex(entry, props) + } + return props, err + } + + // Slow path: create entry + c.mu.Lock() + entry, exists = c.entries[containerLink] + if !exists { + entry = &containerPropsCacheEntry{} + c.entries[containerLink] = entry + } + c.mu.Unlock() + + entry.mu.Lock() + if entry.props != nil { + props := entry.props + entry.mu.Unlock() + return props, nil + } + props, err := c.refreshEntry(ctx, container, entry) + entry.mu.Unlock() + if err == nil { + c.updateRIDIndex(entry, props) + } + return props, err +} + +// getPropertiesByRID looks up cached container properties by ResourceID. +// Returns nil if the RID is not in the cache. +func (c *containerPropertiesCache) getPropertiesByRID(resourceID string) *ContainerProperties { + if resourceID == "" { + return nil + } + c.mu.RLock() + entry, exists := c.entriesByID[resourceID] + c.mu.RUnlock() + + if !exists { + return nil + } + + entry.mu.Lock() + props := entry.props + entry.mu.Unlock() + return props +} + +// invalidate removes the cached properties for the given container link, +// forcing the next access to fetch fresh data. Also removes the RID index entry. +// This is called during 410/Gone retry paths to handle scenarios where a container +// has been deleted and recreated with the same name but a new ResourceID. +func (c *containerPropertiesCache) invalidate(containerLink string) { + c.mu.RLock() + entry, exists := c.entries[containerLink] + c.mu.RUnlock() + + if exists { + entry.mu.Lock() + props := entry.props + entry.props = nil + entry.mu.Unlock() + + // Remove the RID index entry if we had cached props + if props != nil && props.ResourceID != "" { + c.mu.Lock() + delete(c.entriesByID, props.ResourceID) + c.mu.Unlock() + } + } +} + +// set directly populates the cache with the given container properties. +// This is used when a Read() or Replace() call already fetched the properties. +// Cross-populates both the link-based and RID-based indices. +func (c *containerPropertiesCache) set(containerLink string, props *ContainerProperties) { + c.mu.RLock() + entry, exists := c.entries[containerLink] + c.mu.RUnlock() + + if exists { + entry.mu.Lock() + entry.props = props + entry.mu.Unlock() + } else { + // Slow path: upgrade to write lock and double-check. + // We release c.mu before touching entry.mu to maintain + // consistent lock order (c.mu → entry.mu, never reversed). + c.mu.Lock() + entry, exists = c.entries[containerLink] + if !exists { + entry = &containerPropsCacheEntry{props: props} + c.entries[containerLink] = entry + } + c.mu.Unlock() + + if exists { + entry.mu.Lock() + entry.props = props + entry.mu.Unlock() + } + } + + c.updateRIDIndex(entry, props) +} + +// updateRIDIndex cross-populates the RID-based index. +// Must be called WITHOUT entry.mu held to maintain lock order (c.mu → entry.mu). +func (c *containerPropertiesCache) updateRIDIndex(entry *containerPropsCacheEntry, props *ContainerProperties) { + if props != nil && props.ResourceID != "" { + c.mu.Lock() + c.entriesByID[props.ResourceID] = entry + c.mu.Unlock() + } +} + +// refreshEntry fetches container properties directly from the service. +// This bypasses container.Read() to avoid deadlock — the caller already holds entry.mu, +// and Read() calls cache.set() which would try to re-acquire the same lock. +// It uses readContainerRaw() to share the HTTP call logic with Read(). +// Caller must hold entry.mu. +// NOTE: This method must NOT acquire c.mu — callers update the RID index +// after releasing entry.mu via updateRIDIndex() to prevent lock-order inversion. +func (c *containerPropertiesCache) refreshEntry( + ctx context.Context, + container *ContainerClient, + entry *containerPropsCacheEntry, +) (*ContainerProperties, error) { + response, err := container.readContainerRaw(ctx, nil) + if err != nil { + return nil, err + } + + if response.ContainerProperties == nil { + return nil, errors.New("container properties response contained no properties") + } + + entry.props = response.ContainerProperties + + return entry.props, nil +} diff --git a/sdk/data/azcosmos/cosmos_container_properties_cache_test.go b/sdk/data/azcosmos/cosmos_container_properties_cache_test.go new file mode 100644 index 000000000000..0a5bcc96ff89 --- /dev/null +++ b/sdk/data/azcosmos/cosmos_container_properties_cache_test.go @@ -0,0 +1,176 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_containerPropertiesCache_newCache(t *testing.T) { + cache := newContainerPropertiesCache() + require.NotNil(t, cache) + require.Empty(t, cache.entries) +} + +func Test_containerPropertiesCache_invalidate_noEntry(t *testing.T) { + cache := newContainerPropertiesCache() + + // Should not panic + cache.invalidate("dbs/db1/colls/col1") + + cache.mu.RLock() + _, exists := cache.entries["dbs/db1/colls/col1"] + cache.mu.RUnlock() + require.False(t, exists) +} + +func Test_containerPropertiesCache_invalidate_existingEntry(t *testing.T) { + cache := newContainerPropertiesCache() + + entry := &containerPropsCacheEntry{ + props: &ContainerProperties{ + ID: "col1", + PartitionKeyDefinition: PartitionKeyDefinition{ + Paths: []string{"/pk"}, + Version: 2, + }, + }, + } + cache.mu.Lock() + cache.entries["dbs/db1/colls/col1"] = entry + cache.mu.Unlock() + + // Verify populated + entry.mu.Lock() + require.NotNil(t, entry.props) + entry.mu.Unlock() + + // Invalidate + cache.invalidate("dbs/db1/colls/col1") + + // Verify nil + entry.mu.Lock() + require.Nil(t, entry.props) + entry.mu.Unlock() +} + +func Test_containerPropertiesCache_getProperties_cacheHit(t *testing.T) { + cache := newContainerPropertiesCache() + + expectedProps := &ContainerProperties{ + ID: "col1", + PartitionKeyDefinition: PartitionKeyDefinition{ + Paths: []string{"/pk"}, + Version: 2, + }, + } + entry := &containerPropsCacheEntry{props: expectedProps} + cache.mu.Lock() + cache.entries["dbs/db1/colls/col1"] = entry + cache.mu.Unlock() + + // Create a minimal ContainerClient with link matching the cache key + container := &ContainerClient{ + link: "dbs/db1/colls/col1", + } + + props, err := cache.getProperties(nil, container) //nolint:staticcheck // nil context is fine for cache hit + require.NoError(t, err) + require.Equal(t, expectedProps, props) +} + +func Test_containerPropertiesCache_set_crossPopulatesRID(t *testing.T) { + cache := newContainerPropertiesCache() + + props := &ContainerProperties{ + ID: "col1", + ResourceID: "rid123", + PartitionKeyDefinition: PartitionKeyDefinition{ + Paths: []string{"/pk"}, + Version: 2, + }, + } + + cache.set("dbs/db1/colls/col1", props) + + // Should be retrievable by link + cache.mu.RLock() + entry, exists := cache.entries["dbs/db1/colls/col1"] + cache.mu.RUnlock() + require.True(t, exists) + entry.mu.Lock() + require.Equal(t, props, entry.props) + entry.mu.Unlock() + + // Should also be retrievable by RID + result := cache.getPropertiesByRID("rid123") + require.NotNil(t, result) + require.Equal(t, "col1", result.ID) +} + +func Test_containerPropertiesCache_getPropertiesByRID_miss(t *testing.T) { + cache := newContainerPropertiesCache() + + result := cache.getPropertiesByRID("nonexistent") + require.Nil(t, result) +} + +func Test_containerPropertiesCache_getPropertiesByRID_emptyRID(t *testing.T) { + cache := newContainerPropertiesCache() + + result := cache.getPropertiesByRID("") + require.Nil(t, result) +} + +func Test_containerPropertiesCache_invalidate_removesRIDIndex(t *testing.T) { + cache := newContainerPropertiesCache() + + props := &ContainerProperties{ + ID: "col1", + ResourceID: "rid456", + PartitionKeyDefinition: PartitionKeyDefinition{ + Paths: []string{"/pk"}, + Version: 2, + }, + } + + cache.set("dbs/db1/colls/col1", props) + + // Verify RID lookup works + require.NotNil(t, cache.getPropertiesByRID("rid456")) + + // Invalidate by link + cache.invalidate("dbs/db1/colls/col1") + + // RID index should be removed + require.Nil(t, cache.getPropertiesByRID("rid456")) +} + +func Test_containerPropertiesCache_set_multipleContainers(t *testing.T) { + cache := newContainerPropertiesCache() + + props1 := &ContainerProperties{ + ID: "col1", + ResourceID: "rid1", + PartitionKeyDefinition: PartitionKeyDefinition{Paths: []string{"/pk1"}}, + } + props2 := &ContainerProperties{ + ID: "col2", + ResourceID: "rid2", + PartitionKeyDefinition: PartitionKeyDefinition{Paths: []string{"/pk2"}}, + } + + cache.set("dbs/db1/colls/col1", props1) + cache.set("dbs/db1/colls/col2", props2) + + // Both should be retrievable by RID + r1 := cache.getPropertiesByRID("rid1") + r2 := cache.getPropertiesByRID("rid2") + require.NotNil(t, r1) + require.NotNil(t, r2) + require.Equal(t, "col1", r1.ID) + require.Equal(t, "col2", r2.ID) +} diff --git a/sdk/data/azcosmos/cosmos_container_read_many.go b/sdk/data/azcosmos/cosmos_container_read_many.go index 92d60f875450..84421b8f0027 100644 --- a/sdk/data/azcosmos/cosmos_container_read_many.go +++ b/sdk/data/azcosmos/cosmos_container_read_many.go @@ -6,13 +6,17 @@ package azcosmos import ( "context" "errors" + "net/http" "sort" "sync" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos/internal/epk" ) const maxItemsPerQuery = 1000 +const maxPKRangeGoneRetries = 3 // queryChunk is a single parameterized query targeting one physical partition key range. type queryChunk struct { @@ -33,15 +37,17 @@ type chunkResult struct { // range ID and true if found, or ("", false) otherwise. func findPhysicalRangeForEPK(epkValue string, ranges []partitionKeyRange) (string, bool) { // Binary search: find the last range whose MinInclusive <= epkValue. + // Uses length-aware comparison for HPK containers with mixed-length EPK boundaries. idx := sort.Search(len(ranges), func(i int) bool { - return ranges[i].MinInclusive > epkValue + return epk.CompareEPK(ranges[i].MinInclusive, epkValue) > 0 }) - 1 if idx < 0 { return "", false } - // Verify epkValue < MaxExclusive (empty MaxExclusive means unbounded). + // Verify epkValue < MaxExclusive. + // Empty MaxExclusive or "FF" means unbounded (the last partition). r := ranges[idx] - if r.MaxExclusive != "" && epkValue >= r.MaxExclusive { + if r.MaxExclusive != "" && r.MaxExclusive != "FF" && epk.CompareEPK(epkValue, r.MaxExclusive) >= 0 { return "", false } return r.ID, true @@ -50,35 +56,49 @@ func findPhysicalRangeForEPK(epkValue string, ranges []partitionKeyRange) (strin // groupItemsByPhysicalRange computes the EPK for each item and groups them by // physical partition range. It returns the range IDs in first-seen order and // the groups keyed by range ID. +// +// For MultiHash containers, items with partial partition keys (fewer components +// than paths) are fanned out to all overlapping physical partitions via EPK +// range computation. func groupItemsByPhysicalRange(items []ItemIdentity, pkDef PartitionKeyDefinition, ranges []partitionKeyRange) ([]string, map[string][]ItemIdentity, error) { - // Sort ranges by MinInclusive for binary search. - sorted := make([]partitionKeyRange, len(ranges)) - copy(sorted, ranges) - sort.Slice(sorted, func(i, j int) bool { - return sorted[i].MinInclusive < sorted[j].MinInclusive - }) + // Build a routing map for efficient lookups. + routingMap := newCollectionRoutingMap(ranges, "") order := make([]string, 0) seen := make(map[string]bool) groups := make(map[string][]ItemIdentity) - pkVersion := pkDef.Version - if pkVersion == 0 { - pkVersion = 1 - } - for _, item := range items { - epkVal := item.PartitionKey.computeEffectivePartitionKey(pkDef.Kind, pkVersion) - rangeID, ok := findPhysicalRangeForEPK(epkVal.EPK, sorted) - if !ok { - return nil, nil, errors.New("could not find physical partition range for item EPK") + epkR, err := computeEPKRange(&item.PartitionKey, pkDef) + if err != nil { + return nil, nil, err } - if !seen[rangeID] { - order = append(order, rangeID) - seen[rangeID] = true + if epkR.isRange() { + // Prefix key: fan out to all overlapping ranges + overlapping := routingMap.getOverlappingRanges(epkR.Min, epkR.Max) + if len(overlapping) == 0 { + return nil, nil, errors.New("could not find physical partition range for item EPK range") + } + for _, r := range overlapping { + if !seen[r.ID] { + order = append(order, r.ID) + seen[r.ID] = true + } + groups[r.ID] = append(groups[r.ID], item) + } + } else { + // Point key: direct lookup + rangeID, ok := findPhysicalRangeForEPK(epkR.Min, routingMap.orderedRanges) + if !ok { + return nil, nil, errors.New("could not find physical partition range for item EPK") + } + if !seen[rangeID] { + order = append(order, rangeID) + seen[rangeID] = true + } + groups[rangeID] = append(groups[rangeID], item) } - groups[rangeID] = append(groups[rangeID], item) } return order, groups, nil @@ -246,6 +266,19 @@ func collectChunkResults(results []chunkResult) (ReadManyItemsResponse, error) { return ReadManyItemsResponse{RequestCharge: totalCharge, Items: allItems}, nil } +// hasAnyPKRangeGoneError scans all chunk results for a partition key range gone error. +// This is needed because concurrent chunk cancellation can cause a context.Canceled +// error to appear at a lower index than the actual 410/Gone error, masking it from +// collectChunkResults which returns the first error it encounters. +func hasAnyPKRangeGoneError(results []chunkResult) bool { + for _, res := range results { + if res.err != nil && isPKRangeGoneResponseError(res.err) { + return true + } + } + return false +} + // executeReadManyWithQueries groups items by physical partition range using EPK // hashing, builds parameterized SQL queries (one per physical range, chunked at // maxItemsPerQuery), and executes them concurrently. This replaces the previous @@ -257,27 +290,6 @@ func (c *ContainerClient) executeReadManyWithQueries( readManyOptions *ReadManyOptions, operationContext pipelineRequestOptions, ) (ReadManyItemsResponse, error) { - containerResp, err := c.Read(ctx, nil) - if err != nil { - return ReadManyItemsResponse{}, err - } - pkDef := containerResp.ContainerProperties.PartitionKeyDefinition - - pkRangeResp, err := c.getPartitionKeyRanges(ctx, nil) - if err != nil { - return ReadManyItemsResponse{}, err - } - - orderedRangeIDs, groups, err := groupItemsByPhysicalRange(items, pkDef, pkRangeResp.PartitionKeyRanges) - if err != nil { - return ReadManyItemsResponse{}, err - } - - chunks, err := buildQueryChunksForRanges(orderedRangeIDs, groups, pkDef) - if err != nil { - return ReadManyItemsResponse{}, err - } - concurrency := determineConcurrency(nil) if readManyOptions != nil { concurrency = determineConcurrency(readManyOptions.MaxConcurrency) @@ -290,10 +302,110 @@ func (c *ContainerClient) executeReadManyWithQueries( queryOpts.DedicatedGatewayRequestOptions = readManyOptions.DedicatedGatewayRequestOptions } - results, err := c.executeQueryChunks(ctx, chunks, queryOpts, operationContext, concurrency) - if err != nil { - // An error here means we failed to even start executing the queries, so we return it directly. Errors from individual chunks are handled in collectChunkResults. - return ReadManyItemsResponse{}, err + // Retry loop for partition key range gone (splits/merges). + // pkDef is resolved inside the loop so that after a 410-triggered cache + // invalidation (e.g., container deleted and recreated with a different + // partition key definition), we pick up the new schema — not just the new ranges. + for attempt := 0; attempt <= maxPKRangeGoneRetries; attempt++ { + var pkDef PartitionKeyDefinition + if c.database.client.getContainerCache() != nil { + containerProps, err := c.database.client.getContainerCache().getProperties(ctx, c) + if err != nil { + return ReadManyItemsResponse{}, err + } + pkDef = containerProps.PartitionKeyDefinition + } else { + // Fallback: direct fetch without caching + containerResp, err := c.Read(ctx, nil) + if err != nil { + return ReadManyItemsResponse{}, err + } + pkDef = containerResp.ContainerProperties.PartitionKeyDefinition + } + + pkRangeResp, err := c.getPartitionKeyRanges(ctx, nil) + if err != nil { + return ReadManyItemsResponse{}, err + } + + orderedRangeIDs, groups, err := groupItemsByPhysicalRange(items, pkDef, pkRangeResp.PartitionKeyRanges) + if err != nil { + return ReadManyItemsResponse{}, err + } + + chunks, err := buildQueryChunksForRanges(orderedRangeIDs, groups, pkDef) + if err != nil { + return ReadManyItemsResponse{}, err + } + + results, err := c.executeQueryChunks(ctx, chunks, queryOpts, operationContext, concurrency) + if err != nil { + if attempt < maxPKRangeGoneRetries && isPKRangeGoneResponseError(err) { + if refreshErr := c.refreshPKRangeCache(ctx); refreshErr != nil { + return ReadManyItemsResponse{}, refreshErr + } + continue + } + return ReadManyItemsResponse{}, err + } + + resp, err := collectChunkResults(results) + if err != nil { + // Check all results for 410/Gone, not just the first error returned by + // collectChunkResults. Concurrent chunk cancellation can cause a + // context.Canceled error at a lower index to mask the actual 410 error. + if attempt < maxPKRangeGoneRetries && (isPKRangeGoneResponseError(err) || hasAnyPKRangeGoneError(results)) { + if refreshErr := c.refreshPKRangeCache(ctx); refreshErr != nil { + return ReadManyItemsResponse{}, refreshErr + } + continue + } + return ReadManyItemsResponse{}, err + } + return resp, nil + } + + return ReadManyItemsResponse{}, errors.New("exhausted retries for partition key range gone") +} + +// refreshPKRangeCache forces a refresh of the partition key range cache for this container. +// It also invalidates the container properties cache so that getContainerRID fetches the +// current RID, which is necessary when a container is deleted and recreated with the same name. +// Returns an error if the refresh fails, allowing the caller to fail fast. +func (c *ContainerClient) refreshPKRangeCache(ctx context.Context) error { + if c.database.client.getContainerCache() != nil { + c.database.client.getContainerCache().invalidate(c.link) + } + if c.database.client.getPKRangeCache() != nil { + containerRID, err := c.getContainerRID(ctx) + if err != nil { + return err + } + _, err = c.database.client.getPKRangeCache().forceRefresh(ctx, containerRID, c.link, c.database.client) + if err != nil { + return err + } + } + return nil +} + +// isPKRangeGoneResponseError checks if an error is an azcore.ResponseError +// indicating a partition key range gone condition (HTTP 410 with split-related substatus). +func isPKRangeGoneResponseError(err error) bool { + var respErr *azcore.ResponseError + if !errors.As(err, &respErr) { + return false + } + if respErr.StatusCode != http.StatusGone { + return false + } + // Extract substatus from the raw response if available + if respErr.RawResponse != nil { + subStatus := respErr.RawResponse.Header.Get(cosmosHeaderSubstatus) + return isPartitionKeyRangeGoneError(respErr.StatusCode, subStatus) } - return collectChunkResults(results) + // If no raw response (unusual — typically means the HTTP response couldn't be read), + // conservatively treat any 410 Gone as a PKRange gone. The worst case is + // unnecessary cache refresh + retry (3x), which is preferable to failing permanently. + return true } diff --git a/sdk/data/azcosmos/cosmos_container_read_many_test.go b/sdk/data/azcosmos/cosmos_container_read_many_test.go index 34339a572cc3..1049cdb3cbf3 100644 --- a/sdk/data/azcosmos/cosmos_container_read_many_test.go +++ b/sdk/data/azcosmos/cosmos_container_read_many_test.go @@ -5,8 +5,10 @@ package azcosmos import ( "context" + "io" "net/http" "net/url" + "strings" "sync/atomic" "testing" @@ -165,6 +167,31 @@ func TestGroupItemsByPhysicalRange_DefaultVersion(t *testing.T) { require.Len(t, groups["0"], 1) } +func TestFindPhysicalRangeForEPK_FFSentinel(t *testing.T) { + // "FF" is the sentinel max boundary for the last partition. + // EPK values like "FFF697..." are longer strings that lexicographically + // exceed "FF" but must still match the last partition range. + ranges := []partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "3C3C3C3C"}, + {ID: "1", MinInclusive: "3C3C3C3C", MaxExclusive: "FF"}, + } + + // EPK that starts with "FF..." — lexicographically > "FF" but should match range 1 + rangeID, ok := findPhysicalRangeForEPK("FFF697AF545B8770396E3626B83A20AE", ranges) + require.True(t, ok) + require.Equal(t, "1", rangeID) + + // EPK at very start + rangeID, ok = findPhysicalRangeForEPK("0000000000000000", ranges) + require.True(t, ok) + require.Equal(t, "0", rangeID) + + // EPK at boundary + rangeID, ok = findPhysicalRangeForEPK("3C3C3C3C", ranges) + require.True(t, ok) + require.Equal(t, "1", rangeID) +} + func TestBuildQueryChunksForRanges_SingleRange(t *testing.T) { pkDef := PartitionKeyDefinition{Paths: []string{"/pk"}} orderedIDs := []string{"0"} @@ -297,3 +324,417 @@ func TestExecuteQueryChunks_CancelledContext(t *testing.T) { require.ErrorIs(t, err, context.Canceled) require.Empty(t, resp.Items) } + +func TestComputeEPKRange_FullKeyHash(t *testing.T) { + pkDef := PartitionKeyDefinition{ + Paths: []string{"/pk"}, + Kind: PartitionKeyKindHash, + Version: 2, + } + pk := NewPartitionKeyString("test") + r, err := computeEPKRange(&pk, pkDef) + require.NoError(t, err) + require.False(t, r.isRange(), "full key should be a point, not a range") + require.Equal(t, r.Min, r.Max) +} + +func TestComputeEPKRange_FullKeyMultiHash(t *testing.T) { + pkDef := PartitionKeyDefinition{ + Paths: []string{"/tenantId", "/userId"}, + Kind: PartitionKeyKindMultiHash, + Version: 2, + } + pk := NewPartitionKeyString("tenant1").AppendString("user1") + r, err := computeEPKRange(&pk, pkDef) + require.NoError(t, err) + require.False(t, r.isRange(), "full multi-hash key should be a point") +} + +func TestComputeEPKRange_PrefixKeyMultiHash(t *testing.T) { + pkDef := PartitionKeyDefinition{ + Paths: []string{"/tenantId", "/userId"}, + Kind: PartitionKeyKindMultiHash, + Version: 2, + } + pk := NewPartitionKeyString("tenant1") + r, err := computeEPKRange(&pk, pkDef) + require.NoError(t, err) + require.True(t, r.isRange(), "prefix key should produce a range") + require.Equal(t, r.Min+"FF", r.Max) +} + +func TestComputeEPKRange_TooManyComponents(t *testing.T) { + pkDef := PartitionKeyDefinition{ + Paths: []string{"/pk"}, + Kind: PartitionKeyKindHash, + Version: 2, + } + pk := NewPartitionKeyString("a").AppendString("b") // 2 components for 1 path + _, err := computeEPKRange(&pk, pkDef) + require.Error(t, err) + require.Contains(t, err.Error(), "more partition key components") +} + +func TestComputeEPKRange_NonMultiHashPartialKey(t *testing.T) { + pkDef := PartitionKeyDefinition{ + Paths: []string{"/a", "/b"}, + Kind: PartitionKeyKindHash, + Version: 2, + } + pk := NewPartitionKeyString("only-one") + _, err := computeEPKRange(&pk, pkDef) + require.Error(t, err) + require.Contains(t, err.Error(), "non-MultiHash") +} + +func TestComputeEPKRange_UndefinedPK(t *testing.T) { + pkDef := PartitionKeyDefinition{ + Paths: []string{"/pk"}, + Kind: PartitionKeyKindHash, + Version: 2, + } + pk := NewPartitionKey() + r, err := computeEPKRange(&pk, pkDef) + require.NoError(t, err) + require.False(t, r.isRange(), "undefined PK should be a point, not a range") + require.NotEmpty(t, r.Min) + require.Equal(t, r.Min, r.Max) +} + +func TestComputeEPKRange_UndefinedPK_MultiHash(t *testing.T) { + pkDef := PartitionKeyDefinition{ + Paths: []string{"/tenantId", "/userId"}, + Kind: PartitionKeyKindMultiHash, + Version: 2, + } + pk := NewPartitionKey() + r, err := computeEPKRange(&pk, pkDef) + require.NoError(t, err) + require.False(t, r.isRange(), "undefined PK should be a point even for MultiHash") + require.NotEmpty(t, r.Min) + require.Equal(t, r.Min, r.Max) +} + +func TestFindPhysicalRangeForEPK_MixedLengthBoundaries(t *testing.T) { + // Simulate HPK boundaries where one is 32-char partial and the next is 64-char zero-padded + partial := "06AB34CFE4E482236BCACBBF50E234AB" + fullZero := partial + "00000000000000000000000000000000" + + ranges := []partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: partial}, + {ID: "1", MinInclusive: fullZero, MaxExclusive: "FF"}, + } + + // An EPK exactly at the boundary should go to range 1 (length-aware: partial == fullZero) + id, ok := findPhysicalRangeForEPK(fullZero, ranges) + require.True(t, ok) + require.Equal(t, "1", id) + + // An EPK just below the boundary should go to range 0 + id, ok = findPhysicalRangeForEPK("06AB34CFE4E482236BCACBBF50E234AA", ranges) + require.True(t, ok) + require.Equal(t, "0", id) +} + +func TestGroupItemsByPhysicalRange_MultiHashPrefixFanout(t *testing.T) { + pkDef := PartitionKeyDefinition{ + Paths: []string{"/tenantId", "/userId"}, + Kind: PartitionKeyKindMultiHash, + Version: 2, + } + + // Compute the EPK range for "tenant1" prefix + prefixPK := NewPartitionKeyString("tenant1") + prefixRange, err := computeEPKRange(&prefixPK, pkDef) + require.NoError(t, err) + require.True(t, prefixRange.isRange()) + + // The prefix EPK is a 32-char hash. The range is [epk, epk+"FF"). + // Create a split point INSIDE that range by appending a mid-range suffix + // to the prefix EPK. E.g., if EPK is "AABB...", split at "AABB...80..." + splitPoint := prefixRange.Min + "80000000000000000000000000000000" + + ranges := []partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: splitPoint}, + {ID: "1", MinInclusive: splitPoint, MaxExclusive: "FF"}, + } + + items := []ItemIdentity{ + {ID: "1", PartitionKey: prefixPK}, + } + + orderedIDs, groups, err := groupItemsByPhysicalRange(items, pkDef, ranges) + require.NoError(t, err) + + // The prefix key should fan out to both partitions + require.Len(t, orderedIDs, 2, "prefix key should fan out to 2 partitions") + require.Contains(t, groups, "0") + require.Contains(t, groups, "1") + require.Len(t, groups["0"], 1) + require.Len(t, groups["1"], 1) +} + +func TestRefreshPKRangeCache_InvalidatesContainerCache(t *testing.T) { + cache := newContainerPropertiesCache() + + // Pre-populate the container cache with a stale entry + staleProps := &ContainerProperties{ + ID: "containerId", + ResourceID: "staleRID", + PartitionKeyDefinition: PartitionKeyDefinition{ + Paths: []string{"/pk"}, + Version: 2, + }, + } + containerLink := "dbs/databaseId/colls/containerId" + cache.set(containerLink, staleProps) + + // Verify stale RID is in the cache + require.NotNil(t, cache.getPropertiesByRID("staleRID")) + + // Set up a mock server that serves container properties for the re-fetch after invalidation + srv, close := mock.NewTLSServer() + defer close() + + containerPropsResponse := []byte(`{ + "id": "containerId", + "_rid": "newRID", + "_self": "dbs/db1/colls/containerId/", + "partitionKey": {"paths": ["/pk"], "kind": "Hash", "version": 2} + }`) + + pkRangeResponse := []byte(`{ + "_rid": "newRID", + "PartitionKeyRanges": [{ + "_rid": "newRID", + "id": "0", + "minInclusive": "", + "maxExclusive": "FF" + }], + "_count": 1 + }`) + + // First request: container props re-fetch (after invalidation) + // Second request: PK range fetch + srv.AppendResponse(mock.WithBody(containerPropsResponse), mock.WithStatusCode(200)) + srv.AppendResponse(mock.WithBody(pkRangeResponse), mock.WithStatusCode(200)) + + defaultEndpoint, _ := url.Parse(srv.URL()) + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv}) + gem := &globalEndpointManager{preferredLocations: []string{}} + client := &Client{ + endpoint: srv.URL(), + endpointUrl: defaultEndpoint, + internal: internalClient, + gem: gem, + caches: &sharedCacheSet{ + pkRangeCache: newPartitionKeyRangeCache(), + containerCache: cache, + }, + } + + database, _ := newDatabase("databaseId", client) + container, _ := newContainer("containerId", database) + + // Call refreshPKRangeCache — should invalidate container cache, re-fetch props, then refresh PK ranges + err := container.refreshPKRangeCache(context.TODO()) + require.NoError(t, err) + + // The stale RID entry should have been invalidated and replaced with the new one + require.Nil(t, cache.getPropertiesByRID("staleRID"), "stale RID should be invalidated from container cache") + require.NotNil(t, cache.getPropertiesByRID("newRID"), "new RID should be in container cache after refresh") +} + +// returnGoneOnQueryPolicy is a pipeline policy that returns a 410/Gone response +// with a configurable substatus on query requests. After maxGone 410s have been +// returned, subsequent queries pass through to the server normally. +type returnGoneOnQueryPolicy struct { + maxGone int32 // how many queries to fail with 410 + substatus string // substatus header value (e.g. "1002") + count atomic.Int32 // queries seen so far +} + +func (p *returnGoneOnQueryPolicy) Do(req *policy.Request) (*http.Response, error) { + isQuery := req.Raw().Header.Get(cosmosHeaderQuery) == "True" + if !isQuery { + return req.Next() + } + n := p.count.Add(1) + if n <= p.maxGone { + // Return a synthetic 410 response with substatus header + headers := http.Header{} + headers.Set(cosmosHeaderSubstatus, p.substatus) + return &http.Response{ + StatusCode: http.StatusGone, + Status: "410 Gone", + Header: headers, + Body: io.NopCloser(strings.NewReader(`{"message":"Gone"}`)), + Request: req.Raw(), + }, nil + } + return req.Next() +} + +// createReadManyTestClient builds a *Client with pre-populated container and PK range +// caches, wired to the given mock server with the specified per-call policies. +func createReadManyTestClient(t *testing.T, srv *mock.Server, policies []policy.Policy) *Client { + t.Helper() + defaultEndpoint, err := url.Parse(srv.URL()) + require.NoError(t, err) + + internalClient, err := azcore.NewClient("azcosmostest", "v1.0.0", + azruntime.PipelineOptions{PerCall: policies}, + &policy.ClientOptions{Transport: srv}) + require.NoError(t, err) + + gem := &globalEndpointManager{preferredLocations: []string{}} + containerCache := newContainerPropertiesCache() + pkRangeCache := newPartitionKeyRangeCache() + + client := &Client{ + endpoint: srv.URL(), + endpointUrl: defaultEndpoint, + internal: internalClient, + gem: gem, + caches: &sharedCacheSet{ + containerCache: containerCache, + pkRangeCache: pkRangeCache, + }, + } + + // Pre-populate container cache + containerLink := "dbs/databaseId/colls/containerId" + containerCache.set(containerLink, &ContainerProperties{ + ID: "containerId", + ResourceID: "testRID", + PartitionKeyDefinition: PartitionKeyDefinition{ + Paths: []string{"/pk"}, + Kind: PartitionKeyKindHash, + Version: 2, + }, + }) + + // Pre-populate PK range cache with one range covering the full key space + pkRangeCache.entries["testRID"] = &pkRangeCacheEntry{ + routingMap: newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "FF", ResourceID: "testRID"}, + }, "etag1"), + } + + return client +} + +func TestReadMany_410_retrySucceeds(t *testing.T) { + srv, closeSrv := mock.NewTLSServer() + defer closeSrv() + + gonePolicy := &returnGoneOnQueryPolicy{maxGone: 1, substatus: subStatusPartitionKeyRangeGone} + client := createReadManyTestClient(t, srv, []policy.Policy{gonePolicy}) + + // After the 410, refreshPKRangeCache will: + // 1. Invalidate container cache → re-fetch container props + // 2. Force refresh PK range cache → fetch PK ranges (incremental loop needs 304 to stop) + // Then the retry will execute the query again (passes through policy since maxGone=1) + containerPropsResp := []byte(`{ + "id": "containerId", + "_rid": "testRID", + "_self": "dbs/db1/colls/containerId/", + "partitionKey": {"paths": ["/pk"], "kind": "Hash", "version": 2} + }`) + pkRangeResp := []byte(`{ + "_rid": "testRID", + "PartitionKeyRanges": [{"_rid": "testRID", "id": "0", "minInclusive": "", "maxExclusive": "FF"}], + "_count": 1 + }`) + queryResp := []byte(`{"Documents":[{"id":"item1","pk":"pkA"}]}`) + + // Sequence: container props (re-fetch) → PK ranges (refresh) → 304 (end incremental loop) → query (retry succeeds) + srv.AppendResponse(mock.WithBody(containerPropsResp), mock.WithStatusCode(200)) + srv.AppendResponse(mock.WithBody(pkRangeResp), mock.WithStatusCode(200), + mock.WithHeader(cosmosHeaderEtag, "etag2")) + srv.AppendResponse(mock.WithStatusCode(304)) // End incremental refresh loop + srv.AppendResponse(mock.WithBody(queryResp), mock.WithStatusCode(200), + mock.WithHeader(cosmosHeaderRequestCharge, "1.0")) + + database, _ := newDatabase("databaseId", client) + container, _ := newContainer("containerId", database) + + items := []ItemIdentity{{ID: "item1", PartitionKey: NewPartitionKeyString("pkA")}} + resp, err := container.ReadManyItems(context.Background(), items, nil) + require.NoError(t, err) + require.Len(t, resp.Items, 1) + require.Equal(t, int32(2), gonePolicy.count.Load(), "expected 2 query attempts: 1 failed + 1 retry") +} + +func TestReadMany_410_exhaustedRetries(t *testing.T) { + srv, closeSrv := mock.NewTLSServer() + defer closeSrv() + + // Return 410 on ALL query attempts (maxPKRangeGoneRetries+1 = 4) + gonePolicy := &returnGoneOnQueryPolicy{maxGone: 100, substatus: subStatusPartitionKeyRangeGone} + client := createReadManyTestClient(t, srv, []policy.Policy{gonePolicy}) + + containerPropsResp := []byte(`{ + "id": "containerId", + "_rid": "testRID", + "_self": "dbs/db1/colls/containerId/", + "partitionKey": {"paths": ["/pk"], "kind": "Hash", "version": 2} + }`) + pkRangeResp := []byte(`{ + "_rid": "testRID", + "PartitionKeyRanges": [{"_rid": "testRID", "id": "0", "minInclusive": "", "maxExclusive": "FF"}], + "_count": 1 + }`) + + // Each retry needs container props + PK range re-fetch (with ETag) + 304 (end incremental loop). + // The PK range response must include an ETag header so changeFeedETag stays non-empty, + // ensuring subsequent retries use the incremental refresh path consistently. + // maxPKRangeGoneRetries=3, so we get 3 retries (attempts 0-2 trigger refresh), attempt 3 fails. + for i := 0; i < maxPKRangeGoneRetries; i++ { + srv.AppendResponse(mock.WithBody(containerPropsResp), mock.WithStatusCode(200)) + srv.AppendResponse(mock.WithBody(pkRangeResp), mock.WithStatusCode(200), + mock.WithHeader(cosmosHeaderEtag, "etag-refresh")) + srv.AppendResponse(mock.WithStatusCode(304)) // End incremental refresh loop + } + + database, _ := newDatabase("databaseId", client) + container, _ := newContainer("containerId", database) + + items := []ItemIdentity{{ID: "item1", PartitionKey: NewPartitionKeyString("pkA")}} + _, err := container.ReadManyItems(context.Background(), items, nil) + require.Error(t, err) + + // On the last attempt (attempt == maxPKRangeGoneRetries), the 410 error is + // returned directly since there are no retries remaining. + var respErr *azcore.ResponseError + require.ErrorAs(t, err, &respErr) + require.Equal(t, http.StatusGone, respErr.StatusCode) + + // Verify all 4 attempts were made (initial + 3 retries) + require.Equal(t, int32(maxPKRangeGoneRetries+1), gonePolicy.count.Load(), + "expected initial attempt + maxPKRangeGoneRetries retry attempts") +} + +func TestReadMany_410_nonGoneSubstatus_notRetried(t *testing.T) { + srv, closeSrv := mock.NewTLSServer() + defer closeSrv() + + // Return 410 with substatus "0" (NameCacheIsStale) — not a PKRange gone condition + gonePolicy := &returnGoneOnQueryPolicy{maxGone: 1, substatus: "0"} + client := createReadManyTestClient(t, srv, []policy.Policy{gonePolicy}) + + database, _ := newDatabase("databaseId", client) + container, _ := newContainer("containerId", database) + + items := []ItemIdentity{{ID: "item1", PartitionKey: NewPartitionKeyString("pkA")}} + _, err := container.ReadManyItems(context.Background(), items, nil) + require.Error(t, err) + + // Should have made only 1 query attempt (no retry for non-gone substatus) + require.Equal(t, int32(1), gonePolicy.count.Load(), "should not retry for non-PKRangeGone substatus") + + // Verify it's a 410 error + var respErr *azcore.ResponseError + require.ErrorAs(t, err, &respErr) + require.Equal(t, http.StatusGone, respErr.StatusCode) +} diff --git a/sdk/data/azcosmos/cosmos_feed_range_test.go b/sdk/data/azcosmos/cosmos_feed_range_test.go index cf25dd6a8fa0..8f6f82f7f8ba 100644 --- a/sdk/data/azcosmos/cosmos_feed_range_test.go +++ b/sdk/data/azcosmos/cosmos_feed_range_test.go @@ -12,6 +12,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" + "github.com/stretchr/testify/require" ) func TestContainerGetFeedRanges(t *testing.T) { @@ -142,3 +143,94 @@ func TestContainerGetFeedRangesEmpty(t *testing.T) { t.Fatalf("Expected 0 feed ranges, got %d", len(feedRanges)) } } + +func TestContainerGetFeedRanges_UsesCache(t *testing.T) { + containerResponse := []byte(`{ + "id": "containerId", + "_rid": "testRID", + "_self": "dbs/db1/colls/containerId/", + "partitionKey": { + "paths": ["/pk"], + "kind": "Hash", + "version": 2 + } + }`) + + pkRangeResponse := []byte(`{ + "_rid": "testRID", + "PartitionKeyRanges": [ + { + "_rid": "testRID_range0", + "id": "0", + "_etag": "\"etag0\"", + "minInclusive": "", + "maxExclusive": "05C1E18D2D7F08", + "status": "online", + "parents": [] + }, + { + "_rid": "testRID_range1", + "id": "1", + "_etag": "\"etag1\"", + "minInclusive": "05C1E18D2D7F08", + "maxExclusive": "FF", + "status": "online", + "parents": [] + } + ], + "_count": 2 + }`) + + srv, close := mock.NewTLSServer() + defer close() + + // First call will need: container read (for RID) + PK range fetch + srv.AppendResponse( + mock.WithBody(containerResponse), + mock.WithStatusCode(200), + ) + srv.AppendResponse( + mock.WithBody(pkRangeResponse), + mock.WithHeader(cosmosHeaderEtag, "changeFeedEtag1"), + mock.WithStatusCode(200), + ) + + defaultEndpoint, _ := url.Parse(srv.URL()) + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv}) + gem := &globalEndpointManager{preferredLocations: []string{}} + client := &Client{ + endpoint: srv.URL(), + endpointUrl: defaultEndpoint, + internal: internalClient, + gem: gem, + caches: &sharedCacheSet{ + pkRangeCache: newPartitionKeyRangeCache(), + containerCache: newContainerPropertiesCache(), + }, + } + + database, _ := newDatabase("databaseId", client) + container, _ := newContainer("containerId", database) + + // First call: populates caches, makes 2 HTTP requests + feedRanges, err := container.GetFeedRanges(context.TODO()) + require.NoError(t, err) + require.Equal(t, 2, len(feedRanges)) + require.Equal(t, "", feedRanges[0].MinInclusive) + require.Equal(t, "05C1E18D2D7F08", feedRanges[0].MaxExclusive) + require.Equal(t, "05C1E18D2D7F08", feedRanges[1].MinInclusive) + require.Equal(t, "FF", feedRanges[1].MaxExclusive) + + requestsAfterFirstCall := srv.Requests() + require.Equal(t, 2, requestsAfterFirstCall, "first call should make 2 HTTP requests (container read + PK ranges)") + + // Second call: should use caches, no additional HTTP requests + // (no more responses queued — would panic if a request was made) + feedRanges2, err := container.GetFeedRanges(context.TODO()) + require.NoError(t, err) + require.Equal(t, 2, len(feedRanges2)) + require.Equal(t, feedRanges[0], feedRanges2[0]) + require.Equal(t, feedRanges[1], feedRanges2[1]) + + require.Equal(t, requestsAfterFirstCall, srv.Requests(), "second call should make 0 HTTP requests (cache hit)") +} diff --git a/sdk/data/azcosmos/cosmos_global_cache_registry.go b/sdk/data/azcosmos/cosmos_global_cache_registry.go new file mode 100644 index 000000000000..3b12944697f7 --- /dev/null +++ b/sdk/data/azcosmos/cosmos_global_cache_registry.go @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +import ( + "net/url" + "strings" + "sync" + "sync/atomic" +) + +// sharedCacheSet groups the metadata caches that are shared across all Client +// instances targeting the same Cosmos DB account endpoint. This ensures that +// partition key range and container property metadata is fetched once per +// account regardless of how many Client instances exist. +type sharedCacheSet struct { + containerCache *containerPropertiesCache + pkRangeCache *partitionKeyRangeCache + refCount atomic.Int64 +} + +// globalCacheRegistry is a process-level registry of shared cache sets keyed +// by normalized account endpoint. It ensures singleton caches per account. +var globalCacheRegistry sync.Map // map[string]*sharedCacheSet + +// normalizeEndpoint returns a canonical form of the endpoint for use as a +// registry key. It lowercases the host and strips ports and paths so that +// endpoints like "https://account.documents.azure.com:443/" and +// "https://account.documents.azure.com" resolve to the same key. Ports are +// stripped entirely because the same account is the same account regardless +// of the port used to reach it. +func normalizeEndpoint(endpoint string) string { + u, err := url.Parse(strings.TrimSpace(endpoint)) + if err != nil || u.Host == "" { + // Fallback for malformed input + return strings.TrimRight(strings.ToLower(endpoint), "/") + } + return u.Scheme + "://" + strings.ToLower(u.Hostname()) +} + +// acquireCaches returns the shared cache set for the given endpoint, creating +// one if it doesn't exist. The caller must call releaseCaches when the Client +// is closed to allow cleanup. +// +// The implementation uses a verify-after-increment pattern to prevent a TOCTOU +// race with releaseCaches: after incrementing the refCount, it confirms the +// entry is still in the registry. If a concurrent release evicted it, the +// increment is undone and the loop retries. +func acquireCaches(endpoint string) *sharedCacheSet { + key := normalizeEndpoint(endpoint) + + for { + if val, ok := globalCacheRegistry.Load(key); ok { + set := val.(*sharedCacheSet) + set.refCount.Add(1) + // Verify the entry is still registered after we incremented. + // A concurrent releaseCaches could have evicted it between our + // Load and Add, leaving us with an orphaned cache set. + if val2, ok2 := globalCacheRegistry.Load(key); ok2 && val2 == val { + return set + } + // Entry was evicted — undo our increment and retry. + set.refCount.Add(-1) + continue + } + + // Slow path: create new and use LoadOrStore to avoid races + newSet := &sharedCacheSet{ + containerCache: newContainerPropertiesCache(), + pkRangeCache: newPartitionKeyRangeCache(), + } + newSet.refCount.Store(1) + + _, loaded := globalCacheRegistry.LoadOrStore(key, newSet) + if loaded { + // Another goroutine created it first — use theirs via the retry loop + // to get the verify-after-increment safety. + continue + } + return newSet + } +} + +// releaseCaches decrements the reference count for the given endpoint's cache +// set and removes it from the registry when no clients remain. +func releaseCaches(endpoint string) { + key := normalizeEndpoint(endpoint) + if val, ok := globalCacheRegistry.Load(key); ok { + set := val.(*sharedCacheSet) + if set.refCount.Add(-1) <= 0 { + globalCacheRegistry.CompareAndDelete(key, val) + } + } +} + +// resetGlobalCacheRegistry clears the global cache registry. +// This is intended for test isolation only. +func resetGlobalCacheRegistry() { + globalCacheRegistry.Range(func(key, _ any) bool { + globalCacheRegistry.Delete(key) + return true + }) +} diff --git a/sdk/data/azcosmos/cosmos_global_cache_registry_test.go b/sdk/data/azcosmos/cosmos_global_cache_registry_test.go new file mode 100644 index 000000000000..448b51fa13d9 --- /dev/null +++ b/sdk/data/azcosmos/cosmos_global_cache_registry_test.go @@ -0,0 +1,336 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +import ( + "context" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNormalizeEndpoint(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"https://MyAccount.documents.azure.com/", "https://myaccount.documents.azure.com"}, + {"https://MyAccount.documents.azure.com", "https://myaccount.documents.azure.com"}, + {"https://MYACCOUNT.DOCUMENTS.AZURE.COM///", "https://myaccount.documents.azure.com"}, + {"https://myaccount.documents.azure.com:443/", "https://myaccount.documents.azure.com"}, + {"https://myaccount.documents.azure.com:443", "https://myaccount.documents.azure.com"}, + {"http://myaccount.documents.azure.com:80", "http://myaccount.documents.azure.com"}, + {"https://localhost:8081/", "https://localhost"}, + {"https://localhost:8081", "https://localhost"}, + {" https://myaccount.documents.azure.com ", "https://myaccount.documents.azure.com"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + require.Equal(t, tt.expected, normalizeEndpoint(tt.input)) + }) + } +} + +func TestAcquireCaches_SameEndpoint_ReturnsSameInstance(t *testing.T) { + resetGlobalCacheRegistry() + defer resetGlobalCacheRegistry() + + set1 := acquireCaches("https://account1.documents.azure.com/") + set2 := acquireCaches("https://account1.documents.azure.com/") + set3 := acquireCaches("https://Account1.Documents.Azure.Com") // different case, no trailing slash + + require.Same(t, set1, set2, "same endpoint should return same cache set") + require.Same(t, set1, set3, "normalization should make these equivalent") + require.Equal(t, int64(3), set1.refCount.Load()) +} + +func TestAcquireCaches_DifferentEndpoints_ReturnDifferentInstances(t *testing.T) { + resetGlobalCacheRegistry() + defer resetGlobalCacheRegistry() + + set1 := acquireCaches("https://account1.documents.azure.com") + set2 := acquireCaches("https://account2.documents.azure.com") + + require.NotSame(t, set1, set2, "different endpoints should return different cache sets") + require.NotSame(t, set1.containerCache, set2.containerCache) + require.NotSame(t, set1.pkRangeCache, set2.pkRangeCache) +} + +func TestReleaseCaches_RemovesEntryWhenZeroRefs(t *testing.T) { + resetGlobalCacheRegistry() + defer resetGlobalCacheRegistry() + + endpoint := "https://account1.documents.azure.com" + set1 := acquireCaches(endpoint) + _ = acquireCaches(endpoint) // refCount = 2 + + releaseCaches(endpoint) // refCount = 1 + require.Equal(t, int64(1), set1.refCount.Load()) + + // Entry should still be in the registry + val, ok := globalCacheRegistry.Load(normalizeEndpoint(endpoint)) + require.True(t, ok) + require.Same(t, set1, val.(*sharedCacheSet)) + + releaseCaches(endpoint) // refCount = 0 → removed + _, ok = globalCacheRegistry.Load(normalizeEndpoint(endpoint)) + require.False(t, ok, "entry should be removed when refCount reaches 0") +} + +func TestReleaseCaches_NewAcquireAfterFullRelease_CreatesNew(t *testing.T) { + resetGlobalCacheRegistry() + defer resetGlobalCacheRegistry() + + endpoint := "https://account1.documents.azure.com" + set1 := acquireCaches(endpoint) + releaseCaches(endpoint) // refCount = 0, removed + + set2 := acquireCaches(endpoint) // new instance + require.NotSame(t, set1, set2, "should create a fresh cache set after full release") +} + +func TestSharedCaches_CrossClientCacheHit(t *testing.T) { + resetGlobalCacheRegistry() + defer resetGlobalCacheRegistry() + + endpoint := "https://account1.documents.azure.com" + + // Simulate two clients targeting the same endpoint + clientA := &Client{endpoint: endpoint, caches: acquireCaches(endpoint)} + clientB := &Client{endpoint: endpoint, caches: acquireCaches(endpoint)} + defer clientA.Close() + defer clientB.Close() + + // Verify they share the exact same cache instances + require.Same(t, clientA.getContainerCache(), clientB.getContainerCache()) + require.Same(t, clientA.getPKRangeCache(), clientB.getPKRangeCache()) + + // Client A populates the container cache + containerLink := "dbs/db1/colls/col1" + props := &ContainerProperties{ + ID: "col1", + ResourceID: "rid-abc", + PartitionKeyDefinition: PartitionKeyDefinition{ + Paths: []string{"/pk"}, + Version: 2, + }, + } + clientA.getContainerCache().set(containerLink, props) + + // Client B reads from cache — gets the same value with zero HTTP calls + containerB := &ContainerClient{link: containerLink} + result, err := clientB.getContainerCache().getProperties(context.Background(), containerB) + require.NoError(t, err) + require.Equal(t, props, result) + require.Equal(t, "rid-abc", result.ResourceID) +} + +func TestSharedCaches_CrossClientInvalidationVisibility(t *testing.T) { + resetGlobalCacheRegistry() + defer resetGlobalCacheRegistry() + + endpoint := "https://account1.documents.azure.com" + + clientA := &Client{endpoint: endpoint, caches: acquireCaches(endpoint)} + clientB := &Client{endpoint: endpoint, caches: acquireCaches(endpoint)} + defer clientA.Close() + defer clientB.Close() + + containerLink := "dbs/db1/colls/col1" + props := &ContainerProperties{ + ID: "col1", + ResourceID: "rid-abc", + PartitionKeyDefinition: PartitionKeyDefinition{ + Paths: []string{"/pk"}, + Version: 2, + }, + } + clientA.getContainerCache().set(containerLink, props) + + // Client B invalidates (e.g., got a 410) + clientB.getContainerCache().invalidate(containerLink) + + // Client A should also see the invalidation since they share the cache + containerA := &ContainerClient{link: containerLink} + // getProperties will try to refresh — but no pipeline is wired, so + // we check the entry directly to confirm the nil state is visible + clientA.getContainerCache().mu.RLock() + entry := clientA.getContainerCache().entries[containerLink] + clientA.getContainerCache().mu.RUnlock() + require.NotNil(t, entry) + entry.mu.Lock() + require.Nil(t, entry.props, "invalidation by Client B should be visible to Client A") + entry.mu.Unlock() + _ = containerA +} + +func TestSharedCaches_ConcurrentClientsRefreshSingleFlight(t *testing.T) { + resetGlobalCacheRegistry() + defer resetGlobalCacheRegistry() + + endpoint := "https://account1.documents.azure.com" + caches := acquireCaches(endpoint) + defer releaseCaches(endpoint) + + containerLink := "dbs/db1/colls/col1" + props := &ContainerProperties{ + ID: "col1", + ResourceID: "rid-abc", + PartitionKeyDefinition: PartitionKeyDefinition{ + Paths: []string{"/pk"}, + Version: 2, + }, + } + + // Pre-populate cache, then invalidate to simulate a 410 + caches.containerCache.set(containerLink, props) + caches.containerCache.invalidate(containerLink) + + // Now manually set the entry.props to simulate a refresh completing + // while many goroutines are waiting on the entry lock. + // This tests that all waiters get the value once ONE sets it. + const numClients = 50 + var wg sync.WaitGroup + var fetchCount atomic.Int64 + results := make([]*ContainerProperties, numClients) + + // Grab the entry and manually control the lock to simulate the race + caches.containerCache.mu.RLock() + entry := caches.containerCache.entries[containerLink] + caches.containerCache.mu.RUnlock() + require.NotNil(t, entry) + + // Lock the entry before spawning goroutines — simulates a refresh in progress + entry.mu.Lock() + + wg.Add(numClients) + for i := 0; i < numClients; i++ { + go func(idx int) { + defer wg.Done() + // Each goroutine tries to acquire entry lock (simulating getProperties) + entry.mu.Lock() + if entry.props == nil { + // "I'm the one refreshing" — simulate HTTP call + fetchCount.Add(1) + entry.props = &ContainerProperties{ + ID: "col1", + ResourceID: "rid-new", + } + } + results[idx] = entry.props + entry.mu.Unlock() + }(i) + } + + // Release the lock — one goroutine will "win" and set the value, + // all others will see it already set. + entry.mu.Unlock() + wg.Wait() + + // Exactly one goroutine should have done the "fetch" + require.Equal(t, int64(1), fetchCount.Load(), + "only one goroutine should refresh; others should see the populated value") + + // All goroutines should have gotten the same result + for i := 0; i < numClients; i++ { + require.NotNil(t, results[i]) + require.Equal(t, "rid-new", results[i].ResourceID) + } +} + +func TestAcquireCaches_ConcurrentSafe(t *testing.T) { + resetGlobalCacheRegistry() + defer resetGlobalCacheRegistry() + + endpoint := "https://account1.documents.azure.com" + const goroutines = 100 + + var wg sync.WaitGroup + results := make([]*sharedCacheSet, goroutines) + + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + results[idx] = acquireCaches(endpoint) + }(i) + } + wg.Wait() + + // All goroutines should have gotten the same instance + for i := 1; i < goroutines; i++ { + require.Same(t, results[0], results[i], "concurrent acquires should return same instance") + } + require.Equal(t, int64(goroutines), results[0].refCount.Load()) +} + +func TestAcquireCaches_ResilientToInterleavedRelease(t *testing.T) { + resetGlobalCacheRegistry() + defer resetGlobalCacheRegistry() + + endpoint := "https://account1.documents.azure.com" + + // Simulate the TOCTOU scenario: acquire, then concurrent release+acquire. + // After the race, every acquire must still return the same registered instance. + set1 := acquireCaches(endpoint) // refCount = 1 + + // Release fully (refCount -> 0, entry removed) + releaseCaches(endpoint) + + // Now a new acquire should create a fresh set (the old one is gone) + set2 := acquireCaches(endpoint) + require.NotSame(t, set1, set2, "after full release, a new set should be created") + require.Equal(t, int64(1), set2.refCount.Load()) + + // Stress test: interleave acquires and releases concurrently + const goroutines = 100 + var wg sync.WaitGroup + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + s := acquireCaches(endpoint) + // Verify the returned set is actually in the registry + val, ok := globalCacheRegistry.Load(normalizeEndpoint(endpoint)) + require.True(t, ok, "acquired set must be in registry") + require.Same(t, s, val.(*sharedCacheSet), "acquired set must match registry entry") + releaseCaches(endpoint) + }() + } + wg.Wait() + + // set2 should still be valid (we haven't released it) + val, ok := globalCacheRegistry.Load(normalizeEndpoint(endpoint)) + require.True(t, ok) + require.Same(t, set2, val.(*sharedCacheSet)) + require.Equal(t, int64(1), set2.refCount.Load()) + releaseCaches(endpoint) +} + +func TestClientClose_Idempotent(t *testing.T) { + resetGlobalCacheRegistry() + defer resetGlobalCacheRegistry() + + endpoint := "https://account1.documents.azure.com" + client := &Client{endpoint: endpoint, caches: acquireCaches(endpoint)} + + // Also acquire a second reference so we can observe the refCount + _ = acquireCaches(endpoint) // refCount = 2 + + client.Close() // refCount -> 1 + client.Close() // idempotent — should NOT decrement again + client.Close() // still idempotent + + // RefCount should be 1 (the second acquire), not -1 or 0 + val, ok := globalCacheRegistry.Load(normalizeEndpoint(endpoint)) + require.True(t, ok, "entry should still exist — second reference is alive") + set := val.(*sharedCacheSet) + require.Equal(t, int64(1), set.refCount.Load()) + + // Clean up the second reference + releaseCaches(endpoint) +} diff --git a/sdk/data/azcosmos/cosmos_http_constants.go b/sdk/data/azcosmos/cosmos_http_constants.go index 73a81157b05a..65a521956465 100644 --- a/sdk/data/azcosmos/cosmos_http_constants.go +++ b/sdk/data/azcosmos/cosmos_http_constants.go @@ -3,6 +3,8 @@ package azcosmos +import "net/http" + // Headers const ( cosmosHeaderRequestCharge string = "x-ms-request-charge" @@ -96,8 +98,26 @@ const ( ) // Substatus Codes +// NOTE: Some substatus values are reused across different HTTP status codes. +// Always check the HTTP status code first before interpreting the substatus. +// - 1002: ReadSessionNotAvailable (404) / PartitionKeyRangeGone (410) +// - 1008: DatabaseAccountNotFound (403) / CompletingPartitionMigration (410) const ( - subStatusWriteForbidden string = "3" - subStatusDatabaseAccountNotFound string = "1008" - subStatusReadSessionNotAvailable string = "1002" + subStatusWriteForbidden string = "3" + subStatusDatabaseAccountNotFound string = "1008" + subStatusReadSessionNotAvailable string = "1002" + subStatusPartitionKeyRangeGone string = "1002" + subStatusCompletingSplit string = "1007" + subStatusCompletingPartitionMigration string = "1008" ) + +// isPartitionKeyRangeGoneError checks if an error response indicates a +// partition key range gone condition (HTTP 410 with split-related substatus). +func isPartitionKeyRangeGoneError(statusCode int, subStatus string) bool { + if statusCode != http.StatusGone { + return false + } + return subStatus == subStatusPartitionKeyRangeGone || + subStatus == subStatusCompletingSplit || + subStatus == subStatusCompletingPartitionMigration +} diff --git a/sdk/data/azcosmos/cosmos_partition_key_range_cache.go b/sdk/data/azcosmos/cosmos_partition_key_range_cache.go new file mode 100644 index 000000000000..a9dc06e7265b --- /dev/null +++ b/sdk/data/azcosmos/cosmos_partition_key_range_cache.go @@ -0,0 +1,253 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "sync" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" +) + +// partitionKeyRangeCache provides a client-level cache of partition key ranges +// for containers. It is keyed by container ResourceID (RID) and uses event-driven +// invalidation (no TTL). Refreshes are incremental using the change-feed ETag. +// Keying by RID (rather than name-based link) ensures the cache survives +// container renames and matches the service's partition key range addressing. +type partitionKeyRangeCache struct { + mu sync.RWMutex + entries map[string]*pkRangeCacheEntry // keyed by container ResourceID +} + +type pkRangeCacheEntry struct { + mu sync.Mutex // single-flights refresh for this container + routingMap *collectionRoutingMap +} + +func newPartitionKeyRangeCache() *partitionKeyRangeCache { + return &partitionKeyRangeCache{ + entries: make(map[string]*pkRangeCacheEntry), + } +} + +// getRoutingMap returns the cached routing map for the given container RID. +// If the cache is empty for this container, it fetches from the service. +// containerLink is the name-based path used for the HTTP request. +func (c *partitionKeyRangeCache) getRoutingMap( + ctx context.Context, + containerRID string, + containerLink string, + client *Client, +) (*collectionRoutingMap, error) { + // Fast path: read lock check + c.mu.RLock() + entry, exists := c.entries[containerRID] + c.mu.RUnlock() + + if exists { + entry.mu.Lock() + if entry.routingMap != nil { + rm := entry.routingMap + entry.mu.Unlock() + return rm, nil + } + // Cache entry exists but routing map is nil (invalidated) — refresh under lock + rm, err := c.refreshEntry(ctx, containerLink, entry, client) + entry.mu.Unlock() + return rm, err + } + + // Slow path: create entry + c.mu.Lock() + // Double check after acquiring write lock + entry, exists = c.entries[containerRID] + if !exists { + entry = &pkRangeCacheEntry{} + c.entries[containerRID] = entry + } + c.mu.Unlock() + + entry.mu.Lock() + if entry.routingMap != nil { + rm := entry.routingMap + entry.mu.Unlock() + return rm, nil + } + rm, err := c.refreshEntry(ctx, containerLink, entry, client) + entry.mu.Unlock() + return rm, err +} + +// forceRefresh triggers an incremental refresh of the routing map for the given +// container. If the incremental merge fails (incomplete covering), it falls back +// to a full refresh. containerRID is the cache key; containerLink is used for HTTP requests. +func (c *partitionKeyRangeCache) forceRefresh( + ctx context.Context, + containerRID string, + containerLink string, + client *Client, +) (*collectionRoutingMap, error) { + c.mu.RLock() + entry, exists := c.entries[containerRID] + c.mu.RUnlock() + + if !exists { + // No entry yet — just do a normal get which will create and populate + return c.getRoutingMap(ctx, containerRID, containerLink, client) + } + + entry.mu.Lock() + defer entry.mu.Unlock() + return c.refreshEntry(ctx, containerLink, entry, client) +} + +// invalidate removes the cached routing map for the given container RID, +// forcing the next access to fetch fresh data. +func (c *partitionKeyRangeCache) invalidate(containerRID string) { + c.mu.RLock() + entry, exists := c.entries[containerRID] + c.mu.RUnlock() + + if exists { + entry.mu.Lock() + entry.routingMap = nil + entry.mu.Unlock() + } +} + +// maxIncrementalRefreshIterations caps the number of incremental fetch loops +// to prevent runaway requests during large-scale splits. Aligned with the Rust SDK. +const maxIncrementalRefreshIterations = 1000 + +// refreshEntry fetches PK ranges from the service and populates the entry. +// It attempts an incremental refresh if a previous routing map with an ETag exists, +// looping until 304 Not Modified (capped at maxIncrementalRefreshIterations). +// Falls back to a full refresh if the incremental merge is incomplete. +// Caller must hold entry.mu. +func (c *partitionKeyRangeCache) refreshEntry( + ctx context.Context, + containerLink string, + entry *pkRangeCacheEntry, + client *Client, +) (*collectionRoutingMap, error) { + previousMap := entry.routingMap + + if previousMap != nil && previousMap.changeFeedETag != "" { + // Incremental refresh loop: keep fetching until 304 or iteration cap + currentMap := previousMap + for i := 0; i < maxIncrementalRefreshIterations; i++ { + if err := ctx.Err(); err != nil { + return nil, err + } + ranges, newETag, err := fetchPartitionKeyRanges(ctx, containerLink, currentMap, client) + if err != nil { + return nil, err + } + + if len(ranges) == 0 { + // 304 Not Modified — no more changes + if newETag != "" && newETag != currentMap.changeFeedETag { + currentMap = &collectionRoutingMap{ + orderedRanges: currentMap.orderedRanges, + rangeByID: currentMap.rangeByID, + goneRanges: currentMap.goneRanges, + changeFeedETag: newETag, + } + } + entry.routingMap = currentMap + return currentMap, nil + } + + merged := currentMap.tryCombine(ranges, newETag) + if merged == nil { + // Incremental merge failed — fall through to full refresh + break + } + currentMap = merged + } + + // Loop exited without 304 — either iteration cap or merge failure. + // Fall through to full refresh to guarantee consistency with the service. + } + + // Full refresh: fetch all ranges without ETag + ranges, newETag, err := fetchPartitionKeyRanges(ctx, containerLink, nil, client) + if err != nil { + return nil, err + } + + newMap := newCollectionRoutingMap(ranges, newETag) + if !isCompleteSetOfRanges(newMap.orderedRanges) { + return nil, fmt.Errorf("incomplete partition key range set after full refresh for %s", containerLink) + } + + entry.routingMap = newMap + return newMap, nil +} + +// fetchPartitionKeyRanges fetches partition key ranges from the service. +// If previousMap is non-nil and has an ETag, it uses incremental feed mode. +func fetchPartitionKeyRanges( + ctx context.Context, + containerLink string, + previousMap *collectionRoutingMap, + client *Client, +) ([]partitionKeyRange, string, error) { + operationContext := pipelineRequestOptions{ + resourceType: resourceTypePartitionKeyRange, + resourceAddress: containerLink, + } + + path, err := generatePathForNameBased(resourceTypePartitionKeyRange, operationContext.resourceAddress, true) + if err != nil { + return nil, "", err + } + + var changeFeedETag string + if previousMap != nil { + changeFeedETag = previousMap.changeFeedETag + } + + o := &partitionKeyRangeOptions{} + + azResponse, err := client.sendGetRequest( + path, + ctx, + operationContext, + o, + func(req *policy.Request) { + if changeFeedETag != "" { + req.Raw().Header.Set(cosmosHeaderChangeFeed, cosmosHeaderValuesChangeFeed) + req.Raw().Header.Set(headerIfNoneMatch, changeFeedETag) + } + }) + if err != nil { + return nil, "", err + } + + newETag := azResponse.Header.Get(cosmosHeaderEtag) + + // 304 Not Modified means no changes + if azResponse.StatusCode == http.StatusNotModified { + _ = azResponse.Body.Close() + return nil, newETag, nil + } + + body, err := azruntime.Payload(azResponse) + if err != nil { + return nil, "", err + } + _ = azResponse.Body.Close() + + var response partitionKeyRangeResponse + if err := json.Unmarshal(body, &response); err != nil { + return nil, "", err + } + + return response.PartitionKeyRanges, newETag, nil +} diff --git a/sdk/data/azcosmos/cosmos_partition_key_range_cache_test.go b/sdk/data/azcosmos/cosmos_partition_key_range_cache_test.go new file mode 100644 index 000000000000..8ca53e5aece5 --- /dev/null +++ b/sdk/data/azcosmos/cosmos_partition_key_range_cache_test.go @@ -0,0 +1,379 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +import ( + "context" + "net/url" + "sync" + "sync/atomic" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" + "github.com/stretchr/testify/require" +) + +func Test_partitionKeyRangeCache_newCache_startsEmpty(t *testing.T) { + cache := newPartitionKeyRangeCache() + + cache.mu.RLock() + require.Empty(t, cache.entries) + cache.mu.RUnlock() +} + +func Test_partitionKeyRangeCache_invalidate_nilEntry(t *testing.T) { + cache := newPartitionKeyRangeCache() + + // Invalidating a non-existent entry should not panic + cache.invalidate("rid1") + + cache.mu.RLock() + _, exists := cache.entries["rid1"] + cache.mu.RUnlock() + require.False(t, exists) +} + +func Test_partitionKeyRangeCache_invalidate_existingEntry(t *testing.T) { + cache := newPartitionKeyRangeCache() + + // Manually populate + entry := &pkRangeCacheEntry{ + routingMap: newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "FF"}, + }, "etag1"), + } + cache.mu.Lock() + cache.entries["rid1"] = entry + cache.mu.Unlock() + + // Verify populated + entry.mu.Lock() + require.NotNil(t, entry.routingMap) + entry.mu.Unlock() + + // Invalidate + cache.invalidate("rid1") + + // Verify nil + entry.mu.Lock() + require.Nil(t, entry.routingMap) + entry.mu.Unlock() +} + +func Test_partitionKeyRangeCache_getRoutingMap_cacheHit(t *testing.T) { + cache := newPartitionKeyRangeCache() + + expectedRM := newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "05C1"}, + {ID: "1", MinInclusive: "05C1", MaxExclusive: "FF"}, + }, "etag1") + + entry := &pkRangeCacheEntry{routingMap: expectedRM} + cache.mu.Lock() + cache.entries["rid1"] = entry + cache.mu.Unlock() + + // getRoutingMap with a nil client should return cached value without calling service + rm, err := cache.getRoutingMap(context.Background(), "rid1", "dbs/db1/colls/col1", nil) + require.NoError(t, err) + require.Equal(t, expectedRM, rm) +} + +func Test_partitionKeyRangeCache_singleFlight(t *testing.T) { + cache := newPartitionKeyRangeCache() + + expectedRM := newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "FF"}, + }, "etag1") + + // Pre-populate entry with nil routingMap (simulates invalidated state) + entry := &pkRangeCacheEntry{routingMap: nil} + cache.mu.Lock() + cache.entries["rid1"] = entry + cache.mu.Unlock() + + // Since we can't mock the HTTP call easily, we'll simulate by manually setting + // the routing map after a "concurrent" lock acquisition. + // This test verifies the per-entry mutex protects single-flight. + var callCount int32 + + // Simulate multiple goroutines trying to get the routing map + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + entry.mu.Lock() + if entry.routingMap == nil { + atomic.AddInt32(&callCount, 1) + entry.routingMap = expectedRM + } + entry.mu.Unlock() + }() + } + wg.Wait() + + // Only one goroutine should have populated the map + require.Equal(t, int32(1), callCount) + require.Equal(t, expectedRM, entry.routingMap) +} + +func Test_partitionKeyRangeCache_entryMutex_noDeadlock(t *testing.T) { + cache := newPartitionKeyRangeCache() + + // Pre-populate an entry and verify we can acquire its mutex + // without deadlocking against the cache-level mutex. + rm := newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "FF"}, + }, "etag1") + entry := &pkRangeCacheEntry{routingMap: rm} + cache.mu.Lock() + cache.entries["rid1"] = entry + cache.mu.Unlock() + + entry.mu.Lock() + require.NotNil(t, entry.routingMap) + entry.mu.Unlock() +} + +// createMockClientForPKRangeCache creates a *Client wired to the given mock server +// with both caches initialized, suitable for testing PK range cache refresh flows. +func createMockClientForPKRangeCache(srv *mock.Server) *Client { + defaultEndpoint, _ := url.Parse(srv.URL()) + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv}) + gem := &globalEndpointManager{preferredLocations: []string{}} + return &Client{ + endpoint: srv.URL(), + endpointUrl: defaultEndpoint, + internal: internalClient, + gem: gem, + caches: &sharedCacheSet{ + pkRangeCache: newPartitionKeyRangeCache(), + containerCache: newContainerPropertiesCache(), + }, + } +} + +func Test_partitionKeyRangeCache_cacheMiss_fullRefresh(t *testing.T) { + // Scenario: Empty cache, getRoutingMap triggers a full refresh from the service. + pkRangeResponse := []byte(`{ + "_rid": "testRID", + "PartitionKeyRanges": [ + {"_rid": "r0", "id": "0", "minInclusive": "", "maxExclusive": "05C1E18D2D7F08", "parents": []}, + {"_rid": "r1", "id": "1", "minInclusive": "05C1E18D2D7F08", "maxExclusive": "FF", "parents": []} + ], + "_count": 2 + }`) + + srv, close := mock.NewTLSServer() + defer close() + + // Container properties response (for getContainerRID) + srv.AppendResponse( + mock.WithBody([]byte(`{"id": "col1", "_rid": "testRID", "partitionKey": {"paths": ["/pk"], "kind": "Hash", "version": 2}}`)), + mock.WithStatusCode(200), + ) + // PK range response (full refresh) + srv.AppendResponse( + mock.WithBody(pkRangeResponse), + mock.WithHeader(cosmosHeaderEtag, "etag1"), + mock.WithStatusCode(200), + ) + + client := createMockClientForPKRangeCache(srv) + database, _ := newDatabase("db1", client) + container, _ := newContainer("col1", database) + + // Use getPartitionKeyRanges which goes through getContainerRID → cache → getRoutingMap + resp, err := container.getPartitionKeyRanges(context.Background(), nil) + require.NoError(t, err) + require.Equal(t, 2, resp.Count) + require.Equal(t, "", resp.PartitionKeyRanges[0].MinInclusive) + require.Equal(t, "05C1E18D2D7F08", resp.PartitionKeyRanges[0].MaxExclusive) + require.Equal(t, "05C1E18D2D7F08", resp.PartitionKeyRanges[1].MinInclusive) + require.Equal(t, "FF", resp.PartitionKeyRanges[1].MaxExclusive) + + // Verify cache is populated + rm, err := client.caches.pkRangeCache.getRoutingMap(context.Background(), "testRID", container.link, client) + require.NoError(t, err) + require.Equal(t, 2, len(rm.orderedRanges)) + require.Equal(t, "etag1", rm.changeFeedETag) +} + +func Test_partitionKeyRangeCache_incrementalRefresh_success(t *testing.T) { + // Scenario: Cache has 1 range with ETag. Server returns 2 child ranges (split), + // then 304. forceRefresh merges them incrementally. + srv, close := mock.NewTLSServer() + defer close() + + // Incremental feed response: 2 children replacing parent "0" + srv.AppendResponse( + mock.WithBody([]byte(`{ + "_rid": "testRID", + "PartitionKeyRanges": [ + {"_rid": "r1", "id": "1", "minInclusive": "", "maxExclusive": "05C1E18D2D7F08", "parents": ["0"]}, + {"_rid": "r2", "id": "2", "minInclusive": "05C1E18D2D7F08", "maxExclusive": "FF", "parents": ["0"]} + ], + "_count": 2 + }`)), + mock.WithHeader(cosmosHeaderEtag, "etag2"), + mock.WithStatusCode(200), + ) + // 304 Not Modified — no more changes + srv.AppendResponse( + mock.WithStatusCode(304), + mock.WithHeader(cosmosHeaderEtag, "etag2"), + ) + + client := createMockClientForPKRangeCache(srv) + + // Pre-populate cache with 1 range + ETag + initialMap := newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "FF"}, + }, "etag1") + entry := &pkRangeCacheEntry{routingMap: initialMap} + client.caches.pkRangeCache.mu.Lock() + client.caches.pkRangeCache.entries["testRID"] = entry + client.caches.pkRangeCache.mu.Unlock() + + // forceRefresh should do incremental refresh + rm, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client) + require.NoError(t, err) + require.Equal(t, 2, len(rm.orderedRanges)) + require.Equal(t, "", rm.orderedRanges[0].MinInclusive) + require.Equal(t, "05C1E18D2D7F08", rm.orderedRanges[0].MaxExclusive) + require.Equal(t, "05C1E18D2D7F08", rm.orderedRanges[1].MinInclusive) + require.Equal(t, "FF", rm.orderedRanges[1].MaxExclusive) + require.Equal(t, "etag2", rm.changeFeedETag) + // Parent "0" should be marked as gone + require.True(t, rm.isGone("0")) +} + +func Test_partitionKeyRangeCache_incrementalRefresh_304_immediate(t *testing.T) { + // Scenario: No changes since last fetch — 304 immediately, map preserved. + srv, close := mock.NewTLSServer() + defer close() + + // 304 Not Modified with updated ETag + srv.AppendResponse( + mock.WithStatusCode(304), + mock.WithHeader(cosmosHeaderEtag, "etag2"), + ) + + client := createMockClientForPKRangeCache(srv) + + // Pre-populate cache with 2 ranges + ETag + initialMap := newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "05C1E18D2D7F08"}, + {ID: "1", MinInclusive: "05C1E18D2D7F08", MaxExclusive: "FF"}, + }, "etag1") + entry := &pkRangeCacheEntry{routingMap: initialMap} + client.caches.pkRangeCache.mu.Lock() + client.caches.pkRangeCache.entries["testRID"] = entry + client.caches.pkRangeCache.mu.Unlock() + + rm, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client) + require.NoError(t, err) + // Ranges should be preserved + require.Equal(t, 2, len(rm.orderedRanges)) + require.Equal(t, "0", rm.orderedRanges[0].ID) + require.Equal(t, "1", rm.orderedRanges[1].ID) + // ETag should be updated + require.Equal(t, "etag2", rm.changeFeedETag) +} + +func Test_partitionKeyRangeCache_incrementalRefresh_mergeFailure_fullRefresh(t *testing.T) { + // Scenario: tryCombine returns nil (incomplete ranges after merge) → falls through to full refresh. + srv, close := mock.NewTLSServer() + defer close() + + // Incremental response: only 1 child range that leaves a gap (parent "0" is gone + // but child only covers half the range → tryCombine fails) + srv.AppendResponse( + mock.WithBody([]byte(`{ + "_rid": "testRID", + "PartitionKeyRanges": [ + {"_rid": "r1", "id": "1", "minInclusive": "", "maxExclusive": "05C1E18D2D7F08", "parents": ["0"]} + ], + "_count": 1 + }`)), + mock.WithHeader(cosmosHeaderEtag, "etag2"), + mock.WithStatusCode(200), + ) + // Full refresh response: complete set of 3 ranges + srv.AppendResponse( + mock.WithBody([]byte(`{ + "_rid": "testRID", + "PartitionKeyRanges": [ + {"_rid": "r1", "id": "1", "minInclusive": "", "maxExclusive": "05C1E18D2D7F08", "parents": []}, + {"_rid": "r2", "id": "2", "minInclusive": "05C1E18D2D7F08", "maxExclusive": "0BC1", "parents": []}, + {"_rid": "r3", "id": "3", "minInclusive": "0BC1", "maxExclusive": "FF", "parents": []} + ], + "_count": 3 + }`)), + mock.WithHeader(cosmosHeaderEtag, "etag3"), + mock.WithStatusCode(200), + ) + + client := createMockClientForPKRangeCache(srv) + + // Pre-populate cache with 1 range spanning the full space + initialMap := newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "FF"}, + }, "etag1") + entry := &pkRangeCacheEntry{routingMap: initialMap} + client.caches.pkRangeCache.mu.Lock() + client.caches.pkRangeCache.entries["testRID"] = entry + client.caches.pkRangeCache.mu.Unlock() + + rm, err := client.caches.pkRangeCache.forceRefresh(context.Background(), "testRID", "dbs/db1/colls/col1", client) + require.NoError(t, err) + // Should have the 3 ranges from full refresh + require.Equal(t, 3, len(rm.orderedRanges)) + require.Equal(t, "1", rm.orderedRanges[0].ID) + require.Equal(t, "2", rm.orderedRanges[1].ID) + require.Equal(t, "3", rm.orderedRanges[2].ID) + require.Equal(t, "etag3", rm.changeFeedETag) +} + +func Test_partitionKeyRangeCache_incrementalRefresh_contextCancelled(t *testing.T) { + // Scenario: Context is cancelled before the HTTP call in the incremental loop. + srv, close := mock.NewTLSServer() + defer close() + + // Queue a response that should never be reached + srv.AppendResponse( + mock.WithBody([]byte(`{"_rid": "testRID", "PartitionKeyRanges": [], "_count": 0}`)), + mock.WithStatusCode(200), + ) + + client := createMockClientForPKRangeCache(srv) + + // Pre-populate cache with a range + ETag so it takes the incremental path + initialMap := newCollectionRoutingMap([]partitionKeyRange{ + {ID: "0", MinInclusive: "", MaxExclusive: "FF"}, + }, "etag1") + entry := &pkRangeCacheEntry{routingMap: initialMap} + client.caches.pkRangeCache.mu.Lock() + client.caches.pkRangeCache.entries["testRID"] = entry + client.caches.pkRangeCache.mu.Unlock() + + // Cancel context before calling forceRefresh + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := client.caches.pkRangeCache.forceRefresh(ctx, "testRID", "dbs/db1/colls/col1", client) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + + // Cache entry should still have the original routing map (unchanged) + entry.mu.Lock() + require.NotNil(t, entry.routingMap) + require.Equal(t, "etag1", entry.routingMap.changeFeedETag) + require.Equal(t, 1, len(entry.routingMap.orderedRanges)) + entry.mu.Unlock() +} diff --git a/sdk/data/azcosmos/cosmos_pk_range_gone_test.go b/sdk/data/azcosmos/cosmos_pk_range_gone_test.go new file mode 100644 index 000000000000..4b909b4af65b --- /dev/null +++ b/sdk/data/azcosmos/cosmos_pk_range_gone_test.go @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/stretchr/testify/require" +) + +func Test_isPartitionKeyRangeGoneError_410WithSplitSubstatus(t *testing.T) { + require.True(t, isPartitionKeyRangeGoneError(http.StatusGone, subStatusPartitionKeyRangeGone)) + require.True(t, isPartitionKeyRangeGoneError(http.StatusGone, subStatusCompletingSplit)) + require.True(t, isPartitionKeyRangeGoneError(http.StatusGone, subStatusCompletingPartitionMigration)) +} + +func Test_isPartitionKeyRangeGoneError_410WithOtherSubstatus(t *testing.T) { + require.False(t, isPartitionKeyRangeGoneError(http.StatusGone, "9999")) + require.False(t, isPartitionKeyRangeGoneError(http.StatusGone, "")) +} + +func Test_isPartitionKeyRangeGoneError_non410(t *testing.T) { + require.False(t, isPartitionKeyRangeGoneError(http.StatusNotFound, subStatusPartitionKeyRangeGone)) + require.False(t, isPartitionKeyRangeGoneError(http.StatusOK, subStatusCompletingSplit)) +} + +func Test_isPKRangeGoneResponseError_nonResponseError(t *testing.T) { + err := errors.New("some random error") + require.False(t, isPKRangeGoneResponseError(err)) +} + +func Test_isPKRangeGoneResponseError_non410ResponseError(t *testing.T) { + err := &azcore.ResponseError{ + StatusCode: http.StatusNotFound, + } + require.False(t, isPKRangeGoneResponseError(err)) +} + +func Test_isPKRangeGoneResponseError_410WithSubstatus(t *testing.T) { + header := http.Header{} + header.Set(cosmosHeaderSubstatus, subStatusCompletingSplit) + resp := &http.Response{ + StatusCode: http.StatusGone, + Header: header, + } + err := &azcore.ResponseError{ + StatusCode: http.StatusGone, + RawResponse: resp, + } + require.True(t, isPKRangeGoneResponseError(err)) +} + +func Test_isPKRangeGoneResponseError_410WithoutSubstatus(t *testing.T) { + header := http.Header{} + header.Set(cosmosHeaderSubstatus, "9999") + resp := &http.Response{ + StatusCode: http.StatusGone, + Header: header, + } + err := &azcore.ResponseError{ + StatusCode: http.StatusGone, + RawResponse: resp, + } + require.False(t, isPKRangeGoneResponseError(err)) +} + +func Test_isPKRangeGoneResponseError_410WithNilRawResponse(t *testing.T) { + err := &azcore.ResponseError{ + StatusCode: http.StatusGone, + RawResponse: nil, + } + // Conservative: any 410 without a raw response is treated as PKRange gone + require.True(t, isPKRangeGoneResponseError(err)) +} + +func Test_hasAnyPKRangeGoneError_findsGoneAmongCancelled(t *testing.T) { + // Simulates the concurrent chunk cancellation scenario: a context.Canceled + // error at a lower index masks the actual 410 at a higher index. + header := http.Header{} + header.Set(cosmosHeaderSubstatus, subStatusCompletingSplit) + results := []chunkResult{ + {err: nil}, + {err: context.Canceled}, + {err: &azcore.ResponseError{ + StatusCode: http.StatusGone, + RawResponse: &http.Response{StatusCode: http.StatusGone, Header: header}, + }}, + } + require.True(t, hasAnyPKRangeGoneError(results)) +} + +func Test_hasAnyPKRangeGoneError_noGone(t *testing.T) { + results := []chunkResult{ + {err: nil}, + {err: context.Canceled}, + {err: errors.New("some other error")}, + } + require.False(t, hasAnyPKRangeGoneError(results)) +} + +func Test_hasAnyPKRangeGoneError_allNil(t *testing.T) { + results := []chunkResult{ + {err: nil}, + {err: nil}, + } + require.False(t, hasAnyPKRangeGoneError(results)) +} diff --git a/sdk/data/azcosmos/internal/epk/epk.go b/sdk/data/azcosmos/internal/epk/epk.go index 7bf4a0455e0a..7f8449fc8786 100644 --- a/sdk/data/azcosmos/internal/epk/epk.go +++ b/sdk/data/azcosmos/internal/epk/epk.go @@ -258,8 +258,9 @@ func ComputeV1(values []interface{}) string { return sb.String() } -// ComputeV2Hash computes the V2 EPK for Hash partitioning. -func ComputeV2Hash(values []interface{}) string { +// computeV2Hash computes the raw V2 EPK for Hash partitioning (without top-bit masking). +// Do not use for partition routing — use ComputeV2HashForRouting instead. +func computeV2Hash(values []interface{}) string { var hashBuf []byte for _, comp := range values { writeForHashingV2(comp, &hashBuf) @@ -269,8 +270,9 @@ func ComputeV2Hash(values []interface{}) string { return hash128ToEPK(low, high) } -// ComputeV2MultiHash computes the V2 EPK for MultiHash partitioning. -func ComputeV2MultiHash(values []interface{}) string { +// computeV2MultiHash computes the raw V2 EPK for MultiHash partitioning (without top-bit masking). +// Do not use for partition routing — use ComputeV2MultiHashForRouting instead. +func computeV2MultiHash(values []interface{}) string { var sb strings.Builder for _, comp := range values { var hashBuf []byte @@ -282,6 +284,27 @@ func ComputeV2MultiHash(values []interface{}) string { return sb.String() } +// ComputeV2HashForRouting computes the V2 EPK for routing (with top-2-bit masking). +// The service uses "FF" as the maximum exclusive partition key range sentinel, +// so all valid EPK values must have their first byte masked to [0x00, 0x3F]. +func ComputeV2HashForRouting(values []interface{}) string { + return maskTopBitsForRouting(computeV2Hash(values)) +} + +// ComputeV2MultiHashForRouting computes the V2 MultiHash EPK for routing. +// Each per-component hash has its top 2 bits masked independently. +func ComputeV2MultiHashForRouting(values []interface{}) string { + var sb strings.Builder + for _, comp := range values { + var hashBuf []byte + writeForHashingV2(comp, &hashBuf) + + low, high := murmurhash3_128(hashBuf, 0, 0) + sb.WriteString(maskTopBitsForRouting(hash128ToEPK(low, high))) + } + return sb.String() +} + // hash128ToEPK converts a 128-bit hash (low, high) to an EPK hex string. // The byte array is [low_LE, high_LE] reversed, producing big-endian order. func hash128ToEPK(low, high uint64) string { @@ -297,6 +320,87 @@ func hash128ToEPK(low, high uint64) string { return toHexUpper(bytes[:]) } +// maskTopBitsForRouting clears the two most significant bits of the first byte +// in an EPK hex string. The Cosmos DB partition key range space uses "FF" as the +// maximum exclusive sentinel, so all valid EPK values must have their first byte +// in the range [0x00, 0x3F]. This matches the behavior of other Cosmos DB SDKs. +func maskTopBitsForRouting(epkHex string) string { + if len(epkHex) < 2 { + return epkHex + } + // Parse the first byte (2 hex chars) + firstByte := hexCharToNibble(epkHex[0])<<4 | hexCharToNibble(epkHex[1]) + firstByte &= 0x3F + // Replace the first two hex characters + return fmt.Sprintf("%02X", firstByte) + epkHex[2:] +} + +func hexCharToNibble(c byte) byte { + switch { + case c >= '0' && c <= '9': + return c - '0' + case c >= 'A' && c <= 'F': + return c - 'A' + 10 + case c >= 'a' && c <= 'f': + return c - 'a' + 10 + default: + return 0 + } +} + +// CompareEPK performs length-aware comparison of two effective partition key hex strings. +// +// For hierarchical partition key (HPK) containers, the service may return +// partition key ranges with mixed-length boundaries (e.g., 32-char partial +// vs 64-char fully specified, zero-padded). This function treats two EPKs +// as equal when one is a prefix of the other and the remainder is all '0' characters. +// +// Returns -1 if a < b, 0 if a == b, +1 if a > b. +// +// See: https://github.com/Azure/azure-cosmos-dotnet-v3/pull/5260 +func CompareEPK(a, b string) int { + commonLen := len(a) + if len(b) < commonLen { + commonLen = len(b) + } + + // Compare the common prefix + prefixA := a[:commonLen] + prefixB := b[:commonLen] + if prefixA < prefixB { + return -1 + } + if prefixA > prefixB { + return 1 + } + + // Common prefixes are equal — check the tail of the longer string + var tail string + if len(a) > len(b) { + tail = a[commonLen:] + } else { + tail = b[commonLen:] + } + + // If the tail is all zeros (or empty), the EPKs are equal + allZeros := true + for i := 0; i < len(tail); i++ { + if tail[i] != '0' { + allZeros = false + break + } + } + if allZeros { + return 0 + } + + // Non-zero tail: the longer string is greater + if len(a) > len(b) { + return 1 + } + return -1 +} + // toHexUpper returns uppercase hex encoding of data with no separators. func toHexUpper(data []byte) string { var sb strings.Builder diff --git a/sdk/data/azcosmos/internal/epk/epk_test.go b/sdk/data/azcosmos/internal/epk/epk_test.go index 6d6d07c8163e..c73b70dab8c0 100644 --- a/sdk/data/azcosmos/internal/epk/epk_test.go +++ b/sdk/data/azcosmos/internal/epk/epk_test.go @@ -125,9 +125,9 @@ func TestComputeEPK_Baseline(t *testing.T) { t.Run(category+"/V2/"+tc.Input.Description, func(t *testing.T) { var actual string if spec.multiHash { - actual = ComputeV2MultiHash(values) + actual = computeV2MultiHash(values) } else { - actual = ComputeV2Hash(values) + actual = computeV2Hash(values) } require.Equal(t, expectedV2, actual, "V2 hash mismatch for %s (value: %s)", tc.Input.Description, tc.Input.PartitionKeyValue) @@ -135,3 +135,81 @@ func TestComputeEPK_Baseline(t *testing.T) { } } } + +func TestCompareEPK_EqualSameLength(t *testing.T) { + require.Equal(t, 0, CompareEPK("06AB34CFE4E48223", "06AB34CFE4E48223")) +} + +func TestCompareEPK_LessThan(t *testing.T) { + require.Equal(t, -1, CompareEPK("06AB34CFE4E48223", "06AB34CFE4E48224")) +} + +func TestCompareEPK_GreaterThan(t *testing.T) { + require.Equal(t, 1, CompareEPK("06AB34CFE4E48224", "06AB34CFE4E48223")) +} + +func TestCompareEPK_ZeroPaddedTailEqual(t *testing.T) { + // A 32-char partial EPK should equal its 64-char zero-padded equivalent + partial := "06AB34CFE4E482236BCACBBF50E234AB" + full := "06AB34CFE4E482236BCACBBF50E234AB00000000000000000000000000000000" + require.Equal(t, 0, CompareEPK(partial, full)) + require.Equal(t, 0, CompareEPK(full, partial)) +} + +func TestCompareEPK_NonZeroTailNotEqual(t *testing.T) { + partial := "06AB34CFE4E482236BCACBBF50E234AB" + full := "06AB34CFE4E482236BCACBBF50E234AB00000000000000000000000000000001" + require.Equal(t, -1, CompareEPK(partial, full)) + require.Equal(t, 1, CompareEPK(full, partial)) +} + +func TestCompareEPK_EmptyStrings(t *testing.T) { + require.Equal(t, 0, CompareEPK("", "")) + require.Equal(t, 0, CompareEPK("", "00000")) + require.Equal(t, -1, CompareEPK("", "00001")) +} + +func TestCompareEPK_FFSentinel(t *testing.T) { + // "FF" should be greater than any masked EPK (first hex digit in [0-3]) + require.Equal(t, 1, CompareEPK("FF", "3FFFFFFFFFFFFFFF")) + require.Equal(t, -1, CompareEPK("3FFFFFFFFFFFFFFF", "FF")) +} + +func TestMaskTopBitsForRouting(t *testing.T) { + // Already valid (first byte ≤ 0x3F) — unchanged + require.Equal(t, "3FAABBCC", maskTopBitsForRouting("3FAABBCC")) + // 0xFF & 0x3F = 0x3F + require.Equal(t, "3FAABBCC", maskTopBitsForRouting("FFAABBCC")) + // 0xC0 & 0x3F = 0x00 + require.Equal(t, "00AABBCC", maskTopBitsForRouting("C0AABBCC")) + // 0x80 & 0x3F = 0x00 + require.Equal(t, "00112233", maskTopBitsForRouting("80112233")) + // 0x40 & 0x3F = 0x00 + require.Equal(t, "00112233", maskTopBitsForRouting("40112233")) + // Edge: empty string + require.Equal(t, "", maskTopBitsForRouting("")) + // Edge: single char + require.Equal(t, "A", maskTopBitsForRouting("A")) +} + +func TestComputeV2HashForRouting_MaskingApplied(t *testing.T) { + // null produces a raw V2 hash with first byte 0x77, which should be masked to 0x37 + result := ComputeV2HashForRouting([]interface{}{nil}) + require.True(t, len(result) >= 2, "result should be at least 2 hex chars") + // After masking, the first hex digit must be in [0-3] + firstDigit := result[0] + require.True(t, firstDigit >= '0' && firstDigit <= '3', + "first hex digit should be in [0-3] after masking, got %c", firstDigit) +} + +func TestComputeV2MultiHashForRouting_MaskingApplied(t *testing.T) { + // Each 32-char component should have its first byte masked independently + result := ComputeV2MultiHashForRouting([]interface{}{"hello", "world"}) + require.Equal(t, 64, len(result), "two components should produce 64 hex chars") + // Check first byte of each 32-char component is in [0x00, 0x3F] + for i := 0; i < 2; i++ { + firstDigit := result[i*32] + require.True(t, firstDigit >= '0' && firstDigit <= '3', + "component %d first hex digit should be in [0-3] after masking, got %c", i, firstDigit) + } +} diff --git a/sdk/data/azcosmos/partition_key.go b/sdk/data/azcosmos/partition_key.go index feb9056bcba7..1c6530330a5c 100644 --- a/sdk/data/azcosmos/partition_key.go +++ b/sdk/data/azcosmos/partition_key.go @@ -120,10 +120,71 @@ func (pk *PartitionKey) computeEffectivePartitionKey(kind PartitionKeyKind, vers case version == 1: epkStr = epk.ComputeV1(values) case kind == PartitionKeyKindMultiHash: - epkStr = epk.ComputeV2MultiHash(values) + epkStr = epk.ComputeV2MultiHashForRouting(values) default: - epkStr = epk.ComputeV2Hash(values) + epkStr = epk.ComputeV2HashForRouting(values) } return epk.EffectivePartitionKey{EPK: epkStr} } + +// epkRange represents an effective partition key range for routing. +// For a point key, Min == Max (both equal to the EPK). +// For a prefix key on a MultiHash container, Min is the prefix EPK and +// Max is prefix EPK + "FF" (an exclusive upper bound). +type epkRange struct { + Min string // inclusive + Max string // exclusive (empty means same as Min, i.e., point) +} + +// isRange returns true if this represents a range (prefix key) rather than a point. +func (r epkRange) isRange() bool { + return r.Max != "" && r.Min != r.Max +} + +// computeEPKRange computes the EPK range for a partition key given the container's +// partition key definition. For full keys it returns a point range. For prefix keys +// on MultiHash containers it returns a range [prefix_epk, prefix_epk+"FF"). +// Non-MultiHash containers require exactly the right number of components. +func computeEPKRange(pk *PartitionKey, pkDef PartitionKeyDefinition) (epkRange, error) { + pkVersion := pkDef.Version + if pkVersion == 0 { + pkVersion = 1 + } + + // Undefined PK (no components) is a concrete value, not a prefix. + // It hashes to a single deterministic EPK regardless of the number of + // definition paths, so always return a point range. + if len(pk.values) == 0 { + epkVal := pk.computeEffectivePartitionKey(pkDef.Kind, pkVersion) + return epkRange{Min: epkVal.EPK, Max: epkVal.EPK}, nil + } + + componentCount := len(pk.values) + pathCount := len(pkDef.Paths) + + if componentCount > pathCount { + return epkRange{}, fmt.Errorf("more partition key components (%d) than definition paths (%d)", componentCount, pathCount) + } + + if pkDef.Kind != PartitionKeyKindMultiHash && componentCount != pathCount { + return epkRange{}, fmt.Errorf("non-MultiHash containers require exactly %d components, got %d", pathCount, componentCount) + } + + epkVal := pk.computeEffectivePartitionKey(pkDef.Kind, pkVersion) + + isPrefix := pkDef.Kind == PartitionKeyKindMultiHash && componentCount < pathCount + if isPrefix { + // "FF" is safe as an upper-bound sentinel because maskTopBitsForRouting + // ensures every EPK component's first hex digit is in [0-3]. + return epkRange{ + Min: epkVal.EPK, + Max: epkVal.EPK + "FF", + }, nil + } + + return epkRange{ + Min: epkVal.EPK, + Max: epkVal.EPK, + }, nil +} diff --git a/sdk/data/azcosmos/partition_key_test.go b/sdk/data/azcosmos/partition_key_test.go index 371e1d836b43..883cfdbfc861 100644 --- a/sdk/data/azcosmos/partition_key_test.go +++ b/sdk/data/azcosmos/partition_key_test.go @@ -115,14 +115,14 @@ func TestComputeEffectivePartitionKey(t *testing.T) { result := pk.computeEffectivePartitionKey(PartitionKeyKindHash, 1) require.Equal(t, "000000000000000000000000FF69187F", result.EPK) - // V2 Hash: null → known hash + // V2 Hash: null → known hash (top 2 bits masked for routing: 0x77 & 0x3F = 0x37) result = NullPartitionKey.computeEffectivePartitionKey(PartitionKeyKindHash, 2) - require.Equal(t, "778867E4430E67857ACE5C908374FE16", result.EPK) + require.Equal(t, "378867E4430E67857ACE5C908374FE16", result.EPK) - // V2 MultiHash: ["a", "b"] → per-component hashes concatenated + // V2 MultiHash: ["a", "b"] → per-component hashes concatenated (each masked) multiPK := NewPartitionKey().AppendString("a").AppendString("b") result = multiPK.computeEffectivePartitionKey(PartitionKeyKindMultiHash, 2) - require.Equal(t, "FA5381E1114EB8D3FCC90795045B49B7D95644569A78B1E22D200348AF9416CE", result.EPK) + require.Equal(t, "3A5381E1114EB8D3FCC90795045B49B7195644569A78B1E22D200348AF9416CE", result.EPK) // Undefined partition key emptyPK := NewPartitionKey()