Skip to content

Commit 4e95630

Browse files
[.Net] fix #2695 and #2884 (#3069)
* add round robin orchestrator * add constructor for orchestrators * add tests * revert change * return single orchestrator * address comment
1 parent f55a98f commit 4e95630

19 files changed

+901
-106
lines changed

dotnet/Directory.Build.props

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
3232
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
3333
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
34+
<PackageReference Include="Moq" Version="4.20.70" />
3435
</ItemGroup>
3536

3637
<ItemGroup Condition="'$(IsTestProject)' == 'true'">

dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs

+19-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ private ChatCompletionRequest CreateParameters(IEnumerable<IMessage> messages, G
6767
Stream = shouldStream,
6868
Temperature = (decimal?)options?.Temperature ?? _temperature,
6969
Tools = _tools?.ToList(),
70-
ToolChoice = _toolChoice ?? ToolChoice.Auto
70+
ToolChoice = _toolChoice ?? (_tools is { Length: > 0 } ? ToolChoice.Auto : null),
71+
StopSequences = options?.StopSequence?.ToArray(),
7172
};
7273

7374
chatCompletionRequest.Messages = BuildMessages(messages);
@@ -95,6 +96,22 @@ private List<ChatMessage> BuildMessages(IEnumerable<IMessage> messages)
9596
}
9697
}
9798

98-
return chatMessages;
99+
// merge messages with the same role
100+
// fixing #2884
101+
var mergedMessages = chatMessages.Aggregate(new List<ChatMessage>(), (acc, message) =>
102+
{
103+
if (acc.Count > 0 && acc.Last().Role == message.Role)
104+
{
105+
acc.Last().Content.AddRange(message.Content);
106+
}
107+
else
108+
{
109+
acc.Add(message);
110+
}
111+
112+
return acc;
113+
});
114+
115+
return mergedMessages;
99116
}
100117
}

dotnet/src/AutoGen.Anthropic/AnthropicClient.cs

+2-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) Microsoft Corporation. All rights reserved.
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// AnthropicClient.cs
33

44
using System;
@@ -90,13 +90,7 @@ public async IAsyncEnumerable<ChatCompletionResponse> StreamingChatCompletionsAs
9090
{
9191
var res = await JsonSerializer.DeserializeAsync<ChatCompletionResponse>(
9292
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)),
93-
cancellationToken: cancellationToken);
94-
95-
if (res == null)
96-
{
97-
throw new Exception("Failed to deserialize response");
98-
}
99-
93+
cancellationToken: cancellationToken) ?? throw new Exception("Failed to deserialize response");
10094
if (res.Delta?.Type == "input_json_delta" && !string.IsNullOrEmpty(res.Delta.PartialJson) &&
10195
currentEvent.ContentBlock != null)
10296
{

dotnet/src/AutoGen.Core/Agent/IAgent.cs

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77
using System.Threading.Tasks;
88

99
namespace AutoGen.Core;
10-
public interface IAgent
10+
11+
public interface IAgentMetaInformation
1112
{
1213
public string Name { get; }
14+
}
1315

16+
public interface IAgent : IAgentMetaInformation
17+
{
1418
/// <summary>
1519
/// Generate reply
1620
/// </summary>

dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ internal static IEnumerable<IMessage> ProcessConversationsForRolePlay(
100100
var msg = @$"From {x.From}:
101101
{x.GetContent()}
102102
<eof_msg>
103-
round #
104-
{i}";
103+
round # {i}";
105104

106105
return new TextMessage(Role.User, content: msg);
107106
});

dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs

+54-24
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ public class GroupChat : IGroupChat
1515
private List<IAgent> agents = new List<IAgent>();
1616
private IEnumerable<IMessage> initializeMessages = new List<IMessage>();
1717
private Graph? workflow = null;
18+
private readonly IOrchestrator orchestrator;
1819

1920
public IEnumerable<IMessage>? Messages { get; private set; }
2021

@@ -36,6 +37,37 @@ public GroupChat(
3637
this.initializeMessages = initializeMessages ?? new List<IMessage>();
3738
this.workflow = workflow;
3839

40+
if (admin is not null)
41+
{
42+
this.orchestrator = new RolePlayOrchestrator(admin, workflow);
43+
}
44+
else if (workflow is not null)
45+
{
46+
this.orchestrator = new WorkflowOrchestrator(workflow);
47+
}
48+
else
49+
{
50+
this.orchestrator = new RoundRobinOrchestrator();
51+
}
52+
53+
this.Validation();
54+
}
55+
56+
/// <summary>
57+
/// Create a group chat which uses the <paramref name="orchestrator"/> to decide the next speaker(s).
58+
/// </summary>
59+
/// <param name="members"></param>
60+
/// <param name="orchestrator"></param>
61+
/// <param name="initializeMessages"></param>
62+
public GroupChat(
63+
IEnumerable<IAgent> members,
64+
IOrchestrator orchestrator,
65+
IEnumerable<IMessage>? initializeMessages = null)
66+
{
67+
this.agents = members.ToList();
68+
this.initializeMessages = initializeMessages ?? new List<IMessage>();
69+
this.orchestrator = orchestrator;
70+
3971
this.Validation();
4072
}
4173

@@ -64,12 +96,6 @@ private void Validation()
6496
throw new Exception("All agents in the workflow must be in the group chat.");
6597
}
6698
}
67-
68-
// must provide one of admin or workflow
69-
if (this.admin == null && this.workflow == null)
70-
{
71-
throw new Exception("Must provide one of admin or workflow.");
72-
}
7399
}
74100

75101
/// <summary>
@@ -81,6 +107,7 @@ private void Validation()
81107
/// <param name="currentSpeaker">current speaker</param>
82108
/// <param name="conversationHistory">conversation history</param>
83109
/// <returns>next speaker.</returns>
110+
[Obsolete("Please use RolePlayOrchestrator or WorkflowOrchestrator")]
84111
public async Task<IAgent> SelectNextSpeakerAsync(IAgent currentSpeaker, IEnumerable<IMessage> conversationHistory)
85112
{
86113
var agentNames = this.agents.Select(x => x.Name).ToList();
@@ -140,37 +167,40 @@ public void AddInitializeMessage(IMessage message)
140167
}
141168

142169
public async Task<IEnumerable<IMessage>> CallAsync(
143-
IEnumerable<IMessage>? conversationWithName = null,
170+
IEnumerable<IMessage>? chatHistory = null,
144171
int maxRound = 10,
145172
CancellationToken ct = default)
146173
{
147174
var conversationHistory = new List<IMessage>();
148-
if (conversationWithName != null)
175+
conversationHistory.AddRange(this.initializeMessages);
176+
if (chatHistory != null)
149177
{
150-
conversationHistory.AddRange(conversationWithName);
178+
conversationHistory.AddRange(chatHistory);
151179
}
180+
var roundLeft = maxRound;
152181

153-
var lastSpeaker = conversationHistory.LastOrDefault()?.From switch
182+
while (roundLeft > 0)
154183
{
155-
null => this.agents.First(),
156-
_ => this.agents.FirstOrDefault(x => x.Name == conversationHistory.Last().From) ?? throw new Exception("The agent is not in the group chat"),
157-
};
158-
var round = 0;
159-
while (round < maxRound)
160-
{
161-
var currentSpeaker = await this.SelectNextSpeakerAsync(lastSpeaker, conversationHistory);
162-
var processedConversation = this.ProcessConversationForAgent(this.initializeMessages, conversationHistory);
163-
var result = await currentSpeaker.GenerateReplyAsync(processedConversation) ?? throw new Exception("No result is returned.");
184+
var orchestratorContext = new OrchestrationContext
185+
{
186+
Candidates = this.agents,
187+
ChatHistory = conversationHistory,
188+
};
189+
var nextSpeaker = await this.orchestrator.GetNextSpeakerAsync(orchestratorContext, ct);
190+
if (nextSpeaker == null)
191+
{
192+
break;
193+
}
194+
195+
var result = await nextSpeaker.GenerateReplyAsync(conversationHistory, cancellationToken: ct);
164196
conversationHistory.Add(result);
165197

166-
// if message is terminate message, then terminate the conversation
167-
if (result?.IsGroupChatTerminateMessage() ?? false)
198+
if (result.IsGroupChatTerminateMessage())
168199
{
169-
break;
200+
return conversationHistory;
170201
}
171202

172-
lastSpeaker = currentSpeaker;
173-
round++;
203+
roundLeft--;
174204
}
175205

176206
return conversationHistory;

dotnet/src/AutoGen.Core/GroupChat/RoundRobinGroupChat.cs

+2-69
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33

44
using System;
55
using System.Collections.Generic;
6-
using System.Linq;
7-
using System.Threading;
8-
using System.Threading.Tasks;
96

107
namespace AutoGen.Core;
118

@@ -25,76 +22,12 @@ public SequentialGroupChat(IEnumerable<IAgent> agents, List<IMessage>? initializ
2522
/// <summary>
2623
/// A group chat that allows agents to talk in a round-robin manner.
2724
/// </summary>
28-
public class RoundRobinGroupChat : IGroupChat
25+
public class RoundRobinGroupChat : GroupChat
2926
{
30-
private readonly List<IAgent> agents = new List<IAgent>();
31-
private readonly List<IMessage> initializeMessages = new List<IMessage>();
32-
3327
public RoundRobinGroupChat(
3428
IEnumerable<IAgent> agents,
3529
List<IMessage>? initializeMessages = null)
30+
: base(agents, initializeMessages: initializeMessages)
3631
{
37-
this.agents.AddRange(agents);
38-
this.initializeMessages = initializeMessages ?? new List<IMessage>();
39-
}
40-
41-
/// <inheritdoc />
42-
public void AddInitializeMessage(IMessage message)
43-
{
44-
this.SendIntroduction(message);
45-
}
46-
47-
public async Task<IEnumerable<IMessage>> CallAsync(
48-
IEnumerable<IMessage>? conversationWithName = null,
49-
int maxRound = 10,
50-
CancellationToken ct = default)
51-
{
52-
var conversationHistory = new List<IMessage>();
53-
if (conversationWithName != null)
54-
{
55-
conversationHistory.AddRange(conversationWithName);
56-
}
57-
58-
var lastSpeaker = conversationHistory.LastOrDefault()?.From switch
59-
{
60-
null => this.agents.First(),
61-
_ => this.agents.FirstOrDefault(x => x.Name == conversationHistory.Last().From) ?? throw new Exception("The agent is not in the group chat"),
62-
};
63-
var round = 0;
64-
while (round < maxRound)
65-
{
66-
var currentSpeaker = this.SelectNextSpeaker(lastSpeaker);
67-
var processedConversation = this.ProcessConversationForAgent(this.initializeMessages, conversationHistory);
68-
var result = await currentSpeaker.GenerateReplyAsync(processedConversation) ?? throw new Exception("No result is returned.");
69-
conversationHistory.Add(result);
70-
71-
// if message is terminate message, then terminate the conversation
72-
if (result?.IsGroupChatTerminateMessage() ?? false)
73-
{
74-
break;
75-
}
76-
77-
lastSpeaker = currentSpeaker;
78-
round++;
79-
}
80-
81-
return conversationHistory;
82-
}
83-
84-
public void SendIntroduction(IMessage message)
85-
{
86-
this.initializeMessages.Add(message);
87-
}
88-
89-
private IAgent SelectNextSpeaker(IAgent currentSpeaker)
90-
{
91-
var index = this.agents.IndexOf(currentSpeaker);
92-
if (index == -1)
93-
{
94-
throw new ArgumentException("The agent is not in the group chat", nameof(currentSpeaker));
95-
}
96-
97-
var nextIndex = (index + 1) % this.agents.Count;
98-
return this.agents[nextIndex];
9932
}
10033
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// IOrchestrator.cs
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
9+
namespace AutoGen.Core;
10+
11+
public class OrchestrationContext
12+
{
13+
public IEnumerable<IAgent> Candidates { get; set; } = Array.Empty<IAgent>();
14+
15+
public IEnumerable<IMessage> ChatHistory { get; set; } = Array.Empty<IMessage>();
16+
}
17+
18+
public interface IOrchestrator
19+
{
20+
/// <summary>
21+
/// Return the next agent as the next speaker. return null if no agent is selected.
22+
/// </summary>
23+
/// <param name="context">orchestration context, such as candidate agents and chat history.</param>
24+
/// <param name="cancellationToken">cancellation token</param>
25+
public Task<IAgent?> GetNextSpeakerAsync(
26+
OrchestrationContext context,
27+
CancellationToken cancellationToken = default);
28+
}

0 commit comments

Comments
 (0)