Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;
using Microsoft.Extensions.AI;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Agents.AI;

/// <summary>
/// Contains extension methods to allow storing and retrieving properties using the type name of the property as the key.
/// </summary>
public static class AdditionalPropertiesExtensions
{
/// <summary>
/// Adds an additional property using the type name of the property as the key.
/// </summary>
/// <typeparam name="T">The type of the property to add.</typeparam>
/// <param name="additionalProperties">The dictionary of additional properties.</param>
/// <param name="value">The value to add.</param>
public static void Add<T>(this AdditionalPropertiesDictionary additionalProperties, T value)
{
_ = Throw.IfNull(additionalProperties);

additionalProperties.Add(typeof(T).FullName!, value);
}

/// <summary>
/// Attempts to add a property using the type name of the property as the key.
/// </summary>
/// <remarks>
/// This method uses the full name of the type parameter as the key. If the key already exists,
/// the value is not updated and the method returns <see langword="false"/>.
/// </remarks>
/// <typeparam name="T">The type of the property to add.</typeparam>
/// <param name="additionalProperties">The dictionary of additional properties.</param>
/// <param name="value">The value to add.</param>
/// <returns>
/// <see langword="true"/> if the value was added successfully; <see langword="false"/> if the key already exists.
/// </returns>
public static bool TryAdd<T>(this AdditionalPropertiesDictionary additionalProperties, T value)
{
_ = Throw.IfNull(additionalProperties);

return additionalProperties.TryAdd(typeof(T).FullName!, value);
}

/// <summary>
/// Attempts to retrieve a value from the additional properties dictionary using the type name of the property as the key.
/// </summary>
/// <remarks>
/// This method uses the full name of the type parameter as the key when searching the dictionary.
/// </remarks>
/// <typeparam name="T">The type of the property to be retrieved.</typeparam>
/// <param name="additionalProperties">The dictionary containing additional properties.</param>
/// <param name="value">
/// When this method returns, contains the value retrieved from the dictionary, if found and successfully converted to the requested type;
/// otherwise, the default value of <typeparamref name="T"/>.
/// </param>
/// <returns>
/// <see langword="true"/> if a non-<see langword="null"/> value was found
/// in the dictionary and converted to the requested type; otherwise, <see langword="false"/>.
/// </returns>
public static bool TryGetValue<T>(this AdditionalPropertiesDictionary additionalProperties, [NotNullWhen(true)] out T? value)
{
_ = Throw.IfNull(additionalProperties);

return additionalProperties.TryGetValue(typeof(T).FullName!, out value);
}

/// <summary>
/// Determines whether the additional properties dictionary contains a property with the name of the provided type as the key.
/// </summary>
/// <typeparam name="T">The type of the property to check for.</typeparam>
/// <param name="additionalProperties">The dictionary of additional properties.</param>
/// <returns>
/// <see langword="true"/> if the dictionary contains a property with the name of the provided type as the key; otherwise, <see langword="false"/>.
/// </returns>
public static bool Contains<T>(this AdditionalPropertiesDictionary additionalProperties)
{
_ = Throw.IfNull(additionalProperties);

return additionalProperties.ContainsKey(typeof(T).FullName!);
}

/// <summary>
/// Removes a property from the additional properties dictionary using the name of the provided type as the key.
/// </summary>
/// <typeparam name="T">The type of the property to remove.</typeparam>
/// <param name="additionalProperties">The dictionary of additional properties.</param>
/// <returns>
/// <see langword="true"/> if the property was successfully removed; otherwise, <see langword="false"/>.
/// </returns>
public static bool Remove<T>(this AdditionalPropertiesDictionary additionalProperties)
{
_ = Throw.IfNull(additionalProperties);

return additionalProperties.Remove(typeof(T).FullName!);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ public sealed class InvokedContext
/// <param name="requestMessages">The caller provided messages that were used by the agent for this invocation.</param>
/// <param name="chatMessageStoreMessages">The messages retrieved from the <see cref="ChatMessageStore"/> for this invocation.</param>
/// <exception cref="ArgumentNullException"><paramref name="requestMessages"/> is <see langword="null"/>.</exception>
public InvokedContext(IEnumerable<ChatMessage> requestMessages, IEnumerable<ChatMessage> chatMessageStoreMessages)
public InvokedContext(IEnumerable<ChatMessage> requestMessages, IEnumerable<ChatMessage>? chatMessageStoreMessages)
{
this.RequestMessages = Throw.IfNull(requestMessages);
this.ChatMessageStoreMessages = Throw.IfNull(chatMessageStoreMessages);
this.ChatMessageStoreMessages = chatMessageStoreMessages;
}

/// <summary>
Expand All @@ -191,9 +191,9 @@ public InvokedContext(IEnumerable<ChatMessage> requestMessages, IEnumerable<Chat
/// </summary>
/// <value>
/// A collection of <see cref="ChatMessage"/> instances that were retrieved from the <see cref="ChatMessageStore"/>,
/// and were used by the agent as part of the invocation.
/// and were used by the agent as part of the invocation. May be null on the first run.
/// </value>
public IEnumerable<ChatMessage> ChatMessageStoreMessages { get; set { field = Throw.IfNull(value); } }
public IEnumerable<ChatMessage>? ChatMessageStoreMessages { get; set; }

/// <summary>
/// Gets or sets the messages provided by the <see cref="AIContextProvider"/> for this invocation, if any.
Expand Down
53 changes: 35 additions & 18 deletions dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ protected override async IAsyncEnumerable<AgentResponseUpdate> RunCoreStreamingA
}
catch (Exception ex)
{
await NotifyMessageStoreOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), chatMessageStoreMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
await NotifyMessageStoreOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), chatMessageStoreMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
throw;
}
Expand All @@ -246,7 +246,7 @@ protected override async IAsyncEnumerable<AgentResponseUpdate> RunCoreStreamingA
}
catch (Exception ex)
{
await NotifyMessageStoreOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), chatMessageStoreMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
await NotifyMessageStoreOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), chatMessageStoreMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
throw;
}
Expand All @@ -273,7 +273,7 @@ protected override async IAsyncEnumerable<AgentResponseUpdate> RunCoreStreamingA
}
catch (Exception ex)
{
await NotifyMessageStoreOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), chatMessageStoreMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
await NotifyMessageStoreOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), chatMessageStoreMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
throw;
}
Expand All @@ -286,7 +286,7 @@ protected override async IAsyncEnumerable<AgentResponseUpdate> RunCoreStreamingA
await this.UpdateThreadWithTypeAndConversationIdAsync(safeThread, chatResponse.ConversationId, cancellationToken).ConfigureAwait(false);

// To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request.
await NotifyMessageStoreOfNewMessagesAsync(safeThread, GetInputMessages(inputMessages, continuationToken), chatMessageStoreMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
await NotifyMessageStoreOfNewMessagesAsync(safeThread, GetInputMessages(inputMessages, continuationToken), chatMessageStoreMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false);

// Notify the AIContextProvider of all new messages.
await NotifyAIContextProviderOfSuccessAsync(safeThread, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -442,7 +442,7 @@ private async Task<TAgentResponse> RunCoreAsync<TAgentResponse, TChatClientRespo
}
catch (Exception ex)
{
await NotifyMessageStoreOfFailureAsync(safeThread, ex, inputMessages, chatMessageStoreMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
await NotifyMessageStoreOfFailureAsync(safeThread, ex, inputMessages, chatMessageStoreMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
throw;
}
Expand All @@ -460,7 +460,7 @@ private async Task<TAgentResponse> RunCoreAsync<TAgentResponse, TChatClientRespo
}

// Only notify the thread of new messages if the chatResponse was successful to avoid inconsistent message state in the thread.
await NotifyMessageStoreOfNewMessagesAsync(safeThread, inputMessages, chatMessageStoreMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
await NotifyMessageStoreOfNewMessagesAsync(safeThread, inputMessages, chatMessageStoreMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false);

// Notify the AIContextProvider of all new messages.
await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -703,16 +703,18 @@ private async Task

List<ChatMessage> inputMessagesForChatClient = [];
IList<ChatMessage>? aiContextProviderMessages = null;
IList<ChatMessage>? chatMessageStoreMessages = [];
IList<ChatMessage>? chatMessageStoreMessages = null;

// Populate the thread messages only if we are not continuing an existing response as it's not allowed
if (chatOptions?.ContinuationToken is null)
{
// Add any existing messages from the thread to the messages to be sent to the chat client.
if (typedThread.MessageStore is not null)
ChatMessageStore? chatMessageStore = ResolveChatMessageStore(typedThread, chatOptions);

// Add any existing messages from the chatMessageStore to the messages to be sent to the chat client.
if (chatMessageStore is not null)
{
var invokingContext = new ChatMessageStore.InvokingContext(inputMessages);
var storeMessages = await typedThread.MessageStore.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false);
var storeMessages = await chatMessageStore.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false);
inputMessagesForChatClient.AddRange(storeMessages);
chatMessageStoreMessages = storeMessages as IList<ChatMessage> ?? storeMessages.ToList();
}
Expand Down Expand Up @@ -803,21 +805,22 @@ private static Task NotifyMessageStoreOfFailureAsync(
IEnumerable<ChatMessage> requestMessages,
IEnumerable<ChatMessage>? chatMessageStoreMessages,
IEnumerable<ChatMessage>? aiContextProviderMessages,
ChatOptions? chatOptions,
CancellationToken cancellationToken)
{
var messageStore = thread.MessageStore;
ChatMessageStore? chatMessageStore = ResolveChatMessageStore(thread, chatOptions);

// Only notify the message store if we have one.
// If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages.
if (messageStore is not null)
if (chatMessageStore is not null)
{
var invokedContext = new ChatMessageStore.InvokedContext(requestMessages, chatMessageStoreMessages!)
var invokedContext = new ChatMessageStore.InvokedContext(requestMessages, chatMessageStoreMessages)
{
AIContextProviderMessages = aiContextProviderMessages,
InvokeException = ex
};

return messageStore.InvokedAsync(invokedContext, cancellationToken).AsTask();
return chatMessageStore.InvokedAsync(invokedContext, cancellationToken).AsTask();
}

return Task.CompletedTask;
Expand All @@ -829,25 +832,39 @@ private static Task NotifyMessageStoreOfNewMessagesAsync(
IEnumerable<ChatMessage>? chatMessageStoreMessages,
IEnumerable<ChatMessage>? aiContextProviderMessages,
IEnumerable<ChatMessage> responseMessages,
ChatOptions? chatOptions,
CancellationToken cancellationToken)
{
var messageStore = thread.MessageStore;
ChatMessageStore? chatMessageStore = ResolveChatMessageStore(thread, chatOptions);

// Only notify the message store if we have one.
// If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages.
if (messageStore is not null)
if (chatMessageStore is not null)
{
var invokedContext = new ChatMessageStore.InvokedContext(requestMessages, chatMessageStoreMessages!)
var invokedContext = new ChatMessageStore.InvokedContext(requestMessages, chatMessageStoreMessages)
{
AIContextProviderMessages = aiContextProviderMessages,
ResponseMessages = responseMessages
};
return messageStore.InvokedAsync(invokedContext, cancellationToken).AsTask();
return chatMessageStore.InvokedAsync(invokedContext, cancellationToken).AsTask();
}

return Task.CompletedTask;
}

private static ChatMessageStore? ResolveChatMessageStore(ChatClientAgentThread thread, ChatOptions? chatOptions)
{
ChatMessageStore? chatMessageStore = thread.MessageStore;

// If someone provided an override ChatMessageStore via AdditionalProperties, we should use that instead of the one on the thread.
if (chatOptions?.AdditionalProperties?.TryGetValue(out ChatMessageStore? overrideChatMessageStore) is true)
{
chatMessageStore = overrideChatMessageStore;
}

return chatMessageStore;
}

private static ChatClientAgentContinuationToken? WrapContinuationToken(ResponseContinuationToken? continuationToken, IEnumerable<ChatMessage>? inputMessages = null, List<ChatResponseUpdate>? responseUpdates = null)
{
if (continuationToken is null)
Expand Down
Loading
Loading