diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitter.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitter.cs index 4e78d03679..46654f5e2f 100644 --- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitter.cs +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitter.cs @@ -7,11 +7,13 @@ namespace Microsoft.Azure.Cosmos using System; using System.Collections.Generic; using System.IO; + using System.Net; using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Cosmos.Core.Trace; using Microsoft.Azure.Cosmos.Tracing; using Microsoft.Azure.Documents; + using Microsoft.Azure.Documents.Collections; internal class DistributedTransactionCommitter { @@ -75,13 +77,20 @@ private async Task ExecuteCommitAsync( cancellationToken.ThrowIfCancellationRequested(); - return await DistributedTransactionResponse.FromResponseMessageAsync( + DistributedTransactionResponse response = await DistributedTransactionResponse.FromResponseMessageAsync( responseMessage, serverRequest, this.clientContext.SerializerCore, serverRequest.IdempotencyToken, trace, cancellationToken); + + DistributedTransactionCommitter.MergeSessionTokens( + response, + serverRequest, + this.clientContext.DocumentClient.sessionContainer); + + return response; } } } @@ -100,8 +109,59 @@ private static void EnrichRequestMessage(RequestMessage requestMessage, Distribu requestMessage.UseGatewayMode = true; } + internal static void MergeSessionTokens( + DistributedTransactionResponse response, + DistributedTransactionServerRequest serverRequest, + ISessionContainer sessionContainer) + { + // Mirror the pattern used by GatewayStoreModel.CaptureSessionTokenAndHandleSplitAsync. + // after a response is received, store each operation's session token in the SessionContainer + // so that subsequent Session-consistency reads on the affected collections can use the latest token + // without getting ReadSessionNotAvailable. + // + // DTC spans multiple collections so the server embeds per-operation session + // tokens in the JSON body; those are already parsed into DistributedTransactionOperationResult.SessionToken, + // but we must explicitly push them into the SessionContainer. + + if (response == null || response.Count == 0 || serverRequest == null || sessionContainer == null) + { + return; + } + + RequestNameValueCollection headers = new RequestNameValueCollection(); + + for (int i = 0; i < response.Count; i++) + { + DistributedTransactionOperationResult result = response[i]; + DistributedTransactionOperation operation = serverRequest.Operations[result.Index]; + + if (string.IsNullOrEmpty(result.SessionToken) || string.IsNullOrEmpty(operation.CollectionResourceId)) + { + continue; + } + + if (result.StatusCode == HttpStatusCode.NotFound + && result.SubStatusCode == SubStatusCodes.ReadSessionNotAvailable) + { + continue; + } + + // Note: each SetSessionToken call acquires a write lock on the SessionContainer. + // For a future optimization, consider a batch-update API on ISessionContainer to + // reduce lock acquisitions when multiple operations target the same collection. + headers.Clear(); + headers[HttpConstants.HttpHeaders.SessionToken] = result.SessionToken; + + sessionContainer.SetSessionToken( + operation.CollectionResourceId, + DistributedTransactionConstants.GetCollectionFullName(operation.Database, operation.Container), + headers); + } + } + private Task AbortTransactionAsync(CancellationToken cancellationToken) { + // TODO: Implement abort for the two-phase commit path. throw new NotImplementedException(); } } diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitterUtils.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitterUtils.cs index 8647dd904e..23c3174bee 100644 --- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitterUtils.cs +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitterUtils.cs @@ -20,7 +20,7 @@ public static async Task ResolveCollectionRidsAsync( CancellationToken cancellationToken) { IEnumerable> groupedOperations = operations - .GroupBy(op => $"/dbs/{op.Database}/colls/{op.Container}"); + .GroupBy(op => DistributedTransactionConstants.GetCollectionFullName(op.Database, op.Container)); foreach (IGrouping group in groupedOperations) { diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionConstants.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionConstants.cs index 218b1f2c9f..81896c0b8e 100644 --- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionConstants.cs +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionConstants.cs @@ -13,5 +13,10 @@ public static bool IsDistributedTransactionRequest(OperationType operationType, return operationType == OperationType.CommitDistributedTransaction && resourceType == ResourceType.DistributedTransactionBatch; } + + internal static string GetCollectionFullName(string database, string container) + { + return $"dbs/{database}/colls/{container}"; + } } } diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs index adedceabe5..6e32f3c9f6 100644 --- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs @@ -12,6 +12,7 @@ namespace Microsoft.Azure.Cosmos using System.Text.Json; using System.Threading; using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Core.Trace; using Microsoft.Azure.Cosmos.Tracing; using Microsoft.Azure.Documents; @@ -218,6 +219,10 @@ internal static async Task FromResponseMessageAs // Validate results count matches operations count if (response.results == null || response.results.Count != serverRequest.Operations.Count) { + DefaultTrace.TraceWarning( + $"DTC response: result count ({response.results?.Count ?? 0}) differs from " + + $"operation count ({serverRequest.Operations.Count})."); + if (responseMessage.IsSuccessStatusCode) { // Server should guarantee results count equals operations count on success diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/DistributedTransaction/DistributedTransactionTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/DistributedTransaction/DistributedTransactionTests.cs index 60e621ddba..868fb32c38 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/DistributedTransaction/DistributedTransactionTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/DistributedTransaction/DistributedTransactionTests.cs @@ -729,6 +729,61 @@ public async Task UpsertItemStream_ValidDocument_SerializedAsUpsertOperation() response.Dispose(); } + // Session token handling + + [TestMethod] + [Description("Session tokens returned in DTC operation responses are merged into the client's session container, preventing ReadSessionNotAvailable errors on subsequent reads.")] + public async Task ValidateSessionTokenMergedIntoDtcClient() + { + ToDoActivity seedDoc = ToDoActivity.CreateRandomToDoActivity(); + ItemResponse seedResponse = await this.container.CreateItemAsync(seedDoc, new PartitionKey(seedDoc.pk), cancellationToken: this.cancellationToken); + + string validSessionToken = seedResponse.Headers.Session; + Assert.IsFalse(string.IsNullOrEmpty(validSessionToken), "A valid session token must be obtained from the emulator for this test to be meaningful."); + + string dtcMockResponse = $@"{{""operationResponses"":[{{""index"":0,""statusCode"":201,""sessionToken"":""{validSessionToken}""}}]}}"; + + DistributedTransactionMockHandler handler = new DistributedTransactionMockHandler( + request => Task.FromResult(this.BuildMockResponse(HttpStatusCode.OK, dtcMockResponse))); + + using CosmosClient dtcClient = TestCommon.CreateCosmosClient( + clientOptions: new CosmosClientOptions + { + CustomHandlers = { handler }, + ConnectionMode = ConnectionMode.Gateway, + ConsistencyLevel = Cosmos.ConsistencyLevel.Session, + }); + + ToDoActivity newDoc = ToDoActivity.CreateRandomToDoActivity(); + DistributedTransactionResponse dtcResponse = await dtcClient + .CreateDistributedWriteTransaction() + .CreateItem(this.database.Id, this.container.Id, new PartitionKey(newDoc.pk), newDoc) + .CommitTransactionAsync(this.cancellationToken); + + Assert.IsTrue(dtcResponse.IsSuccessStatusCode, "The simulated DTC commit should appear successful to the client."); + + Container dtcContainer = dtcClient.GetContainer(this.database.Id, this.container.Id); + try + { + ItemResponse readResponse = await dtcContainer.ReadItemAsync( + seedDoc.id, + new PartitionKey(seedDoc.pk), + new ItemRequestOptions { ConsistencyLevel = Cosmos.ConsistencyLevel.Session }, + cancellationToken: this.cancellationToken); + + Assert.AreEqual(HttpStatusCode.OK, readResponse.StatusCode, "A Session-consistency read after a DTC commit should return 200 OK."); + } + catch (CosmosException ex) when (ex.StatusCode == HttpStatusCode.NotFound) + { + Assert.AreNotEqual( + (int)SubStatusCodes.ReadSessionNotAvailable, + ex.SubStatusCode, + "A Session-consistency read after a DTC commit must not fail with " + + "ReadSessionNotAvailable (404/1002). This indicates that session token " + + "merging in DistributedTransactionCommitter is broken."); + } + } + // Helpers private void ValidateValueKind(JsonElement operation, string property, JsonValueKind expectedValueKind, int operationIndex, bool isRequired) diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionCommitterTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionCommitterTests.cs new file mode 100644 index 0000000000..bb0a6d60c2 --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionCommitterTests.cs @@ -0,0 +1,349 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Tests.DistributedTransaction +{ + using System; + using System.Collections.Generic; + using System.IO; + using System.Linq; + using System.Net; + using System.Text; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Common; + using Microsoft.Azure.Cosmos.Tests; + using Microsoft.Azure.Cosmos.Tracing; + using Microsoft.Azure.Documents; + using Microsoft.Azure.Documents.Collections; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using Moq; + using PartitionKey = Microsoft.Azure.Cosmos.PartitionKey; + + [TestClass] + public class DistributedTransactionCommitterTests + { + private const string DatabaseName = "testdb"; + private const string ContainerName = "testcontainer"; + + private static readonly string CollectionResourceId = + ResourceId.NewDocumentCollectionId(42, 129).DocumentCollectionId.ToString(); + + [TestMethod] + [Description("Verifies that when the DTC response carries a session token, the token is merged into the SessionContainer")] + public async Task CommitTransactionAsync_MergesSessionTokensIntoSessionContainer() + { + const string sessionToken = "0:1#9#4=8#5=7"; + + SessionContainer sessionContainer = new SessionContainer("testhost"); + + string responseJson = BuildDtcResponseJson( + new[] { (statusCode: 201, sessionToken: sessionToken) }); + + Mock mockContext = this.CreateMockContext( + sessionContainer, + responseContent: responseJson, + statusCode: HttpStatusCode.OK); + + List operations = new List + { + new DistributedTransactionOperation( + OperationType.Create, + operationIndex: 0, + DatabaseName, + ContainerName, + new PartitionKey("pk1")) + }; + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter( + operations, mockContext.Object); + + await committer.CommitTransactionAsync(CancellationToken.None); + + string storedToken = sessionContainer.GetSessionToken(DistributedTransactionConstants.GetCollectionFullName(DatabaseName, ContainerName)); + Assert.AreEqual(sessionToken, storedToken, + "Session token should be merged into SessionContainer after a successful DTC commit."); + } + + [TestMethod] + [Description("When a per-operation session token is absent, SetSessionToken is NOT called for that operation and the SessionContainer is not updated")] + public async Task CommitTransactionAsync_SkipsMerge_WhenSessionTokenIsNull() + { + // sessionToken: null omits the field from the JSON body entirely + string responseJson = BuildDtcResponseJson(new[] { (statusCode: 201, sessionToken: (string)null) }); + + SessionContainer sessionContainer = new SessionContainer("testhost"); + Mock mockContext = this.CreateMockContext( + sessionContainer, + responseContent: responseJson, + statusCode: HttpStatusCode.OK); + + List operations = new List + { + new DistributedTransactionOperation( + OperationType.Create, + operationIndex: 0, + DatabaseName, + ContainerName, + new PartitionKey("pk1")) + }; + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter( + operations, mockContext.Object); + + await committer.CommitTransactionAsync(CancellationToken.None); + + string storedToken = sessionContainer.GetSessionToken(DistributedTransactionConstants.GetCollectionFullName(DatabaseName, ContainerName)); + Assert.IsTrue(string.IsNullOrEmpty(storedToken), + "SessionContainer should not be updated when the operation result has no session token."); + } + + [TestMethod] + [Description("Verifies that the correct collectionRid and collectionFullname are passed to SetSessionToken for each operation")] + public async Task CommitTransactionAsync_PassesCorrectCollectionToSetSessionToken() + { + const string sessionToken = "0:1#5#4=3"; + const string container2 = "testcontainer2"; + + string collectionRid1 = ResourceId.NewDocumentCollectionId(42, 129).DocumentCollectionId.ToString(); + string collectionRid2 = ResourceId.NewDocumentCollectionId(42, 200).DocumentCollectionId.ToString(); + + Mock mockSessionContainer = new Mock(); + + MockDocumentClient documentClient = new MockDocumentClient(); + documentClient.sessionContainer = mockSessionContainer.Object; + + ContainerProperties containerProperties1 = ContainerProperties.CreateWithResourceId(collectionRid1); + containerProperties1.PartitionKeyPath = "/pk"; + ContainerProperties containerProperties2 = ContainerProperties.CreateWithResourceId(collectionRid2); + containerProperties2.PartitionKeyPath = "/pk"; + + Mock mockContext = new Mock(); + mockContext.Setup(c => c.DocumentClient).Returns(documentClient); + mockContext.Setup(c => c.SerializerCore).Returns(MockCosmosUtil.Serializer); + mockContext + .Setup(c => c.GetCachedContainerPropertiesAsync( + DistributedTransactionConstants.GetCollectionFullName(DatabaseName, ContainerName), + It.IsAny(), It.IsAny())) + .ReturnsAsync(containerProperties1); + mockContext + .Setup(c => c.GetCachedContainerPropertiesAsync( + DistributedTransactionConstants.GetCollectionFullName(DatabaseName, container2), + It.IsAny(), It.IsAny())) + .ReturnsAsync(containerProperties2); + + ResponseMessage responseMessage = new ResponseMessage(HttpStatusCode.OK); + responseMessage.Content = new MemoryStream( + Encoding.UTF8.GetBytes(BuildDtcResponseJson( + new[] + { + (statusCode: 200, sessionToken: sessionToken), + (statusCode: 200, sessionToken: sessionToken), + }))); + mockContext.Setup(c => c.ProcessResourceOperationStreamAsync( + It.IsAny(), + ResourceType.DistributedTransactionBatch, + OperationType.CommitDistributedTransaction, + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(responseMessage); + + List operations = new List + { + new DistributedTransactionOperation( + OperationType.Create, + operationIndex: 0, + DatabaseName, + ContainerName, + new PartitionKey("pk1")), + new DistributedTransactionOperation( + OperationType.Create, + operationIndex: 1, + DatabaseName, + container2, + new PartitionKey("pk2")), + }; + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter( + operations, mockContext.Object); + + await committer.CommitTransactionAsync(CancellationToken.None); + + // Verify SetSessionToken was called once per operation with the correct collection identity. + mockSessionContainer.Verify( + s => s.SetSessionToken( + collectionRid1, + DistributedTransactionConstants.GetCollectionFullName(DatabaseName, ContainerName), + It.Is(h => h[HttpConstants.HttpHeaders.SessionToken] == sessionToken)), + Times.Once, + "SetSessionToken should be called for the first operation with its collection RID and fullname."); + + mockSessionContainer.Verify( + s => s.SetSessionToken( + collectionRid2, + DistributedTransactionConstants.GetCollectionFullName(DatabaseName, container2), + It.Is(h => h[HttpConstants.HttpHeaders.SessionToken] == sessionToken)), + Times.Once, + "SetSessionToken should be called for the second operation with its collection RID and fullname."); + } + + [TestMethod] + [Description("Verifies that 404/1002 (ReadSessionNotAvailable) operation results are excluded from session token merging")] + public async Task CommitTransactionAsync_SkipsMerge_When404ReadSessionNotAvailable() + { + const string sessionToken = "0:1#9#4=8#5=7"; + const int readSessionNotAvailableSubStatus = 1002; + + SessionContainer sessionContainer = new SessionContainer("testhost"); + + Mock mockContext = this.CreateMockContext( + sessionContainer, + responseContent: BuildDtcResponseJson(new[] { (statusCode: 404, subStatusCode: (int?)readSessionNotAvailableSubStatus, sessionToken: sessionToken) }), + statusCode: HttpStatusCode.NotFound); + + List operations = new List + { + new DistributedTransactionOperation( + OperationType.Create, + operationIndex: 0, + DatabaseName, + ContainerName, + new PartitionKey("pk1")) + }; + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter( + operations, mockContext.Object); + + await committer.CommitTransactionAsync(CancellationToken.None); + + string storedToken = sessionContainer.GetSessionToken(DistributedTransactionConstants.GetCollectionFullName(DatabaseName, ContainerName)); + Assert.IsTrue(string.IsNullOrEmpty(storedToken), + "Session token should NOT be merged for 404/ReadSessionNotAvailable operation results."); + } + + [TestMethod] + [Description("Verifies that session tokens are still merged into the SessionContainer even when the DTC response indicates a failure")] + public async Task CommitTransactionAsync_MergesSessionTokens_OnFailureResponse() + { + // Deliberately distinct from the success-path token so a copy-paste regression would be caught. + const string sessionToken = "0:1#3#4=2#5=1"; + + SessionContainer sessionContainer = new SessionContainer("testhost"); + + Mock mockContext = this.CreateMockContext( + sessionContainer, + responseContent: BuildDtcResponseJson(new[] { (statusCode: 409, sessionToken: sessionToken) }), + statusCode: HttpStatusCode.Conflict); + + List operations = new List + { + new DistributedTransactionOperation( + OperationType.Create, + operationIndex: 0, + DatabaseName, + ContainerName, + new PartitionKey("pk1")) + }; + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter( + operations, mockContext.Object); + + DistributedTransactionResponse response = await committer.CommitTransactionAsync(CancellationToken.None); + Assert.IsNotNull(response); + + string storedToken = sessionContainer.GetSessionToken(DistributedTransactionConstants.GetCollectionFullName(DatabaseName, ContainerName)); + Assert.AreEqual(sessionToken, storedToken, + "Session token should still be merged even when the DTC response indicates a failure."); + } + + // ─── Helpers ─────────────────────────────────────────────────────────── + + private static string BuildDtcResponseJson( + (int statusCode, string sessionToken)[] operations) + { + return BuildDtcResponseJson( + operations.Select(o => (o.statusCode, subStatusCode: (int?)null, o.sessionToken)).ToArray()); + } + + private static string BuildDtcResponseJson( + (int statusCode, int? subStatusCode, string sessionToken)[] operations) + { + StringBuilder sb = new StringBuilder(); + sb.Append(@"{""operationResponses"":["); + for (int i = 0; i < operations.Length; i++) + { + if (i > 0) + { + sb.Append(','); + } + + sb.Append($@"{{""index"":{i},""statuscode"":{operations[i].statusCode}"); + if (operations[i].subStatusCode.HasValue) + { + sb.Append($@",""substatuscode"":{operations[i].subStatusCode.Value}"); + } + + if (operations[i].sessionToken != null) + { + sb.Append($@",""sessionToken"":""{operations[i].sessionToken}"""); + } + + sb.Append('}'); + } + + sb.Append("]}"); + return sb.ToString(); + } + + private Mock CreateMockContext( + ISessionContainer sessionContainer, + string responseContent, + HttpStatusCode statusCode) + { + MockDocumentClient documentClient = new MockDocumentClient(); + documentClient.sessionContainer = sessionContainer; + + ContainerProperties containerProperties = ContainerProperties.CreateWithResourceId(CollectionResourceId); + containerProperties.Id = "TestContainerId"; + containerProperties.PartitionKeyPath = "/pk"; + + Mock mockContext = new Mock(); + mockContext.Setup(c => c.DocumentClient).Returns(documentClient); + mockContext.Setup(c => c.SerializerCore).Returns(MockCosmosUtil.Serializer); + mockContext.Setup(c => c.GetCachedContainerPropertiesAsync( + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(containerProperties); + + ResponseMessage responseMessage = new ResponseMessage(statusCode); + if (responseContent != null) + { + responseMessage.Content = new MemoryStream(Encoding.UTF8.GetBytes(responseContent)); + } + + mockContext.Setup(c => c.ProcessResourceOperationStreamAsync( + It.IsAny(), + ResourceType.DistributedTransactionBatch, + OperationType.CommitDistributedTransaction, + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(responseMessage); + + return mockContext; + } + } +} diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionSerializerTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionSerializerTests.cs index 0f0b851813..1a19ce8312 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionSerializerTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionSerializerTests.cs @@ -11,6 +11,7 @@ namespace Microsoft.Azure.Cosmos.Tests using System.Text.Json; using System.Threading; using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Common; using Microsoft.Azure.Cosmos.Tracing; using Microsoft.Azure.Documents; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -436,8 +437,15 @@ private Mock BuildContextSetup() ContainerProperties containerProps = ContainerProperties.CreateWithResourceId("ccZ1ANCszwk="); containerProps.PartitionKeyPath = "/pk"; + MockDocumentClient documentClient = new MockDocumentClient(); + documentClient.sessionContainer = new SessionContainer("testhost"); + Mock contextMock = new Mock(); + contextMock + .Setup(c => c.DocumentClient) + .Returns(documentClient); + contextMock .Setup(c => c.SerializerCore) .Returns(MockCosmosUtil.Serializer); diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedWriteTransactionTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedWriteTransactionTests.cs index dc8d0f2e0b..69e3444959 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedWriteTransactionTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedWriteTransactionTests.cs @@ -194,6 +194,45 @@ public async Task CommitAsync_SetsIdempotencyTokenHeader() Assert.IsTrue(Guid.TryParse(capturedToken, out _), "Idempotency token must be a valid GUID."); } + [TestMethod] + [Description("The idempotency token echoed back in the server response header is surfaced on the DistributedTransactionResponse.")] + public async Task CommitAsync_ResponseContainsIdempotencyToken() + { + Mock contextMock = this.BuildContextSetup(); + contextMock + .Setup(c => c.ProcessResourceOperationStreamAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Returns, ITrace, CancellationToken>( + (uri, resType, opType, opts, container, pk, itemId, stream, enricher, trace, ct) => + { + // Capture the outgoing idempotency token and echo it back, simulating server behavior. + RequestMessage req = new RequestMessage(); + enricher?.Invoke(req); + string token = req.Headers[HttpConstants.HttpHeaders.IdempotencyToken] + ?? Guid.NewGuid().ToString(); + + ResponseMessage response = BuildSuccessResponse(1); + response.Headers[HttpConstants.HttpHeaders.IdempotencyToken] = token; + return Task.FromResult(response); + }); + + DistributedTransactionResponse response = await new DistributedWriteTransactionCore(contextMock.Object) + .CreateItem(Database, Container, new PartitionKey("pk"), new TestItem()) + .CommitTransactionAsync(CancellationToken.None); + + Assert.AreNotEqual(Guid.Empty, response.IdempotencyToken, "Response must carry the idempotency token."); + } + [TestMethod] public async Task CommitAsync_OperationIndexIsZeroBasedAndOrdered() { @@ -371,8 +410,14 @@ private Mock BuildContextSetup() ContainerProperties containerProps = ContainerProperties.CreateWithResourceId("ccZ1ANCszwk="); containerProps.PartitionKeyPath = "/pk"; + MockDocumentClient documentClient = new MockDocumentClient(); + Mock contextMock = new Mock(); + contextMock + .Setup(c => c.DocumentClient) + .Returns(documentClient); + contextMock .Setup(c => c.SerializerCore) .Returns(MockCosmosUtil.Serializer);