diff --git a/sdk/data/azcosmos/CHANGELOG.md b/sdk/data/azcosmos/CHANGELOG.md index b33416c8c8f6..2fc4a603a779 100644 --- a/sdk/data/azcosmos/CHANGELOG.md +++ b/sdk/data/azcosmos/CHANGELOG.md @@ -4,6 +4,7 @@ ### Features Added * Added `NewClientFromConnectionString` function to create client from connection string +* Added support for parametrized queries through `QueryOptions.QueryParameters` ### Breaking Changes diff --git a/sdk/data/azcosmos/cosmos_client.go b/sdk/data/azcosmos/cosmos_client.go index d42f033eaa6c..3ab76e1e8443 100644 --- a/sdk/data/azcosmos/cosmos_client.go +++ b/sdk/data/azcosmos/cosmos_client.go @@ -203,6 +203,7 @@ func (c *Client) sendQueryRequest( path string, ctx context.Context, query string, + parameters []QueryParameter, operationContext pipelineRequestOptions, requestOptions cosmosRequestOptions, requestEnricher func(*policy.Request)) (*http.Response, error) { @@ -211,13 +212,11 @@ func (c *Client) sendQueryRequest( return nil, err } - type queryBody struct { - Query string `json:"query"` - } - err = azruntime.MarshalAsJSON(req, queryBody{ - Query: query, + Query: query, + Parameters: parameters, }) + if err != nil { return nil, err } diff --git a/sdk/data/azcosmos/cosmos_client_test.go b/sdk/data/azcosmos/cosmos_client_test.go index ccb0eccab9b3..b9b15dc7a9e1 100644 --- a/sdk/data/azcosmos/cosmos_client_test.go +++ b/sdk/data/azcosmos/cosmos_client_test.go @@ -388,7 +388,7 @@ func TestSendQuery(t *testing.T) { resourceAddress: "", } - _, err := client.sendQueryRequest("/", context.Background(), "SELECT * FROM c", operationContext, &DeleteDatabaseOptions{}, nil) + _, err := client.sendQueryRequest("/", context.Background(), "SELECT * FROM c", []QueryParameter{}, operationContext, &DeleteDatabaseOptions{}, nil) if err != nil { t.Fatal(err) } @@ -410,6 +410,48 @@ func TestSendQuery(t *testing.T) { } } +func TestSendQueryWithParameters(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.SetResponse( + mock.WithStatusCode(200)) + verifier := pipelineVerifier{} + pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{&verifier}}, &policy.ClientOptions{Transport: srv}) + client := &Client{endpoint: srv.URL(), pipeline: pl} + operationContext := pipelineRequestOptions{ + resourceType: resourceTypeDatabase, + resourceAddress: "", + } + + parameters := []QueryParameter{ + {"@id", "1"}, + {"@status", "enabled"}, + } + + _, err := client.sendQueryRequest("/", context.Background(), "SELECT * FROM c WHERE c.id = @id and c.status = @status", parameters, operationContext, &DeleteDatabaseOptions{}, nil) + if err != nil { + t.Fatal(err) + } + + if verifier.requests[0].method != http.MethodPost { + t.Errorf("Expected %v, but got %v", http.MethodPost, verifier.requests[0].method) + } + + if verifier.requests[0].isQuery != true { + t.Errorf("Expected %v, but got %v", true, verifier.requests[0].isQuery) + } + + if verifier.requests[0].contentType != cosmosHeaderValuesQuery { + t.Errorf("Expected %v, but got %v", cosmosHeaderValuesQuery, verifier.requests[0].contentType) + } + + expectedSerializedQuery := "{\"query\":\"SELECT * FROM c WHERE c.id = @id and c.status = @status\",\"parameters\":[{\"name\":\"@id\",\"value\":\"1\"},{\"name\":\"@status\",\"value\":\"enabled\"}]}" + + if verifier.requests[0].body != expectedSerializedQuery { + t.Errorf("Expected %v, but got %v", expectedSerializedQuery, verifier.requests[0].body) + } +} + func TestSendBatch(t *testing.T) { srv, close := mock.NewTLSServer() defer close() diff --git a/sdk/data/azcosmos/cosmos_container.go b/sdk/data/azcosmos/cosmos_container.go index 79479b58802d..ec2e535bc2b4 100644 --- a/sdk/data/azcosmos/cosmos_container.go +++ b/sdk/data/azcosmos/cosmos_container.go @@ -444,6 +444,7 @@ func (c *ContainerClient) NewQueryItemsPager(query string, partitionKey Partitio path, ctx, query, + queryOptions.QueryParameters, operationContext, queryOptions, nil) diff --git a/sdk/data/azcosmos/cosmos_offers.go b/sdk/data/azcosmos/cosmos_offers.go index f5f841c80c07..ba89864ff4bf 100644 --- a/sdk/data/azcosmos/cosmos_offers.go +++ b/sdk/data/azcosmos/cosmos_offers.go @@ -38,6 +38,7 @@ func (c cosmosOffers) ReadThroughputIfExists( path, ctx, fmt.Sprintf(`SELECT * FROM c WHERE c.offerResourceId = '%s'`, targetRID), + nil, operationContext, requestOptions, nil) diff --git a/sdk/data/azcosmos/cosmos_query.go b/sdk/data/azcosmos/cosmos_query.go new file mode 100644 index 000000000000..1f73b8c7dce3 --- /dev/null +++ b/sdk/data/azcosmos/cosmos_query.go @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +// QueryParameter represents a parameter for a parametrized query. +type QueryParameter struct { + // Name represents the name of the parameter in the parametrized query. + Name string `json:"name"` + // Value represents the value of the parameter in the parametrized query. + Value any `json:"value"` +} + +type queryBody struct { + Query string `json:"query"` + Parameters []QueryParameter `json:"parameters,omitempty"` +} diff --git a/sdk/data/azcosmos/cosmos_query_request_options.go b/sdk/data/azcosmos/cosmos_query_request_options.go index cba848f1438f..f0a23b702c97 100644 --- a/sdk/data/azcosmos/cosmos_query_request_options.go +++ b/sdk/data/azcosmos/cosmos_query_request_options.go @@ -31,6 +31,9 @@ type QueryOptions struct { // ContinuationToken to be used to continue a previous query execution. // Obtained from QueryItemsResponse.ContinuationToken. ContinuationToken string + // QueryParameters allows execution of parametrized queries. + // See https://docs.microsoft.com/azure/cosmos-db/sql/sql-query-parameterized-queries + QueryParameters []QueryParameter } func (options *QueryOptions) toHeaders() *map[string]string { diff --git a/sdk/data/azcosmos/doc.go b/sdk/data/azcosmos/doc.go index d8b5fb48bcbd..3a58fc0f8ac3 100644 --- a/sdk/data/azcosmos/doc.go +++ b/sdk/data/azcosmos/doc.go @@ -144,6 +144,26 @@ Querying items } } +Querying items with parametrized queries + &opt := azcosmos.QueryOptions{ + azcosmos.QueryParameters: []QueryParameter{ + {"@value", "2"}, + }, + } + pk := azcosmos.NewPartitionKeyString("myPartitionKeyValue") + queryPager := container.NewQueryItemsPager("select * from docs c where c.value = @value", pk, opt) + for queryPager.More() { + queryResponse, err := queryPager.NextPage(context) + if err != nil { + handle(err) + } + + for _, item := range queryResponse.Items { + var itemResponseBody map[string]interface{} + json.Unmarshal(item, &itemResponseBody) + } + } + Using Transactional batch pk := azcosmos.NewPartitionKeyString("myPartitionKeyValue") diff --git a/sdk/data/azcosmos/emulator_cosmos_query_test.go b/sdk/data/azcosmos/emulator_cosmos_query_test.go index 321cb5b19fd9..7b669ba0e695 100644 --- a/sdk/data/azcosmos/emulator_cosmos_query_test.go +++ b/sdk/data/azcosmos/emulator_cosmos_query_test.go @@ -169,6 +169,56 @@ func TestSinglePartitionQuery(t *testing.T) { } } +func TestSinglePartitionQueryWithParameters(t *testing.T) { + emulatorTests := newEmulatorTests(t) + client := emulatorTests.getClient(t) + + database := emulatorTests.createDatabase(t, context.TODO(), client, "queryTests") + defer emulatorTests.deleteDatabase(t, context.TODO(), database) + properties := ContainerProperties{ + ID: "aContainer", + PartitionKeyDefinition: PartitionKeyDefinition{ + Paths: []string{"/pk"}, + }, + } + + _, err := database.CreateContainer(context.TODO(), properties, nil) + if err != nil { + t.Fatalf("Failed to create container: %v", err) + } + + container, _ := database.NewContainer("aContainer") + documentsPerPk := 1 + createSampleItems(t, container, documentsPerPk) + + receivedIds := []string{} + opt := QueryOptions{ + QueryParameters: []QueryParameter{ + {"@prop", "2"}, + }, + } + queryPager := container.NewQueryItemsPager("select * from c where c.someProp = @prop", NewPartitionKeyString("1"), &opt) + for queryPager.More() { + queryResponse, err := queryPager.NextPage(context.TODO()) + if err != nil { + t.Fatalf("Failed to query items: %v", err) + } + + for _, item := range queryResponse.Items { + var itemResponseBody map[string]interface{} + err = json.Unmarshal(item, &itemResponseBody) + if err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + receivedIds = append(receivedIds, itemResponseBody["id"].(string)) + } + } + + if len(receivedIds) != 1 { + t.Fatalf("Expected 1 document, got %d", len(receivedIds)) + } +} + func createSampleItems(t *testing.T, container *ContainerClient, documentsPerPk int) { for i := 0; i < documentsPerPk; i++ { item := map[string]string{ diff --git a/sdk/data/azcosmos/example_test.go b/sdk/data/azcosmos/example_test.go index 7ffc986c27bc..3af6ff9dc5ce 100644 --- a/sdk/data/azcosmos/example_test.go +++ b/sdk/data/azcosmos/example_test.go @@ -615,6 +615,63 @@ func ExampleContainerClient_NewQueryItemsPager() { } } +// Azure Cosmos DB supports queries with parameters expressed by the familiar @ notation. +// Parameterized SQL provides robust handling and escaping of user input, and prevents accidental exposure of data through SQL injection. +func ExampleContainerClient_NewQueryItemsPager_parametrizedQueries() { + endpoint, ok := os.LookupEnv("AZURE_COSMOS_ENDPOINT") + if !ok { + panic("AZURE_COSMOS_ENDPOINT could not be found") + } + + key, ok := os.LookupEnv("AZURE_COSMOS_KEY") + if !ok { + panic("AZURE_COSMOS_KEY could not be found") + } + + cred, err := azcosmos.NewKeyCredential(key) + if err != nil { + panic(err) + } + + client, err := azcosmos.NewClientWithKey(endpoint, cred, nil) + if err != nil { + panic(err) + } + + container, err := client.NewContainer("databaseName", "aContainer") + if err != nil { + panic(err) + } + + opt := &azcosmos.QueryOptions{ + QueryParameters: []azcosmos.QueryParameter{ + {"@value", "2"}, + }, + } + + pk := azcosmos.NewPartitionKeyString("newPartitionKey") + + queryPager := container.NewQueryItemsPager("select * from docs c where c.value = @value", pk, opt) + for queryPager.More() { + queryResponse, err := queryPager.NextPage(context.Background()) + if err != nil { + var responseErr *azcore.ResponseError + errors.As(err, &responseErr) + panic(responseErr) + } + + for _, item := range queryResponse.Items { + var itemResponseBody map[string]interface{} + err = json.Unmarshal(item, &itemResponseBody) + if err != nil { + panic(err) + } + } + + fmt.Printf("Query page received with %v items. ActivityId %s consuming %v RU", len(queryResponse.Items), queryResponse.ActivityID, queryResponse.RequestCharge) + } +} + func ExampleContainerClient_NewTransactionalBatch() { endpoint, ok := os.LookupEnv("AZURE_COSMOS_ENDPOINT") if !ok {