diff --git a/Microsoft.Azure.Cosmos/src/CosmosClient.cs b/Microsoft.Azure.Cosmos/src/CosmosClient.cs index 23e40b2eb3..5174362bb6 100644 --- a/Microsoft.Azure.Cosmos/src/CosmosClient.cs +++ b/Microsoft.Azure.Cosmos/src/CosmosClient.cs @@ -1173,7 +1173,7 @@ public virtual FeedIterator GetDatabaseQueryStreamIterator( #endif virtual DistributedWriteTransaction CreateDistributedWriteTransaction() { - return new DistributedWriteTransactionCore(); + return new DistributedWriteTransactionCore(this.ClientContext); } /// diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitter.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitter.cs new file mode 100644 index 0000000000..692af6840f --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitter.cs @@ -0,0 +1,104 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos +{ + using System; + using System.Collections.Generic; + using System.IO; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Core.Trace; + using Microsoft.Azure.Cosmos.Tracing; + using Microsoft.Azure.Documents; + + internal class DistributedTransactionCommitter + { + // TODO: Move to HttpConstants.HttpHeaders once DTC headers are added centrally + private const string IdempotencyTokenHeader = "x-ms-dtc-operation-id"; + + private readonly IReadOnlyList operations; + private readonly CosmosClientContext clientContext; + + public DistributedTransactionCommitter( + IReadOnlyList operations, + CosmosClientContext clientContext) + { + this.operations = operations ?? throw new ArgumentNullException(nameof(operations)); + this.clientContext = clientContext ?? throw new ArgumentNullException(nameof(clientContext)); + } + + public async Task CommitTransactionAsync(CancellationToken cancellationToken) + { + try + { + cancellationToken.ThrowIfCancellationRequested(); + await DistributedTransactionCommitterUtils.ResolveCollectionRidsAsync( + this.operations, + this.clientContext, + cancellationToken); + + DistributedTransactionServerRequest serverRequest = await DistributedTransactionServerRequest.CreateAsync( + this.operations, + this.clientContext.SerializerCore, + cancellationToken); + + return await this.ExecuteCommitAsync(serverRequest, cancellationToken); + } + catch (Exception ex) + { + DefaultTrace.TraceError($"Distributed transaction failed: {ex.Message}"); + // await this.AbortTransactionAsync(cancellationToken); + throw; + } + } + + private async Task ExecuteCommitAsync( + DistributedTransactionServerRequest serverRequest, + CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + using (ITrace trace = Trace.GetRootTrace("Execute Distributed Transaction Commit", TraceComponent.Batch, TraceLevel.Info)) + { + using (MemoryStream bodyStream = serverRequest.TransferBodyStream()) + { + ResponseMessage responseMessage = await this.clientContext.ProcessResourceOperationStreamAsync( + resourceUri: "/dtc/ops", + resourceType: ResourceType.Document, // TODO: Update to a new ResourceType specific to DTC + operationType: OperationType.Batch, // TODO: Update to a new OperationType specific to DTC + requestOptions: null, + cosmosContainerCore: null, + partitionKey: null, + itemId: null, + streamPayload: bodyStream, + requestEnricher: requestMessage => this.EnrichRequestMessage(requestMessage, serverRequest), + trace: trace, + cancellationToken: cancellationToken); + + cancellationToken.ThrowIfCancellationRequested(); + + return await DistributedTransactionResponse.FromResponseMessageAsync( + responseMessage, + serverRequest, + this.clientContext.SerializerCore, + serverRequest.IdempotencyToken, + trace, + cancellationToken); + } + } + } + + private void EnrichRequestMessage(RequestMessage requestMessage, DistributedTransactionServerRequest serverRequest) + { + // Set DTC-specific headers + requestMessage.Headers.Add(IdempotencyTokenHeader, serverRequest.IdempotencyToken.ToString()); + requestMessage.UseGatewayMode = true; + } + + private Task AbortTransactionAsync(CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + } +} diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitterUtils.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitterUtils.cs new file mode 100644 index 0000000000..8647dd904e --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitterUtils.cs @@ -0,0 +1,47 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos +{ + using System; + using System.Collections.Generic; + using System.Linq; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Tracing; + using Microsoft.Azure.Documents; + + internal class DistributedTransactionCommitterUtils + { + public static async Task ResolveCollectionRidsAsync( + IReadOnlyList operations, + CosmosClientContext clientContext, + CancellationToken cancellationToken) + { + IEnumerable> groupedOperations = operations + .GroupBy(op => $"/dbs/{op.Database}/colls/{op.Container}"); + + foreach (IGrouping group in groupedOperations) + { + cancellationToken.ThrowIfCancellationRequested(); + + string collectionPath = group.Key; + ContainerProperties containerProperties = await clientContext.GetCachedContainerPropertiesAsync( + collectionPath, + NoOpTrace.Singleton, + cancellationToken); + + string containerResourceId = containerProperties.ResourceId; + ResourceId resourceId = ResourceId.Parse(containerResourceId); + string databaseResourceId = resourceId.DatabaseId.ToString(); + + foreach (DistributedTransactionOperation operation in group) + { + operation.CollectionResourceId = containerResourceId; + operation.DatabaseResourceId = databaseResourceId; + } + } + } + } +} diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperation.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperation.cs index 5843fdfcc7..6ad3351298 100644 --- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperation.cs +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperation.cs @@ -6,16 +6,19 @@ namespace Microsoft.Azure.Cosmos { using System; using System.IO; - - //using Microsoft.Azure.Documents; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Documents; /// - /// Represents an operation on a document whichwill be executed as a part of a distributed transaction. + /// Represents an operation on a document which will be executed as a part of a distributed transaction. /// internal class DistributedTransactionOperation { + protected Memory body; + public DistributedTransactionOperation( - Documents.OperationType operationType, + OperationType operationType, int operationIndex, string database, string container, @@ -36,11 +39,37 @@ public DistributedTransactionOperation( public string Container { get; internal set; } - public Documents.OperationType OperationType { get; internal set; } + public OperationType OperationType { get; internal set; } - public int OperationIndex { get; internal set; } + public int OperationIndex { get; internal set; } public string Id { get; internal set; } + + public string CollectionResourceId { get; internal set; } + + public string DatabaseResourceId { get; internal set; } + + internal string PartitionKeyJson { get; set; } + + internal string SessionToken { get; set; } + + internal string ETag { get; set; } + + internal Stream ResourceStream { get; set; } + + internal Memory ResourceBody + { + get => this.body; + set => this.body = value; + } + + internal virtual async Task MaterializeResourceAsync(CosmosSerializerCore serializerCore, CancellationToken cancellationToken) + { + if (this.body.IsEmpty && this.ResourceStream != null) + { + this.body = await BatchExecUtils.StreamToMemoryAsync(this.ResourceStream, cancellationToken); + } + } } internal class DistributedTransactionOperation : DistributedTransactionOperation @@ -69,6 +98,18 @@ public DistributedTransactionOperation( { this.Resource = resource; } + public T Resource { get; internal set; } + + internal override Task MaterializeResourceAsync(CosmosSerializerCore serializerCore, CancellationToken cancellationToken) + { + if (this.body.IsEmpty && this.Resource != null) + { + this.ResourceStream = serializerCore.ToStream(this.Resource); + return base.MaterializeResourceAsync(serializerCore, cancellationToken); + } + + return Task.CompletedTask; + } } } diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperationResult.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperationResult.cs index 79d866206e..3618c55a86 100644 --- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperationResult.cs +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperationResult.cs @@ -7,6 +7,8 @@ namespace Microsoft.Azure.Cosmos using System; using System.IO; using System.Net; + using System.Text.Json; + using System.Text.Json.Serialization; using Microsoft.Azure.Cosmos.Tracing; using Microsoft.Azure.Documents; @@ -27,11 +29,14 @@ internal DistributedTransactionOperationResult(HttpStatusCode statusCode) internal DistributedTransactionOperationResult(DistributedTransactionOperationResult other) { + this.Index = other.Index; this.StatusCode = other.StatusCode; this.SubStatusCode = other.SubStatusCode; this.ETag = other.ETag; this.ResourceStream = other.ResourceStream; this.SessionToken = other.SessionToken; + this.RequestCharge = other.RequestCharge; + this.ActivityId = other.ActivityId; this.Trace = other.Trace; } @@ -39,39 +44,97 @@ internal DistributedTransactionOperationResult(DistributedTransactionOperationRe /// Initializes a new instance of the class. /// This protected constructor is intended for use by derived classes. /// + [JsonConstructor] protected DistributedTransactionOperationResult() { } + /// + /// Gets the index of this operation within the distributed transaction. + /// + [JsonInclude] + [JsonPropertyName("index")] + public virtual int Index { get; internal set; } + /// /// Gets the HTTP status code returned by the operation. /// + [JsonInclude] + [JsonPropertyName("statuscode")] public virtual HttpStatusCode StatusCode { get; internal set; } /// /// Gets a value indicating whether the HTTP status code returned by the operation indicates success. /// - public virtual bool IsSuccessStatusCode => ((int)this.StatusCode >= 200) && ((int)this.StatusCode <= 299); + [JsonIgnore] + public virtual bool IsSuccessStatusCode => (int)this.StatusCode >= 200 && (int)this.StatusCode <= 299; /// /// Gets the entity tag (ETag) associated with the operation result. /// The ETag is used for concurrency control and represents the version of the resource. /// + [JsonInclude] + [JsonPropertyName("etag")] public virtual string ETag { get; internal set; } /// /// Gets the session token associated with the operation result. /// + [JsonInclude] + [JsonPropertyName("sessionToken")] public virtual string SessionToken { get; internal set; } /// /// Gets the resource stream associated with the operation result. /// The stream contains the raw response payload returned by the operation. /// + [JsonIgnore] public virtual Stream ResourceStream { get; internal set; } + /// + /// Used for JSON deserialization of the base64-encoded resource body. + /// + [JsonInclude] + [JsonPropertyName("resourcebody")] + internal string ResourceBodyBase64 + { + get => null; // Write-only for deserialization + set + { + if (!string.IsNullOrEmpty(value)) + { + byte[] resourceBody = Convert.FromBase64String(value); + this.ResourceStream = new MemoryStream(resourceBody, 0, resourceBody.Length, writable: false, publiclyVisible: true); + } + } + } + + /// + /// Request charge in request units for the operation. + /// + [JsonPropertyName("requestCharge")] + internal virtual double RequestCharge { get; set; } + + [JsonPropertyName("substatuscode")] internal virtual SubStatusCodes SubStatusCode { get; set; } + /// + /// ActivityId related to the operation. + /// + [JsonIgnore] + internal virtual string ActivityId { get; set; } + + [JsonIgnore] internal ITrace Trace { get; set; } + + /// + /// Creates a from a JSON element. + /// + /// The JSON element containing the operation result. + /// The deserialized operation result. + internal static DistributedTransactionOperationResult FromJson(JsonElement json) + { + return JsonSerializer.Deserialize(json); + } } } diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionRequest.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionRequest.cs new file mode 100644 index 0000000000..3b923be8f9 --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionRequest.cs @@ -0,0 +1,32 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos +{ + using System; + using System.Collections.Generic; + using Microsoft.Azure.Documents; + + internal class DistributedTransactionRequest + { + public DistributedTransactionRequest( + IReadOnlyList operations, + OperationType operationType = OperationType.Batch, + ResourceType resourceType = ResourceType.Document) + { + this.Operations = operations ?? throw new ArgumentNullException(nameof(operations)); + this.IdempotencyToken = Guid.NewGuid(); + this.OperationType = operationType; + this.ResourceType = resourceType; + } + + public Guid IdempotencyToken { get; set; } + + public OperationType OperationType { get; set; } + + public ResourceType ResourceType { get; set; } + + public IReadOnlyList Operations { get; set; } + } +} diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs index d4e16edc7a..490c782d9b 100644 --- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs @@ -7,7 +7,11 @@ namespace Microsoft.Azure.Cosmos using System; using System.Collections; using System.Collections.Generic; + using System.IO; using System.Net; + using System.Text.Json; + using System.Threading; + using System.Threading.Tasks; using Microsoft.Azure.Cosmos.Tracing; using Microsoft.Azure.Documents; @@ -19,22 +23,33 @@ namespace Microsoft.Azure.Cosmos #else internal #endif - class DistributedTransactionResponse : IReadOnlyList + class DistributedTransactionResponse : IReadOnlyList, IDisposable { - private readonly List results = null; + private const string IdempotencyTokenHeader = "x-ms-dtc-operation-id"; + + private List results; + private bool isDisposed; private DistributedTransactionResponse( HttpStatusCode statusCode, SubStatusCodes subStatusCode, string errorMessage, Headers headers, - IReadOnlyList operations) + IReadOnlyList operations, + CosmosSerializerCore serializer, + ITrace trace, + Guid idempotencyToken, + string serverDiagnostics = null) { this.Headers = headers; this.StatusCode = statusCode; this.SubStatusCode = subStatusCode; this.ErrorMessage = errorMessage; this.Operations = operations; + this.SerializerCore = serializer; + this.Trace = trace; + this.IdempotencyToken = idempotencyToken; + this.ServerDiagnostics = serverDiagnostics; } /// @@ -53,6 +68,8 @@ public virtual DistributedTransactionOperationResult this[int index] { get { + this.ThrowIfDisposed(); + if (this.results == null || index < 0 || index >= this.results.Count) { throw new ArgumentOutOfRangeException(nameof(index), "Index is out of range."); @@ -65,21 +82,32 @@ public virtual DistributedTransactionOperationResult this[int index] /// /// Gets the headers associated with the distributed transaction response. /// - public virtual Headers Headers { get; internal set; } + public virtual Headers Headers { get; } /// - /// Gets the HTTP status code for the distributed transaction response. + /// Gets the ActivityId that identifies the server request made to execute the transaction. /// - public virtual HttpStatusCode StatusCode { get; internal set; } + public virtual string ActivityId => this.Headers?.ActivityId; - internal virtual SubStatusCodes SubStatusCode { get; } + /// + /// Gets the request charge for the distributed transaction request. + /// + public virtual double RequestCharge => this.Headers?.RequestCharge ?? 0; - internal IReadOnlyList Operations { get; set; } + /// + /// Gets the HTTP status code for the distributed transaction response. + /// + public virtual HttpStatusCode StatusCode { get; } + + /// + /// Gets a value indicating whether the transaction was processed successfully. + /// + public virtual bool IsSuccessStatusCode => (int)this.StatusCode >= 200 && (int)this.StatusCode <= 299; /// /// Gets the error message associated with the distributed transaction response, if any. /// - public virtual string ErrorMessage { get; internal set; } + public virtual string ErrorMessage { get; } /// /// Gets the number of operation results in the distributed transaction response. @@ -87,23 +115,268 @@ public virtual DistributedTransactionOperationResult this[int index] public virtual int Count => this.results?.Count ?? 0; /// - /// Returns an enumerator that iterates through the collection of distributed transaction operation results. + /// Gets the idempotency token associated with this distributed transaction. /// - /// An enumerator for the collection of objects. + public virtual Guid IdempotencyToken { get; } + + /// + /// Gets the server-side diagnostic information for the transaction. + /// + public virtual string ServerDiagnostics { get; } + + internal virtual SubStatusCodes SubStatusCode { get; } + + internal virtual CosmosSerializerCore SerializerCore { get; } + + internal IReadOnlyList Operations { get; } + + internal ITrace Trace { get; } + + /// + /// Returns an enumerator that iterates through the operation results. + /// + /// An enumerator for the operation results. public virtual IEnumerator GetEnumerator() { - return this.results.GetEnumerator(); + return this.results?.GetEnumerator() + ?? ((IList)Array.Empty()).GetEnumerator(); } - /// + /// + /// Returns an enumerator that iterates through the operation results. + /// + /// An enumerator for the operation results. IEnumerator IEnumerable.GetEnumerator() { return this.GetEnumerator(); } - private void CreateAndPopulateResults(IReadOnlyList operations, ITrace trace) + /// + /// Releases the unmanaged resources used by the and optionally releases the managed resources. + /// + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Creates a from a . + /// + internal static async Task FromResponseMessageAsync( + ResponseMessage responseMessage, + DistributedTransactionServerRequest serverRequest, + CosmosSerializerCore serializer, + Guid requestIdempotencyToken, + ITrace trace, + CancellationToken cancellationToken) + { + using (ITrace createResponseTrace = trace.StartChild("Create Distributed Transaction Response", TraceComponent.Batch, TraceLevel.Info)) + { + cancellationToken.ThrowIfCancellationRequested(); + + // Extract idempotency token from response headers, fallback to request token if not present + Guid idempotencyToken = GetIdempotencyTokenFromHeaders(responseMessage.Headers, requestIdempotencyToken); + + DistributedTransactionResponse response = null; + MemoryStream memoryStream = null; + + try + { + if (responseMessage.Content != null) + { + Stream content = responseMessage.Content; + + // Ensure the stream is seekable + if (!content.CanSeek) + { + memoryStream = new MemoryStream(); + await responseMessage.Content.CopyToAsync(memoryStream); + memoryStream.Position = 0; + content = memoryStream; + } + + response = await PopulateFromJsonContentAsync( + content, + responseMessage, + serverRequest, + serializer, + idempotencyToken, + createResponseTrace, + cancellationToken); + } + + // If we couldn't parse JSON content or there was no content, create default response + response ??= new DistributedTransactionResponse( + responseMessage.StatusCode, + responseMessage.Headers.SubStatusCode, + responseMessage.ErrorMessage, + responseMessage.Headers, + serverRequest.Operations, + serializer, + createResponseTrace, + idempotencyToken); + + // Validate results count matches operations count + if (response.results == null || response.results.Count != serverRequest.Operations.Count) + { + if (responseMessage.IsSuccessStatusCode) + { + // Server should guarantee results count equals operations count on success + return new DistributedTransactionResponse( + HttpStatusCode.InternalServerError, + SubStatusCodes.Unknown, + ClientResources.InvalidServerResponse, + responseMessage.Headers, + serverRequest.Operations, + serializer, + createResponseTrace, + idempotencyToken); + } + + response.CreateAndPopulateResults(serverRequest.Operations, createResponseTrace); + } + + return response; + } + finally + { + memoryStream?.Dispose(); + } + } + } + + /// + /// Disposes the disposable members held by this class. + /// + /// True to release both managed and unmanaged resources; false to release only unmanaged resources. + protected virtual void Dispose(bool disposing) + { + if (this.isDisposed) + { + return; + } + + if (disposing && this.results != null) + { + foreach (DistributedTransactionOperationResult result in this.results) + { + result.ResourceStream?.Dispose(); + } + } + + this.results = null; + this.isDisposed = true; + } + + private static Guid GetIdempotencyTokenFromHeaders(Headers headers, Guid fallbackToken) + { + if (headers != null && + headers.TryGetValue(IdempotencyTokenHeader, out string tokenValue) && + Guid.TryParse(tokenValue, out Guid idempotencyToken)) + { + return idempotencyToken; + } + + return fallbackToken; + } + + private void ThrowIfDisposed() + { + if (this.isDisposed) + { + throw new ObjectDisposedException(nameof(DistributedTransactionResponse)); + } + } + + private static async Task PopulateFromJsonContentAsync( + Stream content, + ResponseMessage responseMessage, + DistributedTransactionServerRequest serverRequest, + CosmosSerializerCore serializer, + Guid idempotencyToken, + ITrace trace, + CancellationToken cancellationToken) { - throw new NotImplementedException(); + List results = new List(); + + try + { + using (JsonDocument responseJson = await JsonDocument.ParseAsync(content, cancellationToken: cancellationToken)) + { + JsonElement root = responseJson.RootElement; + + // Parse operation results from "operationResponses" array + if (root.TryGetProperty("operationResponses", out JsonElement operationResponses) && + operationResponses.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement operationElement in operationResponses.EnumerateArray()) + { + cancellationToken.ThrowIfCancellationRequested(); + + DistributedTransactionOperationResult operationResult = DistributedTransactionOperationResult.FromJson(operationElement); + operationResult.Trace = trace; + operationResult.SessionToken ??= responseMessage.Headers.Session; + operationResult.ActivityId = responseMessage.Headers.ActivityId; + results.Add(operationResult); + } + } + } + } + catch (JsonException) + { + // If JSON parsing fails, return null to fall back to default response + return null; + } + + HttpStatusCode finalStatusCode = responseMessage.StatusCode; + SubStatusCodes finalSubStatusCode = responseMessage.Headers.SubStatusCode; + + // Promote operation error status for MultiStatus responses + if ((int)finalStatusCode == (int)StatusCodes.MultiStatus) + { + foreach (DistributedTransactionOperationResult result in results) + { + if ((int)result.StatusCode != (int)StatusCodes.FailedDependency && + (int)result.StatusCode >= (int)StatusCodes.StartingErrorCode) + { + finalStatusCode = result.StatusCode; + finalSubStatusCode = result.SubStatusCode; + break; + } + } + } + + return new DistributedTransactionResponse( + finalStatusCode, + finalSubStatusCode, + responseMessage.ErrorMessage, + responseMessage.Headers, + serverRequest.Operations, + serializer, + trace, + idempotencyToken) + { + results = results + }; + } + + private void CreateAndPopulateResults( + IReadOnlyList operations, + ITrace trace) + { + this.results = new List(operations.Count); + + for (int i = 0; i < operations.Count; i++) + { + this.results.Add(new DistributedTransactionOperationResult(this.StatusCode) + { + SubStatusCode = this.SubStatusCode, + SessionToken = this.Headers?.Session, + ActivityId = this.ActivityId, + Trace = trace + }); + } } } } diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionSerializer.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionSerializer.cs new file mode 100644 index 0000000000..2cbe7f5d1b --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionSerializer.cs @@ -0,0 +1,124 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos +{ + using System; + using System.Collections.Generic; + using System.IO; + using System.Text.Json; + using Microsoft.Azure.Documents; + + /// + /// Serializes distributed transaction requests to JSON format. + /// + internal static class DistributedTransactionSerializer + { + /// + /// Serializes a distributed transaction request body to a JSON stream. + /// The body contains only the operations array. Other metadata like idempotencyToken, + /// operationType, and resourceType are sent as HTTP headers per the spec. + /// + /// The list of operations to include in the request. + /// A MemoryStream containing the JSON-serialized request body. + public static MemoryStream SerializeRequest(IReadOnlyList operations) + { + MemoryStream stream = new MemoryStream(); + + using (Utf8JsonWriter jsonWriter = new Utf8JsonWriter(stream, new JsonWriterOptions { Indented = false })) + { + jsonWriter.WriteStartObject(); + + // operations array + jsonWriter.WriteStartArray("operations"); + + foreach (DistributedTransactionOperation operation in operations) + { + WriteOperation(jsonWriter, operation); + } + + jsonWriter.WriteEndArray(); + + jsonWriter.WriteEndObject(); + jsonWriter.Flush(); + } + + stream.Position = 0; + return stream; + } + + /// + /// Writes a single operation to the JSON writer. + /// Keys match the C++ DistributedTransactionOperation model. + /// + private static void WriteOperation(Utf8JsonWriter jsonWriter, DistributedTransactionOperation operation) + { + jsonWriter.WriteStartObject(); + + // databaseName + jsonWriter.WriteString("databaseName", operation.Database); + + // collectionName + jsonWriter.WriteString("collectionName", operation.Container); + + // collectionResourceId + if (operation.CollectionResourceId != null) + { + jsonWriter.WriteString("collectionResourceId", operation.CollectionResourceId); + } + + // databaseResourceId + if (operation.DatabaseResourceId != null) + { + jsonWriter.WriteString("databaseResourceId", operation.DatabaseResourceId); + } + + // id + if (operation.Id != null) + { + jsonWriter.WriteString("id", operation.Id); + } + + // partitionKey + if (operation.PartitionKeyJson != null) + { + jsonWriter.WriteString("partitionKey", operation.PartitionKeyJson); + } + + // index (uint32) + if (operation.OperationIndex < 0) + { + throw new ArgumentOutOfRangeException(nameof(operation.OperationIndex), "Operation index must be non-negative."); + } + jsonWriter.WriteNumber("index", (uint)operation.OperationIndex); + + //resourceBody - written as nested JSON object + if (!operation.ResourceBody.IsEmpty) + { + jsonWriter.WritePropertyName("resourceBody"); + jsonWriter.WriteRawValue(operation.ResourceBody.Span, skipInputValidation: true); + } + + // sessionToken + if (operation.SessionToken != null) + { + jsonWriter.WriteString("sessionToken", operation.SessionToken); + } + + // etag + if (operation.ETag != null) + { + jsonWriter.WriteString("etag", operation.ETag); + } + + // operationType (uint16) + jsonWriter.WriteNumber("operationType", (ushort)operation.OperationType); + + // resourceType (uint16) + jsonWriter.WriteNumber("resourceType", (ushort)ResourceType.Document); + + jsonWriter.WriteEndObject(); + } + } +} diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionServerRequest.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionServerRequest.cs new file mode 100644 index 0000000000..6cbab1e5f8 --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionServerRequest.cs @@ -0,0 +1,59 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos +{ + using System; + using System.Collections.Generic; + using System.IO; + using System.Threading; + using System.Threading.Tasks; + + internal class DistributedTransactionServerRequest + { + private readonly CosmosSerializerCore serializerCore; + private MemoryStream bodyStream; + + private DistributedTransactionServerRequest( + IReadOnlyList operations, + CosmosSerializerCore serializerCore) + { + this.Operations = operations ?? throw new ArgumentNullException(nameof(operations)); + this.serializerCore = serializerCore ?? throw new ArgumentNullException(nameof(serializerCore)); + this.IdempotencyToken = Guid.NewGuid(); + } + + public IReadOnlyList Operations { get; } + + public Guid IdempotencyToken { get; private set; } + + public static async Task CreateAsync( + IReadOnlyList operations, + CosmosSerializerCore serializerCore, + CancellationToken cancellationToken) + { + DistributedTransactionServerRequest request = new DistributedTransactionServerRequest(operations, serializerCore); + await request.CreateBodyStreamAsync(cancellationToken); + return request; + } + + public MemoryStream TransferBodyStream() + { + MemoryStream bodyStream = this.bodyStream; + this.bodyStream = null; + return bodyStream; + } + + private async Task CreateBodyStreamAsync(CancellationToken cancellationToken) + { + foreach (DistributedTransactionOperation operation in this.Operations) + { + await operation.MaterializeResourceAsync(this.serializerCore, cancellationToken); + operation.PartitionKeyJson ??= operation.PartitionKey.ToJsonString(); + } + + this.bodyStream = DistributedTransactionSerializer.SerializeRequest(this.Operations); + } + } +} diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedWriteTransactionCore.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedWriteTransactionCore.cs index 1b14613587..c2b39a8c1b 100644 --- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedWriteTransactionCore.cs +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedWriteTransactionCore.cs @@ -13,10 +13,12 @@ namespace Microsoft.Azure.Cosmos internal class DistributedWriteTransactionCore : DistributedWriteTransaction { - protected List operations; + private readonly CosmosClientContext clientContext; + private readonly List operations; - internal DistributedWriteTransactionCore() + internal DistributedWriteTransactionCore(CosmosClientContext clientContext) { + this.clientContext = clientContext ?? throw new ArgumentNullException(nameof(clientContext)); this.operations = new List(); } @@ -116,8 +118,11 @@ public override DistributedWriteTransaction UpsertItem(string database, strin public override async Task CommitTransactionAsync(CancellationToken cancellationToken) { - await Task.CompletedTask; - throw new NotImplementedException(); + DistributedTransactionCommitter committer = new DistributedTransactionCommitter( + operations: this.operations, + clientContext: this.clientContext); + + return await committer.CommitTransactionAsync(cancellationToken); } private static void ValidateContainerReference(string database, string collection) diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/DistributedTransaction/DistributedTransactionE2ETests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/DistributedTransaction/DistributedTransactionE2ETests.cs new file mode 100644 index 0000000000..4e44e622cd --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/DistributedTransaction/DistributedTransactionE2ETests.cs @@ -0,0 +1,293 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests +{ + using System; + using System.Collections.Generic; + using System.IO; + using System.Net; + using System.Text; + using System.Text.Json; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using OperationType = Documents.OperationType; + + [TestClass] + public class DistributedTransactionE2ETests : BaseCosmosClientHelper + { + private const string IdempotencyTokenHeader = "x-ms-dtc-operation-id"; + private const string PartitionKeyPath = "/pk"; + + private Container container; + + [TestInitialize] + public async Task TestInitialize() + { + await this.TestInit(); + + ContainerResponse response = await this.database.CreateContainerAsync( + new ContainerProperties(id: Guid.NewGuid().ToString(), partitionKeyPath: PartitionKeyPath), + cancellationToken: this.cancellationToken); + + this.container = response.Container; + } + + [TestCleanup] + public async Task Cleanup() + { + await base.TestCleanup(); + } + + [TestMethod] + public async Task ValidateHappyPathRequestAndResponse() + { + // Arrange + ToDoActivity doc1 = ToDoActivity.CreateRandomToDoActivity(); + ToDoActivity doc2 = ToDoActivity.CreateRandomToDoActivity(); + + DistributedTransactionTestHandler handler = CreateMockHandler( + HttpStatusCode.OK, + CreateMockSuccessResponse(operationCount: 2)); + + using CosmosClient client = TestCommon.CreateCosmosClient( + clientOptions: new CosmosClientOptions + { + CustomHandlers = { handler }, + ConnectionMode = ConnectionMode.Gateway + }); + + // Act + DistributedTransactionResponse response = await client.CreateDistributedWriteTransaction() + .CreateItem(this.database.Id, this.container.Id, new PartitionKey(doc1.pk), doc1) + .CreateItem(this.database.Id, this.container.Id, new PartitionKey(doc2.pk), doc2) + .CommitTransactionAsync(CancellationToken.None); + + // Assert - Request + Assert.IsNotNull(handler.CapturedRequest); + Assert.IsNotNull(handler.CapturedRequest.Headers[IdempotencyTokenHeader]); + ValidateRequestBody(handler.CapturedRequestBody, doc1, doc2); + + // Assert - Response + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode); + Assert.IsTrue(response.IsSuccessStatusCode); + Assert.AreEqual(2, response.Count); + + response.Dispose(); + } + + [TestMethod] + public async Task ValidateMixedOperationsRequestStructure() + { + // Arrange + ToDoActivity createDoc = ToDoActivity.CreateRandomToDoActivity(); + ToDoActivity replaceDoc = ToDoActivity.CreateRandomToDoActivity(); + + DistributedTransactionTestHandler handler = CreateMockHandler( + HttpStatusCode.OK, + CreateMockSuccessResponse(operationCount: 3)); + + using CosmosClient client = TestCommon.CreateCosmosClient( + clientOptions: new CosmosClientOptions + { + CustomHandlers = { handler }, + ConnectionMode = ConnectionMode.Gateway + }); + + // Act + DistributedTransactionResponse response = await client.CreateDistributedWriteTransaction() + .CreateItem(this.database.Id, this.container.Id, new PartitionKey(createDoc.pk), createDoc) + .ReplaceItem(this.database.Id, this.container.Id, new PartitionKey(replaceDoc.pk), replaceDoc.id, replaceDoc) + .DeleteItem(this.database.Id, this.container.Id, new PartitionKey("delete-pk"), "delete-id") + .CommitTransactionAsync(CancellationToken.None); + + // Assert + using JsonDocument requestJson = JsonDocument.Parse(handler.CapturedRequestBody); + JsonElement operations = requestJson.RootElement.GetProperty("operations"); + + Assert.AreEqual(3, operations.GetArrayLength()); + Assert.AreEqual((int)OperationType.Create, operations[0].GetProperty("operationType").GetInt32()); // Create + Assert.AreEqual((int)OperationType.Replace, operations[1].GetProperty("operationType").GetInt32()); // Replace + Assert.AreEqual((int)OperationType.Delete, operations[2].GetProperty("operationType").GetInt32()); // Delete + + response.Dispose(); + } + + [TestMethod] + public async Task ValidateConflictResponseReturnsErrorStatus() + { + // Arrange + string mockErrorResponse = @"{ + ""operationResponses"": [{ + ""index"": 0, + ""statuscode"": 409, + ""substatuscode"": 0 + }] + }"; + + DistributedTransactionTestHandler handler = CreateMockHandler(HttpStatusCode.Conflict, mockErrorResponse); + using CosmosClient client = TestCommon.CreateCosmosClient( + clientOptions: new CosmosClientOptions + { + CustomHandlers = { handler }, + ConnectionMode = ConnectionMode.Gateway + }); + + ToDoActivity doc = ToDoActivity.CreateRandomToDoActivity(); + + // Act + DistributedTransactionResponse response = await client.CreateDistributedWriteTransaction() + .CreateItem(this.database.Id, this.container.Id, new PartitionKey(doc.pk), doc) + .CommitTransactionAsync(CancellationToken.None); + + // Assert + Assert.AreEqual(HttpStatusCode.Conflict, response.StatusCode); + Assert.IsFalse(response.IsSuccessStatusCode); + Assert.AreEqual(1, response.Count); + Assert.AreEqual(HttpStatusCode.Conflict, response[0].StatusCode); + + response.Dispose(); + } + + [TestMethod] + public async Task ValidateResponseDeserializesCorrectly() + { + // Arrange + ToDoActivity expectedDoc = ToDoActivity.CreateRandomToDoActivity(); + string base64Body = Convert.ToBase64String(Encoding.UTF8.GetBytes(JsonSerializer.Serialize(expectedDoc))); + + string mockResponse = $@"{{ + ""operationResponses"": [{{ + ""index"": 0, + ""statuscode"": 201, + ""etag"": ""\""test-etag\"""", + ""resourcebody"": ""{base64Body}"" + }}] + }}"; + + DistributedTransactionTestHandler handler = CreateMockHandler(HttpStatusCode.OK, mockResponse); + using CosmosClient client = TestCommon.CreateCosmosClient( + clientOptions: new CosmosClientOptions + { + CustomHandlers = { handler }, + ConnectionMode = ConnectionMode.Gateway + }); + + // Act + DistributedTransactionResponse response = await client.CreateDistributedWriteTransaction() + .CreateItem(this.database.Id, this.container.Id, new PartitionKey(expectedDoc.pk), expectedDoc) + .CommitTransactionAsync(CancellationToken.None); + + // Assert + Assert.AreEqual(HttpStatusCode.Created, response[0].StatusCode); + Assert.AreEqual("\"test-etag\"", response[0].ETag); + Assert.IsNotNull(response[0].ResourceStream); + + using StreamReader reader = new StreamReader(response[0].ResourceStream); + ToDoActivity returnedDoc = JsonSerializer.Deserialize(await reader.ReadToEndAsync()); + + Assert.AreEqual(expectedDoc.id, returnedDoc.id); + Assert.AreEqual(expectedDoc.pk, returnedDoc.pk); + + response.Dispose(); + } + + #region Helper Methods + + private static DistributedTransactionTestHandler CreateMockHandler(HttpStatusCode statusCode, string responseBody) + { + return new DistributedTransactionTestHandler + { + MockResponseFactory = request => + { + ResponseMessage response = new ResponseMessage(statusCode, request, errorMessage: null) + { + Content = new MemoryStream(Encoding.UTF8.GetBytes(responseBody)) + }; + response.Headers["x-ms-activity-id"] = Guid.NewGuid().ToString(); + response.Headers[IdempotencyTokenHeader] = request.Headers[IdempotencyTokenHeader] ?? Guid.NewGuid().ToString(); + return Task.FromResult(response); + } + }; + } + + private static string CreateMockSuccessResponse(int operationCount) + { + List responses = new(); + for (int i = 0; i < operationCount; i++) + { + responses.Add($@"{{""index"":{i},""statusCode"":201,""etag"":""\""etag-{i}\""""}}"); + } + return $@"{{""operationResponses"":[{string.Join(",", responses)}]}}"; + } + + private static void ValidateRequestBody(string requestBody, params ToDoActivity[] expectedDocs) + { + using JsonDocument json = JsonDocument.Parse(requestBody); + JsonElement operations = json.RootElement.GetProperty("operations"); + + Assert.AreEqual(expectedDocs.Length, operations.GetArrayLength()); + + for (int i = 0; i < expectedDocs.Length; i++) + { + JsonElement op = operations[i]; + + Assert.AreEqual(i, op.GetProperty("index").GetInt32()); + Assert.IsTrue(op.TryGetProperty("databaseName", out _)); + Assert.IsTrue(op.TryGetProperty("collectionName", out _)); + Assert.IsTrue(op.TryGetProperty("operationType", out _)); + + // resourceBody is now a nested JSON object, not a string + JsonElement resourceBody = op.GetProperty("resourceBody"); + Assert.AreEqual(JsonValueKind.Object, resourceBody.ValueKind); + + ToDoActivity actualDoc = JsonSerializer.Deserialize(resourceBody.GetRawText()); + ToDoActivity expectedDoc = expectedDocs[i]; + + Assert.AreEqual(expectedDoc.id, actualDoc.id); + Assert.AreEqual(expectedDoc.pk, actualDoc.pk); + Assert.AreEqual(expectedDoc.taskNum, actualDoc.taskNum); + Assert.AreEqual(expectedDoc.cost, actualDoc.cost); + Assert.AreEqual(expectedDoc.description, actualDoc.description); + } + } + + #endregion + + #region Test Handler + + private class DistributedTransactionTestHandler : RequestHandler + { + public RequestMessage CapturedRequest { get; private set; } + public string CapturedRequestBody { get; private set; } + public Func> MockResponseFactory { get; set; } + + public override async Task SendAsync(RequestMessage request, CancellationToken cancellationToken) + { + if (request.RequestUriString?.StartsWith("/dtc/", StringComparison.OrdinalIgnoreCase) == true) + { + this.CapturedRequest = request; + + if (request.Content != null) + { + using MemoryStream ms = new(); + await request.Content.CopyToAsync(ms); + this.CapturedRequestBody = Encoding.UTF8.GetString(ms.ToArray()); + request.Content.Position = 0; + } + + return this.MockResponseFactory != null + ? await this.MockResponseFactory(request) + : new ResponseMessage(HttpStatusCode.OK, request, errorMessage: null); + } + + return await base.SendAsync(request, cancellationToken); + } + } + + #endregion + } +} \ No newline at end of file