Skip to content

Commit

Permalink
Initial distributed tracing implementation for .NET (#257)
Browse files Browse the repository at this point in the history
  • Loading branch information
ReubenBond authored Jul 25, 2024
1 parent b466259 commit e7ac11b
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 17 deletions.
6 changes: 4 additions & 2 deletions dotnet/samples/Greeter/Greeter.AgentWorker/AgentClient.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
using System.Diagnostics;
using Agents;
using Microsoft.AI.Agents.Worker.Client;
using AgentId = Microsoft.AI.Agents.Worker.Client.AgentId;

namespace Greeter.AgentWorker;

public sealed class AgentClient(ILogger<AgentClient> logger, AgentWorkerRuntime runtime) : AgentBase(new ClientContext(logger, runtime))
public sealed class AgentClient(ILogger<AgentClient> logger, AgentWorkerRuntime runtime, DistributedContextPropagator distributedContextPropagator) : AgentBase(new ClientContext(logger, runtime, distributedContextPropagator))
{
public async ValueTask PublishEventAsync(Event @event) => await PublishEvent(@event);
public async ValueTask<RpcResponse> SendRequestAsync(AgentId target, string method, Dictionary<string, string> parameters) => await RequestAsync(target, method, parameters);

private sealed class ClientContext(ILogger<AgentClient> logger, AgentWorkerRuntime runtime) : IAgentContext
private sealed class ClientContext(ILogger<AgentClient> logger, AgentWorkerRuntime runtime, DistributedContextPropagator distributedContextPropagator) : IAgentContext
{
public AgentId AgentId { get; } = new AgentId("client", Guid.NewGuid().ToString());
public AgentBase? AgentInstance { get; set; }
public ILogger Logger { get; } = logger;
public DistributedContextPropagator DistributedContextPropagator { get; } = distributedContextPropagator;

public async ValueTask PublishEventAsync(Event @event)
{
Expand Down
165 changes: 156 additions & 9 deletions dotnet/src/Microsoft.AI.Agents.Worker.Client/AgentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
using System.Threading.Channels;
using Microsoft.Extensions.Logging;
using System.Text.Json;
using System.Diagnostics;

namespace Microsoft.AI.Agents.Worker.Client;

public abstract class AgentBase
{
private static readonly ActivitySource s_source = new("Starfleet.Agent");
private readonly object _lock = new();
private readonly Dictionary<string, TaskCompletionSource<RpcResponse>> _pendingRequests = [];
private readonly Channel<object> _mailbox = Channel.CreateUnbounded<object>();
Expand Down Expand Up @@ -77,10 +79,24 @@ private async Task HandleRpcMessage(Message msg)
switch (msg.MessageCase)
{
case Message.MessageOneofCase.Event:
await HandleEvent(msg.Event).ConfigureAwait(false);
{
var activity = ExtractActivity(msg.Event.Type, msg.Event.Metadata);
await InvokeWithActivityAsync(
static ((AgentBase Agent, Event Item) state) => state.Agent.HandleEvent(state.Item),
(this, msg.Event),
activity,
msg.Event.Type).ConfigureAwait(false);
}
break;
case Message.MessageOneofCase.Request:
await OnRequestCore(msg.Request).ConfigureAwait(false);
{
var activity = ExtractActivity(msg.Request.Method, msg.Request.Metadata);
await InvokeWithActivityAsync(
static ((AgentBase Agent, RpcRequest Request) state) => state.Agent.OnRequestCore(state.Request),
(this, msg.Request),
activity,
msg.Request.Method).ConfigureAwait(false);
}
break;
case Message.MessageOneofCase.Response:
OnResponseCore(msg.Response);
Expand All @@ -103,7 +119,7 @@ private void OnResponseCore(RpcResponse response)
completion.SetResult(response);
}

private async ValueTask OnRequestCore(RpcRequest request)
private async Task OnRequestCore(RpcRequest request)
{
RpcResponse response;

Expand All @@ -130,22 +146,153 @@ protected async Task<RpcResponse> RequestAsync(AgentId target, string method, Di
Data = JsonSerializer.Serialize(parameters)
};

var activity = s_source.StartActivity($"Call '{method}'", ActivityKind.Client, Activity.Current?.Context ?? default);
activity?.SetTag("peer.service", target.ToString());

var completion = new TaskCompletionSource<RpcResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
lock (_lock)
{
_pendingRequests[requestId] = completion;
}
Context.DistributedContextPropagator.Inject(activity, request.Metadata, static (carrier, key, value) => ((IDictionary<string, string>)carrier!)[key] = value);
await InvokeWithActivityAsync(
static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource<RpcResponse>) state) =>
{
var (self, request, completion) = state;

lock (self._lock)
{
self._pendingRequests[request.RequestId] = completion;
}

await _context.SendRequestAsync(this, request).ConfigureAwait(false);
await state.Agent._context.SendRequestAsync(state.Agent, state.Request).ConfigureAwait(false);

await completion.Task.ConfigureAwait(false);
},
(this, request, completion),
activity,
method).ConfigureAwait(false);

// Return the result from the already-completed task
return await completion.Task.ConfigureAwait(false);
}

protected async ValueTask PublishEvent(Event item)
{
await _context.PublishEventAsync(item).ConfigureAwait(false);
var activity = s_source.StartActivity($"PublishEvent '{item.Type}'", ActivityKind.Client, Activity.Current?.Context ?? default);
activity?.SetTag("peer.service", $"{item.Type}/{item.Namespace}");

var completion = new TaskCompletionSource<RpcResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
Context.DistributedContextPropagator.Inject(activity, item.Metadata, static (carrier, key, value) => ((IDictionary<string, string>)carrier!)[key] = value);
await InvokeWithActivityAsync(
static async ((AgentBase Agent, Event Event, TaskCompletionSource<RpcResponse>) state) =>
{
await state.Agent._context.PublishEventAsync(state.Event).ConfigureAwait(false);
},
(this, item, completion),
activity,
item.Type).ConfigureAwait(false);
}

protected virtual Task<RpcResponse> HandleRequest(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" });

protected virtual Task HandleEvent(Event item) => Task.CompletedTask;

protected async Task InvokeWithActivityAsync<TState>(Func<TState, Task> func, TState state, Activity? activity, string methodName)
{
if (activity is not null)
{
activity.Start();

// rpc attributes from https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/rpc.md
activity.SetTag("rpc.system", "starfleet");
activity.SetTag("rpc.service", AgentId.ToString());
activity.SetTag("rpc.method", methodName);
}

try
{
await func(state).ConfigureAwait(false);
if (activity is not null && activity.IsAllDataRequested)
{
activity.SetStatus(ActivityStatusCode.Ok);
}
}
catch (Exception e)
{
if (activity is not null && activity.IsAllDataRequested)
{
activity.SetStatus(ActivityStatusCode.Error);

// exception attributes from https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/exceptions.md
activity.SetTag("exception.type", e.GetType().FullName);
activity.SetTag("exception.message", e.Message);

// Note that "exception.stacktrace" is the full exception detail, not just the StackTrace property.
// See https://opentelemetry.io/docs/specs/semconv/attributes-registry/exception/
// and https://github.com/open-telemetry/opentelemetry-specification/pull/697#discussion_r453662519
activity.SetTag("exception.stacktrace", e.ToString());
activity.SetTag("exception.escaped", true);
}

throw;
}
finally
{
activity?.Stop();
}
}

private Activity? ExtractActivity(string activityName, IDictionary<string, string> metadata)
{
Activity? activity = null;
Context.DistributedContextPropagator.ExtractTraceIdAndState(metadata,
static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
{
var metadata = (IDictionary<string, string>)carrier!;
fieldValues = null;
metadata.TryGetValue(fieldName, out fieldValue);
},
out var traceParent,
out var traceState);

if (!string.IsNullOrEmpty(traceParent))
{
if (ActivityContext.TryParse(traceParent, traceState, isRemote: true, out ActivityContext parentContext))
{
// traceParent is a W3CId
activity = s_source.CreateActivity(activityName, ActivityKind.Server, parentContext);
}
else
{
// Most likely, traceParent uses ActivityIdFormat.Hierarchical
activity = s_source.CreateActivity(activityName, ActivityKind.Server, traceParent);
}

if (activity is not null)
{
if (!string.IsNullOrEmpty(traceState))
{
activity.TraceStateString = traceState;
}

var baggage = Context.DistributedContextPropagator.ExtractBaggage(metadata, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
{
var metadata = (IDictionary<string, string>)carrier!;
fieldValues = null;
metadata.TryGetValue(fieldName, out fieldValue);
});

if (baggage is not null)
{
foreach (var baggageItem in baggage)
{
activity.AddBaggage(baggageItem.Key, baggageItem.Value);
}
}
}
}
else
{
activity = s_source.CreateActivity(activityName, ActivityKind.Server);
}

return activity;
}
}
6 changes: 4 additions & 2 deletions dotnet/src/Microsoft.AI.Agents.Worker.Client/AgentContext.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
using Agents;
using Agents;
using RpcEvent = Agents.Event;
using Microsoft.Extensions.Logging;
using System.Diagnostics;

namespace Microsoft.AI.Agents.Worker.Client;

internal sealed class AgentContext(AgentId agentId, AgentWorkerRuntime runtime, ILogger<AgentBase> logger) : IAgentContext
internal sealed class AgentContext(AgentId agentId, AgentWorkerRuntime runtime, ILogger<AgentBase> logger, DistributedContextPropagator distributedContextPropagator) : IAgentContext
{
private readonly AgentWorkerRuntime _runtime = runtime;

public AgentId AgentId { get; } = agentId;
public ILogger Logger { get; } = logger;
public AgentBase? AgentInstance { get; set; }
public DistributedContextPropagator DistributedContextPropagator { get; } = distributedContextPropagator;

public async ValueTask SendResponseAsync(RpcRequest request, RpcResponse response)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
using Microsoft.Extensions.DependencyInjection;
using System.Threading.Channels;
using Grpc.Net.Client.Configuration;
using System.Diagnostics;
using Microsoft.Extensions.DependencyInjection.Extensions;

namespace Microsoft.AI.Agents.Worker.Client;

Expand Down Expand Up @@ -45,6 +47,7 @@ public static AgentApplicationBuilder AddAgentWorker(this IHostApplicationBuilde
channelOptions.ThrowOperationCanceledOnCancellation = true;
});
});
builder.Services.TryAddSingleton(DistributedContextPropagator.Current);
builder.Services.AddSingleton<AgentWorkerRuntime>();
builder.Services.AddSingleton<IHostedService>(sp => sp.GetRequiredService<AgentWorkerRuntime>());
return new AgentApplicationBuilder(builder);
Expand Down Expand Up @@ -77,6 +80,7 @@ public sealed class AgentWorkerRuntime : IHostedService, IDisposable
private readonly IServiceProvider _serviceProvider;
private readonly IEnumerable<Tuple<string, Type>> _configuredAgentTypes;
private readonly ILogger<AgentWorkerRuntime> _logger;
private readonly DistributedContextPropagator _distributedContextPropagator;
private readonly CancellationTokenSource _shutdownCts;
private AsyncDuplexStreamingCall<Message, Message>? _channel;
private Task? _readTask;
Expand All @@ -87,12 +91,14 @@ public AgentWorkerRuntime(
IHostApplicationLifetime hostApplicationLifetime,
IServiceProvider serviceProvider,
[FromKeyedServices("AgentTypes")] IEnumerable<Tuple<string, Type>> configuredAgentTypes,
ILogger<AgentWorkerRuntime> logger)
ILogger<AgentWorkerRuntime> logger,
DistributedContextPropagator distributedContextPropagator)
{
_client = client;
_serviceProvider = serviceProvider;
_configuredAgentTypes = configuredAgentTypes;
_logger = logger;
_distributedContextPropagator = distributedContextPropagator;
_shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(hostApplicationLifetime.ApplicationStopping);
}

Expand Down Expand Up @@ -195,7 +201,7 @@ private AgentBase GetOrActivateAgent(AgentId agentId)
{
if (_agentTypes.TryGetValue(agentId.Name, out var agentType))
{
var context = new AgentContext(agentId, this, _serviceProvider.GetRequiredService<ILogger<AgentBase>>());
var context = new AgentContext(agentId, this, _serviceProvider.GetRequiredService<ILogger<AgentBase>>(), _distributedContextPropagator);
agent = (AgentBase)ActivatorUtilities.CreateInstance(_serviceProvider, agentType, context);
_agents.TryAdd((agentId.Name, agentId.Namespace), agent);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
using Agents;
using Agents;
using RpcEvent = Agents.Event;
using Microsoft.Extensions.Logging;
using System.Diagnostics;

namespace Microsoft.AI.Agents.Worker.Client;

public interface IAgentContext
{
AgentId AgentId { get; }
AgentBase? AgentInstance { get; set; }
DistributedContextPropagator DistributedContextPropagator { get; }
ILogger Logger { get; }
ValueTask SendResponseAsync(RpcRequest request, RpcResponse response);
ValueTask SendRequestAsync(AgentBase agent, RpcRequest request);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.DependencyInjection;
using Orleans.Serialization;
using Microsoft.Extensions.DependencyInjection.Extensions;
using System.Diagnostics;

namespace Microsoft.AI.Agents.Worker;

Expand All @@ -14,6 +16,7 @@ public static IHostApplicationBuilder AddAgentService(this IHostApplicationBuild

// Ensure Orleans is added before the hosted service to guarantee that it starts first.
builder.UseOrleans();
builder.Services.TryAddSingleton(DistributedContextPropagator.Current);
builder.Services.AddSingleton<WorkerGateway>();
builder.Services.AddSingleton<IHostedService>(sp => sp.GetRequiredService<WorkerGateway>());

Expand Down

0 comments on commit e7ac11b

Please sign in to comment.