diff --git a/sdk/eventgrid/Azure.Messaging.EventGrid/src/Customization/CloudEventRequestContent.cs b/sdk/eventgrid/Azure.Messaging.EventGrid/src/Customization/CloudEventRequestContent.cs new file mode 100644 index 000000000000..05de0df53faf --- /dev/null +++ b/sdk/eventgrid/Azure.Messaging.EventGrid/src/Customization/CloudEventRequestContent.cs @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Core.Pipeline; + +namespace Azure.Messaging.EventGrid +{ + internal class CloudEventRequestContent : RequestContent + { + private IEnumerable _cloudEvents; + private const string TraceParentHeaderName = "traceparent"; + private const string TraceStateHeaderName = "tracestate"; + private byte[] _data; + + public CloudEventRequestContent(IEnumerable cloudEvents) + { + _cloudEvents = cloudEvents; + } + + public override void Dispose() + { + } + + public override bool TryComputeLength(out long length) + { + EnsureSerialized(); + length = _data.Length; + return true; + } + + public override void WriteTo(Stream stream, CancellationToken cancellationToken) + { + EnsureSerialized(); + stream.Write(_data, 0, _data.Length); + } + + public override async Task WriteToAsync(Stream stream, CancellationToken cancellationToken) + { + EnsureSerialized(); + await stream.WriteAsync(_data, 0, _data.Length, cancellationToken).ConfigureAwait(false); + } + + private void EnsureSerialized() + { + if (_data != null) + { + return; + } + + string currentActivityId = null; + string traceState = null; + Activity currentActivity = Activity.Current; + if (currentActivity != null && currentActivity.IsW3CFormat()) + { + currentActivityId = currentActivity.Id; + currentActivity.TryGetTraceState(out traceState); + } + + foreach (CloudEvent cloudEvent in _cloudEvents) + { + if (currentActivityId != null && + !cloudEvent.ExtensionAttributes.ContainsKey(TraceParentHeaderName) && + !cloudEvent.ExtensionAttributes.ContainsKey(TraceStateHeaderName)) + { + cloudEvent.ExtensionAttributes.Add(TraceParentHeaderName, currentActivityId); + if (traceState != null) + { + cloudEvent.ExtensionAttributes.Add(TraceStateHeaderName, traceState); + } + } + } + _data = JsonSerializer.SerializeToUtf8Bytes(_cloudEvents, typeof(List)); + } + } +} diff --git a/sdk/eventgrid/Azure.Messaging.EventGrid/src/Customization/EventGridPublisherClient.cs b/sdk/eventgrid/Azure.Messaging.EventGrid/src/Customization/EventGridPublisherClient.cs index fa78b262c6dd..7c71bb24eca0 100644 --- a/sdk/eventgrid/Azure.Messaging.EventGrid/src/Customization/EventGridPublisherClient.cs +++ b/sdk/eventgrid/Azure.Messaging.EventGrid/src/Customization/EventGridPublisherClient.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.ComponentModel; -using System.Diagnostics; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -28,9 +27,6 @@ public class EventGridPublisherClient private readonly HttpPipeline _pipeline; private readonly string _apiVersion; - private const string TraceParentHeaderName = "traceparent"; - private const string TraceStateHeaderName = "tracestate"; - private static readonly JsonObjectSerializer s_jsonSerializer = new JsonObjectSerializer(); /// Initalizes a new instance of the class for mocking. @@ -235,46 +231,9 @@ private async Task SendCloudEventsInternal(IEnumerable eve { // List of events cannot be null Argument.AssertNotNull(events, nameof(events)); - - string activityId = null; - string traceState = null; - Activity currentActivity = Activity.Current; - if (currentActivity != null && currentActivity.IsW3CFormat()) - { - activityId = currentActivity.Id; - currentActivity.TryGetTraceState(out traceState); - } - - foreach (CloudEvent cloudEvent in events) - { - // Individual events cannot be null - Argument.AssertNotNull(cloudEvent, nameof(cloudEvent)); - - if (activityId != null && - !cloudEvent.ExtensionAttributes.ContainsKey(TraceParentHeaderName) && - !cloudEvent.ExtensionAttributes.ContainsKey(TraceStateHeaderName)) - { - cloudEvent.ExtensionAttributes.Add(TraceParentHeaderName, activityId); - if (traceState != null) - { - cloudEvent.ExtensionAttributes.Add(TraceStateHeaderName, traceState); - } - } - } using HttpMessage message = _pipeline.CreateMessage(); Request request = CreateEventRequest(message, "application/cloudevents-batch+json; charset=utf-8"); - - BinaryData data; - if (async) - { - data = await s_jsonSerializer.SerializeAsync(events, typeof(List), cancellationToken).ConfigureAwait(false); - } - else - { - data = s_jsonSerializer.Serialize(events, typeof(List), cancellationToken); - } - - RequestContent content = RequestContent.Create(data.ToMemory()); + CloudEventRequestContent content = new CloudEventRequestContent(events); request.Content = content; if (async) diff --git a/sdk/eventgrid/Azure.Messaging.EventGrid/tests/CloudEventTests.cs b/sdk/eventgrid/Azure.Messaging.EventGrid/tests/CloudEventTests.cs index adc2ce952dc7..7cc218f6b3ed 100644 --- a/sdk/eventgrid/Azure.Messaging.EventGrid/tests/CloudEventTests.cs +++ b/sdk/eventgrid/Azure.Messaging.EventGrid/tests/CloudEventTests.cs @@ -9,9 +9,9 @@ using System.Text.Json; using System.Threading; using System.Threading.Tasks; -using Azure.Core; using Azure.Core.Pipeline; using Azure.Core.TestFramework; +using Azure.Core.Tests; using NUnit.Framework; namespace Azure.Messaging.EventGrid.Tests @@ -28,7 +28,122 @@ public class CloudEventTests [TestCase(false, true)] public async Task SetsTraceParentExtension(bool inclTraceparent, bool inclTracestate) { - var mockTransport = new MockTransport(new MockResponse(200)); + MockTransport mockTransport = CreateMockTransport(); + + var options = new EventGridPublisherClientOptions + { + Transport = mockTransport + }; + EventGridPublisherClient client = + new EventGridPublisherClient( + new Uri("http://localHost"), + new AzureKeyCredential("fakeKey"), + options); + + using ClientDiagnosticListener diagnosticListener = new ClientDiagnosticListener(s => s.StartsWith("Azure."), asyncLocal: true); + + var activity = new Activity($"{nameof(EventGridPublisherClient)}.{nameof(EventGridPublisherClient.SendEvents)}"); + activity.SetW3CFormat(); + activity.Start(); + activity.TraceStateString = "tracestatevalue"; + List eventsList = new List(); + for (int i = 0; i < 10; i++) + { + CloudEvent cloudEvent = new CloudEvent( + "record", + "Microsoft.MockPublisher.TestEvent", + JsonDocument.Parse("{\"property1\": \"abc\", \"property2\": 123}").RootElement) + { + Id = "id", + Subject = $"Subject-{i}", + Time = DateTimeOffset.UtcNow + }; + if (inclTraceparent && i % 2 == 0) + { + cloudEvent.ExtensionAttributes.Add("traceparent", "traceparentValue"); + } + if (inclTracestate && i % 2 == 0) + { + cloudEvent.ExtensionAttributes.Add("tracestate", "param:value"); + } + eventsList.Add(cloudEvent); + } + await client.SendEventsAsync(eventsList); + + // stop activity after extracting the events from the request as this is where the cloudEvents would actually + // be serialized + activity.Stop(); + + IEnumerator cloudEnum = eventsList.GetEnumerator(); + for (int i = 0; i < 10; i++) + { + cloudEnum.MoveNext(); + IDictionary cloudEventAttr = cloudEnum.Current.ExtensionAttributes; + if (inclTraceparent && inclTracestate && i % 2 == 0) + { + Assert.AreEqual( + "traceparentValue", + cloudEventAttr[TraceParentHeaderName]); + + Assert.AreEqual( + "param:value", + cloudEventAttr[TraceStateHeaderName]); + } + else if (inclTraceparent && i % 2 == 0) + { + Assert.AreEqual( + "traceparentValue", + cloudEventAttr[TraceParentHeaderName]); + } + else if (inclTracestate && i % 2 == 0) + { + Assert.AreEqual( + "param:value", + cloudEventAttr[TraceStateHeaderName]); + } + else + { + Assert.IsTrue(mockTransport.SingleRequest.Headers.TryGetValue(TraceParentHeaderName, out string requestHeader)); + + Assert.AreEqual( + requestHeader, + cloudEventAttr[TraceParentHeaderName]); + } + } + } + + private static MockTransport CreateMockTransport() + { + return new MockTransport((request) => + { + var stream = new MemoryStream(); + request.Content.WriteTo(stream, CancellationToken.None); + return new MockResponse(200); + }); + } + + [Test] + [TestCase(false, false)] + [TestCase(true, true)] + [TestCase(true, false)] + [TestCase(false, true)] + public async Task SetsTraceParentExtensionRetries(bool inclTraceparent, bool inclTracestate) + { + int requestCt = 0; + var mockTransport = new MockTransport((request) => + { + var stream = new MemoryStream(); + request.Content.WriteTo(stream, CancellationToken.None); + if (requestCt++ == 0) + { + return new MockResponse(500); + } + else + { + return new MockResponse(200); + } + }); + var options = new EventGridPublisherClientOptions { Transport = mockTransport @@ -38,9 +153,12 @@ public async Task SetsTraceParentExtension(bool inclTraceparent, bool inclTraces new Uri("http://localHost"), new AzureKeyCredential("fakeKey"), options); + using ClientDiagnosticListener diagnosticListener = new ClientDiagnosticListener(s => s.StartsWith("Azure."), asyncLocal: true); + var activity = new Activity($"{nameof(EventGridPublisherClient)}.{nameof(EventGridPublisherClient.SendEvents)}"); activity.SetW3CFormat(); activity.Start(); + activity.TraceStateString = "tracestatevalue"; List eventsList = new List(); for (int i = 0; i < 10; i++) { @@ -65,41 +183,48 @@ public async Task SetsTraceParentExtension(bool inclTraceparent, bool inclTraces } await client.SendEventsAsync(eventsList); + // stop activity after extracting the events from the request as this is where the cloudEvents would actually + // be serialized activity.Stop(); - List cloudEvents = DeserializeRequest(mockTransport.SingleRequest); + IEnumerator cloudEnum = eventsList.GetEnumerator(); - foreach (CloudEvent cloudEvent in cloudEvents) + for (int i = 0; i < 10; i++) { cloudEnum.MoveNext(); IDictionary cloudEventAttr = cloudEnum.Current.ExtensionAttributes; - if (cloudEventAttr.ContainsKey(TraceParentHeaderName) && - cloudEventAttr.ContainsKey(TraceStateHeaderName)) + if (inclTraceparent && inclTracestate && i % 2 == 0) { Assert.AreEqual( - cloudEventAttr[TraceParentHeaderName], - cloudEvent.ExtensionAttributes[TraceParentHeaderName]); + "traceparentValue", + cloudEventAttr[TraceParentHeaderName]); Assert.AreEqual( - cloudEventAttr[TraceStateHeaderName], - cloudEvent.ExtensionAttributes[TraceStateHeaderName]); + "param:value", + cloudEventAttr[TraceStateHeaderName]); } - else if (cloudEventAttr.ContainsKey(TraceParentHeaderName)) + else if (inclTraceparent && i % 2 == 0) { Assert.AreEqual( - cloudEventAttr[TraceParentHeaderName], - cloudEvent.ExtensionAttributes[TraceParentHeaderName]); + "traceparentValue", + cloudEventAttr[TraceParentHeaderName]); } - else if (cloudEventAttr.ContainsKey(TraceStateHeaderName)) + else if (inclTracestate && i % 2 == 0) { Assert.AreEqual( - cloudEventAttr[TraceStateHeaderName], - cloudEvent.ExtensionAttributes[TraceStateHeaderName]); + "param:value", + cloudEventAttr[TraceStateHeaderName]); } else { + Assert.IsTrue(mockTransport.Requests[1].Headers.TryGetValue(TraceParentHeaderName, out string traceParent)); Assert.AreEqual( - activity.Id, - cloudEvent.ExtensionAttributes[TraceParentHeaderName]); + traceParent, + cloudEventAttr[TraceParentHeaderName]); + + Assert.IsTrue(mockTransport.Requests[1].Headers.TryGetValue(TraceStateHeaderName, out string traceState)); + Assert.AreEqual( + traceState, + cloudEventAttr[TraceStateHeaderName]); } } } @@ -107,7 +232,7 @@ public async Task SetsTraceParentExtension(bool inclTraceparent, bool inclTraces [Test] public async Task SerializesExpectedProperties_BaseType() { - var mockTransport = new MockTransport(new MockResponse(200)); + var mockTransport = CreateMockTransport(); var options = new EventGridPublisherClientOptions { Transport = mockTransport @@ -137,14 +262,13 @@ public async Task SerializesExpectedProperties_BaseType() await client.SendEventsAsync(eventsList); - cloudEvent = DeserializeRequest(mockTransport.SingleRequest).First(); Assert.IsNull(cloudEvent.Data.ToObjectFromJson().DerivedProperty); } [Test] public async Task SerializesExpectedProperties_DerivedType() { - var mockTransport = new MockTransport(new MockResponse(200)); + var mockTransport = CreateMockTransport(); var options = new EventGridPublisherClientOptions { Transport = mockTransport @@ -173,18 +297,7 @@ public async Task SerializesExpectedProperties_DerivedType() await client.SendEventsAsync(eventsList); - cloudEvent = DeserializeRequest(mockTransport.SingleRequest).First(); Assert.AreEqual(5, cloudEvent.Data.ToObjectFromJson().DerivedProperty); } - - private static List DeserializeRequest(Request request) - { - var stream = new MemoryStream(); - request.Content.WriteTo(stream, CancellationToken.None); - stream.Position = 0; - using var reader = new StreamReader(stream); - CloudEvent[] cloudEvents = CloudEvent.ParseEvents(reader.ReadToEnd()); - return cloudEvents.ToList(); - } } }