diff --git a/sdk/data/azcosmos/CHANGELOG.md b/sdk/data/azcosmos/CHANGELOG.md index 724526725eb6..5c33379ef3e6 100644 --- a/sdk/data/azcosmos/CHANGELOG.md +++ b/sdk/data/azcosmos/CHANGELOG.md @@ -4,6 +4,7 @@ ### Features Added +* Adds `PriorityLevel` and `ThroughputBucket` options at the client and per-request level for item, query, change-feed, batch, and read-many operations. See [PR 26750](https://github.com/Azure/azure-sdk-for-go/pull/26750) * Added client-level partition key range cache and container properties cache, reducing redundant metadata round-trips for ReadMany and query operations. See [PR 26723](https://github.com/Azure/azure-sdk-for-go/pull/26723) * Added operation diagnostics on responses and `DiagnosticsFromError` for retrieving diagnostics from failed operations. See [PR 26548](https://github.com/Azure/azure-sdk-for-go/pull/26548) diff --git a/sdk/data/azcosmos/cosmos_change_feed_request_options.go b/sdk/data/azcosmos/cosmos_change_feed_request_options.go index 3c54d877f522..a7ec55823469 100644 --- a/sdk/data/azcosmos/cosmos_change_feed_request_options.go +++ b/sdk/data/azcosmos/cosmos_change_feed_request_options.go @@ -29,6 +29,14 @@ type ChangeFeedOptions struct { // CompositeContinuation is used to continue reading the change feed from a specific point. Continuation *string + + // PriorityLevel overrides the client-level default priority for this operation. + // Valid values are PriorityLevelHigh and PriorityLevelLow. + PriorityLevel *PriorityLevel + // ThroughputBucket overrides the client-level default throughput bucket for this operation. + // For more information, see https://aka.ms/CosmosDB/ThroughputBuckets + // The valid range is 1 to 5 (inclusive). + ThroughputBucket *int32 } func (options *ChangeFeedOptions) toHeaders(partitionKeyRanges []partitionKeyRange) *map[string]string { diff --git a/sdk/data/azcosmos/cosmos_client.go b/sdk/data/azcosmos/cosmos_client.go index de94f76080a0..a160924e08c5 100644 --- a/sdk/data/azcosmos/cosmos_client.go +++ b/sdk/data/azcosmos/cosmos_client.go @@ -190,6 +190,8 @@ func newClient(authPolicy policy.Policy, gem *globalEndpointManager, options *Cl PerCall: []policy.Policy{ &headerPolicies{ enableContentResponseOnWrite: options.EnableContentResponseOnWrite, + priorityLevel: options.PriorityLevel, + throughputBucket: options.ThroughputBucket, }, &globalEndpointManagerPolicy{gem: gem}, }, @@ -692,5 +694,7 @@ func getAllowedHeaders() []string { cosmosHeaderIsPartitionKeyDeletePending, cosmosHeaderQueryExecutionInfo, headerXmsItemCount, + cosmosHeaderPriorityLevel, + cosmosHeaderThroughputBucket, } } diff --git a/sdk/data/azcosmos/cosmos_client_options.go b/sdk/data/azcosmos/cosmos_client_options.go index 882f29843b68..33b88bbc26f3 100644 --- a/sdk/data/azcosmos/cosmos_client_options.go +++ b/sdk/data/azcosmos/cosmos_client_options.go @@ -15,4 +15,14 @@ type ClientOptions struct { EnableContentResponseOnWrite bool // PreferredRegions is a list of regions to be used when initializing the client in case the default region fails. PreferredRegions []string + // PriorityLevel defines the default priority level for all requests made by this client. + // This feature is currently in preview. For more information, see https://aka.ms/CosmosDB/PriorityBasedExecution + // Valid values are PriorityLevelHigh and PriorityLevelLow. + // Can be overridden per-request via the operation options. + PriorityLevel *PriorityLevel + // ThroughputBucket defines the default throughput bucket for all requests made by this client. + // This feature is currently in preview. For more information, see https://aka.ms/CosmosDB/ThroughputBuckets + // The valid range is 1 to 5 (inclusive). + // Can be overridden per-request via the operation options. + ThroughputBucket *int32 } diff --git a/sdk/data/azcosmos/cosmos_container.go b/sdk/data/azcosmos/cosmos_container.go index 0627e73e8855..7f7cbd96a55d 100644 --- a/sdk/data/azcosmos/cosmos_container.go +++ b/sdk/data/azcosmos/cosmos_container.go @@ -278,6 +278,8 @@ func (c *ContainerClient) CreateItem( o = &ItemOptions{} } else { h.enableContentResponseOnWrite = &o.EnableContentResponseOnWrite + h.priorityLevel = o.PriorityLevel + h.throughputBucket = o.ThroughputBucket } operationContext := pipelineRequestOptions{ @@ -335,6 +337,8 @@ func (c *ContainerClient) UpsertItem( o = &ItemOptions{} } else { h.enableContentResponseOnWrite = &o.EnableContentResponseOnWrite + h.priorityLevel = o.PriorityLevel + h.throughputBucket = o.ThroughputBucket } operationContext := pipelineRequestOptions{ @@ -390,6 +394,8 @@ func (c *ContainerClient) ReplaceItem( o = &ItemOptions{} } else { h.enableContentResponseOnWrite = &o.EnableContentResponseOnWrite + h.priorityLevel = o.PriorityLevel + h.throughputBucket = o.ThroughputBucket } operationContext := pipelineRequestOptions{ @@ -442,6 +448,8 @@ func (c *ContainerClient) ReadItem( if o == nil { o = &ItemOptions{} } + h.priorityLevel = o.PriorityLevel + h.throughputBucket = o.ThroughputBucket operationContext := pipelineRequestOptions{ resourceType: resourceTypeDocument, @@ -493,9 +501,15 @@ func (c *ContainerClient) ReadManyItems( readManyOptions = &originalOptions } + h := headerOptionsOverride{ + priorityLevel: readManyOptions.PriorityLevel, + throughputBucket: readManyOptions.ThroughputBucket, + } + operationContext := pipelineRequestOptions{ - resourceType: resourceTypeDocument, - resourceAddress: c.link, + resourceType: resourceTypeDocument, + resourceAddress: c.link, + headerOptionsOverride: &h, } ctx, endTrace := ensureOperationTrace(ctx, fmt.Sprintf("read_many_items %s", c.id)) @@ -556,6 +570,8 @@ func (c *ContainerClient) DeleteItem( o = &ItemOptions{} } else { h.enableContentResponseOnWrite = &o.EnableContentResponseOnWrite + h.priorityLevel = o.PriorityLevel + h.throughputBucket = o.ThroughputBucket } operationContext := pipelineRequestOptions{ @@ -614,6 +630,8 @@ func (c *ContainerClient) NewQueryItemsPager(query string, partitionKey Partitio if o != nil { originalOptions := *o queryOptions = &originalOptions + h.priorityLevel = o.PriorityLevel + h.throughputBucket = o.ThroughputBucket } operationContext := pipelineRequestOptions{ @@ -695,6 +713,8 @@ func (c *ContainerClient) PatchItem( o = &ItemOptions{} } else { h.enableContentResponseOnWrite = &o.EnableContentResponseOnWrite + h.priorityLevel = o.PriorityLevel + h.throughputBucket = o.ThroughputBucket } operationContext := pipelineRequestOptions{ @@ -751,6 +771,8 @@ func (c *ContainerClient) ExecuteTransactionalBatch(ctx context.Context, b Trans o = &TransactionalBatchOptions{} } else { h.enableContentResponseOnWrite = &o.EnableContentResponseOnWrite + h.priorityLevel = o.PriorityLevel + h.throughputBucket = o.ThroughputBucket } // If contentResponseOnWrite is not enabled at the client level the @@ -857,9 +879,15 @@ func (c *ContainerClient) getChangeFeedForEPKRange( } } + h := headerOptionsOverride{ + priorityLevel: options.PriorityLevel, + throughputBucket: options.ThroughputBucket, + } + operationContext := pipelineRequestOptions{ - resourceType: resourceTypeDocument, - resourceAddress: c.link, + resourceType: resourceTypeDocument, + resourceAddress: c.link, + headerOptionsOverride: &h, } path, err := generatePathForNameBased(resourceTypeDocument, operationContext.resourceAddress, true) diff --git a/sdk/data/azcosmos/cosmos_container_test.go b/sdk/data/azcosmos/cosmos_container_test.go index 9663ff50b20e..5760351de29f 100644 --- a/sdk/data/azcosmos/cosmos_container_test.go +++ b/sdk/data/azcosmos/cosmos_container_test.go @@ -1210,3 +1210,503 @@ func TestContainerGetChangeFeedForEPKRange(t *testing.T) { h[cosmosHeaderChangeFeed], cosmosHeaderValuesChangeFeed) } } + +func TestCreateItemPriorityAndThroughputBucketHeaders(t *testing.T) { + jsonString := []byte(`{"id":"doc1","foo":"bar"}`) + srv, close := mock.NewTLSServer() + defaultEndpoint, _ := url.Parse(srv.URL()) + defer close() + srv.SetResponse( + mock.WithBody(jsonString), + mock.WithHeader(cosmosHeaderEtag, "someEtag"), + mock.WithHeader(cosmosHeaderActivityId, "someActivityId"), + mock.WithHeader(cosmosHeaderRequestCharge, "13.42"), + mock.WithStatusCode(200)) + + verifier := pipelineVerifier{} + clientPriority := PriorityLevelHigh + clientBucket := int32(2) + headerPolicy := &headerPolicies{ + priorityLevel: &clientPriority, + throughputBucket: &clientBucket, + } + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + gem := &globalEndpointManager{preferredLocations: []string{}} + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + + database, _ := newDatabase("databaseId", client) + container, _ := newContainer("containerId", database) + + // Test with per-request override + requestPriority := PriorityLevelLow + requestBucket := int32(5) + _, err := container.CreateItem(context.TODO(), NewPartitionKeyString("1"), jsonString, &ItemOptions{ + PriorityLevel: &requestPriority, + ThroughputBucket: &requestBucket, + }) + if err != nil { + t.Fatalf("Failed to create item: %v", err) + } + + h := verifier.requests[0].headers + if h.Get(cosmosHeaderPriorityLevel) != "Low" { + t.Errorf("Expected priority level header to be Low, got %v", h.Get(cosmosHeaderPriorityLevel)) + } + if h.Get(cosmosHeaderThroughputBucket) != "5" { + t.Errorf("Expected throughput bucket header to be 5, got %v", h.Get(cosmosHeaderThroughputBucket)) + } +} + +func TestUpsertItemPriorityAndThroughputBucketHeaders(t *testing.T) { + jsonString := []byte(`{"id":"doc1","foo":"bar"}`) + srv, close := mock.NewTLSServer() + defaultEndpoint, _ := url.Parse(srv.URL()) + defer close() + srv.SetResponse( + mock.WithBody(jsonString), + mock.WithHeader(cosmosHeaderEtag, "someEtag"), + mock.WithHeader(cosmosHeaderActivityId, "someActivityId"), + mock.WithHeader(cosmosHeaderRequestCharge, "13.42"), + mock.WithStatusCode(200)) + + verifier := pipelineVerifier{} + headerPolicy := &headerPolicies{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + gem := &globalEndpointManager{preferredLocations: []string{}} + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + + database, _ := newDatabase("databaseId", client) + container, _ := newContainer("containerId", database) + + requestPriority := PriorityLevelLow + requestBucket := int32(3) + _, err := container.UpsertItem(context.TODO(), NewPartitionKeyString("1"), jsonString, &ItemOptions{ + PriorityLevel: &requestPriority, + ThroughputBucket: &requestBucket, + }) + if err != nil { + t.Fatalf("Failed to upsert item: %v", err) + } + + h := verifier.requests[0].headers + if h.Get(cosmosHeaderPriorityLevel) != "Low" { + t.Errorf("Expected priority level header to be Low, got %v", h.Get(cosmosHeaderPriorityLevel)) + } + if h.Get(cosmosHeaderThroughputBucket) != "3" { + t.Errorf("Expected throughput bucket header to be 3, got %v", h.Get(cosmosHeaderThroughputBucket)) + } +} + +func TestReplaceItemPriorityAndThroughputBucketHeaders(t *testing.T) { + jsonString := []byte(`{"id":"doc1","foo":"bar"}`) + srv, close := mock.NewTLSServer() + defaultEndpoint, _ := url.Parse(srv.URL()) + defer close() + srv.SetResponse( + mock.WithBody(jsonString), + mock.WithHeader(cosmosHeaderEtag, "someEtag"), + mock.WithHeader(cosmosHeaderActivityId, "someActivityId"), + mock.WithHeader(cosmosHeaderRequestCharge, "13.42"), + mock.WithStatusCode(200)) + + verifier := pipelineVerifier{} + headerPolicy := &headerPolicies{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + gem := &globalEndpointManager{preferredLocations: []string{}} + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + + database, _ := newDatabase("databaseId", client) + container, _ := newContainer("containerId", database) + + requestPriority := PriorityLevelHigh + requestBucket := int32(2) + _, err := container.ReplaceItem(context.TODO(), NewPartitionKeyString("1"), "doc1", jsonString, &ItemOptions{ + PriorityLevel: &requestPriority, + ThroughputBucket: &requestBucket, + }) + if err != nil { + t.Fatalf("Failed to replace item: %v", err) + } + + h := verifier.requests[0].headers + if h.Get(cosmosHeaderPriorityLevel) != "High" { + t.Errorf("Expected priority level header to be High, got %v", h.Get(cosmosHeaderPriorityLevel)) + } + if h.Get(cosmosHeaderThroughputBucket) != "2" { + t.Errorf("Expected throughput bucket header to be 2, got %v", h.Get(cosmosHeaderThroughputBucket)) + } +} + +func TestReadItemPriorityAndThroughputBucketHeaders(t *testing.T) { + jsonString := []byte(`{"id":"doc1","foo":"bar"}`) + srv, close := mock.NewTLSServer() + defaultEndpoint, _ := url.Parse(srv.URL()) + defer close() + srv.SetResponse( + mock.WithBody(jsonString), + mock.WithHeader(cosmosHeaderEtag, "someEtag"), + mock.WithHeader(cosmosHeaderActivityId, "someActivityId"), + mock.WithHeader(cosmosHeaderRequestCharge, "13.42"), + mock.WithStatusCode(200)) + + verifier := pipelineVerifier{} + headerPolicy := &headerPolicies{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + gem := &globalEndpointManager{preferredLocations: []string{}} + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + + database, _ := newDatabase("databaseId", client) + container, _ := newContainer("containerId", database) + + requestPriority := PriorityLevelLow + requestBucket := int32(4) + _, err := container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", &ItemOptions{ + PriorityLevel: &requestPriority, + ThroughputBucket: &requestBucket, + }) + if err != nil { + t.Fatalf("Failed to read item: %v", err) + } + + h := verifier.requests[0].headers + if h.Get(cosmosHeaderPriorityLevel) != "Low" { + t.Errorf("Expected priority level header to be Low, got %v", h.Get(cosmosHeaderPriorityLevel)) + } + if h.Get(cosmosHeaderThroughputBucket) != "4" { + t.Errorf("Expected throughput bucket header to be 4, got %v", h.Get(cosmosHeaderThroughputBucket)) + } +} + +func TestDeleteItemPriorityAndThroughputBucketHeaders(t *testing.T) { + srv, close := mock.NewTLSServer() + defaultEndpoint, _ := url.Parse(srv.URL()) + defer close() + srv.SetResponse( + mock.WithHeader(cosmosHeaderEtag, "someEtag"), + mock.WithHeader(cosmosHeaderActivityId, "someActivityId"), + mock.WithHeader(cosmosHeaderRequestCharge, "13.42"), + mock.WithStatusCode(204)) + + verifier := pipelineVerifier{} + headerPolicy := &headerPolicies{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + gem := &globalEndpointManager{preferredLocations: []string{}} + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + + database, _ := newDatabase("databaseId", client) + container, _ := newContainer("containerId", database) + + requestPriority := PriorityLevelHigh + requestBucket := int32(1) + _, err := container.DeleteItem(context.TODO(), NewPartitionKeyString("1"), "doc1", &ItemOptions{ + PriorityLevel: &requestPriority, + ThroughputBucket: &requestBucket, + }) + if err != nil { + t.Fatalf("Failed to delete item: %v", err) + } + + h := verifier.requests[0].headers + if h.Get(cosmosHeaderPriorityLevel) != "High" { + t.Errorf("Expected priority level header to be High, got %v", h.Get(cosmosHeaderPriorityLevel)) + } + if h.Get(cosmosHeaderThroughputBucket) != "1" { + t.Errorf("Expected throughput bucket header to be 1, got %v", h.Get(cosmosHeaderThroughputBucket)) + } +} + +func TestPatchItemPriorityAndThroughputBucketHeaders(t *testing.T) { + jsonString := []byte(`{"id":"doc1","foo":"bar","hello":"world"}`) + srv, close := mock.NewTLSServer() + defaultEndpoint, _ := url.Parse(srv.URL()) + defer close() + srv.SetResponse( + mock.WithBody(jsonString), + mock.WithHeader(cosmosHeaderEtag, "someEtag"), + mock.WithHeader(cosmosHeaderActivityId, "someActivityId"), + mock.WithHeader(cosmosHeaderRequestCharge, "13.42"), + mock.WithStatusCode(200)) + + verifier := pipelineVerifier{} + headerPolicy := &headerPolicies{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + gem := &globalEndpointManager{preferredLocations: []string{}} + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + + database, _ := newDatabase("databaseId", client) + container, _ := newContainer("containerId", database) + + patchOpt := PatchOperations{} + patchOpt.AppendSet("/hello", "world") + + requestPriority := PriorityLevelLow + requestBucket := int32(5) + _, err := container.PatchItem(context.TODO(), NewPartitionKeyString("1"), "doc1", patchOpt, &ItemOptions{ + PriorityLevel: &requestPriority, + ThroughputBucket: &requestBucket, + }) + if err != nil { + t.Fatalf("Failed to patch item: %v", err) + } + + h := verifier.requests[0].headers + if h.Get(cosmosHeaderPriorityLevel) != "Low" { + t.Errorf("Expected priority level header to be Low, got %v", h.Get(cosmosHeaderPriorityLevel)) + } + if h.Get(cosmosHeaderThroughputBucket) != "5" { + t.Errorf("Expected throughput bucket header to be 5, got %v", h.Get(cosmosHeaderThroughputBucket)) + } +} + +func TestQueryItemsPriorityAndThroughputBucketHeaders(t *testing.T) { + jsonStringpage1 := []byte(`{"Documents":[{"id":"doc1"}]}`) + srv, close := mock.NewTLSServer() + defaultEndpoint, _ := url.Parse(srv.URL()) + defer close() + srv.SetResponse( + mock.WithBody(jsonStringpage1), + mock.WithHeader(cosmosHeaderEtag, "someEtag"), + mock.WithHeader(cosmosHeaderActivityId, "someActivityId"), + mock.WithHeader(cosmosHeaderRequestCharge, "13.42"), + mock.WithStatusCode(200)) + + verifier := pipelineVerifier{} + headerPolicy := &headerPolicies{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + gem := &globalEndpointManager{preferredLocations: []string{}} + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + + database, _ := newDatabase("databaseId", client) + container, _ := newContainer("containerId", database) + + requestPriority := PriorityLevelHigh + requestBucket := int32(3) + queryPager := container.NewQueryItemsPager("select * from c", NewPartitionKeyString("1"), &QueryOptions{ + PriorityLevel: &requestPriority, + ThroughputBucket: &requestBucket, + }) + _, err := queryPager.NextPage(context.TODO()) + if err != nil { + t.Fatalf("Failed to query items: %v", err) + } + + h := verifier.requests[0].headers + if h.Get(cosmosHeaderPriorityLevel) != "High" { + t.Errorf("Expected priority level header to be High, got %v", h.Get(cosmosHeaderPriorityLevel)) + } + if h.Get(cosmosHeaderThroughputBucket) != "3" { + t.Errorf("Expected throughput bucket header to be 3, got %v", h.Get(cosmosHeaderThroughputBucket)) + } +} + +func TestTransactionalBatchPriorityAndThroughputBucketHeaders(t *testing.T) { + batchResponseRaw := []map[string]interface{}{ + {"statusCode": 200, "requestCharge": 10.0, "eTag": "someETag", "resourceBody": "someBody"}, + } + jsonString, _ := json.Marshal(batchResponseRaw) + + srv, close := mock.NewTLSServer() + defaultEndpoint, _ := url.Parse(srv.URL()) + defer close() + srv.SetResponse( + mock.WithBody(jsonString), + mock.WithStatusCode(http.StatusOK), + mock.WithHeader(cosmosHeaderEtag, "someEtag"), + mock.WithHeader(cosmosHeaderActivityId, "someActivityId"), + mock.WithHeader(cosmosHeaderRequestCharge, "13.42")) + + verifier := pipelineVerifier{} + headerPolicy := &headerPolicies{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + gem := &globalEndpointManager{preferredLocations: []string{}} + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + + database, _ := newDatabase("databaseId", client) + container, _ := newContainer("containerId", database) + + pk := NewPartitionKeyString("pk") + batch := container.NewTransactionalBatch(pk) + batch.ReadItem("someId", nil) + + requestPriority := PriorityLevelLow + requestBucket := int32(1) + _, err := container.ExecuteTransactionalBatch(context.TODO(), batch, &TransactionalBatchOptions{ + PriorityLevel: &requestPriority, + ThroughputBucket: &requestBucket, + }) + if err != nil { + t.Fatalf("Failed to execute batch: %v", err) + } + + h := verifier.requests[0].headers + if h.Get(cosmosHeaderPriorityLevel) != "Low" { + t.Errorf("Expected priority level header to be Low, got %v", h.Get(cosmosHeaderPriorityLevel)) + } + if h.Get(cosmosHeaderThroughputBucket) != "1" { + t.Errorf("Expected throughput bucket header to be 1, got %v", h.Get(cosmosHeaderThroughputBucket)) + } +} + +func TestChangeFeedPriorityAndThroughputBucketHeaders(t *testing.T) { + changeFeedBody := []byte(`{ + "_rid": "test-resource-id", + "Documents": [{"id": "doc1"}], + "_count": 1 + }`) + + pkRangesBody := []byte(`{ + "_rid": "test-resource-id", + "PartitionKeyRanges": [{ + "_rid": "range-rid", + "id": "0", + "minInclusive": "00", + "maxExclusive": "FF" + }], + "_count": 1 + }`) + + srv, close := mock.NewTLSServer() + defaultEndpoint, _ := url.Parse(srv.URL()) + defer close() + + srv.AppendResponse( + mock.WithBody(pkRangesBody), + mock.WithHeader(cosmosHeaderActivityId, "pkRangesActivityId"), + mock.WithHeader(cosmosHeaderRequestCharge, "1.0"), + mock.WithStatusCode(200)) + srv.AppendResponse( + mock.WithBody(changeFeedBody), + mock.WithHeader(cosmosHeaderEtag, "\"etag-12345\""), + mock.WithHeader(cosmosHeaderActivityId, "changeFeedActivityId"), + mock.WithHeader(cosmosHeaderRequestCharge, "3.5"), + mock.WithStatusCode(200)) + + verifier := pipelineVerifier{} + headerPolicy := &headerPolicies{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + gem := &globalEndpointManager{preferredLocations: []string{}} + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + database, _ := newDatabase("databaseId", client) + container, _ := newContainer("containerId", database) + + requestPriority := PriorityLevelHigh + requestBucket := int32(4) + feedRange := &FeedRange{ + MinInclusive: "00", + MaxExclusive: "FF", + } + _, err := container.GetChangeFeed(context.TODO(), &ChangeFeedOptions{ + MaxItemCount: 10, + FeedRange: feedRange, + PriorityLevel: &requestPriority, + ThroughputBucket: &requestBucket, + }) + if err != nil { + t.Fatalf("GetChangeFeed failed: %v", err) + } + + // The second request is the change feed request (first is pk ranges) + h := verifier.requests[1].headers + if h.Get(cosmosHeaderPriorityLevel) != "High" { + t.Errorf("Expected priority level header to be High, got %v", h.Get(cosmosHeaderPriorityLevel)) + } + if h.Get(cosmosHeaderThroughputBucket) != "4" { + t.Errorf("Expected throughput bucket header to be 4, got %v", h.Get(cosmosHeaderThroughputBucket)) + } +} + +func TestReadManyItemsPriorityAndThroughputBucketHeaders(t *testing.T) { + containerBody, _ := json.Marshal(ContainerProperties{ + ID: "containerId", + PartitionKeyDefinition: PartitionKeyDefinition{ + Paths: []string{"/pk"}, + }, + }) + + rangeBody, _ := json.Marshal(struct { + PartitionKeyRanges []partitionKeyRange `json:"PartitionKeyRanges"` + Count int `json:"_count"` + }{ + PartitionKeyRanges: []partitionKeyRange{ + { + ID: "0", + MinInclusive: "", + MaxExclusive: "", + }, + }, + Count: 1, + }) + + queryBody, _ := json.Marshal(map[string][]map[string]string{ + "Documents": { + {"id": "item1", "pk": "pk1"}, + }, + }) + + srv, close := mock.NewTLSServer() + defaultEndpoint, _ := url.Parse(srv.URL()) + defer close() + + // ReadMany needs: container properties, partition key ranges, query result + srv.AppendResponse( + mock.WithBody(containerBody), + mock.WithStatusCode(http.StatusOK), + mock.WithHeader(cosmosHeaderActivityId, "container-read"), + ) + srv.AppendResponse( + mock.WithBody(rangeBody), + mock.WithStatusCode(http.StatusOK), + mock.WithHeader(cosmosHeaderActivityId, "range-read"), + ) + srv.AppendResponse( + mock.WithBody(queryBody), + mock.WithStatusCode(http.StatusOK), + mock.WithHeader(cosmosHeaderActivityId, "query-read"), + mock.WithHeader(cosmosHeaderRequestCharge, "1.5"), + ) + + verifier := pipelineVerifier{} + headerPolicy := &headerPolicies{} + + internalClient, _ := azcore.NewClient("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + gem := &globalEndpointManager{preferredLocations: []string{}} + client := &Client{endpoint: srv.URL(), endpointUrl: defaultEndpoint, internal: internalClient, gem: gem} + database, _ := newDatabase("databaseId", client) + container, _ := newContainer("containerId", database) + + requestPriority := PriorityLevelLow + requestBucket := int32(2) + _, err := container.ReadManyItems(context.Background(), []ItemIdentity{ + { + ID: "item1", + PartitionKey: NewPartitionKeyString("pk1"), + }, + }, &ReadManyOptions{ + PriorityLevel: &requestPriority, + ThroughputBucket: &requestBucket, + }) + if err != nil { + t.Fatalf("ReadManyItems failed: %v", err) + } + + // The third request is the query (first is container read, second is pk ranges) + if len(verifier.requests) < 3 { + t.Fatalf("Expected at least 3 requests, got %d", len(verifier.requests)) + } + h := verifier.requests[2].headers + if h.Get(cosmosHeaderPriorityLevel) != "Low" { + t.Errorf("Expected priority level header to be Low, got %v", h.Get(cosmosHeaderPriorityLevel)) + } + if h.Get(cosmosHeaderThroughputBucket) != "2" { + t.Errorf("Expected throughput bucket header to be 2, got %v", h.Get(cosmosHeaderThroughputBucket)) + } +} diff --git a/sdk/data/azcosmos/cosmos_headers_policy.go b/sdk/data/azcosmos/cosmos_headers_policy.go index 2315dc0b4f50..3024a5c6c627 100644 --- a/sdk/data/azcosmos/cosmos_headers_policy.go +++ b/sdk/data/azcosmos/cosmos_headers_policy.go @@ -5,6 +5,7 @@ package azcosmos import ( "net/http" + "strconv" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid" @@ -12,18 +13,24 @@ import ( type headerPolicies struct { enableContentResponseOnWrite bool + priorityLevel *PriorityLevel + throughputBucket *int32 } type headerOptionsOverride struct { enableContentResponseOnWrite *bool partitionKey *PartitionKey correlatedActivityId *uuid.UUID + priorityLevel *PriorityLevel + throughputBucket *int32 } func (p *headerPolicies) Do(req *policy.Request) (*http.Response, error) { o := pipelineRequestOptions{} if req.OperationValue(&o) { enableContentResponseOnWrite := p.enableContentResponseOnWrite + priorityLevel := p.priorityLevel + throughputBucket := p.throughputBucket if o.headerOptionsOverride != nil { if o.headerOptionsOverride.enableContentResponseOnWrite != nil { @@ -41,11 +48,27 @@ func (p *headerPolicies) Do(req *policy.Request) (*http.Response, error) { if o.headerOptionsOverride.correlatedActivityId != nil { req.Raw().Header.Add(cosmosHeaderCorrelatedActivityId, o.headerOptionsOverride.correlatedActivityId.String()) } + + if o.headerOptionsOverride.priorityLevel != nil { + priorityLevel = o.headerOptionsOverride.priorityLevel + } + + if o.headerOptionsOverride.throughputBucket != nil { + throughputBucket = o.headerOptionsOverride.throughputBucket + } } if o.isWriteOperation && o.resourceType == resourceTypeDocument && !enableContentResponseOnWrite { req.Raw().Header.Add(cosmosHeaderPrefer, cosmosHeaderValuesPreferMinimal) } + + if priorityLevel != nil { + req.Raw().Header.Add(cosmosHeaderPriorityLevel, string(*priorityLevel)) + } + + if throughputBucket != nil { + req.Raw().Header.Add(cosmosHeaderThroughputBucket, strconv.FormatInt(int64(*throughputBucket), 10)) + } } return req.Next() diff --git a/sdk/data/azcosmos/cosmos_headers_policy_test.go b/sdk/data/azcosmos/cosmos_headers_policy_test.go index d4492850d02d..6cc5de0489c6 100644 --- a/sdk/data/azcosmos/cosmos_headers_policy_test.go +++ b/sdk/data/azcosmos/cosmos_headers_policy_test.go @@ -258,12 +258,182 @@ type headerPoliciesVerify struct { isEnableContentResponseOnWriteHeaderSet bool isPartitionKeyHeaderSet string isCorrelatedActivityIdSet string + priorityLevelHeaderValue string + throughputBucketHeaderValue string } func (p *headerPoliciesVerify) Do(req *policy.Request) (*http.Response, error) { p.isEnableContentResponseOnWriteHeaderSet = req.Raw().Header.Get(cosmosHeaderPrefer) != "" p.isPartitionKeyHeaderSet = req.Raw().Header.Get(cosmosHeaderPartitionKey) p.isCorrelatedActivityIdSet = req.Raw().Header.Get(cosmosHeaderCorrelatedActivityId) + p.priorityLevelHeaderValue = req.Raw().Header.Get(cosmosHeaderPriorityLevel) + p.throughputBucketHeaderValue = req.Raw().Header.Get(cosmosHeaderThroughputBucket) return req.Next() } + +func TestPriorityLevelHeaderFromClientDefault(t *testing.T) { + priority := PriorityLevelLow + headerPolicy := &headerPolicies{ + priorityLevel: &priority, + } + srv, close := mock.NewTLSServer() + defer close() + srv.SetResponse(mock.WithStatusCode(http.StatusOK)) + + verifier := headerPoliciesVerify{} + pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + req.SetOperationValue(pipelineRequestOptions{}) + + _, err = pl.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if verifier.priorityLevelHeaderValue != "Low" { + t.Fatalf("expected priority level header to be Low but got %v", verifier.priorityLevelHeaderValue) + } +} + +func TestPriorityLevelHeaderRequestOverridesClient(t *testing.T) { + clientPriority := PriorityLevelHigh + headerPolicy := &headerPolicies{ + priorityLevel: &clientPriority, + } + srv, close := mock.NewTLSServer() + defer close() + srv.SetResponse(mock.WithStatusCode(http.StatusOK)) + + verifier := headerPoliciesVerify{} + pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + requestPriority := PriorityLevelLow + req.SetOperationValue(pipelineRequestOptions{ + headerOptionsOverride: &headerOptionsOverride{ + priorityLevel: &requestPriority, + }, + }) + + _, err = pl.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if verifier.priorityLevelHeaderValue != "Low" { + t.Fatalf("expected priority level header to be Low (request override) but got %v", verifier.priorityLevelHeaderValue) + } +} + +func TestPriorityLevelHeaderNotSetWhenNil(t *testing.T) { + headerPolicy := &headerPolicies{} + srv, close := mock.NewTLSServer() + defer close() + srv.SetResponse(mock.WithStatusCode(http.StatusOK)) + + verifier := headerPoliciesVerify{} + pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + req.SetOperationValue(pipelineRequestOptions{}) + + _, err = pl.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if verifier.priorityLevelHeaderValue != "" { + t.Fatalf("expected no priority level header but got %v", verifier.priorityLevelHeaderValue) + } +} + +func TestThroughputBucketHeaderFromClientDefault(t *testing.T) { + bucket := int32(3) + headerPolicy := &headerPolicies{ + throughputBucket: &bucket, + } + srv, close := mock.NewTLSServer() + defer close() + srv.SetResponse(mock.WithStatusCode(http.StatusOK)) + + verifier := headerPoliciesVerify{} + pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + req.SetOperationValue(pipelineRequestOptions{}) + + _, err = pl.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if verifier.throughputBucketHeaderValue != "3" { + t.Fatalf("expected throughput bucket header to be 3 but got %v", verifier.throughputBucketHeaderValue) + } +} + +func TestThroughputBucketHeaderRequestOverridesClient(t *testing.T) { + clientBucket := int32(1) + headerPolicy := &headerPolicies{ + throughputBucket: &clientBucket, + } + srv, close := mock.NewTLSServer() + defer close() + srv.SetResponse(mock.WithStatusCode(http.StatusOK)) + + verifier := headerPoliciesVerify{} + pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + requestBucket := int32(5) + req.SetOperationValue(pipelineRequestOptions{ + headerOptionsOverride: &headerOptionsOverride{ + throughputBucket: &requestBucket, + }, + }) + + _, err = pl.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if verifier.throughputBucketHeaderValue != "5" { + t.Fatalf("expected throughput bucket header to be 5 (request override) but got %v", verifier.throughputBucketHeaderValue) + } +} + +func TestThroughputBucketHeaderNotSetWhenNil(t *testing.T) { + headerPolicy := &headerPolicies{} + srv, close := mock.NewTLSServer() + defer close() + srv.SetResponse(mock.WithStatusCode(http.StatusOK)) + + verifier := headerPoliciesVerify{} + pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{headerPolicy, &verifier}}, &policy.ClientOptions{Transport: srv}) + req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + req.SetOperationValue(pipelineRequestOptions{}) + + _, err = pl.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if verifier.throughputBucketHeaderValue != "" { + t.Fatalf("expected no throughput bucket header but got %v", verifier.throughputBucketHeaderValue) + } +} diff --git a/sdk/data/azcosmos/cosmos_http_constants.go b/sdk/data/azcosmos/cosmos_http_constants.go index 65a521956465..afa9f4da8f8c 100644 --- a/sdk/data/azcosmos/cosmos_http_constants.go +++ b/sdk/data/azcosmos/cosmos_http_constants.go @@ -89,6 +89,8 @@ const ( headerXmsItemCount string = "x-ms-item-count" headerDedicatedGatewayMaxAge string = "x-ms-dedicatedgateway-max-age" headerDedicatedGatewayBypassCache string = "x-ms-dedicatedgateway-bypass-cache" + cosmosHeaderPriorityLevel string = "x-ms-cosmos-priority-level" + cosmosHeaderThroughputBucket string = "x-ms-cosmos-throughput-bucket" ) const ( diff --git a/sdk/data/azcosmos/cosmos_item_request_options.go b/sdk/data/azcosmos/cosmos_item_request_options.go index d6b49c5e16a9..7983adea38b0 100644 --- a/sdk/data/azcosmos/cosmos_item_request_options.go +++ b/sdk/data/azcosmos/cosmos_item_request_options.go @@ -37,6 +37,13 @@ type ItemOptions struct { IfMatchEtag *azcore.ETag // Options for operations in the dedicated gateway. DedicatedGatewayRequestOptions *DedicatedGatewayRequestOptions + // PriorityLevel overrides the client-level default priority for this operation. + // Valid values are PriorityLevelHigh and PriorityLevelLow. + PriorityLevel *PriorityLevel + // ThroughputBucket overrides the client-level default throughput bucket for this operation. + // For more information, see https://aka.ms/CosmosDB/ThroughputBuckets + // The valid range is 1 to 5 (inclusive). + ThroughputBucket *int32 } func (options *ItemOptions) toHeaders() *map[string]string { diff --git a/sdk/data/azcosmos/cosmos_priority_level.go b/sdk/data/azcosmos/cosmos_priority_level.go new file mode 100644 index 000000000000..da735a8ef495 --- /dev/null +++ b/sdk/data/azcosmos/cosmos_priority_level.go @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +// PriorityLevel defines the priority level for a Cosmos DB request. +// When the total number of RU/s consumed exceeds the provisioned capacity, +// low-priority requests are throttled before high-priority requests. +// Valid values are PriorityLevelHigh and PriorityLevelLow. +// For more information, see https://aka.ms/CosmosDB/PriorityBasedExecution +type PriorityLevel string + +const ( + // PriorityLevelHigh is the default priority level. High-priority requests are served before low-priority requests. + PriorityLevelHigh PriorityLevel = "High" + // PriorityLevelLow marks a request as low priority. These requests are throttled first when over RU/s budget. + PriorityLevelLow PriorityLevel = "Low" +) + +// PriorityLevelValues returns a list of available priority levels. +func PriorityLevelValues() []PriorityLevel { + return []PriorityLevel{PriorityLevelHigh, PriorityLevelLow} +} + +// ToPtr returns a *PriorityLevel. +func (p PriorityLevel) ToPtr() *PriorityLevel { + return &p +} diff --git a/sdk/data/azcosmos/cosmos_priority_level_test.go b/sdk/data/azcosmos/cosmos_priority_level_test.go new file mode 100644 index 000000000000..14bc4510d677 --- /dev/null +++ b/sdk/data/azcosmos/cosmos_priority_level_test.go @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPriorityLevelValues(t *testing.T) { + values := PriorityLevelValues() + require.Len(t, values, 2, "expected 2 priority levels") + require.Equal(t, PriorityLevelHigh, values[0], "expected first value to be High") + require.Equal(t, PriorityLevelLow, values[1], "expected second value to be Low") +} + +func TestPriorityLevelToPtr(t *testing.T) { + ptr := PriorityLevelHigh.ToPtr() + require.NotNil(t, ptr, "expected non-nil pointer") + require.Equal(t, PriorityLevelHigh, *ptr, "expected High") +} diff --git a/sdk/data/azcosmos/cosmos_query_request_options.go b/sdk/data/azcosmos/cosmos_query_request_options.go index 1ae23550a31b..d41a51c2dc14 100644 --- a/sdk/data/azcosmos/cosmos_query_request_options.go +++ b/sdk/data/azcosmos/cosmos_query_request_options.go @@ -49,6 +49,13 @@ type QueryOptions struct { // QueryEngine can be set to enable the use of an external query engine for processing cross-partition queries. // This is a preview feature, which is NOT SUPPORTED in production, and is subject to breaking changes. QueryEngine queryengine.QueryEngine + // PriorityLevel overrides the client-level default priority for this operation. + // Valid values are PriorityLevelHigh and PriorityLevelLow. + PriorityLevel *PriorityLevel + // ThroughputBucket overrides the client-level default throughput bucket for this operation. + // For more information, see https://aka.ms/CosmosDB/ThroughputBuckets + // The valid range is 1 to 5 (inclusive). + ThroughputBucket *int32 } func (options *QueryOptions) toHeaders() *map[string]string { diff --git a/sdk/data/azcosmos/cosmos_read_many_request_options.go b/sdk/data/azcosmos/cosmos_read_many_request_options.go index dc25d5304a88..1e8cec12e839 100644 --- a/sdk/data/azcosmos/cosmos_read_many_request_options.go +++ b/sdk/data/azcosmos/cosmos_read_many_request_options.go @@ -25,6 +25,13 @@ type ReadManyOptions struct { // MaxConcurrency indicates the maximum number of concurrent operations to use when reading many items. // If not set, the SDK will determine an optimal number of concurrent operations to use. MaxConcurrency *int32 + // PriorityLevel overrides the client-level default priority for this operation. + // Valid values are PriorityLevelHigh and PriorityLevelLow. + PriorityLevel *PriorityLevel + // ThroughputBucket overrides the client-level default throughput bucket for this operation. + // For more information, see https://aka.ms/CosmosDB/ThroughputBuckets + // The valid range is 1 to 5 (inclusive). + ThroughputBucket *int32 } func (options *ReadManyOptions) toHeaders() *map[string]string { diff --git a/sdk/data/azcosmos/cosmos_transactional_batch_options.go b/sdk/data/azcosmos/cosmos_transactional_batch_options.go index 22ed6ddd57c3..ae51c1d466d2 100644 --- a/sdk/data/azcosmos/cosmos_transactional_batch_options.go +++ b/sdk/data/azcosmos/cosmos_transactional_batch_options.go @@ -23,6 +23,13 @@ type TransactionalBatchOptions struct { // When EnableContentResponseOnWrite is false, the operations in the batch response will have no body, except when they are Read operations. // The default is false. EnableContentResponseOnWrite bool + // PriorityLevel overrides the client-level default priority for this operation. + // Valid values are PriorityLevelHigh and PriorityLevelLow. + PriorityLevel *PriorityLevel + // ThroughputBucket overrides the client-level default throughput bucket for this operation. + // For more information, see https://aka.ms/CosmosDB/ThroughputBuckets + // The valid range is 1 to 5 (inclusive). + ThroughputBucket *int32 } // TransactionalBatchItemOptions includes options for the specific operation inside a TransactionalBatch