Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,16 @@ public override async Task<AgentRunResponse> RunAsync(IEnumerable<ChatMessage> m
// Create a thread if the user didn't supply one.
thread ??= this.GetNewThread();

if (thread is not CustomAgentThread typedThread)
{
throw new ArgumentException($"The provided thread is not of type {nameof(CustomAgentThread)}.", nameof(thread));
}

// Clone the input messages and turn them into response messages with upper case text.
List<ChatMessage> responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList();

// Notify the thread of the input and output messages.
await NotifyThreadOfNewMessagesAsync(thread, messages.Concat(responseMessages), cancellationToken);
await typedThread.MessageStore.AddMessagesAsync(messages.Concat(responseMessages), cancellationToken);

return new AgentRunResponse
{
Expand All @@ -58,11 +63,16 @@ public override async IAsyncEnumerable<AgentRunResponseUpdate> RunStreamingAsync
// Create a thread if the user didn't supply one.
thread ??= this.GetNewThread();

if (thread is not CustomAgentThread typedThread)
{
throw new ArgumentException($"The provided thread is not of type {nameof(CustomAgentThread)}.", nameof(thread));
}

// Clone the input messages and turn them into response messages with upper case text.
List<ChatMessage> responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList();

// Notify the thread of the input and output messages.
await NotifyThreadOfNewMessagesAsync(thread, messages.Concat(responseMessages), cancellationToken);
await typedThread.MessageStore.AddMessagesAsync(messages.Concat(responseMessages), cancellationToken);

foreach (var message in responseMessages)
{
Expand Down
24 changes: 0 additions & 24 deletions dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -328,28 +328,4 @@ public abstract IAsyncEnumerable<AgentRunResponseUpdate> RunStreamingAsync(
AgentThread? thread = null,
AgentRunOptions? options = null,
CancellationToken cancellationToken = default);

/// <summary>
/// Notifies the specified thread about new messages that have been added to the conversation.
/// </summary>
/// <param name="thread">The conversation thread to notify about the new messages.</param>
/// <param name="messages">The collection of new messages to report to the thread.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that represents the asynchronous notification operation.</returns>
/// <exception cref="ArgumentNullException"><paramref name="thread"/> or <paramref name="messages"/> is <see langword="null"/>.</exception>
/// <remarks>
/// <para>
/// This method ensures that conversation threads are kept informed about message additions, which
/// is important for threads that manage their own state, memory components, or derived context.
/// While all agent implementations should notify their threads, the specific actions taken by
/// each thread type may vary.
/// </para>
/// </remarks>
protected static async Task NotifyThreadOfNewMessagesAsync(AgentThread thread, IEnumerable<ChatMessage> messages, CancellationToken cancellationToken)
{
_ = Throw.IfNull(thread);
_ = Throw.IfNull(messages);

await thread.MessagesReceivedAsync(messages, cancellationToken).ConfigureAwait(false);
}
}
17 changes: 0 additions & 17 deletions dotnet/src/Microsoft.Agents.AI.Abstractions/AgentThread.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Agents.AI;
Expand Down Expand Up @@ -65,19 +61,6 @@ protected AgentThread()
public virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null)
=> default;

/// <summary>
/// This method is called when new messages have been contributed to the chat by any participant.
/// </summary>
/// <remarks>
/// Inheritors can use this method to update their context based on the new message.
/// </remarks>
/// <param name="newMessages">The new messages.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that completes when the context has been updated.</returns>
/// <exception cref="InvalidOperationException">The thread has been deleted.</exception>
protected internal virtual Task MessagesReceivedAsync(IEnumerable<ChatMessage> newMessages, CancellationToken cancellationToken = default)
=> Task.CompletedTask;

/// <summary>Asks the <see cref="AgentThread"/> for an object of the specified type <paramref name="serviceType"/>.</summary>
/// <param name="serviceType">The type of object being requested.</param>
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;

namespace Microsoft.Agents.AI;
Expand Down Expand Up @@ -116,10 +114,6 @@ public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptio
public override object? GetService(Type serviceType, object? serviceKey = null) =>
base.GetService(serviceType, serviceKey) ?? this.MessageStore?.GetService(serviceType, serviceKey);

/// <inheritdoc />
protected internal override Task MessagesReceivedAsync(IEnumerable<ChatMessage> newMessages, CancellationToken cancellationToken = default)
=> this.MessageStore.AddMessagesAsync(newMessages, cancellationToken);

[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private string DebuggerDisplay => $"Count = {this.MessageStore.Count}";

Expand Down
3 changes: 0 additions & 3 deletions dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowThread.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ public WorkflowThread(Workflow workflow, JsonElement serializedThread, IWorkflow

public CheckpointInfo? LastCheckpoint { get; set; }

protected override Task MessagesReceivedAsync(IEnumerable<ChatMessage> newMessages, CancellationToken cancellationToken = default)
=> this.MessageStore.AddMessagesAsync(newMessages, cancellationToken);

public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null)
{
JsonMarshaller marshaller = new(jsonSerializerOptions);
Expand Down
24 changes: 19 additions & 5 deletions dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ public override async IAsyncEnumerable<AgentRunResponseUpdate> RunStreamingAsync
this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId);

// To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request.
await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages.Concat(aiContextProviderMessages ?? []).Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false);
await NotifyMessageStoreOfNewMessagesAsync(safeThread, inputMessages.Concat(aiContextProviderMessages ?? []).Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false);

// Notify the AIContextProvider of all new messages.
await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -413,7 +413,7 @@ private async Task<TAgentRunResponse> RunCoreAsync<TAgentRunResponse, TChatClien
}

// Only notify the thread of new messages if the chatResponse was successful to avoid inconsistent message state in the thread.
await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages.Concat(aiContextProviderMessages ?? []).Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false);
await NotifyMessageStoreOfNewMessagesAsync(safeThread, inputMessages.Concat(aiContextProviderMessages ?? []).Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false);

// Notify the AIContextProvider of all new messages.
await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -711,12 +711,26 @@ private void UpdateThreadWithTypeAndConversationId(ChatClientAgentThread thread,
else
{
// If the service doesn't use service side thread storage (i.e. we got no id back from invocation), and
// the thread has no MessageStore yet, and we have a custom messages store, we should update the thread
// with the custom MessageStore so that it has somewhere to store the chat history.
thread.MessageStore ??= this._agentOptions?.ChatMessageStoreFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null });
// the thread has no MessageStore yet, we should update the thread with the custom MessageStore or
// default InMemoryMessageStore so that it has somewhere to store the chat history.
thread.MessageStore ??= this._agentOptions?.ChatMessageStoreFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }) ?? new InMemoryChatMessageStore();
}
}

private static Task NotifyMessageStoreOfNewMessagesAsync(ChatClientAgentThread thread, IEnumerable<ChatMessage> newMessages, CancellationToken cancellationToken)
{
var messageStore = thread.MessageStore;

// Only notify the message store if we have one.
// If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages.
if (messageStore is not null)
{
return messageStore.AddMessagesAsync(newMessages, cancellationToken);
}

return Task.CompletedTask;
}

private string GetLoggingAgentName() => this.Name ?? "UnnamedAgent";
#endregion
}
31 changes: 0 additions & 31 deletions dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentThread.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Agents.AI;
Expand Down Expand Up @@ -181,33 +177,6 @@ public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptio
?? this.AIContextProvider?.GetService(serviceType, serviceKey)
?? this.MessageStore?.GetService(serviceType, serviceKey);

/// <inheritdoc />
protected override async Task MessagesReceivedAsync(IEnumerable<ChatMessage> newMessages, CancellationToken cancellationToken = default)
{
switch (this)
{
case { ConversationId: not null }:
// If the thread messages are stored in the service
// there is nothing to do here, since invoking the
// service should already update the thread.
break;

case { MessageStore: null }:
// If there is no conversation id, and no store we can createa a default in memory store and add messages to it.
this._messageStore = new InMemoryChatMessageStore();
await this._messageStore!.AddMessagesAsync(newMessages, cancellationToken).ConfigureAwait(false);
break;

case { MessageStore: not null }:
// If a store has been provided, we need to add the messages to the store.
await this._messageStore!.AddMessagesAsync(newMessages, cancellationToken).ConfigureAwait(false);
break;

default:
throw new UnreachableException();
}
}

[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private string DebuggerDisplay =>
this.ConversationId is { } conversationId ? $"ConversationId = {conversationId}" :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Moq;
using Moq.Protected;

namespace Microsoft.Agents.AI.Abstractions.UnitTests;

Expand Down Expand Up @@ -222,21 +221,6 @@ public void ValidateAgentIDIsIdempotent()
Assert.Equal(id, agent.Id);
}

[Fact]
public async Task NotifyThreadOfNewMessagesNotifiesThreadAsync()
{
var cancellationToken = default(CancellationToken);

var messages = new[] { new ChatMessage(ChatRole.User, "msg1"), new ChatMessage(ChatRole.User, "msg2") };

var threadMock = new Mock<TestAgentThread> { CallBase = true };
threadMock.SetupAllProperties();

await MockAgent.NotifyThreadOfNewMessagesAsync(threadMock.Object, messages, cancellationToken);

threadMock.Protected().Verify("MessagesReceivedAsync", Times.Once(), messages, cancellationToken);
}

#region GetService Method Tests

/// <summary>
Expand Down Expand Up @@ -360,9 +344,6 @@ public abstract class TestAgentThread : AgentThread;

private sealed class MockAgent : AIAgent
{
public static new Task NotifyThreadOfNewMessagesAsync(AgentThread thread, IEnumerable<ChatMessage> messages, CancellationToken cancellationToken) =>
AIAgent.NotifyThreadOfNewMessagesAsync(thread, messages, cancellationToken);

public override AgentThread GetNewThread()
=> throw new NotImplementedException();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using Microsoft.Extensions.AI;

#pragma warning disable CA1861 // Avoid constant arrays as arguments

Expand All @@ -21,15 +19,6 @@ public void Serialize_ReturnsDefaultJsonElement()
Assert.Equal(default, result);
}

[Fact]
public void MessagesReceivedAsync_ReturnsCompletedTask()
{
var thread = new TestAgentThread();
var messages = new List<ChatMessage> { new(ChatRole.User, "hello") };
var result = thread.MessagesReceivedAsync(messages);
Assert.True(result.IsCompleted);
}

#region GetService Method Tests

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Moq;
Expand Down Expand Up @@ -91,50 +90,6 @@ public void SetChatMessageStoreThrowsWhenConversationIdIsSet()

#endregion Constructor and Property Tests

#region OnNewMessagesAsync Tests

[Fact]
public async Task OnNewMessagesAsyncDoesNothingWhenAgentServiceIdAsync()
{
// Arrange
var thread = new ChatClientAgentThread { ConversationId = "thread-123" };
var messages = new List<ChatMessage>
{
new(ChatRole.User, "Hello"),
new(ChatRole.Assistant, "Hi there!")
};
var agent = new MessageSendingAgent();

// Act
await agent.SendMessagesAsync(thread, messages, CancellationToken.None);
Assert.Equal("thread-123", thread.ConversationId);
Assert.Null(thread.MessageStore);
}

[Fact]
public async Task OnNewMessagesAsyncAddsMessagesToStoreAsync()
{
// Arrange
var store = new InMemoryChatMessageStore();
var thread = new ChatClientAgentThread { MessageStore = store };
var messages = new List<ChatMessage>
{
new(ChatRole.User, "Hello"),
new(ChatRole.Assistant, "Hi there!")
};
var agent = new MessageSendingAgent();

// Act
await agent.SendMessagesAsync(thread, messages, CancellationToken.None);

// Assert
Assert.Equal(2, store.Count);
Assert.Equal("Hello", store[0].Text);
Assert.Equal("Hi there!", store[1].Text);
}

#endregion OnNewMessagesAsync Tests

#region Deserialize Tests

[Fact]
Expand Down Expand Up @@ -372,22 +327,4 @@ public void GetService_RequestingChatMessageStore_ReturnsChatMessageStore()
}

#endregion

private sealed class MessageSendingAgent : AIAgent
{
public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null)
=> throw new NotImplementedException();

public override AgentThread GetNewThread()
=> throw new NotImplementedException();

public override Task<AgentRunResponse> RunAsync(IEnumerable<ChatMessage> messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default)
=> throw new NotImplementedException();

public override IAsyncEnumerable<AgentRunResponseUpdate> RunStreamingAsync(IEnumerable<ChatMessage> messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default)
=> throw new NotImplementedException();

public Task SendMessagesAsync(AgentThread thread, IEnumerable<ChatMessage> messages, CancellationToken cancellationToken = default)
=> NotifyThreadOfNewMessagesAsync(thread, messages, cancellationToken);
}
}
Loading