Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,25 @@ namespace SampleApp
/// <summary>
/// Sample memory component that can remember a user's name and age.
/// </summary>
internal sealed class UserInfoMemory : AIContextProvider
internal sealed class UserInfoMemory : AIContextProvider<UserInfo>
{
private readonly IChatClient _chatClient;
private readonly Func<AgentSession?, UserInfo> _stateInitializer;

public UserInfoMemory(IChatClient chatClient, Func<AgentSession?, UserInfo>? stateInitializer = null)
: base(stateInitializer ?? (_ => new UserInfo()), null, null, null, null)
{
this._chatClient = chatClient;
this._stateInitializer = stateInitializer ?? (_ => new UserInfo());
}

public UserInfo GetUserInfo(AgentSession session)
=> session.StateBag.GetValue<UserInfo>(nameof(UserInfoMemory)) ?? new UserInfo();
=> this.GetOrInitializeState(session);

public void SetUserInfo(AgentSession session, UserInfo userInfo)
=> session.StateBag.SetValue(nameof(UserInfoMemory), userInfo);
=> this.SaveState(session, userInfo);

protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default)
protected override async ValueTask StoreAIContextAsync(InvokedContext context, CancellationToken cancellationToken = default)
{
var userInfo = context.Session?.StateBag.GetValue<UserInfo>(nameof(UserInfoMemory))
?? this._stateInitializer.Invoke(context.Session);
var userInfo = this.GetOrInitializeState(context.Session);

// Try and extract the user name and age from the message if we don't have it already and it's a user message.
if ((userInfo.UserName is null || userInfo.UserAge is null) && context.RequestMessages.Any(x => x.Role == ChatRole.User))
Expand All @@ -123,20 +121,14 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc
userInfo.UserAge ??= result.Result.UserAge;
}

context.Session?.StateBag.SetValue(nameof(UserInfoMemory), userInfo);
this.SaveState(context.Session, userInfo);
}

protected override ValueTask<AIContext> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
protected override ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
var inputContext = context.AIContext;
var userInfo = context.Session?.StateBag.GetValue<UserInfo>(nameof(UserInfoMemory))
?? this._stateInitializer.Invoke(context.Session);
var userInfo = this.GetOrInitializeState(context.Session);

StringBuilder instructions = new();
if (!string.IsNullOrEmpty(inputContext.Instructions))
{
instructions.AppendLine(inputContext.Instructions);
}

// If we don't already know the user's name and age, add instructions to ask for them, otherwise just provide what we have to the context.
instructions
Expand All @@ -151,9 +143,7 @@ userInfo.UserAge is null ?

return new ValueTask<AIContext>(new AIContext
{
Instructions = instructions.ToString(),
Messages = inputContext.Messages,
Tools = inputContext.Tools
Instructions = instructions.ToString()
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,45 +76,23 @@ namespace SampleApp
/// State (the session DB key) is stored in the <see cref="AgentSession.StateBag"/> so it roundtrips
/// automatically with session serialization.
/// </summary>
internal sealed class VectorChatHistoryProvider : ChatHistoryProvider
internal sealed class VectorChatHistoryProvider : ChatHistoryProvider<VectorChatHistoryProvider.State>
{
private readonly VectorStore _vectorStore;
private readonly Func<AgentSession?, State> _stateInitializer;
private readonly string _stateKey;

/// <inheritdoc />
public override string StateKey => this._stateKey;

public VectorChatHistoryProvider(
VectorStore vectorStore,
Func<AgentSession?, State>? stateInitializer = null,
string? stateKey = null)
: base(stateInitializer: stateInitializer ?? (_ => new State(Guid.NewGuid().ToString("N"))), stateKey: stateKey, jsonSerializerOptions: null, provideOutputMessageFilter: null, storeInputMessageFilter: null)
{
this._vectorStore = vectorStore ?? throw new ArgumentNullException(nameof(vectorStore));
this._stateInitializer = stateInitializer ?? (_ => new State(Guid.NewGuid().ToString("N")));
this._stateKey = stateKey ?? base.StateKey;
}

public string GetSessionDbKey(AgentSession session)
=> this.GetOrInitializeState(session).SessionDbKey;

private State GetOrInitializeState(AgentSession? session)
{
if (session?.StateBag.TryGetValue<State>(this._stateKey, out var state) is true && state is not null)
{
return state;
}

state = this._stateInitializer(session);
if (session is not null)
{
session.StateBag.SetValue(this._stateKey, state);
}

return state;
}

protected override async ValueTask<IEnumerable<ChatMessage>> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
protected override async ValueTask<IEnumerable<ChatMessage>> ProvideChatHistoryAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
var state = this.GetOrInitializeState(context.Session);
var collection = this._vectorStore.GetCollection<string, ChatHistoryItem>("ChatHistory");
Expand All @@ -129,29 +107,17 @@ protected override async ValueTask<IEnumerable<ChatMessage>> InvokingCoreAsync(I

var messages = records.ConvertAll(x => JsonSerializer.Deserialize<ChatMessage>(x.SerializedMessage!)!);
messages.Reverse();
return messages
.Select(message => message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, this.GetType().FullName!))
.Concat(context.RequestMessages);
return messages;
}

protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default)
protected override async ValueTask StoreChatHistoryAsync(InvokedContext context, CancellationToken cancellationToken = default)
{
// Don't store messages if the request failed.
if (context.InvokeException is not null)
{
return;
}

var state = this.GetOrInitializeState(context.Session);

var collection = this._vectorStore.GetCollection<string, ChatHistoryItem>("ChatHistory");
await collection.EnsureCollectionExistsAsync(cancellationToken);

// Add both request and response messages to the store, excluding messages that came from chat history.
// Optionally messages produced by the AIContextProvider can also be persisted (not shown).
var allNewMessages = context.RequestMessages
.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory)
.Concat(context.ResponseMessages ?? []);
var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []);

await collection.UpsertAsync(allNewMessages.Select(x => new ChatHistoryItem()
{
Expand Down
159 changes: 154 additions & 5 deletions dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
Expand All @@ -10,7 +11,7 @@
namespace Microsoft.Agents.AI;

/// <summary>
/// Provides an abstract base class for components that enhance AI context management during agent invocations.
/// Provides an abstract base class for components that enhance AI context during agent invocations.
/// </summary>
/// <remarks>
/// <para>
Expand All @@ -30,6 +31,25 @@ namespace Microsoft.Agents.AI;
/// </remarks>
public abstract class AIContextProvider
{
private static IEnumerable<ChatMessage> DefaultExternalOnlyFilter(IEnumerable<ChatMessage> messages)
=> messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External);

private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _provideInputMessageFilter;
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _storeInputMessageFilter;

/// <summary>
/// Initializes a new instance of the <see cref="AIContextProvider"/> class.
/// </summary>
/// <param name="provideInputMessageFilter">An optional filter function to apply to input messages before providing context via <see cref="ProvideAIContextAsync"/>. If not set, defaults to including only <see cref="AgentRequestMessageSourceType.External"/> messages.</param>
/// <param name="storeInputMessageFilter">An optional filter function to apply to request messages before storing context via <see cref="StoreAIContextAsync"/>. If not set, defaults to including only <see cref="AgentRequestMessageSourceType.External"/> messages.</param>
protected AIContextProvider(
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? provideInputMessageFilter = null,
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? storeInputMessageFilter = null)
{
this._provideInputMessageFilter = provideInputMessageFilter ?? DefaultExternalOnlyFilter;
this._storeInputMessageFilter = storeInputMessageFilter ?? DefaultExternalOnlyFilter;
}

/// <summary>
/// Gets the key used to store the provider state in the <see cref="AgentSession.StateBag"/>.
/// </summary>
Expand Down Expand Up @@ -58,7 +78,7 @@ public abstract class AIContextProvider
/// </para>
/// </remarks>
public ValueTask<AIContext> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
=> this.InvokingCoreAsync(context, cancellationToken);
=> this.InvokingCoreAsync(Throw.IfNull(context), cancellationToken);

/// <summary>
/// Called at the start of agent invocation to provide additional context.
Expand All @@ -76,8 +96,96 @@ public ValueTask<AIContext> InvokingAsync(InvokingContext context, CancellationT
/// <item><description>Injecting contextual messages from conversation history</description></item>
/// </list>
/// </para>
/// <para>
/// The default implementation of this method filters the input messages using the configured provide-input message filter
/// (which defaults to including only <see cref="AgentRequestMessageSourceType.External"/> messages),
/// then calls <see cref="ProvideAIContextAsync"/> to get additional context,
/// stamps any messages from the returned context with <see cref="AgentRequestMessageSourceType.AIContextProvider"/> source attribution,
/// and merges the returned context with the original (unfiltered) input context (concatenating instructions, messages, and tools).
/// For most scenarios, overriding <see cref="ProvideAIContextAsync"/> is sufficient to provide additional context,
/// while still benefiting from the default filtering, merging and source stamping behavior.
/// However, for scenarios that require more control over context filtering, merging or source stamping, overriding this method
/// allows you to directly control the full <see cref="AIContext"/> returned for the invocation.
/// </para>
/// </remarks>
protected abstract ValueTask<AIContext> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default);
protected virtual async ValueTask<AIContext> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
var inputContext = context.AIContext;

// Create a filtered context for ProvideAIContextAsync, filtering input messages
// to exclude non-external messages (e.g. chat history, other AI context provider messages).
var filteredContext = new InvokingContext(
context.Agent,
context.Session,
new AIContext
{
Instructions = inputContext.Instructions,
Messages = inputContext.Messages is not null ? this._provideInputMessageFilter(inputContext.Messages) : null,
Tools = inputContext.Tools
});

var provided = await this.ProvideAIContextAsync(filteredContext, cancellationToken).ConfigureAwait(false);

var mergedInstructions = (inputContext.Instructions, provided.Instructions) switch
{
(null, null) => null,
(string a, null) => a,
(null, string b) => b,
(string a, string b) => a + "\n" + b
};

var providedMessages = provided.Messages is not null
? provided.Messages.Select(m => m.WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, this.GetType().FullName!))
: null;

var mergedMessages = (inputContext.Messages, providedMessages) switch
{
(null, null) => null,
(var a, null) => a,
(null, var b) => b,
(var a, var b) => a.Concat(b)
};

var mergedTools = (inputContext.Tools, provided.Tools) switch
{
(null, null) => null,
(var a, null) => a,
(null, var b) => b,
(var a, var b) => a.Concat(b)
};

return new AIContext
{
Instructions = mergedInstructions,
Messages = mergedMessages,
Tools = mergedTools
};
}

/// <summary>
/// When overridden in a derived class, provides additional AI context to be merged with the input context for the current invocation.
/// </summary>
/// <remarks>
/// <para>
/// This method is called from <see cref="InvokingCoreAsync"/>.
/// Note that <see cref="InvokingCoreAsync"/> can be overridden to directly control context merging and source stamping, in which case
/// it is up to the implementer to call this method as needed to retrieve the additional context.
/// </para>
/// <para>
/// In contrast with <see cref="InvokingCoreAsync"/>, this method only returns additional context to be merged with the input,
/// while <see cref="InvokingCoreAsync"/> is responsible for returning the full merged <see cref="AIContext"/> for the invocation.
/// </para>
/// </remarks>
/// <param name="context">Contains the request context including the caller provided messages that will be used by the agent for this invocation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>
/// A task that represents the asynchronous operation. The task result contains an <see cref="AIContext"/>
/// with additional context to be merged with the input context.
/// </returns>
protected virtual ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
return new ValueTask<AIContext>(new AIContext());
}

/// <summary>
/// Called at the end of the agent invocation to process the invocation results.
Expand Down Expand Up @@ -106,7 +214,7 @@ public ValueTask<AIContext> InvokingAsync(InvokingContext context, CancellationT
/// </para>
/// </remarks>
public ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default)
=> this.InvokedCoreAsync(context, cancellationToken);
=> this.InvokedCoreAsync(Throw.IfNull(context), cancellationToken);

/// <summary>
/// Called at the end of the agent invocation to process the invocation results.
Expand All @@ -128,9 +236,50 @@ public ValueTask InvokedAsync(InvokedContext context, CancellationToken cancella
/// This method is called regardless of whether the invocation succeeded or failed.
/// To check if the invocation was successful, inspect the <see cref="InvokedContext.InvokeException"/> property.
/// </para>
/// <para>
/// The default implementation of this method skips execution for any invocation failures,
/// filters the request messages using the configured store-input message filter
/// (which defaults to including only <see cref="AgentRequestMessageSourceType.External"/> messages),
/// and calls <see cref="StoreAIContextAsync"/> to process the invocation results.
/// For most scenarios, overriding <see cref="StoreAIContextAsync"/> is sufficient to process invocation results,
/// while still benefiting from the default error handling and filtering behavior.
/// However, for scenarios that require more control over error handling or message filtering, overriding this method
/// allows you to directly control the processing of invocation results.
/// </para>
/// </remarks>
protected virtual ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default)
=> default;
{
if (context.InvokeException is not null)
{
return default;
}

var subContext = new InvokedContext(context.Agent, context.Session, this._storeInputMessageFilter(context.RequestMessages), context.ResponseMessages!);
return this.StoreAIContextAsync(subContext, cancellationToken);
}

/// <summary>
/// When overridden in a derived class, processes invocation results at the end of the agent invocation.
/// </summary>
/// <param name="context">Contains the invocation context including request messages, response messages, and any exception that occurred.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that represents the asynchronous operation.</returns>
/// <remarks>
/// <para>
/// This method is called from <see cref="InvokedCoreAsync"/>.
/// Note that <see cref="InvokedCoreAsync"/> can be overridden to directly control error handling, in which case
/// it is up to the implementer to call this method as needed to process the invocation results.
/// </para>
/// <para>
/// In contrast with <see cref="InvokedCoreAsync"/>, this method only processes the invocation results,
/// while <see cref="InvokedCoreAsync"/> is also responsible for error handling.
/// </para>
/// <para>
/// The default implementation of <see cref="InvokedCoreAsync"/> only calls this method if the invocation succeeded.
/// </para>
/// </remarks>
protected virtual ValueTask StoreAIContextAsync(InvokedContext context, CancellationToken cancellationToken = default) =>
default;

/// <summary>Asks the <see cref="AIContextProvider"/> for an object of the specified type <paramref name="serviceType"/>.</summary>
/// <param name="serviceType">The type of object being requested.</param>
Expand Down
Loading
Loading