Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[.Net] Fix #3045 #3047

Merged
merged 3 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, G
return new MessageEnvelope<ChatCompletionResponse>(response, from: this.Name);
}

public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages,
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await foreach (var message in _anthropicClient.StreamingChatCompletionsAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent,
: response;
}

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
public async IAsyncEnumerable<IMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessages = await ProcessMessageAsync(messages, agent);

await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
{
if (reply is IStreamingMessage<ChatCompletionResponse> chatMessage)
if (reply is IMessage<ChatCompletionResponse> chatMessage)
{
var response = ProcessChatCompletionResponse(chatMessage, agent);
if (response is not null)
Expand All @@ -52,7 +52,7 @@ public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext c
}
}

private IStreamingMessage? ProcessChatCompletionResponse(IStreamingMessage<ChatCompletionResponse> chatMessage,
private IMessage? ProcessChatCompletionResponse(IMessage<ChatCompletionResponse> chatMessage,
IStreamingAgent agent)
{
if (chatMessage.Content.Content is { Count: 1 } &&
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace AutoGen.Core;
/// </summary>
public interface IStreamingAgent : IAgent
{
public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
public IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default);
Expand Down
4 changes: 2 additions & 2 deletions dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, Generat
return _agent.GenerateReplyAsync(messages, options, cancellationToken);
}

public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
public IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
return _agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
}
Expand Down Expand Up @@ -83,7 +83,7 @@ public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, Generat
return this.streamingMiddleware.InvokeAsync(context, (IAgent)innerAgent, cancellationToken);
}

public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
public IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
if (streamingMiddleware is null)
{
Expand Down
14 changes: 9 additions & 5 deletions dotnet/src/AutoGen.Core/Message/IMessage.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IMessage.cs

using System;
using System.Collections.Generic;

namespace AutoGen.Core;
Expand Down Expand Up @@ -35,37 +36,40 @@ namespace AutoGen.Core;
/// </item>
/// </list>
/// </summary>
public interface IMessage : IStreamingMessage
public interface IMessage
{
string? From { get; set; }
}

public interface IMessage<out T> : IMessage, IStreamingMessage<T>
public interface IMessage<out T> : IMessage
{
T Content { get; }
}

/// <summary>
/// The interface for messages that can get text content.
/// This interface will be used by <see cref="MessageExtension.GetContent(IMessage)"/> to get the content from the message.
/// </summary>
public interface ICanGetTextContent : IMessage, IStreamingMessage
public interface ICanGetTextContent : IMessage
{
public string? GetContent();
}

/// <summary>
/// The interface for messages that can get a list of <see cref="ToolCall"/>
/// </summary>
public interface ICanGetToolCalls : IMessage, IStreamingMessage
public interface ICanGetToolCalls : IMessage
{
public IEnumerable<ToolCall> GetToolCalls();
}


[Obsolete("Use IMessage instead")]
public interface IStreamingMessage
{
string? From { get; set; }
}

[Obsolete("Use IMessage<T> instead")]
public interface IStreamingMessage<out T> : IStreamingMessage
{
T Content { get; }
Expand Down
4 changes: 2 additions & 2 deletions dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace AutoGen.Core;

public abstract class MessageEnvelope : IMessage, IStreamingMessage
public abstract class MessageEnvelope : IMessage
{
public MessageEnvelope(string? from = null, IDictionary<string, object>? metadata = null)
{
Expand All @@ -23,7 +23,7 @@ public static MessageEnvelope<TContent> Create<TContent>(TContent content, strin
public IDictionary<string, object> Metadata { get; set; }
}

public class MessageEnvelope<T> : MessageEnvelope, IMessage<T>, IStreamingMessage<T>
public class MessageEnvelope<T> : MessageEnvelope, IMessage<T>
{
public MessageEnvelope(T content, string? from = null, IDictionary<string, object>? metadata = null)
: base(from, metadata)
Expand Down
4 changes: 2 additions & 2 deletions dotnet/src/AutoGen.Core/Message/TextMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace AutoGen.Core;

public class TextMessage : IMessage, IStreamingMessage, ICanGetTextContent
public class TextMessage : IMessage, ICanGetTextContent
{
public TextMessage(Role role, string content, string? from = null)
{
Expand Down Expand Up @@ -51,7 +51,7 @@ public override string ToString()
}
}

public class TextMessageUpdate : IStreamingMessage, ICanGetTextContent
public class TextMessageUpdate : IMessage, ICanGetTextContent
{
public TextMessageUpdate(Role role, string? content, string? from = null)
{
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public IEnumerable<ToolCall> GetToolCalls()
}
}

public class ToolCallMessageUpdate : IStreamingMessage
public class ToolCallMessageUpdate : IMessage
{
public ToolCallMessageUpdate(string functionName, string functionArgumentUpdate, string? from = null)
{
Expand Down
16 changes: 10 additions & 6 deletions dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent,
return reply;
}

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(
public async IAsyncEnumerable<IMessage> InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
Expand All @@ -86,16 +86,16 @@ public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(
var combinedFunctions = this.functions?.Concat(options.Functions ?? []) ?? options.Functions;
options.Functions = combinedFunctions?.ToArray();

IStreamingMessage? initMessage = default;
IMessage? mergedFunctionCallMessage = default;
await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken))
{
if (message is ToolCallMessageUpdate toolCallMessageUpdate && this.functionMap != null)
{
if (initMessage is null)
if (mergedFunctionCallMessage is null)
{
initMessage = new ToolCallMessage(toolCallMessageUpdate);
mergedFunctionCallMessage = new ToolCallMessage(toolCallMessageUpdate);
}
else if (initMessage is ToolCallMessage toolCall)
else if (mergedFunctionCallMessage is ToolCallMessage toolCall)
{
toolCall.Update(toolCallMessageUpdate);
}
Expand All @@ -104,13 +104,17 @@ public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(
throw new InvalidOperationException("The first message is ToolCallMessage, but the update message is not ToolCallMessageUpdate");
}
}
else if (message is ToolCallMessage toolCallMessage1)
{
mergedFunctionCallMessage = toolCallMessage1;
}
else
{
yield return message;
}
}

if (initMessage is ToolCallMessage toolCallMsg)
if (mergedFunctionCallMessage is ToolCallMessage toolCallMsg)
{
yield return await this.InvokeToolCallMessagesAfterInvokingAgentAsync(toolCallMsg, agent);
}
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public interface IStreamingMiddleware : IMiddleware
/// <summary>
/// The streaming version of <see cref="IMiddleware.InvokeAsync(MiddlewareContext, IAgent, CancellationToken)"/>.
/// </summary>
public IAsyncEnumerable<IStreamingMessage> InvokeAsync(
public IAsyncEnumerable<IMessage> InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
CancellationToken cancellationToken = default);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent,
}
}

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public async IAsyncEnumerable<IMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
IMessage? recentUpdate = null;
await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken))
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Gemini/GeminiChatAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, G
return MessageEnvelope.Create(response, this.Name);
}

public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var request = BuildChatRequest(messages, options);
var response = this.client.GenerateContentStreamAsync(request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public GeminiMessageConnector(bool strictMode = false)

public string Name => nameof(GeminiMessageConnector);

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public async IAsyncEnumerable<IMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = ProcessMessage(context.Messages, agent);

Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public async Task<IMessage> GenerateReplyAsync(
return new MessageEnvelope<ChatCompletionResponse>(response, from: this.Name);
}

public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ public class MistralChatMessageConnector : IStreamingMiddleware, IMiddleware
{
public string? Name => nameof(MistralChatMessageConnector);

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public async IAsyncEnumerable<IMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessages = ProcessMessage(messages, agent);
var chunks = new List<ChatCompletionResponse>();
await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
{
if (reply is IStreamingMessage<ChatCompletionResponse> chatMessage)
if (reply is IMessage<ChatCompletionResponse> chatMessage)
{
chunks.Add(chatMessage.Content);
var response = ProcessChatCompletionResponse(chatMessage, agent);
Expand Down Expand Up @@ -167,7 +167,7 @@ private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from
}
}

private IStreamingMessage? ProcessChatCompletionResponse(IStreamingMessage<ChatCompletionResponse> message, IAgent agent)
private IMessage? ProcessChatCompletionResponse(IMessage<ChatCompletionResponse> message, IAgent agent)
{
var response = message.Content;
if (response.VarObject != "chat.completion.chunk")
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public async Task<IMessage> GenerateReplyAsync(
}
}

public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent,
};
}

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
public async IAsyncEnumerable<IMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = ProcessMessage(context.Messages, agent);
var chunks = new List<ChatResponseUpdate>();
await foreach (var update in agent.GenerateStreamingReplyAsync(messages, context.Options, cancellationToken))
{
if (update is IStreamingMessage<ChatResponseUpdate> chatResponseUpdate)
if (update is IMessage<ChatResponseUpdate> chatResponseUpdate)
{
var response = chatResponseUpdate.Content switch
{
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public async Task<IMessage> GenerateReplyAsync(
return await _innerAgent.GenerateReplyAsync(messages, options, cancellationToken);
}

public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
public IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public async Task<IMessage> GenerateReplyAsync(
return new MessageEnvelope<ChatCompletions>(reply, from: this.Name);
}

public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent,
return PostProcessMessage(reply);
}

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(
public async IAsyncEnumerable<IMessage> InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
Expand All @@ -57,7 +57,7 @@ public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(
string? currentToolName = null;
await foreach (var reply in streamingReply)
{
if (reply is IStreamingMessage<StreamingChatCompletionsUpdate> update)
if (reply is IMessage<StreamingChatCompletionsUpdate> update)
{
if (update.Content.FunctionName is string functionName)
{
Expand Down Expand Up @@ -98,7 +98,7 @@ public IMessage PostProcessMessage(IMessage message)
};
}

public IStreamingMessage? PostProcessStreamingMessage(IStreamingMessage<StreamingChatCompletionsUpdate> update, string? currentToolName)
public IMessage? PostProcessStreamingMessage(IMessage<StreamingChatCompletionsUpdate> update, string? currentToolName)
{
if (update.Content.ContentUpdate is string contentUpdate)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent,
return PostProcessMessage(reply);
}

public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public async IAsyncEnumerable<IMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var chatMessageContents = ProcessMessage(context.Messages, agent)
.Select(m => new MessageEnvelope<ChatMessageContent>(m));
Expand All @@ -67,11 +67,11 @@ private IMessage PostProcessMessage(IMessage input)
};
}

private IStreamingMessage PostProcessStreamingMessage(IStreamingMessage input)
private IMessage PostProcessStreamingMessage(IMessage input)
{
return input switch
{
IStreamingMessage<StreamingChatMessageContent> streamingMessage => PostProcessMessage(streamingMessage),
IMessage<StreamingChatMessageContent> streamingMessage => PostProcessMessage(streamingMessage),
IMessage msg => PostProcessMessage(msg),
_ => input,
};
Expand All @@ -98,7 +98,7 @@ private IMessage PostProcessMessage(IMessage<ChatMessageContent> messageEnvelope
}
}

private IStreamingMessage PostProcessMessage(IStreamingMessage<StreamingChatMessageContent> streamingMessage)
private IMessage PostProcessMessage(IMessage<StreamingChatMessageContent> streamingMessage)
{
var chatMessageContent = streamingMessage.Content;
if (chatMessageContent.ChoiceIndex > 0)
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, G
return new MessageEnvelope<ChatMessageContent>(reply.First(), from: this.Name);
}

public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
Expand Down
Loading
Loading