diff --git a/Microsoft.Azure.Cosmos/src/CosmosClient.cs b/Microsoft.Azure.Cosmos/src/CosmosClient.cs
index 23e40b2eb3..5174362bb6 100644
--- a/Microsoft.Azure.Cosmos/src/CosmosClient.cs
+++ b/Microsoft.Azure.Cosmos/src/CosmosClient.cs
@@ -1173,7 +1173,7 @@ public virtual FeedIterator GetDatabaseQueryStreamIterator(
#endif
virtual DistributedWriteTransaction CreateDistributedWriteTransaction()
{
- return new DistributedWriteTransactionCore();
+ return new DistributedWriteTransactionCore(this.ClientContext);
}
///
diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitter.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitter.cs
new file mode 100644
index 0000000000..692af6840f
--- /dev/null
+++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitter.cs
@@ -0,0 +1,104 @@
+// ------------------------------------------------------------
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ------------------------------------------------------------
+
+namespace Microsoft.Azure.Cosmos
+{
+ using System;
+ using System.Collections.Generic;
+ using System.IO;
+ using System.Threading;
+ using System.Threading.Tasks;
+ using Microsoft.Azure.Cosmos.Core.Trace;
+ using Microsoft.Azure.Cosmos.Tracing;
+ using Microsoft.Azure.Documents;
+
+ internal class DistributedTransactionCommitter
+ {
+ // TODO: Move to HttpConstants.HttpHeaders once DTC headers are added centrally
+ private const string IdempotencyTokenHeader = "x-ms-dtc-operation-id";
+
+ private readonly IReadOnlyList operations;
+ private readonly CosmosClientContext clientContext;
+
+ public DistributedTransactionCommitter(
+ IReadOnlyList operations,
+ CosmosClientContext clientContext)
+ {
+ this.operations = operations ?? throw new ArgumentNullException(nameof(operations));
+ this.clientContext = clientContext ?? throw new ArgumentNullException(nameof(clientContext));
+ }
+
+ public async Task CommitTransactionAsync(CancellationToken cancellationToken)
+ {
+ try
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+ await DistributedTransactionCommitterUtils.ResolveCollectionRidsAsync(
+ this.operations,
+ this.clientContext,
+ cancellationToken);
+
+ DistributedTransactionServerRequest serverRequest = await DistributedTransactionServerRequest.CreateAsync(
+ this.operations,
+ this.clientContext.SerializerCore,
+ cancellationToken);
+
+ return await this.ExecuteCommitAsync(serverRequest, cancellationToken);
+ }
+ catch (Exception ex)
+ {
+ DefaultTrace.TraceError($"Distributed transaction failed: {ex.Message}");
+ // await this.AbortTransactionAsync(cancellationToken);
+ throw;
+ }
+ }
+
+ private async Task ExecuteCommitAsync(
+ DistributedTransactionServerRequest serverRequest,
+ CancellationToken cancellationToken)
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+ using (ITrace trace = Trace.GetRootTrace("Execute Distributed Transaction Commit", TraceComponent.Batch, TraceLevel.Info))
+ {
+ using (MemoryStream bodyStream = serverRequest.TransferBodyStream())
+ {
+ ResponseMessage responseMessage = await this.clientContext.ProcessResourceOperationStreamAsync(
+ resourceUri: "/dtc/ops",
+ resourceType: ResourceType.Document, // TODO: Update to a new ResourceType specific to DTC
+ operationType: OperationType.Batch, // TODO: Update to a new OperationType specific to DTC
+ requestOptions: null,
+ cosmosContainerCore: null,
+ partitionKey: null,
+ itemId: null,
+ streamPayload: bodyStream,
+ requestEnricher: requestMessage => this.EnrichRequestMessage(requestMessage, serverRequest),
+ trace: trace,
+ cancellationToken: cancellationToken);
+
+ cancellationToken.ThrowIfCancellationRequested();
+
+ return await DistributedTransactionResponse.FromResponseMessageAsync(
+ responseMessage,
+ serverRequest,
+ this.clientContext.SerializerCore,
+ serverRequest.IdempotencyToken,
+ trace,
+ cancellationToken);
+ }
+ }
+ }
+
+ private void EnrichRequestMessage(RequestMessage requestMessage, DistributedTransactionServerRequest serverRequest)
+ {
+ // Set DTC-specific headers
+ requestMessage.Headers.Add(IdempotencyTokenHeader, serverRequest.IdempotencyToken.ToString());
+ requestMessage.UseGatewayMode = true;
+ }
+
+ private Task AbortTransactionAsync(CancellationToken cancellationToken)
+ {
+ throw new NotImplementedException();
+ }
+ }
+}
diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitterUtils.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitterUtils.cs
new file mode 100644
index 0000000000..8647dd904e
--- /dev/null
+++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitterUtils.cs
@@ -0,0 +1,47 @@
+// ------------------------------------------------------------
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ------------------------------------------------------------
+
+namespace Microsoft.Azure.Cosmos
+{
+ using System;
+ using System.Collections.Generic;
+ using System.Linq;
+ using System.Threading;
+ using System.Threading.Tasks;
+ using Microsoft.Azure.Cosmos.Tracing;
+ using Microsoft.Azure.Documents;
+
+ internal class DistributedTransactionCommitterUtils
+ {
+ public static async Task ResolveCollectionRidsAsync(
+ IReadOnlyList operations,
+ CosmosClientContext clientContext,
+ CancellationToken cancellationToken)
+ {
+ IEnumerable> groupedOperations = operations
+ .GroupBy(op => $"/dbs/{op.Database}/colls/{op.Container}");
+
+ foreach (IGrouping group in groupedOperations)
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ string collectionPath = group.Key;
+ ContainerProperties containerProperties = await clientContext.GetCachedContainerPropertiesAsync(
+ collectionPath,
+ NoOpTrace.Singleton,
+ cancellationToken);
+
+ string containerResourceId = containerProperties.ResourceId;
+ ResourceId resourceId = ResourceId.Parse(containerResourceId);
+ string databaseResourceId = resourceId.DatabaseId.ToString();
+
+ foreach (DistributedTransactionOperation operation in group)
+ {
+ operation.CollectionResourceId = containerResourceId;
+ operation.DatabaseResourceId = databaseResourceId;
+ }
+ }
+ }
+ }
+}
diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperation.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperation.cs
index 5843fdfcc7..6ad3351298 100644
--- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperation.cs
+++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperation.cs
@@ -6,16 +6,19 @@ namespace Microsoft.Azure.Cosmos
{
using System;
using System.IO;
-
- //using Microsoft.Azure.Documents;
+ using System.Threading;
+ using System.Threading.Tasks;
+ using Microsoft.Azure.Documents;
///
- /// Represents an operation on a document whichwill be executed as a part of a distributed transaction.
+ /// Represents an operation on a document which will be executed as a part of a distributed transaction.
///
internal class DistributedTransactionOperation
{
+ protected Memory body;
+
public DistributedTransactionOperation(
- Documents.OperationType operationType,
+ OperationType operationType,
int operationIndex,
string database,
string container,
@@ -36,11 +39,37 @@ public DistributedTransactionOperation(
public string Container { get; internal set; }
- public Documents.OperationType OperationType { get; internal set; }
+ public OperationType OperationType { get; internal set; }
- public int OperationIndex { get; internal set; }
+ public int OperationIndex { get; internal set; }
public string Id { get; internal set; }
+
+ public string CollectionResourceId { get; internal set; }
+
+ public string DatabaseResourceId { get; internal set; }
+
+ internal string PartitionKeyJson { get; set; }
+
+ internal string SessionToken { get; set; }
+
+ internal string ETag { get; set; }
+
+ internal Stream ResourceStream { get; set; }
+
+ internal Memory ResourceBody
+ {
+ get => this.body;
+ set => this.body = value;
+ }
+
+ internal virtual async Task MaterializeResourceAsync(CosmosSerializerCore serializerCore, CancellationToken cancellationToken)
+ {
+ if (this.body.IsEmpty && this.ResourceStream != null)
+ {
+ this.body = await BatchExecUtils.StreamToMemoryAsync(this.ResourceStream, cancellationToken);
+ }
+ }
}
internal class DistributedTransactionOperation : DistributedTransactionOperation
@@ -69,6 +98,18 @@ public DistributedTransactionOperation(
{
this.Resource = resource;
}
+
public T Resource { get; internal set; }
+
+ internal override Task MaterializeResourceAsync(CosmosSerializerCore serializerCore, CancellationToken cancellationToken)
+ {
+ if (this.body.IsEmpty && this.Resource != null)
+ {
+ this.ResourceStream = serializerCore.ToStream(this.Resource);
+ return base.MaterializeResourceAsync(serializerCore, cancellationToken);
+ }
+
+ return Task.CompletedTask;
+ }
}
}
diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperationResult.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperationResult.cs
index 79d866206e..3618c55a86 100644
--- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperationResult.cs
+++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionOperationResult.cs
@@ -7,6 +7,8 @@ namespace Microsoft.Azure.Cosmos
using System;
using System.IO;
using System.Net;
+ using System.Text.Json;
+ using System.Text.Json.Serialization;
using Microsoft.Azure.Cosmos.Tracing;
using Microsoft.Azure.Documents;
@@ -27,11 +29,14 @@ internal DistributedTransactionOperationResult(HttpStatusCode statusCode)
internal DistributedTransactionOperationResult(DistributedTransactionOperationResult other)
{
+ this.Index = other.Index;
this.StatusCode = other.StatusCode;
this.SubStatusCode = other.SubStatusCode;
this.ETag = other.ETag;
this.ResourceStream = other.ResourceStream;
this.SessionToken = other.SessionToken;
+ this.RequestCharge = other.RequestCharge;
+ this.ActivityId = other.ActivityId;
this.Trace = other.Trace;
}
@@ -39,39 +44,97 @@ internal DistributedTransactionOperationResult(DistributedTransactionOperationRe
/// Initializes a new instance of the class.
/// This protected constructor is intended for use by derived classes.
///
+ [JsonConstructor]
protected DistributedTransactionOperationResult()
{
}
+ ///
+ /// Gets the index of this operation within the distributed transaction.
+ ///
+ [JsonInclude]
+ [JsonPropertyName("index")]
+ public virtual int Index { get; internal set; }
+
///
/// Gets the HTTP status code returned by the operation.
///
+ [JsonInclude]
+ [JsonPropertyName("statuscode")]
public virtual HttpStatusCode StatusCode { get; internal set; }
///
/// Gets a value indicating whether the HTTP status code returned by the operation indicates success.
///
- public virtual bool IsSuccessStatusCode => ((int)this.StatusCode >= 200) && ((int)this.StatusCode <= 299);
+ [JsonIgnore]
+ public virtual bool IsSuccessStatusCode => (int)this.StatusCode >= 200 && (int)this.StatusCode <= 299;
///
/// Gets the entity tag (ETag) associated with the operation result.
/// The ETag is used for concurrency control and represents the version of the resource.
///
+ [JsonInclude]
+ [JsonPropertyName("etag")]
public virtual string ETag { get; internal set; }
///
/// Gets the session token associated with the operation result.
///
+ [JsonInclude]
+ [JsonPropertyName("sessionToken")]
public virtual string SessionToken { get; internal set; }
///
/// Gets the resource stream associated with the operation result.
/// The stream contains the raw response payload returned by the operation.
///
+ [JsonIgnore]
public virtual Stream ResourceStream { get; internal set; }
+ ///
+ /// Used for JSON deserialization of the base64-encoded resource body.
+ ///
+ [JsonInclude]
+ [JsonPropertyName("resourcebody")]
+ internal string ResourceBodyBase64
+ {
+ get => null; // Write-only for deserialization
+ set
+ {
+ if (!string.IsNullOrEmpty(value))
+ {
+ byte[] resourceBody = Convert.FromBase64String(value);
+ this.ResourceStream = new MemoryStream(resourceBody, 0, resourceBody.Length, writable: false, publiclyVisible: true);
+ }
+ }
+ }
+
+ ///
+ /// Request charge in request units for the operation.
+ ///
+ [JsonPropertyName("requestCharge")]
+ internal virtual double RequestCharge { get; set; }
+
+ [JsonPropertyName("substatuscode")]
internal virtual SubStatusCodes SubStatusCode { get; set; }
+ ///
+ /// ActivityId related to the operation.
+ ///
+ [JsonIgnore]
+ internal virtual string ActivityId { get; set; }
+
+ [JsonIgnore]
internal ITrace Trace { get; set; }
+
+ ///
+ /// Creates a from a JSON element.
+ ///
+ /// The JSON element containing the operation result.
+ /// The deserialized operation result.
+ internal static DistributedTransactionOperationResult FromJson(JsonElement json)
+ {
+ return JsonSerializer.Deserialize(json);
+ }
}
}
diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionRequest.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionRequest.cs
new file mode 100644
index 0000000000..3b923be8f9
--- /dev/null
+++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionRequest.cs
@@ -0,0 +1,32 @@
+// ------------------------------------------------------------
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ------------------------------------------------------------
+
+namespace Microsoft.Azure.Cosmos
+{
+ using System;
+ using System.Collections.Generic;
+ using Microsoft.Azure.Documents;
+
+ internal class DistributedTransactionRequest
+ {
+ public DistributedTransactionRequest(
+ IReadOnlyList operations,
+ OperationType operationType = OperationType.Batch,
+ ResourceType resourceType = ResourceType.Document)
+ {
+ this.Operations = operations ?? throw new ArgumentNullException(nameof(operations));
+ this.IdempotencyToken = Guid.NewGuid();
+ this.OperationType = operationType;
+ this.ResourceType = resourceType;
+ }
+
+ public Guid IdempotencyToken { get; set; }
+
+ public OperationType OperationType { get; set; }
+
+ public ResourceType ResourceType { get; set; }
+
+ public IReadOnlyList Operations { get; set; }
+ }
+}
diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs
index d4e16edc7a..490c782d9b 100644
--- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs
+++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs
@@ -7,7 +7,11 @@ namespace Microsoft.Azure.Cosmos
using System;
using System.Collections;
using System.Collections.Generic;
+ using System.IO;
using System.Net;
+ using System.Text.Json;
+ using System.Threading;
+ using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Tracing;
using Microsoft.Azure.Documents;
@@ -19,22 +23,33 @@ namespace Microsoft.Azure.Cosmos
#else
internal
#endif
- class DistributedTransactionResponse : IReadOnlyList
+ class DistributedTransactionResponse : IReadOnlyList, IDisposable
{
- private readonly List results = null;
+ private const string IdempotencyTokenHeader = "x-ms-dtc-operation-id";
+
+ private List results;
+ private bool isDisposed;
private DistributedTransactionResponse(
HttpStatusCode statusCode,
SubStatusCodes subStatusCode,
string errorMessage,
Headers headers,
- IReadOnlyList operations)
+ IReadOnlyList operations,
+ CosmosSerializerCore serializer,
+ ITrace trace,
+ Guid idempotencyToken,
+ string serverDiagnostics = null)
{
this.Headers = headers;
this.StatusCode = statusCode;
this.SubStatusCode = subStatusCode;
this.ErrorMessage = errorMessage;
this.Operations = operations;
+ this.SerializerCore = serializer;
+ this.Trace = trace;
+ this.IdempotencyToken = idempotencyToken;
+ this.ServerDiagnostics = serverDiagnostics;
}
///
@@ -53,6 +68,8 @@ public virtual DistributedTransactionOperationResult this[int index]
{
get
{
+ this.ThrowIfDisposed();
+
if (this.results == null || index < 0 || index >= this.results.Count)
{
throw new ArgumentOutOfRangeException(nameof(index), "Index is out of range.");
@@ -65,21 +82,32 @@ public virtual DistributedTransactionOperationResult this[int index]
///
/// Gets the headers associated with the distributed transaction response.
///
- public virtual Headers Headers { get; internal set; }
+ public virtual Headers Headers { get; }
///
- /// Gets the HTTP status code for the distributed transaction response.
+ /// Gets the ActivityId that identifies the server request made to execute the transaction.
///
- public virtual HttpStatusCode StatusCode { get; internal set; }
+ public virtual string ActivityId => this.Headers?.ActivityId;
- internal virtual SubStatusCodes SubStatusCode { get; }
+ ///
+ /// Gets the request charge for the distributed transaction request.
+ ///
+ public virtual double RequestCharge => this.Headers?.RequestCharge ?? 0;
- internal IReadOnlyList Operations { get; set; }
+ ///
+ /// Gets the HTTP status code for the distributed transaction response.
+ ///
+ public virtual HttpStatusCode StatusCode { get; }
+
+ ///
+ /// Gets a value indicating whether the transaction was processed successfully.
+ ///
+ public virtual bool IsSuccessStatusCode => (int)this.StatusCode >= 200 && (int)this.StatusCode <= 299;
///
/// Gets the error message associated with the distributed transaction response, if any.
///
- public virtual string ErrorMessage { get; internal set; }
+ public virtual string ErrorMessage { get; }
///
/// Gets the number of operation results in the distributed transaction response.
@@ -87,23 +115,268 @@ public virtual DistributedTransactionOperationResult this[int index]
public virtual int Count => this.results?.Count ?? 0;
///
- /// Returns an enumerator that iterates through the collection of distributed transaction operation results.
+ /// Gets the idempotency token associated with this distributed transaction.
///
- /// An enumerator for the collection of objects.
+ public virtual Guid IdempotencyToken { get; }
+
+ ///
+ /// Gets the server-side diagnostic information for the transaction.
+ ///
+ public virtual string ServerDiagnostics { get; }
+
+ internal virtual SubStatusCodes SubStatusCode { get; }
+
+ internal virtual CosmosSerializerCore SerializerCore { get; }
+
+ internal IReadOnlyList Operations { get; }
+
+ internal ITrace Trace { get; }
+
+ ///
+ /// Returns an enumerator that iterates through the operation results.
+ ///
+ /// An enumerator for the operation results.
public virtual IEnumerator GetEnumerator()
{
- return this.results.GetEnumerator();
+ return this.results?.GetEnumerator()
+ ?? ((IList)Array.Empty()).GetEnumerator();
}
- ///
+ ///
+ /// Returns an enumerator that iterates through the operation results.
+ ///
+ /// An enumerator for the operation results.
IEnumerator IEnumerable.GetEnumerator()
{
return this.GetEnumerator();
}
- private void CreateAndPopulateResults(IReadOnlyList operations, ITrace trace)
+ ///
+ /// Releases the unmanaged resources used by the and optionally releases the managed resources.
+ ///
+ public void Dispose()
+ {
+ this.Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+
+ ///
+ /// Creates a from a .
+ ///
+ internal static async Task FromResponseMessageAsync(
+ ResponseMessage responseMessage,
+ DistributedTransactionServerRequest serverRequest,
+ CosmosSerializerCore serializer,
+ Guid requestIdempotencyToken,
+ ITrace trace,
+ CancellationToken cancellationToken)
+ {
+ using (ITrace createResponseTrace = trace.StartChild("Create Distributed Transaction Response", TraceComponent.Batch, TraceLevel.Info))
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ // Extract idempotency token from response headers, fallback to request token if not present
+ Guid idempotencyToken = GetIdempotencyTokenFromHeaders(responseMessage.Headers, requestIdempotencyToken);
+
+ DistributedTransactionResponse response = null;
+ MemoryStream memoryStream = null;
+
+ try
+ {
+ if (responseMessage.Content != null)
+ {
+ Stream content = responseMessage.Content;
+
+ // Ensure the stream is seekable
+ if (!content.CanSeek)
+ {
+ memoryStream = new MemoryStream();
+ await responseMessage.Content.CopyToAsync(memoryStream);
+ memoryStream.Position = 0;
+ content = memoryStream;
+ }
+
+ response = await PopulateFromJsonContentAsync(
+ content,
+ responseMessage,
+ serverRequest,
+ serializer,
+ idempotencyToken,
+ createResponseTrace,
+ cancellationToken);
+ }
+
+ // If we couldn't parse JSON content or there was no content, create default response
+ response ??= new DistributedTransactionResponse(
+ responseMessage.StatusCode,
+ responseMessage.Headers.SubStatusCode,
+ responseMessage.ErrorMessage,
+ responseMessage.Headers,
+ serverRequest.Operations,
+ serializer,
+ createResponseTrace,
+ idempotencyToken);
+
+ // Validate results count matches operations count
+ if (response.results == null || response.results.Count != serverRequest.Operations.Count)
+ {
+ if (responseMessage.IsSuccessStatusCode)
+ {
+ // Server should guarantee results count equals operations count on success
+ return new DistributedTransactionResponse(
+ HttpStatusCode.InternalServerError,
+ SubStatusCodes.Unknown,
+ ClientResources.InvalidServerResponse,
+ responseMessage.Headers,
+ serverRequest.Operations,
+ serializer,
+ createResponseTrace,
+ idempotencyToken);
+ }
+
+ response.CreateAndPopulateResults(serverRequest.Operations, createResponseTrace);
+ }
+
+ return response;
+ }
+ finally
+ {
+ memoryStream?.Dispose();
+ }
+ }
+ }
+
+ ///
+ /// Disposes the disposable members held by this class.
+ ///
+ /// True to release both managed and unmanaged resources; false to release only unmanaged resources.
+ protected virtual void Dispose(bool disposing)
+ {
+ if (this.isDisposed)
+ {
+ return;
+ }
+
+ if (disposing && this.results != null)
+ {
+ foreach (DistributedTransactionOperationResult result in this.results)
+ {
+ result.ResourceStream?.Dispose();
+ }
+ }
+
+ this.results = null;
+ this.isDisposed = true;
+ }
+
+ private static Guid GetIdempotencyTokenFromHeaders(Headers headers, Guid fallbackToken)
+ {
+ if (headers != null &&
+ headers.TryGetValue(IdempotencyTokenHeader, out string tokenValue) &&
+ Guid.TryParse(tokenValue, out Guid idempotencyToken))
+ {
+ return idempotencyToken;
+ }
+
+ return fallbackToken;
+ }
+
+ private void ThrowIfDisposed()
+ {
+ if (this.isDisposed)
+ {
+ throw new ObjectDisposedException(nameof(DistributedTransactionResponse));
+ }
+ }
+
+ private static async Task PopulateFromJsonContentAsync(
+ Stream content,
+ ResponseMessage responseMessage,
+ DistributedTransactionServerRequest serverRequest,
+ CosmosSerializerCore serializer,
+ Guid idempotencyToken,
+ ITrace trace,
+ CancellationToken cancellationToken)
{
- throw new NotImplementedException();
+ List results = new List();
+
+ try
+ {
+ using (JsonDocument responseJson = await JsonDocument.ParseAsync(content, cancellationToken: cancellationToken))
+ {
+ JsonElement root = responseJson.RootElement;
+
+ // Parse operation results from "operationResponses" array
+ if (root.TryGetProperty("operationResponses", out JsonElement operationResponses) &&
+ operationResponses.ValueKind == JsonValueKind.Array)
+ {
+ foreach (JsonElement operationElement in operationResponses.EnumerateArray())
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ DistributedTransactionOperationResult operationResult = DistributedTransactionOperationResult.FromJson(operationElement);
+ operationResult.Trace = trace;
+ operationResult.SessionToken ??= responseMessage.Headers.Session;
+ operationResult.ActivityId = responseMessage.Headers.ActivityId;
+ results.Add(operationResult);
+ }
+ }
+ }
+ }
+ catch (JsonException)
+ {
+ // If JSON parsing fails, return null to fall back to default response
+ return null;
+ }
+
+ HttpStatusCode finalStatusCode = responseMessage.StatusCode;
+ SubStatusCodes finalSubStatusCode = responseMessage.Headers.SubStatusCode;
+
+ // Promote operation error status for MultiStatus responses
+ if ((int)finalStatusCode == (int)StatusCodes.MultiStatus)
+ {
+ foreach (DistributedTransactionOperationResult result in results)
+ {
+ if ((int)result.StatusCode != (int)StatusCodes.FailedDependency &&
+ (int)result.StatusCode >= (int)StatusCodes.StartingErrorCode)
+ {
+ finalStatusCode = result.StatusCode;
+ finalSubStatusCode = result.SubStatusCode;
+ break;
+ }
+ }
+ }
+
+ return new DistributedTransactionResponse(
+ finalStatusCode,
+ finalSubStatusCode,
+ responseMessage.ErrorMessage,
+ responseMessage.Headers,
+ serverRequest.Operations,
+ serializer,
+ trace,
+ idempotencyToken)
+ {
+ results = results
+ };
+ }
+
+ private void CreateAndPopulateResults(
+ IReadOnlyList operations,
+ ITrace trace)
+ {
+ this.results = new List(operations.Count);
+
+ for (int i = 0; i < operations.Count; i++)
+ {
+ this.results.Add(new DistributedTransactionOperationResult(this.StatusCode)
+ {
+ SubStatusCode = this.SubStatusCode,
+ SessionToken = this.Headers?.Session,
+ ActivityId = this.ActivityId,
+ Trace = trace
+ });
+ }
}
}
}
diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionSerializer.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionSerializer.cs
new file mode 100644
index 0000000000..2cbe7f5d1b
--- /dev/null
+++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionSerializer.cs
@@ -0,0 +1,124 @@
+// ------------------------------------------------------------
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ------------------------------------------------------------
+
+namespace Microsoft.Azure.Cosmos
+{
+ using System;
+ using System.Collections.Generic;
+ using System.IO;
+ using System.Text.Json;
+ using Microsoft.Azure.Documents;
+
+ ///
+ /// Serializes distributed transaction requests to JSON format.
+ ///
+ internal static class DistributedTransactionSerializer
+ {
+ ///
+ /// Serializes a distributed transaction request body to a JSON stream.
+ /// The body contains only the operations array. Other metadata like idempotencyToken,
+ /// operationType, and resourceType are sent as HTTP headers per the spec.
+ ///
+ /// The list of operations to include in the request.
+ /// A MemoryStream containing the JSON-serialized request body.
+ public static MemoryStream SerializeRequest(IReadOnlyList operations)
+ {
+ MemoryStream stream = new MemoryStream();
+
+ using (Utf8JsonWriter jsonWriter = new Utf8JsonWriter(stream, new JsonWriterOptions { Indented = false }))
+ {
+ jsonWriter.WriteStartObject();
+
+ // operations array
+ jsonWriter.WriteStartArray("operations");
+
+ foreach (DistributedTransactionOperation operation in operations)
+ {
+ WriteOperation(jsonWriter, operation);
+ }
+
+ jsonWriter.WriteEndArray();
+
+ jsonWriter.WriteEndObject();
+ jsonWriter.Flush();
+ }
+
+ stream.Position = 0;
+ return stream;
+ }
+
+ ///
+ /// Writes a single operation to the JSON writer.
+ /// Keys match the C++ DistributedTransactionOperation model.
+ ///
+ private static void WriteOperation(Utf8JsonWriter jsonWriter, DistributedTransactionOperation operation)
+ {
+ jsonWriter.WriteStartObject();
+
+ // databaseName
+ jsonWriter.WriteString("databaseName", operation.Database);
+
+ // collectionName
+ jsonWriter.WriteString("collectionName", operation.Container);
+
+ // collectionResourceId
+ if (operation.CollectionResourceId != null)
+ {
+ jsonWriter.WriteString("collectionResourceId", operation.CollectionResourceId);
+ }
+
+ // databaseResourceId
+ if (operation.DatabaseResourceId != null)
+ {
+ jsonWriter.WriteString("databaseResourceId", operation.DatabaseResourceId);
+ }
+
+ // id
+ if (operation.Id != null)
+ {
+ jsonWriter.WriteString("id", operation.Id);
+ }
+
+ // partitionKey
+ if (operation.PartitionKeyJson != null)
+ {
+ jsonWriter.WriteString("partitionKey", operation.PartitionKeyJson);
+ }
+
+ // index (uint32)
+ if (operation.OperationIndex < 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(operation.OperationIndex), "Operation index must be non-negative.");
+ }
+ jsonWriter.WriteNumber("index", (uint)operation.OperationIndex);
+
+ //resourceBody - written as nested JSON object
+ if (!operation.ResourceBody.IsEmpty)
+ {
+ jsonWriter.WritePropertyName("resourceBody");
+ jsonWriter.WriteRawValue(operation.ResourceBody.Span, skipInputValidation: true);
+ }
+
+ // sessionToken
+ if (operation.SessionToken != null)
+ {
+ jsonWriter.WriteString("sessionToken", operation.SessionToken);
+ }
+
+ // etag
+ if (operation.ETag != null)
+ {
+ jsonWriter.WriteString("etag", operation.ETag);
+ }
+
+ // operationType (uint16)
+ jsonWriter.WriteNumber("operationType", (ushort)operation.OperationType);
+
+ // resourceType (uint16)
+ jsonWriter.WriteNumber("resourceType", (ushort)ResourceType.Document);
+
+ jsonWriter.WriteEndObject();
+ }
+ }
+}
diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionServerRequest.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionServerRequest.cs
new file mode 100644
index 0000000000..6cbab1e5f8
--- /dev/null
+++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionServerRequest.cs
@@ -0,0 +1,59 @@
+// ------------------------------------------------------------
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ------------------------------------------------------------
+
+namespace Microsoft.Azure.Cosmos
+{
+ using System;
+ using System.Collections.Generic;
+ using System.IO;
+ using System.Threading;
+ using System.Threading.Tasks;
+
+ internal class DistributedTransactionServerRequest
+ {
+ private readonly CosmosSerializerCore serializerCore;
+ private MemoryStream bodyStream;
+
+ private DistributedTransactionServerRequest(
+ IReadOnlyList operations,
+ CosmosSerializerCore serializerCore)
+ {
+ this.Operations = operations ?? throw new ArgumentNullException(nameof(operations));
+ this.serializerCore = serializerCore ?? throw new ArgumentNullException(nameof(serializerCore));
+ this.IdempotencyToken = Guid.NewGuid();
+ }
+
+ public IReadOnlyList Operations { get; }
+
+ public Guid IdempotencyToken { get; private set; }
+
+ public static async Task CreateAsync(
+ IReadOnlyList operations,
+ CosmosSerializerCore serializerCore,
+ CancellationToken cancellationToken)
+ {
+ DistributedTransactionServerRequest request = new DistributedTransactionServerRequest(operations, serializerCore);
+ await request.CreateBodyStreamAsync(cancellationToken);
+ return request;
+ }
+
+ public MemoryStream TransferBodyStream()
+ {
+ MemoryStream bodyStream = this.bodyStream;
+ this.bodyStream = null;
+ return bodyStream;
+ }
+
+ private async Task CreateBodyStreamAsync(CancellationToken cancellationToken)
+ {
+ foreach (DistributedTransactionOperation operation in this.Operations)
+ {
+ await operation.MaterializeResourceAsync(this.serializerCore, cancellationToken);
+ operation.PartitionKeyJson ??= operation.PartitionKey.ToJsonString();
+ }
+
+ this.bodyStream = DistributedTransactionSerializer.SerializeRequest(this.Operations);
+ }
+ }
+}
diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedWriteTransactionCore.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedWriteTransactionCore.cs
index 1b14613587..c2b39a8c1b 100644
--- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedWriteTransactionCore.cs
+++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedWriteTransactionCore.cs
@@ -13,10 +13,12 @@ namespace Microsoft.Azure.Cosmos
internal class DistributedWriteTransactionCore : DistributedWriteTransaction
{
- protected List operations;
+ private readonly CosmosClientContext clientContext;
+ private readonly List operations;
- internal DistributedWriteTransactionCore()
+ internal DistributedWriteTransactionCore(CosmosClientContext clientContext)
{
+ this.clientContext = clientContext ?? throw new ArgumentNullException(nameof(clientContext));
this.operations = new List();
}
@@ -116,8 +118,11 @@ public override DistributedWriteTransaction UpsertItem(string database, strin
public override async Task CommitTransactionAsync(CancellationToken cancellationToken)
{
- await Task.CompletedTask;
- throw new NotImplementedException();
+ DistributedTransactionCommitter committer = new DistributedTransactionCommitter(
+ operations: this.operations,
+ clientContext: this.clientContext);
+
+ return await committer.CommitTransactionAsync(cancellationToken);
}
private static void ValidateContainerReference(string database, string collection)
diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/DistributedTransaction/DistributedTransactionE2ETests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/DistributedTransaction/DistributedTransactionE2ETests.cs
new file mode 100644
index 0000000000..4e44e622cd
--- /dev/null
+++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/DistributedTransaction/DistributedTransactionE2ETests.cs
@@ -0,0 +1,293 @@
+// ------------------------------------------------------------
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ------------------------------------------------------------
+
+namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
+{
+ using System;
+ using System.Collections.Generic;
+ using System.IO;
+ using System.Net;
+ using System.Text;
+ using System.Text.Json;
+ using System.Threading;
+ using System.Threading.Tasks;
+ using Microsoft.Azure.Cosmos;
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+ using OperationType = Documents.OperationType;
+
+ [TestClass]
+ public class DistributedTransactionE2ETests : BaseCosmosClientHelper
+ {
+ private const string IdempotencyTokenHeader = "x-ms-dtc-operation-id";
+ private const string PartitionKeyPath = "/pk";
+
+ private Container container;
+
+ [TestInitialize]
+ public async Task TestInitialize()
+ {
+ await this.TestInit();
+
+ ContainerResponse response = await this.database.CreateContainerAsync(
+ new ContainerProperties(id: Guid.NewGuid().ToString(), partitionKeyPath: PartitionKeyPath),
+ cancellationToken: this.cancellationToken);
+
+ this.container = response.Container;
+ }
+
+ [TestCleanup]
+ public async Task Cleanup()
+ {
+ await base.TestCleanup();
+ }
+
+ [TestMethod]
+ public async Task ValidateHappyPathRequestAndResponse()
+ {
+ // Arrange
+ ToDoActivity doc1 = ToDoActivity.CreateRandomToDoActivity();
+ ToDoActivity doc2 = ToDoActivity.CreateRandomToDoActivity();
+
+ DistributedTransactionTestHandler handler = CreateMockHandler(
+ HttpStatusCode.OK,
+ CreateMockSuccessResponse(operationCount: 2));
+
+ using CosmosClient client = TestCommon.CreateCosmosClient(
+ clientOptions: new CosmosClientOptions
+ {
+ CustomHandlers = { handler },
+ ConnectionMode = ConnectionMode.Gateway
+ });
+
+ // Act
+ DistributedTransactionResponse response = await client.CreateDistributedWriteTransaction()
+ .CreateItem(this.database.Id, this.container.Id, new PartitionKey(doc1.pk), doc1)
+ .CreateItem(this.database.Id, this.container.Id, new PartitionKey(doc2.pk), doc2)
+ .CommitTransactionAsync(CancellationToken.None);
+
+ // Assert - Request
+ Assert.IsNotNull(handler.CapturedRequest);
+ Assert.IsNotNull(handler.CapturedRequest.Headers[IdempotencyTokenHeader]);
+ ValidateRequestBody(handler.CapturedRequestBody, doc1, doc2);
+
+ // Assert - Response
+ Assert.AreEqual(HttpStatusCode.OK, response.StatusCode);
+ Assert.IsTrue(response.IsSuccessStatusCode);
+ Assert.AreEqual(2, response.Count);
+
+ response.Dispose();
+ }
+
+ [TestMethod]
+ public async Task ValidateMixedOperationsRequestStructure()
+ {
+ // Arrange
+ ToDoActivity createDoc = ToDoActivity.CreateRandomToDoActivity();
+ ToDoActivity replaceDoc = ToDoActivity.CreateRandomToDoActivity();
+
+ DistributedTransactionTestHandler handler = CreateMockHandler(
+ HttpStatusCode.OK,
+ CreateMockSuccessResponse(operationCount: 3));
+
+ using CosmosClient client = TestCommon.CreateCosmosClient(
+ clientOptions: new CosmosClientOptions
+ {
+ CustomHandlers = { handler },
+ ConnectionMode = ConnectionMode.Gateway
+ });
+
+ // Act
+ DistributedTransactionResponse response = await client.CreateDistributedWriteTransaction()
+ .CreateItem(this.database.Id, this.container.Id, new PartitionKey(createDoc.pk), createDoc)
+ .ReplaceItem(this.database.Id, this.container.Id, new PartitionKey(replaceDoc.pk), replaceDoc.id, replaceDoc)
+ .DeleteItem(this.database.Id, this.container.Id, new PartitionKey("delete-pk"), "delete-id")
+ .CommitTransactionAsync(CancellationToken.None);
+
+ // Assert
+ using JsonDocument requestJson = JsonDocument.Parse(handler.CapturedRequestBody);
+ JsonElement operations = requestJson.RootElement.GetProperty("operations");
+
+ Assert.AreEqual(3, operations.GetArrayLength());
+ Assert.AreEqual((int)OperationType.Create, operations[0].GetProperty("operationType").GetInt32()); // Create
+ Assert.AreEqual((int)OperationType.Replace, operations[1].GetProperty("operationType").GetInt32()); // Replace
+ Assert.AreEqual((int)OperationType.Delete, operations[2].GetProperty("operationType").GetInt32()); // Delete
+
+ response.Dispose();
+ }
+
+ [TestMethod]
+ public async Task ValidateConflictResponseReturnsErrorStatus()
+ {
+ // Arrange
+ string mockErrorResponse = @"{
+ ""operationResponses"": [{
+ ""index"": 0,
+ ""statuscode"": 409,
+ ""substatuscode"": 0
+ }]
+ }";
+
+ DistributedTransactionTestHandler handler = CreateMockHandler(HttpStatusCode.Conflict, mockErrorResponse);
+ using CosmosClient client = TestCommon.CreateCosmosClient(
+ clientOptions: new CosmosClientOptions
+ {
+ CustomHandlers = { handler },
+ ConnectionMode = ConnectionMode.Gateway
+ });
+
+ ToDoActivity doc = ToDoActivity.CreateRandomToDoActivity();
+
+ // Act
+ DistributedTransactionResponse response = await client.CreateDistributedWriteTransaction()
+ .CreateItem(this.database.Id, this.container.Id, new PartitionKey(doc.pk), doc)
+ .CommitTransactionAsync(CancellationToken.None);
+
+ // Assert
+ Assert.AreEqual(HttpStatusCode.Conflict, response.StatusCode);
+ Assert.IsFalse(response.IsSuccessStatusCode);
+ Assert.AreEqual(1, response.Count);
+ Assert.AreEqual(HttpStatusCode.Conflict, response[0].StatusCode);
+
+ response.Dispose();
+ }
+
+ [TestMethod]
+ public async Task ValidateResponseDeserializesCorrectly()
+ {
+ // Arrange
+ ToDoActivity expectedDoc = ToDoActivity.CreateRandomToDoActivity();
+ string base64Body = Convert.ToBase64String(Encoding.UTF8.GetBytes(JsonSerializer.Serialize(expectedDoc)));
+
+ string mockResponse = $@"{{
+ ""operationResponses"": [{{
+ ""index"": 0,
+ ""statuscode"": 201,
+ ""etag"": ""\""test-etag\"""",
+ ""resourcebody"": ""{base64Body}""
+ }}]
+ }}";
+
+ DistributedTransactionTestHandler handler = CreateMockHandler(HttpStatusCode.OK, mockResponse);
+ using CosmosClient client = TestCommon.CreateCosmosClient(
+ clientOptions: new CosmosClientOptions
+ {
+ CustomHandlers = { handler },
+ ConnectionMode = ConnectionMode.Gateway
+ });
+
+ // Act
+ DistributedTransactionResponse response = await client.CreateDistributedWriteTransaction()
+ .CreateItem(this.database.Id, this.container.Id, new PartitionKey(expectedDoc.pk), expectedDoc)
+ .CommitTransactionAsync(CancellationToken.None);
+
+ // Assert
+ Assert.AreEqual(HttpStatusCode.Created, response[0].StatusCode);
+ Assert.AreEqual("\"test-etag\"", response[0].ETag);
+ Assert.IsNotNull(response[0].ResourceStream);
+
+ using StreamReader reader = new StreamReader(response[0].ResourceStream);
+ ToDoActivity returnedDoc = JsonSerializer.Deserialize(await reader.ReadToEndAsync());
+
+ Assert.AreEqual(expectedDoc.id, returnedDoc.id);
+ Assert.AreEqual(expectedDoc.pk, returnedDoc.pk);
+
+ response.Dispose();
+ }
+
+ #region Helper Methods
+
+ private static DistributedTransactionTestHandler CreateMockHandler(HttpStatusCode statusCode, string responseBody)
+ {
+ return new DistributedTransactionTestHandler
+ {
+ MockResponseFactory = request =>
+ {
+ ResponseMessage response = new ResponseMessage(statusCode, request, errorMessage: null)
+ {
+ Content = new MemoryStream(Encoding.UTF8.GetBytes(responseBody))
+ };
+ response.Headers["x-ms-activity-id"] = Guid.NewGuid().ToString();
+ response.Headers[IdempotencyTokenHeader] = request.Headers[IdempotencyTokenHeader] ?? Guid.NewGuid().ToString();
+ return Task.FromResult(response);
+ }
+ };
+ }
+
+ private static string CreateMockSuccessResponse(int operationCount)
+ {
+ List responses = new();
+ for (int i = 0; i < operationCount; i++)
+ {
+ responses.Add($@"{{""index"":{i},""statusCode"":201,""etag"":""\""etag-{i}\""""}}");
+ }
+ return $@"{{""operationResponses"":[{string.Join(",", responses)}]}}";
+ }
+
+ private static void ValidateRequestBody(string requestBody, params ToDoActivity[] expectedDocs)
+ {
+ using JsonDocument json = JsonDocument.Parse(requestBody);
+ JsonElement operations = json.RootElement.GetProperty("operations");
+
+ Assert.AreEqual(expectedDocs.Length, operations.GetArrayLength());
+
+ for (int i = 0; i < expectedDocs.Length; i++)
+ {
+ JsonElement op = operations[i];
+
+ Assert.AreEqual(i, op.GetProperty("index").GetInt32());
+ Assert.IsTrue(op.TryGetProperty("databaseName", out _));
+ Assert.IsTrue(op.TryGetProperty("collectionName", out _));
+ Assert.IsTrue(op.TryGetProperty("operationType", out _));
+
+ // resourceBody is now a nested JSON object, not a string
+ JsonElement resourceBody = op.GetProperty("resourceBody");
+ Assert.AreEqual(JsonValueKind.Object, resourceBody.ValueKind);
+
+ ToDoActivity actualDoc = JsonSerializer.Deserialize(resourceBody.GetRawText());
+ ToDoActivity expectedDoc = expectedDocs[i];
+
+ Assert.AreEqual(expectedDoc.id, actualDoc.id);
+ Assert.AreEqual(expectedDoc.pk, actualDoc.pk);
+ Assert.AreEqual(expectedDoc.taskNum, actualDoc.taskNum);
+ Assert.AreEqual(expectedDoc.cost, actualDoc.cost);
+ Assert.AreEqual(expectedDoc.description, actualDoc.description);
+ }
+ }
+
+ #endregion
+
+ #region Test Handler
+
+ private class DistributedTransactionTestHandler : RequestHandler
+ {
+ public RequestMessage CapturedRequest { get; private set; }
+ public string CapturedRequestBody { get; private set; }
+ public Func> MockResponseFactory { get; set; }
+
+ public override async Task SendAsync(RequestMessage request, CancellationToken cancellationToken)
+ {
+ if (request.RequestUriString?.StartsWith("/dtc/", StringComparison.OrdinalIgnoreCase) == true)
+ {
+ this.CapturedRequest = request;
+
+ if (request.Content != null)
+ {
+ using MemoryStream ms = new();
+ await request.Content.CopyToAsync(ms);
+ this.CapturedRequestBody = Encoding.UTF8.GetString(ms.ToArray());
+ request.Content.Position = 0;
+ }
+
+ return this.MockResponseFactory != null
+ ? await this.MockResponseFactory(request)
+ : new ResponseMessage(HttpStatusCode.OK, request, errorMessage: null);
+ }
+
+ return await base.SendAsync(request, cancellationToken);
+ }
+ }
+
+ #endregion
+ }
+}
\ No newline at end of file