From 5ec030839cf65b5c6f8a3f9405596b6a58b7d804 Mon Sep 17 00:00:00 2001 From: Xiaoyun Zhang Date: Mon, 15 Jul 2024 11:02:30 -0700 Subject: [PATCH] add cancellation token to transition check lambda (#3132) --- dotnet/src/AutoGen.Core/GroupChat/Graph.cs | 45 +++++++++++++++------- dotnet/test/AutoGen.Tests/WorkflowTest.cs | 4 +- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/dotnet/src/AutoGen.Core/GroupChat/Graph.cs b/dotnet/src/AutoGen.Core/GroupChat/Graph.cs index d6b71e2a3f13..acff955a292c 100644 --- a/dotnet/src/AutoGen.Core/GroupChat/Graph.cs +++ b/dotnet/src/AutoGen.Core/GroupChat/Graph.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; namespace AutoGen.Core; @@ -12,11 +13,7 @@ public class Graph { private readonly List transitions = new List(); - public Graph() - { - } - - public Graph(IEnumerable? transitions) + public Graph(IEnumerable? transitions = null) { if (transitions != null) { @@ -40,13 +37,13 @@ public void AddTransition(Transition transition) /// the from agent /// messages /// A list of agents that the messages can be transit to - public async Task> TransitToNextAvailableAgentsAsync(IAgent fromAgent, IEnumerable messages) + public async Task> TransitToNextAvailableAgentsAsync(IAgent fromAgent, IEnumerable messages, CancellationToken ct = default) { var nextAgents = new List(); var availableTransitions = transitions.FindAll(t => t.From == fromAgent) ?? Enumerable.Empty(); foreach (var transition in availableTransitions) { - if (await transition.CanTransitionAsync(messages)) + if (await transition.CanTransitionAsync(messages, ct)) { nextAgents.Add(transition.To); } @@ -63,7 +60,7 @@ public class Transition { private readonly IAgent _from; private readonly IAgent _to; - private readonly Func, Task>? _canTransition; + private readonly Func, CancellationToken, Task>? _canTransition; /// /// Create a new instance of . @@ -73,22 +70,44 @@ public class Transition /// from agent /// to agent /// detect if the transition is allowed, default to be always true - internal Transition(IAgent from, IAgent to, Func, Task>? canTransitionAsync = null) + internal Transition(IAgent from, IAgent to, Func, CancellationToken, Task>? canTransitionAsync = null) { _from = from; _to = to; _canTransition = canTransitionAsync; } + /// + /// Create a new instance of without transition condition check. + /// + /// " + public static Transition Create(TFromAgent from, TToAgent to) + where TFromAgent : IAgent + where TToAgent : IAgent + { + return new Transition(from, to, (fromAgent, toAgent, messages, _) => Task.FromResult(true)); + } + /// /// Create a new instance of . /// /// " - public static Transition Create(TFromAgent from, TToAgent to, Func, Task>? canTransitionAsync = null) + public static Transition Create(TFromAgent from, TToAgent to, Func, Task> canTransitionAsync) + where TFromAgent : IAgent + where TToAgent : IAgent + { + return new Transition(from, to, (fromAgent, toAgent, messages, _) => canTransitionAsync.Invoke((TFromAgent)fromAgent, (TToAgent)toAgent, messages)); + } + + /// + /// Create a new instance of with cancellation token. + /// + /// " + public static Transition Create(TFromAgent from, TToAgent to, Func, CancellationToken, Task> canTransitionAsync) where TFromAgent : IAgent where TToAgent : IAgent { - return new Transition(from, to, (fromAgent, toAgent, messages) => canTransitionAsync?.Invoke((TFromAgent)fromAgent, (TToAgent)toAgent, messages) ?? Task.FromResult(true)); + return new Transition(from, to, (fromAgent, toAgent, messages, ct) => canTransitionAsync.Invoke((TFromAgent)fromAgent, (TToAgent)toAgent, messages, ct)); } public IAgent From => _from; @@ -99,13 +118,13 @@ public static Transition Create(TFromAgent from, TToAgent /// Check if the transition is allowed. /// /// messages - public Task CanTransitionAsync(IEnumerable messages) + public Task CanTransitionAsync(IEnumerable messages, CancellationToken ct = default) { if (_canTransition == null) { return Task.FromResult(true); } - return _canTransition(this.From, this.To, messages); + return _canTransition(this.From, this.To, messages, ct); } } diff --git a/dotnet/test/AutoGen.Tests/WorkflowTest.cs b/dotnet/test/AutoGen.Tests/WorkflowTest.cs index d1d12010e39f..1079ec95515a 100644 --- a/dotnet/test/AutoGen.Tests/WorkflowTest.cs +++ b/dotnet/test/AutoGen.Tests/WorkflowTest.cs @@ -17,7 +17,7 @@ public async Task TransitionTestAsync() var alice = new EchoAgent("alice"); var bob = new EchoAgent("bob"); - var aliceToBob = Transition.Create(alice, bob, async (from, to, messages) => + var aliceToBob = Transition.Create(alice, bob, async (from, to, messages, _) => { if (messages.Any(m => m.GetContent() == "Hello")) { @@ -30,7 +30,7 @@ public async Task TransitionTestAsync() var canTransit = await aliceToBob.CanTransitionAsync([]); canTransit.Should().BeFalse(); - canTransit = await aliceToBob.CanTransitionAsync(new[] { new TextMessage(Role.Assistant, "Hello") }); + canTransit = await aliceToBob.CanTransitionAsync([new TextMessage(Role.Assistant, "Hello")]); canTransit.Should().BeTrue(); // if no function is provided, it should always return true