Skip to content

Commit

Permalink
[.Net] Fix #3045 (#3047)
Browse files Browse the repository at this point in the history
* make IStreamingMessage obsolete

* update final reply message
  • Loading branch information
LittleLittleCloud authored and victordibia committed Jul 30, 2024
1 parent 540997c commit cd203b3
Show file tree
Hide file tree
Showing 30 changed files with 67 additions and 65 deletions.
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, G
return new MessageEnvelope<ChatCompletionResponse>(response, from: this.Name);
}

public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages,
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await foreach (var message in _anthropicClient.StreamingChatCompletionsAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent,
: response;
}

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
public async IAsyncEnumerable<IMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessages = await ProcessMessageAsync(messages, agent);

await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
{
if (reply is IStreamingMessage<ChatCompletionResponse> chatMessage)
if (reply is IMessage<ChatCompletionResponse> chatMessage)
{
var response = ProcessChatCompletionResponse(chatMessage, agent);
if (response is not null)
Expand All @@ -52,7 +52,7 @@ public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext c
}
}

private IStreamingMessage? ProcessChatCompletionResponse(IStreamingMessage<ChatCompletionResponse> chatMessage,
private IMessage? ProcessChatCompletionResponse(IMessage<ChatCompletionResponse> chatMessage,
IStreamingAgent agent)
{
if (chatMessage.Content.Content is { Count: 1 } &&
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace AutoGen.Core;
/// </summary>
public interface IStreamingAgent : IAgent
{
public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
public IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default);
Expand Down
4 changes: 2 additions & 2 deletions dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, Generat
return _agent.GenerateReplyAsync(messages, options, cancellationToken);
}

public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
public IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
return _agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
}
Expand Down Expand Up @@ -83,7 +83,7 @@ public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, Generat
return this.streamingMiddleware.InvokeAsync(context, (IAgent)innerAgent, cancellationToken);
}

public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
public IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
if (streamingMiddleware is null)
{
Expand Down
14 changes: 9 additions & 5 deletions dotnet/src/AutoGen.Core/Message/IMessage.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IMessage.cs

using System;
using System.Collections.Generic;

namespace AutoGen.Core;
Expand Down Expand Up @@ -35,37 +36,40 @@ namespace AutoGen.Core;
/// </item>
/// </list>
/// </summary>
public interface IMessage : IStreamingMessage
public interface IMessage
{
string? From { get; set; }
}

public interface IMessage<out T> : IMessage, IStreamingMessage<T>
public interface IMessage<out T> : IMessage
{
T Content { get; }
}

/// <summary>
/// The interface for messages that can get text content.
/// This interface will be used by <see cref="MessageExtension.GetContent(IMessage)"/> to get the content from the message.
/// </summary>
public interface ICanGetTextContent : IMessage, IStreamingMessage
public interface ICanGetTextContent : IMessage
{
public string? GetContent();
}

/// <summary>
/// The interface for messages that can get a list of <see cref="ToolCall"/>
/// </summary>
public interface ICanGetToolCalls : IMessage, IStreamingMessage
public interface ICanGetToolCalls : IMessage
{
public IEnumerable<ToolCall> GetToolCalls();
}


[Obsolete("Use IMessage instead")]
public interface IStreamingMessage
{
string? From { get; set; }
}

[Obsolete("Use IMessage<T> instead")]
public interface IStreamingMessage<out T> : IStreamingMessage
{
T Content { get; }
Expand Down
4 changes: 2 additions & 2 deletions dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace AutoGen.Core;

public abstract class MessageEnvelope : IMessage, IStreamingMessage
public abstract class MessageEnvelope : IMessage
{
public MessageEnvelope(string? from = null, IDictionary<string, object>? metadata = null)
{
Expand All @@ -23,7 +23,7 @@ public static MessageEnvelope<TContent> Create<TContent>(TContent content, strin
public IDictionary<string, object> Metadata { get; set; }
}

public class MessageEnvelope<T> : MessageEnvelope, IMessage<T>, IStreamingMessage<T>
public class MessageEnvelope<T> : MessageEnvelope, IMessage<T>
{
public MessageEnvelope(T content, string? from = null, IDictionary<string, object>? metadata = null)
: base(from, metadata)
Expand Down
4 changes: 2 additions & 2 deletions dotnet/src/AutoGen.Core/Message/TextMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace AutoGen.Core;

public class TextMessage : IMessage, IStreamingMessage, ICanGetTextContent
public class TextMessage : IMessage, ICanGetTextContent
{
public TextMessage(Role role, string content, string? from = null)
{
Expand Down Expand Up @@ -51,7 +51,7 @@ public override string ToString()
}
}

public class TextMessageUpdate : IStreamingMessage, ICanGetTextContent
public class TextMessageUpdate : IMessage, ICanGetTextContent
{
public TextMessageUpdate(Role role, string? content, string? from = null)
{
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public IEnumerable<ToolCall> GetToolCalls()
}
}

public class ToolCallMessageUpdate : IStreamingMessage
public class ToolCallMessageUpdate : IMessage
{
public ToolCallMessageUpdate(string functionName, string functionArgumentUpdate, string? from = null)
{
Expand Down
16 changes: 10 additions & 6 deletions dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent,
return reply;
}

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(
public async IAsyncEnumerable<IMessage> InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
Expand All @@ -86,16 +86,16 @@ public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(
var combinedFunctions = this.functions?.Concat(options.Functions ?? []) ?? options.Functions;
options.Functions = combinedFunctions?.ToArray();

IStreamingMessage? initMessage = default;
IMessage? mergedFunctionCallMessage = default;
await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken))
{
if (message is ToolCallMessageUpdate toolCallMessageUpdate && this.functionMap != null)
{
if (initMessage is null)
if (mergedFunctionCallMessage is null)
{
initMessage = new ToolCallMessage(toolCallMessageUpdate);
mergedFunctionCallMessage = new ToolCallMessage(toolCallMessageUpdate);
}
else if (initMessage is ToolCallMessage toolCall)
else if (mergedFunctionCallMessage is ToolCallMessage toolCall)
{
toolCall.Update(toolCallMessageUpdate);
}
Expand All @@ -104,13 +104,17 @@ public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(
throw new InvalidOperationException("The first message is ToolCallMessage, but the update message is not ToolCallMessageUpdate");
}
}
else if (message is ToolCallMessage toolCallMessage1)
{
mergedFunctionCallMessage = toolCallMessage1;
}
else
{
yield return message;
}
}

if (initMessage is ToolCallMessage toolCallMsg)
if (mergedFunctionCallMessage is ToolCallMessage toolCallMsg)
{
yield return await this.InvokeToolCallMessagesAfterInvokingAgentAsync(toolCallMsg, agent);
}
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public interface IStreamingMiddleware : IMiddleware
/// <summary>
/// The streaming version of <see cref="IMiddleware.InvokeAsync(MiddlewareContext, IAgent, CancellationToken)"/>.
/// </summary>
public IAsyncEnumerable<IStreamingMessage> InvokeAsync(
public IAsyncEnumerable<IMessage> InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
CancellationToken cancellationToken = default);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent,
}
}

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public async IAsyncEnumerable<IMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
IMessage? recentUpdate = null;
await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken))
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Gemini/GeminiChatAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, G
return MessageEnvelope.Create(response, this.Name);
}

public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var request = BuildChatRequest(messages, options);
var response = this.client.GenerateContentStreamAsync(request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public GeminiMessageConnector(bool strictMode = false)

public string Name => nameof(GeminiMessageConnector);

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public async IAsyncEnumerable<IMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = ProcessMessage(context.Messages, agent);

Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public async Task<IMessage> GenerateReplyAsync(
return new MessageEnvelope<ChatCompletionResponse>(response, from: this.Name);
}

public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ public class MistralChatMessageConnector : IStreamingMiddleware, IMiddleware
{
public string? Name => nameof(MistralChatMessageConnector);

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public async IAsyncEnumerable<IMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessages = ProcessMessage(messages, agent);
var chunks = new List<ChatCompletionResponse>();
await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
{
if (reply is IStreamingMessage<ChatCompletionResponse> chatMessage)
if (reply is IMessage<ChatCompletionResponse> chatMessage)
{
chunks.Add(chatMessage.Content);
var response = ProcessChatCompletionResponse(chatMessage, agent);
Expand Down Expand Up @@ -167,7 +167,7 @@ private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from
}
}

private IStreamingMessage? ProcessChatCompletionResponse(IStreamingMessage<ChatCompletionResponse> message, IAgent agent)
private IMessage? ProcessChatCompletionResponse(IMessage<ChatCompletionResponse> message, IAgent agent)
{
var response = message.Content;
if (response.VarObject != "chat.completion.chunk")
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public async Task<IMessage> GenerateReplyAsync(
}
}

public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent,
};
}

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
public async IAsyncEnumerable<IMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = ProcessMessage(context.Messages, agent);
var chunks = new List<ChatResponseUpdate>();
await foreach (var update in agent.GenerateStreamingReplyAsync(messages, context.Options, cancellationToken))
{
if (update is IStreamingMessage<ChatResponseUpdate> chatResponseUpdate)
if (update is IMessage<ChatResponseUpdate> chatResponseUpdate)
{
var response = chatResponseUpdate.Content switch
{
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public async Task<IMessage> GenerateReplyAsync(
return await _innerAgent.GenerateReplyAsync(messages, options, cancellationToken);
}

public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
public IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public async Task<IMessage> GenerateReplyAsync(
return new MessageEnvelope<ChatCompletions>(reply, from: this.Name);
}

public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent,
return PostProcessMessage(reply);
}

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(
public async IAsyncEnumerable<IMessage> InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
Expand All @@ -57,7 +57,7 @@ public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(
string? currentToolName = null;
await foreach (var reply in streamingReply)
{
if (reply is IStreamingMessage<StreamingChatCompletionsUpdate> update)
if (reply is IMessage<StreamingChatCompletionsUpdate> update)
{
if (update.Content.FunctionName is string functionName)
{
Expand Down Expand Up @@ -98,7 +98,7 @@ public IMessage PostProcessMessage(IMessage message)
};
}

public IStreamingMessage? PostProcessStreamingMessage(IStreamingMessage<StreamingChatCompletionsUpdate> update, string? currentToolName)
public IMessage? PostProcessStreamingMessage(IMessage<StreamingChatCompletionsUpdate> update, string? currentToolName)
{
if (update.Content.ContentUpdate is string contentUpdate)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent,
return PostProcessMessage(reply);
}

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public async IAsyncEnumerable<IMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var chatMessageContents = ProcessMessage(context.Messages, agent)
.Select(m => new MessageEnvelope<ChatMessageContent>(m));
Expand All @@ -67,11 +67,11 @@ private IMessage PostProcessMessage(IMessage input)
};
}

private IStreamingMessage PostProcessStreamingMessage(IStreamingMessage input)
private IMessage PostProcessStreamingMessage(IMessage input)
{
return input switch
{
IStreamingMessage<StreamingChatMessageContent> streamingMessage => PostProcessMessage(streamingMessage),
IMessage<StreamingChatMessageContent> streamingMessage => PostProcessMessage(streamingMessage),
IMessage msg => PostProcessMessage(msg),
_ => input,
};
Expand All @@ -98,7 +98,7 @@ private IMessage PostProcessMessage(IMessage<ChatMessageContent> messageEnvelope
}
}

private IStreamingMessage PostProcessMessage(IStreamingMessage<StreamingChatMessageContent> streamingMessage)
private IMessage PostProcessMessage(IMessage<StreamingChatMessageContent> streamingMessage)
{
var chatMessageContent = streamingMessage.Content;
if (chatMessageContent.ChoiceIndex > 0)
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, G
return new MessageEnvelope<ChatMessageContent>(reply.First(), from: this.Name);
}

public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
Expand Down
Loading

0 comments on commit cd203b3

Please sign in to comment.