Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@ namespace Azure.Messaging.EventGrid
{
internal class EventGridKeyCredentialPolicy : HttpPipelineSynchronousPolicy
{
private readonly string _name;
private readonly AzureKeyCredential _credential;
public const string SystemPublisherKey = "AzureSystemPublisher";

public EventGridKeyCredentialPolicy(AzureKeyCredential credential, string name)
public EventGridKeyCredentialPolicy(AzureKeyCredential credential)
{
Argument.AssertNotNull(credential, nameof(credential));
Argument.AssertNotNullOrEmpty(name, nameof(name));
_credential = credential;
_name = name;
}

public override void OnSendingRequest(HttpMessage message)
Expand All @@ -29,7 +26,7 @@ public override void OnSendingRequest(HttpMessage message)
// in the request in this case.
if (_credential.Key != SystemPublisherKey)
{
message.Request.Headers.SetValue(_name, _credential.Key);
message.Request.Headers.SetValue(Constants.SasKeyName, _credential.Key);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;
using Azure.Core.Serialization;
using Azure.Messaging.EventGrid.Models;

namespace Azure.Messaging.EventGrid
Expand All @@ -19,15 +18,9 @@ namespace Azure.Messaging.EventGrid
/// </summary>
public class EventGridPublisherClient
{
private readonly EventGridRestClient _serviceRestClient;
private readonly ClientDiagnostics _clientDiagnostics;
private string _hostName => _endpoint.Authority;
private readonly Uri _endpoint;
private readonly AzureKeyCredential _key;
private readonly RequestUriBuilder _uriBuilder;
private readonly HttpPipeline _pipeline;
private readonly string _apiVersion;

private static readonly JsonObjectSerializer s_jsonSerializer = new JsonObjectSerializer();

/// <summary>Initalizes a new instance of the <see cref="EventGridPublisherClient"/> class for mocking.</summary>
protected EventGridPublisherClient()
Expand All @@ -50,11 +43,10 @@ public EventGridPublisherClient(Uri endpoint, AzureKeyCredential credential, Eve
{
Argument.AssertNotNull(credential, nameof(credential));
options ??= new EventGridPublisherClientOptions();
_apiVersion = options.Version.GetVersionString();
_endpoint = endpoint;
_key = credential;
_pipeline = HttpPipelineBuilder.Build(options, new EventGridKeyCredentialPolicy(credential, Constants.SasKeyName));
_serviceRestClient = new EventGridRestClient(new ClientDiagnostics(options), _pipeline, options.Version.GetVersionString());
_uriBuilder = new RequestUriBuilder();
_uriBuilder.Reset(endpoint);
_uriBuilder.AppendQuery("api-version", options.Version.GetVersionString(), true);
_pipeline = HttpPipelineBuilder.Build(options, new EventGridKeyCredentialPolicy(credential));
_clientDiagnostics = new ClientDiagnostics(options);
}

Expand All @@ -69,9 +61,10 @@ public EventGridPublisherClient(Uri endpoint, AzureSasCredential credential, Eve
{
Argument.AssertNotNull(credential, nameof(credential));
options ??= new EventGridPublisherClientOptions();
_endpoint = endpoint;
HttpPipeline pipeline = HttpPipelineBuilder.Build(options, new EventGridSharedAccessSignatureCredentialPolicy(credential));
_serviceRestClient = new EventGridRestClient(new ClientDiagnostics(options), pipeline, options.Version.GetVersionString());
_uriBuilder = new RequestUriBuilder();
_uriBuilder.Reset(endpoint);
_uriBuilder.AppendQuery("api-version", options.Version.GetVersionString(), true);
_pipeline = HttpPipelineBuilder.Build(options, new EventGridSharedAccessSignatureCredentialPolicy(credential));
_clientDiagnostics = new ClientDiagnostics(options);
}

Expand Down Expand Up @@ -123,20 +116,6 @@ private async Task<Response> SendCloudNativeCloudEventsInternalAsync(ReadOnlyMem
}
}

private Request CreateEventRequest(HttpMessage message, string contentType)
{
Request request = message.Request;
request.Method = RequestMethod.Post;
var uri = new RawRequestUriBuilder();
uri.AppendRaw("https://", false);
uri.AppendRaw(_hostName, false);
uri.AppendPath("/api/events", false);
uri.AppendQuery("api-version", _apiVersion, true);
request.Uri = uri;
request.Headers.Add("Content-Type", contentType);
return request;
}

/// <summary> Publishes a set of EventGridEvents to an Event Grid topic. </summary>
/// <param name="eventGridEvent"> The event to be published to Event Grid. </param>
/// <param name="cancellationToken"> An optional cancellation token instance to signal the request to cancel the operation.</param>
Expand Down Expand Up @@ -177,13 +156,13 @@ private async Task<Response> SendEventsInternal(IEnumerable<EventGridEvent> even
// List of events cannot be null
Argument.AssertNotNull(events, nameof(events));

List<EventGridEventInternal> eventsWithSerializedPayloads = new List<EventGridEventInternal>();
using HttpMessage message = _pipeline.CreateMessage();
Request request = CreateEventRequest(message, "application/json");
var content = new Utf8JsonRequestContent();
content.JsonWriter.WriteStartArray();
foreach (EventGridEvent egEvent in events)
{
// Individual events cannot be null
Argument.AssertNotNull(egEvent, nameof(egEvent));
JsonDocument data = JsonDocument.Parse(egEvent.Data.ToStream());

EventGridEventInternal newEGEvent = new EventGridEventInternal(
egEvent.Id,
egEvent.Subject,
Expand All @@ -194,24 +173,27 @@ private async Task<Response> SendEventsInternal(IEnumerable<EventGridEvent> even
{
Topic = egEvent.Topic
};

eventsWithSerializedPayloads.Add(newEGEvent);
content.JsonWriter.WriteObjectValue(newEGEvent);
}

content.JsonWriter.WriteEndArray();
request.Content = content;

if (async)
{
// Publish asynchronously if called via an async path
return await _serviceRestClient.PublishEventsAsync(
_hostName,
eventsWithSerializedPayloads,
cancellationToken).ConfigureAwait(false);
await _pipeline.SendAsync(message, cancellationToken).ConfigureAwait(false);
}
else
{
return _serviceRestClient.PublishEvents(
_hostName,
eventsWithSerializedPayloads,
cancellationToken);
_pipeline.Send(message, cancellationToken);
}
return message.Response.Status switch
{
200 => message.Response,
_ => async ?
throw await _clientDiagnostics.CreateRequestFailedExceptionAsync(message.Response).ConfigureAwait(false) :
throw _clientDiagnostics.CreateRequestFailedException(message.Response)
};
}
catch (Exception e)
{
Expand Down Expand Up @@ -313,44 +295,6 @@ public virtual async Task<Response> SendEventsAsync(IEnumerable<BinaryData> cust
public virtual Response SendEvents(IEnumerable<BinaryData> customEvents, CancellationToken cancellationToken = default)
=> PublishCustomEventsInternal(customEvents, false /*async*/, cancellationToken).EnsureCompleted();

private async Task<Response> PublishCustomEventsInternal(IEnumerable<object> events, bool async, CancellationToken cancellationToken = default)
{
using DiagnosticScope scope = _clientDiagnostics.CreateScope($"{nameof(EventGridPublisherClient)}.{nameof(SendEvents)}");
scope.Start();

try
{
List<CustomModelSerializer> serializedEvents = new List<CustomModelSerializer>();
foreach (object customEvent in events)
{
serializedEvents.Add(
new CustomModelSerializer(
customEvent,
s_jsonSerializer,
cancellationToken));
}
if (async)
{
return await _serviceRestClient.PublishCustomEventEventsAsync(
_hostName,
serializedEvents,
cancellationToken).ConfigureAwait(false);
}
else
{
return _serviceRestClient.PublishCustomEventEvents(
_hostName,
serializedEvents,
cancellationToken);
}
}
catch (Exception e)
{
scope.Failed(e);
throw;
}
}

private async Task<Response> PublishCustomEventsInternal(IEnumerable<BinaryData> events, bool async, CancellationToken cancellationToken = default)
{
using DiagnosticScope scope = _clientDiagnostics.CreateScope($"{nameof(EventGridPublisherClient)}.{nameof(SendEvents)}");
Expand Down Expand Up @@ -394,5 +338,14 @@ private async Task<Response> PublishCustomEventsInternal(IEnumerable<BinaryData>
throw;
}
}

private Request CreateEventRequest(HttpMessage message, string contentType)
{
Request request = message.Request;
request.Method = RequestMethod.Post;
request.Uri = _uriBuilder;
request.Headers.Add("Content-Type", contentType);
return request;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,48 @@ public async Task CanPublishEvent()
await client.SendEventsAsync(GetEventsList());
}

[RecordedTest]
public void CannotPublishEventMissingApiEventsPathFromUri()
{
EventGridPublisherClientOptions options = InstrumentClientOptions(new EventGridPublisherClientOptions());
Uri host = new UriBuilder("https", new Uri(TestEnvironment.TopicHost).Host).Uri;
EventGridPublisherClient client = InstrumentClient(
new EventGridPublisherClient(
host,
new AzureKeyCredential(TestEnvironment.TopicKey),
options));

Assert.ThrowsAsync<RequestFailedException>(async () => await client.SendEventAsync(new BinaryData(jsonSerializable: "data")));
}

[RecordedTest]
public void CannotPublishCloudEventMissingApiEventsPathFromUri()
{
EventGridPublisherClientOptions options = InstrumentClientOptions(new EventGridPublisherClientOptions());
Uri host = new UriBuilder("https", new Uri(TestEnvironment.TopicHost).Host).Uri;
EventGridPublisherClient client = InstrumentClient(
new EventGridPublisherClient(
host,
new AzureKeyCredential(TestEnvironment.TopicKey),
options));

Assert.ThrowsAsync<RequestFailedException>(async () => await client.SendEventAsync(new BinaryData(jsonSerializable: "data")));
}

[RecordedTest]
public void CannotPublishCustomEventMissingApiEventsPathFromUri()
{
EventGridPublisherClientOptions options = InstrumentClientOptions(new EventGridPublisherClientOptions());
Uri host = new UriBuilder("https", new Uri(TestEnvironment.TopicHost).Host).Uri;
EventGridPublisherClient client = InstrumentClient(
new EventGridPublisherClient(
host,
new AzureKeyCredential(TestEnvironment.TopicKey),
options));

Assert.ThrowsAsync<RequestFailedException>(async () => await client.SendEventAsync(new BinaryData(jsonSerializable: "data")));
}

[RecordedTest]
public async Task CanPublishSingleEvent()
{
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading