diff --git a/.github/upgrades/prompts/SemanticKernelToAgentFramework.md b/.github/upgrades/prompts/SemanticKernelToAgentFramework.md index 44985bba98..6ff0984609 100644 --- a/.github/upgrades/prompts/SemanticKernelToAgentFramework.md +++ b/.github/upgrades/prompts/SemanticKernelToAgentFramework.md @@ -105,7 +105,7 @@ After completing migration, verify these specific items: 1. **Compilation**: Execute `dotnet build` on all modified projects - zero errors required 2. **Namespace Updates**: Confirm all `using Microsoft.SemanticKernel.Agents` statements are replaced 3. **Method Calls**: Verify all `InvokeAsync` calls are changed to `RunAsync` -4. **Return Types**: Confirm handling of `AgentRunResponse` instead of `IAsyncEnumerable>` +4. **Return Types**: Confirm handling of `AgentResponse` instead of `IAsyncEnumerable>` 5. **Thread Creation**: Validate all thread creation uses `agent.GetNewThread()` pattern 6. **Tool Registration**: Ensure `[KernelFunction]` attributes are removed and `AIFunctionFactory.Create()` is used 7. **Options Configuration**: Verify `AgentRunOptions` or `ChatClientAgentRunOptions` replaces `AgentInvokeOptions` @@ -119,7 +119,7 @@ Agent Framework provides functionality for creating and managing AI agents throu Key API differences: - Agent creation: Remove Kernel dependency, use direct client-based creation - Method names: `InvokeAsync` → `RunAsync`, `InvokeStreamingAsync` → `RunStreamingAsync` -- Return types: `IAsyncEnumerable>` → `AgentRunResponse` +- Return types: `IAsyncEnumerable>` → `AgentResponse` - Thread creation: Provider-specific constructors → `agent.GetNewThread()` - Tool registration: `KernelPlugin` system → Direct `AIFunction` registration - Options: `AgentInvokeOptions` → Provider-specific run options (e.g., `ChatClientAgentRunOptions`) @@ -166,8 +166,8 @@ Replace these method calls: | `thread.DeleteAsync()` | Provider-specific cleanup | Use provider client directly | Return type changes: -- `IAsyncEnumerable>` → `AgentRunResponse` -- `IAsyncEnumerable` → `IAsyncEnumerable` +- `IAsyncEnumerable>` → `AgentResponse` +- `IAsyncEnumerable` → `IAsyncEnumerable` @@ -191,8 +191,8 @@ Agent Framework changes these behaviors compared to Semantic Kernel Agents: 1. **Thread Management**: Agent Framework automatically manages thread state. Semantic Kernel required manual thread updates in some scenarios (e.g., OpenAI Responses). 2. **Return Types**: - - Non-streaming: Returns single `AgentRunResponse` instead of `IAsyncEnumerable>` - - Streaming: Returns `IAsyncEnumerable` instead of `IAsyncEnumerable` + - Non-streaming: Returns single `AgentResponse` instead of `IAsyncEnumerable>` + - Streaming: Returns `IAsyncEnumerable` instead of `IAsyncEnumerable` 3. **Tool Registration**: Agent Framework uses direct function registration without requiring `[KernelFunction]` attributes. @@ -397,7 +397,7 @@ await foreach (AgentResponseItem item in agent.InvokeAsync(u **With this Agent Framework non-streaming pattern:** ```csharp -AgentRunResponse result = await agent.RunAsync(userInput, thread, options); +AgentResponse result = await agent.RunAsync(userInput, thread, options); Console.WriteLine(result); ``` @@ -411,7 +411,7 @@ await foreach (StreamingChatMessageContent update in agent.InvokeStreamingAsync( **With this Agent Framework streaming pattern:** ```csharp -await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(userInput, thread, options)) +await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(userInput, thread, options)) { Console.Write(update); } @@ -420,8 +420,8 @@ await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(userInpu **Required changes:** 1. Replace `agent.InvokeAsync()` with `agent.RunAsync()` 2. Replace `agent.InvokeStreamingAsync()` with `agent.RunStreamingAsync()` -3. Change return type handling from `IAsyncEnumerable>` to `AgentRunResponse` -4. Change streaming type from `StreamingChatMessageContent` to `AgentRunResponseUpdate` +3. Change return type handling from `IAsyncEnumerable>` to `AgentResponse` +4. Change streaming type from `StreamingChatMessageContent` to `AgentResponseUpdate` 5. Remove `await foreach` for non-streaming calls 6. Access message content directly from result object instead of iterating @@ -661,7 +661,7 @@ await foreach (var result in agent.InvokeAsync(input, thread, options)) ```csharp ChatClientAgentRunOptions options = new(new ChatOptions { MaxOutputTokens = 1000 }); -AgentRunResponse result = await agent.RunAsync(input, thread, options); +AgentResponse result = await agent.RunAsync(input, thread, options); Console.WriteLine(result); // Access underlying content when needed: @@ -689,7 +689,7 @@ await foreach (var result in agent.InvokeAsync(input, thread, options)) **With this Agent Framework non-streaming usage pattern:** ```csharp -AgentRunResponse result = await agent.RunAsync(input, thread, options); +AgentResponse result = await agent.RunAsync(input, thread, options); Console.WriteLine($"Tokens: {result.Usage.TotalTokenCount}"); ``` @@ -709,7 +709,7 @@ await foreach (StreamingChatMessageContent response in agent.InvokeStreamingAsyn **With this Agent Framework streaming usage pattern:** ```csharp -await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(input, thread, options)) +await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(input, thread, options)) { if (update.Contents.OfType().FirstOrDefault() is { } usageContent) { diff --git a/.github/workflows/merge-gatekeeper.yml b/.github/workflows/merge-gatekeeper.yml index ea0f27ce92..49d04183d5 100644 --- a/.github/workflows/merge-gatekeeper.yml +++ b/.github/workflows/merge-gatekeeper.yml @@ -29,3 +29,4 @@ jobs: token: ${{ secrets.GITHUB_TOKEN }} timeout: 3600 interval: 30 + ignored: CodeQL diff --git a/.github/workflows/python-merge-tests.yml b/.github/workflows/python-merge-tests.yml index eb4f123f42..0dafc1266f 100644 --- a/.github/workflows/python-merge-tests.yml +++ b/.github/workflows/python-merge-tests.yml @@ -97,7 +97,7 @@ jobs: id: azure-functions-setup - name: Test with pytest timeout-minutes: 10 - run: uv run poe all-tests -n logical --dist loadfile --dist worksteal --timeout 300 --retries 3 --retry-delay 10 + run: uv run poe all-tests -n logical --dist loadfile --dist worksteal --timeout 600 --retries 3 --retry-delay 10 working-directory: ./python - name: Test core samples timeout-minutes: 10 diff --git a/.gitignore b/.gitignore index 0672e15083..f3b78125fd 100644 --- a/.gitignore +++ b/.gitignore @@ -206,6 +206,7 @@ agents.md WARP.md **/memory-bank/ **/projectBrief.md +**/tmpclaude* # Azurite storage emulator files */__azurite_db_blob__.json* @@ -227,3 +228,4 @@ local.settings.json # Database files *.db +python/dotnet-ref diff --git a/docs/decisions/0001-agent-run-response.md b/docs/decisions/0001-agent-run-response.md index 9f13af787c..6f3385e1a1 100644 --- a/docs/decisions/0001-agent-run-response.md +++ b/docs/decisions/0001-agent-run-response.md @@ -163,8 +163,8 @@ foreach (var update in response.Messages) ### Option 2 Run: Container with Primary and Secondary Properties, RunStreaming: Stream of Primary + Secondary Run returns a new response type that has separate properties for the Primary Content and the Secondary Updates leading up to it. -The Primary content is available in the `AgentRunResponse.Messages` property while Secondary updates are in a new `AgentRunResponse.Updates` property. -`AgentRunResponse.Text` returns the Primary content text. +The Primary content is available in the `AgentResponse.Messages` property while Secondary updates are in a new `AgentResponse.Updates` property. +`AgentResponse.Text` returns the Primary content text. Since streaming would still need to return an `IAsyncEnumerable` of updates, the design would differ from non-streaming. With non-streaming Primary and Secondary content is split into separate lists, while with streaming it's combined in one stream. @@ -232,24 +232,24 @@ await foreach (var update in responses) ```csharp class Agent { - public abstract Task RunAsync( + public abstract Task RunAsync( IReadOnlyCollection messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default); - public abstract IAsyncEnumerable RunStreamingAsync( + public abstract IAsyncEnumerable RunStreamingAsync( IReadOnlyCollection messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default); } -class AgentRunResponse : ChatResponse +class AgentResponse : ChatResponse { } -public class AgentRunResponseUpdate : ChatResponseUpdate +public class AgentResponseUpdate : ChatResponseUpdate { } ``` @@ -265,20 +265,20 @@ The new types could also exclude properties that make less sense for agents, lik ```csharp class Agent { - public abstract Task RunAsync( + public abstract Task RunAsync( IReadOnlyCollection messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default); - public abstract IAsyncEnumerable RunStreamingAsync( + public abstract IAsyncEnumerable RunStreamingAsync( IReadOnlyCollection messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default); } -class AgentRunResponse // Compare with ChatResponse +class AgentResponse // Compare with ChatResponse { public string Text { get; } // Aggregation of TextContent from messages. @@ -294,12 +294,12 @@ class AgentRunResponse // Compare with ChatResponse public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } } -// Not Included in AgentRunResponse compared to ChatResponse +// Not Included in AgentResponse compared to ChatResponse public ChatFinishReason? FinishReason { get; set; } public string? ConversationId { get; set; } public string? ModelId { get; set; } -public class AgentRunResponseUpdate // Compare with ChatResponseUpdate +public class AgentResponseUpdate // Compare with ChatResponseUpdate { public string Text { get; } // Aggregation of TextContent from Contents. @@ -317,7 +317,7 @@ public class AgentRunResponseUpdate // Compare with ChatResponseUpdate public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } } -// Not Included in AgentRunResponseUpdate compared to ChatResponseUpdate +// Not Included in AgentResponseUpdate compared to ChatResponseUpdate public ChatFinishReason? FinishReason { get; set; } public string? ConversationId { get; set; } public string? ModelId { get; set; } @@ -360,7 +360,7 @@ public class ChatFinishReason ### Option 2: Add another property on responses for AgentRun ```csharp -class AgentRunResponse +class AgentResponse { ... public AgentRun RunReference { get; set; } // Reference to long running process @@ -368,7 +368,7 @@ class AgentRunResponse } -public class AgentRunResponseUpdate +public class AgentResponseUpdate { ... public AgentRun RunReference { get; set; } // Reference to long running process @@ -424,7 +424,7 @@ Note that where an agent doesn't support structured output, it may also be possi See [Structured Outputs Support](#structured-outputs-support) for a comparison on what other agent frameworks and protocols support. To support a good user experience for structured outputs, I'm proposing that we follow the pattern used by MEAI. -We would add a generic version of `AgentRunResponse`, that allows us to get the agent result already deserialized into our preferred type. +We would add a generic version of `AgentResponse`, that allows us to get the agent result already deserialized into our preferred type. This would be coupled with generic overload extension methods for Run that automatically builds a schema from the supplied type and updates the run options. @@ -438,14 +438,14 @@ class Movie public int ReleaseYear { get; set; } } -AgentRunResponse response = agent.RunAsync("What are the top 3 children's movies of the 80s."); +AgentResponse response = agent.RunAsync("What are the top 3 children's movies of the 80s."); Movie[] movies = response.Result ``` If we only support requesting a schema at agent creation time or where an agent has a built in schema, the following would be the preferred approach: ```csharp -AgentRunResponse response = agent.RunAsync("What are the top 3 children's movies of the 80s."); +AgentResponse response = agent.RunAsync("What are the top 3 children's movies of the 80s."); Movie[] movies = response.TryParseStructuredOutput(); ``` @@ -463,7 +463,7 @@ Option 2 chosen so that we can vary Agent responses independently of Chat Client ### StructuredOutputs Decision We will not support structured output per run request, but individual agents are free to allow this on the concrete implementation or at construction time. -We will however add support for easily extracting a structured output type from the `AgentRunResponse`. +We will however add support for easily extracting a structured output type from the `AgentResponse`. ## Addendum 1: AIContext Derived Types for different response types / Gap Analysis (Work in progress) diff --git a/docs/decisions/0005-python-naming-conventions.md b/docs/decisions/0005-python-naming-conventions.md index d82cad16ab..3a79b98f91 100644 --- a/docs/decisions/0005-python-naming-conventions.md +++ b/docs/decisions/0005-python-naming-conventions.md @@ -54,7 +54,7 @@ The table below represents the majority of the naming changes discussed in issue | *Mcp* & *Http* | *MCP* & *HTTP* | accepted | Acronyms should be uppercased in class names, according to PEP 8. | None | | `agent.run_streaming` | `agent.run_stream` | accepted | Shorter and more closely aligns with AutoGen and Semantic Kernel names for the same methods. | None | | `workflow.run_streaming` | `workflow.run_stream` | accepted | In sync with `agent.run_stream` and shorter and more closely aligns with AutoGen and Semantic Kernel names for the same methods. | None | -| AgentRunResponse & AgentRunResponseUpdate | AgentResponse & AgentResponseUpdate | rejected | Rejected, because it is the response to a run invocation and AgentResponse is too generic. | None | +| AgentResponse & AgentResponseUpdate | AgentResponse & AgentResponseUpdate | rejected | Rejected, because it is the response to a run invocation and AgentResponse is too generic. | None | | *Content | * | rejected | Rejected other content type renames (removing `Content` suffix) because it would reduce clarity and discoverability. | Item was also considered, but rejected as it is very similar to Content, but would be inconsistent with dotnet. | | ChatResponse & ChatResponseUpdate | Response & ResponseUpdate | rejected | Rejected, because Response is too generic. | None | diff --git a/docs/decisions/0006-userapproval.md b/docs/decisions/0006-userapproval.md index 63ca8bc0fb..7823ab4de4 100644 --- a/docs/decisions/0006-userapproval.md +++ b/docs/decisions/0006-userapproval.md @@ -161,11 +161,11 @@ while (response.ApprovalRequests.Count > 0) response = await agent.RunAsync(messages, thread); } -class AgentRunResponse +class AgentResponse { ... - // A new property on AgentRunResponse to aggregate the ApprovalRequestContent items from + // A new property on AgentResponse to aggregate the ApprovalRequestContent items from // the response messages (Similar to the Text property). public IEnumerable ApprovalRequests { get; set; } @@ -251,11 +251,11 @@ while (response.UserInputRequests.Any()) response = await agent.RunAsync(messages, thread); } -class AgentRunResponse +class AgentResponse { ... - // A new property on AgentRunResponse to aggregate the UserInputRequestContent items from + // A new property on AgentResponse to aggregate the UserInputRequestContent items from // the response messages (Similar to the Text property). public IReadOnlyList UserInputRequests { get; set; } @@ -366,11 +366,11 @@ while (response.UserInputRequests.Any()) response = await agent.RunAsync(messages, thread); } -class AgentRunResponse +class AgentResponse { ... - // A new property on AgentRunResponse to aggregate the UserInputRequestContent items from + // A new property on AgentResponse to aggregate the UserInputRequestContent items from // the response messages (Similar to the Text property). public IEnumerable UserInputRequests { get; set; } diff --git a/docs/decisions/0007-agent-filtering-middleware.md b/docs/decisions/0007-agent-filtering-middleware.md index 3855e8a9c8..dbdd6d37d1 100644 --- a/docs/decisions/0007-agent-filtering-middleware.md +++ b/docs/decisions/0007-agent-filtering-middleware.md @@ -115,7 +115,7 @@ public class AIAgent } } - public async Task RunAsync( + public async Task RunAsync( IReadOnlyCollection messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -135,7 +135,7 @@ public class AIAgent return context.Response ?? throw new InvalidOperationException("Agent execution did not produce a response"); } - protected abstract Task ExecuteCoreLogicAsync( + protected abstract Task ExecuteCoreLogicAsync( IReadOnlyCollection messages, AgentThread? thread, AgentRunOptions? options, @@ -190,7 +190,7 @@ internal sealed class GuardrailCallbackAgent : DelegatingAIAgent public GuardrailCallbackAgent(AIAgent innerAgent) : base(innerAgent) { } - public override async Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + public override async Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { var filteredMessages = this.FilterMessages(messages); Console.WriteLine($"Guardrail Middleware - Filtered messages: {new ChatResponse(filteredMessages).Text}"); @@ -202,14 +202,14 @@ internal sealed class GuardrailCallbackAgent : DelegatingAIAgent return response; } - public override async IAsyncEnumerable RunStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public override async IAsyncEnumerable RunStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var filteredMessages = this.FilterMessages(messages); await foreach (var update in this.InnerAgent.RunStreamingAsync(filteredMessages, thread, options, cancellationToken)) { if (update.Text != null) { - yield return new AgentRunResponseUpdate(update.Role, this.FilterContent(update.Text)); + yield return new AgentResponseUpdate(update.Role, this.FilterContent(update.Text)); } else { @@ -252,7 +252,7 @@ internal sealed class RunningCallbackHandlerAgent : DelegatingAIAgent this._func = func; } - public override async Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + public override async Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { var context = new AgentInvokeCallbackContext(this, messages, thread, options, isStreaming: false, cancellationToken); @@ -469,7 +469,7 @@ public sealed class CallbackEnabledAgent : DelegatingAIAgent this._callbacksProcessor = callbackMiddlewareProcessor ?? new(); } - public override async Task RunAsync( + public override async Task RunAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -541,7 +541,7 @@ public abstract class AgentContext public class AgentRunContext : AgentContext { public IList Messages { get; set; } - public AgentRunResponse? Response { get; set; } + public AgentResponse? Response { get; set; } public AgentThread? Thread { get; } public AgentRunContext(AIAgent agent, IList messages, AgentThread? thread, AgentRunOptions? options) diff --git a/docs/decisions/0009-support-long-running-operations.md b/docs/decisions/0009-support-long-running-operations.md index 7227840c8f..a62a038553 100644 --- a/docs/decisions/0009-support-long-running-operations.md +++ b/docs/decisions/0009-support-long-running-operations.md @@ -687,7 +687,7 @@ This section considers different options for exposing the `RunId`, `Status`, and #### 4.1. As AIContent The `AsyncRunContent` class will represent a long-running operation initiated and managed by an agent/LLM. -Items of this content type will be returned in a chat message as part of the `AgentRunResponse` or `ChatResponse` +Items of this content type will be returned in a chat message as part of the `AgentResponse` or `ChatResponse` response to represent the long-running operation. The `AsyncRunContent` class has two properties: `RunId` and `Status`. The `RunId` identifies the @@ -1162,29 +1162,29 @@ For cancellation and deletion of long-running operations, new methods will be ad public abstract class AIAgent { // Existing methods... - public Task RunAsync(string message, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { ... } - public IAsyncEnumerable RunStreamingAsync(string message, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { ... } + public Task RunAsync(string message, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { ... } + public IAsyncEnumerable RunStreamingAsync(string message, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { ... } // New methods for uncommon operations - public virtual Task CancelRunAsync(string id, AgentCancelRunOptions? options = null, CancellationToken cancellationToken = default) + public virtual Task CancelRunAsync(string id, AgentCancelRunOptions? options = null, CancellationToken cancellationToken = default) { - return Task.FromResult(null); + return Task.FromResult(null); } - public virtual Task DeleteRunAsync(string id, AgentDeleteRunOptions? options = null, CancellationToken cancellationToken = default) + public virtual Task DeleteRunAsync(string id, AgentDeleteRunOptions? options = null, CancellationToken cancellationToken = default) { - return Task.FromResult(null); + return Task.FromResult(null); } } // Agent that supports update and cancellation public class CustomAgent : AIAgent { - public override async Task CancelRunAsync(string id, AgentCancelRunOptions? options = null, CancellationToken cancellationToken = default) + public override async Task CancelRunAsync(string id, AgentCancelRunOptions? options = null, CancellationToken cancellationToken = default) { var response = await this._client.CancelRunAsync(id, options?.Thread?.ConversationId); - return ConvertToAgentRunResponse(response); + return ConvertToAgentResponse(response); } // No overload for DeleteRunAsync as it's not supported by the underlying API @@ -1195,7 +1195,7 @@ AIAgent agent = new CustomAgent(); AgentThread thread = agent.GetNewThread(); -AgentRunResponse response = await agent.RunAsync("What is the capital of France?"); +AgentResponse response = await agent.RunAsync("What is the capital of France?"); response = await agent.CancelRunAsync(response.ResponseId, new AgentCancelRunOptions { Thread = thread }); ``` @@ -1251,10 +1251,10 @@ public class AgentRunOptions AIAgent agent = ...; // Get an instance of an AIAgent // Start a long-running execution for the prompt if supported by the underlying API -AgentRunResponse response = await agent.RunAsync("", new AgentRunOptions { AllowLongRunningResponses = true }); +AgentResponse response = await agent.RunAsync("", new AgentRunOptions { AllowLongRunningResponses = true }); // Start a quick prompt -AgentRunResponse response = await agent.RunAsync(""); +AgentResponse response = await agent.RunAsync(""); ``` **Pros:** @@ -1279,7 +1279,7 @@ Below are the details of the option selected for chat clients that is also selec #### 3.1 Continuation Token of a Custom Type This option suggests using `ContinuationToken` to encapsulate all properties representing a long-running operation. The continuation token will be returned by agents in the -`ContinuationToken` property of the `AgentRunResponse` and `AgentRunResponseUpdate` responses to indicate that the response is part of a long-running operation. A null value +`ContinuationToken` property of the `AgentResponse` and `AgentResponseUpdate` responses to indicate that the response is part of a long-running operation. A null value of the property will indicate that the response is not part of a long-running operation or the long-running operation has been completed. Callers will set the token in the `ContinuationToken` property of the `AgentRunOptions` class in follow-up calls to the `Run{Streaming}Async` methods to indicate that they want to "continue" the long-running operation identified by the token. @@ -1313,18 +1313,18 @@ public class AgentRunOptions public ResponseContinuationToken? ContinuationToken { get; set; } } -public class AgentRunResponse +public class AgentResponse { public ResponseContinuationToken? ContinuationToken { get; } } -public class AgentRunResponseUpdate +public class AgentResponseUpdate { public ResponseContinuationToken? ContinuationToken { get; } } // Usage example -AgentRunResponse response = await agent.RunAsync("What is the capital of France?"); +AgentResponse response = await agent.RunAsync("What is the capital of France?"); AgentRunOptions options = new() { ContinuationToken = response.ContinuationToken }; diff --git a/docs/decisions/0010-ag-ui-support.md b/docs/decisions/0010-ag-ui-support.md index 8d9475bb5a..e1d46e9eff 100644 --- a/docs/decisions/0010-ag-ui-support.md +++ b/docs/decisions/0010-ag-ui-support.md @@ -36,7 +36,7 @@ Chosen option: "Current approach with internal event types and framework-native - Protects consumers from protocol changes by keeping AG-UI events internal - Maintains framework abstractions through conversion at boundaries -- Uses existing framework types (AgentRunResponseUpdate, ChatMessage) for public API +- Uses existing framework types (AgentResponseUpdate, ChatMessage) for public API - Focuses on core text streaming functionality - Leverages existing properties (ConversationId, ResponseId, ErrorContent) instead of custom types - Provides bidirectional client and server support @@ -69,7 +69,7 @@ Chosen option: "Current approach with internal event types and framework-native 3. **Agent Factory Pattern** - `MapAGUIAgent` uses factory function `(messages) => AIAgent` to allow request-specific agent configuration supporting multi-tenancy -4. **Bidirectional Conversion Architecture** - Symmetric conversion logic in shared namespace compiled into both packages for server (`AgentRunResponseUpdate` → AG-UI events) and client (AG-UI events → `AgentRunResponseUpdate`) +4. **Bidirectional Conversion Architecture** - Symmetric conversion logic in shared namespace compiled into both packages for server (`AgentResponseUpdate` → AG-UI events) and client (AG-UI events → `AgentResponseUpdate`) 5. **Thread Management** - `AGUIAgentThread` stores only `ThreadId` with thread ID communicated via `ConversationId`; applications manage persistence for parity with other implementations and to be compliant with the protocol. Future extensions will support having the server manage the conversation. diff --git a/docs/decisions/0011-create-get-agent-api.md b/docs/decisions/0011-create-get-agent-api.md new file mode 100644 index 0000000000..4703c1271d --- /dev/null +++ b/docs/decisions/0011-create-get-agent-api.md @@ -0,0 +1,368 @@ +--- +status: proposed +contact: dmytrostruk +date: 2025-12-12 +deciders: dmytrostruk, markwallace-microsoft, eavanvalkenburg, giles17 +--- + +# Create/Get Agent API + +## Context and Problem Statement + +There is a misalignment between the create/get agent API in the .NET and Python implementations. + +In .NET, the `CreateAIAgent` method can create either a local instance of an agent or a remote instance if the backend provider supports it. For remote agents, once the agent is created, you can retrieve an existing remote agent by using the `GetAIAgent` method. If a backend provider doesn't support remote agents, `CreateAIAgent` just initializes a new local agent instance and `GetAIAgent` is not available. There is also a `BuildAIAgent` method, which is an extension for the `ChatClientBuilder` class from `Microsoft.Extensions.AI`. It builds pipelines of `IChatClient` instances with an `IServiceProvider`. This functionality does not exist in Python, so `BuildAIAgent` is out of scope. + +In Python, there is only one `create_agent` method, which always creates a local instance of the agent. If the backend provider supports remote agents, the remote agent is created only on the first `agent.run()` invocation. + +Below is a short summary of different providers and their APIs in .NET: + +| Package | Method | Behavior | Python support | +|---|---|---|---| +| Microsoft.Agents.AI | `CreateAIAgent` (based on `IChatClient`) | Creates a local instance of `ChatClientAgent`. | Yes (`create_agent` in `BaseChatClient`). | +| Microsoft.Agents.AI.Anthropic | `CreateAIAgent` (based on `IBetaService` and `IAnthropicClient`) | Creates a local instance of `ChatClientAgent`. | Yes (`AnthropicClient` inherits `BaseChatClient`, which exposes `create_agent`). | +| Microsoft.Agents.AI.AzureAI (V2) | `GetAIAgent` (based on `AIProjectClient` with `AgentReference`) | Creates a local instance of `ChatClientAgent`. | Partial (Python uses `create_agent` from `BaseChatClient`). | +| Microsoft.Agents.AI.AzureAI (V2) | `GetAIAgent`/`GetAIAgentAsync` (with `Name`/`ChatClientAgentOptions`) | Fetches `AgentRecord` via HTTP, then creates a local `ChatClientAgent` instance. | No | +| Microsoft.Agents.AI.AzureAI (V2) | `CreateAIAgent`/`CreateAIAgentAsync` (based on `AIProjectClient`) | Creates a remote agent first, then wraps it into a local `ChatClientAgent` instance. | No | +| Microsoft.Agents.AI.AzureAI.Persistent (V1) | `GetAIAgent` (based on `PersistentAgentsClient` with `PersistentAgent`) | Creates a local instance of `ChatClientAgent`. | Partial (Python uses `create_agent` from `BaseChatClient`). | +| Microsoft.Agents.AI.AzureAI.Persistent (V1) | `GetAIAgent`/`GetAIAgentAsync` (with `AgentId`) | Fetches `PersistentAgent` via HTTP, then creates a local `ChatClientAgent` instance. | No | +| Microsoft.Agents.AI.AzureAI.Persistent (V1) | `CreateAIAgent`/`CreateAIAgentAsync` | Creates a remote agent first, then wraps it into a local `ChatClientAgent` instance. | No | +| Microsoft.Agents.AI.OpenAI | `GetAIAgent` (based on `AssistantClient` with `Assistant`) | Creates a local instance of `ChatClientAgent`. | Partial (Python uses `create_agent` from `BaseChatClient`). | +| Microsoft.Agents.AI.OpenAI | `GetAIAgent`/`GetAIAgentAsync` (with `AgentId`) | Fetches `Assistant` via HTTP, then creates a local `ChatClientAgent` instance. | No | +| Microsoft.Agents.AI.OpenAI | `CreateAIAgent`/`CreateAIAgentAsync` (based on `AssistantClient`) | Creates a remote agent first, then wraps it into a local `ChatClientAgent` instance. | No | +| Microsoft.Agents.AI.OpenAI | `CreateAIAgent` (based on `ChatClient`) | Creates a local instance of `ChatClientAgent`. | Yes (`create_agent` in `BaseChatClient`). | +| Microsoft.Agents.AI.OpenAI | `CreateAIAgent` (based on `OpenAIResponseClient`) | Creates a local instance of `ChatClientAgent`. | Yes (`create_agent` in `BaseChatClient`). | + +Another difference between Python and .NET implementation is that in .NET `CreateAIAgent`/`GetAIAgent` methods are implemented as extension methods based on underlying SDK client, like `AIProjectClient` from Azure AI or `AssistantClient` from OpenAI: + +```csharp +// Definition +public static ChatClientAgent CreateAIAgent( + this AIProjectClient aiProjectClient, + string name, + string model, + string instructions, + string? description = null, + IList? tools = null, + Func? clientFactory = null, + IServiceProvider? services = null, + CancellationToken cancellationToken = default) +{ } + +// Usage +AIProjectClient aiProjectClient = new(new Uri(endpoint), new AzureCliCredential()); // Initialization of underlying SDK client + +var newAgent = await aiProjectClient.CreateAIAgentAsync(name: AgentName, model: deploymentName, instructions: AgentInstructions, tools: [tool]); // ChatClientAgent creation from underlying SDK client + +// Alternative usage (same as extension method, just explicit syntax) +var newAgent = await AzureAIProjectChatClientExtensions.CreateAIAgentAsync( + aiProjectClient, + name: AgentName, + model: deploymentName, + instructions: AgentInstructions, + tools: [tool]); +``` + +Python doesn't support extension methods. Currently `create_agent` method is defined on `BaseChatClient`, but this method only creates a local instance of `ChatAgent` and it can't create remote agents for providers that support it for a couple of reasons: + +- It's defined as non-async. +- `BaseChatClient` implementation is stateful for providers like Azure AI or OpenAI Assistants. The implementation stores agent/assistant metadata like `AgentId` and `AgentName`, so currently it's not possible to create different instances of `ChatAgent` from a single `BaseChatClient` in case if the implementation is stateful. + +## Decision Drivers + +- API should be aligned between .NET and Python. +- API should be intuitive and consistent between backend providers in .NET and Python. + +## Considered Options + +Add missing implementations on the Python side. This should include the following: + +### agent-framework-azure-ai (both V1 and V2) + +- Add a `get_agent` method that accepts an underlying SDK agent instance and creates a local instance of `ChatAgent`. +- Add a `get_agent` method that accepts an agent identifier, performs an additional HTTP request to fetch agent data, and then creates a local instance of `ChatAgent`. +- Override the `create_agent` method from `BaseChatClient` to create a remote agent instance and wrap it into a local `ChatAgent`. + +.NET: + +```csharp +var agent1 = new AIProjectClient(...).GetAIAgent(agentInstanceFromSdkType); // Creates a local ChatClientAgent instance from Azure.AI.Projects.OpenAI.AgentReference +var agent2 = new AIProjectClient(...).GetAIAgent(agentName); // Fetches agent data, creates a local ChatClientAgent instance +var agent3 = new AIProjectClient(...).CreateAIAgent(...); // Creates a remote agent, returns a local ChatClientAgent instance +``` + +### agent-framework-core (OpenAI Assistants) + +- Add a `get_agent` method that accepts an underlying SDK agent instance and creates a local instance of `ChatAgent`. +- Add a `get_agent` method that accepts an agent name, performs an additional HTTP request to fetch agent data, and then creates a local instance of `ChatAgent`. +- Override the `create_agent` method from `BaseChatClient` to create a remote agent instance and wrap it into a local `ChatAgent`. + +.NET: + +```csharp +var agent1 = new AssistantClient(...).GetAIAgent(agentInstanceFromSdkType); // Creates a local ChatClientAgent instance from OpenAI.Assistants.Assistant +var agent2 = new AssistantClient(...).GetAIAgent(agentId); // Fetches agent data, creates a local ChatClientAgent instance +var agent3 = new AssistantClient(...).CreateAIAgent(...); // Creates a remote agent, returns a local ChatClientAgent instance +``` + +### Possible Python implementations + +Methods like `create_agent` and `get_agent` should be implemented separately or defined on some stateless component that will allow to create multiple agents from the same instance/place. + +Possible options: + +#### Option 1: Module-level functions + +Implement free functions in the provider package that accept the underlying SDK client as the first argument (similar to .NET extension methods, but expressed in Python). + +Example: + +```python +from agent_framework.azure import create_agent, get_agent + +ai_project_client = AIProjectClient(...) + +# Creates a remote agent first, then returns a local ChatAgent wrapper +created_agent = await create_agent( + ai_project_client, + name="", + instructions="", + tools=[tool], +) + +# Gets an existing remote agent and returns a local ChatAgent wrapper +first_agent = await get_agent(ai_project_client, agent_id=agent_id) + +# Wraps an SDK agent instance (no extra HTTP call) +second_agent = get_agent(ai_project_client, agent_reference) +``` + +Pros: + +- Naturally supports async `create_agent` / `get_agent`. +- Supports multiple agents per SDK client. +- Closest conceptual match to .NET extension methods while staying Pythonic. + +Cons: + +- Discoverability is lower (users need to know where the functions live). +- Verbose when creating multiple agents (client must be passed every time): + + ```python + agent1 = await azure_agents.create_agent(client, name="Agent1", ...) + agent2 = await azure_agents.create_agent(client, name="Agent2", ...) + ``` + +#### Option 2: Provider object + +Introduce a dedicated provider type that is constructed from the underlying SDK client, and exposes async `create_agent` / `get_agent` methods. + +Example: + +```python +from agent_framework.azure import AzureAIAgentProvider + +ai_project_client = AIProjectClient(...) +provider = AzureAIAgentProvider(ai_project_client) + +agent = await provider.create_agent( + name="", + instructions="", + tools=[tool], +) + +agent = await provider.get_agent(agent_id=agent_id) +agent = provider.get_agent(agent_reference=agent_reference) +``` + +Pros: + +- High discoverability and clear grouping of related behavior. +- Keeps SDK clients unchanged and supports multiple agents per SDK client. +- Concise when creating multiple agents (client passed once): + + ```python + provider = AzureAIAgentProvider(ai_project_client) + agent1 = await provider.create_agent(name="Agent1", ...) + agent2 = await provider.create_agent(name="Agent2", ...) + ``` + +Cons: + +- Adds a new public concept/type for users to learn. + +#### Option 3: Inheritance (SDK client subclass) + +Create a subclass of the underlying SDK client and add `create_agent` / `get_agent` methods. + +Example: + +```python +class ExtendedAIProjectClient(AIProjectClient): + async def create_agent(self, *, name: str, model: str, instructions: str, **kwargs) -> ChatAgent: + ... + + async def get_agent(self, *, agent_id: str | None = None, sdk_agent=None, **kwargs) -> ChatAgent: + ... + +client = ExtendedAIProjectClient(...) +agent = await client.create_agent(name="", instructions="") +``` + +Pros: + +- Discoverable and ergonomic call sites. +- Mirrors the .NET “methods on the client” feeling. + +Cons: + +- Many SDK clients are not designed for inheritance; SDK upgrades can break subclasses. +- Users must opt into subclass everywhere. +- Typing/initialization can be tricky if the SDK client has non-trivial constructors. + +#### Option 4: Monkey patching + +Attach `create_agent` / `get_agent` methods to an SDK client class (or instance) at runtime. + +Example: + +```python +def _create_agent(self, *, name: str, model: str, instructions: str, **kwargs) -> ChatAgent: + ... + +AIProjectClient.create_agent = _create_agent # monkey patch +``` + +Pros: + +- Produces “extension method-like” call sites without wrappers or subclasses. + +Cons: + +- Fragile across SDK updates and difficult to type-check. +- Surprising behavior (global side effects), potential conflicts across packages. +- Harder to support/debug, especially in larger apps and test suites. + +## Decision Outcome + +Implement `create_agent`/`get_agent`/`as_agent` API via **Option 2: Provider object**. + +### Rationale + +| Aspect | Option 1 (Functions) | Option 2 (Provider) | +|--------|----------------------|---------------------| +| Multiple implementations | One package may contain V1, V2, and other agent types. Function names like `create_agent` become ambiguous - which agent type does it create? | Each provider class is explicit: `AzureAIAgentsProvider` vs `AzureAIProjectAgentProvider` | +| Discoverability | Users must know to import specific functions from the package | IDE autocomplete on provider instance shows all available methods | +| Client reuse | SDK client must be passed to every function call: `create_agent(client, ...)`, `get_agent(client, ...)` | SDK client passed once at construction: `provider = Provider(client)` | + +**Option 1 example:** +```python +from agent_framework.azure import create_agent, get_agent +agent1 = await create_agent(client, name="Agent1", ...) # Which agent type, V1 or V2? +agent2 = await create_agent(client, name="Agent2", ...) # Repetitive client passing +``` + +**Option 2 example:** +```python +from agent_framework.azure import AzureAIProjectAgentProvider +provider = AzureAIProjectAgentProvider(client) # Clear which service, client passed once +agent1 = await provider.create_agent(name="Agent1", ...) +agent2 = await provider.create_agent(name="Agent2", ...) +``` + +### Method Naming + +| Operation | Python | .NET | Async | +|-----------|--------|------|-------| +| Create on service | `create_agent()` | `CreateAIAgent()` | Yes | +| Get from service | `get_agent(id=...)` | `GetAIAgent(agentId)` | Yes | +| Wrap SDK object | `as_agent(reference)` | `AsAIAgent(agentInstance)` | No | + +The method names (`create_agent`, `get_agent`) do not explicitly mention "service" or "remote" because: +- In Python, the provider class name explicitly identifies the service (`AzureAIAgentsProvider`, `OpenAIAssistantProvider`), making additional qualifiers in method names redundant. +- In .NET, these are extension methods on `AIProjectClient` or `AssistantClient`, which already imply service operations. + +### Provider Class Naming + +| Package | Provider Class | SDK Client | Service | +|---------|---------------|------------|---------| +| `agent_framework.azure` | `AzureAIProjectAgentProvider` | `AIProjectClient` | Azure AI Agent Service, based on Responses API (V2) | +| `agent_framework.azure` | `AzureAIAgentsProvider` | `AgentsClient` | Azure AI Agent Service (V1) | +| `agent_framework.openai` | `OpenAIAssistantProvider` | `AsyncOpenAI` | OpenAI Assistants API | + +> **Note:** Azure AI naming is temporary. Final naming will be updated according to Azure AI / Microsoft Foundry renaming decisions. + +### Usage Examples + +#### Azure AI Agent Service V2 (based on Responses API) + +```python +from agent_framework.azure import AzureAIProjectAgentProvider +from azure.ai.projects import AIProjectClient + +client = AIProjectClient(endpoint, credential) +provider = AzureAIProjectAgentProvider(client) + +# Create new agent on service +agent = await provider.create_agent(name="MyAgent", model="gpt-4", instructions="...") + +# Get existing agent by name +agent = await provider.get_agent(agent_name="MyAgent") + +# Wrap already-fetched SDK object (no HTTP calls) +agent_ref = await client.agents.get("MyAgent") +agent = provider.as_agent(agent_ref) +``` + +#### Azure AI Persistent Agents V1 + +```python +from agent_framework.azure import AzureAIAgentsProvider +from azure.ai.agents import AgentsClient + +client = AgentsClient(endpoint, credential) +provider = AzureAIAgentsProvider(client) + +agent = await provider.create_agent(name="MyAgent", model="gpt-4", instructions="...") +agent = await provider.get_agent(agent_id="persistent-agent-456") +agent = provider.as_agent(persistent_agent) +``` + +#### OpenAI Assistants + +```python +from agent_framework.openai import OpenAIAssistantProvider +from openai import OpenAI + +client = OpenAI() +provider = OpenAIAssistantProvider(client) + +agent = await provider.create_agent(name="MyAssistant", model="gpt-4", instructions="...") +agent = await provider.get_agent(assistant_id="asst_123") +agent = provider.as_agent(assistant) +``` + +#### Local-Only Agents (No Provider) + +Current method `create_agent` (python) / `CreateAIAgent` (.NET) can be renamed to `as_agent` (python) / `AsAIAgent` (.NET) to emphasize the conversion logic rather than creation/initialization logic and to avoid collision with `create_agent` method for remote calls. + +```python +from agent_framework import ChatAgent +from agent_framework.openai import OpenAIChatClient + +# Convert chat client to ChatAgent (no remote service involved) +client = OpenAIChatClient(model="gpt-4") +agent = client.as_agent(name="LocalAgent", instructions="...") # instead of create_agent +``` + +### Adding New Agent Types + +Python: + +1. Create provider class in appropriate package. +2. Implement `create_agent`, `get_agent`, `as_agent` as applicable. + +.NET: + +1. Create static class for extension methods. +2. Implement `CreateAIAgentAsync`, `GetAIAgentAsync`, `AsAIAgent` as applicable. diff --git a/docs/decisions/0011-python-typeddict-options.md b/docs/decisions/0011-python-typeddict-options.md new file mode 100644 index 0000000000..09657b2cfb --- /dev/null +++ b/docs/decisions/0011-python-typeddict-options.md @@ -0,0 +1,129 @@ +--- +# These are optional elements. Feel free to remove any of them. +status: proposed +contact: eavanvalkenburg +date: 2026-01-08 +deciders: eavanvalkenburg, markwallace-microsoft, sphenry, alliscode, johanst, brettcannon +consulted: taochenosu, moonbox3, dmytrostruk, giles17 +--- + +# Leveraging TypedDict and Generic Options in Python Chat Clients + +## Context and Problem Statement + +The Agent Framework Python SDK provides multiple chat client implementations for different providers (OpenAI, Anthropic, Azure AI, Bedrock, Ollama, etc.). Each provider has unique configuration options beyond the common parameters defined in `ChatOptions`. Currently, developers using these clients lack type safety and IDE autocompletion for provider-specific options, leading to runtime errors and a poor developer experience. + +How can we provide type-safe, discoverable options for each chat client while maintaining a consistent API across all implementations? + +## Decision Drivers + +- **Type Safety**: Developers should get compile-time/static analysis errors when using invalid options +- **IDE Support**: Full autocompletion and inline documentation for all available options +- **Extensibility**: Users should be able to define custom options that extend provider-specific options +- **Consistency**: All chat clients should follow the same pattern for options handling +- **Provider Flexibility**: Each provider can expose its unique options without affecting the common interface + +## Considered Options + +- **Option 1: Status Quo - Class `ChatOptions` with `**kwargs`** +- **Option 2: TypedDict with Generic Type Parameters** + +### Option 1: Status Quo - Class `ChatOptions` with `**kwargs` + +The current approach uses a base `ChatOptions` Class with common parameters, and provider-specific options are passed via `**kwargs` or loosely typed dictionaries. + +```python +# Current usage - no type safety for provider-specific options +response = await client.get_response( + messages=messages, + temperature=0.7, + top_k=40, + random=42, # No validation +) +``` + +**Pros:** +- Simple implementation +- Maximum flexibility + +**Cons:** +- No type checking for provider-specific options +- No IDE autocompletion for available options +- Runtime errors for typos or invalid options +- Documentation must be consulted for each provider + +### Option 2: TypedDict with Generic Type Parameters (Chosen) + +Each chat client is parameterized with a TypeVar bound to a provider-specific `TypedDict` that extends `ChatOptions`. This enables full type safety and IDE support. + +```python +# Provider-specific TypedDict +class AnthropicChatOptions(ChatOptions, total=False): + """Anthropic-specific chat options.""" + top_k: int + thinking: ThinkingConfig + # ... other Anthropic-specific options + +# Generic chat client +class AnthropicChatClient(ChatClientBase[TAnthropicChatOptions]): + ... + +client = AnthropicChatClient(...) + +# Usage with full type safety +response = await client.get_response( + messages=messages, + options={ + "temperature": 0.7, + "top_k": 40, + "random": 42, # fails type checking and IDE would flag this + } +) + +# Users can extend for custom options +class MyAnthropicOptions(AnthropicChatOptions, total=False): + custom_field: str + + +client = AnthropicChatClient[MyAnthropicOptions](...) + +# Usage of custom options with full type safety +response = await client.get_response( + messages=messages, + options={ + "temperature": 0.7, + "top_k": 40, + "custom_field": "value", + } +) + +``` + +**Pros:** +- Full type safety with static analysis +- IDE autocompletion for all options +- Compile-time error detection +- Self-documenting through type hints +- Users can extend options for their specific needs or advances in models + +**Cons:** +- More complex implementation +- Some type: ignore comments needed for TypedDict field overrides +- Minor: Requires TypeVar with default (Python 3.13+ or typing_extensions) + +> [NOTE!] +> In .NET this is already achieved through overloads on the `GetResponseAsync` method for each provider-specific options class, e.g., `AnthropicChatOptions`, `OpenAIChatOptions`, etc. So this does not apply to .NET. + +### Implementation Details + +1. **Base Protocol**: `ChatClientProtocol[TOptions]` is generic over options type, with default set to `ChatOptions` (the new TypedDict) +2. **Provider TypedDicts**: Each provider defines its options extending `ChatOptions` + They can even override fields with type=None to indicate they are not supported. +3. **TypeVar Pattern**: `TProviderOptions = TypeVar("TProviderOptions", bound=TypedDict, default=ProviderChatOptions, contravariant=True)` +4. **Option Translation**: Common options are kept in place,and explicitly documented in the Options class how they are used. (e.g., `user` → `metadata.user_id`) in `_prepare_options` (for Anthropic) to preserve easy use of common options. + +## Decision Outcome + +Chosen option: **"Option 2: TypedDict with Generic Type Parameters"**, because it provides full type safety, excellent IDE support with autocompletion, and allows users to extend provider-specific options for their use cases. Extended this Generic to ChatAgents in order to also properly type the options used in agent construction and run methods. + +See [typed_options.py](../../python/samples/getting_started/chat_client/typed_options.py) for a complete example demonstrating the usage of typed options with custom extensions. diff --git a/docs/decisions/0012-python-get-response-simplification.md b/docs/decisions/0012-python-get-response-simplification.md new file mode 100644 index 0000000000..2c3965ecd8 --- /dev/null +++ b/docs/decisions/0012-python-get-response-simplification.md @@ -0,0 +1,258 @@ +--- +status: Accepted +contact: eavanvalkenburg +date: 2026-01-06 +deciders: markwallace-microsoft, dmytrostruk, taochenosu, alliscode, moonbox3, sphenry +consulted: sergeymenshykh, rbarreto, dmytrostruk, westey-m +informed: +--- + +# Simplify Python Get Response API into a single method + +## Context and Problem Statement + +Currently chat clients must implement two separate methods to get responses, one for streaming and one for non-streaming. This adds complexity to the client implementations and increases the maintenance burden. This was likely done because the .NET version cannot do proper typing with a single method, in Python this is possible and this for instance is also how the OpenAI python client works, this would then also make it simpler to work with the Python version because there is only one method to learn about instead of two. + +## Implications of this change + +### Current Architecture Overview + +The current design has **two separate methods** at each layer: + +| Layer | Non-streaming | Streaming | +|-------|---------------|-----------| +| **Protocol** | `get_response()` → `ChatResponse` | `get_streaming_response()` → `AsyncIterable[ChatResponseUpdate]` | +| **BaseChatClient** | `get_response()` (public) | `get_streaming_response()` (public) | +| **Implementation** | `_inner_get_response()` (private) | `_inner_get_streaming_response()` (private) | + +### Key Usage Areas Identified + +#### 1. **ChatAgent** (_agents.py) +- `run()` → calls `self.chat_client.get_response()` +- `run_stream()` → calls `self.chat_client.get_streaming_response()` + +These are parallel methods on the agent, so consolidating the client methods would **not break** the agent API. You could keep `agent.run()` and `agent.run_stream()` unchanged while internally calling `get_response(stream=True/False)`. + +#### 2. **Function Invocation Decorator** (_tools.py) +This is **the most impacted area**. Currently: +- `_handle_function_calls_response()` decorates `get_response` +- `_handle_function_calls_streaming_response()` decorates `get_streaming_response` +- The `use_function_invocation` class decorator wraps **both methods separately** + +**Impact**: The decorator logic is almost identical (~200 lines each) with small differences: +- Non-streaming collects response, returns it +- Streaming yields updates, returns async iterable + +With a unified method, you'd need **one decorator** that: +- Checks the `stream` parameter +- Uses `@overload` to determine return type +- Handles both paths with conditional logic +- The new decorator could be applied just on the method, instead of the whole class. + +This would **reduce code duplication** but add complexity to a single function. + +#### 3. **Observability/Instrumentation** (observability.py) +Same pattern as function invocation: +- `_trace_get_response()` wraps `get_response` +- `_trace_get_streaming_response()` wraps `get_streaming_response` +- `use_instrumentation` decorator applies both + +**Impact**: Would need consolidation into a single tracing wrapper. + +#### 4. **Chat Middleware** (_middleware.py) +The `use_chat_middleware` decorator also wraps both methods separately with similar logic. + +#### 5. **AG-UI Client** (_client.py) +Wraps both methods to unwrap server function calls: +```python +original_get_streaming_response = chat_client.get_streaming_response +original_get_response = chat_client.get_response +``` + +#### 6. **Provider Implementations** (all subpackages) +All subclasses implement both `_inner_*` methods, except: +- OpenAI Assistants Client (and similar clients, such as Foundry Agents V1) - it implements `_inner_get_response` by calling `_inner_get_streaming_response` + +### Implications of Consolidation + +| Aspect | Impact | +|--------|--------| +| **Type Safety** | Overloads work well: `@overload` with `Literal[True]` → `AsyncIterable`, `Literal[False]` → `ChatResponse`. Runtime return type based on `stream` param. | +| **Breaking Change** | **Major breaking change** for anyone implementing custom chat clients. They'd need to update from 2 methods to 1 (or 2 inner methods to 1). | +| **Decorator Complexity** | All 3 decorator systems (function invocation, middleware, observability) would need refactoring to handle both paths in one wrapper. | +| **Code Reduction** | Significant reduction in _tools.py (~200 lines of near-duplicate code) and other decorators. | +| **Samples/Tests** | Many samples call `get_streaming_response()` directly - would need updates. | +| **Protocol Simplification** | `ChatClientProtocol` goes from 2 methods + 1 property to 1 method + 1 property. | + +### Recommendation + +The consolidation makes sense architecturally, but consider: + +1. **The overload pattern with `stream: bool`** works well in Python typing: + ```python + @overload + async def get_response(self, messages, *, stream: Literal[True] = True, ...) -> AsyncIterable[ChatResponseUpdate]: ... + @overload + async def get_response(self, messages, *, stream: Literal[False] = False, ...) -> ChatResponse: ... + ``` + +2. **The decorator complexity** is the biggest concern. The current approach of separate decorators for separate methods is cleaner than conditional logic inside one wrapper. + +## Decision Drivers + +- Reduce code needed to implement a Chat Client, simplify the public API for chat clients +- Reduce code duplication in decorators and middleware +- Maintain type safety and clarity in method signatures + +## Considered Options + +1. Status quo: Keep separate methods for streaming and non-streaming +2. Consolidate into a single `get_response` method with a `stream` parameter +3. Option 2 plus merging `agent.run` and `agent.run_stream` into a single method with a `stream` parameter as well + +## Option 1: Status Quo +- Good: Clear separation of streaming vs non-streaming logic +- Good: Aligned with .NET design, although it is already `run` for Python and `RunAsync` for .NET +- Bad: Code duplication in decorators and middleware +- Bad: More complex client implementations + +## Option 2: Consolidate into Single Method +- Good: Simplified public API for chat clients +- Good: Reduced code duplication in decorators +- Good: Smaller API footprint for users to get familiar with +- Good: People using OpenAI directly already expect this pattern +- Bad: Increased complexity in decorators and middleware +- Bad: Less alignment with .NET design (`get_response(stream=True)` vs `GetStreamingResponseAsync`) + +## Option 3: Consolidate + Merge Agent and Workflow Methods +- Good: Further simplifies agent and workflow implementation +- Good: Single method for all chat interactions +- Good: Smaller API footprint for users to get familiar with +- Good: People using OpenAI directly already expect this pattern +- Good: Workflows internally already use a single method (_run_workflow_with_tracing), so would eliminate public API duplication as well, with hardly any code changes +- Bad: More breaking changes for agent users +- Bad: Increased complexity in agent implementation +- Bad: More extensive misalignment with .NET design (`run(stream=True)` vs `RunStreamingAsync` in addition to `get_response` change) + +## Misc + +Smaller questions to consider: +- Should default be `stream=False` or `stream=True`? (Current is False) + - Default to `False` makes it simpler for new users, as non-streaming is easier to handle. + - Default to `False` aligns with existing behavior. + - Streaming tends to be faster, so defaulting to `True` could improve performance for common use cases. + - Should this differ between ChatClient, Agent and Workflows? (e.g., Agent and Workflow defaults to streaming, ChatClient to non-streaming) + +## Decision Outcome + +Chosen Option: **Option 3: Consolidate + Merge Agent and Workflow Methods** + +Since this is the most pythonic option and it reduces the API surface and code duplication the most, we will go with this option. +We will keep the default of `stream=False` for all methods to maintain backward compatibility and simplicity for new users. + +# Appendix +## Code Samples for Consolidated Method + +### Python - Option 3: Direct ChatClient + Agent with Single Method + +```python +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from random import randint +from typing import Annotated + +from agent_framework import ChatAgent +from agent_framework.openai import OpenAIChatClient +from pydantic import Field + + +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +async def main() -> None: + # Example 1: Direct ChatClient usage with single method + client = OpenAIChatClient() + message = "What's the weather in Amsterdam and in Paris?" + + # Non-streaming usage + print(f"User: {message}") + response = await client.get_response(message, tools=get_weather) + print(f"Assistant: {response.text}") + + # Streaming usage - same method, different parameter + print(f"\nUser: {message}") + print("Assistant: ", end="") + async for chunk in client.get_response(message, tools=get_weather, stream=True): + if chunk.text: + print(chunk.text, end="") + print("") + + # Example 2: Agent usage with single method + agent = ChatAgent( + chat_client=client, + tools=get_weather, + name="WeatherAgent", + instructions="You are a weather assistant.", + ) + thread = agent.get_new_thread() + + # Non-streaming agent + print(f"\nUser: {message}") + result = await agent.run(message, thread=thread) # default would be stream=False + print(f"{agent.name}: {result.text}") + + # Streaming agent - same method, different parameter + print(f"\nUser: {message}") + print(f"{agent.name}: ", end="") + async for update in agent.run(message, thread=thread, stream=True): + if update.text: + print(update.text, end="") + print("") + + +if __name__ == "__main__": + asyncio.run(main()) +``` + +### .NET - Current pattern for comparison + +```csharp +// Copyright (c) Microsoft. All rights reserved. + +using Azure.AI.OpenAI; +using Azure.Identity; +using Microsoft.Agents.AI; +using OpenAI.Chat; + +var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") + ?? throw new InvalidOperationException("AZURE_OPENAI_ENDPOINT is not set."); +var deploymentName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOYMENT_NAME") ?? "gpt-4o-mini"; + +AIAgent agent = new AzureOpenAIClient( + new Uri(endpoint), + new AzureCliCredential()) + .GetChatClient(deploymentName) + .CreateAIAgent( + instructions: "You are good at telling jokes about pirates.", + name: "PirateJoker"); + +// Non-streaming: Returns a string directly +Console.WriteLine("=== Non-streaming ==="); +string result = await agent.RunAsync("Tell me a joke about a pirate."); +Console.WriteLine(result); + +// Streaming: Returns IAsyncEnumerable +Console.WriteLine("\n=== Streaming ==="); +await foreach (AgentUpdate update in agent.RunStreamingAsync("Tell me a joke about a pirate.")) +{ + Console.Write(update); +} +Console.WriteLine(); + +``` diff --git a/docs/specs/001-foundry-sdk-alignment.md b/docs/specs/001-foundry-sdk-alignment.md index 1bbe879be8..b7b780c35f 100644 --- a/docs/specs/001-foundry-sdk-alignment.md +++ b/docs/specs/001-foundry-sdk-alignment.md @@ -125,7 +125,7 @@ The proposed solution is to add helper methods which allow developers to either - [Foundry SDK] Create a `PersistentAgentsClient` - [Foundry SDK] Create a `PersistentAgent` using the `PersistentAgentsClient` - [Foundry SDK] Retrieve an `AIAgent` using the `PersistentAgentsClient` -- [Agent Framework SDK] Invoke the `AIAgent` instance and access response from the `AgentRunResponse` +- [Agent Framework SDK] Invoke the `AIAgent` instance and access response from the `AgentResponse` - [Foundry SDK] Clean up the agent @@ -156,7 +156,7 @@ await persistentAgentsClient.Administration.DeleteAgentAsync(agent.Id); - [Foundry SDK] Create a `PersistentAgentsClient` - [Foundry SDK] Create a `AIAgent` using the `PersistentAgentsClient` -- [Agent Framework SDK] Invoke the `AIAgent` instance and access response from the `AgentRunResponse` +- [Agent Framework SDK] Invoke the `AIAgent` instance and access response from the `AgentResponse` - [Foundry SDK] Clean up the agent ```csharp @@ -184,7 +184,7 @@ await persistentAgentsClient.Administration.DeleteAgentAsync(agent.Id); - [Foundry SDK] Create a `PersistentAgentsClient` - [Foundry SDK] Create a `AIAgent` using the `PersistentAgentsClient` - [Agent Framework SDK] Optionally create an `AgentThread` for the agent run -- [Agent Framework SDK] Invoke the `AIAgent` instance and access response from the `AgentRunResponse` +- [Agent Framework SDK] Invoke the `AIAgent` instance and access response from the `AgentResponse` - [Foundry SDK] Clean up the agent and the agent thread ```csharp @@ -227,7 +227,7 @@ await persistentAgentsClient.Administration.DeleteAgentAsync(agent.Id); - [Foundry SDK] Create a `PersistentAgentsClient` - [Foundry SDK] Create multiple `AIAgent` instances using the `PersistentAgentsClient` - [Agent Framework SDK] Create a `SequentialOrchestration` and add all of the agents to it -- [Agent Framework SDK] Invoke the `SequentialOrchestration` instance and access response from the `AgentRunResponse` +- [Agent Framework SDK] Invoke the `SequentialOrchestration` instance and access response from the `AgentResponse` - [Foundry SDK] Clean up the agents ```csharp @@ -281,7 +281,7 @@ SequentialOrchestration orchestration = // Run the orchestration string input = "An eco-friendly stainless steel water bottle that keeps drinks cold for 24 hours"; Console.WriteLine($"\n# INPUT: {input}\n"); -AgentRunResponse result = await orchestration.RunAsync(input); +AgentResponse result = await orchestration.RunAsync(input); Console.WriteLine($"\n# RESULT: {result}"); // Cleanup diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 5ee419114b..c7e53ea256 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -26,7 +26,7 @@ - + diff --git a/dotnet/README.md b/dotnet/README.md index 1d29dbbc2a..4e52260f56 100644 --- a/dotnet/README.md +++ b/dotnet/README.md @@ -21,7 +21,7 @@ var deploymentName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOYMENT var agent = new AzureOpenAIClient(new Uri(endpoint), new AzureCliCredential()) .GetOpenAIResponseClient(deploymentName) - .CreateAIAgent(name: "HaikuBot", instructions: "You are an upbeat assistant that writes beautifully."); + .AsAIAgent(name: "HaikuBot", instructions: "You are an upbeat assistant that writes beautifully."); Console.WriteLine(await agent.RunAsync("Write a haiku about Microsoft Agent Framework.")); ``` diff --git a/dotnet/samples/A2AClientServer/A2AClient/HostClientAgent.cs b/dotnet/samples/A2AClientServer/A2AClient/HostClientAgent.cs index 5ebae80ffe..4daf2c542b 100644 --- a/dotnet/samples/A2AClientServer/A2AClient/HostClientAgent.cs +++ b/dotnet/samples/A2AClientServer/A2AClient/HostClientAgent.cs @@ -29,7 +29,7 @@ internal async Task InitializeAgentAsync(string modelId, string apiKey, string[] // Create the agent that uses the remote agents as tools this.Agent = new OpenAIClient(new ApiKeyCredential(apiKey)) .GetChatClient(modelId) - .CreateAIAgent(instructions: "You specialize in handling queries for users and using your tools to provide answers.", name: "HostClient", tools: tools); + .AsAIAgent(instructions: "You specialize in handling queries for users and using your tools to provide answers.", name: "HostClient", tools: tools); } catch (Exception ex) { diff --git a/dotnet/samples/A2AClientServer/A2AClient/Program.cs b/dotnet/samples/A2AClientServer/A2AClient/Program.cs index 838cbaaef8..b701ea7441 100644 --- a/dotnet/samples/A2AClientServer/A2AClient/Program.cs +++ b/dotnet/samples/A2AClientServer/A2AClient/Program.cs @@ -42,7 +42,7 @@ private static async Task HandleCommandsAsync(CancellationToken cancellationToke // Create the Host agent var hostAgent = new HostClientAgent(loggerFactory); await hostAgent.InitializeAgentAsync(modelId, apiKey, agentUrls!.Split(";")); - AgentThread thread = hostAgent.Agent!.GetNewThread(); + AgentThread thread = await hostAgent.Agent!.GetNewThreadAsync(cancellationToken); try { while (true) diff --git a/dotnet/samples/A2AClientServer/A2AServer/HostAgentFactory.cs b/dotnet/samples/A2AClientServer/A2AServer/HostAgentFactory.cs index 9c4fbaaf2c..8af2b01daf 100644 --- a/dotnet/samples/A2AClientServer/A2AServer/HostAgentFactory.cs +++ b/dotnet/samples/A2AClientServer/A2AServer/HostAgentFactory.cs @@ -35,7 +35,7 @@ internal static class HostAgentFactory { AIAgent agent = new OpenAIClient(apiKey) .GetChatClient(model) - .CreateAIAgent(instructions, name, tools: tools); + .AsAIAgent(instructions, name, tools: tools); AgentCard agentCard = agentType.ToUpperInvariant() switch { diff --git a/dotnet/samples/AGUIClientServer/AGUIClient/Program.cs b/dotnet/samples/AGUIClientServer/AGUIClient/Program.cs index 3079bf1451..1906b4d913 100644 --- a/dotnet/samples/AGUIClientServer/AGUIClient/Program.cs +++ b/dotnet/samples/AGUIClientServer/AGUIClient/Program.cs @@ -83,12 +83,12 @@ private static async Task HandleCommandsAsync(CancellationToken cancellationToke serverUrl, jsonSerializerOptions: AGUIClientSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent( + AIAgent agent = chatClient.AsAIAgent( name: "agui-client", description: "AG-UI Client Agent", tools: [changeBackground, readClientClimateSensors]); - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(cancellationToken); List messages = [new(ChatRole.System, "You are a helpful assistant.")]; try { @@ -114,7 +114,7 @@ private static async Task HandleCommandsAsync(CancellationToken cancellationToke bool isFirstUpdate = true; string? threadId = null; var updates = new List(); - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(messages, thread, cancellationToken: cancellationToken)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(messages, thread, cancellationToken: cancellationToken)) { // Use AsChatResponseUpdate to access ChatResponseUpdate properties ChatResponseUpdate chatUpdate = update.AsChatResponseUpdate(); diff --git a/dotnet/samples/AGUIClientServer/AGUIDojoServer/AgenticUI/AgenticUIAgent.cs b/dotnet/samples/AGUIClientServer/AGUIDojoServer/AgenticUI/AgenticUIAgent.cs index d79787d260..da082483db 100644 --- a/dotnet/samples/AGUIClientServer/AGUIDojoServer/AgenticUI/AgenticUIAgent.cs +++ b/dotnet/samples/AGUIClientServer/AGUIDojoServer/AgenticUI/AgenticUIAgent.cs @@ -19,12 +19,12 @@ public AgenticUIAgent(AIAgent innerAgent, JsonSerializerOptions jsonSerializerOp this._jsonSerializerOptions = jsonSerializerOptions; } - protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { - return this.RunCoreStreamingAsync(messages, thread, options, cancellationToken).ToAgentRunResponseAsync(cancellationToken); + return this.RunCoreStreamingAsync(messages, thread, options, cancellationToken).ToAgentResponseAsync(cancellationToken); } - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -69,7 +69,7 @@ protected override async IAsyncEnumerable RunCoreStreami yield return update; - yield return new AgentRunResponseUpdate( + yield return new AgentResponseUpdate( new ChatResponseUpdate(role: ChatRole.System, stateEventsToEmit) { MessageId = "delta_" + Guid.NewGuid().ToString("N"), diff --git a/dotnet/samples/AGUIClientServer/AGUIDojoServer/ChatClientAgentFactory.cs b/dotnet/samples/AGUIClientServer/AGUIDojoServer/ChatClientAgentFactory.cs index 58f5ad4ae9..d14755db3f 100644 --- a/dotnet/samples/AGUIClientServer/AGUIDojoServer/ChatClientAgentFactory.cs +++ b/dotnet/samples/AGUIClientServer/AGUIDojoServer/ChatClientAgentFactory.cs @@ -33,7 +33,7 @@ public static ChatClientAgent CreateAgenticChat() { ChatClient chatClient = s_azureOpenAIClient!.GetChatClient(s_deploymentName!); - return chatClient.AsIChatClient().CreateAIAgent( + return chatClient.AsIChatClient().AsAIAgent( name: "AgenticChat", description: "A simple chat agent using Azure OpenAI"); } @@ -42,7 +42,7 @@ public static ChatClientAgent CreateBackendToolRendering() { ChatClient chatClient = s_azureOpenAIClient!.GetChatClient(s_deploymentName!); - return chatClient.AsIChatClient().CreateAIAgent( + return chatClient.AsIChatClient().AsAIAgent( name: "BackendToolRenderer", description: "An agent that can render backend tools using Azure OpenAI", tools: [AIFunctionFactory.Create( @@ -56,7 +56,7 @@ public static ChatClientAgent CreateHumanInTheLoop() { ChatClient chatClient = s_azureOpenAIClient!.GetChatClient(s_deploymentName!); - return chatClient.AsIChatClient().CreateAIAgent( + return chatClient.AsIChatClient().AsAIAgent( name: "HumanInTheLoopAgent", description: "An agent that involves human feedback in its decision-making process using Azure OpenAI"); } @@ -65,7 +65,7 @@ public static ChatClientAgent CreateToolBasedGenerativeUI() { ChatClient chatClient = s_azureOpenAIClient!.GetChatClient(s_deploymentName!); - return chatClient.AsIChatClient().CreateAIAgent( + return chatClient.AsIChatClient().AsAIAgent( name: "ToolBasedGenerativeUIAgent", description: "An agent that uses tools to generate user interfaces using Azure OpenAI"); } @@ -73,7 +73,7 @@ public static ChatClientAgent CreateToolBasedGenerativeUI() public static AIAgent CreateAgenticUI(JsonSerializerOptions options) { ChatClient chatClient = s_azureOpenAIClient!.GetChatClient(s_deploymentName!); - var baseAgent = chatClient.AsIChatClient().CreateAIAgent(new ChatClientAgentOptions + var baseAgent = chatClient.AsIChatClient().AsAIAgent(new ChatClientAgentOptions { Name = "AgenticUIAgent", Description = "An agent that generates agentic user interfaces using Azure OpenAI", @@ -116,7 +116,7 @@ public static AIAgent CreateSharedState(JsonSerializerOptions options) { ChatClient chatClient = s_azureOpenAIClient!.GetChatClient(s_deploymentName!); - var baseAgent = chatClient.AsIChatClient().CreateAIAgent( + var baseAgent = chatClient.AsIChatClient().AsAIAgent( name: "SharedStateAgent", description: "An agent that demonstrates shared state patterns using Azure OpenAI"); @@ -127,7 +127,7 @@ public static AIAgent CreatePredictiveStateUpdates(JsonSerializerOptions options { ChatClient chatClient = s_azureOpenAIClient!.GetChatClient(s_deploymentName!); - var baseAgent = chatClient.AsIChatClient().CreateAIAgent(new ChatClientAgentOptions + var baseAgent = chatClient.AsIChatClient().AsAIAgent(new ChatClientAgentOptions { Name = "PredictiveStateUpdatesAgent", Description = "An agent that demonstrates predictive state updates using Azure OpenAI", diff --git a/dotnet/samples/AGUIClientServer/AGUIDojoServer/PredictiveStateUpdates/PredictiveStateUpdatesAgent.cs b/dotnet/samples/AGUIClientServer/AGUIDojoServer/PredictiveStateUpdates/PredictiveStateUpdatesAgent.cs index ab9ca2fca3..2e994d8e29 100644 --- a/dotnet/samples/AGUIClientServer/AGUIDojoServer/PredictiveStateUpdates/PredictiveStateUpdatesAgent.cs +++ b/dotnet/samples/AGUIClientServer/AGUIDojoServer/PredictiveStateUpdates/PredictiveStateUpdatesAgent.cs @@ -20,12 +20,12 @@ public PredictiveStateUpdatesAgent(AIAgent innerAgent, JsonSerializerOptions jso this._jsonSerializerOptions = jsonSerializerOptions; } - protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { - return this.RunCoreStreamingAsync(messages, thread, options, cancellationToken).ToAgentRunResponseAsync(cancellationToken); + return this.RunCoreStreamingAsync(messages, thread, options, cancellationToken).ToAgentResponseAsync(cancellationToken); } - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -79,7 +79,7 @@ protected override async IAsyncEnumerable RunCoreStreami stateUpdate, this._jsonSerializerOptions.GetTypeInfo(typeof(DocumentState))); - yield return new AgentRunResponseUpdate( + yield return new AgentResponseUpdate( new ChatResponseUpdate(role: ChatRole.Assistant, [new DataContent(stateBytes, "application/json")]) { MessageId = "snapshot" + Guid.NewGuid().ToString("N"), diff --git a/dotnet/samples/AGUIClientServer/AGUIDojoServer/SharedState/SharedStateAgent.cs b/dotnet/samples/AGUIClientServer/AGUIDojoServer/SharedState/SharedStateAgent.cs index 1a1e58860a..36a629dd56 100644 --- a/dotnet/samples/AGUIClientServer/AGUIDojoServer/SharedState/SharedStateAgent.cs +++ b/dotnet/samples/AGUIClientServer/AGUIDojoServer/SharedState/SharedStateAgent.cs @@ -19,12 +19,12 @@ public SharedStateAgent(AIAgent innerAgent, JsonSerializerOptions jsonSerializer this._jsonSerializerOptions = jsonSerializerOptions; } - protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { - return this.RunCoreStreamingAsync(messages, thread, options, cancellationToken).ToAgentRunResponseAsync(cancellationToken); + return this.RunCoreStreamingAsync(messages, thread, options, cancellationToken).ToAgentResponseAsync(cancellationToken); } - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -63,7 +63,7 @@ protected override async IAsyncEnumerable RunCoreStreami var firstRunMessages = messages.Append(stateUpdateMessage); - var allUpdates = new List(); + var allUpdates = new List(); await foreach (var update in this.InnerAgent.RunStreamingAsync(firstRunMessages, thread, firstRunOptions, cancellationToken).ConfigureAwait(false)) { allUpdates.Add(update); @@ -76,14 +76,14 @@ protected override async IAsyncEnumerable RunCoreStreami } } - var response = allUpdates.ToAgentRunResponse(); + var response = allUpdates.ToAgentResponse(); if (response.TryDeserialize(this._jsonSerializerOptions, out JsonElement stateSnapshot)) { byte[] stateBytes = JsonSerializer.SerializeToUtf8Bytes( stateSnapshot, this._jsonSerializerOptions.GetTypeInfo(typeof(JsonElement))); - yield return new AgentRunResponseUpdate + yield return new AgentResponseUpdate { Contents = [new DataContent(stateBytes, "application/json")] }; diff --git a/dotnet/samples/AGUIClientServer/AGUIServer/Program.cs b/dotnet/samples/AGUIClientServer/AGUIServer/Program.cs index bcfd86e60d..418f72ad43 100644 --- a/dotnet/samples/AGUIClientServer/AGUIServer/Program.cs +++ b/dotnet/samples/AGUIClientServer/AGUIServer/Program.cs @@ -23,7 +23,7 @@ new Uri(endpoint), new DefaultAzureCredential()) .GetChatClient(deploymentName) - .CreateAIAgent( + .AsAIAgent( name: "AGUIAssistant", tools: [ AIFunctionFactory.Create( diff --git a/dotnet/samples/AGUIClientServer/README.md b/dotnet/samples/AGUIClientServer/README.md index b0ad2265d0..2e4887cde9 100644 --- a/dotnet/samples/AGUIClientServer/README.md +++ b/dotnet/samples/AGUIClientServer/README.md @@ -119,7 +119,7 @@ The `AGUIServer` uses the `MapAGUI` extension method to expose an agent through ```csharp AIAgent agent = new OpenAIClient(apiKey) .GetChatClient(model) - .CreateAIAgent( + .AsAIAgent( instructions: "You are a helpful assistant.", name: "AGUIAssistant"); @@ -144,16 +144,16 @@ var chatClient = new AGUIChatClient( modelId: "agui-client", jsonSerializerOptions: null); -AIAgent agent = chatClient.CreateAIAgent( +AIAgent agent = chatClient.AsAIAgent( instructions: null, name: "agui-client", description: "AG-UI Client Agent", tools: []); bool isFirstUpdate = true; -AgentRunResponseUpdate? currentUpdate = null; +AgentResponseUpdate? currentUpdate = null; -await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(messages, thread)) +await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(messages, thread)) { // First update indicates run started if (isFirstUpdate) @@ -190,19 +190,19 @@ if (currentUpdate != null) The `RunStreamingAsync` method: 1. Sends messages to the server via HTTP POST 2. Receives server-sent events (SSE) stream -3. Parses events into `AgentRunResponseUpdate` objects +3. Parses events into `AgentResponseUpdate` objects 4. Yields updates as they arrive for real-time display ## Key Concepts - **Thread**: Represents a conversation context that persists across multiple runs (accessed via `ConversationId` property) - **Run**: A single execution of the agent for a given set of messages (identified by `ResponseId` property) -- **AgentRunResponseUpdate**: Contains the response data with: +- **AgentResponseUpdate**: Contains the response data with: - `ResponseId`: The unique run identifier - `ConversationId`: The thread/conversation identifier - `Contents`: Collection of content items (TextContent, ErrorContent, etc.) - **Run Lifecycle**: - - The **first** `AgentRunResponseUpdate` in a run indicates the run has started + - The **first** `AgentResponseUpdate` in a run indicates the run has started - Subsequent updates contain streaming content as the agent processes - - The **last** `AgentRunResponseUpdate` in a run indicates the run has finished + - The **last** `AgentResponseUpdate` in a run indicates the run has finished - If an error occurs, the update will contain `ErrorContent` \ No newline at end of file diff --git a/dotnet/samples/AGUIWebChat/README.md b/dotnet/samples/AGUIWebChat/README.md index 75af0872c1..bdb8ae25d2 100644 --- a/dotnet/samples/AGUIWebChat/README.md +++ b/dotnet/samples/AGUIWebChat/README.md @@ -74,7 +74,7 @@ AzureOpenAIClient azureOpenAIClient = new AzureOpenAIClient( ChatClient chatClient = azureOpenAIClient.GetChatClient(deploymentName); // Create AI agent -ChatClientAgent agent = chatClient.AsIChatClient().CreateAIAgent( +ChatClientAgent agent = chatClient.AsIChatClient().AsAIAgent( name: "ChatAssistant", instructions: "You are a helpful assistant."); @@ -162,7 +162,7 @@ dotnet run Edit the instructions in `Server/Program.cs`: ```csharp -ChatClientAgent agent = chatClient.AsIChatClient().CreateAIAgent( +ChatClientAgent agent = chatClient.AsIChatClient().AsAIAgent( name: "ChatAssistant", instructions: "You are a helpful coding assistant specializing in C# and .NET."); ``` diff --git a/dotnet/samples/AGUIWebChat/Server/Program.cs b/dotnet/samples/AGUIWebChat/Server/Program.cs index 1683a7e3ed..eb5b259016 100644 --- a/dotnet/samples/AGUIWebChat/Server/Program.cs +++ b/dotnet/samples/AGUIWebChat/Server/Program.cs @@ -25,7 +25,7 @@ ChatClient chatClient = azureOpenAIClient.GetChatClient(deploymentName); -ChatClientAgent agent = chatClient.AsIChatClient().CreateAIAgent( +ChatClientAgent agent = chatClient.AsIChatClient().AsAIAgent( name: "ChatAssistant", instructions: "You are a helpful assistant."); diff --git a/dotnet/samples/AgentWebChat/AgentWebChat.Web/A2AAgentClient.cs b/dotnet/samples/AgentWebChat/AgentWebChat.Web/A2AAgentClient.cs index 08dafea129..1f87597122 100644 --- a/dotnet/samples/AgentWebChat/AgentWebChat.Web/A2AAgentClient.cs +++ b/dotnet/samples/AgentWebChat/AgentWebChat.Web/A2AAgentClient.cs @@ -25,7 +25,7 @@ public A2AAgentClient(ILogger logger, Uri baseUri) this._uri = baseUri; } - public override async IAsyncEnumerable RunStreamingAsync( + public override async IAsyncEnumerable RunStreamingAsync( string agentName, IList messages, string? threadId = null, @@ -37,7 +37,7 @@ public override async IAsyncEnumerable RunStreamingAsync var contextId = threadId ?? Guid.NewGuid().ToString("N"); // Convert and send messages via A2A without try-catch in yield method - var results = new List(); + var results = new List(); try { @@ -60,7 +60,7 @@ public override async IAsyncEnumerable RunStreamingAsync var responseMessage = message.ToChatMessage(); if (responseMessage is { Contents.Count: > 0 }) { - results.Add(new AgentRunResponseUpdate(responseMessage.Role, responseMessage.Contents) + results.Add(new AgentResponseUpdate(responseMessage.Role, responseMessage.Contents) { MessageId = message.MessageId, CreatedAt = DateTimeOffset.UtcNow @@ -90,7 +90,7 @@ public override async IAsyncEnumerable RunStreamingAsync RawRepresentation = artifact, }; - results.Add(new AgentRunResponseUpdate(chatMessage.Role, chatMessage.Contents) + results.Add(new AgentResponseUpdate(chatMessage.Role, chatMessage.Contents) { MessageId = agentTask.Id, CreatedAt = DateTimeOffset.UtcNow @@ -108,7 +108,7 @@ public override async IAsyncEnumerable RunStreamingAsync { this._logger.LogError(ex, "Error running agent {AgentName} via A2A", agentName); - results.Add(new AgentRunResponseUpdate(ChatRole.Assistant, $"Error: {ex.Message}") + results.Add(new AgentResponseUpdate(ChatRole.Assistant, $"Error: {ex.Message}") { MessageId = Guid.NewGuid().ToString("N"), CreatedAt = DateTimeOffset.UtcNow diff --git a/dotnet/samples/AgentWebChat/AgentWebChat.Web/IAgentClient.cs b/dotnet/samples/AgentWebChat/AgentWebChat.Web/IAgentClient.cs index 2d08ef5e45..2d22413f0d 100644 --- a/dotnet/samples/AgentWebChat/AgentWebChat.Web/IAgentClient.cs +++ b/dotnet/samples/AgentWebChat/AgentWebChat.Web/IAgentClient.cs @@ -19,7 +19,7 @@ internal abstract class AgentClientBase /// Optional thread identifier for conversation continuity. /// Cancellation token. /// An asynchronous enumerable of agent response updates. - public abstract IAsyncEnumerable RunStreamingAsync( + public abstract IAsyncEnumerable RunStreamingAsync( string agentName, IList messages, string? threadId = null, diff --git a/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIChatCompletionsAgentClient.cs b/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIChatCompletionsAgentClient.cs index 95e3d16fd4..a5b522a892 100644 --- a/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIChatCompletionsAgentClient.cs +++ b/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIChatCompletionsAgentClient.cs @@ -16,7 +16,7 @@ namespace AgentWebChat.Web; /// internal sealed class OpenAIChatCompletionsAgentClient(HttpClient httpClient) : AgentClientBase { - public override async IAsyncEnumerable RunStreamingAsync( + public override async IAsyncEnumerable RunStreamingAsync( string agentName, IList messages, string? threadId = null, @@ -31,7 +31,7 @@ public override async IAsyncEnumerable RunStreamingAsync var openAiClient = new ChatClient(model: "myModel!", credential: new ApiKeyCredential("dummy-key"), options: options).AsIChatClient(); await foreach (var update in openAiClient.GetStreamingResponseAsync(messages, cancellationToken: cancellationToken)) { - yield return new AgentRunResponseUpdate(update); + yield return new AgentResponseUpdate(update); } } } diff --git a/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIResponsesAgentClient.cs b/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIResponsesAgentClient.cs index d0121a6165..7594468398 100644 --- a/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIResponsesAgentClient.cs +++ b/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIResponsesAgentClient.cs @@ -15,7 +15,7 @@ namespace AgentWebChat.Web; /// internal sealed class OpenAIResponsesAgentClient(HttpClient httpClient) : AgentClientBase { - public override async IAsyncEnumerable RunStreamingAsync( + public override async IAsyncEnumerable RunStreamingAsync( string agentName, IList messages, string? threadId = null, @@ -35,7 +35,7 @@ public override async IAsyncEnumerable RunStreamingAsync await foreach (var update in openAiClient.GetStreamingResponseAsync(messages, chatOptions, cancellationToken: cancellationToken)) { - yield return new AgentRunResponseUpdate(update); + yield return new AgentResponseUpdate(update); } } } diff --git a/dotnet/samples/AzureFunctions/01_SingleAgent/Program.cs b/dotnet/samples/AzureFunctions/01_SingleAgent/Program.cs index 39b020e137..609ba162c9 100644 --- a/dotnet/samples/AzureFunctions/01_SingleAgent/Program.cs +++ b/dotnet/samples/AzureFunctions/01_SingleAgent/Program.cs @@ -25,7 +25,7 @@ const string JokerName = "Joker"; const string JokerInstructions = "You are good at telling jokes."; -AIAgent agent = client.GetChatClient(deploymentName).CreateAIAgent(JokerInstructions, JokerName); +AIAgent agent = client.GetChatClient(deploymentName).AsAIAgent(JokerInstructions, JokerName); // Configure the function app to host the AI agent. // This will automatically generate HTTP API endpoints for the agent. diff --git a/dotnet/samples/AzureFunctions/02_AgentOrchestration_Chaining/FunctionTriggers.cs b/dotnet/samples/AzureFunctions/02_AgentOrchestration_Chaining/FunctionTriggers.cs index a631e7715c..8ac9ea8d51 100644 --- a/dotnet/samples/AzureFunctions/02_AgentOrchestration_Chaining/FunctionTriggers.cs +++ b/dotnet/samples/AzureFunctions/02_AgentOrchestration_Chaining/FunctionTriggers.cs @@ -19,13 +19,13 @@ public sealed record TextResponse(string Text); public static async Task RunOrchestrationAsync([OrchestrationTrigger] TaskOrchestrationContext context) { DurableAIAgent writer = context.GetAgent("WriterAgent"); - AgentThread writerThread = writer.GetNewThread(); + AgentThread writerThread = await writer.GetNewThreadAsync(); - AgentRunResponse initial = await writer.RunAsync( + AgentResponse initial = await writer.RunAsync( message: "Write a concise inspirational sentence about learning.", thread: writerThread); - AgentRunResponse refined = await writer.RunAsync( + AgentResponse refined = await writer.RunAsync( message: $"Improve this further while keeping it under 25 words: {initial.Result.Text}", thread: writerThread); diff --git a/dotnet/samples/AzureFunctions/02_AgentOrchestration_Chaining/Program.cs b/dotnet/samples/AzureFunctions/02_AgentOrchestration_Chaining/Program.cs index 3776ecd062..ba16578935 100644 --- a/dotnet/samples/AzureFunctions/02_AgentOrchestration_Chaining/Program.cs +++ b/dotnet/samples/AzureFunctions/02_AgentOrchestration_Chaining/Program.cs @@ -29,7 +29,7 @@ when given an improved sentence you polish it further. """; -AIAgent writerAgent = client.GetChatClient(deploymentName).CreateAIAgent(WriterInstructions, WriterName); +AIAgent writerAgent = client.GetChatClient(deploymentName).AsAIAgent(WriterInstructions, WriterName); using IHost app = FunctionsApplication .CreateBuilder(args) diff --git a/dotnet/samples/AzureFunctions/03_AgentOrchestration_Concurrency/FunctionTriggers.cs b/dotnet/samples/AzureFunctions/03_AgentOrchestration_Concurrency/FunctionTriggers.cs index 2d15dd585c..241faf6df7 100644 --- a/dotnet/samples/AzureFunctions/03_AgentOrchestration_Concurrency/FunctionTriggers.cs +++ b/dotnet/samples/AzureFunctions/03_AgentOrchestration_Concurrency/FunctionTriggers.cs @@ -26,9 +26,9 @@ public static async Task RunOrchestrationAsync([OrchestrationTrigger] Ta DurableAIAgent chemist = context.GetAgent("ChemistAgent"); // Start both agent runs concurrently - Task> physicistTask = physicist.RunAsync(prompt); + Task> physicistTask = physicist.RunAsync(prompt); - Task> chemistTask = chemist.RunAsync(prompt); + Task> chemistTask = chemist.RunAsync(prompt); // Wait for both tasks to complete using Task.WhenAll await Task.WhenAll(physicistTask, chemistTask); diff --git a/dotnet/samples/AzureFunctions/03_AgentOrchestration_Concurrency/Program.cs b/dotnet/samples/AzureFunctions/03_AgentOrchestration_Concurrency/Program.cs index dfc2049d45..b180c8139f 100644 --- a/dotnet/samples/AzureFunctions/03_AgentOrchestration_Concurrency/Program.cs +++ b/dotnet/samples/AzureFunctions/03_AgentOrchestration_Concurrency/Program.cs @@ -28,8 +28,8 @@ const string ChemistName = "ChemistAgent"; const string ChemistInstructions = "You are an expert in chemistry. You answer questions from a chemistry perspective."; -AIAgent physicistAgent = client.GetChatClient(deploymentName).CreateAIAgent(PhysicistInstructions, PhysicistName); -AIAgent chemistAgent = client.GetChatClient(deploymentName).CreateAIAgent(ChemistInstructions, ChemistName); +AIAgent physicistAgent = client.GetChatClient(deploymentName).AsAIAgent(PhysicistInstructions, PhysicistName); +AIAgent chemistAgent = client.GetChatClient(deploymentName).AsAIAgent(ChemistInstructions, ChemistName); using IHost app = FunctionsApplication .CreateBuilder(args) diff --git a/dotnet/samples/AzureFunctions/04_AgentOrchestration_Conditionals/FunctionTriggers.cs b/dotnet/samples/AzureFunctions/04_AgentOrchestration_Conditionals/FunctionTriggers.cs index 14a91185f8..f09579978f 100644 --- a/dotnet/samples/AzureFunctions/04_AgentOrchestration_Conditionals/FunctionTriggers.cs +++ b/dotnet/samples/AzureFunctions/04_AgentOrchestration_Conditionals/FunctionTriggers.cs @@ -21,10 +21,10 @@ public static async Task RunOrchestrationAsync([OrchestrationTrigger] Ta // Get the spam detection agent DurableAIAgent spamDetectionAgent = context.GetAgent("SpamDetectionAgent"); - AgentThread spamThread = spamDetectionAgent.GetNewThread(); + AgentThread spamThread = await spamDetectionAgent.GetNewThreadAsync(); // Step 1: Check if the email is spam - AgentRunResponse spamDetectionResponse = await spamDetectionAgent.RunAsync( + AgentResponse spamDetectionResponse = await spamDetectionAgent.RunAsync( message: $""" Analyze this email for spam content and return a JSON response with 'is_spam' (boolean) and 'reason' (string) fields: @@ -43,9 +43,9 @@ public static async Task RunOrchestrationAsync([OrchestrationTrigger] Ta // Generate and send response for legitimate email DurableAIAgent emailAssistantAgent = context.GetAgent("EmailAssistantAgent"); - AgentThread emailThread = emailAssistantAgent.GetNewThread(); + AgentThread emailThread = await emailAssistantAgent.GetNewThreadAsync(); - AgentRunResponse emailAssistantResponse = await emailAssistantAgent.RunAsync( + AgentResponse emailAssistantResponse = await emailAssistantAgent.RunAsync( message: $""" Draft a professional response to this email. Return a JSON response with a 'response' field containing the reply: diff --git a/dotnet/samples/AzureFunctions/04_AgentOrchestration_Conditionals/Program.cs b/dotnet/samples/AzureFunctions/04_AgentOrchestration_Conditionals/Program.cs index a04b4c3e70..07dcd302cc 100644 --- a/dotnet/samples/AzureFunctions/04_AgentOrchestration_Conditionals/Program.cs +++ b/dotnet/samples/AzureFunctions/04_AgentOrchestration_Conditionals/Program.cs @@ -29,10 +29,10 @@ const string EmailAssistantInstructions = "You are an email assistant that helps users draft responses to emails with professionalism."; AIAgent spamDetectionAgent = client.GetChatClient(deploymentName) - .CreateAIAgent(SpamDetectionInstructions, SpamDetectionName); + .AsAIAgent(SpamDetectionInstructions, SpamDetectionName); AIAgent emailAssistantAgent = client.GetChatClient(deploymentName) - .CreateAIAgent(EmailAssistantInstructions, EmailAssistantName); + .AsAIAgent(EmailAssistantInstructions, EmailAssistantName); using IHost app = FunctionsApplication .CreateBuilder(args) diff --git a/dotnet/samples/AzureFunctions/05_AgentOrchestration_HITL/FunctionTriggers.cs b/dotnet/samples/AzureFunctions/05_AgentOrchestration_HITL/FunctionTriggers.cs index 001a52c105..6dcbb50c01 100644 --- a/dotnet/samples/AzureFunctions/05_AgentOrchestration_HITL/FunctionTriggers.cs +++ b/dotnet/samples/AzureFunctions/05_AgentOrchestration_HITL/FunctionTriggers.cs @@ -24,13 +24,13 @@ public static async Task RunOrchestrationAsync( // Get the writer agent DurableAIAgent writerAgent = context.GetAgent("WriterAgent"); - AgentThread writerThread = writerAgent.GetNewThread(); + AgentThread writerThread = await writerAgent.GetNewThreadAsync(); // Set initial status context.SetCustomStatus($"Starting content generation for topic: {input.Topic}"); // Step 1: Generate initial content - AgentRunResponse writerResponse = await writerAgent.RunAsync( + AgentResponse writerResponse = await writerAgent.RunAsync( message: $"Write a short article about '{input.Topic}'.", thread: writerThread); GeneratedContent content = writerResponse.Result; diff --git a/dotnet/samples/AzureFunctions/05_AgentOrchestration_HITL/Program.cs b/dotnet/samples/AzureFunctions/05_AgentOrchestration_HITL/Program.cs index 741e5407e0..77e2dfa2d4 100644 --- a/dotnet/samples/AzureFunctions/05_AgentOrchestration_HITL/Program.cs +++ b/dotnet/samples/AzureFunctions/05_AgentOrchestration_HITL/Program.cs @@ -29,7 +29,7 @@ You are a professional content writer who creates high-quality articles on vario You write engaging, informative, and well-structured content that follows best practices for readability and accuracy. """; -AIAgent writerAgent = client.GetChatClient(deploymentName).CreateAIAgent(WriterInstructions, WriterName); +AIAgent writerAgent = client.GetChatClient(deploymentName).AsAIAgent(WriterInstructions, WriterName); using IHost app = FunctionsApplication .CreateBuilder(args) diff --git a/dotnet/samples/AzureFunctions/06_LongRunningTools/FunctionTriggers.cs b/dotnet/samples/AzureFunctions/06_LongRunningTools/FunctionTriggers.cs index b5f81276b8..9f73cff18e 100644 --- a/dotnet/samples/AzureFunctions/06_LongRunningTools/FunctionTriggers.cs +++ b/dotnet/samples/AzureFunctions/06_LongRunningTools/FunctionTriggers.cs @@ -20,13 +20,13 @@ public static async Task RunOrchestrationAsync( // Get the writer agent DurableAIAgent writerAgent = context.GetAgent("Writer"); - AgentThread writerThread = writerAgent.GetNewThread(); + AgentThread writerThread = await writerAgent.GetNewThreadAsync(); // Set initial status context.SetCustomStatus($"Starting content generation for topic: {input.Topic}"); // Step 1: Generate initial content - AgentRunResponse writerResponse = await writerAgent.RunAsync( + AgentResponse writerResponse = await writerAgent.RunAsync( message: $"Write a short article about '{input.Topic}'.", thread: writerThread); GeneratedContent content = writerResponse.Result; diff --git a/dotnet/samples/AzureFunctions/06_LongRunningTools/Program.cs b/dotnet/samples/AzureFunctions/06_LongRunningTools/Program.cs index 581b2bce11..e4d88d3ae7 100644 --- a/dotnet/samples/AzureFunctions/06_LongRunningTools/Program.cs +++ b/dotnet/samples/AzureFunctions/06_LongRunningTools/Program.cs @@ -33,7 +33,7 @@ You are a professional content writer who creates high-quality articles on vario You write engaging, informative, and well-structured content that follows best practices for readability and accuracy. """; -AIAgent writerAgent = client.GetChatClient(deploymentName).CreateAIAgent(WriterAgentInstructions, WriterAgentName); +AIAgent writerAgent = client.GetChatClient(deploymentName).AsAIAgent(WriterAgentInstructions, WriterAgentName); // Agent that can start content generation workflows using tools const string PublisherAgentName = "Publisher"; @@ -57,7 +57,7 @@ You are a publishing agent that can manage content generation workflows. // Initialize the tools to be used by the agent. Tools publisherTools = new(sp.GetRequiredService>()); - return client.GetChatClient(deploymentName).CreateAIAgent( + return client.GetChatClient(deploymentName).AsAIAgent( instructions: PublisherAgentInstructions, name: PublisherAgentName, services: sp, diff --git a/dotnet/samples/AzureFunctions/07_AgentAsMcpTool/Program.cs b/dotnet/samples/AzureFunctions/07_AgentAsMcpTool/Program.cs index bc0a69cbf2..2503037a8c 100644 --- a/dotnet/samples/AzureFunctions/07_AgentAsMcpTool/Program.cs +++ b/dotnet/samples/AzureFunctions/07_AgentAsMcpTool/Program.cs @@ -28,13 +28,13 @@ : new AzureOpenAIClient(new Uri(endpoint), new AzureCliCredential()); // Define three AI agents we are going to use in this application. -AIAgent agent1 = client.GetChatClient(deploymentName).CreateAIAgent("You are good at telling jokes.", "Joker"); +AIAgent agent1 = client.GetChatClient(deploymentName).AsAIAgent("You are good at telling jokes.", "Joker"); AIAgent agent2 = client.GetChatClient(deploymentName) - .CreateAIAgent("Check stock prices.", "StockAdvisor"); + .AsAIAgent("Check stock prices.", "StockAdvisor"); AIAgent agent3 = client.GetChatClient(deploymentName) - .CreateAIAgent("Recommend plants.", "PlantAdvisor", description: "Get plant recommendations."); + .AsAIAgent("Recommend plants.", "PlantAdvisor", description: "Get plant recommendations."); using IHost app = FunctionsApplication .CreateBuilder(args) diff --git a/dotnet/samples/AzureFunctions/08_ReliableStreaming/FunctionTriggers.cs b/dotnet/samples/AzureFunctions/08_ReliableStreaming/FunctionTriggers.cs index a6d3e9db55..94905f8156 100644 --- a/dotnet/samples/AzureFunctions/08_ReliableStreaming/FunctionTriggers.cs +++ b/dotnet/samples/AzureFunctions/08_ReliableStreaming/FunctionTriggers.cs @@ -95,7 +95,7 @@ public async Task CreateAsync( AIAgent agentProxy = durableClient.AsDurableAgentProxy(context, "TravelPlanner"); // Create a new agent thread - AgentThread thread = agentProxy.GetNewThread(); + AgentThread thread = await agentProxy.GetNewThreadAsync(cancellationToken); string agentSessionId = thread.GetService().ToString(); this._logger.LogInformation("Creating new agent session: {AgentSessionId}", agentSessionId); diff --git a/dotnet/samples/AzureFunctions/08_ReliableStreaming/Program.cs b/dotnet/samples/AzureFunctions/08_ReliableStreaming/Program.cs index 6c48ed4177..c279b968a3 100644 --- a/dotnet/samples/AzureFunctions/08_ReliableStreaming/Program.cs +++ b/dotnet/samples/AzureFunctions/08_ReliableStreaming/Program.cs @@ -70,7 +70,7 @@ to make the itinerary easy to scan and visually appealing. // Define the Travel Planner agent with tools for weather and events options.AddAIAgentFactory(TravelPlannerName, sp => { - return client.GetChatClient(deploymentName).CreateAIAgent( + return client.GetChatClient(deploymentName).AsAIAgent( instructions: TravelPlannerInstructions, name: TravelPlannerName, services: sp, diff --git a/dotnet/samples/AzureFunctions/08_ReliableStreaming/README.md b/dotnet/samples/AzureFunctions/08_ReliableStreaming/README.md index f1c68c2339..fd13f23ecd 100644 --- a/dotnet/samples/AzureFunctions/08_ReliableStreaming/README.md +++ b/dotnet/samples/AzureFunctions/08_ReliableStreaming/README.md @@ -196,7 +196,7 @@ The `id` field is the Redis stream entry ID - use it as the `cursor` parameter t 2. **Agent invoked**: The durable entity (`AgentEntity`) is signaled to run the travel planner agent. This is fire-and-forget from the HTTP request's perspective. -3. **Responses captured**: As the agent generates responses, `RedisStreamResponseHandler` (implementing `IAgentResponseHandler`) extracts the text from each `AgentRunResponseUpdate` and publishes it to a Redis Stream keyed by session ID. +3. **Responses captured**: As the agent generates responses, `RedisStreamResponseHandler` (implementing `IAgentResponseHandler`) extracts the text from each `AgentResponseUpdate` and publishes it to a Redis Stream keyed by session ID. 4. **Client polls Redis**: The HTTP response streams events by polling the Redis Stream. For SSE format, each event includes the Redis entry ID as the `id` field. diff --git a/dotnet/samples/AzureFunctions/08_ReliableStreaming/RedisStreamResponseHandler.cs b/dotnet/samples/AzureFunctions/08_ReliableStreaming/RedisStreamResponseHandler.cs index 21f944338a..e13c685a08 100644 --- a/dotnet/samples/AzureFunctions/08_ReliableStreaming/RedisStreamResponseHandler.cs +++ b/dotnet/samples/AzureFunctions/08_ReliableStreaming/RedisStreamResponseHandler.cs @@ -29,7 +29,7 @@ namespace ReliableStreaming; /// /// /// Each agent session gets its own Redis Stream, keyed by session ID. The stream entries -/// contain text chunks extracted from objects. +/// contain text chunks extracted from objects. /// /// public sealed class RedisStreamResponseHandler : IAgentResponseHandler @@ -53,7 +53,7 @@ public RedisStreamResponseHandler(IConnectionMultiplexer redis, TimeSpan streamT /// public async ValueTask OnStreamingResponseUpdateAsync( - IAsyncEnumerable messageStream, + IAsyncEnumerable messageStream, CancellationToken cancellationToken) { // Get the current session ID from the DurableAgentContext @@ -73,7 +73,7 @@ public async ValueTask OnStreamingResponseUpdateAsync( IDatabase db = this._redis.GetDatabase(); int sequenceNumber = 0; - await foreach (AgentRunResponseUpdate update in messageStream.WithCancellation(cancellationToken)) + await foreach (AgentResponseUpdate update in messageStream.WithCancellation(cancellationToken)) { // Extract just the text content - this avoids serialization round-trip issues string text = update.Text; @@ -112,7 +112,7 @@ public async ValueTask OnStreamingResponseUpdateAsync( } /// - public ValueTask OnAgentResponseAsync(AgentRunResponse message, CancellationToken cancellationToken) + public ValueTask OnAgentResponseAsync(AgentResponse message, CancellationToken cancellationToken) { // This handler is optimized for streaming responses. // For non-streaming responses, we don't need to store in Redis since diff --git a/dotnet/samples/GettingStarted/A2A/A2AAgent_AsFunctionTools/Program.cs b/dotnet/samples/GettingStarted/A2A/A2AAgent_AsFunctionTools/Program.cs index c813464d0c..d1384d2c21 100644 --- a/dotnet/samples/GettingStarted/A2A/A2AAgent_AsFunctionTools/Program.cs +++ b/dotnet/samples/GettingStarted/A2A/A2AAgent_AsFunctionTools/Program.cs @@ -23,14 +23,14 @@ AgentCard agentCard = await agentCardResolver.GetAgentCardAsync(); // Create an instance of the AIAgent for an existing A2A agent specified by the agent card. -AIAgent a2aAgent = agentCard.GetAIAgent(); +AIAgent a2aAgent = agentCard.AsAIAgent(); // Create the main agent, and provide the a2a agent skills as a function tools. AIAgent agent = new AzureOpenAIClient( new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent( + .AsAIAgent( instructions: "You are a helpful assistant that helps people with travel planning.", tools: [.. CreateFunctionTools(a2aAgent, agentCard)] ); diff --git a/dotnet/samples/GettingStarted/A2A/A2AAgent_PollingForTaskCompletion/Program.cs b/dotnet/samples/GettingStarted/A2A/A2AAgent_PollingForTaskCompletion/Program.cs index 7b5934575c..de5cc79ebf 100644 --- a/dotnet/samples/GettingStarted/A2A/A2AAgent_PollingForTaskCompletion/Program.cs +++ b/dotnet/samples/GettingStarted/A2A/A2AAgent_PollingForTaskCompletion/Program.cs @@ -14,12 +14,12 @@ AgentCard agentCard = await agentCardResolver.GetAgentCardAsync(); // Create an instance of the AIAgent for an existing A2A agent specified by the agent card. -AIAgent agent = agentCard.GetAIAgent(); +AIAgent agent = agentCard.AsAIAgent(); -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); // Start the initial run with a long-running task. -AgentRunResponse response = await agent.RunAsync("Conduct a comprehensive analysis of quantum computing applications in cryptography, including recent breakthroughs, implementation challenges, and future roadmap. Please include diagrams and visual representations to illustrate complex concepts.", thread); +AgentResponse response = await agent.RunAsync("Conduct a comprehensive analysis of quantum computing applications in cryptography, including recent breakthroughs, implementation challenges, and future roadmap. Please include diagrams and visual representations to illustrate complex concepts.", thread); // Poll until the response is complete. while (response.ContinuationToken is { } token) diff --git a/dotnet/samples/GettingStarted/AGUI/README.md b/dotnet/samples/GettingStarted/AGUI/README.md index a624fe81f3..f55e317e36 100644 --- a/dotnet/samples/GettingStarted/AGUI/README.md +++ b/dotnet/samples/GettingStarted/AGUI/README.md @@ -212,7 +212,7 @@ dotnet run 1. `AGUIAgent` sends HTTP POST request to server 2. Server responds with SSE stream -3. Client parses events into `AgentRunResponseUpdate` objects +3. Client parses events into `AgentResponseUpdate` objects 4. Updates are displayed based on content type 5. `ConversationId` maintains conversation context diff --git a/dotnet/samples/GettingStarted/AGUI/Step01_GettingStarted/Client/Program.cs b/dotnet/samples/GettingStarted/AGUI/Step01_GettingStarted/Client/Program.cs index d942314806..b3e74e7efd 100644 --- a/dotnet/samples/GettingStarted/AGUI/Step01_GettingStarted/Client/Program.cs +++ b/dotnet/samples/GettingStarted/AGUI/Step01_GettingStarted/Client/Program.cs @@ -16,11 +16,11 @@ AGUIChatClient chatClient = new(httpClient, serverUrl); -AIAgent agent = chatClient.CreateAIAgent( +AIAgent agent = chatClient.AsAIAgent( name: "agui-client", description: "AG-UI Client Agent"); -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); List messages = [ new(ChatRole.System, "You are a helpful assistant.") @@ -51,7 +51,7 @@ bool isFirstUpdate = true; string? threadId = null; - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(messages, thread)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(messages, thread)) { ChatResponseUpdate chatUpdate = update.AsChatResponseUpdate(); diff --git a/dotnet/samples/GettingStarted/AGUI/Step01_GettingStarted/Server/Program.cs b/dotnet/samples/GettingStarted/AGUI/Step01_GettingStarted/Server/Program.cs index 1bfb9a97aa..fb3cbe401e 100644 --- a/dotnet/samples/GettingStarted/AGUI/Step01_GettingStarted/Server/Program.cs +++ b/dotnet/samples/GettingStarted/AGUI/Step01_GettingStarted/Server/Program.cs @@ -24,7 +24,7 @@ new DefaultAzureCredential()) .GetChatClient(deploymentName); -AIAgent agent = chatClient.AsIChatClient().CreateAIAgent( +AIAgent agent = chatClient.AsIChatClient().AsAIAgent( name: "AGUIAssistant", instructions: "You are a helpful assistant."); diff --git a/dotnet/samples/GettingStarted/AGUI/Step02_BackendTools/Client/Program.cs b/dotnet/samples/GettingStarted/AGUI/Step02_BackendTools/Client/Program.cs index 1919a9565f..9544d4286b 100644 --- a/dotnet/samples/GettingStarted/AGUI/Step02_BackendTools/Client/Program.cs +++ b/dotnet/samples/GettingStarted/AGUI/Step02_BackendTools/Client/Program.cs @@ -16,11 +16,11 @@ AGUIChatClient chatClient = new(httpClient, serverUrl); -AIAgent agent = chatClient.CreateAIAgent( +AIAgent agent = chatClient.AsAIAgent( name: "agui-client", description: "AG-UI Client Agent"); -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); List messages = [ new(ChatRole.System, "You are a helpful assistant.") @@ -51,7 +51,7 @@ bool isFirstUpdate = true; string? threadId = null; - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(messages, thread)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(messages, thread)) { ChatResponseUpdate chatUpdate = update.AsChatResponseUpdate(); diff --git a/dotnet/samples/GettingStarted/AGUI/Step02_BackendTools/Server/Program.cs b/dotnet/samples/GettingStarted/AGUI/Step02_BackendTools/Server/Program.cs index 2867721d02..73ece031fc 100644 --- a/dotnet/samples/GettingStarted/AGUI/Step02_BackendTools/Server/Program.cs +++ b/dotnet/samples/GettingStarted/AGUI/Step02_BackendTools/Server/Program.cs @@ -79,7 +79,7 @@ static RestaurantSearchResponse SearchRestaurants( new DefaultAzureCredential()) .GetChatClient(deploymentName); -ChatClientAgent agent = chatClient.AsIChatClient().CreateAIAgent( +ChatClientAgent agent = chatClient.AsIChatClient().AsAIAgent( name: "AGUIAssistant", instructions: "You are a helpful assistant with access to restaurant information.", tools: tools); diff --git a/dotnet/samples/GettingStarted/AGUI/Step03_FrontendTools/Client/Program.cs b/dotnet/samples/GettingStarted/AGUI/Step03_FrontendTools/Client/Program.cs index d295ed7116..fa760e9a6e 100644 --- a/dotnet/samples/GettingStarted/AGUI/Step03_FrontendTools/Client/Program.cs +++ b/dotnet/samples/GettingStarted/AGUI/Step03_FrontendTools/Client/Program.cs @@ -28,12 +28,12 @@ static string GetUserLocation() AGUIChatClient chatClient = new(httpClient, serverUrl); -AIAgent agent = chatClient.CreateAIAgent( +AIAgent agent = chatClient.AsAIAgent( name: "agui-client", description: "AG-UI Client Agent", tools: frontendTools); -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); List messages = [ new(ChatRole.System, "You are a helpful assistant.") @@ -64,7 +64,7 @@ static string GetUserLocation() bool isFirstUpdate = true; string? threadId = null; - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(messages, thread)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(messages, thread)) { ChatResponseUpdate chatUpdate = update.AsChatResponseUpdate(); diff --git a/dotnet/samples/GettingStarted/AGUI/Step03_FrontendTools/Server/Program.cs b/dotnet/samples/GettingStarted/AGUI/Step03_FrontendTools/Server/Program.cs index 1bfb9a97aa..fb3cbe401e 100644 --- a/dotnet/samples/GettingStarted/AGUI/Step03_FrontendTools/Server/Program.cs +++ b/dotnet/samples/GettingStarted/AGUI/Step03_FrontendTools/Server/Program.cs @@ -24,7 +24,7 @@ new DefaultAzureCredential()) .GetChatClient(deploymentName); -AIAgent agent = chatClient.AsIChatClient().CreateAIAgent( +AIAgent agent = chatClient.AsIChatClient().AsAIAgent( name: "AGUIAssistant", instructions: "You are a helpful assistant."); diff --git a/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Client/Program.cs b/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Client/Program.cs index 656989458d..e66087bad7 100644 --- a/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Client/Program.cs +++ b/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Client/Program.cs @@ -16,7 +16,7 @@ AGUIChatClient chatClient = new(httpClient, serverUrl); // Create agent -ChatClientAgent baseAgent = chatClient.CreateAIAgent( +ChatClientAgent baseAgent = chatClient.AsAIAgent( name: "AGUIAssistant", instructions: "You are a helpful assistant."); @@ -51,8 +51,8 @@ { approvalResponses.Clear(); - List chatResponseUpdates = []; - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(messages, thread, cancellationToken: default)) + List chatResponseUpdates = []; + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(messages, thread, cancellationToken: default)) { chatResponseUpdates.Add(update); foreach (AIContent content in update.Contents) @@ -111,7 +111,7 @@ } } - AgentRunResponse response = chatResponseUpdates.ToAgentRunResponse(); + AgentResponse response = chatResponseUpdates.ToAgentResponse(); messages.AddRange(response.Messages); foreach (AIContent approvalResponse in approvalResponses) { diff --git a/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Client/ServerFunctionApprovalClientAgent.cs b/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Client/ServerFunctionApprovalClientAgent.cs index 9f7812cc50..ef84b85281 100644 --- a/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Client/ServerFunctionApprovalClientAgent.cs +++ b/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Client/ServerFunctionApprovalClientAgent.cs @@ -22,17 +22,17 @@ public ServerFunctionApprovalClientAgent(AIAgent innerAgent, JsonSerializerOptio this._jsonSerializerOptions = jsonSerializerOptions; } - protected override Task RunCoreAsync( + protected override Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { return this.RunCoreStreamingAsync(messages, thread, options, cancellationToken) - .ToAgentRunResponseAsync(cancellationToken); + .ToAgentResponseAsync(cancellationToken); } - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -166,8 +166,8 @@ private static List ProcessOutgoingServerFunctionApprovals( return result ?? messages; } - private static AgentRunResponseUpdate ProcessIncomingServerApprovalRequests( - AgentRunResponseUpdate update, + private static AgentResponseUpdate ProcessIncomingServerApprovalRequests( + AgentResponseUpdate update, JsonSerializerOptions jsonSerializerOptions) { IList? updatedContents = null; @@ -215,7 +215,7 @@ private static AgentRunResponseUpdate ProcessIncomingServerApprovalRequests( if (updatedContents is not null) { var chatUpdate = update.AsChatResponseUpdate(); - return new AgentRunResponseUpdate(new ChatResponseUpdate() + return new AgentResponseUpdate(new ChatResponseUpdate() { Role = chatUpdate.Role, Contents = updatedContents, diff --git a/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Server/Program.cs b/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Server/Program.cs index 1af163435a..023b3327ba 100644 --- a/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Server/Program.cs +++ b/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Server/Program.cs @@ -57,7 +57,7 @@ static string ApproveExpenseReport(string expenseReportId) new DefaultAzureCredential()) .GetChatClient(deploymentName); -ChatClientAgent baseAgent = openAIChatClient.AsIChatClient().CreateAIAgent( +ChatClientAgent baseAgent = openAIChatClient.AsIChatClient().AsAIAgent( name: "AGUIAssistant", instructions: "You are a helpful assistant in charge of approving expenses", tools: tools); diff --git a/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Server/ServerFunctionApprovalServerAgent.cs b/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Server/ServerFunctionApprovalServerAgent.cs index 69e3db58c7..01649084ac 100644 --- a/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Server/ServerFunctionApprovalServerAgent.cs +++ b/dotnet/samples/GettingStarted/AGUI/Step04_HumanInLoop/Server/ServerFunctionApprovalServerAgent.cs @@ -22,17 +22,17 @@ public ServerFunctionApprovalAgent(AIAgent innerAgent, JsonSerializerOptions jso this._jsonSerializerOptions = jsonSerializerOptions; } - protected override Task RunCoreAsync( + protected override Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { return this.RunCoreStreamingAsync(messages, thread, options, cancellationToken) - .ToAgentRunResponseAsync(cancellationToken); + .ToAgentResponseAsync(cancellationToken); } - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -172,8 +172,8 @@ private static List ProcessIncomingFunctionApprovals( return result ?? messages; } - private static AgentRunResponseUpdate ProcessOutgoingApprovalRequests( - AgentRunResponseUpdate update, + private static AgentResponseUpdate ProcessOutgoingApprovalRequests( + AgentResponseUpdate update, JsonSerializerOptions jsonSerializerOptions) { IList? updatedContents = null; @@ -207,7 +207,7 @@ private static AgentRunResponseUpdate ProcessOutgoingApprovalRequests( { var chatUpdate = update.AsChatResponseUpdate(); // Yield a tool call update that represents the approval request - return new AgentRunResponseUpdate(new ChatResponseUpdate() + return new AgentResponseUpdate(new ChatResponseUpdate() { Role = chatUpdate.Role, Contents = updatedContents, diff --git a/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Client/Program.cs b/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Client/Program.cs index 49ffa0587d..0072f62845 100644 --- a/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Client/Program.cs +++ b/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Client/Program.cs @@ -19,7 +19,7 @@ AGUIChatClient chatClient = new(httpClient, serverUrl); -AIAgent baseAgent = chatClient.CreateAIAgent( +AIAgent baseAgent = chatClient.AsAIAgent( name: "recipe-client", description: "AG-UI Recipe Client Agent"); @@ -30,7 +30,7 @@ }; StatefulAgent agent = new(baseAgent, jsonOptions, new AgentState()); -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); List messages = [ new(ChatRole.System, "You are a helpful recipe assistant.") @@ -70,7 +70,7 @@ Console.WriteLine(); - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(messages, thread)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(messages, thread)) { ChatResponseUpdate chatUpdate = update.AsChatResponseUpdate(); diff --git a/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Client/StatefulAgent.cs b/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Client/StatefulAgent.cs index d5fd9f187b..8eca890e60 100644 --- a/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Client/StatefulAgent.cs +++ b/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Client/StatefulAgent.cs @@ -35,18 +35,18 @@ public StatefulAgent(AIAgent innerAgent, JsonSerializerOptions jsonSerializerOpt } /// - protected override Task RunCoreAsync( + protected override Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { return this.RunCoreStreamingAsync(messages, thread, options, cancellationToken) - .ToAgentRunResponseAsync(cancellationToken); + .ToAgentResponseAsync(cancellationToken); } /// - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -64,7 +64,7 @@ protected override async IAsyncEnumerable RunCoreStreami messagesWithState.Add(stateMessage); // Stream the response and update state when received - await foreach (AgentRunResponseUpdate update in this.InnerAgent.RunStreamingAsync(messagesWithState, thread, options, cancellationToken)) + await foreach (AgentResponseUpdate update in this.InnerAgent.RunStreamingAsync(messagesWithState, thread, options, cancellationToken)) { // Check if this update contains a state snapshot foreach (AIContent content in update.Contents) diff --git a/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Server/Program.cs b/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Server/Program.cs index 40c51887d1..a6bd6f5ef6 100644 --- a/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Server/Program.cs +++ b/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Server/Program.cs @@ -34,7 +34,7 @@ new DefaultAzureCredential()) .GetChatClient(deploymentName); -AIAgent baseAgent = chatClient.AsIChatClient().CreateAIAgent( +AIAgent baseAgent = chatClient.AsIChatClient().AsAIAgent( name: "RecipeAgent", instructions: """ You are a helpful recipe assistant. When users ask you to create or suggest a recipe, diff --git a/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Server/SharedStateAgent.cs b/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Server/SharedStateAgent.cs index 603698b579..1ac21adfce 100644 --- a/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Server/SharedStateAgent.cs +++ b/dotnet/samples/GettingStarted/AGUI/Step05_StateManagement/Server/SharedStateAgent.cs @@ -17,17 +17,17 @@ public SharedStateAgent(AIAgent innerAgent, JsonSerializerOptions jsonSerializer this._jsonSerializerOptions = jsonSerializerOptions; } - protected override Task RunCoreAsync( + protected override Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { return this.RunCoreStreamingAsync(messages, thread, options, cancellationToken) - .ToAgentRunResponseAsync(cancellationToken); + .ToAgentResponseAsync(cancellationToken); } - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -91,7 +91,7 @@ stateObj is not JsonElement state || var firstRunMessages = messages.Append(stateUpdateMessage); // Collect all updates from first run - var allUpdates = new List(); + var allUpdates = new List(); await foreach (var update in this.InnerAgent.RunStreamingAsync(firstRunMessages, thread, firstRunOptions, cancellationToken).ConfigureAwait(false)) { allUpdates.Add(update); @@ -104,7 +104,7 @@ stateObj is not JsonElement state || } } - var response = allUpdates.ToAgentRunResponse(); + var response = allUpdates.ToAgentResponse(); // Try to deserialize the structured state response if (response.TryDeserialize(this._jsonSerializerOptions, out JsonElement stateSnapshot)) @@ -113,7 +113,7 @@ stateObj is not JsonElement state || byte[] stateBytes = JsonSerializer.SerializeToUtf8Bytes( stateSnapshot, this._jsonSerializerOptions.GetTypeInfo(typeof(JsonElement))); - yield return new AgentRunResponseUpdate + yield return new AgentResponseUpdate { Contents = [new DataContent(stateBytes, "application/json")] }; diff --git a/dotnet/samples/GettingStarted/AgentOpenTelemetry/Program.cs b/dotnet/samples/GettingStarted/AgentOpenTelemetry/Program.cs index dd5c6f9c7d..abef6ee30f 100644 --- a/dotnet/samples/GettingStarted/AgentOpenTelemetry/Program.cs +++ b/dotnet/samples/GettingStarted/AgentOpenTelemetry/Program.cs @@ -128,7 +128,7 @@ static async Task GetWeatherAsync([Description("The location to get the .UseOpenTelemetry(SourceName, configure: (cfg) => cfg.EnableSensitiveData = true) // enable telemetry at the agent level .Build(); -var thread = agent.GetNewThread(); +var thread = await agent.GetNewThreadAsync(); appLogger.LogInformation("Agent created successfully with ID: {AgentId}", agent.Id); diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_A2A/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_A2A/Program.cs index 46ac8a55fa..3d72a82c11 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_A2A/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_A2A/Program.cs @@ -14,5 +14,5 @@ AIAgent agent = await agentCardResolver.GetAIAgentAsync(); // Invoke the agent and output the text result. -AgentRunResponse response = await agent.RunAsync("Tell me a joke about a pirate."); +AgentResponse response = await agent.RunAsync("Tell me a joke about a pirate."); Console.WriteLine(response); diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_A2A/README.md b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_A2A/README.md index 536514306e..f76af52f02 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_A2A/README.md +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_A2A/README.md @@ -26,9 +26,9 @@ using Microsoft.Agents.AI.A2A; A2AClient a2aClient = new(new Uri("https://your-a2a-agent-host/echo")); // Create an AIAgent from the A2AClient -AIAgent agent = a2aClient.GetAIAgent(); +AIAgent agent = a2aClient.AsAIAgent(); // Run the agent -AgentRunResponse response = await agent.RunAsync("Tell me a joke about a pirate."); +AgentResponse response = await agent.RunAsync("Tell me a joke about a pirate."); Console.WriteLine(response); ``` \ No newline at end of file diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_Anthropic/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_Anthropic/Program.cs index df070c335b..ad49c9229e 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_Anthropic/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_Anthropic/Program.cs @@ -26,7 +26,7 @@ ? new AnthropicFoundryClient(new AnthropicFoundryApiKeyCredentials(apiKey, resource)) // If an apiKey is provided, use Foundry with ApiKey authentication : new AnthropicFoundryClient(new AnthropicAzureTokenCredential(new AzureCliCredential(), resource)); // Otherwise, use Foundry with Azure Client authentication -AIAgent agent = client.CreateAIAgent(model: deploymentName, instructions: JokerInstructions, name: JokerName); +AIAgent agent = client.AsAIAgent(model: deploymentName, instructions: JokerInstructions, name: JokerName); // Invoke the agent and output the text result. Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.")); diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureAIAgentsPersistent/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureAIAgentsPersistent/Program.cs index 31f18ee7ae..e3d37a39d7 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureAIAgentsPersistent/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureAIAgentsPersistent/Program.cs @@ -31,7 +31,7 @@ instructions: JokerInstructions); // You can then invoke the agent like any other AIAgent. -AgentThread thread = agent1.GetNewThread(); +AgentThread thread = await agent1.GetNewThreadAsync(); Console.WriteLine(await agent1.RunAsync("Tell me a joke about a pirate.", thread)); // Cleanup for sample purposes. diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureAIProject/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureAIProject/Program.cs index 2c2b9d1969..4ca52aa268 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureAIProject/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureAIProject/Program.cs @@ -27,7 +27,7 @@ // agentVersion.Name = // You can retrieve an AIAgent for an already created server side agent version. -AIAgent existingJokerAgent = aiProjectClient.GetAIAgent(createdAgentVersion); +AIAgent existingJokerAgent = aiProjectClient.AsAIAgent(createdAgentVersion); // You can also create another AIAgent version by providing the same name with a different definition. AIAgent newJokerAgent = aiProjectClient.CreateAIAgent(name: JokerName, model: deploymentName, instructions: "You are extremely hilarious at telling jokes."); @@ -40,7 +40,7 @@ Console.WriteLine($"Latest agent version id: {latestAgentVersion.Id}"); // Once you have the AIAgent, you can invoke it like any other AIAgent. -AgentThread thread = jokerAgentLatest.GetNewThread(); +AgentThread thread = await jokerAgentLatest.GetNewThreadAsync(); Console.WriteLine(await jokerAgentLatest.RunAsync("Tell me a joke about a pirate.", thread)); // This will use the same thread to continue the conversation. diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureFoundryModel/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureFoundryModel/Program.cs index 5ed66d3e4b..d22fc627ff 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureFoundryModel/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureFoundryModel/Program.cs @@ -25,7 +25,7 @@ AIAgent agent = client .GetChatClient(model) - .CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); + .AsAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); // Invoke the agent and output the text result. Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.")); diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIChatCompletion/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIChatCompletion/Program.cs index cf717550d2..ea647c2d4f 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIChatCompletion/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIChatCompletion/Program.cs @@ -14,7 +14,7 @@ new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); + .AsAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); // Invoke the agent and output the text result. Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.")); diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIResponses/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIResponses/Program.cs index 5ce85b2b91..31a24b6585 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIResponses/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIResponses/Program.cs @@ -14,7 +14,7 @@ new Uri(endpoint), new AzureCliCredential()) .GetResponsesClient(deploymentName) - .CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); + .AsAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); // Invoke the agent and output the text result. Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.")); diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs index 6beef64405..3e52a5f01a 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs @@ -28,16 +28,16 @@ internal sealed class UpperCaseParrotAgent : AIAgent { public override string? Name => "UpperCaseParrotAgent"; - public override AgentThread GetNewThread() - => new CustomAgentThread(); + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) + => new(new CustomAgentThread()); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - => new CustomAgentThread(serializedThread, jsonSerializerOptions); + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + => new(new CustomAgentThread(serializedThread, jsonSerializerOptions)); - protected override async Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override async Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { // Create a thread if the user didn't supply one. - thread ??= this.GetNewThread(); + thread ??= await this.GetNewThreadAsync(cancellationToken); if (thread is not CustomAgentThread typedThread) { @@ -58,7 +58,7 @@ protected override async Task RunCoreAsync(IEnumerable RunCoreAsync(IEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { // Create a thread if the user didn't supply one. - thread ??= this.GetNewThread(); + thread ??= await this.GetNewThreadAsync(cancellationToken); if (thread is not CustomAgentThread typedThread) { @@ -92,7 +92,7 @@ protected override async IAsyncEnumerable RunCoreStreami foreach (var message in responseMessages) { - yield return new AgentRunResponseUpdate + yield return new AgentResponseUpdate { AgentId = this.Id, AuthorName = message.AuthorName, diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_GoogleGemini/GeminiChatClient.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_GoogleGemini/GeminiChatClient.cs deleted file mode 100644 index 28f6f26013..0000000000 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_GoogleGemini/GeminiChatClient.cs +++ /dev/null @@ -1,558 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Runtime.CompilerServices; -using Google.Apis.Util; -using Google.GenAI; -using Google.GenAI.Types; - -namespace Microsoft.Extensions.AI; - -/// Provides an implementation based on . -internal sealed class GoogleGenAIChatClient : IChatClient -{ - /// The wrapped instance (optional). - private readonly Client? _client; - - /// The wrapped instance. - private readonly Models _models; - - /// The default model that should be used when no override is specified. - private readonly string? _defaultModelId; - - /// Lazily-initialized metadata describing the implementation. - private ChatClientMetadata? _metadata; - - /// Initializes a new instance. - public GoogleGenAIChatClient(Client client, string? defaultModelId) - { - this._client = client; - this._models = client.Models; - this._defaultModelId = defaultModelId; - } - - /// Initializes a new instance. - public GoogleGenAIChatClient(Models client, string? defaultModelId) - { - this._models = client; - this._defaultModelId = defaultModelId; - } - - /// - public async Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) - { - Utilities.ThrowIfNull(messages, nameof(messages)); - - // Create the request. - (string? modelId, List contents, GenerateContentConfig config) = this.CreateRequest(messages, options); - - // Send it. - GenerateContentResponse generateResult = await this._models.GenerateContentAsync(modelId!, contents, config).ConfigureAwait(false); - - // Create the response. - ChatResponse chatResponse = new(new ChatMessage(ChatRole.Assistant, [])) - { - CreatedAt = generateResult.CreateTime is { } dt ? new DateTimeOffset(dt) : null, - ModelId = !string.IsNullOrWhiteSpace(generateResult.ModelVersion) ? generateResult.ModelVersion : modelId, - RawRepresentation = generateResult, - ResponseId = generateResult.ResponseId, - }; - - // Populate the response messages. - chatResponse.FinishReason = PopulateResponseContents(generateResult, chatResponse.Messages[0].Contents); - - // Populate usage information if there is any. - if (generateResult.UsageMetadata is { } usageMetadata) - { - chatResponse.Usage = ExtractUsageDetails(usageMetadata); - } - - // Return the response. - return chatResponse; - } - - /// - public async IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Utilities.ThrowIfNull(messages, nameof(messages)); - - // Create the request. - (string? modelId, List contents, GenerateContentConfig config) = this.CreateRequest(messages, options); - - // Send it, and process the results. - await foreach (GenerateContentResponse generateResult in this._models.GenerateContentStreamAsync(modelId!, contents, config).WithCancellation(cancellationToken).ConfigureAwait(false)) - { - // Create a response update for each result in the stream. - ChatResponseUpdate responseUpdate = new(ChatRole.Assistant, []) - { - CreatedAt = generateResult.CreateTime is { } dt ? new DateTimeOffset(dt) : null, - ModelId = !string.IsNullOrWhiteSpace(generateResult.ModelVersion) ? generateResult.ModelVersion : modelId, - RawRepresentation = generateResult, - ResponseId = generateResult.ResponseId, - }; - - // Populate the response update contents. - responseUpdate.FinishReason = PopulateResponseContents(generateResult, responseUpdate.Contents); - - // Populate usage information if there is any. - if (generateResult.UsageMetadata is { } usageMetadata) - { - responseUpdate.Contents.Add(new UsageContent(ExtractUsageDetails(usageMetadata))); - } - - // Yield the update. - yield return responseUpdate; - } - } - - /// - public object? GetService(System.Type serviceType, object? serviceKey = null) - { - Utilities.ThrowIfNull(serviceType, nameof(serviceType)); - - if (serviceKey is null) - { - // If there's a request for metadata, lazily-initialize it and return it. We don't need to worry about race conditions, - // as there's no requirement that the same instance be returned each time, and creation is idempotent. - if (serviceType == typeof(ChatClientMetadata)) - { - return this._metadata ??= new("gcp.gen_ai", new("https://generativelanguage.googleapis.com/"), defaultModelId: this._defaultModelId); - } - - // Allow a consumer to "break glass" and access the underlying client if they need it. - if (serviceType.IsInstanceOfType(this._models)) - { - return this._models; - } - - if (this._client is not null && serviceType.IsInstanceOfType(this._client)) - { - return this._client; - } - - if (serviceType.IsInstanceOfType(this)) - { - return this; - } - } - - return null; - } - - /// - void IDisposable.Dispose() { /* nop */ } - - /// Creates the message parameters for from and . - private (string? ModelId, List Contents, GenerateContentConfig Config) CreateRequest(IEnumerable messages, ChatOptions? options) - { - // Create the GenerateContentConfig object. If the options contains a RawRepresentationFactory, try to use it to - // create the request instance, allowing the caller to populate it with GenAI-specific options. Otherwise, create - // a new instance directly. - string? model = this._defaultModelId; - List contents = []; - GenerateContentConfig config = options?.RawRepresentationFactory?.Invoke(this) as GenerateContentConfig ?? new(); - - if (options is not null) - { - if (options.FrequencyPenalty is { } frequencyPenalty) - { - config.FrequencyPenalty ??= frequencyPenalty; - } - - if (options.Instructions is { } instructions) - { - ((config.SystemInstruction ??= new()).Parts ??= []).Add(new() { Text = instructions }); - } - - if (options.MaxOutputTokens is { } maxOutputTokens) - { - config.MaxOutputTokens ??= maxOutputTokens; - } - - if (!string.IsNullOrWhiteSpace(options.ModelId)) - { - model = options.ModelId; - } - - if (options.PresencePenalty is { } presencePenalty) - { - config.PresencePenalty ??= presencePenalty; - } - - if (options.Seed is { } seed) - { - config.Seed ??= (int)seed; - } - - if (options.StopSequences is { } stopSequences) - { - (config.StopSequences ??= []).AddRange(stopSequences); - } - - if (options.Temperature is { } temperature) - { - config.Temperature ??= temperature; - } - - if (options.TopP is { } topP) - { - config.TopP ??= topP; - } - - if (options.TopK is { } topK) - { - config.TopK ??= topK; - } - - // Populate tools. Each kind of tool is added on its own, except for function declarations, - // which are grouped into a single FunctionDeclaration. - List? functionDeclarations = null; - if (options.Tools is { } tools) - { - foreach (var tool in tools) - { - switch (tool) - { - case AIFunctionDeclaration af: - functionDeclarations ??= []; - functionDeclarations.Add(new() - { - Name = af.Name, - Description = af.Description ?? "", - ParametersJsonSchema = af.JsonSchema, - }); - break; - - case HostedCodeInterpreterTool: - (config.Tools ??= []).Add(new() { CodeExecution = new() }); - break; - - case HostedFileSearchTool: - (config.Tools ??= []).Add(new() { Retrieval = new() }); - break; - - case HostedWebSearchTool: - (config.Tools ??= []).Add(new() { GoogleSearch = new() }); - break; - } - } - } - - if (functionDeclarations is { Count: > 0 }) - { - Tool functionTools = new(); - (functionTools.FunctionDeclarations ??= []).AddRange(functionDeclarations); - (config.Tools ??= []).Add(functionTools); - } - - // Transfer over the tool mode if there are any tools. - if (options.ToolMode is { } toolMode && config.Tools?.Count > 0) - { - switch (toolMode) - { - case NoneChatToolMode: - config.ToolConfig = new() { FunctionCallingConfig = new() { Mode = FunctionCallingConfigMode.NONE } }; - break; - - case AutoChatToolMode: - config.ToolConfig = new() { FunctionCallingConfig = new() { Mode = FunctionCallingConfigMode.AUTO } }; - break; - - case RequiredChatToolMode required: - config.ToolConfig = new() { FunctionCallingConfig = new() { Mode = FunctionCallingConfigMode.ANY } }; - if (required.RequiredFunctionName is not null) - { - ((config.ToolConfig.FunctionCallingConfig ??= new()).AllowedFunctionNames ??= []).Add(required.RequiredFunctionName); - } - break; - } - } - - // Set the response format if specified. - if (options.ResponseFormat is ChatResponseFormatJson responseFormat) - { - config.ResponseMimeType = "application/json"; - if (responseFormat.Schema is { } schema) - { - config.ResponseJsonSchema = schema; - } - } - } - - // Transfer messages to request, handling system messages specially - Dictionary? callIdToFunctionNames = null; - foreach (var message in messages) - { - if (message.Role == ChatRole.System) - { - string instruction = message.Text; - if (!string.IsNullOrWhiteSpace(instruction)) - { - ((config.SystemInstruction ??= new()).Parts ??= []).Add(new() { Text = instruction }); - } - - continue; - } - - Content content = new() { Role = message.Role == ChatRole.Assistant ? "model" : "user" }; - content.Parts ??= []; - AddPartsForAIContents(ref callIdToFunctionNames, message.Contents, content.Parts); - - contents.Add(content); - } - - // Make sure the request contains at least one content part (the request would always fail if empty). - if (!contents.SelectMany(c => c.Parts ?? Enumerable.Empty()).Any()) - { - contents.Add(new() { Role = "user", Parts = new() { { new() { Text = "" } } } }); - } - - return (model, contents, config); - } - - /// Creates s for and adds them to . - private static void AddPartsForAIContents(ref Dictionary? callIdToFunctionNames, IList contents, List parts) - { - for (int i = 0; i < contents.Count; i++) - { - var content = contents[i]; - - byte[]? thoughtSignature = null; - if (content is not TextReasoningContent { ProtectedData: not null } && - i + 1 < contents.Count && - contents[i + 1] is TextReasoningContent nextReasoning && - string.IsNullOrWhiteSpace(nextReasoning.Text) && - nextReasoning.ProtectedData is { } protectedData) - { - i++; - thoughtSignature = Convert.FromBase64String(protectedData); - } - - Part? part = null; - switch (content) - { - case TextContent textContent: - part = new() { Text = textContent.Text }; - break; - - case TextReasoningContent reasoningContent: - part = new() - { - Thought = true, - Text = !string.IsNullOrWhiteSpace(reasoningContent.Text) ? reasoningContent.Text : null, - ThoughtSignature = reasoningContent.ProtectedData is not null ? Convert.FromBase64String(reasoningContent.ProtectedData) : null, - }; - break; - - case DataContent dataContent: - part = new() - { - InlineData = new() - { - MimeType = dataContent.MediaType, - Data = dataContent.Data.ToArray(), - DisplayName = dataContent.Name, - } - }; - break; - - case UriContent uriContent: - part = new() - { - FileData = new() - { - FileUri = uriContent.Uri.AbsoluteUri, - MimeType = uriContent.MediaType, - } - }; - break; - - case FunctionCallContent functionCallContent: - (callIdToFunctionNames ??= [])[functionCallContent.CallId] = functionCallContent.Name; - callIdToFunctionNames[""] = functionCallContent.Name; // track last function name in case calls don't have IDs - - part = new() - { - FunctionCall = new() - { - Id = functionCallContent.CallId, - Name = functionCallContent.Name, - Args = functionCallContent.Arguments is null ? null : functionCallContent.Arguments as Dictionary ?? new(functionCallContent.Arguments!), - } - }; - break; - - case FunctionResultContent functionResultContent: - part = new() - { - FunctionResponse = new() - { - Id = functionResultContent.CallId, - Name = callIdToFunctionNames?.TryGetValue(functionResultContent.CallId, out string? functionName) is true || callIdToFunctionNames?.TryGetValue("", out functionName) is true ? - functionName : - null, - Response = functionResultContent.Result is null ? null : new() { ["result"] = functionResultContent.Result }, - } - }; - break; - } - - if (part is not null) - { - part.ThoughtSignature ??= thoughtSignature; - parts.Add(part); - } - } - } - - /// Creates s for and adds them to . - private static void AddAIContentsForParts(List parts, IList contents) - { - foreach (var part in parts) - { - AIContent? content = null; - - if (!string.IsNullOrEmpty(part.Text)) - { - content = part.Thought is true ? - new TextReasoningContent(part.Text) : - new TextContent(part.Text); - } - else if (part.InlineData is { } inlineData) - { - content = new DataContent(inlineData.Data, inlineData.MimeType ?? "application/octet-stream") - { - Name = inlineData.DisplayName, - }; - } - else if (part.FileData is { FileUri: not null } fileData) - { - content = new UriContent(new Uri(fileData.FileUri), fileData.MimeType ?? "application/octet-stream"); - } - else if (part.FunctionCall is { Name: not null } functionCall) - { - content = new FunctionCallContent(functionCall.Id ?? "", functionCall.Name, functionCall.Args!); - } - else if (part.FunctionResponse is { } functionResponse) - { - content = new FunctionResultContent( - functionResponse.Id ?? "", - functionResponse.Response?.TryGetValue("output", out var output) is true ? output : - functionResponse.Response?.TryGetValue("error", out var error) is true ? error : - null); - } - - if (content is not null) - { - content.RawRepresentation = part; - contents.Add(content); - - if (part.ThoughtSignature is { } thoughtSignature) - { - contents.Add(new TextReasoningContent(null) - { - ProtectedData = Convert.ToBase64String(thoughtSignature), - }); - } - } - } - } - - private static ChatFinishReason? PopulateResponseContents(GenerateContentResponse generateResult, IList responseContents) - { - ChatFinishReason? finishReason = null; - - // Populate the response messages. There should only be at most one candidate, but if there are more, ignore all but the first. - if (generateResult.Candidates is { Count: > 0 } && - generateResult.Candidates[0] is { Content: { } candidateContent } candidate) - { - // Grab the finish reason if one exists. - finishReason = ConvertFinishReason(candidate.FinishReason); - - // Add all of the response content parts as AIContents. - if (candidateContent.Parts is { } parts) - { - AddAIContentsForParts(parts, responseContents); - } - - // Add any citation metadata. - if (candidate.CitationMetadata is { Citations: { Count: > 0 } citations } && - responseContents.OfType().FirstOrDefault() is TextContent textContent) - { - foreach (var citation in citations) - { - textContent.Annotations = - [ - new CitationAnnotation() - { - Title = citation.Title, - Url = Uri.TryCreate(citation.Uri, UriKind.Absolute, out Uri? uri) ? uri : null, - AnnotatedRegions = - [ - new TextSpanAnnotatedRegion() - { - StartIndex = citation.StartIndex, - EndIndex = citation.EndIndex, - } - ], - } - ]; - } - } - } - - // Populate error information if there is any. - if (generateResult.PromptFeedback is { } promptFeedback) - { - responseContents.Add(new ErrorContent(promptFeedback.BlockReasonMessage)); - } - - return finishReason; - } - - /// Creates an M.E.AI from a Google . - private static ChatFinishReason? ConvertFinishReason(FinishReason? finishReason) - { - return finishReason switch - { - null => null, - - FinishReason.MAX_TOKENS => - ChatFinishReason.Length, - - FinishReason.MALFORMED_FUNCTION_CALL or - FinishReason.UNEXPECTED_TOOL_CALL => - ChatFinishReason.ToolCalls, - - FinishReason.FINISH_REASON_UNSPECIFIED or - FinishReason.STOP => - ChatFinishReason.Stop, - - _ => ChatFinishReason.ContentFilter, - }; - } - - /// Creates a populated from the supplied . - private static UsageDetails ExtractUsageDetails(GenerateContentResponseUsageMetadata usageMetadata) - { - UsageDetails details = new() - { - InputTokenCount = usageMetadata.PromptTokenCount, - OutputTokenCount = usageMetadata.CandidatesTokenCount, - TotalTokenCount = usageMetadata.TotalTokenCount, - }; - - AddIfPresent(nameof(usageMetadata.CachedContentTokenCount), usageMetadata.CachedContentTokenCount); - AddIfPresent(nameof(usageMetadata.ThoughtsTokenCount), usageMetadata.ThoughtsTokenCount); - AddIfPresent(nameof(usageMetadata.ToolUsePromptTokenCount), usageMetadata.ToolUsePromptTokenCount); - - return details; - - void AddIfPresent(string key, int? value) - { - if (value is int i) - { - (details.AdditionalCounts ??= [])[key] = i; - } - } - } -} diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_GoogleGemini/GoogleGenAIExtensions.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_GoogleGemini/GoogleGenAIExtensions.cs deleted file mode 100644 index b1044fa373..0000000000 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_GoogleGemini/GoogleGenAIExtensions.cs +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using Google.Apis.Util; -using Google.GenAI; - -namespace Microsoft.Extensions.AI; - -/// Provides implementations of Microsoft.Extensions.AI abstractions based on . -public static class GoogleGenAIExtensions -{ - /// - /// Creates an wrapper around the specified . - /// - /// The to wrap. - /// The default model ID to use for chat requests if not specified in . - /// An that wraps the specified client. - /// is . - public static IChatClient AsIChatClient(this Client client, string? defaultModelId = null) - { - Utilities.ThrowIfNull(client, nameof(client)); - return new GoogleGenAIChatClient(client, defaultModelId); - } - - /// - /// Creates an wrapper around the specified . - /// - /// The client to wrap. - /// The default model ID to use for chat requests if not specified in . - /// An that wraps the specified client. - /// is . - public static IChatClient AsIChatClient(this Models models, string? defaultModelId = null) - { - Utilities.ThrowIfNull(models, nameof(models)); - return new GoogleGenAIChatClient(models, defaultModelId); - } -} diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_GoogleGemini/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_GoogleGemini/Program.cs index 89c86d5c56..4f478baf22 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_GoogleGemini/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_GoogleGemini/Program.cs @@ -14,15 +14,13 @@ string model = Environment.GetEnvironmentVariable("GOOGLE_GENAI_MODEL") ?? "gemini-2.5-flash"; // Using a Google GenAI IChatClient implementation -// Until the PR https://github.com/googleapis/dotnet-genai/pull/81 is not merged this option -// requires usage of also both GeminiChatClient.cs and GoogleGenAIExtensions.cs polyfills to work. ChatClientAgent agentGenAI = new( new Client(vertexAI: false, apiKey: apiKey).AsIChatClient(model), name: JokerName, instructions: JokerInstructions); -AgentRunResponse response = await agentGenAI.RunAsync("Tell me a joke about a pirate."); +AgentResponse response = await agentGenAI.RunAsync("Tell me a joke about a pirate."); Console.WriteLine($"Google GenAI client based agent response:\n{response}"); // Using a community driven Mscc.GenerativeAI.Microsoft package diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_GoogleGemini/README.md b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_GoogleGemini/README.md index bc3a3592e6..d4c8d1097b 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_GoogleGemini/README.md +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_GoogleGemini/README.md @@ -25,12 +25,7 @@ $env:GOOGLE_GENAI_MODEL="gemini-2.5-fast" # Optional, defaults to gemini-2.5-fa ### Google GenAI (Official) -The official Google GenAI package provides direct access to Google's Generative AI models. This sample uses an extension method to convert the Google client to an `IChatClient`. - -> [!NOTE] -> Until PR [googleapis/dotnet-genai#81](https://github.com/googleapis/dotnet-genai/pull/81) is merged, this option requires the additional `GeminiChatClient.cs` and `GoogleGenAIExtensions.cs` files included in this sample. -> -> We appreciate any community push by liking and commenting in the above PR to get it merged and release as part of official Google GenAI package. +The official Google GenAI package provides direct access to Google's Generative AI models. This sample uses the `AsIChatClient()` extension method to convert the Google client to an `IChatClient`. ### Mscc.GenerativeAI.Microsoft (Community) diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_ONNX/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_ONNX/Program.cs index d6c306bfd1..5385aab7a1 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_ONNX/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_ONNX/Program.cs @@ -12,7 +12,7 @@ // Get a chat client for ONNX and use it to construct an AIAgent. using OnnxRuntimeGenAIChatClient chatClient = new(modelPath); -AIAgent agent = chatClient.CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); +AIAgent agent = chatClient.AsAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); // Invoke the agent and output the text result. Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.")); diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_Ollama/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_Ollama/Program.cs index 8cacfef3ef..89f92a98a4 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_Ollama/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_Ollama/Program.cs @@ -11,7 +11,7 @@ // Get a chat client for Ollama and use it to construct an AIAgent. AIAgent agent = new OllamaApiClient(new Uri(endpoint), modelName) - .CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); + .AsAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); // Invoke the agent and output the text result. Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.")); diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIAssistants/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIAssistants/Program.cs index 3079bd103c..eb194badfe 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIAssistants/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIAssistants/Program.cs @@ -33,7 +33,7 @@ instructions: JokerInstructions); // You can invoke the agent like any other AIAgent. -AgentThread thread = agent1.GetNewThread(); +AgentThread thread = await agent1.GetNewThreadAsync(); Console.WriteLine(await agent1.RunAsync("Tell me a joke about a pirate.", thread)); // Cleanup for sample purposes. diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIChatCompletion/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIChatCompletion/Program.cs index b4c6d626fc..3b22c21b3b 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIChatCompletion/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIChatCompletion/Program.cs @@ -13,7 +13,7 @@ AIAgent agent = new OpenAIClient( apiKey) .GetChatClient(model) - .CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); + .AsAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); // Invoke the agent and output the text result. Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.")); diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIResponses/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIResponses/Program.cs index b0d0285928..1e5883c67d 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIResponses/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIResponses/Program.cs @@ -12,7 +12,7 @@ AIAgent agent = new OpenAIClient( apiKey) .GetResponsesClient(model) - .CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); + .AsAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); // Invoke the agent and output the text result. Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.")); diff --git a/dotnet/samples/GettingStarted/AgentWithAnthropic/Agent_Anthropic_Step01_Running/Program.cs b/dotnet/samples/GettingStarted/AgentWithAnthropic/Agent_Anthropic_Step01_Running/Program.cs index cf7e29c2fe..085fbfd989 100644 --- a/dotnet/samples/GettingStarted/AgentWithAnthropic/Agent_Anthropic_Step01_Running/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithAnthropic/Agent_Anthropic_Step01_Running/Program.cs @@ -11,7 +11,7 @@ var model = Environment.GetEnvironmentVariable("ANTHROPIC_MODEL") ?? "claude-haiku-4-5"; AIAgent agent = new AnthropicClient(new ClientOptions { APIKey = apiKey }) - .CreateAIAgent(model: model, instructions: "You are good at telling jokes.", name: "Joker"); + .AsAIAgent(model: model, instructions: "You are good at telling jokes.", name: "Joker"); // Invoke the agent and output the text result. var response = await agent.RunAsync("Tell me a joke about a pirate."); diff --git a/dotnet/samples/GettingStarted/AgentWithAnthropic/Agent_Anthropic_Step02_Reasoning/Program.cs b/dotnet/samples/GettingStarted/AgentWithAnthropic/Agent_Anthropic_Step02_Reasoning/Program.cs index d362a9dd0d..120402ee14 100644 --- a/dotnet/samples/GettingStarted/AgentWithAnthropic/Agent_Anthropic_Step02_Reasoning/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithAnthropic/Agent_Anthropic_Step02_Reasoning/Program.cs @@ -14,7 +14,7 @@ var thinkingTokens = 2048; var agent = new AnthropicClient(new ClientOptions { APIKey = apiKey }) - .CreateAIAgent( + .AsAIAgent( model: model, clientFactory: (chatClient) => chatClient .AsBuilder() diff --git a/dotnet/samples/GettingStarted/AgentWithAnthropic/Agent_Anthropic_Step03_UsingFunctionTools/Program.cs b/dotnet/samples/GettingStarted/AgentWithAnthropic/Agent_Anthropic_Step03_UsingFunctionTools/Program.cs index a56db8d4a2..4253e8819f 100644 --- a/dotnet/samples/GettingStarted/AgentWithAnthropic/Agent_Anthropic_Step03_UsingFunctionTools/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithAnthropic/Agent_Anthropic_Step03_UsingFunctionTools/Program.cs @@ -23,15 +23,15 @@ static string GetWeather([Description("The location to get the weather for.")] s // Get anthropic client to create agents. AIAgent agent = new AnthropicClient { APIKey = apiKey } - .CreateAIAgent(model: model, instructions: AssistantInstructions, name: AssistantName, tools: [tool]); + .AsAIAgent(model: model, instructions: AssistantInstructions, name: AssistantName, tools: [tool]); // Non-streaming agent interaction with function tools. -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); Console.WriteLine(await agent.RunAsync("What is the weather like in Amsterdam?", thread)); // Streaming agent interaction with function tools. -thread = agent.GetNewThread(); -await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync("What is the weather like in Amsterdam?", thread)) +thread = await agent.GetNewThreadAsync(); +await foreach (AgentResponseUpdate update in agent.RunStreamingAsync("What is the weather like in Amsterdam?", thread)) { Console.WriteLine(update); } diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs index a11edafabc..b8fe566dc9 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs @@ -30,11 +30,11 @@ new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent(new ChatClientAgentOptions + .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - AIContextProviderFactory = (ctx) => new ChatHistoryMemoryProvider( + AIContextProviderFactory = (ctx, ct) => new ValueTask(new ChatHistoryMemoryProvider( vectorStore, collectionName: "chathistory", vectorDimensions: 3072, @@ -43,18 +43,18 @@ storageScope: new() { UserId = "UID1", ThreadId = new Guid().ToString() }, // Configure the scope which would be used to search for relevant prior messages. // In this case, we are searching for any messages for the user across all threads. - searchScope: new() { UserId = "UID1" }) + searchScope: new() { UserId = "UID1" })) }); // Start a new thread for the agent conversation. -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); // Run the agent with the thread that stores conversation history in the vector store. Console.WriteLine(await agent.RunAsync("I like jokes about Pirates. Tell me a joke about a pirate.", thread)); // Start a second thread. Since we configured the search scope to be across all threads for the user, // the agent should remember that the user likes pirate jokes. -AgentThread thread2 = agent.GetNewThread(); +AgentThread thread2 = await agent.GetNewThreadAsync(); // Run the agent with the second thread. Console.WriteLine(await agent.RunAsync("Tell me a joke that I might like.", thread2)); diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs index 739c5e3f13..da0e816448 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs @@ -28,19 +28,19 @@ new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent(new ChatClientAgentOptions() + .AsAIAgent(new ChatClientAgentOptions() { ChatOptions = new() { Instructions = "You are a friendly travel assistant. Use known memories about the user when responding, and do not invent details." }, - AIContextProviderFactory = ctx => ctx.SerializedState.ValueKind is not JsonValueKind.Null and not JsonValueKind.Undefined + AIContextProviderFactory = (ctx, ct) => new ValueTask(ctx.SerializedState.ValueKind is not JsonValueKind.Null and not JsonValueKind.Undefined // If each thread should have its own Mem0 scope, you can create a new id per thread here: // ? new Mem0Provider(mem0HttpClient, new Mem0ProviderScope() { ThreadId = Guid.NewGuid().ToString() }) // In this case we are storing memories scoped by application and user instead so that memories are retained across threads. ? new Mem0Provider(mem0HttpClient, new Mem0ProviderScope() { ApplicationId = "getting-started-agents", UserId = "sample-user" }) // For cases where we are restoring from serialized state: - : new Mem0Provider(mem0HttpClient, ctx.SerializedState, ctx.JsonSerializerOptions) + : new Mem0Provider(mem0HttpClient, ctx.SerializedState, ctx.JsonSerializerOptions)) }); -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); // Clear any existing memories for this scope to demonstrate fresh behavior. Mem0Provider mem0Provider = thread.GetService()!; @@ -56,9 +56,9 @@ Console.WriteLine("\n>> Serialize and deserialize the thread to demonstrate persisted state\n"); JsonElement serializedThread = thread.Serialize(); -AgentThread restoredThread = agent.DeserializeThread(serializedThread); +AgentThread restoredThread = await agent.DeserializeThreadAsync(serializedThread); Console.WriteLine(await agent.RunAsync("Can you recap the personal details you remember?", restoredThread)); Console.WriteLine("\n>> Start a new thread that shares the same Mem0 scope\n"); -AgentThread newThread = agent.GetNewThread(); +AgentThread newThread = await agent.GetNewThreadAsync(); Console.WriteLine(await agent.RunAsync("Summarize what you already know about me.", newThread)); diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs index 5727e8ca3c..4e84a4b53c 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs @@ -30,14 +30,14 @@ // and preferably shared between multiple threads used by the same user, ensure that the // factory reads the user id from the current context and scopes the memory component // and its storage to that user id. -AIAgent agent = chatClient.CreateAIAgent(new ChatClientAgentOptions() +AIAgent agent = chatClient.AsAIAgent(new ChatClientAgentOptions() { ChatOptions = new() { Instructions = "You are a friendly assistant. Always address the user by their name." }, - AIContextProviderFactory = ctx => new UserInfoMemory(chatClient.AsIChatClient(), ctx.SerializedState, ctx.JsonSerializerOptions) + AIContextProviderFactory = (ctx, ct) => new ValueTask(new UserInfoMemory(chatClient.AsIChatClient(), ctx.SerializedState, ctx.JsonSerializerOptions)) }); // Create a new thread for the conversation. -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); Console.WriteLine(">> Use thread with blank memory\n"); @@ -52,7 +52,7 @@ Console.WriteLine("\n>> Use deserialized thread with previously created memories\n"); // Later we can deserialize the thread and continue the conversation with the previous memory component state. -var deserializedThread = agent.DeserializeThread(threadElement); +var deserializedThread = await agent.DeserializeThreadAsync(threadElement); Console.WriteLine(await agent.RunAsync("What is my name and age?", deserializedThread)); Console.WriteLine("\n>> Read memories from memory component\n"); @@ -68,7 +68,7 @@ // It is also possible to set the memories in a memory component on an individual thread. // This is useful if we want to start a new thread, but have it share the same memories as a previous thread. -var newThread = agent.GetNewThread(); +var newThread = await agent.GetNewThreadAsync(); if (userInfo is not null && newThread.GetService() is UserInfoMemory newThreadMemory) { newThreadMemory.UserInfo = userInfo; diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step01_Running/Program.cs b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step01_Running/Program.cs index ccd42a2007..78ea76e03f 100644 --- a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step01_Running/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step01_Running/Program.cs @@ -12,7 +12,7 @@ AIAgent agent = new OpenAIClient(apiKey) .GetChatClient(model) - .CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); + .AsAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); UserChatMessage chatMessage = new("Tell me a joke about a pirate."); diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step03_CreateFromChatClient/OpenAIChatClientAgent.cs b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step03_CreateFromChatClient/OpenAIChatClientAgent.cs index a0b59d1053..3694f70b11 100644 --- a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step03_CreateFromChatClient/OpenAIChatClientAgent.cs +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step03_CreateFromChatClient/OpenAIChatClientAgent.cs @@ -87,10 +87,10 @@ public virtual IAsyncEnumerable RunStreamingAsync } /// - protected sealed override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => + protected sealed override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => base.RunCoreAsync(messages, thread, options, cancellationToken); /// - protected override IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => + protected override IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => base.RunCoreStreamingAsync(messages, thread, options, cancellationToken); } diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/OpenAIResponseClientAgent.cs b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/OpenAIResponseClientAgent.cs index f894a5434c..e0c4f36356 100644 --- a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/OpenAIResponseClientAgent.cs +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/OpenAIResponseClientAgent.cs @@ -105,10 +105,10 @@ public virtual async IAsyncEnumerable RunStreamingAsync } /// - protected sealed override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => + protected sealed override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => base.RunCoreAsync(messages, thread, options, cancellationToken); /// - protected sealed override IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => + protected sealed override IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => base.RunCoreStreamingAsync(messages, thread, options, cancellationToken); } diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/Program.cs b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/Program.cs index 8aebebdfa0..07a67edae4 100644 --- a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/Program.cs @@ -30,7 +30,7 @@ string conversationId = createConversationResultAsJson.RootElement.GetProperty("id"u8)!.GetString()!; // Create a thread for the conversation - this enables conversation state management for subsequent turns -AgentThread thread = agent.GetNewThread(conversationId); +AgentThread thread = await agent.GetNewThreadAsync(conversationId); Console.WriteLine("=== Multi-turn Conversation Demo ===\n"); diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/README.md b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/README.md index c279ba2c17..5b999955b2 100644 --- a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/README.md +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/README.md @@ -33,7 +33,7 @@ The `AgentThread` works with `ChatClientAgentRunOptions` to link the agent to a ChatClientAgentRunOptions agentRunOptions = new() { ChatOptions = new ChatOptions() { ConversationId = conversationId } }; // Create a thread for the conversation -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); // First call links the thread to the conversation ChatCompletion firstResponse = await agent.RunAsync([firstMessage], thread, agentRunOptions); @@ -59,7 +59,7 @@ foreach (ClientResult result in getConversationItemsResults.GetRawPages()) 1. **Create an OpenAI Client**: Initialize an `OpenAIClient` with your API key 2. **Create a Conversation**: Use `ConversationClient` to create a server-side conversation 3. **Create an Agent**: Initialize an `OpenAIResponseClientAgent` with the desired model and instructions -4. **Create a Thread**: Call `agent.GetNewThread()` to create a new conversation thread +4. **Create a Thread**: Call `agent.GetNewThreadAsync()` to create a new conversation thread 5. **Link Thread to Conversation**: Pass `ChatClientAgentRunOptions` with the `ConversationId` on the first call 6. **Send Messages**: Subsequent calls to `agent.RunAsync()` only need the thread - context is maintained 7. **Cleanup**: Delete the conversation when done using `conversationClient.DeleteConversation()` diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs index 9207a08182..a4904ecf77 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs @@ -59,18 +59,18 @@ // Create the AI agent with the TextSearchProvider as the AI context provider. AIAgent agent = azureOpenAIClient .GetChatClient(deploymentName) - .CreateAIAgent(new ChatClientAgentOptions + .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." }, - AIContextProviderFactory = ctx => new TextSearchProvider(SearchAdapter, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions), + AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(SearchAdapter, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions)), // Since we are using ChatCompletion which stores chat history locally, we can also add a message removal policy // that removes messages produced by the TextSearchProvider before they are added to the chat history, so that // we don't bloat chat history with all the search result messages. - ChatMessageStoreFactory = ctx => new InMemoryChatMessageStore(ctx.SerializedState, ctx.JsonSerializerOptions) - .WithAIContextProviderMessageRemoval(), + ChatMessageStoreFactory = (ctx, ct) => new ValueTask(new InMemoryChatMessageStore(ctx.SerializedState, ctx.JsonSerializerOptions) + .WithAIContextProviderMessageRemoval()), }); -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); Console.WriteLine(">> Asking about returns\n"); Console.WriteLine(await agent.RunAsync("Hi! I need help understanding the return policy.", thread)); diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs index f20e42f01d..40e2834317 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs @@ -68,13 +68,13 @@ // Create the AI agent with the TextSearchProvider as the AI context provider. AIAgent agent = azureOpenAIClient .GetChatClient(deploymentName) - .CreateAIAgent(new ChatClientAgentOptions + .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for the Microsoft Agent Framework. Answer questions using the provided context and cite the source document when available. Keep responses brief." }, - AIContextProviderFactory = ctx => new TextSearchProvider(SearchAdapter, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions) + AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(SearchAdapter, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions)) }); -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); Console.WriteLine(">> Asking about SK threads\n"); Console.WriteLine(await agent.RunAsync("Hi! How do I create a thread in Semantic Kernel?", thread)); diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs index e9a62e382f..a8b21ec287 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs @@ -26,13 +26,13 @@ new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent(new ChatClientAgentOptions + .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." }, - AIContextProviderFactory = ctx => new TextSearchProvider(MockSearchAsync, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions) + AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(MockSearchAsync, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions)) }); -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); Console.WriteLine(">> Asking about returns\n"); Console.WriteLine(await agent.RunAsync("Hi! I need help understanding the return policy.", thread)); diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step04_FoundryServiceRAG/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step04_FoundryServiceRAG/Program.cs index 0989394185..e93fd474f6 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step04_FoundryServiceRAG/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step04_FoundryServiceRAG/Program.cs @@ -43,7 +43,7 @@ instructions: "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available.", tools: [fileSearchTool]); -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); Console.WriteLine(">> Asking about returns\n"); Console.WriteLine(await agent.RunAsync("Hi! I need help understanding the return policy.", thread)); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step01_Running/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step01_Running/Program.cs index 889045c228..3ce20975b9 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step01_Running/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step01_Running/Program.cs @@ -14,7 +14,7 @@ new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); + .AsAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); // Invoke the agent and output the text result. Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.")); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step02_MultiturnConversation/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step02_MultiturnConversation/Program.cs index e27a5bb36d..22e5078498 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step02_MultiturnConversation/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step02_MultiturnConversation/Program.cs @@ -14,15 +14,15 @@ new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); + .AsAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); // Invoke the agent with a multi-turn conversation, where the context is preserved in the thread object. -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); Console.WriteLine(await agent.RunAsync("Now add some emojis to the joke and tell it in the voice of a pirate's parrot.", thread)); // Invoke the agent with a multi-turn conversation and streaming, where the context is preserved in the thread object. -thread = agent.GetNewThread(); +thread = await agent.GetNewThreadAsync(); await foreach (var update in agent.RunStreamingAsync("Tell me a joke about a pirate.", thread)) { Console.WriteLine(update); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step03_UsingFunctionTools/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step03_UsingFunctionTools/Program.cs index ae41572cc2..87cc021fb3 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step03_UsingFunctionTools/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step03_UsingFunctionTools/Program.cs @@ -22,7 +22,7 @@ static string GetWeather([Description("The location to get the weather for.")] s new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent(instructions: "You are a helpful assistant", tools: [AIFunctionFactory.Create(GetWeather)]); + .AsAIAgent(instructions: "You are a helpful assistant", tools: [AIFunctionFactory.Create(GetWeather)]); // Non-streaming agent interaction with function tools. Console.WriteLine(await agent.RunAsync("What is the weather like in Amsterdam?")); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step04_UsingFunctionToolsWithApprovals/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step04_UsingFunctionToolsWithApprovals/Program.cs index be2a4801ae..b12bf48e6e 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step04_UsingFunctionToolsWithApprovals/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step04_UsingFunctionToolsWithApprovals/Program.cs @@ -27,10 +27,10 @@ static string GetWeather([Description("The location to get the weather for.")] s new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent(instructions: "You are a helpful assistant", tools: [new ApprovalRequiredAIFunction(AIFunctionFactory.Create(GetWeather))]); + .AsAIAgent(instructions: "You are a helpful assistant", tools: [new ApprovalRequiredAIFunction(AIFunctionFactory.Create(GetWeather))]); // Call the agent and check if there are any user input requests to handle. -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); var response = await agent.RunAsync("What is the weather like in Amsterdam?", thread); var userInputRequests = response.UserInputRequests.ToList(); @@ -64,4 +64,4 @@ static string GetWeather([Description("The location to get the weather for.")] s Console.WriteLine($"\nAgent: {response}"); // For streaming use: -// Console.WriteLine($"\nAgent: {updates.ToAgentRunResponse()}"); +// Console.WriteLine($"\nAgent: {updates.ToAgentResponse()}"); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step05_StructuredOutput/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step05_StructuredOutput/Program.cs index 3b923069f4..38762ebfd1 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step05_StructuredOutput/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step05_StructuredOutput/Program.cs @@ -21,10 +21,10 @@ .GetChatClient(deploymentName); // Create the ChatClientAgent with the specified name and instructions. -ChatClientAgent agent = chatClient.CreateAIAgent(name: "HelpfulAssistant", instructions: "You are a helpful assistant."); +ChatClientAgent agent = chatClient.AsAIAgent(name: "HelpfulAssistant", instructions: "You are a helpful assistant."); // Set PersonInfo as the type parameter of RunAsync method to specify the expected structured output from the agent and invoke the agent with some unstructured input. -AgentRunResponse response = await agent.RunAsync("Please provide information about John Smith, who is a 35-year-old software engineer."); +AgentResponse response = await agent.RunAsync("Please provide information about John Smith, who is a 35-year-old software engineer."); // Access the structured output via the Result property of the agent response. Console.WriteLine("Assistant Output:"); @@ -33,7 +33,7 @@ Console.WriteLine($"Occupation: {response.Result.Occupation}"); // Create the ChatClientAgent with the specified name, instructions, and expected structured output the agent should produce. -ChatClientAgent agentWithPersonInfo = chatClient.CreateAIAgent(new ChatClientAgentOptions() +ChatClientAgent agentWithPersonInfo = chatClient.AsAIAgent(new ChatClientAgentOptions() { Name = "HelpfulAssistant", ChatOptions = new() { Instructions = "You are a helpful assistant.", ResponseFormat = Microsoft.Extensions.AI.ChatResponseFormat.ForJsonSchema() } @@ -44,7 +44,7 @@ // Assemble all the parts of the streamed output, since we can only deserialize once we have the full json, // then deserialize the response into the PersonInfo class. -PersonInfo personInfo = (await updates.ToAgentRunResponseAsync()).Deserialize(JsonSerializerOptions.Web); +PersonInfo personInfo = (await updates.ToAgentResponseAsync()).Deserialize(JsonSerializerOptions.Web); Console.WriteLine("Assistant Output:"); Console.WriteLine($"Name: {personInfo.Name}"); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step06_PersistedConversations/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step06_PersistedConversations/Program.cs index 5d3247b69c..b23e1abe78 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step06_PersistedConversations/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step06_PersistedConversations/Program.cs @@ -16,10 +16,10 @@ new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); + .AsAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); // Start a new thread for the agent conversation. -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); // Run the agent with a new thread. Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); @@ -35,7 +35,7 @@ JsonElement reloadedSerializedThread = JsonElement.Parse(await File.ReadAllTextAsync(tempFilePath)); // Deserialize the thread state after loading from storage. -AgentThread resumedThread = agent.DeserializeThread(reloadedSerializedThread); +AgentThread resumedThread = await agent.DeserializeThreadAsync(reloadedSerializedThread); // Run the agent again with the resumed thread. Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedThread)); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs index 280c84dc0d..a03b3bb349 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs @@ -27,21 +27,19 @@ new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent(new ChatClientAgentOptions + .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - ChatMessageStoreFactory = ctx => - { + ChatMessageStoreFactory = (ctx, ct) => new ValueTask( // Create a new chat message store for this agent that stores the messages in a vector store. // Each thread must get its own copy of the VectorChatMessageStore, since the store // also contains the id that the thread is stored under. - return new VectorChatMessageStore(vectorStore, ctx.SerializedState, ctx.JsonSerializerOptions); - } + new VectorChatMessageStore(vectorStore, ctx.SerializedState, ctx.JsonSerializerOptions)) }); // Start a new thread for the agent conversation. -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); // Run the agent with the thread that stores conversation history in the vector store. Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); @@ -58,7 +56,7 @@ // and loaded again later. // Deserialize the thread state after loading from storage. -AgentThread resumedThread = agent.DeserializeThread(serializedThread); +AgentThread resumedThread = await agent.DeserializeThreadAsync(serializedThread); // Run the agent with the thread that stores conversation history in the vector store a second time. Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedThread)); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step08_Observability/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step08_Observability/Program.cs index e43e80d664..6a969d7512 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step08_Observability/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step08_Observability/Program.cs @@ -29,7 +29,7 @@ // Create the agent, and enable OpenTelemetry instrumentation. AIAgent agent = new AzureOpenAIClient(new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker") + .AsAIAgent(instructions: "You are good at telling jokes.", name: "Joker") .AsBuilder() .UseOpenTelemetry(sourceName: sourceName) .Build(); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step09_DependencyInjection/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step09_DependencyInjection/Program.cs index d1b75d2fe5..ab0ac64e99 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step09_DependencyInjection/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step09_DependencyInjection/Program.cs @@ -49,7 +49,7 @@ internal sealed class SampleService(AIAgent agent, IHostApplicationLifetime appL public async Task StartAsync(CancellationToken cancellationToken) { // Create a thread that will be used for the entirety of the service lifetime so that the user can ask follow up questions. - this._thread = agent.GetNewThread(); + this._thread = await agent.GetNewThreadAsync(cancellationToken); _ = this.RunAsync(appLifetime.ApplicationStopping); } diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step11_UsingImages/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step11_UsingImages/Program.cs index f534e4edd7..b517e3ee95 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step11_UsingImages/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step11_UsingImages/Program.cs @@ -13,7 +13,7 @@ var agent = new AzureOpenAIClient(new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent( + .AsAIAgent( name: "VisionAgent", instructions: "You are a helpful agent that can analyze images"); @@ -22,7 +22,7 @@ new UriContent("https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "image/jpeg") ]); -var thread = agent.GetNewThread(); +var thread = await agent.GetNewThreadAsync(); await foreach (var update in agent.RunStreamingAsync(message, thread)) { diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step12_AsFunctionTool/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step12_AsFunctionTool/Program.cs index 5e37ff4039..765174072f 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step12_AsFunctionTool/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step12_AsFunctionTool/Program.cs @@ -21,7 +21,7 @@ static string GetWeather([Description("The location to get the weather for.")] s new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent( + .AsAIAgent( instructions: "You answer questions about the weather.", name: "WeatherAgent", description: "An agent that answers questions about the weather.", @@ -32,7 +32,7 @@ static string GetWeather([Description("The location to get the weather for.")] s new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent(instructions: "You are a helpful assistant who responds in French.", tools: [weatherAgent.AsAIFunction()]); + .AsAIAgent(instructions: "You are a helpful assistant who responds in French.", tools: [weatherAgent.AsAIFunction()]); // Invoke the agent and output the text result. Console.WriteLine(await agent.RunAsync("What is the weather like in Amsterdam?")); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step13_BackgroundResponsesWithToolsAndPersistence/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step13_BackgroundResponsesWithToolsAndPersistence/Program.cs index 29dc347b4a..fee7b2900c 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step13_BackgroundResponsesWithToolsAndPersistence/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step13_BackgroundResponsesWithToolsAndPersistence/Program.cs @@ -23,7 +23,7 @@ new Uri(endpoint), new AzureCliCredential()) .GetResponsesClient(deploymentName) - .CreateAIAgent( + .AsAIAgent( name: "SpaceNovelWriter", instructions: "You are a space novel writer. Always research relevant facts and generate character profiles for the main characters before writing novels." + "Write complete chapters without asking for approval or feedback. Do not ask the user about tone, style, pace, or format preferences - just write the novel based on the request.", @@ -32,10 +32,10 @@ // Enable background responses (only supported by {Azure}OpenAI Responses at this time). AgentRunOptions options = new() { AllowBackgroundResponses = true }; -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); // Start the initial run. -AgentRunResponse response = await agent.RunAsync("Write a very long novel about a team of astronauts exploring an uncharted galaxy.", thread, options); +AgentResponse response = await agent.RunAsync("Write a very long novel about a team of astronauts exploring an uncharted galaxy.", thread, options); // Poll for background responses until complete. while (response.ContinuationToken is not null) @@ -44,10 +44,10 @@ await Task.Delay(TimeSpan.FromSeconds(10)); - RestoreAgentState(agent, out thread, out ResponseContinuationToken? continuationToken); + var (restoredThread, continuationToken) = await RestoreAgentState(agent); options.ContinuationToken = continuationToken; - response = await agent.RunAsync(thread, options); + response = await agent.RunAsync(restoredThread, options); } Console.WriteLine(response.Text); @@ -58,13 +58,15 @@ void PersistAgentState(AgentThread thread, ResponseContinuationToken? continuati stateStore["continuationToken"] = JsonSerializer.SerializeToElement(continuationToken, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ResponseContinuationToken))); } -void RestoreAgentState(AIAgent agent, out AgentThread thread, out ResponseContinuationToken? continuationToken) +async Task<(AgentThread Thread, ResponseContinuationToken? ContinuationToken)> RestoreAgentState(AIAgent agent) { JsonElement serializedThread = stateStore["thread"] ?? throw new InvalidOperationException("No serialized thread found in state store."); JsonElement? serializedToken = stateStore["continuationToken"]; - thread = agent.DeserializeThread(serializedThread); - continuationToken = (ResponseContinuationToken?)serializedToken?.Deserialize(AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ResponseContinuationToken))); + AgentThread thread = await agent.DeserializeThreadAsync(serializedThread); + ResponseContinuationToken? continuationToken = (ResponseContinuationToken?)serializedToken?.Deserialize(AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ResponseContinuationToken))); + + return (thread, continuationToken); } [Description("Researches relevant space facts and scientific information for writing a science fiction novel")] diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step14_Middleware/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step14_Middleware/Program.cs index a0ca338297..7e8f9de058 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step14_Middleware/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step14_Middleware/Program.cs @@ -45,7 +45,7 @@ static string GetDateTime() .Use(GuardrailMiddleware, null) .Build(); -var thread = middlewareEnabledAgent.GetNewThread(); +var thread = await middlewareEnabledAgent.GetNewThreadAsync(); Console.WriteLine("\n\n=== Example 1: Wording Guardrail ==="); var guardRailedResponse = await middlewareEnabledAgent.RunAsync("Tell me something harmful."); @@ -131,7 +131,7 @@ static string GetDateTime() } // This middleware redacts PII information from input and output messages. -async Task PIIMiddleware(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) +async Task PIIMiddleware(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) { // Redact PII information from input messages var filteredMessages = FilterMessages(messages); @@ -171,7 +171,7 @@ static string FilterPii(string content) } // This middleware enforces guardrails by redacting certain keywords from input and output messages. -async Task GuardrailMiddleware(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) +async Task GuardrailMiddleware(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) { // Redact keywords from input messages var filteredMessages = FilterMessages(messages); @@ -208,7 +208,7 @@ static string FilterContent(string content) } // This middleware handles Human in the loop console interaction for any user approval required during function calling. -async Task ConsolePromptingApprovalMiddleware(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) +async Task ConsolePromptingApprovalMiddleware(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) { var response = await innerAgent.RunAsync(messages, thread, options, cancellationToken); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step15_Plugins/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step15_Plugins/Program.cs index 38cd20b8d6..54f977352b 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step15_Plugins/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step15_Plugins/Program.cs @@ -31,7 +31,7 @@ new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent( + .AsAIAgent( instructions: "You are a helpful assistant that helps people find information.", name: "Assistant", tools: [.. serviceProvider.GetRequiredService().AsAITools()], diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs index decf0de25a..a80dd0fed0 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs @@ -20,14 +20,14 @@ new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent(new ChatClientAgentOptions + .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - ChatMessageStoreFactory = ctx => new InMemoryChatMessageStore(new MessageCountingChatReducer(2), ctx.SerializedState, ctx.JsonSerializerOptions) + ChatMessageStoreFactory = (ctx, ct) => new ValueTask(new InMemoryChatMessageStore(new MessageCountingChatReducer(2), ctx.SerializedState, ctx.JsonSerializerOptions)) }); -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); // Invoke the agent and output the text result. Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step17_BackgroundResponses/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step17_BackgroundResponses/Program.cs index 3e172a95b5..ae0151c8ab 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step17_BackgroundResponses/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step17_BackgroundResponses/Program.cs @@ -14,15 +14,15 @@ new Uri(endpoint), new AzureCliCredential()) .GetResponsesClient(deploymentName) - .CreateAIAgent(); + .AsAIAgent(); // Enable background responses (only supported by OpenAI Responses at this time). AgentRunOptions options = new() { AllowBackgroundResponses = true }; -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); // Start the initial run. -AgentRunResponse response = await agent.RunAsync("Write a very long novel about otters in space.", thread, options); +AgentResponse response = await agent.RunAsync("Write a very long novel about otters in space.", thread, options); // Poll until the response is complete. while (response.ContinuationToken is { } token) @@ -41,11 +41,11 @@ // Reset options and thread for streaming. options = new() { AllowBackgroundResponses = true }; -thread = agent.GetNewThread(); +thread = await agent.GetNewThreadAsync(); -AgentRunResponseUpdate? lastReceivedUpdate = null; +AgentResponseUpdate? lastReceivedUpdate = null; // Start streaming. -await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync("Write a very long novel about otters in space.", thread, options)) +await foreach (AgentResponseUpdate update in agent.RunStreamingAsync("Write a very long novel about otters in space.", thread, options)) { // Output each update. Console.Write(update.Text); @@ -63,7 +63,7 @@ // Resume from interruption point. options.ContinuationToken = lastReceivedUpdate?.ContinuationToken; -await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(thread, options)) +await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(thread, options)) { // Output each update. Console.Write(update.Text); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step18_DeepResearch/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step18_DeepResearch/Program.cs index f6aa825a54..e36612d89b 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step18_DeepResearch/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step18_DeepResearch/Program.cs @@ -39,7 +39,7 @@ try { - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(); await foreach (var response in agent.RunStreamingAsync(Task, thread)) { diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step01.2_Running/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step01.2_Running/Program.cs index 4d840d54ff..8ebceaea26 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step01.2_Running/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step01.2_Running/Program.cs @@ -27,7 +27,7 @@ AIAgent jokerAgent = aiProjectClient.GetAIAgent(agentVersion); // Invoke the agent with streaming support. -await foreach (AgentRunResponseUpdate update in jokerAgent.RunStreamingAsync("Tell me a joke about a pirate.")) +await foreach (AgentResponseUpdate update in jokerAgent.RunStreamingAsync("Tell me a joke about a pirate.")) { Console.WriteLine(update); } diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step02_MultiturnConversation/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step02_MultiturnConversation/Program.cs index 3cbb0099ea..4ad3d86255 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step02_MultiturnConversation/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step02_MultiturnConversation/Program.cs @@ -26,17 +26,17 @@ AIAgent jokerAgent = aiProjectClient.GetAIAgent(agentVersion); // Invoke the agent with a multi-turn conversation, where the context is preserved in the thread object. -AgentThread thread = jokerAgent.GetNewThread(); +AgentThread thread = await jokerAgent.GetNewThreadAsync(); Console.WriteLine(await jokerAgent.RunAsync("Tell me a joke about a pirate.", thread)); Console.WriteLine(await jokerAgent.RunAsync("Now add some emojis to the joke and tell it in the voice of a pirate's parrot.", thread)); // Invoke the agent with a multi-turn conversation and streaming, where the context is preserved in the thread object. -thread = jokerAgent.GetNewThread(); -await foreach (AgentRunResponseUpdate update in jokerAgent.RunStreamingAsync("Tell me a joke about a pirate.", thread)) +thread = await jokerAgent.GetNewThreadAsync(); +await foreach (AgentResponseUpdate update in jokerAgent.RunStreamingAsync("Tell me a joke about a pirate.", thread)) { Console.WriteLine(update); } -await foreach (AgentRunResponseUpdate update in jokerAgent.RunStreamingAsync("Now add some emojis to the joke and tell it in the voice of a pirate's parrot.", thread)) +await foreach (AgentResponseUpdate update in jokerAgent.RunStreamingAsync("Now add some emojis to the joke and tell it in the voice of a pirate's parrot.", thread)) { Console.WriteLine(update); } diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step03_UsingFunctionTools/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step03_UsingFunctionTools/Program.cs index 38c5a15d75..0f51bb8364 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step03_UsingFunctionTools/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step03_UsingFunctionTools/Program.cs @@ -37,12 +37,12 @@ static string GetWeather([Description("The location to get the weather for.")] s var existingAgent = await aiProjectClient.GetAIAgentAsync(name: AssistantName, tools: [tool]); // Non-streaming agent interaction with function tools. -AgentThread thread = existingAgent.GetNewThread(); +AgentThread thread = await existingAgent.GetNewThreadAsync(); Console.WriteLine(await existingAgent.RunAsync("What is the weather like in Amsterdam?", thread)); // Streaming agent interaction with function tools. -thread = existingAgent.GetNewThread(); -await foreach (AgentRunResponseUpdate update in existingAgent.RunStreamingAsync("What is the weather like in Amsterdam?", thread)) +thread = await existingAgent.GetNewThreadAsync(); +await foreach (AgentResponseUpdate update in existingAgent.RunStreamingAsync("What is the weather like in Amsterdam?", thread)) { Console.WriteLine(update); } diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step04_UsingFunctionToolsWithApprovals/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step04_UsingFunctionToolsWithApprovals/Program.cs index 1b51d210cf..1eb140cedc 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step04_UsingFunctionToolsWithApprovals/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step04_UsingFunctionToolsWithApprovals/Program.cs @@ -32,8 +32,8 @@ static string GetWeather([Description("The location to get the weather for.")] s // Call the agent with approval-required function tools. // The agent will request approval before invoking the function. -AgentThread thread = agent.GetNewThread(); -AgentRunResponse response = await agent.RunAsync("What is the weather like in Amsterdam?", thread); +AgentThread thread = await agent.GetNewThreadAsync(); +AgentResponse response = await agent.RunAsync("What is the weather like in Amsterdam?", thread); // Check if there are any user input requests (approvals needed). List userInputRequests = response.UserInputRequests.ToList(); diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step05_StructuredOutput/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step05_StructuredOutput/Program.cs index ac05565836..aaeb90fafa 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step05_StructuredOutput/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step05_StructuredOutput/Program.cs @@ -35,7 +35,7 @@ }); // Set PersonInfo as the type parameter of RunAsync method to specify the expected structured output from the agent and invoke the agent with some unstructured input. -AgentRunResponse response = await agent.RunAsync("Please provide information about John Smith, who is a 35-year-old software engineer."); +AgentResponse response = await agent.RunAsync("Please provide information about John Smith, who is a 35-year-old software engineer."); // Access the structured output via the Result property of the agent response. Console.WriteLine("Assistant Output:"); @@ -57,11 +57,11 @@ }); // Invoke the agent with some unstructured input while streaming, to extract the structured information from. -IAsyncEnumerable updates = agentWithPersonInfo.RunStreamingAsync("Please provide information about John Smith, who is a 35-year-old software engineer."); +IAsyncEnumerable updates = agentWithPersonInfo.RunStreamingAsync("Please provide information about John Smith, who is a 35-year-old software engineer."); // Assemble all the parts of the streamed output, since we can only deserialize once we have the full json, // then deserialize the response into the PersonInfo class. -PersonInfo personInfo = (await updates.ToAgentRunResponseAsync()).Deserialize(JsonSerializerOptions.Web); +PersonInfo personInfo = (await updates.ToAgentResponseAsync()).Deserialize(JsonSerializerOptions.Web); Console.WriteLine("Assistant Output:"); Console.WriteLine($"Name: {personInfo.Name}"); diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step06_PersistedConversations/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step06_PersistedConversations/Program.cs index d404a814c0..7c1c65c01d 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step06_PersistedConversations/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step06_PersistedConversations/Program.cs @@ -19,7 +19,7 @@ AIAgent agent = await aiProjectClient.CreateAIAgentAsync(name: JokerName, model: deploymentName, instructions: JokerInstructions); // Start a new thread for the agent conversation. -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); // Run the agent with a new thread. Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); @@ -35,7 +35,7 @@ JsonElement reloadedSerializedThread = JsonElement.Parse(await File.ReadAllTextAsync(tempFilePath))!; // Deserialize the thread state after loading from storage. -AgentThread resumedThread = agent.DeserializeThread(reloadedSerializedThread); +AgentThread resumedThread = await agent.DeserializeThreadAsync(reloadedSerializedThread); // Run the agent again with the resumed thread. Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedThread)); diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step07_Observability/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step07_Observability/Program.cs index eb011ba064..4ba3ee4d34 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step07_Observability/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step07_Observability/Program.cs @@ -38,12 +38,12 @@ .Build(); // Invoke the agent and output the text result. -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); // Invoke the agent with streaming support. -thread = agent.GetNewThread(); -await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync("Tell me a joke about a pirate.", thread)) +thread = await agent.GetNewThreadAsync(); +await foreach (AgentResponseUpdate update in agent.RunStreamingAsync("Tell me a joke about a pirate.", thread)) { Console.WriteLine(update); } diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step08_DependencyInjection/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step08_DependencyInjection/Program.cs index 4bf4843d66..d6d1dd8d9f 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step08_DependencyInjection/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step08_DependencyInjection/Program.cs @@ -42,7 +42,7 @@ internal sealed class SampleService(AIProjectClient client, AIAgent agent, IHost public async Task StartAsync(CancellationToken cancellationToken) { // Create a thread that will be used for the entirety of the service lifetime so that the user can ask follow up questions. - this._thread = agent.GetNewThread(); + this._thread = await agent.GetNewThreadAsync(cancellationToken); _ = this.RunAsync(appLifetime.ApplicationStopping); } @@ -65,7 +65,7 @@ public async Task RunAsync(CancellationToken cancellationToken) } // Stream the output to the console as it is generated. - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(input, this._thread, cancellationToken: cancellationToken)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(input, this._thread, cancellationToken: cancellationToken)) { Console.Write(update); } diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step10_UsingImages/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step10_UsingImages/Program.cs index a799fe46fb..fa841ca913 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step10_UsingImages/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step10_UsingImages/Program.cs @@ -24,9 +24,9 @@ new DataContent(File.ReadAllBytes("assets/walkway.jpg"), "image/jpeg") ]); -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); -await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(message, thread)) +await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(message, thread)) { Console.WriteLine(update); } diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step11_AsFunctionTool/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step11_AsFunctionTool/Program.cs index 9fb589f5ce..e00b05d9cb 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step11_AsFunctionTool/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step11_AsFunctionTool/Program.cs @@ -39,7 +39,7 @@ static string GetWeather([Description("The location to get the weather for.")] s tools: [weatherAgent.AsAIFunction()]); // Invoke the agent and output the text result. -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); Console.WriteLine(await agent.RunAsync("What is the weather like in Amsterdam?", thread)); // Cleanup by agent name removes the agent versions created. diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step12_Middleware/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step12_Middleware/Program.cs index 0a00e9107c..0c6b76612c 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step12_Middleware/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step12_Middleware/Program.cs @@ -49,21 +49,21 @@ static string GetDateTime() .Use(GuardrailMiddleware, null) .Build(); -AgentThread thread = middlewareEnabledAgent.GetNewThread(); +AgentThread thread = await middlewareEnabledAgent.GetNewThreadAsync(); Console.WriteLine("\n\n=== Example 1: Wording Guardrail ==="); -AgentRunResponse guardRailedResponse = await middlewareEnabledAgent.RunAsync("Tell me something harmful."); +AgentResponse guardRailedResponse = await middlewareEnabledAgent.RunAsync("Tell me something harmful."); Console.WriteLine($"Guard railed response: {guardRailedResponse}"); Console.WriteLine("\n\n=== Example 2: PII detection ==="); -AgentRunResponse piiResponse = await middlewareEnabledAgent.RunAsync("My name is John Doe, call me at 123-456-7890 or email me at john@something.com"); +AgentResponse piiResponse = await middlewareEnabledAgent.RunAsync("My name is John Doe, call me at 123-456-7890 or email me at john@something.com"); Console.WriteLine($"Pii filtered response: {piiResponse}"); Console.WriteLine("\n\n=== Example 3: Agent function middleware ==="); // Agent function middleware support is limited to agents that wraps a upstream ChatClientAgent or derived from it. -AgentRunResponse functionCallResponse = await middlewareEnabledAgent.RunAsync("What's the current time and the weather in Seattle?", thread); +AgentResponse functionCallResponse = await middlewareEnabledAgent.RunAsync("What's the current time and the weather in Seattle?", thread); Console.WriteLine($"Function calling response: {functionCallResponse}"); // Special per-request middleware agent. @@ -78,7 +78,7 @@ static string GetDateTime() tools: [new ApprovalRequiredAIFunction(AIFunctionFactory.Create(GetWeather, name: nameof(GetWeather)))]); // Using the ConsolePromptingApprovalMiddleware for a specific request to handle user approval during function calls. -AgentRunResponse response = await humanInTheLoopAgent +AgentResponse response = await humanInTheLoopAgent .AsBuilder() .Use(ConsolePromptingApprovalMiddleware, null) .Build() @@ -113,7 +113,7 @@ static string GetDateTime() } // This middleware redacts PII information from input and output messages. -async Task PIIMiddleware(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) +async Task PIIMiddleware(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) { // Redact PII information from input messages var filteredMessages = FilterMessages(messages); @@ -152,7 +152,7 @@ static string FilterPii(string content) } // This middleware enforces guardrails by redacting certain keywords from input and output messages. -async Task GuardrailMiddleware(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) +async Task GuardrailMiddleware(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) { // Redact keywords from input messages var filteredMessages = FilterMessages(messages); @@ -189,9 +189,9 @@ static string FilterContent(string content) } // This middleware handles Human in the loop console interaction for any user approval required during function calling. -async Task ConsolePromptingApprovalMiddleware(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) +async Task ConsolePromptingApprovalMiddleware(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) { - AgentRunResponse response = await innerAgent.RunAsync(messages, thread, options, cancellationToken); + AgentResponse response = await innerAgent.RunAsync(messages, thread, options, cancellationToken); List userInputRequests = response.UserInputRequests.ToList(); diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step13_Plugins/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step13_Plugins/Program.cs index b55f38b66b..0cd9674770 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step13_Plugins/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step13_Plugins/Program.cs @@ -42,7 +42,7 @@ services: serviceProvider); // Invoke the agent and output the text result. -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); Console.WriteLine(await agent.RunAsync("Tell me current time and weather in Seattle.", thread)); // Cleanup by agent name removes the agent version created. diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step14_CodeInterpreter/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step14_CodeInterpreter/Program.cs index 0f6f6ef2d9..858c678528 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step14_CodeInterpreter/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step14_CodeInterpreter/Program.cs @@ -49,10 +49,10 @@ // Either invoke option1 or option2 agent, should have same result // Option 1 -AgentRunResponse response = await agentOption1.RunAsync("I need to solve the equation sin(x) + x^2 = 42"); +AgentResponse response = await agentOption1.RunAsync("I need to solve the equation sin(x) + x^2 = 42"); // Option 2 -// AgentRunResponse response = await agentOption2.RunAsync("I need to solve the equation sin(x) + x^2 = 42"); +// AgentResponse response = await agentOption2.RunAsync("I need to solve the equation sin(x) + x^2 = 42"); // Get the CodeInterpreterToolCallContent CodeInterpreterToolCallContent? toolCallContent = response.Messages.SelectMany(m => m.Contents).OfType().FirstOrDefault(); diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step15_ComputerUse/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step15_ComputerUse/Program.cs index ff4f57924a..9d8cab17a0 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step15_ComputerUse/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step15_ComputerUse/Program.cs @@ -83,7 +83,7 @@ private static async Task InvokeComputerUseAgentAsync(AIAgent agent) AllowBackgroundResponses = true, }; - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(); ChatMessage message = new(ChatRole.User, [ new TextContent("I need you to help me search for 'OpenAI news'. Please type 'OpenAI news' and submit the search. Once you see search results, the task is complete."), @@ -93,7 +93,7 @@ private static async Task InvokeComputerUseAgentAsync(AIAgent agent) // Initial request with screenshot - start with Bing search page Console.WriteLine("Starting computer automation session (initial screenshot: cua_browser_search.png)..."); - AgentRunResponse runResponse = await agent.RunAsync(message, thread: thread, options: runOptions); + AgentResponse response = await agent.RunAsync(message, thread: thread, options: runOptions); // Main interaction loop const int MaxIterations = 10; @@ -105,7 +105,7 @@ private static async Task InvokeComputerUseAgentAsync(AIAgent agent) while (true) { // Poll until the response is complete. - while (runResponse.ContinuationToken is { } token) + while (response.ContinuationToken is { } token) { // Wait before polling again. await Task.Delay(TimeSpan.FromSeconds(2)); @@ -113,10 +113,10 @@ private static async Task InvokeComputerUseAgentAsync(AIAgent agent) // Continue with the token. runOptions.ContinuationToken = token; - runResponse = await agent.RunAsync(thread, runOptions); + response = await agent.RunAsync(thread, runOptions); } - Console.WriteLine($"Agent response received (ID: {runResponse.ResponseId})"); + Console.WriteLine($"Agent response received (ID: {response.ResponseId})"); if (iteration >= MaxIterations) { @@ -128,7 +128,7 @@ private static async Task InvokeComputerUseAgentAsync(AIAgent agent) Console.WriteLine($"\n--- Iteration {iteration} ---"); // Check for computer calls in the response - IEnumerable computerCallResponseItems = runResponse.Messages + IEnumerable computerCallResponseItems = response.Messages .SelectMany(x => x.Contents) .Where(c => c.RawRepresentation is ComputerCallResponseItem and not null) .Select(c => (ComputerCallResponseItem)c.RawRepresentation!); @@ -137,7 +137,7 @@ private static async Task InvokeComputerUseAgentAsync(AIAgent agent) if (firstComputerCall is null) { Console.WriteLine("No computer call actions found. Ending interaction."); - Console.WriteLine($"Final Response: {runResponse}"); + Console.WriteLine($"Final Response: {response}"); break; } @@ -168,7 +168,7 @@ private static async Task InvokeComputerUseAgentAsync(AIAgent agent) // Follow-up message with action result and new screenshot message = new(ChatRole.User, [content]); - runResponse = await agent.RunAsync(message, thread: thread, options: runOptions); + response = await agent.RunAsync(message, thread: thread, options: runOptions); } } } diff --git a/dotnet/samples/GettingStarted/ModelContextProtocol/Agent_MCP_Server/Program.cs b/dotnet/samples/GettingStarted/ModelContextProtocol/Agent_MCP_Server/Program.cs index 774c33ed58..568830bb04 100644 --- a/dotnet/samples/GettingStarted/ModelContextProtocol/Agent_MCP_Server/Program.cs +++ b/dotnet/samples/GettingStarted/ModelContextProtocol/Agent_MCP_Server/Program.cs @@ -27,7 +27,7 @@ new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent(instructions: "You answer questions related to GitHub repositories only.", tools: [.. mcpTools.Cast()]); + .AsAIAgent(instructions: "You answer questions related to GitHub repositories only.", tools: [.. mcpTools.Cast()]); // Invoke the agent and output the text result. Console.WriteLine(await agent.RunAsync("Summarize the last four commits to the microsoft/semantic-kernel repository?")); diff --git a/dotnet/samples/GettingStarted/ModelContextProtocol/Agent_MCP_Server_Auth/Program.cs b/dotnet/samples/GettingStarted/ModelContextProtocol/Agent_MCP_Server_Auth/Program.cs index c9197c573c..1a08945680 100644 --- a/dotnet/samples/GettingStarted/ModelContextProtocol/Agent_MCP_Server_Auth/Program.cs +++ b/dotnet/samples/GettingStarted/ModelContextProtocol/Agent_MCP_Server_Auth/Program.cs @@ -50,7 +50,7 @@ new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent(instructions: "You answer questions related to the weather.", tools: [.. mcpTools]); + .AsAIAgent(instructions: "You answer questions related to the weather.", tools: [.. mcpTools]); // Invoke the agent and output the text result. Console.WriteLine(await agent.RunAsync("Get current weather alerts for New York?")); diff --git a/dotnet/samples/GettingStarted/ModelContextProtocol/FoundryAgent_Hosted_MCP/Program.cs b/dotnet/samples/GettingStarted/ModelContextProtocol/FoundryAgent_Hosted_MCP/Program.cs index 123d666f09..9a42c1c467 100644 --- a/dotnet/samples/GettingStarted/ModelContextProtocol/FoundryAgent_Hosted_MCP/Program.cs +++ b/dotnet/samples/GettingStarted/ModelContextProtocol/FoundryAgent_Hosted_MCP/Program.cs @@ -42,7 +42,7 @@ }); // You can then invoke the agent like any other AIAgent. -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); Console.WriteLine(await agent.RunAsync("Please summarize the Azure AI Agent documentation related to MCP Tool calling?", thread)); // Cleanup for sample purposes. @@ -75,7 +75,7 @@ }); // You can then invoke the agent like any other AIAgent. -var threadWithRequiredApproval = agentWithRequiredApproval.GetNewThread(); +var threadWithRequiredApproval = await agentWithRequiredApproval.GetNewThreadAsync(); var response = await agentWithRequiredApproval.RunAsync("Please summarize the Azure AI Agent documentation related to MCP Tool calling?", threadWithRequiredApproval); var userInputRequests = response.UserInputRequests.ToList(); diff --git a/dotnet/samples/GettingStarted/ModelContextProtocol/ResponseAgent_Hosted_MCP/Program.cs b/dotnet/samples/GettingStarted/ModelContextProtocol/ResponseAgent_Hosted_MCP/Program.cs index 13ee28d6a1..986e7d0977 100644 --- a/dotnet/samples/GettingStarted/ModelContextProtocol/ResponseAgent_Hosted_MCP/Program.cs +++ b/dotnet/samples/GettingStarted/ModelContextProtocol/ResponseAgent_Hosted_MCP/Program.cs @@ -31,13 +31,13 @@ new Uri(endpoint), new AzureCliCredential()) .GetResponsesClient(deploymentName) - .CreateAIAgent( + .AsAIAgent( instructions: "You answer questions by searching the Microsoft Learn content only.", name: "MicrosoftLearnAgent", tools: [mcpTool]); // You can then invoke the agent like any other AIAgent. -AgentThread thread = agent.GetNewThread(); +AgentThread thread = await agent.GetNewThreadAsync(); Console.WriteLine(await agent.RunAsync("Please summarize the Azure AI Agent documentation related to MCP Tool calling?", thread)); // **** MCP Tool with Approval Required **** @@ -58,13 +58,13 @@ new Uri(endpoint), new AzureCliCredential()) .GetResponsesClient(deploymentName) - .CreateAIAgent( + .AsAIAgent( instructions: "You answer questions by searching the Microsoft Learn content only.", name: "MicrosoftLearnAgentWithApproval", tools: [mcpToolWithApproval]); // You can then invoke the agent like any other AIAgent. -var threadWithRequiredApproval = agentWithRequiredApproval.GetNewThread(); +var threadWithRequiredApproval = await agentWithRequiredApproval.GetNewThreadAsync(); var response = await agentWithRequiredApproval.RunAsync("Please summarize the Azure AI Agent documentation related to MCP Tool calling?", threadWithRequiredApproval); var userInputRequests = response.UserInputRequests.ToList(); diff --git a/dotnet/samples/GettingStarted/Workflows/Agents/CustomAgentExecutors/Program.cs b/dotnet/samples/GettingStarted/Workflows/Agents/CustomAgentExecutors/Program.cs index 91f58f460e..594f447e8c 100644 --- a/dotnet/samples/GettingStarted/Workflows/Agents/CustomAgentExecutors/Program.cs +++ b/dotnet/samples/GettingStarted/Workflows/Agents/CustomAgentExecutors/Program.cs @@ -109,7 +109,7 @@ internal sealed class SloganGeneratedEvent(SloganResult sloganResult) : Workflow internal sealed class SloganWriterExecutor : Executor { private readonly AIAgent _agent; - private readonly AgentThread _thread; + private AgentThread? _thread; /// /// Initializes a new instance of the class. @@ -128,7 +128,6 @@ public SloganWriterExecutor(string id, IChatClient chatClient) : base(id) }; this._agent = new ChatClientAgent(chatClient, agentOptions); - this._thread = this._agent.GetNewThread(); } protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => @@ -137,6 +136,8 @@ protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => public async ValueTask HandleAsync(string message, IWorkflowContext context, CancellationToken cancellationToken = default) { + this._thread ??= await this._agent.GetNewThreadAsync(cancellationToken); + var result = await this._agent.RunAsync(message, this._thread, cancellationToken: cancellationToken); var sloganResult = JsonSerializer.Deserialize(result.Text) ?? throw new InvalidOperationException("Failed to deserialize slogan result."); @@ -179,7 +180,7 @@ internal sealed class FeedbackEvent(FeedbackResult feedbackResult) : WorkflowEve internal sealed class FeedbackExecutor : Executor { private readonly AIAgent _agent; - private readonly AgentThread _thread; + private AgentThread? _thread; public int MinimumRating { get; init; } = 8; @@ -204,11 +205,12 @@ public FeedbackExecutor(string id, IChatClient chatClient) : base(id) }; this._agent = new ChatClientAgent(chatClient, agentOptions); - this._thread = this._agent.GetNewThread(); } public override async ValueTask HandleAsync(SloganResult message, IWorkflowContext context, CancellationToken cancellationToken = default) { + this._thread ??= await this._agent.GetNewThreadAsync(cancellationToken); + var sloganMessage = $""" Here is a slogan for the task '{message.Task}': Slogan: {message.Slogan} diff --git a/dotnet/samples/GettingStarted/Workflows/Agents/FoundryAgent/Program.cs b/dotnet/samples/GettingStarted/Workflows/Agents/FoundryAgent/Program.cs index 9f1de87438..35809685ea 100644 --- a/dotnet/samples/GettingStarted/Workflows/Agents/FoundryAgent/Program.cs +++ b/dotnet/samples/GettingStarted/Workflows/Agents/FoundryAgent/Program.cs @@ -45,7 +45,7 @@ private static async Task Main() await run.TrySendMessageAsync(new TurnToken(emitEvents: true)); await foreach (WorkflowEvent evt in run.WatchStreamAsync()) { - if (evt is AgentRunUpdateEvent executorComplete) + if (evt is AgentResponseUpdateEvent executorComplete) { Console.WriteLine($"{executorComplete.ExecutorId}: {executorComplete.Data}"); } diff --git a/dotnet/samples/GettingStarted/Workflows/Agents/WorkflowAsAnAgent/Program.cs b/dotnet/samples/GettingStarted/Workflows/Agents/WorkflowAsAnAgent/Program.cs index 6aa65d56b5..4ba04bd87f 100644 --- a/dotnet/samples/GettingStarted/Workflows/Agents/WorkflowAsAnAgent/Program.cs +++ b/dotnet/samples/GettingStarted/Workflows/Agents/WorkflowAsAnAgent/Program.cs @@ -37,7 +37,7 @@ private static async Task Main() // Create the workflow and turn it into an agent var workflow = WorkflowFactory.BuildWorkflow(chatClient); var agent = workflow.AsAgent("workflow-agent", "Workflow Agent"); - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); // Start an interactive loop to interact with the workflow as if it were an agent while (true) @@ -58,8 +58,8 @@ private static async Task Main() // re-render all messages on each update. static async Task ProcessInputAsync(AIAgent agent, AgentThread thread, string input) { - Dictionary> buffer = []; - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(input, thread)) + Dictionary> buffer = []; + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(input, thread)) { if (update.MessageId is null || string.IsNullOrEmpty(update.Text)) { @@ -68,7 +68,7 @@ static async Task ProcessInputAsync(AIAgent agent, AgentThread thread, string in } Console.Clear(); - if (!buffer.TryGetValue(update.MessageId, out List? value)) + if (!buffer.TryGetValue(update.MessageId, out List? value)) { value = []; buffer[update.MessageId] = value; diff --git a/dotnet/samples/GettingStarted/Workflows/Declarative/ExecuteCode/Generated.cs b/dotnet/samples/GettingStarted/Workflows/Declarative/ExecuteCode/Generated.cs index 6fc508064e..49a6ced2b7 100644 --- a/dotnet/samples/GettingStarted/Workflows/Declarative/ExecuteCode/Generated.cs +++ b/dotnet/samples/GettingStarted/Workflows/Declarative/ExecuteCode/Generated.cs @@ -65,7 +65,7 @@ internal sealed class QuestionStudentExecutor(FormulaSession session, WorkflowAg bool autoSend = true; IList? inputMessages = null; - AgentRunResponse agentResponse = + AgentResponse agentResponse = await InvokeAgentAsync( context, agentName, @@ -76,7 +76,7 @@ await InvokeAgentAsync( if (autoSend) { - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, agentResponse)).ConfigureAwait(false); + await context.AddEventAsync(new AgentResponseEvent(this.Id, agentResponse)).ConfigureAwait(false); } return default; @@ -102,7 +102,7 @@ internal sealed class QuestionTeacherExecutor(FormulaSession session, WorkflowAg bool autoSend = false; IList? inputMessages = null; - AgentRunResponse agentResponse = + AgentResponse agentResponse = await InvokeAgentAsync( context, agentName, @@ -113,7 +113,7 @@ await InvokeAgentAsync( if (autoSend) { - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, agentResponse)).ConfigureAwait(false); + await context.AddEventAsync(new AgentResponseEvent(this.Id, agentResponse)).ConfigureAwait(false); } await context.QueueStateUpdateAsync(key: "TeacherResponse", value: agentResponse.Messages, scopeName: "Local").ConfigureAwait(false); @@ -175,8 +175,8 @@ await context.FormatTemplateAsync( GOLD STAR! """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); return default; } @@ -196,8 +196,8 @@ await context.FormatTemplateAsync( Let's try again later... """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); return default; } diff --git a/dotnet/samples/GettingStarted/Workflows/Declarative/HostedWorkflow/Program.cs b/dotnet/samples/GettingStarted/Workflows/Declarative/HostedWorkflow/Program.cs index ff45cbc0c2..ee5c229f3d 100644 --- a/dotnet/samples/GettingStarted/Workflows/Declarative/HostedWorkflow/Program.cs +++ b/dotnet/samples/GettingStarted/Workflows/Declarative/HostedWorkflow/Program.cs @@ -47,7 +47,7 @@ public static async Task Main(string[] args) AIAgent agent = aiProjectClient.GetAIAgent(agentVersion); - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(); ProjectConversation conversation = await aiProjectClient @@ -65,10 +65,10 @@ await aiProjectClient }; ChatClientAgentRunOptions runOptions = new(chatOptions); - IAsyncEnumerable agentResponseUpdates = agent.RunStreamingAsync(workflowInput, thread, runOptions); + IAsyncEnumerable agentResponseUpdates = agent.RunStreamingAsync(workflowInput, thread, runOptions); string? lastMessageId = null; - await foreach (AgentRunResponseUpdate responseUpdate in agentResponseUpdates) + await foreach (AgentResponseUpdate responseUpdate in agentResponseUpdates) { if (responseUpdate.MessageId != lastMessageId) { diff --git a/dotnet/samples/GettingStarted/Workflows/Observability/WorkflowAsAnAgent/Program.cs b/dotnet/samples/GettingStarted/Workflows/Observability/WorkflowAsAnAgent/Program.cs index 17d7d03b3f..bf9f17ac80 100644 --- a/dotnet/samples/GettingStarted/Workflows/Observability/WorkflowAsAnAgent/Program.cs +++ b/dotnet/samples/GettingStarted/Workflows/Observability/WorkflowAsAnAgent/Program.cs @@ -90,7 +90,7 @@ private static async Task Main() { EnableSensitiveData = true // enable sensitive data at the agent level such as prompts and responses }; - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); // Start an interactive loop to interact with the workflow as if it were an agent while (true) @@ -111,8 +111,8 @@ private static async Task Main() // re-render all messages on each update. static async Task ProcessInputAsync(AIAgent agent, AgentThread thread, string input) { - Dictionary> buffer = []; - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(input, thread)) + Dictionary> buffer = []; + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(input, thread)) { if (update.MessageId is null || string.IsNullOrEmpty(update.Text)) { @@ -121,7 +121,7 @@ static async Task ProcessInputAsync(AIAgent agent, AgentThread thread, string in } Console.Clear(); - if (!buffer.TryGetValue(update.MessageId, out List? value)) + if (!buffer.TryGetValue(update.MessageId, out List? value)) { value = []; buffer[update.MessageId] = value; diff --git a/dotnet/samples/GettingStarted/Workflows/_Foundational/03_AgentsInWorkflows/Program.cs b/dotnet/samples/GettingStarted/Workflows/_Foundational/03_AgentsInWorkflows/Program.cs index 0a8ee0d6ee..4e61b5def6 100644 --- a/dotnet/samples/GettingStarted/Workflows/_Foundational/03_AgentsInWorkflows/Program.cs +++ b/dotnet/samples/GettingStarted/Workflows/_Foundational/03_AgentsInWorkflows/Program.cs @@ -52,7 +52,7 @@ private static async Task Main() await run.TrySendMessageAsync(new TurnToken(emitEvents: true)); await foreach (WorkflowEvent evt in run.WatchStreamAsync()) { - if (evt is AgentRunUpdateEvent executorComplete) + if (evt is AgentResponseUpdateEvent executorComplete) { Console.WriteLine($"{executorComplete.ExecutorId}: {executorComplete.Data}"); } diff --git a/dotnet/samples/GettingStarted/Workflows/_Foundational/04_AgentWorkflowPatterns/Program.cs b/dotnet/samples/GettingStarted/Workflows/_Foundational/04_AgentWorkflowPatterns/Program.cs index 1fa3aabb5c..225f11b59a 100644 --- a/dotnet/samples/GettingStarted/Workflows/_Foundational/04_AgentWorkflowPatterns/Program.cs +++ b/dotnet/samples/GettingStarted/Workflows/_Foundational/04_AgentWorkflowPatterns/Program.cs @@ -88,7 +88,7 @@ static async Task> RunWorkflowAsync(Workflow workflow, List HandleAsyncCoreAsync( Console.WriteLine($"\n=== Writer (Iteration {state.Iteration}) ===\n"); StringBuilder sb = new(); - await foreach (AgentRunResponseUpdate update in this._agent.RunStreamingAsync(message, cancellationToken: cancellationToken)) + await foreach (AgentResponseUpdate update in this._agent.RunStreamingAsync(message, cancellationToken: cancellationToken)) { if (!string.IsNullOrEmpty(update.Text)) { @@ -313,10 +313,10 @@ public override async ValueTask HandleAsync( Console.WriteLine($"=== Critic (Iteration {state.Iteration}) ===\n"); // Use RunStreamingAsync to get streaming updates, then deserialize at the end - IAsyncEnumerable updates = this._agent.RunStreamingAsync(message, cancellationToken: cancellationToken); + IAsyncEnumerable updates = this._agent.RunStreamingAsync(message, cancellationToken: cancellationToken); // Stream the output in real-time (for any rationale/explanation) - await foreach (AgentRunResponseUpdate update in updates) + await foreach (AgentResponseUpdate update in updates) { if (!string.IsNullOrEmpty(update.Text)) { @@ -326,7 +326,7 @@ public override async ValueTask HandleAsync( Console.WriteLine("\n"); // Convert the stream to a response and deserialize the structured output - AgentRunResponse response = await updates.ToAgentRunResponseAsync(cancellationToken); + AgentResponse response = await updates.ToAgentResponseAsync(cancellationToken); CriticDecision decision = response.Deserialize(JsonSerializerOptions.Web); Console.WriteLine($"Decision: {(decision.Approved ? "✅ APPROVED" : "❌ NEEDS REVISION")}"); @@ -394,7 +394,7 @@ public override async ValueTask HandleAsync( string prompt = $"Present this approved content:\n\n{message.Content}"; StringBuilder sb = new(); - await foreach (AgentRunResponseUpdate update in this._agent.RunStreamingAsync(new ChatMessage(ChatRole.User, prompt), cancellationToken: cancellationToken)) + await foreach (AgentResponseUpdate update in this._agent.RunStreamingAsync(new ChatMessage(ChatRole.User, prompt), cancellationToken: cancellationToken)) { if (!string.IsNullOrEmpty(update.Text)) { diff --git a/dotnet/samples/M365Agent/AFAgentApplication.cs b/dotnet/samples/M365Agent/AFAgentApplication.cs index 04aabf96a8..0962c1d6d2 100644 --- a/dotnet/samples/M365Agent/AFAgentApplication.cs +++ b/dotnet/samples/M365Agent/AFAgentApplication.cs @@ -43,22 +43,22 @@ private async Task MessageActivityAsync(ITurnContext turnContext, ITurnState tur // Deserialize the conversation history into an AgentThread, or create a new one if none exists. AgentThread agentThread = threadElementStart.ValueKind is not JsonValueKind.Undefined and not JsonValueKind.Null - ? this._agent.DeserializeThread(threadElementStart, JsonUtilities.DefaultOptions) - : this._agent.GetNewThread(); + ? await this._agent.DeserializeThreadAsync(threadElementStart, JsonUtilities.DefaultOptions, cancellationToken) + : await this._agent.GetNewThreadAsync(cancellationToken); ChatMessage chatMessage = HandleUserInput(turnContext); // Invoke the WeatherForecastAgent to process the message - AgentRunResponse agentRunResponse = await this._agent.RunAsync(chatMessage, agentThread, cancellationToken: cancellationToken); + AgentResponse agentResponse = await this._agent.RunAsync(chatMessage, agentThread, cancellationToken: cancellationToken); // Check for any user input requests in the response // and turn them into adaptive cards in the streaming response. List? attachments = null; - HandleUserInputRequests(agentRunResponse, ref attachments); + HandleUserInputRequests(agentResponse, ref attachments); // Check for Adaptive Card content in the response messages // and return them appropriately in the response. - var adaptiveCards = agentRunResponse.Messages.SelectMany(x => x.Contents).OfType().ToList(); + var adaptiveCards = agentResponse.Messages.SelectMany(x => x.Contents).OfType().ToList(); if (adaptiveCards.Count > 0) { attachments ??= []; @@ -70,7 +70,7 @@ private async Task MessageActivityAsync(ITurnContext turnContext, ITurnState tur } else { - turnContext.StreamingResponse.QueueTextChunk(agentRunResponse.Text); + turnContext.StreamingResponse.QueueTextChunk(agentResponse.Text); } // If created any adaptive cards, add them to the final message. @@ -134,9 +134,9 @@ private static ChatMessage HandleUserInput(ITurnContext turnContext) /// When the agent returns any user input requests, this method converts them into adaptive cards that /// asks the user to approve or deny the requests. /// - /// The that may contain the user input requests. + /// The that may contain the user input requests. /// The list of to which the adaptive cards will be added. - private static void HandleUserInputRequests(AgentRunResponse response, ref List? attachments) + private static void HandleUserInputRequests(AgentResponse response, ref List? attachments) { var userInputRequests = response.UserInputRequests.ToList(); if (userInputRequests.Count > 0) diff --git a/dotnet/samples/M365Agent/Agents/WeatherForecastAgent.cs b/dotnet/samples/M365Agent/Agents/WeatherForecastAgent.cs index ff7af20ba9..a4023a2c58 100644 --- a/dotnet/samples/M365Agent/Agents/WeatherForecastAgent.cs +++ b/dotnet/samples/M365Agent/Agents/WeatherForecastAgent.cs @@ -48,7 +48,7 @@ public WeatherForecastAgent(IChatClient chatClient) { } - protected override async Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override async Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { var response = await base.RunCoreAsync(messages, thread, options, cancellationToken); diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs index 96a8856dea..01e7c8f146 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs @@ -52,27 +52,27 @@ public A2AAgent(A2AClient a2aClient, string? id = null, string? name = null, str } /// - public sealed override AgentThread GetNewThread() - => new A2AAgentThread(); + public sealed override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) + => new(new A2AAgentThread()); /// /// Get a new instance using an existing context id, to continue that conversation. /// /// The context id to continue. - /// A new instance. - public AgentThread GetNewThread(string contextId) - => new A2AAgentThread() { ContextId = contextId }; + /// A value task representing the asynchronous operation. The task result contains a new instance. + public ValueTask GetNewThreadAsync(string contextId) + => new(new A2AAgentThread() { ContextId = contextId }); /// - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - => new A2AAgentThread(serializedThread, jsonSerializerOptions); + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + => new(new A2AAgentThread(serializedThread, jsonSerializerOptions)); /// - protected override async Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override async Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { _ = Throw.IfNull(messages); - A2AAgentThread typedThread = this.GetA2AThread(thread, options); + A2AAgentThread typedThread = await this.GetA2AThreadAsync(thread, options, cancellationToken).ConfigureAwait(false); this._logger.LogA2AAgentInvokingAgent(nameof(RunAsync), this.Id, this.Name); @@ -99,7 +99,7 @@ protected override async Task RunCoreAsync(IEnumerable RunCoreAsync(IEnumerable RunCoreAsync(IEnumerable - protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { _ = Throw.IfNull(messages); - A2AAgentThread typedThread = this.GetA2AThread(thread, options); + A2AAgentThread typedThread = await this.GetA2AThreadAsync(thread, options, cancellationToken).ConfigureAwait(false); this._logger.LogA2AAgentInvokingAgent(nameof(RunStreamingAsync), this.Id, this.Name); @@ -211,7 +211,7 @@ protected override async IAsyncEnumerable RunCoreStreami /// public override string? Description => this._description; - private A2AAgentThread GetA2AThread(AgentThread? thread, AgentRunOptions? options) + private async ValueTask GetA2AThreadAsync(AgentThread? thread, AgentRunOptions? options, CancellationToken cancellationToken) { // Aligning with other agent implementations that support background responses, where // a thread is required for background responses to prevent inconsistent experience @@ -221,7 +221,7 @@ private A2AAgentThread GetA2AThread(AgentThread? thread, AgentRunOptions? option throw new InvalidOperationException("A thread must be provided when AllowBackgroundResponses is enabled."); } - thread ??= this.GetNewThread(); + thread ??= await this.GetNewThreadAsync(cancellationToken).ConfigureAwait(false); if (thread is not A2AAgentThread typedThread) { @@ -291,9 +291,9 @@ private static AgentMessage CreateA2AMessage(A2AAgentThread typedThread, IEnumer return null; } - private AgentRunResponseUpdate ConvertToAgentResponseUpdate(AgentMessage message) + private AgentResponseUpdate ConvertToAgentResponseUpdate(AgentMessage message) { - return new AgentRunResponseUpdate + return new AgentResponseUpdate { AgentId = this.Id, ResponseId = message.MessageId, @@ -305,9 +305,9 @@ private AgentRunResponseUpdate ConvertToAgentResponseUpdate(AgentMessage message }; } - private AgentRunResponseUpdate ConvertToAgentResponseUpdate(AgentTask task) + private AgentResponseUpdate ConvertToAgentResponseUpdate(AgentTask task) { - return new AgentRunResponseUpdate + return new AgentResponseUpdate { AgentId = this.Id, ResponseId = task.Id, @@ -318,9 +318,9 @@ private AgentRunResponseUpdate ConvertToAgentResponseUpdate(AgentTask task) }; } - private AgentRunResponseUpdate ConvertToAgentResponseUpdate(TaskUpdateEvent taskUpdateEvent) + private AgentResponseUpdate ConvertToAgentResponseUpdate(TaskUpdateEvent taskUpdateEvent) { - AgentRunResponseUpdate responseUpdate = new() + AgentResponseUpdate responseUpdate = new() { AgentId = this.Id, ResponseId = taskUpdateEvent.TaskId, diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2AAgentCardExtensions.cs b/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2AAgentCardExtensions.cs index 39d7107430..1998d020b5 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2AAgentCardExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2AAgentCardExtensions.cs @@ -27,11 +27,11 @@ public static class A2AAgentCardExtensions /// The to use for HTTP requests. /// The logger factory for enabling logging within the agent. /// An instance backed by the A2A agent. - public static AIAgent GetAIAgent(this AgentCard card, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) + public static AIAgent AsAIAgent(this AgentCard card, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) { // Create the A2A client using the agent URL from the card. var a2aClient = new A2AClient(new Uri(card.Url), httpClient); - return a2aClient.GetAIAgent(name: card.Name, description: card.Description, loggerFactory: loggerFactory); + return a2aClient.AsAIAgent(name: card.Name, description: card.Description, loggerFactory: loggerFactory); } } diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2ACardResolverExtensions.cs b/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2ACardResolverExtensions.cs index 2da58222b8..6a32822fea 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2ACardResolverExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2ACardResolverExtensions.cs @@ -42,6 +42,6 @@ public static async Task GetAIAgentAsync(this A2ACardResolver resolver, // Obtain the agent card from the resolver. var agentCard = await resolver.GetAgentCardAsync(cancellationToken).ConfigureAwait(false); - return agentCard.GetAIAgent(httpClient, loggerFactory); + return agentCard.AsAIAgent(httpClient, loggerFactory); } } diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2AClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2AClientExtensions.cs index d57ed4cb42..cd93ca0bac 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2AClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2AClientExtensions.cs @@ -35,6 +35,6 @@ public static class A2AClientExtensions /// The description of the agent. /// Optional logger factory for enabling logging within the agent. /// An instance backed by the A2A agent. - public static AIAgent GetAIAgent(this A2AClient client, string? id = null, string? name = null, string? description = null, ILoggerFactory? loggerFactory = null) => + public static AIAgent AsAIAgent(this A2AClient client, string? id = null, string? name = null, string? description = null, ILoggerFactory? loggerFactory = null) => new A2AAgent(client, id, name, description, loggerFactory); } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs index afed5d1518..10284dc25a 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs @@ -105,7 +105,8 @@ public abstract class AIAgent /// /// Creates a new conversation thread that is compatible with this agent. /// - /// A new instance ready for use with this agent. + /// The to monitor for cancellation requests. The default is . + /// A value task that represents the asynchronous operation. The task result contains a new instance ready for use with this agent. /// /// /// This method creates a fresh conversation thread that can be used to maintain state @@ -118,14 +119,15 @@ public abstract class AIAgent /// may be deferred until first use to optimize performance. /// /// - public abstract AgentThread GetNewThread(); + public abstract ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default); /// /// Deserializes an agent thread from its JSON serialized representation. /// /// A containing the serialized thread state. /// Optional settings to customize the deserialization process. - /// A restored instance with the state from . + /// The to monitor for cancellation requests. The default is . + /// A value task that represents the asynchronous operation. The task result contains a restored instance with the state from . /// The is not in the expected format. /// The serialized data is invalid or cannot be deserialized. /// @@ -133,7 +135,7 @@ public abstract class AIAgent /// allowing conversations to resume across application restarts or be migrated between /// different agent instances. /// - public abstract AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null); + public abstract ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default); /// /// Run the agent with no message assuming that all required instructions are already provided to the agent or on the thread. @@ -144,12 +146,12 @@ public abstract class AIAgent /// /// Optional configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. /// /// This overload is useful when the agent has sufficient context from previous messages in the thread /// or from its initial configuration to generate a meaningful response without additional input. /// - public Task RunAsync( + public Task RunAsync( AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => @@ -165,13 +167,13 @@ public Task RunAsync( /// /// Optional configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. /// is , empty, or contains only whitespace. /// /// The provided text will be wrapped in a with the role /// before being sent to the agent. This is a convenience method for simple text-based interactions. /// - public Task RunAsync( + public Task RunAsync( string message, AgentThread? thread = null, AgentRunOptions? options = null, @@ -192,9 +194,9 @@ public Task RunAsync( /// /// Optional configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. /// is . - public Task RunAsync( + public Task RunAsync( ChatMessage message, AgentThread? thread = null, AgentRunOptions? options = null, @@ -215,7 +217,7 @@ public Task RunAsync( /// /// Optional configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. /// /// /// This method delegates to to perform the actual agent invocation. It handles collections of messages, @@ -227,7 +229,7 @@ public Task RunAsync( /// The agent's response will also be added to if one is provided. /// /// - public Task RunAsync( + public Task RunAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -244,7 +246,7 @@ public Task RunAsync( /// /// Optional configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. /// /// /// This is the primary invocation method that implementations must override. It handles collections of messages, @@ -256,7 +258,7 @@ public Task RunAsync( /// The agent's response will also be added to if one is provided. /// /// - protected abstract Task RunCoreAsync( + protected abstract Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -271,8 +273,8 @@ protected abstract Task RunCoreAsync( /// /// Optional configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// An asynchronous enumerable of instances representing the streaming response. - public IAsyncEnumerable RunStreamingAsync( + /// An asynchronous enumerable of instances representing the streaming response. + public IAsyncEnumerable RunStreamingAsync( AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => @@ -288,13 +290,13 @@ public IAsyncEnumerable RunStreamingAsync( /// /// Optional configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// An asynchronous enumerable of instances representing the streaming response. + /// An asynchronous enumerable of instances representing the streaming response. /// is , empty, or contains only whitespace. /// /// The provided text will be wrapped in a with the role. /// Streaming invocation provides real-time updates as the agent generates its response. /// - public IAsyncEnumerable RunStreamingAsync( + public IAsyncEnumerable RunStreamingAsync( string message, AgentThread? thread = null, AgentRunOptions? options = null, @@ -315,9 +317,9 @@ public IAsyncEnumerable RunStreamingAsync( /// /// Optional configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// An asynchronous enumerable of instances representing the streaming response. + /// An asynchronous enumerable of instances representing the streaming response. /// is . - public IAsyncEnumerable RunStreamingAsync( + public IAsyncEnumerable RunStreamingAsync( ChatMessage message, AgentThread? thread = null, AgentRunOptions? options = null, @@ -338,18 +340,18 @@ public IAsyncEnumerable RunStreamingAsync( /// /// Optional configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// An asynchronous enumerable of instances representing the streaming response. + /// An asynchronous enumerable of instances representing the streaming response. /// /// /// This method delegates to to perform the actual streaming invocation. It provides real-time /// updates as the agent processes the input and generates its response, enabling more responsive user experiences. /// /// - /// Each represents a portion of the complete response, allowing consumers + /// Each represents a portion of the complete response, allowing consumers /// to display partial results, implement progressive loading, or provide immediate feedback to users. /// /// - public IAsyncEnumerable RunStreamingAsync( + public IAsyncEnumerable RunStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -366,18 +368,18 @@ public IAsyncEnumerable RunStreamingAsync( /// /// Optional configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// An asynchronous enumerable of instances representing the streaming response. + /// An asynchronous enumerable of instances representing the streaming response. /// /// /// This is the primary streaming invocation method that implementations must override. It provides real-time /// updates as the agent processes the input and generates its response, enabling more responsive user experiences. /// /// - /// Each represents a portion of the complete response, allowing consumers + /// Each represents a portion of the complete response, allowing consumers /// to display partial results, implement progressive loading, or provide immediate feedback to users. /// /// - protected abstract IAsyncEnumerable RunCoreStreamingAsync( + protected abstract IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index fd3ff10fc2..f104f12890 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -142,7 +142,7 @@ public InvokingContext(IEnumerable requestMessages) /// /// A collection of instances representing new messages that were provided by the caller. /// - public IEnumerable RequestMessages { get; } + public IEnumerable RequestMessages { get; set { field = Throw.IfNull(value); } } } /// @@ -174,7 +174,7 @@ public InvokedContext(IEnumerable requestMessages, IEnumerable instances representing new messages that were provided by the caller. /// This does not include any supplied messages. /// - public IEnumerable RequestMessages { get; } + public IEnumerable RequestMessages { get; set { field = Throw.IfNull(value); } } /// /// Gets the messages provided by the for this invocation, if any. @@ -183,7 +183,7 @@ public InvokedContext(IEnumerable requestMessages, IEnumerable instances that were provided by the , /// and were used by the agent as part of the invocation. /// - public IEnumerable? AIContextProviderMessages { get; } + public IEnumerable? AIContextProviderMessages { get; set; } /// /// Gets the collection of response messages generated during this invocation if the invocation succeeded. diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs index d5003cace0..937d871c56 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs @@ -76,10 +76,10 @@ private static JsonSerializerOptions CreateDefaultOptions() // Agent abstraction types [JsonSerializable(typeof(AgentRunOptions))] - [JsonSerializable(typeof(AgentRunResponse))] - [JsonSerializable(typeof(AgentRunResponse[]))] - [JsonSerializable(typeof(AgentRunResponseUpdate))] - [JsonSerializable(typeof(AgentRunResponseUpdate[]))] + [JsonSerializable(typeof(AgentResponse))] + [JsonSerializable(typeof(AgentResponse[]))] + [JsonSerializable(typeof(AgentResponseUpdate))] + [JsonSerializable(typeof(AgentResponseUpdate[]))] [JsonSerializable(typeof(ServiceIdAgentThread.ServiceIdAgentThreadState))] [JsonSerializable(typeof(InMemoryAgentThread.InMemoryAgentThreadState))] [JsonSerializable(typeof(InMemoryChatMessageStore.StoreState))] diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponse.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentResponse.cs similarity index 92% rename from dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponse.cs rename to dotnet/src/Microsoft.Agents.AI.Abstractions/AgentResponse.cs index 7828b5c62d..dbded1ef88 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponse.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentResponse.cs @@ -24,30 +24,30 @@ namespace Microsoft.Agents.AI; /// /// /// -/// provides one or more response messages and metadata about the response. +/// provides one or more response messages and metadata about the response. /// A typical response will contain a single message, however a response may contain multiple messages /// in a variety of scenarios. For example, if the agent internally invokes functions or tools, performs /// RAG retrievals or has other complex logic, a single run by the agent may produce many messages showing /// the intermediate progress that the agent made towards producing the agent result. /// /// -/// To get the text result of the response, use the property or simply call on the . +/// To get the text result of the response, use the property or simply call on the . /// /// -public class AgentRunResponse +public class AgentResponse { /// The response messages. private IList? _messages; - /// Initializes a new instance of the class. - public AgentRunResponse() + /// Initializes a new instance of the class. + public AgentResponse() { } - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// The response message to include in this response. /// is . - public AgentRunResponse(ChatMessage message) + public AgentResponse(ChatMessage message) { _ = Throw.IfNull(message); @@ -55,16 +55,16 @@ public AgentRunResponse(ChatMessage message) } /// - /// Initializes a new instance of the class from an existing . + /// Initializes a new instance of the class from an existing . /// - /// The from which to populate this . + /// The from which to populate this . /// is . /// /// This constructor creates an agent response that wraps an existing , preserving all /// metadata and storing the original response in for access to /// the underlying implementation details. /// - public AgentRunResponse(ChatResponse response) + public AgentResponse(ChatResponse response) { _ = Throw.IfNull(response); @@ -78,10 +78,10 @@ public AgentRunResponse(ChatResponse response) } /// - /// Initializes a new instance of the class with the specified collection of messages. + /// Initializes a new instance of the class with the specified collection of messages. /// /// The collection of response messages, or to create an empty response. - public AgentRunResponse(IList? messages) + public AgentResponse(IList? messages) { this._messages = messages; } @@ -201,7 +201,7 @@ public IList Messages /// Gets or sets the raw representation of the run response from an underlying implementation. /// - /// If a is created to represent some underlying object from another object + /// If a is created to represent some underlying object from another object /// model, this property can be used to store that original object. This can be useful for debugging or /// for enabling a consumer to access the underlying object model if needed. /// @@ -226,11 +226,11 @@ public IList Messages public override string ToString() => this.Text; /// - /// Converts this into a collection of instances + /// Converts this into a collection of instances /// suitable for streaming scenarios. /// /// - /// An array of instances that collectively represent + /// An array of instances that collectively represent /// the same information as this response. /// /// @@ -245,12 +245,12 @@ public IList Messages /// original message sequence. /// /// - public AgentRunResponseUpdate[] ToAgentRunResponseUpdates() + public AgentResponseUpdate[] ToAgentResponseUpdates() { - AgentRunResponseUpdate? extra = null; + AgentResponseUpdate? extra = null; if (this.AdditionalProperties is not null || this.Usage is not null) { - extra = new AgentRunResponseUpdate + extra = new AgentResponseUpdate { AdditionalProperties = this.AdditionalProperties, }; @@ -262,13 +262,13 @@ public AgentRunResponseUpdate[] ToAgentRunResponseUpdates() } int messageCount = this._messages?.Count ?? 0; - var updates = new AgentRunResponseUpdate[messageCount + (extra is not null ? 1 : 0)]; + var updates = new AgentResponseUpdate[messageCount + (extra is not null ? 1 : 0)]; int i; for (i = 0; i < messageCount; i++) { ChatMessage message = this._messages![i]; - updates[i] = new AgentRunResponseUpdate + updates[i] = new AgentResponseUpdate { AdditionalProperties = message.AdditionalProperties, AuthorName = message.AuthorName, diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponseExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentResponseExtensions.cs similarity index 71% rename from dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponseExtensions.cs rename to dotnet/src/Microsoft.Agents.AI.Abstractions/AgentResponseExtensions.cs index cb3ad7ec74..75ff6fb359 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponseExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentResponseExtensions.cs @@ -11,24 +11,24 @@ namespace Microsoft.Agents.AI; /// -/// Provides extension methods for working with and instances. +/// Provides extension methods for working with and instances. /// -public static class AgentRunResponseExtensions +public static class AgentResponseExtensions { /// - /// Creates a from an instance. + /// Creates a from an instance. /// - /// The to convert. + /// The to convert. /// A built from the specified . /// is . /// - /// If the 's is already a + /// If the 's is already a /// instance, that instance is returned directly. /// Otherwise, a new is created and populated with the data from the . - /// The resulting instance is a shallow copy; any reference-type members (e.g. ) + /// The resulting instance is a shallow copy; any reference-type members (e.g. ) /// will be shared between the two instances. /// - public static ChatResponse AsChatResponse(this AgentRunResponse response) + public static ChatResponse AsChatResponse(this AgentResponse response) { Throw.IfNull(response); @@ -47,19 +47,19 @@ response.RawRepresentation as ChatResponse ?? } /// - /// Creates a from an instance. + /// Creates a from an instance. /// - /// The to convert. + /// The to convert. /// A built from the specified . /// is . /// - /// If the 's is already a + /// If the 's is already a /// instance, that instance is returned directly. /// Otherwise, a new is created and populated with the data from the . - /// The resulting instance is a shallow copy; any reference-type members (e.g. ) + /// The resulting instance is a shallow copy; any reference-type members (e.g. ) /// will be shared between the two instances. /// - public static ChatResponseUpdate AsChatResponseUpdate(this AgentRunResponseUpdate responseUpdate) + public static ChatResponseUpdate AsChatResponseUpdate(this AgentResponseUpdate responseUpdate) { Throw.IfNull(responseUpdate); @@ -81,17 +81,17 @@ responseUpdate.RawRepresentation as ChatResponseUpdate ?? /// /// Creates an asynchronous enumerable of instances from an asynchronous - /// enumerable of instances. + /// enumerable of instances. /// - /// The sequence of instances to convert. + /// The sequence of instances to convert. /// An asynchronous enumerable of instances built from . /// is . /// - /// Each is converted to a using + /// Each is converted to a using /// . /// public static async IAsyncEnumerable AsChatResponseUpdatesAsync( - this IAsyncEnumerable responseUpdates) + this IAsyncEnumerable responseUpdates) { Throw.IfNull(responseUpdates); @@ -102,71 +102,71 @@ public static async IAsyncEnumerable AsChatResponseUpdatesAs } /// - /// Combines a sequence of instances into a single . + /// Combines a sequence of instances into a single . /// /// The sequence of updates to be combined into a single response. - /// A single that represents the combined state of all the updates. + /// A single that represents the combined state of all the updates. /// is . /// - /// As part of combining into a single , the method will attempt to reconstruct - /// instances. This includes using to determine + /// As part of combining into a single , the method will attempt to reconstruct + /// instances. This includes using to determine /// message boundaries, as well as coalescing contiguous items where applicable, e.g. multiple /// instances in a row may be combined into a single . /// - public static AgentRunResponse ToAgentRunResponse( - this IEnumerable updates) + public static AgentResponse ToAgentResponse( + this IEnumerable updates) { _ = Throw.IfNull(updates); - AgentRunResponseDetails additionalDetails = new(); + AgentResponseDetails additionalDetails = new(); ChatResponse chatResponse = AsChatResponseUpdatesWithAdditionalDetails(updates, additionalDetails) .ToChatResponse(); - return new AgentRunResponse(chatResponse) + return new AgentResponse(chatResponse) { AgentId = additionalDetails.AgentId, }; } /// - /// Asynchronously combines a sequence of instances into a single . + /// Asynchronously combines a sequence of instances into a single . /// /// The asynchronous sequence of updates to be combined into a single response. /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains a single that represents the combined state of all the updates. + /// A task that represents the asynchronous operation. The task result contains a single that represents the combined state of all the updates. /// is . /// /// - /// This is the asynchronous version of . + /// This is the asynchronous version of . /// It performs the same combining logic but operates on an asynchronous enumerable of updates. /// /// - /// As part of combining into a single , the method will attempt to reconstruct - /// instances. This includes using to determine + /// As part of combining into a single , the method will attempt to reconstruct + /// instances. This includes using to determine /// message boundaries, as well as coalescing contiguous items where applicable, e.g. multiple /// instances in a row may be combined into a single . /// /// - public static Task ToAgentRunResponseAsync( - this IAsyncEnumerable updates, + public static Task ToAgentResponseAsync( + this IAsyncEnumerable updates, CancellationToken cancellationToken = default) { _ = Throw.IfNull(updates); - return ToAgentRunResponseAsync(updates, cancellationToken); + return ToAgentResponseAsync(updates, cancellationToken); - static async Task ToAgentRunResponseAsync( - IAsyncEnumerable updates, + static async Task ToAgentResponseAsync( + IAsyncEnumerable updates, CancellationToken cancellationToken) { - AgentRunResponseDetails additionalDetails = new(); + AgentResponseDetails additionalDetails = new(); ChatResponse chatResponse = await AsChatResponseUpdatesWithAdditionalDetailsAsync(updates, additionalDetails, cancellationToken) .ToChatResponseAsync(cancellationToken) .ConfigureAwait(false); - return new AgentRunResponse(chatResponse) + return new AgentResponse(chatResponse) { AgentId = additionalDetails.AgentId, }; @@ -174,8 +174,8 @@ static async Task ToAgentRunResponseAsync( } private static IEnumerable AsChatResponseUpdatesWithAdditionalDetails( - IEnumerable updates, - AgentRunResponseDetails additionalDetails) + IEnumerable updates, + AgentResponseDetails additionalDetails) { foreach (var update in updates) { @@ -185,8 +185,8 @@ private static IEnumerable AsChatResponseUpdatesWithAddition } private static async IAsyncEnumerable AsChatResponseUpdatesWithAdditionalDetailsAsync( - IAsyncEnumerable updates, - AgentRunResponseDetails additionalDetails, + IAsyncEnumerable updates, + AgentResponseDetails additionalDetails, [EnumeratorCancellation] CancellationToken cancellationToken) { await foreach (var update in updates.WithCancellation(cancellationToken).ConfigureAwait(false)) @@ -196,7 +196,7 @@ private static async IAsyncEnumerable AsChatResponseUpdatesW } } - private static void UpdateAdditionalDetails(AgentRunResponseUpdate update, AgentRunResponseDetails details) + private static void UpdateAdditionalDetails(AgentResponseUpdate update, AgentResponseDetails details) { if (update.AgentId is { Length: > 0 }) { @@ -204,7 +204,7 @@ private static void UpdateAdditionalDetails(AgentRunResponseUpdate update, Agent } } - private sealed class AgentRunResponseDetails + private sealed class AgentResponseDetails { public string? AgentId { get; set; } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponseUpdate.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentResponseUpdate.cs similarity index 82% rename from dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponseUpdate.cs rename to dotnet/src/Microsoft.Agents.AI.Abstractions/AgentResponseUpdate.cs index ccf3deae54..041af06593 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponseUpdate.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentResponseUpdate.cs @@ -16,54 +16,54 @@ namespace Microsoft.Agents.AI; /// /// /// -/// is so named because it represents updates +/// is so named because it represents updates /// that layer on each other to form a single agent response. Conceptually, this combines the roles of -/// and in streaming output. +/// and in streaming output. /// /// -/// To get the text result of this response chunk, use the property or simply call on the . +/// To get the text result of this response chunk, use the property or simply call on the . /// /// -/// The relationship between and is -/// codified in the and -/// , which enable bidirectional conversions +/// The relationship between and is +/// codified in the and +/// , which enable bidirectional conversions /// between the two. Note, however, that the provided conversions may be lossy, for example if multiple /// updates all have different objects whereas there's only one slot for -/// such an object available in . +/// such an object available in . /// /// [DebuggerDisplay("[{Role}] {ContentForDebuggerDisplay}{EllipsesForDebuggerDisplay,nq}")] -public class AgentRunResponseUpdate +public class AgentResponseUpdate { /// The response update content items. private IList? _contents; - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. [JsonConstructor] - public AgentRunResponseUpdate() + public AgentResponseUpdate() { } - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// The role of the author of the update. /// The text content of the update. - public AgentRunResponseUpdate(ChatRole? role, string? content) + public AgentResponseUpdate(ChatRole? role, string? content) : this(role, content is null ? null : [new TextContent(content)]) { } - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// The role of the author of the update. /// The contents of the update. - public AgentRunResponseUpdate(ChatRole? role, IList? contents) + public AgentResponseUpdate(ChatRole? role, IList? contents) { this.Role = role; this._contents = contents; } - /// Initializes a new instance of the class. - /// The from which to seed this . - public AgentRunResponseUpdate(ChatResponseUpdate chatResponseUpdate) + /// Initializes a new instance of the class. + /// The from which to seed this . + public AgentResponseUpdate(ChatResponseUpdate chatResponseUpdate) { _ = Throw.IfNull(chatResponseUpdate); @@ -112,7 +112,7 @@ public IList Contents /// Gets or sets the raw representation of the response update from an underlying implementation. /// - /// If a is created to represent some underlying object from another object + /// If a is created to represent some underlying object from another object /// model, this property can be used to store that original object. This can be useful for debugging or /// for enabling a consumer to access the underlying object model if needed. /// @@ -136,8 +136,8 @@ public IList Contents /// Some providers may consider streaming responses to be a single message, and in that case /// the value of this property may be the same as the response ID. /// - /// This value is used when - /// groups instances into instances. + /// This value is used when + /// groups instances into instances. /// The value must be unique to each call to the underlying provider, and must be shared by /// all updates that are part of the same logical message within a streaming response. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponse{T}.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentResponse{T}.cs similarity index 63% rename from dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponse{T}.cs rename to dotnet/src/Microsoft.Agents.AI.Abstractions/AgentResponse{T}.cs index 9bac7df6fe..2a18aadb37 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunResponse{T}.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentResponse{T}.cs @@ -8,18 +8,18 @@ namespace Microsoft.Agents.AI; /// Represents the response of the specified type to an run request. /// /// The type of value expected from the agent. -public abstract class AgentRunResponse : AgentRunResponse +public abstract class AgentResponse : AgentResponse { - /// Initializes a new instance of the class. - protected AgentRunResponse() + /// Initializes a new instance of the class. + protected AgentResponse() { } /// - /// Initializes a new instance of the class from an existing . + /// Initializes a new instance of the class from an existing . /// - /// The from which to populate this . - protected AgentRunResponse(ChatResponse response) : base(response) + /// The from which to populate this . + protected AgentResponse(ChatResponse response) : base(response) { } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunOptions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunOptions.cs index 9cd6d51680..5fea157b75 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunOptions.cs @@ -43,10 +43,10 @@ public AgentRunOptions(AgentRunOptions options) /// This property is used for background responses that can be activated via the /// property if the implementation supports them. /// Streamed background responses, such as those returned by default by - /// can be resumed if interrupted. This means that a continuation token obtained from the + /// can be resumed if interrupted. This means that a continuation token obtained from the /// of an update just before the interruption occurred can be passed to this property to resume the stream from the point of interruption. /// Non-streamed background responses, such as those returned by , - /// can be polled for completion by obtaining the token from the property + /// can be polled for completion by obtaining the token from the property /// and passing it via this property on subsequent calls to . /// public ResponseContinuationToken? ContinuationToken { get; set; } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentThread.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentThread.cs index 0a3301d05f..318307ec43 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentThread.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentThread.cs @@ -26,8 +26,8 @@ namespace Microsoft.Agents.AI; /// Chat history reduction, e.g. where messages needs to be summarized or truncated to reduce the size. /// /// An is always constructed by an so that the -/// can attach any necessary behaviors to the . See the -/// and methods for more information. +/// can attach any necessary behaviors to the . See the +/// and methods for more information. /// /// /// Because of these behaviors, an may not be reusable across different agents, since each agent @@ -37,13 +37,13 @@ namespace Microsoft.Agents.AI; /// To support conversations that may need to survive application restarts or separate service requests, an can be serialized /// and deserialized, so that it can be saved in a persistent store. /// The provides the method to serialize the thread to a -/// and the method +/// and the method /// can be used to deserialize the thread. /// /// /// -/// -/// +/// +/// public abstract class AgentThread { /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageStore.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageStore.cs index d28cd191b7..54cee063d7 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageStore.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageStore.cs @@ -152,7 +152,7 @@ public InvokingContext(IEnumerable requestMessages) /// /// A collection of instances representing new messages that were provided by the caller. /// - public IEnumerable RequestMessages { get; } + public IEnumerable RequestMessages { get; set { field = Throw.IfNull(value); } } } /// @@ -174,7 +174,7 @@ public sealed class InvokedContext public InvokedContext(IEnumerable requestMessages, IEnumerable chatMessageStoreMessages) { this.RequestMessages = Throw.IfNull(requestMessages); - this.ChatMessageStoreMessages = chatMessageStoreMessages; + this.ChatMessageStoreMessages = Throw.IfNull(chatMessageStoreMessages); } /// @@ -184,7 +184,7 @@ public InvokedContext(IEnumerable requestMessages, IEnumerable instances representing new messages that were provided by the caller. /// This does not include any supplied messages. /// - public IEnumerable RequestMessages { get; } + public IEnumerable RequestMessages { get; set { field = Throw.IfNull(value); } } /// /// Gets the messages retrieved from the for this invocation, if any. @@ -193,7 +193,7 @@ public InvokedContext(IEnumerable requestMessages, IEnumerable instances that were retrieved from the , /// and were used by the agent as part of the invocation. /// - public IEnumerable ChatMessageStoreMessages { get; } + public IEnumerable ChatMessageStoreMessages { get; set { field = Throw.IfNull(value); } } /// /// Gets or sets the messages provided by the for this invocation, if any. diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs index e7bf58f39f..e9a3d5bc7a 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs @@ -74,14 +74,14 @@ protected DelegatingAIAgent(AIAgent innerAgent) } /// - public override AgentThread GetNewThread() => this.InnerAgent.GetNewThread(); + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) => this.InnerAgent.GetNewThreadAsync(cancellationToken); /// - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - => this.InnerAgent.DeserializeThread(serializedThread, jsonSerializerOptions); + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + => this.InnerAgent.DeserializeThreadAsync(serializedThread, jsonSerializerOptions, cancellationToken); /// - protected override Task RunCoreAsync( + protected override Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -89,7 +89,7 @@ protected override Task RunCoreAsync( => this.InnerAgent.RunAsync(messages, thread, options, cancellationToken); /// - protected override IAsyncEnumerable RunCoreStreamingAsync( + protected override IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatMessageStore.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatMessageStore.cs index f7f4522f8f..1fb1b568ae 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatMessageStore.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatMessageStore.cs @@ -29,7 +29,7 @@ namespace Microsoft.Agents.AI; /// [DebuggerDisplay("Count = {Count}")] [DebuggerTypeProxy(typeof(DebugView))] -public sealed class InMemoryChatMessageStore : ChatMessageStore, IList +public sealed class InMemoryChatMessageStore : ChatMessageStore, IList, IReadOnlyList { private List _messages; diff --git a/dotnet/src/Microsoft.Agents.AI.Anthropic/AnthropicBetaServiceExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Anthropic/AnthropicBetaServiceExtensions.cs index 6b4f872a63..06c7cbaf15 100644 --- a/dotnet/src/Microsoft.Agents.AI.Anthropic/AnthropicBetaServiceExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Anthropic/AnthropicBetaServiceExtensions.cs @@ -31,7 +31,7 @@ public static class AnthropicBetaServiceExtensions /// Optional logger factory for enabling logging within the agent. /// An optional to use for resolving services required by the instances being invoked. /// The created AI agent. - public static ChatClientAgent CreateAIAgent( + public static ChatClientAgent AsAIAgent( this IBetaService betaService, string model, string? instructions = null, @@ -81,7 +81,7 @@ public static ChatClientAgent CreateAIAgent( /// An optional to use for resolving services required by the instances being invoked. /// An instance backed by the Anthropic Chat Completion service. /// Thrown when or is . - public static ChatClientAgent CreateAIAgent( + public static ChatClientAgent AsAIAgent( this IBetaService betaService, ChatClientAgentOptions options, Func? clientFactory = null, diff --git a/dotnet/src/Microsoft.Agents.AI.Anthropic/AnthropicClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Anthropic/AnthropicClientExtensions.cs index b4b8e2bc1e..c0bbd4715d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Anthropic/AnthropicClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Anthropic/AnthropicClientExtensions.cs @@ -31,7 +31,7 @@ public static class AnthropicClientExtensions /// Optional logger factory for enabling logging within the agent. /// An optional to use for resolving services required by the instances being invoked. /// The created AI agent. - public static ChatClientAgent CreateAIAgent( + public static ChatClientAgent AsAIAgent( this IAnthropicClient client, string model, string? instructions = null, @@ -81,7 +81,7 @@ public static ChatClientAgent CreateAIAgent( /// An optional to use for resolving services required by the instances being invoked. /// An instance backed by the Anthropic Chat Completion service. /// Thrown when or is . - public static ChatClientAgent CreateAIAgent( + public static ChatClientAgent AsAIAgent( this IAnthropicClient client, ChatClientAgentOptions options, Func? clientFactory = null, diff --git a/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs index 5ca1436587..55c3c4f0bf 100644 --- a/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs @@ -19,7 +19,7 @@ public static class PersistentAgentsClientExtensions /// Provides a way to customize the creation of the underlying used by the agent. /// An optional to use for resolving services required by the instances being invoked. /// A instance that can be used to perform operations on the persistent agent. - public static ChatClientAgent GetAIAgent( + public static ChatClientAgent AsAIAgent( this PersistentAgentsClient persistentAgentsClient, Response persistentAgentResponse, ChatOptions? chatOptions = null, @@ -31,7 +31,7 @@ public static ChatClientAgent GetAIAgent( throw new ArgumentNullException(nameof(persistentAgentResponse)); } - return GetAIAgent(persistentAgentsClient, persistentAgentResponse.Value, chatOptions, clientFactory, services); + return AsAIAgent(persistentAgentsClient, persistentAgentResponse.Value, chatOptions, clientFactory, services); } /// @@ -43,7 +43,7 @@ public static ChatClientAgent GetAIAgent( /// Provides a way to customize the creation of the underlying used by the agent. /// An optional to use for resolving services required by the instances being invoked. /// A instance that can be used to perform operations on the persistent agent. - public static ChatClientAgent GetAIAgent( + public static ChatClientAgent AsAIAgent( this PersistentAgentsClient persistentAgentsClient, PersistentAgent persistentAgentMetadata, ChatOptions? chatOptions = null, @@ -112,7 +112,7 @@ public static ChatClientAgent GetAIAgent( } var persistentAgentResponse = persistentAgentsClient.Administration.GetAgent(agentId, cancellationToken); - return persistentAgentsClient.GetAIAgent(persistentAgentResponse, chatOptions, clientFactory, services); + return persistentAgentsClient.AsAIAgent(persistentAgentResponse, chatOptions, clientFactory, services); } /// @@ -145,7 +145,7 @@ public static async Task GetAIAgentAsync( } var persistentAgentResponse = await persistentAgentsClient.Administration.GetAgentAsync(agentId, cancellationToken).ConfigureAwait(false); - return persistentAgentsClient.GetAIAgent(persistentAgentResponse, chatOptions, clientFactory, services); + return persistentAgentsClient.AsAIAgent(persistentAgentResponse, chatOptions, clientFactory, services); } /// @@ -158,7 +158,7 @@ public static async Task GetAIAgentAsync( /// An optional to use for resolving services required by the instances being invoked. /// A instance that can be used to perform operations on the persistent agent. /// Thrown when or is . - public static ChatClientAgent GetAIAgent( + public static ChatClientAgent AsAIAgent( this PersistentAgentsClient persistentAgentsClient, Response persistentAgentResponse, ChatClientAgentOptions options, @@ -170,7 +170,7 @@ public static ChatClientAgent GetAIAgent( throw new ArgumentNullException(nameof(persistentAgentResponse)); } - return GetAIAgent(persistentAgentsClient, persistentAgentResponse.Value, options, clientFactory, services); + return AsAIAgent(persistentAgentsClient, persistentAgentResponse.Value, options, clientFactory, services); } /// @@ -183,7 +183,7 @@ public static ChatClientAgent GetAIAgent( /// An optional to use for resolving services required by the instances being invoked. /// A instance that can be used to perform operations on the persistent agent. /// Thrown when or is . - public static ChatClientAgent GetAIAgent( + public static ChatClientAgent AsAIAgent( this PersistentAgentsClient persistentAgentsClient, PersistentAgent persistentAgentMetadata, ChatClientAgentOptions options, @@ -268,7 +268,7 @@ public static ChatClientAgent GetAIAgent( } var persistentAgentResponse = persistentAgentsClient.Administration.GetAgent(agentId, cancellationToken); - return persistentAgentsClient.GetAIAgent(persistentAgentResponse, options, clientFactory, services); + return persistentAgentsClient.AsAIAgent(persistentAgentResponse, options, clientFactory, services); } /// @@ -307,7 +307,7 @@ public static async Task GetAIAgentAsync( } var persistentAgentResponse = await persistentAgentsClient.Administration.GetAgentAsync(agentId, cancellationToken).ConfigureAwait(false); - return persistentAgentsClient.GetAIAgent(persistentAgentResponse, options, clientFactory, services); + return persistentAgentsClient.AsAIAgent(persistentAgentResponse, options, clientFactory, services); } /// diff --git a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs index 7319bb13eb..8e03a33be3 100644 --- a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs @@ -91,7 +91,7 @@ public static ChatClientAgent GetAIAgent( AgentRecord agentRecord = GetAgentRecordByName(aiProjectClient, name, cancellationToken); - return GetAIAgent( + return AsAIAgent( aiProjectClient, agentRecord, tools, @@ -125,7 +125,7 @@ public static async Task GetAIAgentAsync( AgentRecord agentRecord = await GetAgentRecordByNameAsync(aiProjectClient, name, cancellationToken).ConfigureAwait(false); - return GetAIAgent( + return AsAIAgent( aiProjectClient, agentRecord, tools, @@ -143,7 +143,7 @@ public static async Task GetAIAgentAsync( /// An optional to use for resolving services required by the instances being invoked. /// A instance that can be used to perform operations based on the latest version of the Azure AI Agent. /// Thrown when or is . - public static ChatClientAgent GetAIAgent( + public static ChatClientAgent AsAIAgent( this AIProjectClient aiProjectClient, AgentRecord agentRecord, IList? tools = null, @@ -174,7 +174,7 @@ public static ChatClientAgent GetAIAgent( /// An optional to use for resolving services required by the instances being invoked. /// A instance that can be used to perform operations based on the provided version of the Azure AI Agent. /// Thrown when or is . - public static ChatClientAgent GetAIAgent( + public static ChatClientAgent AsAIAgent( this AIProjectClient aiProjectClient, AgentVersion agentVersion, IList? tools = null, diff --git a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs index 203bab21ed..6b69975f50 100644 --- a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs @@ -42,23 +42,23 @@ public CopilotStudioAgent(CopilotClient client, ILoggerFactory? loggerFactory = } /// - public sealed override AgentThread GetNewThread() - => new CopilotStudioAgentThread(); + public sealed override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) + => new(new CopilotStudioAgentThread()); /// /// Get a new instance using an existing conversation id, to continue that conversation. /// /// The conversation id to continue. /// A new instance. - public AgentThread GetNewThread(string conversationId) - => new CopilotStudioAgentThread() { ConversationId = conversationId }; + public ValueTask GetNewThreadAsync(string conversationId) + => new(new CopilotStudioAgentThread() { ConversationId = conversationId }); /// - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - => new CopilotStudioAgentThread(serializedThread, jsonSerializerOptions); + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + => new(new CopilotStudioAgentThread(serializedThread, jsonSerializerOptions)); /// - protected override async Task RunCoreAsync( + protected override async Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -68,7 +68,7 @@ protected override async Task RunCoreAsync( // Ensure that we have a valid thread to work with. // If the thread ID is null, we need to start a new conversation and set the thread ID accordingly. - thread ??= this.GetNewThread(); + thread ??= await this.GetNewThreadAsync(cancellationToken).ConfigureAwait(false); if (thread is not CopilotStudioAgentThread typedThread) { throw new InvalidOperationException("The provided thread is not compatible with the agent. Only threads created by the agent can be used."); @@ -88,7 +88,7 @@ protected override async Task RunCoreAsync( // TODO: Review list of ChatResponse properties to ensure we set all availble values. // Setting ResponseId and MessageId end up being particularly important for streaming consumers // so that they can tell things like response boundaries. - return new AgentRunResponse(responseMessagesList) + return new AgentResponse(responseMessagesList) { AgentId = this.Id, ResponseId = responseMessagesList.LastOrDefault()?.MessageId, @@ -96,7 +96,7 @@ protected override async Task RunCoreAsync( } /// - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -106,7 +106,8 @@ protected override async IAsyncEnumerable RunCoreStreami // Ensure that we have a valid thread to work with. // If the thread ID is null, we need to start a new conversation and set the thread ID accordingly. - thread ??= this.GetNewThread(); + + thread ??= await this.GetNewThreadAsync(cancellationToken).ConfigureAwait(false); if (thread is not CopilotStudioAgentThread typedThread) { throw new InvalidOperationException("The provided thread is not compatible with the agent. Only threads created by the agent can be used."); @@ -124,7 +125,7 @@ protected override async IAsyncEnumerable RunCoreStreami // TODO: Review list of ChatResponse properties to ensure we set all availble values. // Setting ResponseId and MessageId end up being particularly important for streaming consumers // so that they can tell things like response boundaries. - yield return new AgentRunResponseUpdate(message.Role, message.Contents) + yield return new AgentResponseUpdate(message.Role, message.Contents) { AgentId = this.Id, AdditionalProperties = message.AdditionalProperties, diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs index 4e3b66fd54..45c0d09536 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs @@ -2,6 +2,7 @@ using System; using System.Diagnostics.CodeAnalysis; +using System.Threading.Tasks; using Azure.Identity; using Microsoft.Azure.Cosmos; @@ -35,7 +36,7 @@ public static ChatClientAgentOptions WithCosmosDBMessageStore( throw new ArgumentNullException(nameof(options)); } - options.ChatMessageStoreFactory = context => new CosmosChatMessageStore(connectionString, databaseId, containerId); + options.ChatMessageStoreFactory = (context, ct) => new ValueTask(new CosmosChatMessageStore(connectionString, databaseId, containerId)); return options; } @@ -62,7 +63,7 @@ public static ChatClientAgentOptions WithCosmosDBMessageStoreUsingManagedIdentit throw new ArgumentNullException(nameof(options)); } - options.ChatMessageStoreFactory = context => new CosmosChatMessageStore(accountEndpoint, new DefaultAzureCredential(), databaseId, containerId); + options.ChatMessageStoreFactory = (context, ct) => new ValueTask(new CosmosChatMessageStore(accountEndpoint, new DefaultAzureCredential(), databaseId, containerId)); return options; } @@ -89,7 +90,7 @@ public static ChatClientAgentOptions WithCosmosDBMessageStore( throw new ArgumentNullException(nameof(options)); } - options.ChatMessageStoreFactory = context => new CosmosChatMessageStore(cosmosClient, databaseId, containerId); + options.ChatMessageStoreFactory = (context, ct) => new ValueTask(new CosmosChatMessageStore(cosmosClient, databaseId, containerId)); return options; } } diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentEntity.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentEntity.cs index ec4ba3acf6..15432804f3 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentEntity.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentEntity.cs @@ -21,7 +21,7 @@ internal class AgentEntity(IServiceProvider services, CancellationToken cancella ? cancellationToken : services.GetService()?.ApplicationStopping ?? CancellationToken.None; - public Task RunAgentAsync(RunRequest request) + public Task RunAgentAsync(RunRequest request) { return this.Run(request); } @@ -29,7 +29,7 @@ public Task RunAgentAsync(RunRequest request) // IDE1006 and VSTHRD200 disabled to allow method name to match the common cross-platform entity operation name. #pragma warning disable IDE1006 #pragma warning disable VSTHRD200 - public async Task Run(RunRequest request) + public async Task Run(RunRequest request) #pragma warning restore VSTHRD200 #pragma warning restore IDE1006 { @@ -43,7 +43,7 @@ public async Task Run(RunRequest request) if (request.Messages.Count == 0) { logger.LogInformation("Ignoring empty request"); - return new AgentRunResponse(); + return new AgentResponse(); } this.State.Data.ConversationHistory.Add(DurableAgentStateRequest.FromRunRequest(request)); @@ -65,29 +65,29 @@ public async Task Run(RunRequest request) try { // Start the agent response stream - IAsyncEnumerable responseStream = agentWrapper.RunStreamingAsync( + IAsyncEnumerable responseStream = agentWrapper.RunStreamingAsync( this.State.Data.ConversationHistory.SelectMany(e => e.Messages).Select(m => m.ToChatMessage()), - agentWrapper.GetNewThread(), + await agentWrapper.GetNewThreadAsync(cancellationToken).ConfigureAwait(false), options: null, this._cancellationToken); - AgentRunResponse response; + AgentResponse response; if (this._messageHandler is null) { // If no message handler is provided, we can just get the full response at once. // This is expected to be the common case for non-interactive agents. - response = await responseStream.ToAgentRunResponseAsync(this._cancellationToken); + response = await responseStream.ToAgentResponseAsync(this._cancellationToken); } else { - List responseUpdates = []; + List responseUpdates = []; // To support interactive chat agents, we need to stream the responses to an IAgentMessageHandler. // The user-provided message handler can be implemented to send the responses to the user. // We assume that only non-empty text updates are useful for the user. - async IAsyncEnumerable StreamResultsAsync() + async IAsyncEnumerable StreamResultsAsync() { - await foreach (AgentRunResponseUpdate update in responseStream) + await foreach (AgentResponseUpdate update in responseStream) { // We need the full response further down, so we piece it together as we go. responseUpdates.Add(update); @@ -98,12 +98,12 @@ async IAsyncEnumerable StreamResultsAsync() } await this._messageHandler.OnStreamingResponseUpdateAsync(StreamResultsAsync(), this._cancellationToken); - response = responseUpdates.ToAgentRunResponse(); + response = responseUpdates.ToAgentResponse(); } // Persist the agent response to the entity state for client polling this.State.Data.ConversationHistory.Add( - DurableAgentStateResponse.FromRunResponse(request.CorrelationId, response)); + DurableAgentStateResponse.FromResponse(request.CorrelationId, response)); string responseText = response.Text; diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentRunHandle.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentRunHandle.cs index e4fe08dbf2..0ff329153f 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentRunHandle.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentRunHandle.cs @@ -44,7 +44,7 @@ internal AgentRunHandle( /// The cancellation token. /// The agent response corresponding to this request. /// Thrown when the response is not found after polling. - public async Task ReadAgentResponseAsync(CancellationToken cancellationToken = default) + public async Task ReadAgentResponseAsync(CancellationToken cancellationToken = default) { TimeSpan pollInterval = TimeSpan.FromMilliseconds(50); // Start with 50ms TimeSpan maxPollInterval = TimeSpan.FromSeconds(3); // Maximum 3 seconds @@ -69,7 +69,7 @@ public async Task ReadAgentResponseAsync(CancellationToken can if (response is not null) { this._logger.LogDonePollingForResponse(this.SessionId, this.CorrelationId); - return response.ToRunResponse(); + return response.ToResponse(); } } diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs index d841a80ddd..dd598e2618 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs @@ -32,11 +32,12 @@ internal DurableAIAgent(TaskOrchestrationContext context, string agentName) /// /// Creates a new agent thread for this agent using a random session ID. /// - /// A new agent thread. - public override AgentThread GetNewThread() + /// The cancellation token. + /// A value task that represents the asynchronous operation. The task result contains a new agent thread. + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) { AgentSessionId sessionId = this._context.NewAgentSessionId(this._agentName); - return new DurableAgentThread(sessionId); + return ValueTask.FromResult(new DurableAgentThread(sessionId)); } /// @@ -44,12 +45,13 @@ public override AgentThread GetNewThread() /// /// The serialized thread data. /// Optional JSON serializer options. - /// The deserialized agent thread. - public override AgentThread DeserializeThread( + /// The cancellation token. + /// A value task that represents the asynchronous operation. The task result contains the deserialized agent thread. + public override ValueTask DeserializeThreadAsync( JsonElement serializedThread, - JsonSerializerOptions? jsonSerializerOptions = null) + JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - return DurableAgentThread.Deserialize(serializedThread, jsonSerializerOptions); + return ValueTask.FromResult(DurableAgentThread.Deserialize(serializedThread, jsonSerializerOptions)); } /// @@ -63,7 +65,7 @@ public override AgentThread DeserializeThread( /// Thrown when the agent has not been registered. /// Thrown when the provided thread is not valid for a durable agent. /// Thrown when cancellation is requested (cancellation is not supported for durable agents). - protected override async Task RunCoreAsync( + protected override async Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -74,12 +76,12 @@ protected override async Task RunCoreAsync( throw new NotSupportedException("Cancellation is not supported for durable agents."); } - thread ??= this.GetNewThread(); + thread ??= await this.GetNewThreadAsync(cancellationToken).ConfigureAwait(false); if (thread is not DurableAgentThread durableThread) { throw new ArgumentException( "The provided thread is not valid for a durable agent. " + - "Create a new thread using GetNewThread or provide a thread previously created by this agent.", + "Create a new thread using GetNewThreadAsync or provide a thread previously created by this agent.", paramName: nameof(thread)); } @@ -105,7 +107,7 @@ protected override async Task RunCoreAsync( try { - return await this._context.Entities.CallEntityAsync( + return await this._context.Entities.CallEntityAsync( durableThread.SessionId, nameof(AgentEntity.Run), request); @@ -128,7 +130,7 @@ protected override async Task RunCoreAsync( /// Optional run options. /// The cancellation token. /// A streaming response enumerable. - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -136,8 +138,8 @@ protected override async IAsyncEnumerable RunCoreStreami { // Streaming is not supported for durable agents, so we just return the full response // as a single update. - AgentRunResponse response = await this.RunAsync(messages, thread, options, cancellationToken); - foreach (AgentRunResponseUpdate update in response.ToAgentRunResponseUpdates()) + AgentResponse response = await this.RunAsync(messages, thread, options, cancellationToken); + foreach (AgentResponseUpdate update in response.ToAgentResponseUpdates()) { yield return update; } @@ -160,7 +162,7 @@ protected override async IAsyncEnumerable RunCoreStreami /// Thrown when the agent response is empty or cannot be deserialized. /// /// The output from the agent. - public async Task> RunAsync( + public async Task> RunAsync( string message, AgentThread? thread = null, JsonSerializerOptions? serializerOptions = null, @@ -194,7 +196,7 @@ public async Task> RunAsync( /// The output from the agent. [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Fallback to reflection-based deserialization is intentional for library flexibility with user-defined types.")] [UnconditionalSuppressMessage("ReflectionAnalysis", "IL3050", Justification = "Fallback to reflection-based deserialization is intentional for library flexibility with user-defined types.")] - public async Task> RunAsync( + public async Task> RunAsync( IEnumerable messages, AgentThread? thread = null, JsonSerializerOptions? serializerOptions = null, @@ -221,7 +223,7 @@ public async Task> RunAsync( // Create the JSON schema for the response type durableOptions.ResponseFormat = ChatResponseFormat.ForJsonSchema(); - AgentRunResponse response = await this.RunAsync(messages, thread, durableOptions, cancellationToken); + AgentResponse response = await this.RunAsync(messages, thread, durableOptions, cancellationToken); // Deserialize the response text to the requested type if (string.IsNullOrEmpty(response.Text)) @@ -240,11 +242,11 @@ public async Task> RunAsync( : JsonSerializer.Deserialize(response.Text, serializerOptions)) ?? throw new InvalidOperationException($"Failed to deserialize agent response to type {typeof(T).Name}."); - return new DurableAIAgentRunResponse(response, result); + return new DurableAIAgentResponse(response, result); } - private sealed class DurableAIAgentRunResponse(AgentRunResponse response, T result) - : AgentRunResponse(response.AsChatResponse()) + private sealed class DurableAIAgentResponse(AgentResponse response, T result) + : AgentResponse(response.AsChatResponse()) { public override T Result { get; } = result; } diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgentProxy.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgentProxy.cs index ecff2d5c90..2461302e6e 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgentProxy.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgentProxy.cs @@ -11,25 +11,25 @@ internal class DurableAIAgentProxy(string name, IDurableAgentClient agentClient) public override string? Name { get; } = name; - public override AgentThread DeserializeThread( + public override ValueTask DeserializeThreadAsync( JsonElement serializedThread, - JsonSerializerOptions? jsonSerializerOptions = null) + JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - return DurableAgentThread.Deserialize(serializedThread, jsonSerializerOptions); + return ValueTask.FromResult(DurableAgentThread.Deserialize(serializedThread, jsonSerializerOptions)); } - public override AgentThread GetNewThread() + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) { - return new DurableAgentThread(AgentSessionId.WithRandomKey(this.Name!)); + return ValueTask.FromResult(new DurableAgentThread(AgentSessionId.WithRandomKey(this.Name!))); } - protected override async Task RunCoreAsync( + protected override async Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { - thread ??= this.GetNewThread(); + thread ??= await this.GetNewThreadAsync(cancellationToken).ConfigureAwait(false); if (thread is not DurableAgentThread durableThread) { throw new ArgumentException( @@ -64,13 +64,13 @@ protected override async Task RunCoreAsync( if (isFireAndForget) { // If the request is fire and forget, return an empty response. - return new AgentRunResponse(); + return new AgentResponse(); } return await agentRunHandle.ReadAgentResponseAsync(cancellationToken); } - protected override IAsyncEnumerable RunCoreStreamingAsync( + protected override IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentJsonUtilities.cs index e3864e9ad4..966218058c 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentJsonUtilities.cs @@ -20,7 +20,7 @@ namespace Microsoft.Agents.AI.DurableTask; /// baseline defaults. /// for default null-value suppression. /// to tolerate numbers encoded as strings. -/// Chained type info resolvers from shared agent abstractions to cover cross-package types (e.g. , ). +/// Chained type info resolvers from shared agent abstractions to cover cross-package types (e.g. , ). /// /// /// Keep the list of [JsonSerializable] types in sync with the Durable Agent data model anytime new state or request/response diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/EntityAgentWrapper.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/EntityAgentWrapper.cs index 4a6074fcb6..e58db5e9b4 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/EntityAgentWrapper.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/EntityAgentWrapper.cs @@ -21,13 +21,13 @@ internal sealed class EntityAgentWrapper( // The ID of the agent is always the entity ID. protected override string? IdCore => this._entityContext.Id.ToString(); - protected override async Task RunCoreAsync( + protected override async Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { - AgentRunResponse response = await base.RunCoreAsync( + AgentResponse response = await base.RunCoreAsync( messages, thread, this.GetAgentEntityRunOptions(options), @@ -37,13 +37,13 @@ protected override async Task RunCoreAsync( return response; } - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - await foreach (AgentRunResponseUpdate update in base.RunCoreStreamingAsync( + await foreach (AgentResponseUpdate update in base.RunCoreStreamingAsync( messages, thread, this.GetAgentEntityRunOptions(options), diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/IAgentResponseHandler.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/IAgentResponseHandler.cs index 45a4e9f258..c12a765e00 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/IAgentResponseHandler.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/IAgentResponseHandler.cs @@ -17,7 +17,7 @@ public interface IAgentResponseHandler /// Signals that the operation should be cancelled. /// ValueTask OnStreamingResponseUpdateAsync( - IAsyncEnumerable messageStream, + IAsyncEnumerable messageStream, CancellationToken cancellationToken); /// @@ -30,6 +30,6 @@ ValueTask OnStreamingResponseUpdateAsync( /// Signals that the operation should be cancelled. /// ValueTask OnAgentResponseAsync( - AgentRunResponse message, + AgentResponse message, CancellationToken cancellationToken); } diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/Microsoft.Agents.AI.DurableTask.csproj b/dotnet/src/Microsoft.Agents.AI.DurableTask/Microsoft.Agents.AI.DurableTask.csproj index 41284e1085..43ebe9c61f 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/Microsoft.Agents.AI.DurableTask.csproj +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/Microsoft.Agents.AI.DurableTask.csproj @@ -4,7 +4,7 @@ $(TargetFrameworksCore) enable - + $(NoWarn);CA2007;MEAI001 diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/State/DurableAgentStateResponse.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/State/DurableAgentStateResponse.cs index 216bb6e05c..612ff4b48f 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/State/DurableAgentStateResponse.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/State/DurableAgentStateResponse.cs @@ -17,12 +17,12 @@ internal sealed class DurableAgentStateResponse : DurableAgentStateEntry public DurableAgentStateUsage? Usage { get; init; } /// - /// Creates a from an . + /// Creates a from an . /// /// The correlation ID linking this response to its request. - /// The to convert. + /// The to convert. /// A representing the original response. - public static DurableAgentStateResponse FromRunResponse(string correlationId, AgentRunResponse response) + public static DurableAgentStateResponse FromResponse(string correlationId, AgentResponse response) { return new DurableAgentStateResponse() { @@ -34,12 +34,12 @@ public static DurableAgentStateResponse FromRunResponse(string correlationId, Ag } /// - /// Converts this back to an . + /// Converts this back to an . /// - /// A representing this response. - public AgentRunResponse ToRunResponse() + /// A representing this response. + public AgentResponse ToResponse() { - return new AgentRunResponse() + return new AgentResponse() { CreatedAt = this.CreatedAt, Messages = this.Messages.Select(m => m.ToChatMessage()).ToList(), diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.A2A/AIAgentExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Hosting.A2A/AIAgentExtensions.cs index 499d724b1a..a2cb300687 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting.A2A/AIAgentExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.A2A/AIAgentExtensions.cs @@ -90,13 +90,7 @@ public static ITaskManager MapA2A( // we can help user if they did not set Url explicitly. if (string.IsNullOrEmpty(agentCard.Url)) { - var agentCardUrl = context.TrimEnd('/'); - if (!context.EndsWith("/v1/card", StringComparison.Ordinal)) - { - agentCardUrl += "/v1/card"; - } - - agentCard.Url = agentCardUrl; + agentCard.Url = context.TrimEnd('/'); } return Task.FromResult(agentCard); diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.AzureFunctions/BuiltInFunctions.cs b/dotnet/src/Microsoft.Agents.AI.Hosting.AzureFunctions/BuiltInFunctions.cs index 3d824994e9..edde523271 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting.AzureFunctions/BuiltInFunctions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.AzureFunctions/BuiltInFunctions.cs @@ -119,7 +119,7 @@ public static async Task RunAgentHttpAsync( if (waitForResponse) { - AgentRunResponse agentResponse = await agentProxy.RunAsync( + AgentResponse agentResponse = await agentProxy.RunAsync( message: new ChatMessage(ChatRole.User, message), thread: new DurableAgentThread(sessionId), options: options, @@ -170,7 +170,7 @@ await agentProxy.RunAsync( AIAgent agentProxy = client.AsDurableAgentProxy(functionContext, agentName); - AgentRunResponse agentResponse = await agentProxy.RunAsync( + AgentResponse agentResponse = await agentProxy.RunAsync( message: new ChatMessage(ChatRole.User, query), thread: new DurableAgentThread(sessionId), options: null); @@ -224,7 +224,7 @@ private static async Task CreateSuccessResponseAsync( FunctionContext context, HttpStatusCode statusCode, string threadId, - AgentRunResponse agentResponse) + AgentResponse agentResponse) { HttpResponseData response = req.CreateResponse(statusCode); response.Headers.Add("x-ms-thread-id", threadId); @@ -321,7 +321,7 @@ private sealed record ErrorResponse( private sealed record AgentRunSuccessResponse( [property: JsonPropertyName("status")] int Status, [property: JsonPropertyName("thread_id")] string ThreadId, - [property: JsonPropertyName("response")] AgentRunResponse Response); + [property: JsonPropertyName("response")] AgentResponse Response); /// /// Represents an accepted (fire-and-forget) agent run response. diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.AzureFunctions/README.md b/dotnet/src/Microsoft.Agents.AI.Hosting.AzureFunctions/README.md index 4e819e5985..2b3c87c348 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting.AzureFunctions/README.md +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.AzureFunctions/README.md @@ -36,13 +36,13 @@ This package provides a `ConfigureDurableAgents` extension method on the `Functi // Invocable via HTTP via http://localhost:7071/api/agents/SpamDetectionAgent/run AIAgent spamDetector = new AzureOpenAIClient(new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent( + .AsAIAgent( instructions: "You are a spam detection assistant that identifies spam emails.", name: "SpamDetectionAgent"); AIAgent emailAssistant = new AzureOpenAIClient(new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent( + .AsAIAgent( instructions: "You are an email assistant that helps users draft responses to emails with professionalism.", name: "EmailAssistantAgent"); @@ -74,10 +74,10 @@ public static async Task SpamDetectionOrchestration( // Get the spam detection agent DurableAIAgent spamDetectionAgent = context.GetAgent("SpamDetectionAgent"); - AgentThread spamThread = spamDetectionAgent.GetNewThread(); + AgentThread spamThread = await spamDetectionAgent.GetNewThreadAsync(); // Step 1: Check if the email is spam - AgentRunResponse spamDetectionResponse = await spamDetectionAgent.RunAsync( + AgentResponse spamDetectionResponse = await spamDetectionAgent.RunAsync( message: $""" Analyze this email for spam content and return a JSON response with 'is_spam' (boolean) and 'reason' (string) fields: @@ -97,9 +97,9 @@ public static async Task SpamDetectionOrchestration( { // Generate and send response for legitimate email DurableAIAgent emailAssistantAgent = context.GetAgent("EmailAssistantAgent"); - AgentThread emailThread = emailAssistantAgent.GetNewThread(); + AgentThread emailThread = await emailAssistantAgent.GetNewThreadAsync(); - AgentRunResponse emailAssistantResponse = await emailAssistantAgent.RunAsync( + AgentResponse emailAssistantResponse = await emailAssistantAgent.RunAsync( message: $""" Draft a professional response to this email. Return a JSON response with a 'response' field containing the reply: @@ -156,7 +156,7 @@ These tools are registered with the agent using the `tools` parameter when creat Tools tools = new(); AIAgent agent = new AzureOpenAIClient(new Uri(endpoint), new AzureCliCredential()) .GetChatClient(deploymentName) - .CreateAIAgent( + .AsAIAgent( instructions: "You are a content generation assistant that helps users generate content.", name: "ContentGenerationAgent", tools: [ diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/ChatCompletions/AIAgentChatCompletionsProcessor.cs b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/ChatCompletions/AIAgentChatCompletionsProcessor.cs index d9e51b6aa2..42443dc2ca 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/ChatCompletions/AIAgentChatCompletionsProcessor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/ChatCompletions/AIAgentChatCompletionsProcessor.cs @@ -70,18 +70,18 @@ private async IAsyncEnumerable> GetStreamingChunksA DateTimeOffset? createdAt = null; var chunkId = IdGenerator.NewId(prefix: "chatcmpl", delimiter: "-", stringLength: 13); - await foreach (var agentRunResponseUpdate in agent.RunStreamingAsync(chatMessages, options: options, cancellationToken: cancellationToken).WithCancellation(cancellationToken)) + await foreach (var agentResponseUpdate in agent.RunStreamingAsync(chatMessages, options: options, cancellationToken: cancellationToken).WithCancellation(cancellationToken)) { - var finishReason = (agentRunResponseUpdate.RawRepresentation is ChatResponseUpdate { FinishReason: not null } chatResponseUpdate) + var finishReason = (agentResponseUpdate.RawRepresentation is ChatResponseUpdate { FinishReason: not null } chatResponseUpdate) ? chatResponseUpdate.FinishReason.ToString() : "stop"; var choiceChunks = new List(); CompletionUsage? usageDetails = null; - createdAt ??= agentRunResponseUpdate.CreatedAt; + createdAt ??= agentResponseUpdate.CreatedAt; - foreach (var content in agentRunResponseUpdate.Contents) + foreach (var content in agentResponseUpdate.Contents) { // usage content is handled separately if (content is UsageContent usageContent && usageContent.Details != null) @@ -124,7 +124,7 @@ private async IAsyncEnumerable> GetStreamingChunksA continue; } - delta.Role = agentRunResponseUpdate.Role?.Value ?? "user"; + delta.Role = agentResponseUpdate.Role?.Value ?? "user"; var choiceChunk = new ChatCompletionChoiceChunk { diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/ChatCompletions/AgentRunResponseExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/ChatCompletions/AgentResponseExtensions.cs similarity index 92% rename from dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/ChatCompletions/AgentRunResponseExtensions.cs rename to dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/ChatCompletions/AgentResponseExtensions.cs index f50aa44d4d..95d7df0231 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/ChatCompletions/AgentRunResponseExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/ChatCompletions/AgentResponseExtensions.cs @@ -12,33 +12,33 @@ namespace Microsoft.Agents.AI.Hosting.OpenAI.ChatCompletions; /// /// Extension methods for converting agent responses to ChatCompletion models. /// -internal static class AgentRunResponseExtensions +internal static class AgentResponseExtensions { - public static ChatCompletion ToChatCompletion(this AgentRunResponse agentRunResponse, CreateChatCompletion request) + public static ChatCompletion ToChatCompletion(this AgentResponse agentResponse, CreateChatCompletion request) { - IList choices = agentRunResponse.ToChoices(); + IList choices = agentResponse.ToChoices(); return new ChatCompletion { Id = IdGenerator.NewId(prefix: "chatcmpl", delimiter: "-", stringLength: 13), Choices = choices, - Created = (agentRunResponse.CreatedAt ?? DateTimeOffset.UtcNow).ToUnixTimeSeconds(), + Created = (agentResponse.CreatedAt ?? DateTimeOffset.UtcNow).ToUnixTimeSeconds(), Model = request.Model, - Usage = agentRunResponse.Usage.ToCompletionUsage(), + Usage = agentResponse.Usage.ToCompletionUsage(), ServiceTier = request.ServiceTier ?? "default" }; } - public static List ToChoices(this AgentRunResponse agentRunResponse) + public static List ToChoices(this AgentResponse agentResponse) { var chatCompletionChoices = new List(); var index = 0; - var finishReason = (agentRunResponse.RawRepresentation is ChatResponse { FinishReason: not null } chatResponse) + var finishReason = (agentResponse.RawRepresentation is ChatResponse { FinishReason: not null } chatResponse) ? chatResponse.FinishReason.ToString() : "stop"; // "stop" is a natural stop point; returning this by-default - foreach (var message in agentRunResponse.Messages) + foreach (var message in agentResponse.Messages) { foreach (var content in message.Contents) { diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/AgentRunResponseExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/AgentResponseExtensions.cs similarity index 96% rename from dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/AgentRunResponseExtensions.cs rename to dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/AgentResponseExtensions.cs index 97dcf9740f..2734fad427 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/AgentRunResponseExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/AgentResponseExtensions.cs @@ -13,19 +13,19 @@ namespace Microsoft.Agents.AI.Hosting.OpenAI.Responses; /// /// Extension methods for converting agent responses to Response models. /// -internal static class AgentRunResponseExtensions +internal static class AgentResponseExtensions { private static ChatRole s_DeveloperRole => new("developer"); /// - /// Converts an AgentRunResponse to a Response model. + /// Converts an AgentResponse to a Response model. /// - /// The agent run response to convert. + /// The agent response to convert. /// The original create response request. /// The agent invocation context. /// A Response model. public static Response ToResponse( - this AgentRunResponse agentRunResponse, + this AgentResponse agentResponse, CreateResponse request, AgentInvocationContext context) { @@ -41,7 +41,7 @@ public static Response ToResponse( }); } - output.AddRange(agentRunResponse.Messages + output.AddRange(agentResponse.Messages .SelectMany(msg => msg.ToItemResource(context.IdGenerator, context.JsonSerializerOptions))); return new Response @@ -49,7 +49,7 @@ public static Response ToResponse( Agent = request.Agent?.ToAgentId(), Background = request.Background, Conversation = request.Conversation ?? (context.ConversationId != null ? new ConversationReference { Id = context.ConversationId } : null), - CreatedAt = (agentRunResponse.CreatedAt ?? DateTimeOffset.UtcNow).ToUnixTimeSeconds(), + CreatedAt = (agentResponse.CreatedAt ?? DateTimeOffset.UtcNow).ToUnixTimeSeconds(), Error = null, Id = context.ResponseId, Instructions = request.Instructions, @@ -74,7 +74,7 @@ public static Response ToResponse( TopLogprobs = request.TopLogprobs, TopP = request.TopP ?? 1.0, Truncation = request.Truncation, - Usage = agentRunResponse.Usage.ToResponseUsage(), + Usage = agentResponse.Usage.ToResponseUsage(), #pragma warning disable CS0618 // Type or member is obsolete User = request.User, #pragma warning restore CS0618 // Type or member is obsolete diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/AgentRunResponseUpdateExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/AgentResponseUpdateExtensions.cs similarity index 97% rename from dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/AgentRunResponseUpdateExtensions.cs rename to dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/AgentResponseUpdateExtensions.cs index 628b80b340..f4c1e3c7a0 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/AgentRunResponseUpdateExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/AgentResponseUpdateExtensions.cs @@ -16,12 +16,12 @@ namespace Microsoft.Agents.AI.Hosting.OpenAI.Responses; /// -/// Extension methods for . +/// Extension methods for . /// -internal static class AgentRunResponseUpdateExtensions +internal static class AgentResponseUpdateExtensions { /// - /// Converts a stream of to stream of . + /// Converts a stream of to stream of . /// /// The agent run response updates. /// The create response request. @@ -29,7 +29,7 @@ internal static class AgentRunResponseUpdateExtensions /// The cancellation token. /// A stream of response events. public static async IAsyncEnumerable ToStreamingResponseAsync( - this IAsyncEnumerable updates, + this IAsyncEnumerable updates, CreateResponse request, AgentInvocationContext context, [EnumeratorCancellation] CancellationToken cancellationToken = default) @@ -48,7 +48,7 @@ public static async IAsyncEnumerable ToStreamingResponse // Track active item IDs by executor ID to pair invoked/completed/failed events Dictionary executorItemIds = []; - AgentRunResponseUpdate? previousUpdate = null; + AgentResponseUpdate? previousUpdate = null; StreamingEventGenerator? generator = null; while (await updateEnumerator.MoveNextAsync().ConfigureAwait(false)) { @@ -279,7 +279,7 @@ Response CreateResponse(ResponseStatus status = ResponseStatus.Completed, IEnume } } - private static bool IsSameMessage(AgentRunResponseUpdate? first, AgentRunResponseUpdate? second) + private static bool IsSameMessage(AgentResponseUpdate? first, AgentResponseUpdate? second) { return IsSameValue(first?.MessageId, second?.MessageId) && IsSameValue(first?.AuthorName, second?.AuthorName) diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs index e12d017343..733a7af9a7 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs @@ -1,8 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; -using Microsoft.Agents.AI.Hosting.Local; +using System.Linq; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Shared.Diagnostics; @@ -29,7 +28,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s return services.AddAIAgent(name, (sp, key) => { var chatClient = sp.GetRequiredService(); - var tools = GetRegisteredToolsForAgent(sp, name); + var tools = sp.GetKeyedServices(name).ToList(); return new ChatClientAgent(chatClient, instructions, key, tools: tools); }); } @@ -49,7 +48,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s Throw.IfNullOrEmpty(name); return services.AddAIAgent(name, (sp, key) => { - var tools = GetRegisteredToolsForAgent(sp, name); + var tools = sp.GetKeyedServices(name).ToList(); return new ChatClientAgent(chatClient, instructions, key, tools: tools); }); } @@ -70,7 +69,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s return services.AddAIAgent(name, (sp, key) => { var chatClient = chatClientServiceKey is null ? sp.GetRequiredService() : sp.GetRequiredKeyedService(chatClientServiceKey); - var tools = GetRegisteredToolsForAgent(sp, name); + var tools = sp.GetKeyedServices(name).ToList(); return new ChatClientAgent(chatClient, instructions, key, tools: tools); }); } @@ -92,7 +91,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s return services.AddAIAgent(name, (sp, key) => { var chatClient = chatClientServiceKey is null ? sp.GetRequiredService() : sp.GetRequiredKeyedService(chatClientServiceKey); - var tools = GetRegisteredToolsForAgent(sp, name); + var tools = sp.GetKeyedServices(name).ToList(); return new ChatClientAgent(chatClient, instructions: instructions, name: key, description: description, tools: tools); }); } @@ -127,10 +126,4 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s return new HostedAgentBuilder(name, services); } - - private static IList GetRegisteredToolsForAgent(IServiceProvider serviceProvider, string agentName) - { - var registry = serviceProvider.GetService(); - return registry?.GetTools(agentName) ?? []; - } } diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting/HostedAgentBuilderExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Hosting/HostedAgentBuilderExtensions.cs index d3a437663a..e2c52ff9e0 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting/HostedAgentBuilderExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting/HostedAgentBuilderExtensions.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Linq; -using Microsoft.Agents.AI.Hosting.Local; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Shared.Diagnostics; @@ -70,18 +68,7 @@ public static IHostedAgentBuilder WithAITool(this IHostedAgentBuilder builder, A Throw.IfNull(builder); Throw.IfNull(tool); - var agentName = builder.Name; - var services = builder.ServiceCollection; - - // Get or create the agent tool registry - var descriptor = services.FirstOrDefault(sd => !sd.IsKeyedService && sd.ServiceType.Equals(typeof(LocalAgentToolRegistry))); - if (descriptor?.ImplementationInstance is not LocalAgentToolRegistry toolRegistry) - { - toolRegistry = new(); - services.Add(ServiceDescriptor.Singleton(toolRegistry)); - } - - toolRegistry.AddTool(agentName, tool); + builder.ServiceCollection.AddKeyedSingleton(builder.Name, tool); return builder; } @@ -105,4 +92,19 @@ public static IHostedAgentBuilder WithAITools(this IHostedAgentBuilder builder, return builder; } + + /// + /// Adds AI tool to an agent being configured with the service collection. + /// + /// The hosted agent builder. + /// A factory function that creates a AI tool using the provided service provider. + public static IHostedAgentBuilder WithAITool(this IHostedAgentBuilder builder, Func factory) + { + Throw.IfNull(builder); + Throw.IfNull(factory); + + builder.ServiceCollection.AddKeyedSingleton(builder.Name, (sp, name) => factory(sp)); + + return builder; + } } diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting/Local/InMemoryAgentThreadStore.cs b/dotnet/src/Microsoft.Agents.AI.Hosting/Local/InMemoryAgentThreadStore.cs index 74bbe279fb..febbf4c06a 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting/Local/InMemoryAgentThreadStore.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting/Local/InMemoryAgentThreadStore.cs @@ -38,15 +38,15 @@ public override ValueTask SaveThreadAsync(AIAgent agent, string conversationId, } /// - public override ValueTask GetThreadAsync(AIAgent agent, string conversationId, CancellationToken cancellationToken = default) + public override async ValueTask GetThreadAsync(AIAgent agent, string conversationId, CancellationToken cancellationToken = default) { var key = GetKey(conversationId, agent.Id); JsonElement? threadContent = this._threads.TryGetValue(key, out var existingThread) ? existingThread : null; return threadContent switch { - null => new ValueTask(agent.GetNewThread()), - _ => new ValueTask(agent.DeserializeThread(threadContent.Value)), + null => await agent.GetNewThreadAsync(cancellationToken).ConfigureAwait(false), + _ => await agent.DeserializeThreadAsync(threadContent.Value, cancellationToken: cancellationToken).ConfigureAwait(false), }; } diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting/Local/LocalAgentToolRegistry.cs b/dotnet/src/Microsoft.Agents.AI.Hosting/Local/LocalAgentToolRegistry.cs deleted file mode 100644 index 8c87803db3..0000000000 --- a/dotnet/src/Microsoft.Agents.AI.Hosting/Local/LocalAgentToolRegistry.cs +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Collections.Generic; -using Microsoft.Extensions.AI; - -namespace Microsoft.Agents.AI.Hosting.Local; - -internal sealed class LocalAgentToolRegistry -{ - private readonly Dictionary> _toolsByAgentName = []; - - public void AddTool(string agentName, AITool tool) - { - if (!this._toolsByAgentName.TryGetValue(agentName, out var tools)) - { - tools = []; - this._toolsByAgentName[agentName] = tools; - } - - tools.Add(tool); - } - - public IList GetTools(string agentName) - { - return this._toolsByAgentName.TryGetValue(agentName, out var tools) ? tools : []; - } -} diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting/NoopAgentThreadStore.cs b/dotnet/src/Microsoft.Agents.AI.Hosting/NoopAgentThreadStore.cs index c94489d0b0..02c78178a0 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting/NoopAgentThreadStore.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting/NoopAgentThreadStore.cs @@ -20,6 +20,6 @@ public override ValueTask SaveThreadAsync(AIAgent agent, string conversationId, /// public override ValueTask GetThreadAsync(AIAgent agent, string conversationId, CancellationToken cancellationToken = default) { - return new ValueTask(agent.GetNewThread()); + return agent.GetNewThreadAsync(cancellationToken); } } diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/ChatClient/AsyncStreamingChatCompletionUpdateCollectionResult.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/ChatClient/AsyncStreamingChatCompletionUpdateCollectionResult.cs index 17c9c2d95a..db0c7a8673 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/ChatClient/AsyncStreamingChatCompletionUpdateCollectionResult.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/ChatClient/AsyncStreamingChatCompletionUpdateCollectionResult.cs @@ -7,9 +7,9 @@ namespace Microsoft.Agents.AI.OpenAI; internal sealed class AsyncStreamingChatCompletionUpdateCollectionResult : AsyncCollectionResult { - private readonly IAsyncEnumerable _updates; + private readonly IAsyncEnumerable _updates; - internal AsyncStreamingChatCompletionUpdateCollectionResult(IAsyncEnumerable updates) + internal AsyncStreamingChatCompletionUpdateCollectionResult(IAsyncEnumerable updates) { this._updates = updates; } @@ -23,7 +23,7 @@ public override async IAsyncEnumerable GetRawPagesAsync() protected override IAsyncEnumerable GetValuesFromPageAsync(ClientResult page) { - var updates = ((ClientResult>)page).Value; + var updates = ((ClientResult>)page).Value; return updates.AsChatResponseUpdatesAsync().AsOpenAIStreamingChatCompletionUpdatesAsync(); } diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/ChatClient/AsyncStreamingResponseUpdateCollectionResult.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/ChatClient/AsyncStreamingResponseUpdateCollectionResult.cs index c67f4d1462..77400b3377 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/ChatClient/AsyncStreamingResponseUpdateCollectionResult.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/ChatClient/AsyncStreamingResponseUpdateCollectionResult.cs @@ -7,9 +7,9 @@ namespace Microsoft.Agents.AI.OpenAI; internal sealed class AsyncStreamingResponseUpdateCollectionResult : AsyncCollectionResult { - private readonly IAsyncEnumerable _updates; + private readonly IAsyncEnumerable _updates; - internal AsyncStreamingResponseUpdateCollectionResult(IAsyncEnumerable updates) + internal AsyncStreamingResponseUpdateCollectionResult(IAsyncEnumerable updates) { this._updates = updates; } @@ -23,7 +23,7 @@ public override async IAsyncEnumerable GetRawPagesAsync() protected async override IAsyncEnumerable GetValuesFromPageAsync(ClientResult page) { - var updates = ((ClientResult>)page).Value; + var updates = ((ClientResult>)page).Value; await foreach (var update in updates.ConfigureAwait(false)) { diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/ChatClient/StreamingUpdatePipelineResponse.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/ChatClient/StreamingUpdatePipelineResponse.cs index e999ad04e7..3114464675 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/ChatClient/StreamingUpdatePipelineResponse.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/ChatClient/StreamingUpdatePipelineResponse.cs @@ -55,7 +55,7 @@ public override void Dispose() // No resources to dispose. } - internal StreamingUpdatePipelineResponse(IAsyncEnumerable updates) + internal StreamingUpdatePipelineResponse(IAsyncEnumerable updates) { } diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AIAgentWithOpenAIExtensions.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AIAgentWithOpenAIExtensions.cs index d487ba00e1..defc934b30 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AIAgentWithOpenAIExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AIAgentWithOpenAIExtensions.cs @@ -16,7 +16,7 @@ namespace Microsoft.Agents.AI; /// These extensions bridge the gap between the Microsoft Extensions AI framework and the OpenAI SDK, /// allowing developers to work with native OpenAI types while leveraging the AI Agent framework. /// The methods handle the conversion between OpenAI chat message types and Microsoft Extensions AI types, -/// and return OpenAI objects directly from the agent's . +/// and return OpenAI objects directly from the agent's . /// public static class AIAgentWithOpenAIExtensions { @@ -34,7 +34,7 @@ public static class AIAgentWithOpenAIExtensions /// Thrown when any message in has a type that is not supported by the message conversion method. /// /// This method converts the OpenAI chat messages to the Microsoft Extensions AI format using the appropriate conversion method, - /// runs the agent with the converted message collection, and then extracts the native OpenAI from the response using . + /// runs the agent with the converted message collection, and then extracts the native OpenAI from the response using . /// public static async Task RunAsync(this AIAgent agent, IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { @@ -60,14 +60,14 @@ public static async Task RunAsync(this AIAgent agent, IEnumerabl /// Thrown when the type is not supported by the message conversion method. /// /// This method converts the OpenAI chat messages to the Microsoft Extensions AI format using the appropriate conversion method, - /// runs the agent, and then extracts the native OpenAI from the response using . + /// runs the agent, and then extracts the native OpenAI from the response using . /// public static AsyncCollectionResult RunStreamingAsync(this AIAgent agent, IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { Throw.IfNull(agent); Throw.IfNull(messages); - IAsyncEnumerable response = agent.RunStreamingAsync([.. messages.AsChatMessages()], thread, options, cancellationToken); + IAsyncEnumerable response = agent.RunStreamingAsync([.. messages.AsChatMessages()], thread, options, cancellationToken); return new AsyncStreamingChatCompletionUpdateCollectionResult(response); } @@ -86,7 +86,7 @@ public static AsyncCollectionResult RunStreamingA /// Thrown when any message in has a type that is not supported by the message conversion method. /// /// This method converts the OpenAI response items to the Microsoft Extensions AI format using the appropriate conversion method, - /// runs the agent with the converted message collection, and then extracts the native OpenAI from the response using . + /// runs the agent with the converted message collection, and then extracts the native OpenAI from the response using . /// public static async Task RunAsync(this AIAgent agent, IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { @@ -121,7 +121,7 @@ public static AsyncCollectionResult RunStreamingAsync(t Throw.IfNull(agent); Throw.IfNull(messages); - IAsyncEnumerable response = agent.RunStreamingAsync([.. messages.AsChatMessages()], thread, options, cancellationToken); + IAsyncEnumerable response = agent.RunStreamingAsync([.. messages.AsChatMessages()], thread, options, cancellationToken); return new AsyncStreamingResponseUpdateCollectionResult(response); } diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AgentRunResponseExtensions.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AgentResponseExtensions.cs similarity index 79% rename from dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AgentRunResponseExtensions.cs rename to dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AgentResponseExtensions.cs index 44844e64f5..e855aaef56 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AgentRunResponseExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AgentResponseExtensions.cs @@ -8,18 +8,18 @@ namespace Microsoft.Agents.AI; /// -/// Provides extension methods for and instances to +/// Provides extension methods for and instances to /// create or extract native OpenAI response objects from the Microsoft Agent Framework responses. /// -public static class AgentRunResponseExtensions +public static class AgentResponseExtensions { /// - /// Creates or extracts a native OpenAI object from an . + /// Creates or extracts a native OpenAI object from an . /// /// The agent response. /// The OpenAI object. /// is . - public static ChatCompletion AsOpenAIChatCompletion(this AgentRunResponse response) + public static ChatCompletion AsOpenAIChatCompletion(this AgentResponse response) { Throw.IfNull(response); @@ -29,12 +29,12 @@ response.RawRepresentation as ChatCompletion ?? } /// - /// Creates or extracts a native OpenAI object from an . + /// Creates or extracts a native OpenAI object from an . /// /// The agent response. /// The OpenAI object. /// is . - public static ResponseResult AsOpenAIResponse(this AgentRunResponse response) + public static ResponseResult AsOpenAIResponse(this AgentResponse response) { Throw.IfNull(response); diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs index 881266fe8b..291ff56091 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs @@ -30,7 +30,7 @@ public static class OpenAIAssistantClientExtensions /// An optional to use for resolving services required by the instances being invoked. /// A instance that can be used to perform operations on the assistant. [Obsolete("The Assistants API has been deprecated. Please use the Responses API instead.")] - public static ChatClientAgent GetAIAgent( + public static ChatClientAgent AsAIAgent( this AssistantClient assistantClient, ClientResult assistantClientResult, ChatOptions? chatOptions = null, @@ -42,7 +42,7 @@ public static ChatClientAgent GetAIAgent( throw new ArgumentNullException(nameof(assistantClientResult)); } - return assistantClient.GetAIAgent(assistantClientResult.Value, chatOptions, clientFactory, services); + return assistantClient.AsAIAgent(assistantClientResult.Value, chatOptions, clientFactory, services); } /// @@ -55,7 +55,7 @@ public static ChatClientAgent GetAIAgent( /// An optional to use for resolving services required by the instances being invoked. /// A instance that can be used to perform operations on the assistant. [Obsolete("The Assistants API has been deprecated. Please use the Responses API instead.")] - public static ChatClientAgent GetAIAgent( + public static ChatClientAgent AsAIAgent( this AssistantClient assistantClient, Assistant assistantMetadata, ChatOptions? chatOptions = null, @@ -123,7 +123,7 @@ public static ChatClientAgent GetAIAgent( } var assistant = assistantClient.GetAssistant(agentId, cancellationToken); - return assistantClient.GetAIAgent(assistant, chatOptions, clientFactory, services); + return assistantClient.AsAIAgent(assistant, chatOptions, clientFactory, services); } /// @@ -156,7 +156,7 @@ public static async Task GetAIAgentAsync( } var assistantResponse = await assistantClient.GetAssistantAsync(agentId, cancellationToken).ConfigureAwait(false); - return assistantClient.GetAIAgent(assistantResponse, chatOptions, clientFactory, services); + return assistantClient.AsAIAgent(assistantResponse, chatOptions, clientFactory, services); } /// @@ -170,7 +170,7 @@ public static async Task GetAIAgentAsync( /// A instance that can be used to perform operations on the assistant. /// or is . [Obsolete("The Assistants API has been deprecated. Please use the Responses API instead.")] - public static ChatClientAgent GetAIAgent( + public static ChatClientAgent AsAIAgent( this AssistantClient assistantClient, ClientResult assistantClientResult, ChatClientAgentOptions options, @@ -182,7 +182,7 @@ public static ChatClientAgent GetAIAgent( throw new ArgumentNullException(nameof(assistantClientResult)); } - return assistantClient.GetAIAgent(assistantClientResult.Value, options, clientFactory, services); + return assistantClient.AsAIAgent(assistantClientResult.Value, options, clientFactory, services); } /// @@ -196,7 +196,7 @@ public static ChatClientAgent GetAIAgent( /// A instance that can be used to perform operations on the assistant. /// or is . [Obsolete("The Assistants API has been deprecated. Please use the Responses API instead.")] - public static ChatClientAgent GetAIAgent( + public static ChatClientAgent AsAIAgent( this AssistantClient assistantClient, Assistant assistantMetadata, ChatClientAgentOptions options, @@ -282,7 +282,7 @@ public static ChatClientAgent GetAIAgent( } var assistant = assistantClient.GetAssistant(agentId, cancellationToken); - return assistantClient.GetAIAgent(assistant, options, clientFactory, services); + return assistantClient.AsAIAgent(assistant, options, clientFactory, services); } /// @@ -322,7 +322,7 @@ public static async Task GetAIAgentAsync( } var assistantResponse = await assistantClient.GetAssistantAsync(agentId, cancellationToken).ConfigureAwait(false); - return assistantClient.GetAIAgent(assistantResponse, options, clientFactory, services); + return assistantClient.AsAIAgent(assistantResponse, options, clientFactory, services); } /// diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIChatClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIChatClientExtensions.cs index aa4f38e5f4..be3216083c 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIChatClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIChatClientExtensions.cs @@ -32,7 +32,7 @@ public static class OpenAIChatClientExtensions /// An optional to use for resolving services required by the instances being invoked. /// An instance backed by the OpenAI Chat Completion service. /// Thrown when is . - public static ChatClientAgent CreateAIAgent( + public static ChatClientAgent AsAIAgent( this ChatClient client, string? instructions = null, string? name = null, @@ -41,7 +41,7 @@ public static ChatClientAgent CreateAIAgent( Func? clientFactory = null, ILoggerFactory? loggerFactory = null, IServiceProvider? services = null) => - client.CreateAIAgent( + client.AsAIAgent( new ChatClientAgentOptions() { Name = name, @@ -66,7 +66,7 @@ public static ChatClientAgent CreateAIAgent( /// An optional to use for resolving services required by the instances being invoked. /// An instance backed by the OpenAI Chat Completion service. /// Thrown when or is . - public static ChatClientAgent CreateAIAgent( + public static ChatClientAgent AsAIAgent( this ChatClient client, ChatClientAgentOptions options, Func? clientFactory = null, diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIResponseClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIResponseClientExtensions.cs index 224bf5db95..bc9f28c37d 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIResponseClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIResponseClientExtensions.cs @@ -32,7 +32,7 @@ public static class OpenAIResponseClientExtensions /// An optional to use for resolving services required by the instances being invoked. /// An instance backed by the OpenAI Response service. /// Thrown when is . - public static ChatClientAgent CreateAIAgent( + public static ChatClientAgent AsAIAgent( this ResponsesClient client, string? instructions = null, string? name = null, @@ -44,7 +44,7 @@ public static ChatClientAgent CreateAIAgent( { Throw.IfNull(client); - return client.CreateAIAgent( + return client.AsAIAgent( new ChatClientAgentOptions() { Name = name, @@ -70,7 +70,7 @@ public static ChatClientAgent CreateAIAgent( /// An optional to use for resolving services required by the instances being invoked. /// An instance backed by the OpenAI Response service. /// Thrown when or is . - public static ChatClientAgent CreateAIAgent( + public static ChatClientAgent AsAIAgent( this ResponsesClient client, ChatClientAgentOptions options, Func? clientFactory = null, diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/PurviewAgent.cs b/dotnet/src/Microsoft.Agents.AI.Purview/PurviewAgent.cs index 6907fe8889..c30c089198 100644 --- a/dotnet/src/Microsoft.Agents.AI.Purview/PurviewAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Purview/PurviewAgent.cs @@ -30,28 +30,28 @@ public PurviewAgent(AIAgent innerAgent, PurviewWrapper purviewWrapper) } /// - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - return this._innerAgent.DeserializeThread(serializedThread, jsonSerializerOptions); + return this._innerAgent.DeserializeThreadAsync(serializedThread, jsonSerializerOptions, cancellationToken); } /// - public override AgentThread GetNewThread() + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) { - return this._innerAgent.GetNewThread(); + return this._innerAgent.GetNewThreadAsync(cancellationToken); } /// - protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { return this._purviewWrapper.ProcessAgentContentAsync(messages, thread, options, this._innerAgent, cancellationToken); } /// - protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var response = await this._purviewWrapper.ProcessAgentContentAsync(messages, thread, options, this._innerAgent, cancellationToken).ConfigureAwait(false); - foreach (var update in response.ToAgentRunResponseUpdates()) + foreach (var update in response.ToAgentResponseUpdates()) { yield return update; } diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/PurviewWrapper.cs b/dotnet/src/Microsoft.Agents.AI.Purview/PurviewWrapper.cs index a818fb264f..835b0146a8 100644 --- a/dotnet/src/Microsoft.Agents.AI.Purview/PurviewWrapper.cs +++ b/dotnet/src/Microsoft.Agents.AI.Purview/PurviewWrapper.cs @@ -134,7 +134,7 @@ public async Task ProcessChatContentAsync(IEnumerable /// The wrapped agent. /// The cancellation token used to interrupt async operations. /// The agent's response. This could be the response from the agent or a message indicating that Purview has blocked the prompt or response. - public async Task ProcessAgentContentAsync(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) + public async Task ProcessAgentContentAsync(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) { string threadId = GetThreadIdFromAgentThread(thread, messages); @@ -151,7 +151,7 @@ public async Task ProcessAgentContentAsync(IEnumerable ProcessAgentContentAsync(IEnumerable ProcessAgentContentAsync(IEnumerable GetResponseItems() } /// - public override async IAsyncEnumerable InvokeAgentAsync( + public override async IAsyncEnumerable InvokeAgentAsync( string agentId, string? agentVersion, string? conversationId, @@ -120,12 +120,12 @@ public override async IAsyncEnumerable InvokeAgentAsync( ChatClientAgentRunOptions runOptions = new(chatOptions); - IAsyncEnumerable agentResponse = + IAsyncEnumerable agentResponse = messages is not null ? agent.RunStreamingAsync([.. messages], null, runOptions, cancellationToken) : agent.RunStreamingAsync([new ChatMessage(ChatRole.User, string.Empty)], null, runOptions, cancellationToken); - await foreach (AgentRunResponseUpdate update in agentResponse.ConfigureAwait(false)) + await foreach (AgentResponseUpdate update in agentResponse.ConfigureAwait(false)) { update.AuthorName = agentVersionResult.Name; yield return update; @@ -174,7 +174,7 @@ private async Task GetAgentAsync(AgentVersion agentVersion, Cancellatio AIProjectClient client = this.GetAgentClient(); - agent = client.GetAIAgent(agentVersion, tools: null, clientFactory: null, services: null); + agent = client.AsAIAgent(agentVersion, tools: null, clientFactory: null, services: null); FunctionInvokingChatClient? functionInvokingClient = agent.GetService(); if (functionInvokingClient is not null) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/InvokeAzureAgentTemplate.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/InvokeAzureAgentTemplate.cs index 3704d38b21..af8728f9c1 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/InvokeAzureAgentTemplate.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/InvokeAzureAgentTemplate.cs @@ -64,7 +64,7 @@ public override string TransformText() EvaluateListExpression(this.Model.Input?.Messages, "inputMessages"); this.Write(@" - AgentRunResponse agentResponse = + AgentResponse agentResponse = await InvokeAgentAsync( context, agentName, @@ -75,7 +75,7 @@ await InvokeAgentAsync( if (autoSend) { - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, agentResponse)).ConfigureAwait(false); + await context.AddEventAsync(new AgentResponseEvent(this.Id, agentResponse)).ConfigureAwait(false); } "); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/InvokeAzureAgentTemplate.tt b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/InvokeAzureAgentTemplate.tt index 48c4acb859..b4ca34174d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/InvokeAzureAgentTemplate.tt +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/InvokeAzureAgentTemplate.tt @@ -21,7 +21,7 @@ internal sealed class <#= this.Name #>Executor(FormulaSession session, WorkflowA EvaluateBoolExpression(this.Model.Output?.AutoSend, "autoSend", defaultValue: true); EvaluateListExpression(this.Model.Input?.Messages, "inputMessages");#> - AgentRunResponse agentResponse = + AgentResponse agentResponse = await InvokeAgentAsync( context, agentName, @@ -32,7 +32,7 @@ internal sealed class <#= this.Name #>Executor(FormulaSession session, WorkflowA if (autoSend) { - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, agentResponse)).ConfigureAwait(false); + await context.AddEventAsync(new AgentResponseEvent(this.Id, agentResponse)).ConfigureAwait(false); } <# AssignVariable(this.Messages, "agentResponse.Messages"); #> diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/SendActivityTemplate.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/SendActivityTemplate.cs index 1d85f885b8..b7b6724069 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/SendActivityTemplate.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/SendActivityTemplate.cs @@ -71,8 +71,8 @@ public override string TransformText() } - this.Write("\n );\n AgentRunResponse response = new([new ChatMessage(ChatRole" + - ".Assistant, activityText)]);\n await context.AddEventAsync(new AgentRunRes" + + this.Write("\n );\n AgentResponse response = new([new ChatMessage(ChatRole" + + ".Assistant, activityText)]);\n await context.AddEventAsync(new AgentRes" + "ponseEvent(this.Id, response)).ConfigureAwait(false);"); } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/SendActivityTemplate.tt b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/SendActivityTemplate.tt index a1f0e3191f..f11d2181b4 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/SendActivityTemplate.tt +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/CodeGen/SendActivityTemplate.tt @@ -25,8 +25,8 @@ if (this.Model.Activity is MessageActivityTemplate messageActivity) } #> ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false);<# + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false);<# } #> return default; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Events/ExternalInputRequest.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Events/ExternalInputRequest.cs index 8caf374b70..6cee3d308e 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Events/ExternalInputRequest.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Events/ExternalInputRequest.cs @@ -13,21 +13,21 @@ public sealed class ExternalInputRequest /// /// The source message that triggered the request for external input. /// - public AgentRunResponse AgentResponse { get; } + public AgentResponse AgentResponse { get; } [JsonConstructor] - internal ExternalInputRequest(AgentRunResponse agentResponse) + internal ExternalInputRequest(AgentResponse agentResponse) { this.AgentResponse = agentResponse; } internal ExternalInputRequest(ChatMessage message) { - this.AgentResponse = new AgentRunResponse(message); + this.AgentResponse = new AgentResponse(message); } internal ExternalInputRequest(string text) { - this.AgentResponse = new AgentRunResponse(new ChatMessage(ChatRole.User, text)); + this.AgentResponse = new AgentResponse(new ChatMessage(ChatRole.User, text)); } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Extensions/AgentProviderExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Extensions/AgentProviderExtensions.cs index 037665e8b8..19dd4aae75 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Extensions/AgentProviderExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Extensions/AgentProviderExtensions.cs @@ -9,7 +9,7 @@ namespace Microsoft.Agents.AI.Workflows.Declarative.Extensions; internal static class AgentProviderExtensions { - public static async ValueTask InvokeAgentAsync( + public static async ValueTask InvokeAgentAsync( this WorkflowAgentProvider agentProvider, string executorId, IWorkflowContext context, @@ -20,15 +20,15 @@ public static async ValueTask InvokeAgentAsync( IDictionary? inputArguments = null, CancellationToken cancellationToken = default) { - IAsyncEnumerable agentUpdates = agentProvider.InvokeAgentAsync(agentName, null, conversationId, inputMessages, inputArguments, cancellationToken); + IAsyncEnumerable agentUpdates = agentProvider.InvokeAgentAsync(agentName, null, conversationId, inputMessages, inputArguments, cancellationToken); // Enable "autoSend" behavior if this is the workflow conversation. bool isWorkflowConversation = context.IsWorkflowConversation(conversationId, out string? workflowConversationId); autoSend |= isWorkflowConversation; // Process the agent response updates. - List updates = []; - await foreach (AgentRunResponseUpdate update in agentUpdates.ConfigureAwait(false)) + List updates = []; + await foreach (AgentResponseUpdate update in agentUpdates.ConfigureAwait(false)) { await AssignConversationIdAsync(((ChatResponseUpdate?)update.RawRepresentation)?.ConversationId).ConfigureAwait(false); @@ -36,15 +36,15 @@ public static async ValueTask InvokeAgentAsync( if (autoSend) { - await context.AddEventAsync(new AgentRunUpdateEvent(executorId, update), cancellationToken).ConfigureAwait(false); + await context.AddEventAsync(new AgentResponseUpdateEvent(executorId, update), cancellationToken).ConfigureAwait(false); } } - AgentRunResponse response = updates.ToAgentRunResponse(); + AgentResponse response = updates.ToAgentResponse(); if (autoSend) { - await context.AddEventAsync(new AgentRunResponseEvent(executorId, response), cancellationToken).ConfigureAwait(false); + await context.AddEventAsync(new AgentResponseEvent(executorId, response), cancellationToken).ConfigureAwait(false); } // If autoSend is enabled and this is not the workflow conversation, copy messages to the workflow conversation. diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Kit/AgentExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Kit/AgentExecutor.cs index 45a5b47bd8..9545b9d1ff 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Kit/AgentExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/Kit/AgentExecutor.cs @@ -26,7 +26,7 @@ public abstract class AgentExecutor(string id, FormulaSession session, WorkflowA /// Optional messages to add to the conversation prior to invocation. /// A token that can be used to observe cancellation. /// - protected ValueTask InvokeAgentAsync( + protected ValueTask InvokeAgentAsync( IWorkflowContext context, string agentName, string? conversationId, diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/AddConversationMessageExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/AddConversationMessageExecutor.cs index 632f462758..922a3f61b1 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/AddConversationMessageExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/AddConversationMessageExecutor.cs @@ -30,7 +30,7 @@ internal sealed class AddConversationMessageExecutor(AddConversationMessage mode if (isWorkflowConversation) { - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, new AgentRunResponse(newMessage)), cancellationToken).ConfigureAwait(false); + await context.AddEventAsync(new AgentResponseEvent(this.Id, new AgentResponse(newMessage)), cancellationToken).ConfigureAwait(false); } return default; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/CopyConversationMessagesExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/CopyConversationMessagesExecutor.cs index 01b1bab496..2d27408645 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/CopyConversationMessagesExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/CopyConversationMessagesExecutor.cs @@ -33,7 +33,7 @@ internal sealed class CopyConversationMessagesExecutor(CopyConversationMessages if (isWorkflowConversation) { - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, new AgentRunResponse([.. inputMessages])), cancellationToken).ConfigureAwait(false); + await context.AddEventAsync(new AgentResponseEvent(this.Id, new AgentResponse([.. inputMessages])), cancellationToken).ConfigureAwait(false); } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeAzureAgentExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeAzureAgentExecutor.cs index ed069a9b78..0cd6fee77b 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeAzureAgentExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeAzureAgentExecutor.cs @@ -61,12 +61,12 @@ private async ValueTask InvokeAgentAsync(IWorkflowContext context, IEnumerable? inputParameters = this.GetStructuredInputs(); - AgentRunResponse agentResponse = await agentProvider.InvokeAgentAsync(this.Id, context, agentName, conversationId, autoSend, messages, inputParameters, cancellationToken).ConfigureAwait(false); + AgentResponse agentResponse = await agentProvider.InvokeAgentAsync(this.Id, context, agentName, conversationId, autoSend, messages, inputParameters, cancellationToken).ConfigureAwait(false); ChatMessage[] actionableMessages = FilterActionableContent(agentResponse).ToArray(); if (actionableMessages.Length > 0) { - AgentRunResponse filteredResponse = + AgentResponse filteredResponse = new(actionableMessages) { AdditionalProperties = agentResponse.AdditionalProperties, @@ -137,7 +137,7 @@ private async ValueTask InvokeAgentAsync(IWorkflowContext context, IEnumerable FilterActionableContent(AgentRunResponse agentResponse) + private static IEnumerable FilterActionableContent(AgentResponse agentResponse) { HashSet functionResultIds = [.. agentResponse.Messages diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/RequestExternalInputExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/RequestExternalInputExecutor.cs index 2c35dc18e9..1b7a348ede 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/RequestExternalInputExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/RequestExternalInputExecutor.cs @@ -26,7 +26,7 @@ public static class Steps protected override async ValueTask ExecuteAsync(IWorkflowContext context, CancellationToken cancellationToken = default) { - ExternalInputRequest inputRequest = new(new AgentRunResponse()); + ExternalInputRequest inputRequest = new(new AgentResponse()); await context.SendMessageAsync(inputRequest, cancellationToken).ConfigureAwait(false); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/WorkflowAgentProvider.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/WorkflowAgentProvider.cs index e0967ab376..cfd75d1ca4 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/WorkflowAgentProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/WorkflowAgentProvider.cs @@ -91,8 +91,8 @@ public abstract class WorkflowAgentProvider /// The messages to include in the invocation. /// Optional input arguments for agents that provide support. /// A token that propagates notification when operation should be canceled. - /// Asynchronous set of . - public abstract IAsyncEnumerable InvokeAgentAsync( + /// Asynchronous set of . + public abstract IAsyncEnumerable InvokeAgentAsync( string agentId, string? agentVersion, string? conversationId, diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs index 9f9906270e..54653f306a 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs @@ -8,7 +8,7 @@ namespace Microsoft.Agents.AI.Workflows; internal static class AIAgentsAbstractionsExtensions { - public static ChatMessage ToChatMessage(this AgentRunResponseUpdate update) => + public static ChatMessage ToChatMessage(this AgentResponseUpdate update) => new() { AuthorName = update.AuthorName, diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/AgentResponseEvent.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/AgentResponseEvent.cs new file mode 100644 index 0000000000..a6c0b22525 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/AgentResponseEvent.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI.Workflows; + +/// +/// Represents an event triggered when an agent produces a response. +/// +public class AgentResponseEvent : ExecutorEvent +{ + /// + /// Initializes a new instance of the class. + /// + /// The identifier of the executor that generated this event. + /// The agent response. + public AgentResponseEvent(string executorId, AgentResponse response) : base(executorId, data: response) + { + this.Response = Throw.IfNull(response); + } + + /// + /// Gets the agent response. + /// + public AgentResponse Response { get; } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/AgentRunUpdateEvent.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/AgentResponseUpdateEvent.cs similarity index 55% rename from dotnet/src/Microsoft.Agents.AI.Workflows/AgentRunUpdateEvent.cs rename to dotnet/src/Microsoft.Agents.AI.Workflows/AgentResponseUpdateEvent.cs index 9fbf16b602..939e7a67e8 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/AgentRunUpdateEvent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/AgentResponseUpdateEvent.cs @@ -8,14 +8,14 @@ namespace Microsoft.Agents.AI.Workflows; /// /// Represents an event triggered when an agent run produces an update. /// -public class AgentRunUpdateEvent : ExecutorEvent +public class AgentResponseUpdateEvent : ExecutorEvent { /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The identifier of the executor that generated this event. /// The agent run response update. - public AgentRunUpdateEvent(string executorId, AgentRunResponseUpdate update) : base(executorId, data: update) + public AgentResponseUpdateEvent(string executorId, AgentResponseUpdate update) : base(executorId, data: update) { this.Update = Throw.IfNull(update); } @@ -23,15 +23,15 @@ public AgentRunUpdateEvent(string executorId, AgentRunResponseUpdate update) : b /// /// Gets the agent run response update. /// - public AgentRunResponseUpdate Update { get; } + public AgentResponseUpdate Update { get; } /// - /// Converts this event to an containing just this update. + /// Converts this event to an containing just this update. /// /// - public AgentRunResponse AsResponse() + public AgentResponse AsResponse() { - IEnumerable updates = [this.Update]; - return updates.ToAgentRunResponse(); + IEnumerable updates = [this.Update]; + return updates.ToAgentResponse(); } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/AgentRunResponseEvent.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/AgentRunResponseEvent.cs deleted file mode 100644 index 3f0013a88c..0000000000 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/AgentRunResponseEvent.cs +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using Microsoft.Shared.Diagnostics; - -namespace Microsoft.Agents.AI.Workflows; - -/// -/// Represents an event triggered when an agent run produces an update. -/// -public class AgentRunResponseEvent : ExecutorEvent -{ - /// - /// Initializes a new instance of the class. - /// - /// The identifier of the executor that generated this event. - /// The agent run response. - public AgentRunResponseEvent(string executorId, AgentRunResponse response) : base(executorId, data: response) - { - this.Response = Throw.IfNull(response); - } - - /// - /// Gets the agent run response. - /// - public AgentRunResponse Response { get; } -} diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/MessageMerger.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/MessageMerger.cs index 4560074dd2..de4a8b89f7 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/MessageMerger.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/MessageMerger.cs @@ -14,10 +14,10 @@ private sealed class ResponseMergeState(string? responseId) { public string? ResponseId { get; } = responseId; - public Dictionary> UpdatesByMessageId { get; } = []; - public List DanglingUpdates { get; } = []; + public Dictionary> UpdatesByMessageId { get; } = []; + public List DanglingUpdates { get; } = []; - public void AddUpdate(AgentRunResponseUpdate update) + public void AddUpdate(AgentResponseUpdate update) { if (update.MessageId is null) { @@ -25,7 +25,7 @@ public void AddUpdate(AgentRunResponseUpdate update) } else { - if (!this.UpdatesByMessageId.TryGetValue(update.MessageId, out List? updates)) + if (!this.UpdatesByMessageId.TryGetValue(update.MessageId, out List? updates)) { this.UpdatesByMessageId[update.MessageId] = updates = []; } @@ -34,24 +34,24 @@ public void AddUpdate(AgentRunResponseUpdate update) } } - public AgentRunResponse ComputeMerged(string messageId) + public AgentResponse ComputeMerged(string messageId) { - if (this.UpdatesByMessageId.TryGetValue(Throw.IfNull(messageId), out List? updates)) + if (this.UpdatesByMessageId.TryGetValue(Throw.IfNull(messageId), out List? updates)) { - return updates.ToAgentRunResponse(); + return updates.ToAgentResponse(); } throw new KeyNotFoundException($"No updates found for message ID '{messageId}' in response '{this.ResponseId}'."); } - public AgentRunResponse ComputeDangling() + public AgentResponse ComputeDangling() { if (this.DanglingUpdates.Count == 0) { throw new InvalidOperationException("No dangling updates to compute a response from."); } - return this.DanglingUpdates.ToAgentRunResponse(); + return this.DanglingUpdates.ToAgentResponse(); } public List ComputeFlattened() @@ -66,7 +66,7 @@ public List ComputeFlattened() IList AggregateUpdatesToMessage(string messageId) { - List updates = this.UpdatesByMessageId[messageId]; + List updates = this.UpdatesByMessageId[messageId]; if (updates.Count == 0) { throw new InvalidOperationException($"No updates found for message ID '{messageId}' in response '{this.ResponseId}'."); @@ -80,7 +80,7 @@ IList AggregateUpdatesToMessage(string messageId) private readonly Dictionary _mergeStates = []; private readonly ResponseMergeState _danglingState = new(null); - public void AddUpdate(AgentRunResponseUpdate update) + public void AddUpdate(AgentResponseUpdate update) { if (update.ResponseId is null) { @@ -97,7 +97,7 @@ public void AddUpdate(AgentRunResponseUpdate update) } } - private int CompareByDateTimeOffset(AgentRunResponse left, AgentRunResponse right) + private int CompareByDateTimeOffset(AgentResponse left, AgentResponse right) { const int LESS = -1, EQ = 0, GREATER = 1; @@ -119,17 +119,17 @@ private int CompareByDateTimeOffset(AgentRunResponse left, AgentRunResponse righ return left.CreatedAt.Value.CompareTo(right.CreatedAt.Value); } - public AgentRunResponse ComputeMerged(string primaryResponseId, string? primaryAgentId = null, string? primaryAgentName = null) + public AgentResponse ComputeMerged(string primaryResponseId, string? primaryAgentId = null, string? primaryAgentName = null) { List messages = []; - Dictionary responses = []; + Dictionary responses = []; HashSet agentIds = []; foreach (string responseId in this._mergeStates.Keys) { ResponseMergeState mergeState = this._mergeStates[responseId]; - List responseList = mergeState.UpdatesByMessageId.Keys.Select(mergeState.ComputeMerged).ToList(); + List responseList = mergeState.UpdatesByMessageId.Keys.Select(mergeState.ComputeMerged).ToList(); if (mergeState.DanglingUpdates.Count > 0) { responseList.Add(mergeState.ComputeDangling()); @@ -144,7 +144,7 @@ public AgentRunResponse ComputeMerged(string primaryResponseId, string? primaryA AdditionalPropertiesDictionary? additionalProperties = null; HashSet createdTimes = []; - foreach (AgentRunResponse response in responses.Values) + foreach (AgentResponse response in responses.Values) { if (response.AgentId is not null) { @@ -176,7 +176,7 @@ public AgentRunResponse ComputeMerged(string primaryResponseId, string? primaryA } messages.RemoveAll(m => m.Contents.Count == 0); - return new AgentRunResponse(messages) + return new AgentResponse(messages) { ResponseId = primaryResponseId, AgentId = primaryAgentId @@ -187,7 +187,7 @@ public AgentRunResponse ComputeMerged(string primaryResponseId, string? primaryA AdditionalProperties = additionalProperties }; - static AgentRunResponse MergeResponses(AgentRunResponse? current, AgentRunResponse incoming) + static AgentResponse MergeResponses(AgentResponse? current, AgentResponse incoming) { if (current is null) { @@ -214,7 +214,7 @@ static AgentRunResponse MergeResponses(AgentRunResponse? current, AgentRunRespon }; } - static IEnumerable GetMessagesWithCreatedAt(AgentRunResponse response) + static IEnumerable GetMessagesWithCreatedAt(AgentResponse response) { if (response.Messages.Count == 0) { diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs index 0a887013a3..42217ce7bc 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs @@ -20,8 +20,8 @@ public AIAgentHostExecutor(AIAgent agent, bool emitEvents = false) : base(id: ag this._emitEvents = emitEvents; } - private AgentThread EnsureThread(IWorkflowContext context) => - this._thread ??= this._agent.GetNewThread(); + private async Task EnsureThreadAsync(IWorkflowContext context, CancellationToken cancellationToken) => + this._thread ??= await this._agent.GetNewThreadAsync(cancellationToken).ConfigureAwait(false); private const string ThreadStateKey = nameof(_thread); protected internal override async ValueTask OnCheckpointingAsync(IWorkflowContext context, CancellationToken cancellationToken = default) @@ -43,7 +43,7 @@ protected internal override async ValueTask OnCheckpointRestoredAsync(IWorkflowC JsonElement? threadValue = await context.ReadStateAsync(ThreadStateKey, cancellationToken: cancellationToken).ConfigureAwait(false); if (threadValue.HasValue) { - this._thread = this._agent.DeserializeThread(threadValue.Value); + this._thread = await this._agent.DeserializeThreadAsync(threadValue.Value, cancellationToken: cancellationToken).ConfigureAwait(false); } await base.OnCheckpointRestoredAsync(context, cancellationToken).ConfigureAwait(false); @@ -54,13 +54,16 @@ protected override async ValueTask TakeTurnAsync(List messages, IWo if (emitEvents ?? this._emitEvents) { // Run the agent in streaming mode only when agent run update events are to be emitted. - IAsyncEnumerable agentStream = this._agent.RunStreamingAsync(messages, this.EnsureThread(context), cancellationToken: cancellationToken); + IAsyncEnumerable agentStream = this._agent.RunStreamingAsync( + messages, + await this.EnsureThreadAsync(context, cancellationToken).ConfigureAwait(false), + cancellationToken: cancellationToken); - List updates = []; + List updates = []; - await foreach (AgentRunResponseUpdate update in agentStream.ConfigureAwait(false)) + await foreach (AgentResponseUpdate update in agentStream.ConfigureAwait(false)) { - await context.AddEventAsync(new AgentRunUpdateEvent(this.Id, update), cancellationToken).ConfigureAwait(false); + await context.AddEventAsync(new AgentResponseUpdateEvent(this.Id, update), cancellationToken).ConfigureAwait(false); // TODO: FunctionCall request handling, and user info request handling. // In some sense: We should just let it be handled as a ChatMessage, though we should consider @@ -69,12 +72,15 @@ protected override async ValueTask TakeTurnAsync(List messages, IWo updates.Add(update); } - await context.SendMessageAsync(updates.ToAgentRunResponse().Messages, cancellationToken: cancellationToken).ConfigureAwait(false); + await context.SendMessageAsync(updates.ToAgentResponse().Messages, cancellationToken: cancellationToken).ConfigureAwait(false); } else { // Otherwise, run the agent in non-streaming mode. - AgentRunResponse response = await this._agent.RunAsync(messages, this.EnsureThread(context), cancellationToken: cancellationToken).ConfigureAwait(false); + AgentResponse response = await this._agent.RunAsync( + messages, + await this.EnsureThreadAsync(context, cancellationToken).ConfigureAwait(false), + cancellationToken: cancellationToken).ConfigureAwait(false); await context.SendMessageAsync(response.Messages, cancellationToken: cancellationToken).ConfigureAwait(false); } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AgentRunStreamingExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AgentRunStreamingExecutor.cs index ae3a932feb..e28870729c 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AgentRunStreamingExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AgentRunStreamingExecutor.cs @@ -22,20 +22,20 @@ protected override async ValueTask TakeTurnAsync(List messages, IWo { List? roleChanged = messages.ChangeAssistantToUserForOtherParticipants(agent.Name ?? agent.Id); - List updates = []; + List updates = []; await foreach (var update in agent.RunStreamingAsync(messages, cancellationToken: cancellationToken).ConfigureAwait(false)) { updates.Add(update); if (emitEvents is true) { - await context.AddEventAsync(new AgentRunUpdateEvent(this.Id, update), cancellationToken).ConfigureAwait(false); + await context.AddEventAsync(new AgentResponseUpdateEvent(this.Id, update), cancellationToken).ConfigureAwait(false); } } roleChanged.ResetUserToAssistantForChangedRoles(); List result = includeInputInOutput ? [.. messages] : []; - result.AddRange(updates.ToAgentRunResponse().Messages); + result.AddRange(updates.ToAgentResponse().Messages); await context.SendMessageAsync(result, cancellationToken: cancellationToken).ConfigureAwait(false); } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs index 24e0eea3cb..8c608090f3 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs @@ -64,7 +64,7 @@ protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => routeBuilder.AddHandler(async (handoffState, context, cancellationToken) => { string? requestedHandoff = null; - List updates = []; + List updates = []; List allMessages = handoffState.Messages; List? roleChanges = allMessages.ChangeAssistantToUserForOtherParticipants(this._agent.Name ?? this._agent.Id); @@ -82,7 +82,7 @@ protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => { requestedHandoff = fcc.Name; await AddUpdateAsync( - new AgentRunResponseUpdate + new AgentResponseUpdate { AgentId = this._agent.Id, AuthorName = this._agent.Name ?? this._agent.Id, @@ -98,18 +98,18 @@ await AddUpdateAsync( } } - allMessages.AddRange(updates.ToAgentRunResponse().Messages); + allMessages.AddRange(updates.ToAgentResponse().Messages); roleChanges.ResetUserToAssistantForChangedRoles(); await context.SendMessageAsync(new HandoffState(handoffState.TurnToken, requestedHandoff, allMessages), cancellationToken: cancellationToken).ConfigureAwait(false); - async Task AddUpdateAsync(AgentRunResponseUpdate update, CancellationToken cancellationToken) + async Task AddUpdateAsync(AgentResponseUpdate update, CancellationToken cancellationToken) { updates.Add(update); if (handoffState.TurnToken.EmitEvents is true) { - await context.AddEventAsync(new AgentRunUpdateEvent(this.Id, update), cancellationToken).ConfigureAwait(false); + await context.AddEventAsync(new AgentResponseUpdateEvent(this.Id, update), cancellationToken).ConfigureAwait(false); } } }); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs index 4e5ee86070..f20660bc51 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs @@ -63,14 +63,15 @@ private async ValueTask ValidateWorkflowAsync() protocol.ThrowIfNotChatProtocol(); } - public override AgentThread GetNewThread() => new WorkflowThread(this._workflow, this.GenerateNewId(), this._executionEnvironment, this._checkpointManager, this._includeExceptionDetails); + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) + => new(new WorkflowThread(this._workflow, this.GenerateNewId(), this._executionEnvironment, this._checkpointManager, this._includeExceptionDetails)); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - => new WorkflowThread(this._workflow, serializedThread, this._executionEnvironment, this._checkpointManager, this._includeExceptionDetails, jsonSerializerOptions); + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + => new(new WorkflowThread(this._workflow, serializedThread, this._executionEnvironment, this._checkpointManager, this._includeExceptionDetails, jsonSerializerOptions)); - private ValueTask UpdateThreadAsync(IEnumerable messages, AgentThread? thread = null, CancellationToken cancellationToken = default) + private async ValueTask UpdateThreadAsync(IEnumerable messages, AgentThread? thread = null, CancellationToken cancellationToken = default) { - thread ??= this.GetNewThread(); + thread ??= await this.GetNewThreadAsync(cancellationToken).ConfigureAwait(false); if (thread is not WorkflowThread workflowThread) { @@ -80,11 +81,11 @@ private ValueTask UpdateThreadAsync(IEnumerable mes // For workflow threads, messages are added directly via the internal AddMessages method // The MessageStore methods are used for agent invocation scenarios workflowThread.MessageStore.AddMessages(messages); - return new ValueTask(workflowThread); + return workflowThread; } protected override async - Task RunCoreAsync( + Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -95,7 +96,7 @@ Task RunCoreAsync( WorkflowThread workflowThread = await this.UpdateThreadAsync(messages, thread, cancellationToken).ConfigureAwait(false); MessageMerger merger = new(); - await foreach (AgentRunResponseUpdate update in workflowThread.InvokeStageAsync(cancellationToken) + await foreach (AgentResponseUpdate update in workflowThread.InvokeStageAsync(cancellationToken) .ConfigureAwait(false) .WithCancellation(cancellationToken)) { @@ -106,7 +107,7 @@ Task RunCoreAsync( } protected override async - IAsyncEnumerable RunCoreStreamingAsync( + IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -115,7 +116,7 @@ IAsyncEnumerable RunCoreStreamingAsync( await this.ValidateWorkflowAsync().ConfigureAwait(false); WorkflowThread workflowThread = await this.UpdateThreadAsync(messages, thread, cancellationToken).ConfigureAwait(false); - await foreach (AgentRunResponseUpdate update in workflowThread.InvokeStageAsync(cancellationToken) + await foreach (AgentResponseUpdate update in workflowThread.InvokeStageAsync(cancellationToken) .ConfigureAwait(false) .WithCancellation(cancellationToken)) { diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowThread.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowThread.cs index 94144831e0..6f00566b95 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowThread.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowThread.cs @@ -83,11 +83,11 @@ public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptio return marshaller.Marshal(info); } - public AgentRunResponseUpdate CreateUpdate(string responseId, object raw, params AIContent[] parts) + public AgentResponseUpdate CreateUpdate(string responseId, object raw, params AIContent[] parts) { Throw.IfNullOrEmpty(parts); - AgentRunResponseUpdate update = new(ChatRole.Assistant, parts) + AgentResponseUpdate update = new(ChatRole.Assistant, parts) { CreatedAt = DateTimeOffset.UtcNow, MessageId = Guid.NewGuid().ToString("N"), @@ -130,7 +130,7 @@ await this._executionEnvironment } internal async - IAsyncEnumerable InvokeStageAsync( + IAsyncEnumerable InvokeStageAsync( [EnumeratorCancellation] CancellationToken cancellationToken = default) { try @@ -151,13 +151,13 @@ IAsyncEnumerable InvokeStageAsync( { switch (evt) { - case AgentRunUpdateEvent agentUpdate: + case AgentResponseUpdateEvent agentUpdate: yield return agentUpdate.Update; break; case RequestInfoEvent requestInfo: FunctionCallContent fcContent = requestInfo.Request.ToFunctionCall(); - AgentRunResponseUpdate update = this.CreateUpdate(this.LastResponseId, evt, fcContent); + AgentResponseUpdate update = this.CreateUpdate(this.LastResponseId, evt, fcContent); yield return update; break; @@ -186,7 +186,7 @@ IAsyncEnumerable InvokeStageAsync( default: // Emit all other workflow events for observability (DevUI, logging, etc.) - yield return new AgentRunResponseUpdate(ChatRole.Assistant, []) + yield return new AgentResponseUpdate(ChatRole.Assistant, []) { CreatedAt = DateTimeOffset.UtcNow, MessageId = Guid.NewGuid().ToString("N"), diff --git a/dotnet/src/Microsoft.Agents.AI/AIAgentBuilder.cs b/dotnet/src/Microsoft.Agents.AI/AIAgentBuilder.cs index a74da3ed20..7d629c42b1 100644 --- a/dotnet/src/Microsoft.Agents.AI/AIAgentBuilder.cs +++ b/dotnet/src/Microsoft.Agents.AI/AIAgentBuilder.cs @@ -143,8 +143,8 @@ public AIAgentBuilder Use(Func, AgentThread?, AgentRunO /// /// Both and are . public AIAgentBuilder Use( - Func, AgentThread?, AgentRunOptions?, AIAgent, CancellationToken, Task>? runFunc, - Func, AgentThread?, AgentRunOptions?, AIAgent, CancellationToken, IAsyncEnumerable>? runStreamingFunc) + Func, AgentThread?, AgentRunOptions?, AIAgent, CancellationToken, Task>? runFunc, + Func, AgentThread?, AgentRunOptions?, AIAgent, CancellationToken, IAsyncEnumerable>? runStreamingFunc) { AnonymousDelegatingAIAgent.ThrowIfBothDelegatesNull(runFunc, runStreamingFunc); diff --git a/dotnet/src/Microsoft.Agents.AI/AnonymousDelegatingAIAgent.cs b/dotnet/src/Microsoft.Agents.AI/AnonymousDelegatingAIAgent.cs index 542bafdbf4..48de303bc1 100644 --- a/dotnet/src/Microsoft.Agents.AI/AnonymousDelegatingAIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/AnonymousDelegatingAIAgent.cs @@ -18,7 +18,7 @@ namespace Microsoft.Agents.AI; internal sealed class AnonymousDelegatingAIAgent : DelegatingAIAgent { /// The delegate to use as the implementation of . - private readonly Func, AgentThread?, AgentRunOptions?, AIAgent, CancellationToken, Task>? _runFunc; + private readonly Func, AgentThread?, AgentRunOptions?, AIAgent, CancellationToken, Task>? _runFunc; /// The delegate to use as the implementation of . /// @@ -26,7 +26,7 @@ internal sealed class AnonymousDelegatingAIAgent : DelegatingAIAgent /// will be invoked with the same arguments as the method itself. /// When , will delegate directly to the inner agent. /// - private readonly Func, AgentThread?, AgentRunOptions?, AIAgent, CancellationToken, IAsyncEnumerable>? _runStreamingFunc; + private readonly Func, AgentThread?, AgentRunOptions?, AIAgent, CancellationToken, IAsyncEnumerable>? _runStreamingFunc; /// The delegate to use as the implementation of both and . private readonly Func, AgentThread?, AgentRunOptions?, Func, AgentThread?, AgentRunOptions?, CancellationToken, Task>, CancellationToken, Task>? _sharedFunc; @@ -74,8 +74,8 @@ public AnonymousDelegatingAIAgent( /// Both and are . public AnonymousDelegatingAIAgent( AIAgent innerAgent, - Func, AgentThread?, AgentRunOptions?, AIAgent, CancellationToken, Task>? runFunc, - Func, AgentThread?, AgentRunOptions?, AIAgent, CancellationToken, IAsyncEnumerable>? runStreamingFunc) + Func, AgentThread?, AgentRunOptions?, AIAgent, CancellationToken, Task>? runFunc, + Func, AgentThread?, AgentRunOptions?, AIAgent, CancellationToken, IAsyncEnumerable>? runStreamingFunc) : base(innerAgent) { ThrowIfBothDelegatesNull(runFunc, runStreamingFunc); @@ -85,7 +85,7 @@ public AnonymousDelegatingAIAgent( } /// - protected override Task RunCoreAsync( + protected override Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -97,10 +97,10 @@ protected override Task RunCoreAsync( { return GetRunViaSharedAsync(messages, thread, options, cancellationToken); - async Task GetRunViaSharedAsync( + async Task GetRunViaSharedAsync( IEnumerable messages, AgentThread? thread, AgentRunOptions? options, CancellationToken cancellationToken) { - AgentRunResponse? response = null; + AgentResponse? response = null; await this._sharedFunc( messages, @@ -113,7 +113,7 @@ await this._sharedFunc( if (response is null) { - Throw.InvalidOperationException("The shared delegate completed successfully without producing an AgentRunResponse."); + Throw.InvalidOperationException("The shared delegate completed successfully without producing an AgentResponse."); } return response; @@ -127,12 +127,12 @@ await this._sharedFunc( { Debug.Assert(this._runStreamingFunc is not null, "Expected non-null streaming delegate."); return this._runStreamingFunc!(messages, thread, options, this.InnerAgent, cancellationToken) - .ToAgentRunResponseAsync(cancellationToken); + .ToAgentResponseAsync(cancellationToken); } } /// - protected override IAsyncEnumerable RunCoreStreamingAsync( + protected override IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -142,7 +142,7 @@ protected override IAsyncEnumerable RunCoreStreamingAsyn if (this._sharedFunc is not null) { - var updates = Channel.CreateBounded(1); + var updates = Channel.CreateBounded(1); _ = ProcessAsync(); async Task ProcessAsync() @@ -180,10 +180,10 @@ await this._sharedFunc(messages, thread, options, async (messages, thread, optio Debug.Assert(this._runFunc is not null, "Expected non-null non-streaming delegate."); return GetStreamingRunAsyncViaRunAsync(this._runFunc!(messages, thread, options, this.InnerAgent, cancellationToken)); - static async IAsyncEnumerable GetStreamingRunAsyncViaRunAsync(Task task) + static async IAsyncEnumerable GetStreamingRunAsyncViaRunAsync(Task task) { - AgentRunResponse response = await task.ConfigureAwait(false); - foreach (var update in response.ToAgentRunResponseUpdates()) + AgentResponse response = await task.ConfigureAwait(false); + foreach (var update in response.ToAgentResponseUpdates()) { yield return update; } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 0fa6473de0..4a42241b3c 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -149,7 +149,7 @@ public ChatClientAgent(IChatClient chatClient, ChatClientAgentOptions? options, internal ChatOptions? ChatOptions => this._agentOptions?.ChatOptions; /// - protected override Task RunCoreAsync( + protected override Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -160,9 +160,9 @@ static Task GetResponseAsync(IChatClient chatClient, List - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -283,7 +283,7 @@ protected override async IAsyncEnumerable RunCoreStreami // We can derive the type of supported thread from whether we have a conversation id, // so let's update it and set the conversation id for the service thread case. - this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId); + await this.UpdateThreadWithTypeAndConversationIdAsync(safeThread, chatResponse.ConversationId, cancellationToken).ConfigureAwait(false); // To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request. await NotifyMessageStoreOfNewMessagesAsync(safeThread, GetInputMessages(inputMessages, continuationToken), chatMessageStoreMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); @@ -302,19 +302,30 @@ protected override async IAsyncEnumerable RunCoreStreami : this.ChatClient.GetService(serviceType, serviceKey)); /// - public override AgentThread GetNewThread() - => new ChatClientAgentThread + public override async ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) + { + ChatMessageStore? messageStore = this._agentOptions?.ChatMessageStoreFactory is not null + ? await this._agentOptions.ChatMessageStoreFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) + : null; + + AIContextProvider? contextProvider = this._agentOptions?.AIContextProviderFactory is not null + ? await this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) + : null; + + return new ChatClientAgentThread { - MessageStore = this._agentOptions?.ChatMessageStoreFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }), - AIContextProvider = this._agentOptions?.AIContextProviderFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }) + MessageStore = messageStore, + AIContextProvider = contextProvider }; + } /// /// Creates a new agent thread instance using an existing conversation identifier to continue that conversation. /// /// The identifier of an existing conversation to continue. + /// The to monitor for cancellation requests. /// - /// A new instance configured to work with the specified conversation. + /// A value task representing the asynchronous operation. The task result contains a new instance configured to work with the specified conversation. /// /// /// @@ -326,19 +337,26 @@ public override AgentThread GetNewThread() /// instances that support server-side conversation storage through their underlying . /// /// - public AgentThread GetNewThread(string conversationId) - => new ChatClientAgentThread() + public async ValueTask GetNewThreadAsync(string conversationId, CancellationToken cancellationToken = default) + { + AIContextProvider? contextProvider = this._agentOptions?.AIContextProviderFactory is not null + ? await this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) + : null; + + return new ChatClientAgentThread() { ConversationId = conversationId, - AIContextProvider = this._agentOptions?.AIContextProviderFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }) + AIContextProvider = contextProvider }; + } /// /// Creates a new agent thread instance using an existing to continue a conversation. /// /// The instance to use for managing the conversation's message history. + /// The to monitor for cancellation requests. /// - /// A new instance configured to work with the provided . + /// A value task representing the asynchronous operation. The task result contains a new instance configured to work with the provided . /// /// /// @@ -347,48 +365,55 @@ public AgentThread GetNewThread(string conversationId) /// with a may not be compatible with these services. /// /// - /// Where a service requires server-side conversation storage, use . + /// Where a service requires server-side conversation storage, use . /// /// /// If the agent detects, during the first run, that the underlying AI service requires server-side conversation storage, /// the thread will throw an exception to indicate that it cannot continue using the provided . /// /// - public AgentThread GetNewThread(ChatMessageStore chatMessageStore) - => new ChatClientAgentThread() + public async ValueTask GetNewThreadAsync(ChatMessageStore chatMessageStore, CancellationToken cancellationToken = default) + { + AIContextProvider? contextProvider = this._agentOptions?.AIContextProviderFactory is not null + ? await this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) + : null; + + return new ChatClientAgentThread() { MessageStore = Throw.IfNull(chatMessageStore), - AIContextProvider = this._agentOptions?.AIContextProviderFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }) + AIContextProvider = contextProvider }; + } /// - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override async ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - Func? chatMessageStoreFactory = this._agentOptions?.ChatMessageStoreFactory is null ? + Func>? chatMessageStoreFactory = this._agentOptions?.ChatMessageStoreFactory is null ? null : - (jse, jso) => this._agentOptions.ChatMessageStoreFactory.Invoke(new() { SerializedState = jse, JsonSerializerOptions = jso }); + (jse, jso, ct) => this._agentOptions.ChatMessageStoreFactory.Invoke(new() { SerializedState = jse, JsonSerializerOptions = jso }, ct); - Func? aiContextProviderFactory = this._agentOptions?.AIContextProviderFactory is null ? + Func>? aiContextProviderFactory = this._agentOptions?.AIContextProviderFactory is null ? null : - (jse, jso) => this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = jse, JsonSerializerOptions = jso }); + (jse, jso, ct) => this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = jse, JsonSerializerOptions = jso }, ct); - return new ChatClientAgentThread( + return await ChatClientAgentThread.DeserializeAsync( serializedThread, jsonSerializerOptions, chatMessageStoreFactory, - aiContextProviderFactory); + aiContextProviderFactory, + cancellationToken).ConfigureAwait(false); } #region Private - private async Task RunCoreAsync( + private async Task RunCoreAsync( Func, ChatOptions?, CancellationToken, Task> chatClientRunFunc, - Func agentResponseFactoryFunc, + Func agentResponseFactoryFunc, IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) - where TAgentRunResponse : AgentRunResponse + where TAgentResponse : AgentResponse where TChatClientResponse : ChatResponse { var inputMessages = Throw.IfNull(messages) as IReadOnlyCollection ?? messages.ToList(); @@ -426,7 +451,7 @@ private async Task RunCoreAsync 0 }) + { + chatOptions ??= new ChatOptions(); + chatOptions.AdditionalProperties ??= new(); + foreach (var kvp in agentRunOptions.AdditionalProperties) + { + chatOptions.AdditionalProperties[kvp.Key] = kvp.Value; + } + } + return (chatOptions, agentContinuationToken); } } @@ -653,7 +689,7 @@ private async Task throw new InvalidOperationException("A thread must be provided when continuing a background response with a continuation token."); } - thread ??= this.GetNewThread(); + thread ??= await this.GetNewThreadAsync(cancellationToken).ConfigureAwait(false); if (thread is not ChatClientAgentThread typedThread) { throw new InvalidOperationException("The provided thread is not compatible with the agent. Only threads created by the agent can be used."); @@ -667,7 +703,7 @@ private async Task List inputMessagesForChatClient = []; IList? aiContextProviderMessages = null; - IList? chatMessageStoreMessages = null; + IList? chatMessageStoreMessages = []; // Populate the thread messages only if we are not continuing an existing response as it's not allowed if (chatOptions?.ContinuationToken is null) @@ -735,13 +771,13 @@ private async Task return (typedThread, chatOptions, inputMessagesForChatClient, aiContextProviderMessages, chatMessageStoreMessages, continuationToken); } - private void UpdateThreadWithTypeAndConversationId(ChatClientAgentThread thread, string? responseConversationId) + private async Task UpdateThreadWithTypeAndConversationIdAsync(ChatClientAgentThread thread, string? responseConversationId, CancellationToken cancellationToken) { if (string.IsNullOrWhiteSpace(responseConversationId) && !string.IsNullOrWhiteSpace(thread.ConversationId)) { - // We were passed a thread that is service managed, but we got no conversation id back from the chat client, - // meaning the service doesn't support service managed threads, so the thread cannot be used with this service. - throw new InvalidOperationException("Service did not return a valid conversation id when using a service managed thread."); + // We were passed an AgentThread that has an id for service managed chat history, but we got no conversation id back from the chat client, + // meaning the service doesn't support service managed chat history, so the thread cannot be used with this service. + throw new InvalidOperationException("Service did not return a valid conversation id when using an AgentThread with service managed chat history."); } if (!string.IsNullOrWhiteSpace(responseConversationId)) @@ -752,10 +788,12 @@ private void UpdateThreadWithTypeAndConversationId(ChatClientAgentThread thread, } else { - // If the service doesn't use service side thread storage (i.e. we got no id back from invocation), and + // If the service doesn't use service side chat history storage (i.e. we got no id back from invocation), and // the thread has no MessageStore yet, we should update the thread with the custom MessageStore or // default InMemoryMessageStore so that it has somewhere to store the chat history. - thread.MessageStore ??= this._agentOptions?.ChatMessageStoreFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }) ?? new InMemoryChatMessageStore(); + thread.MessageStore ??= this._agentOptions?.ChatMessageStoreFactory is not null + ? await this._agentOptions.ChatMessageStoreFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) + : new InMemoryChatMessageStore(); } } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentCustomOptions.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentCustomOptions.cs index b0cbd3d793..c5502ad916 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentCustomOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentCustomOptions.cs @@ -22,8 +22,8 @@ public partial class ChatClientAgent /// /// Configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. - public Task RunAsync( + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + public Task RunAsync( AgentThread? thread, ChatClientAgentRunOptions? options, CancellationToken cancellationToken = default) => @@ -39,8 +39,8 @@ public Task RunAsync( /// /// Configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. - public Task RunAsync( + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + public Task RunAsync( string message, AgentThread? thread, ChatClientAgentRunOptions? options, @@ -57,8 +57,8 @@ public Task RunAsync( /// /// Configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. - public Task RunAsync( + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + public Task RunAsync( ChatMessage message, AgentThread? thread, ChatClientAgentRunOptions? options, @@ -75,8 +75,8 @@ public Task RunAsync( /// /// Configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. - public Task RunAsync( + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + public Task RunAsync( IEnumerable messages, AgentThread? thread, ChatClientAgentRunOptions? options, @@ -92,8 +92,8 @@ public Task RunAsync( /// /// Configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// An asynchronous enumerable of instances representing the streaming response. - public IAsyncEnumerable RunStreamingAsync( + /// An asynchronous enumerable of instances representing the streaming response. + public IAsyncEnumerable RunStreamingAsync( AgentThread? thread, ChatClientAgentRunOptions? options, CancellationToken cancellationToken = default) => @@ -109,8 +109,8 @@ public IAsyncEnumerable RunStreamingAsync( /// /// Configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// An asynchronous enumerable of instances representing the streaming response. - public IAsyncEnumerable RunStreamingAsync( + /// An asynchronous enumerable of instances representing the streaming response. + public IAsyncEnumerable RunStreamingAsync( string message, AgentThread? thread, ChatClientAgentRunOptions? options, @@ -127,8 +127,8 @@ public IAsyncEnumerable RunStreamingAsync( /// /// Configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// An asynchronous enumerable of instances representing the streaming response. - public IAsyncEnumerable RunStreamingAsync( + /// An asynchronous enumerable of instances representing the streaming response. + public IAsyncEnumerable RunStreamingAsync( ChatMessage message, AgentThread? thread, ChatClientAgentRunOptions? options, @@ -145,8 +145,8 @@ public IAsyncEnumerable RunStreamingAsync( /// /// Configuration parameters for controlling the agent's invocation behavior. /// The to monitor for cancellation requests. The default is . - /// An asynchronous enumerable of instances representing the streaming response. - public IAsyncEnumerable RunStreamingAsync( + /// An asynchronous enumerable of instances representing the streaming response. + public IAsyncEnumerable RunStreamingAsync( IEnumerable messages, AgentThread? thread, ChatClientAgentRunOptions? options, @@ -167,8 +167,8 @@ public IAsyncEnumerable RunStreamingAsync( /// Using a JSON schema improves reliability if the underlying model supports native structured output with a schema, but might cause an error if the model does not support it. /// /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. - public Task> RunAsync( + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + public Task> RunAsync( AgentThread? thread, JsonSerializerOptions? serializerOptions, ChatClientAgentRunOptions? options, @@ -191,8 +191,8 @@ public Task> RunAsync( /// Using a JSON schema improves reliability if the underlying model supports native structured output with a schema, but might cause an error if the model does not support it. /// /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. - public Task> RunAsync( + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + public Task> RunAsync( string message, AgentThread? thread, JsonSerializerOptions? serializerOptions, @@ -216,8 +216,8 @@ public Task> RunAsync( /// Using a JSON schema improves reliability if the underlying model supports native structured output with a schema, but might cause an error if the model does not support it. /// /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. - public Task> RunAsync( + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + public Task> RunAsync( ChatMessage message, AgentThread? thread, JsonSerializerOptions? serializerOptions, @@ -241,8 +241,8 @@ public Task> RunAsync( /// Using a JSON schema improves reliability if the underlying model supports native structured output with a schema, but might cause an error if the model does not support it. /// /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. - public Task> RunAsync( + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + public Task> RunAsync( IEnumerable messages, AgentThread? thread, JsonSerializerOptions? serializerOptions, diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs index dd1ff3b228..719e863f0c 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs @@ -2,6 +2,8 @@ using System; using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; using Microsoft.Extensions.AI; namespace Microsoft.Agents.AI; @@ -40,14 +42,14 @@ public sealed class ChatClientAgentOptions /// Gets or sets a factory function to create an instance of /// which will be used to store chat messages for this agent. /// - public Func? ChatMessageStoreFactory { get; set; } + public Func>? ChatMessageStoreFactory { get; set; } /// /// Gets or sets a factory function to create an instance of /// which will be used to create a context provider for each new thread, and can then /// provide additional context for each agent run. /// - public Func? AIContextProviderFactory { get; set; } + public Func>? AIContextProviderFactory { get; set; } /// /// Gets or sets a value indicating whether to use the provided instance as is, diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentRunResponse{T}.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentRunResponse{T}.cs index 352be764eb..a4fadff0c7 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentRunResponse{T}.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentRunResponse{T}.cs @@ -12,23 +12,23 @@ namespace Microsoft.Agents.AI; /// The type of value expected from the chat response. /// /// Language models are not guaranteed to honor the requested schema. If the model's output is not -/// parsable as the expected type, you can access the underlying JSON response on the property. +/// parsable as the expected type, you can access the underlying JSON response on the property. /// -public sealed class ChatClientAgentRunResponse : AgentRunResponse +public sealed class ChatClientAgentResponse : AgentResponse { private readonly ChatResponse _response; /// - /// Initializes a new instance of the class from an existing . + /// Initializes a new instance of the class from an existing . /// - /// The from which to populate this . + /// The from which to populate this . /// is . /// /// This constructor creates an agent response that wraps an existing , preserving all /// metadata and storing the original response in for access to /// the underlying implementation details. /// - public ChatClientAgentRunResponse(ChatResponse response) : base(response) + public ChatClientAgentResponse(ChatResponse response) : base(response) { _ = Throw.IfNull(response); diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentStructuredOutput.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentStructuredOutput.cs index 9a535cd645..6bd62e85a2 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentStructuredOutput.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentStructuredOutput.cs @@ -29,12 +29,12 @@ public sealed partial class ChatClientAgent /// Using a JSON schema improves reliability if the underlying model supports native structured output with a schema, but might cause an error if the model does not support it. /// /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. /// /// This overload is useful when the agent has sufficient context from previous messages in the thread /// or from its initial configuration to generate a meaningful response without additional input. /// - public Task> RunAsync( + public Task> RunAsync( AgentThread? thread = null, JsonSerializerOptions? serializerOptions = null, AgentRunOptions? options = null, @@ -57,13 +57,13 @@ public Task> RunAsync( /// Using a JSON schema improves reliability if the underlying model supports native structured output with a schema, but might cause an error if the model does not support it. /// /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. /// is , empty, or contains only whitespace. /// /// The provided text will be wrapped in a with the role /// before being sent to the agent. This is a convenience method for simple text-based interactions. /// - public Task> RunAsync( + public Task> RunAsync( string message, AgentThread? thread = null, JsonSerializerOptions? serializerOptions = null, @@ -91,9 +91,9 @@ public Task> RunAsync( /// Using a JSON schema improves reliability if the underlying model supports native structured output with a schema, but might cause an error if the model does not support it. /// /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. /// is . - public Task> RunAsync( + public Task> RunAsync( ChatMessage message, AgentThread? thread = null, JsonSerializerOptions? serializerOptions = null, @@ -121,7 +121,7 @@ public Task> RunAsync( /// Using a JSON schema improves reliability if the underlying model supports native structured output with a schema, but might cause an error if the model does not support it. /// /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains an with the agent's output. + /// A task that represents the asynchronous operation. The task result contains an with the agent's output. /// The type of structured output to request. /// /// @@ -134,7 +134,7 @@ public Task> RunAsync( /// The agent's response will also be added to if one is provided. /// /// - public Task> RunAsync( + public Task> RunAsync( IEnumerable messages, AgentThread? thread = null, JsonSerializerOptions? serializerOptions = null, @@ -152,9 +152,9 @@ async Task> GetResponseAsync(IChatClient chatClient, List CreateResponse(ChatResponse chatResponse) + static ChatClientAgentResponse CreateResponse(ChatResponse chatResponse) { - return new ChatClientAgentRunResponse(chatResponse) + return new ChatClientAgentResponse(chatResponse) { ContinuationToken = WrapContinuationToken(chatResponse.ContinuationToken) }; diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentThread.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentThread.cs index f4cf4aa033..06326d1ed2 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentThread.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentThread.cs @@ -3,6 +3,8 @@ using System; using System.Diagnostics; using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI; @@ -22,48 +24,6 @@ internal ChatClientAgentThread() { } - /// - /// Initializes a new instance of the class from previously serialized state. - /// - /// A representing the serialized state of the thread. - /// Optional settings for customizing the JSON deserialization process. - /// - /// An optional factory function to create a custom from its serialized state. - /// If not provided, the default in-memory message store will be used. - /// - /// - /// An optional factory function to create a custom from its serialized state. - /// If not provided, no context provider will be configured. - /// - internal ChatClientAgentThread( - JsonElement serializedThreadState, - JsonSerializerOptions? jsonSerializerOptions = null, - Func? chatMessageStoreFactory = null, - Func? aiContextProviderFactory = null) - { - if (serializedThreadState.ValueKind != JsonValueKind.Object) - { - throw new ArgumentException("The serialized thread state must be a JSON object.", nameof(serializedThreadState)); - } - - var state = serializedThreadState.Deserialize( - AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ThreadState))) as ThreadState; - - this.AIContextProvider = aiContextProviderFactory?.Invoke(state?.AIContextProviderState ?? default, jsonSerializerOptions); - - if (state?.ConversationId is string threadId) - { - this.ConversationId = threadId; - - // Since we have an ID, we should not have a chat message store and we can return here. - return; - } - - this._messageStore = - chatMessageStoreFactory?.Invoke(state?.StoreState ?? default, jsonSerializerOptions) ?? - new InMemoryChatMessageStore(state?.StoreState ?? default, jsonSerializerOptions); // default to an in-memory store - } - /// /// Gets or sets the ID of the underlying service thread to support cases where the chat history is stored by the agent service. /// @@ -152,6 +112,58 @@ internal set /// public AIContextProvider? AIContextProvider { get; internal set; } + /// + /// Creates a new instance of the class from previously serialized state. + /// + /// A representing the serialized state of the thread. + /// Optional settings for customizing the JSON deserialization process. + /// + /// An optional factory function to create a custom from its serialized state. + /// If not provided, the default in-memory message store will be used. + /// + /// + /// An optional factory function to create a custom from its serialized state. + /// If not provided, no context provider will be configured. + /// + /// The to monitor for cancellation requests. + /// A task representing the asynchronous operation. The task result contains the deserialized . + internal static async Task DeserializeAsync( + JsonElement serializedThreadState, + JsonSerializerOptions? jsonSerializerOptions = null, + Func>? chatMessageStoreFactory = null, + Func>? aiContextProviderFactory = null, + CancellationToken cancellationToken = default) + { + if (serializedThreadState.ValueKind != JsonValueKind.Object) + { + throw new ArgumentException("The serialized thread state must be a JSON object.", nameof(serializedThreadState)); + } + + var state = serializedThreadState.Deserialize( + AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ThreadState))) as ThreadState; + + var thread = new ChatClientAgentThread(); + + thread.AIContextProvider = aiContextProviderFactory is not null + ? await aiContextProviderFactory.Invoke(state?.AIContextProviderState ?? default, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) + : null; + + if (state?.ConversationId is string threadId) + { + thread.ConversationId = threadId; + + // Since we have an ID, we should not have a chat message store and we can return here. + return thread; + } + + thread._messageStore = + chatMessageStoreFactory is not null + ? await chatMessageStoreFactory.Invoke(state?.StoreState ?? default, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) + : new InMemoryChatMessageStore(state?.StoreState ?? default, jsonSerializerOptions); // default to an in-memory store + + return thread; + } + /// public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientBuilderExtensions.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientBuilderExtensions.cs index fd4b6df60a..ee782dce52 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientBuilderExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientBuilderExtensions.cs @@ -49,7 +49,7 @@ public static ChatClientAgent BuildAIAgent( IList? tools = null, ILoggerFactory? loggerFactory = null, IServiceProvider? services = null) => - Throw.IfNull(builder).Build(services).CreateAIAgent( + Throw.IfNull(builder).Build(services).AsAIAgent( instructions: instructions, name: name, description: description, @@ -78,7 +78,7 @@ public static ChatClientAgent BuildAIAgent( ChatClientAgentOptions? options, ILoggerFactory? loggerFactory = null, IServiceProvider? services = null) => - Throw.IfNull(builder).Build(services).CreateAIAgent( + Throw.IfNull(builder).Build(services).AsAIAgent( options: options, loggerFactory: loggerFactory, services: services); diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientExtensions.cs index f65d41efe7..653f198402 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientExtensions.cs @@ -20,7 +20,7 @@ public static class ChatClientExtensions /// /// /// A new instance. - public static ChatClientAgent CreateAIAgent( + public static ChatClientAgent AsAIAgent( this IChatClient chatClient, string? instructions = null, string? name = null, @@ -42,7 +42,7 @@ public static ChatClientAgent CreateAIAgent( /// /// /// A new instance. - public static ChatClientAgent CreateAIAgent( + public static ChatClientAgent AsAIAgent( this IChatClient chatClient, ChatClientAgentOptions? options, ILoggerFactory? loggerFactory = null, diff --git a/dotnet/src/Microsoft.Agents.AI/FunctionInvocationDelegatingAgent.cs b/dotnet/src/Microsoft.Agents.AI/FunctionInvocationDelegatingAgent.cs index 2463b266c7..0604ef17d8 100644 --- a/dotnet/src/Microsoft.Agents.AI/FunctionInvocationDelegatingAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/FunctionInvocationDelegatingAgent.cs @@ -21,10 +21,10 @@ internal FunctionInvocationDelegatingAgent(AIAgent innerAgent, Func RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => this.InnerAgent.RunAsync(messages, thread, this.AgentRunOptionsWithFunctionMiddleware(options), cancellationToken); - protected override IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => this.InnerAgent.RunStreamingAsync(messages, thread, this.AgentRunOptionsWithFunctionMiddleware(options), cancellationToken); // Decorate options to add the middleware function diff --git a/dotnet/src/Microsoft.Agents.AI/LoggingAgent.cs b/dotnet/src/Microsoft.Agents.AI/LoggingAgent.cs index 03b85d1ef5..258ea55ed7 100644 --- a/dotnet/src/Microsoft.Agents.AI/LoggingAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/LoggingAgent.cs @@ -55,7 +55,7 @@ public JsonSerializerOptions JsonSerializerOptions } /// - protected override async Task RunCoreAsync( + protected override async Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { if (this._logger.IsEnabled(LogLevel.Debug)) @@ -72,7 +72,7 @@ protected override async Task RunCoreAsync( try { - AgentRunResponse response = await base.RunCoreAsync(messages, thread, options, cancellationToken).ConfigureAwait(false); + AgentResponse response = await base.RunCoreAsync(messages, thread, options, cancellationToken).ConfigureAwait(false); if (this._logger.IsEnabled(LogLevel.Debug)) { @@ -101,7 +101,7 @@ protected override async Task RunCoreAsync( } /// - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { if (this._logger.IsEnabled(LogLevel.Debug)) @@ -116,7 +116,7 @@ protected override async IAsyncEnumerable RunCoreStreami } } - IAsyncEnumerator e; + IAsyncEnumerator e; try { e = base.RunCoreStreamingAsync(messages, thread, options, cancellationToken).GetAsyncEnumerator(cancellationToken); @@ -134,7 +134,7 @@ protected override async IAsyncEnumerable RunCoreStreami try { - AgentRunResponseUpdate? update = null; + AgentResponseUpdate? update = null; while (true) { try diff --git a/dotnet/src/Microsoft.Agents.AI/OpenTelemetryAgent.cs b/dotnet/src/Microsoft.Agents.AI/OpenTelemetryAgent.cs index 35d31371c3..07dadf4e0b 100644 --- a/dotnet/src/Microsoft.Agents.AI/OpenTelemetryAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/OpenTelemetryAgent.cs @@ -78,25 +78,25 @@ public bool EnableSensitiveData } /// - protected override async Task RunCoreAsync( + protected override async Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { ChatOptions co = new ForwardedOptions(options, thread, Activity.Current); var response = await this._otelClient.GetResponseAsync(messages, co, cancellationToken).ConfigureAwait(false); - return response.RawRepresentation as AgentRunResponse ?? new AgentRunResponse(response); + return response.RawRepresentation as AgentResponse ?? new AgentResponse(response); } /// - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { ChatOptions co = new ForwardedOptions(options, thread, Activity.Current); await foreach (var update in this._otelClient.GetStreamingResponseAsync(messages, co, cancellationToken).ConfigureAwait(false)) { - yield return update.RawRepresentation as AgentRunResponseUpdate ?? new AgentRunResponseUpdate(update); + yield return update.RawRepresentation as AgentResponseUpdate ?? new AgentResponseUpdate(update); } } diff --git a/dotnet/src/Shared/Samples/BaseSample.cs b/dotnet/src/Shared/Samples/BaseSample.cs index 36c18f2db4..90c2d991a8 100644 --- a/dotnet/src/Shared/Samples/BaseSample.cs +++ b/dotnet/src/Shared/Samples/BaseSample.cs @@ -86,12 +86,12 @@ protected void WriteUserMessage(string message) => /// Processes and writes the latest agent chat response to the console, including metadata and content details. /// /// This method formats and outputs the most recent message from the provided object. It includes the message role, author name (if available), text content, and + /// cref="AgentResponse"/> object. It includes the message role, author name (if available), text content, and /// additional content such as images, function calls, and function results. Usage statistics, including token /// counts, are also displayed. - /// The object containing the chat messages and usage data. + /// The object containing the chat messages and usage data. /// The flag to indicate whether to print usage information. Defaults to . - protected void WriteResponseOutput(AgentRunResponse response, bool? printUsage = true) + protected void WriteResponseOutput(AgentResponse response, bool? printUsage = true) { if (response.Messages.Count == 0) { @@ -150,11 +150,11 @@ protected void WriteMessageOutput(ChatMessage message) /// Writes the streaming agent response updates to the console. /// /// This method formats and outputs the most recent message from the provided object. It includes the message role, author name (if available), text content, and + /// cref="AgentResponseUpdate"/> object. It includes the message role, author name (if available), text content, and /// additional content such as images, function calls, and function results. Usage statistics, including token /// counts, are also displayed. - /// The object containing the chat messages and usage data. - protected void WriteAgentOutput(AgentRunResponseUpdate update) + /// The object containing the chat messages and usage data. + protected void WriteAgentOutput(AgentResponseUpdate update) { if (update.Contents.Count == 0) { diff --git a/dotnet/src/Shared/Samples/OrchestrationSample.cs b/dotnet/src/Shared/Samples/OrchestrationSample.cs index 55f372de47..6eb8b5f886 100644 --- a/dotnet/src/Shared/Samples/OrchestrationSample.cs +++ b/dotnet/src/Shared/Samples/OrchestrationSample.cs @@ -75,13 +75,13 @@ protected static void WriteResponse(IEnumerable response) /// /// Writes the streamed agent run response updates to the console or test output, including role and author information. /// - /// An enumerable of objects representing streamed responses. - protected static void WriteStreamedResponse(IEnumerable streamedResponses) + /// An enumerable of objects representing streamed responses. + protected static void WriteStreamedResponse(IEnumerable streamedResponses) { string? authorName = null; ChatRole? authorRole = null; StringBuilder builder = new(); - foreach (AgentRunResponseUpdate response in streamedResponses) + foreach (AgentResponseUpdate response in streamedResponses) { authorName ??= response.AuthorName; authorRole ??= response.Role; @@ -106,7 +106,7 @@ protected sealed class OrchestrationMonitor /// /// Gets the list of streamed response updates received so far. /// - public List StreamedResponses { get; } = []; + public List StreamedResponses { get; } = []; /// /// Gets the list of chat messages representing the conversation history. @@ -131,9 +131,9 @@ public ValueTask ResponseCallbackAsync(IEnumerable response) /// /// Callback to handle a streamed agent run response update, adding it to the list and writing output if final. /// - /// The to process. + /// The to process. /// A representing the asynchronous operation. - public ValueTask StreamingResultCallbackAsync(AgentRunResponseUpdate streamedResponse) + public ValueTask StreamingResultCallbackAsync(AgentResponseUpdate streamedResponse) { this.StreamedResponses.Add(streamedResponse); return default; diff --git a/dotnet/src/Shared/Workflows/Execution/WorkflowRunner.cs b/dotnet/src/Shared/Workflows/Execution/WorkflowRunner.cs index d82bc800bb..b8666451f6 100644 --- a/dotnet/src/Shared/Workflows/Execution/WorkflowRunner.cs +++ b/dotnet/src/Shared/Workflows/Execution/WorkflowRunner.cs @@ -177,7 +177,7 @@ public async Task ExecuteAsync(Func workflowProvider, string input) Console.ResetColor(); break; - case AgentRunUpdateEvent streamEvent: + case AgentResponseUpdateEvent streamEvent: if (!string.Equals(messageId, streamEvent.Update.MessageId, StringComparison.Ordinal)) { hasStreamed = false; @@ -230,7 +230,7 @@ public async Task ExecuteAsync(Func workflowProvider, string input) } break; - case AgentRunResponseEvent messageEvent: + case AgentResponseEvent messageEvent: try { if (hasStreamed) diff --git a/dotnet/tests/AgentConformance.IntegrationTests/ChatClientAgentRunStreamingTests.cs b/dotnet/tests/AgentConformance.IntegrationTests/ChatClientAgentRunStreamingTests.cs index 834de0ea4e..2d8d6787aa 100644 --- a/dotnet/tests/AgentConformance.IntegrationTests/ChatClientAgentRunStreamingTests.cs +++ b/dotnet/tests/AgentConformance.IntegrationTests/ChatClientAgentRunStreamingTests.cs @@ -22,7 +22,7 @@ public virtual async Task RunWithInstructionsAndNoMessageReturnsExpectedResultAs { // Arrange var agent = await this.Fixture.CreateChatClientAgentAsync(instructions: "Always respond with 'Computer says no', even if there was no user input."); - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await using var agentCleanup = new AgentCleanup(agent, this.Fixture); await using var threadCleanup = new ThreadCleanup(thread, this.Fixture); @@ -53,7 +53,7 @@ public virtual async Task RunWithFunctionsInvokesFunctionsAndReturnsExpectedResu AIFunctionFactory.Create(MenuPlugin.GetSpecials), AIFunctionFactory.Create(MenuPlugin.GetItemPrice) ]); - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); foreach (var questionAndAnswer in questionsAndAnswers) { diff --git a/dotnet/tests/AgentConformance.IntegrationTests/ChatClientAgentRunTests.cs b/dotnet/tests/AgentConformance.IntegrationTests/ChatClientAgentRunTests.cs index ab85bf5ba0..80fd7106ac 100644 --- a/dotnet/tests/AgentConformance.IntegrationTests/ChatClientAgentRunTests.cs +++ b/dotnet/tests/AgentConformance.IntegrationTests/ChatClientAgentRunTests.cs @@ -21,7 +21,7 @@ public virtual async Task RunWithInstructionsAndNoMessageReturnsExpectedResultAs { // Arrange var agent = await this.Fixture.CreateChatClientAgentAsync(instructions: "ALWAYS RESPOND WITH 'Computer says no', even if there was no user input."); - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await using var agentCleanup = new AgentCleanup(agent, this.Fixture); await using var threadCleanup = new ThreadCleanup(thread, this.Fixture); @@ -53,7 +53,7 @@ public virtual async Task RunWithFunctionsInvokesFunctionsAndReturnsExpectedResu AIFunctionFactory.Create(MenuPlugin.GetSpecials), AIFunctionFactory.Create(MenuPlugin.GetItemPrice) ]); - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); foreach (var questionAndAnswer in questionsAndAnswers) { diff --git a/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs b/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs index a2da3e0d6e..d5c85b1866 100644 --- a/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs +++ b/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs @@ -24,7 +24,7 @@ public virtual async Task RunWithNoMessageDoesNotFailAsync() { // Arrange var agent = this.Fixture.Agent; - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await using var cleanup = new ThreadCleanup(thread, this.Fixture); // Act @@ -36,7 +36,7 @@ public virtual async Task RunWithStringReturnsExpectedResultAsync() { // Arrange var agent = this.Fixture.Agent; - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await using var cleanup = new ThreadCleanup(thread, this.Fixture); // Act @@ -52,7 +52,7 @@ public virtual async Task RunWithChatMessageReturnsExpectedResultAsync() { // Arrange var agent = this.Fixture.Agent; - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await using var cleanup = new ThreadCleanup(thread, this.Fixture); // Act @@ -68,7 +68,7 @@ public virtual async Task RunWithChatMessagesReturnsExpectedResultAsync() { // Arrange var agent = this.Fixture.Agent; - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await using var cleanup = new ThreadCleanup(thread, this.Fixture); // Act @@ -92,7 +92,7 @@ public virtual async Task ThreadMaintainsHistoryAsync() const string Q1 = "What is the capital of France."; const string Q2 = "And Austria?"; var agent = this.Fixture.Agent; - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await using var cleanup = new ThreadCleanup(thread, this.Fixture); // Act diff --git a/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs b/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs index 58f8b67d1d..be98bbd2bf 100644 --- a/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs +++ b/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs @@ -24,7 +24,7 @@ public virtual async Task RunWithNoMessageDoesNotFailAsync() { // Arrange var agent = this.Fixture.Agent; - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await using var cleanup = new ThreadCleanup(thread, this.Fixture); // Act @@ -39,7 +39,7 @@ public virtual async Task RunWithStringReturnsExpectedResultAsync() { // Arrange var agent = this.Fixture.Agent; - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await using var cleanup = new ThreadCleanup(thread, this.Fixture); // Act @@ -57,7 +57,7 @@ public virtual async Task RunWithChatMessageReturnsExpectedResultAsync() { // Arrange var agent = this.Fixture.Agent; - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await using var cleanup = new ThreadCleanup(thread, this.Fixture); // Act @@ -74,7 +74,7 @@ public virtual async Task RunWithChatMessagesReturnsExpectedResultAsync() { // Arrange var agent = this.Fixture.Agent; - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await using var cleanup = new ThreadCleanup(thread, this.Fixture); // Act @@ -99,7 +99,7 @@ public virtual async Task ThreadMaintainsHistoryAsync() const string Q1 = "What is the capital of France."; const string Q2 = "And Austria?"; var agent = this.Fixture.Agent; - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await using var cleanup = new ThreadCleanup(thread, this.Fixture); // Act diff --git a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs index 236ae7b332..b6002de9ef 100644 --- a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs @@ -148,7 +148,7 @@ public async Task RunAsync_WithNewThread_UpdatesThreadConversationIdAsync() new(ChatRole.User, "Test message") }; - var thread = this._agent.GetNewThread(); + var thread = await this._agent.GetNewThreadAsync(); // Act await this._agent.RunAsync(inputMessages, thread); @@ -168,7 +168,7 @@ public async Task RunAsync_WithExistingThread_SetConversationIdToMessageAsync() new(ChatRole.User, "Test message") }; - var thread = this._agent.GetNewThread(); + var thread = await this._agent.GetNewThreadAsync(); var a2aThread = (A2AAgentThread)thread; a2aThread.ContextId = "existing-context-id"; @@ -201,7 +201,7 @@ public async Task RunAsync_WithThreadHavingDifferentContextId_ThrowsInvalidOpera ContextId = "different-context" }; - var thread = this._agent.GetNewThread(); + var thread = await this._agent.GetNewThreadAsync(); var a2aThread = (A2AAgentThread)thread; a2aThread.ContextId = "existing-context-id"; @@ -210,7 +210,7 @@ public async Task RunAsync_WithThreadHavingDifferentContextId_ThrowsInvalidOpera } [Fact] - public async Task RunStreamingAsync_WithValidUserMessage_YieldsAgentRunResponseUpdatesAsync() + public async Task RunStreamingAsync_WithValidUserMessage_YieldsAgentResponseUpdatesAsync() { // Arrange var inputMessages = new List @@ -227,7 +227,7 @@ public async Task RunStreamingAsync_WithValidUserMessage_YieldsAgentRunResponseU }; // Act - var updates = new List(); + var updates = new List(); await foreach (var update in this._agent.RunStreamingAsync(inputMessages)) { updates.Add(update); @@ -272,7 +272,7 @@ public async Task RunStreamingAsync_WithThread_UpdatesThreadConversationIdAsync( ContextId = "new-stream-context" }; - var thread = this._agent.GetNewThread(); + var thread = await this._agent.GetNewThreadAsync(); // Act await foreach (var _ in this._agent.RunStreamingAsync(inputMessages, thread)) @@ -296,7 +296,7 @@ public async Task RunStreamingAsync_WithExistingThread_SetConversationIdToMessag this._handler.StreamingResponseToReturn = new AgentMessage(); - var thread = this._agent.GetNewThread(); + var thread = await this._agent.GetNewThreadAsync(); var a2aThread = (A2AAgentThread)thread; a2aThread.ContextId = "existing-context-id"; @@ -316,7 +316,7 @@ public async Task RunStreamingAsync_WithExistingThread_SetConversationIdToMessag public async Task RunStreamingAsync_WithThreadHavingDifferentContextId_ThrowsInvalidOperationExceptionAsync() { // Arrange - var thread = this._agent.GetNewThread(); + var thread = await this._agent.GetNewThreadAsync(); var a2aThread = (A2AAgentThread)thread; a2aThread.ContextId = "existing-context-id"; @@ -440,7 +440,7 @@ public async Task RunAsync_WithTaskInThreadAndMessage_AddTaskAsReferencesToMessa Parts = [new TextPart { Text = "Response to task" }] }; - var thread = (A2AAgentThread)this._agent.GetNewThread(); + var thread = (A2AAgentThread)await this._agent.GetNewThreadAsync(); thread.TaskId = "task-123"; var inputMessage = new ChatMessage(ChatRole.User, "Please make the background transparent"); @@ -466,7 +466,7 @@ public async Task RunAsync_WithAgentTask_UpdatesThreadTaskIdAsync() Status = new() { State = TaskState.Submitted } }; - var thread = this._agent.GetNewThread(); + var thread = await this._agent.GetNewThreadAsync(); // Act await this._agent.RunAsync("Start a task", thread); @@ -492,7 +492,7 @@ public async Task RunAsync_WithAgentTaskResponse_ReturnsTaskResponseCorrectlyAsy } }; - var thread = this._agent.GetNewThread(); + var thread = await this._agent.GetNewThreadAsync(); // Act var result = await this._agent.RunAsync("Start a long-running task", thread); @@ -586,7 +586,7 @@ public async Task RunStreamingAsync_WithTaskInThreadAndMessage_AddTaskAsReferenc Parts = [new TextPart { Text = "Response to task" }] }; - var thread = (A2AAgentThread)this._agent.GetNewThread(); + var thread = (A2AAgentThread)await this._agent.GetNewThreadAsync(); thread.TaskId = "task-123"; // Act @@ -613,7 +613,7 @@ public async Task RunStreamingAsync_WithAgentTask_UpdatesThreadTaskIdAsync() Status = new() { State = TaskState.Submitted } }; - var thread = this._agent.GetNewThread(); + var thread = await this._agent.GetNewThreadAsync(); // Act await foreach (var _ in this._agent.RunStreamingAsync("Start a task", thread)) @@ -646,7 +646,7 @@ public async Task RunStreamingAsync_WithAgentMessage_YieldsResponseUpdateAsync() }; // Act - var updates = new List(); + var updates = new List(); await foreach (var update in this._agent.RunStreamingAsync("Test message")) { updates.Add(update); @@ -686,10 +686,10 @@ public async Task RunStreamingAsync_WithAgentTask_YieldsResponseUpdateAsync() ] }; - var thread = this._agent.GetNewThread(); + var thread = await this._agent.GetNewThreadAsync(); // Act - var updates = new List(); + var updates = new List(); await foreach (var update in this._agent.RunStreamingAsync("Start long-running task", thread)) { updates.Add(update); @@ -725,10 +725,10 @@ public async Task RunStreamingAsync_WithTaskStatusUpdateEvent_YieldsResponseUpda Status = new() { State = TaskState.Working } }; - var thread = this._agent.GetNewThread(); + var thread = await this._agent.GetNewThreadAsync(); // Act - var updates = new List(); + var updates = new List(); await foreach (var update in this._agent.RunStreamingAsync("Check task status", thread)) { updates.Add(update); @@ -768,10 +768,10 @@ public async Task RunStreamingAsync_WithTaskArtifactUpdateEvent_YieldsResponseUp } }; - var thread = this._agent.GetNewThread(); + var thread = await this._agent.GetNewThreadAsync(); // Act - var updates = new List(); + var updates = new List(); await foreach (var update in this._agent.RunStreamingAsync("Process artifact", thread)) { updates.Add(update); diff --git a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/Extensions/A2AAgentCardExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/Extensions/A2AAgentCardExtensionsTests.cs index 16e80b4b26..f644109b38 100644 --- a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/Extensions/A2AAgentCardExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/Extensions/A2AAgentCardExtensionsTests.cs @@ -34,7 +34,7 @@ public A2AAgentCardExtensionsTests() public void GetAIAgent_ReturnsAIAgent() { // Act - var agent = this._agentCard.GetAIAgent(); + var agent = this._agentCard.AsAIAgent(); // Assert Assert.NotNull(agent); @@ -56,7 +56,7 @@ public async Task RunIAgentAsync_SendsRequestToTheUrlSpecifiedInAgentCardAsync() Parts = [new TextPart { Text = "Response" }], }); - var agent = this._agentCard.GetAIAgent(httpClient); + var agent = this._agentCard.AsAIAgent(httpClient); // Act await agent.RunAsync("Test input"); diff --git a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/Extensions/A2AClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/Extensions/A2AClientExtensionsTests.cs index 5b84324e8b..9ad4d982a9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/Extensions/A2AClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/Extensions/A2AClientExtensionsTests.cs @@ -21,7 +21,7 @@ public void GetAIAgent_WithAllParameters_ReturnsA2AAgentWithSpecifiedProperties( const string TestDescription = "This is a test agent description"; // Act - var agent = a2aClient.GetAIAgent(TestId, TestName, TestDescription); + var agent = a2aClient.AsAIAgent(TestId, TestName, TestDescription); // Assert Assert.NotNull(agent); diff --git a/dotnet/tests/Microsoft.Agents.AI.AGUI.UnitTests/AGUIChatClientTests.cs b/dotnet/tests/Microsoft.Agents.AI.AGUI.UnitTests/AGUIChatClientTests.cs index c61a7e289d..9109118f73 100644 --- a/dotnet/tests/Microsoft.Agents.AI.AGUI.UnitTests/AGUIChatClientTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.AGUI.UnitTests/AGUIChatClientTests.cs @@ -30,11 +30,11 @@ public async Task RunAsync_AggregatesStreamingUpdates_ReturnsCompleteMessagesAsy ]); var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: []); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: []); List messages = [new ChatMessage(ChatRole.User, "Test")]; // Act - AgentRunResponse response = await agent.RunAsync(messages); + AgentResponse response = await agent.RunAsync(messages); // Assert Assert.NotNull(response); @@ -55,11 +55,11 @@ public async Task RunAsync_WithEmptyUpdateStream_ContainsOnlyMetadataMessagesAsy ]); var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: []); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: []); List messages = [new ChatMessage(ChatRole.User, "Test")]; // Act - AgentRunResponse response = await agent.RunAsync(messages); + AgentResponse response = await agent.RunAsync(messages); // Assert Assert.NotNull(response); @@ -74,7 +74,7 @@ public async Task RunAsync_WithNullMessages_ThrowsArgumentNullExceptionAsync() // Arrange using HttpClient httpClient = new(); var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: "Test agent", name: "agent1"); + AIAgent agent = chatClient.AsAIAgent(instructions: "Test agent", name: "agent1"); // Act & Assert await Assert.ThrowsAsync(() => agent.RunAsync(messages: null!)); @@ -91,11 +91,11 @@ public async Task RunAsync_WithNullThread_CreatesNewThreadAsync() ]); var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: "Test agent", name: "agent1"); + AIAgent agent = chatClient.AsAIAgent(instructions: "Test agent", name: "agent1"); List messages = [new ChatMessage(ChatRole.User, "Test")]; // Act - AgentRunResponse response = await agent.RunAsync(messages, thread: null); + AgentResponse response = await agent.RunAsync(messages, thread: null); // Assert Assert.NotNull(response); @@ -115,12 +115,12 @@ public async Task RunStreamingAsync_YieldsAllEvents_FromServerStreamAsync() ]); var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: "Test agent", name: "agent1"); + AIAgent agent = chatClient.AsAIAgent(instructions: "Test agent", name: "agent1"); List messages = [new ChatMessage(ChatRole.User, "Test")]; // Act - List updates = []; - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(messages)) + List updates = []; + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(messages)) { // Consume the stream updates.Add(update); @@ -139,7 +139,7 @@ public async Task RunStreamingAsync_WithNullMessages_ThrowsArgumentNullException // Arrange using HttpClient httpClient = new(); var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: "Test agent", name: "agent1"); + AIAgent agent = chatClient.AsAIAgent(instructions: "Test agent", name: "agent1"); // Act & Assert await Assert.ThrowsAsync(async () => @@ -162,12 +162,12 @@ public async Task RunStreamingAsync_WithNullThread_CreatesNewThreadAsync() ]); var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: "Test agent", name: "agent1"); + AIAgent agent = chatClient.AsAIAgent(instructions: "Test agent", name: "agent1"); List messages = [new ChatMessage(ChatRole.User, "Test")]; // Act - List updates = []; - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(messages, thread: null)) + List updates = []; + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(messages, thread: null)) { // Consume the stream updates.Add(update); @@ -195,7 +195,7 @@ public async Task RunStreamingAsync_GeneratesUniqueRunId_ForEachInvocationAsync( using HttpClient httpClient = new(handler); var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: []); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: []); List messages = [new ChatMessage(ChatRole.User, "Test")]; // Act @@ -227,12 +227,12 @@ public async Task RunStreamingAsync_ReturnsStreamingUpdates_AfterCompletionAsync ]); var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: []); - AgentThread thread = agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: []); + AgentThread thread = await agent.GetNewThreadAsync(); List messages = [new ChatMessage(ChatRole.User, "Hello")]; // Act - List updates = []; + List updates = []; await foreach (var update in agent.RunStreamingAsync(messages, thread)) { updates.Add(update); @@ -244,17 +244,17 @@ public async Task RunStreamingAsync_ReturnsStreamingUpdates_AfterCompletionAsync } [Fact] - public void DeserializeThread_WithValidState_ReturnsChatClientAgentThread() + public async Task DeserializeThread_WithValidState_ReturnsChatClientAgentThreadAsync() { // Arrange using var httpClient = new HttpClient(); var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: []); - AgentThread originalThread = agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: []); + AgentThread originalThread = await agent.GetNewThreadAsync(); JsonElement serialized = originalThread.Serialize(); // Act - AgentThread deserialized = agent.DeserializeThread(serialized); + AgentThread deserialized = await agent.DeserializeThreadAsync(serialized); // Assert Assert.NotNull(deserialized); @@ -301,12 +301,12 @@ public async Task RunStreamingAsync_InvokesTools_WhenFunctionCallsReturnedAsync( ]); var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: [testTool]); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: [testTool]); List messages = [new ChatMessage(ChatRole.User, "What's the weather?")]; // Act - List allUpdates = []; - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(messages)) + List allUpdates = []; + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(messages)) { allUpdates.Add(update); } @@ -353,12 +353,12 @@ public async Task RunStreamingAsync_DoesNotInvokeTools_WhenSomeToolsNotAvailable using HttpClient httpClient = new(handler); var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: [tool1]); // Only tool1, not tool2 + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: [tool1]); // Only tool1, not tool2 List messages = [new ChatMessage(ChatRole.User, "Test")]; // Act - List allUpdates = []; - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(messages)) + List allUpdates = []; + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(messages)) { allUpdates.Add(update); } @@ -403,12 +403,12 @@ public async Task RunStreamingAsync_HandlesToolInvocationErrors_GracefullyAsync( ]); var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: [faultyTool]); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: [faultyTool]); List messages = [new ChatMessage(ChatRole.User, "Test")]; // Act - List allUpdates = []; - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(messages)) + List allUpdates = []; + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(messages)) { allUpdates.Add(update); } @@ -448,7 +448,7 @@ public async Task RunStreamingAsync_InvokesMultipleTools_InSingleTurnAsync() ]); var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: [tool1, tool2]); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: [tool1, tool2]); List messages = [new ChatMessage(ChatRole.User, "Test")]; // Act @@ -486,12 +486,12 @@ public async Task RunStreamingAsync_UpdatesThreadWithToolMessages_AfterCompletio ]); var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: [testTool]); - AgentThread thread = agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: [testTool]); + AgentThread thread = await agent.GetNewThreadAsync(); List messages = [new ChatMessage(ChatRole.User, "Test")]; // Act - List updates = []; + List updates = []; await foreach (var update in agent.RunStreamingAsync(messages, thread)) { updates.Add(update); diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs index a1c8cb32bf..8d5f1b0b87 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs @@ -19,8 +19,8 @@ public class AIAgentTests { private readonly Mock _agentMock; private readonly Mock _agentThreadMock; - private readonly AgentRunResponse _invokeResponse; - private readonly List _invokeStreamingResponses = []; + private readonly AgentResponse _invokeResponse; + private readonly List _invokeStreamingResponses = []; /// /// Initializes a new instance of the class. @@ -29,13 +29,13 @@ public AIAgentTests() { this._agentThreadMock = new Mock(MockBehavior.Strict); - this._invokeResponse = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "Hi")); - this._invokeStreamingResponses.Add(new AgentRunResponseUpdate(ChatRole.Assistant, "Hi")); + this._invokeResponse = new AgentResponse(new ChatMessage(ChatRole.Assistant, "Hi")); + this._invokeStreamingResponses.Add(new AgentResponseUpdate(ChatRole.Assistant, "Hi")); this._agentMock = new Mock { CallBase = true }; this._agentMock .Protected() - .Setup>("RunCoreAsync", + .Setup>("RunCoreAsync", ItExpr.IsAny>(), ItExpr.Is(t => t == this._agentThreadMock.Object), ItExpr.IsAny(), @@ -43,7 +43,7 @@ public AIAgentTests() .ReturnsAsync(this._invokeResponse); this._agentMock .Protected() - .Setup>("RunCoreStreamingAsync", + .Setup>("RunCoreStreamingAsync", ItExpr.IsAny>(), ItExpr.Is(t => t == this._agentThreadMock.Object), ItExpr.IsAny(), @@ -69,7 +69,7 @@ public async Task InvokeWithoutMessageCallsMockedInvokeWithEmptyArrayAsync() // Verify that the mocked method was called with the expected parameters this._agentMock .Protected() - .Verify>("RunCoreAsync", + .Verify>("RunCoreAsync", Times.Once(), ItExpr.Is>(messages => !messages.Any()), ItExpr.Is(t => t == this._agentThreadMock.Object), @@ -96,7 +96,7 @@ public async Task InvokeWithStringMessageCallsMockedInvokeWithMessageInCollectio // Verify that the mocked method was called with the expected parameters this._agentMock .Protected() - .Verify>("RunCoreAsync", + .Verify>("RunCoreAsync", Times.Once(), ItExpr.Is>(messages => messages.Count() == 1 && messages.First().Text == Message), ItExpr.Is(t => t == this._agentThreadMock.Object), @@ -123,7 +123,7 @@ public async Task InvokeWithSingleMessageCallsMockedInvokeWithMessageInCollectio // Verify that the mocked method was called with the expected parameters this._agentMock .Protected() - .Verify>("RunCoreAsync", + .Verify>("RunCoreAsync", Times.Once(), ItExpr.Is>(messages => messages.Count() == 1 && messages.First() == message), ItExpr.Is(t => t == this._agentThreadMock.Object), @@ -152,7 +152,7 @@ public async Task InvokeStreamingWithoutMessageCallsMockedInvokeWithEmptyArrayAs // Verify that the mocked method was called with the expected parameters this._agentMock .Protected() - .Verify>("RunCoreStreamingAsync", + .Verify>("RunCoreStreamingAsync", Times.Once(), ItExpr.Is>(messages => !messages.Any()), ItExpr.Is(t => t == this._agentThreadMock.Object), @@ -182,7 +182,7 @@ public async Task InvokeStreamingWithStringMessageCallsMockedInvokeWithMessageIn // Verify that the mocked method was called with the expected parameters this._agentMock .Protected() - .Verify>("RunCoreStreamingAsync", + .Verify>("RunCoreStreamingAsync", Times.Once(), ItExpr.Is>(messages => messages.Count() == 1 && messages.First().Text == Message), ItExpr.Is(t => t == this._agentThreadMock.Object), @@ -212,7 +212,7 @@ public async Task InvokeStreamingWithSingleMessageCallsMockedInvokeWithMessageIn // Verify that the mocked method was called with the expected parameters this._agentMock .Protected() - .Verify>("RunCoreStreamingAsync", + .Verify>("RunCoreStreamingAsync", Times.Once(), ItExpr.Is>(messages => messages.Count() == 1 && messages.First() == message), ItExpr.Is(t => t == this._agentThreadMock.Object), @@ -378,20 +378,20 @@ public MockAgent(string? id = null) protected override string? IdCore { get; } - public override AgentThread GetNewThread() + public override async ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override async ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - protected override Task RunCoreAsync( + protected override Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - protected override IAsyncEnumerable RunCoreStreamingAsync( + protected override IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentAbstractionsJsonUtilitiesTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentAbstractionsJsonUtilitiesTests.cs index e286796243..5958bba3b3 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentAbstractionsJsonUtilitiesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentAbstractionsJsonUtilitiesTests.cs @@ -79,9 +79,9 @@ public void DefaultOptions_SerializesEnumsAsStrings() #endif [Fact] - public void DefaultOptions_UsesCamelCasePropertyNames_ForAgentRunResponse() + public void DefaultOptions_UsesCamelCasePropertyNames_ForAgentResponse() { - var response = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "Hello")); + var response = new AgentResponse(new ChatMessage(ChatRole.Assistant, "Hello")); string json = JsonSerializer.Serialize(response, AgentAbstractionsJsonUtilities.DefaultOptions); Assert.Contains("\"messages\"", json); Assert.DoesNotContain("\"Messages\"", json); diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentResponseTests.cs similarity index 83% rename from dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseTests.cs rename to dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentResponseTests.cs index 8e39b4c4fa..75bc90ca8e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentResponseTests.cs @@ -9,12 +9,12 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; -public class AgentRunResponseTests +public class AgentResponseTests { [Fact] public void ConstructorWithNullEmptyArgsIsValid() { - AgentRunResponse response; + AgentResponse response; response = new(); Assert.Empty(response.Messages); @@ -26,13 +26,13 @@ public void ConstructorWithNullEmptyArgsIsValid() Assert.Empty(response.Text); Assert.Null(response.ContinuationToken); - Assert.Throws("message", () => new AgentRunResponse((ChatMessage)null!)); + Assert.Throws("message", () => new AgentResponse((ChatMessage)null!)); } [Fact] public void ConstructorWithMessagesRoundtrips() { - AgentRunResponse response = new(); + AgentResponse response = new(); Assert.NotNull(response.Messages); Assert.Same(response.Messages, response.Messages); @@ -60,7 +60,7 @@ public void ConstructorWithChatResponseRoundtrips() ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; - AgentRunResponse response = new(chatResponse); + AgentResponse response = new(chatResponse); Assert.Same(chatResponse.AdditionalProperties, response.AdditionalProperties); Assert.Equal(chatResponse.CreatedAt, response.CreatedAt); Assert.Same(chatResponse.Messages, response.Messages); @@ -73,7 +73,7 @@ public void ConstructorWithChatResponseRoundtrips() [Fact] public void PropertiesRoundtrip() { - AgentRunResponse response = new(); + AgentResponse response = new(); Assert.Null(response.AgentId); response.AgentId = "agentId"; @@ -110,7 +110,7 @@ public void PropertiesRoundtrip() [Fact] public void JsonSerializationRoundtrips() { - AgentRunResponse original = new(new ChatMessage(ChatRole.Assistant, "the message")) + AgentResponse original = new(new ChatMessage(ChatRole.Assistant, "the message")) { AgentId = "agentId", ResponseId = "id", @@ -123,7 +123,7 @@ public void JsonSerializationRoundtrips() string json = JsonSerializer.Serialize(original, AgentAbstractionsJsonUtilities.DefaultOptions); - AgentRunResponse? result = JsonSerializer.Deserialize(json, AgentAbstractionsJsonUtilities.DefaultOptions); + AgentResponse? result = JsonSerializer.Deserialize(json, AgentAbstractionsJsonUtilities.DefaultOptions); Assert.NotNull(result); Assert.Equal(ChatRole.Assistant, result.Messages.Single().Role); @@ -145,7 +145,7 @@ public void JsonSerializationRoundtrips() [Fact] public void ToStringOutputsText() { - AgentRunResponse response = new(new ChatMessage(ChatRole.Assistant, $"This is a test.{Environment.NewLine}It's multiple lines.")); + AgentResponse response = new(new ChatMessage(ChatRole.Assistant, $"This is a test.{Environment.NewLine}It's multiple lines.")); Assert.Equal(response.Text, response.ToString()); } @@ -153,7 +153,7 @@ public void ToStringOutputsText() [Fact] public void TextGetConcatenatesAllTextContent() { - AgentRunResponse response = new( + AgentResponse response = new( [ new ChatMessage( ChatRole.Assistant, @@ -174,15 +174,15 @@ public void TextGetConcatenatesAllTextContent() [Fact] public void TextGetReturnsEmptyStringWithNoMessages() { - AgentRunResponse response = new(); + AgentResponse response = new(); Assert.Equal(string.Empty, response.Text); } [Fact] - public void ToAgentRunResponseUpdatesProducesUpdates() + public void ToAgentResponseUpdatesProducesUpdates() { - AgentRunResponse response = new(new ChatMessage(new ChatRole("customRole"), "Text") { MessageId = "someMessage" }) + AgentResponse response = new(new ChatMessage(new ChatRole("customRole"), "Text") { MessageId = "someMessage" }) { AgentId = "agentId", ResponseId = "12345", @@ -194,11 +194,11 @@ public void ToAgentRunResponseUpdatesProducesUpdates() }, }; - AgentRunResponseUpdate[] updates = response.ToAgentRunResponseUpdates(); + AgentResponseUpdate[] updates = response.ToAgentResponseUpdates(); Assert.NotNull(updates); Assert.Equal(2, updates.Length); - AgentRunResponseUpdate update0 = updates[0]; + AgentResponseUpdate update0 = updates[0]; Assert.Equal("agentId", update0.AgentId); Assert.Equal("12345", update0.ResponseId); Assert.Equal("someMessage", update0.MessageId); @@ -206,7 +206,7 @@ public void ToAgentRunResponseUpdatesProducesUpdates() Assert.Equal("customRole", update0.Role?.Value); Assert.Equal("Text", update0.Text); - AgentRunResponseUpdate update1 = updates[1]; + AgentResponseUpdate update1 = updates[1]; Assert.Equal("value1", update1.AdditionalProperties?["key1"]); Assert.Equal(42, update1.AdditionalProperties?["key2"]); Assert.IsType(update1.Contents[0]); @@ -225,7 +225,7 @@ public void ParseAsStructuredOutputSuccess() { // Arrange. var expectedResult = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger }; - var response = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult, TestJsonSerializerContext.Default.Animal))); + var response = new AgentResponse(new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult, TestJsonSerializerContext.Default.Animal))); // Act. var animal = response.Deserialize(); @@ -243,7 +243,7 @@ public void ParseAsStructuredOutputWithJSOSuccess() { // Arrange. var expectedResult = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger }; - var response = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult, TestJsonSerializerContext.Default.Animal))); + var response = new AgentResponse(new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult, TestJsonSerializerContext.Default.Animal))); // Act. var animal = response.Deserialize(TestJsonSerializerContext.Default.Options); @@ -259,7 +259,7 @@ public void ParseAsStructuredOutputWithJSOSuccess() public void ParseAsStructuredOutputFailsWithEmptyString() { // Arrange. - var response = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, string.Empty)); + var response = new AgentResponse(new ChatMessage(ChatRole.Assistant, string.Empty)); // Act & Assert. var exception = Assert.Throws(() => response.Deserialize(TestJsonSerializerContext.Default.Options)); @@ -270,7 +270,7 @@ public void ParseAsStructuredOutputFailsWithEmptyString() public void ParseAsStructuredOutputFailsWithInvalidJson() { // Arrange. - var response = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "invalid json")); + var response = new AgentResponse(new ChatMessage(ChatRole.Assistant, "invalid json")); // Act & Assert. Assert.Throws(() => response.Deserialize(TestJsonSerializerContext.Default.Options)); @@ -280,7 +280,7 @@ public void ParseAsStructuredOutputFailsWithInvalidJson() public void ParseAsStructuredOutputFailsWithIncorrectTypedJson() { // Arrange. - var response = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "[]")); + var response = new AgentResponse(new ChatMessage(ChatRole.Assistant, "[]")); // Act & Assert. Assert.Throws(() => response.Deserialize(TestJsonSerializerContext.Default.Options)); @@ -297,7 +297,7 @@ public void TryParseAsStructuredOutputSuccess() { // Arrange. var expectedResult = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger }; - var response = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult, TestJsonSerializerContext.Default.Animal))); + var response = new AgentResponse(new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult, TestJsonSerializerContext.Default.Animal))); // Act. response.TryDeserialize(out Animal? animal); @@ -315,7 +315,7 @@ public void TryParseAsStructuredOutputWithJSOSuccess() { // Arrange. var expectedResult = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger }; - var response = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult, TestJsonSerializerContext.Default.Animal))); + var response = new AgentResponse(new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult, TestJsonSerializerContext.Default.Animal))); // Act. response.TryDeserialize(TestJsonSerializerContext.Default.Options, out Animal? animal); @@ -331,7 +331,7 @@ public void TryParseAsStructuredOutputWithJSOSuccess() public void TryParseAsStructuredOutputFailsWithEmptyText() { // Arrange. - var response = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, string.Empty)); + var response = new AgentResponse(new ChatMessage(ChatRole.Assistant, string.Empty)); // Act & Assert. Assert.False(response.TryDeserialize(TestJsonSerializerContext.Default.Options, out _)); @@ -341,7 +341,7 @@ public void TryParseAsStructuredOutputFailsWithEmptyText() public void TryParseAsStructuredOutputFailsWithIncorrectTypedJson() { // Arrange. - var response = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "[]")); + var response = new AgentResponse(new ChatMessage(ChatRole.Assistant, "[]")); // Act & Assert. Assert.False(response.TryDeserialize(TestJsonSerializerContext.Default.Options, out _)); diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseUpdateExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentResponseUpdateExtensionsTests.cs similarity index 80% rename from dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseUpdateExtensionsTests.cs rename to dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentResponseUpdateExtensionsTests.cs index a653cf80f5..2f136066e4 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseUpdateExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentResponseUpdateExtensionsTests.cs @@ -9,9 +9,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; -public class AgentRunResponseUpdateExtensionsTests +public class AgentResponseUpdateExtensionsTests { - public static IEnumerable ToAgentRunResponseCoalescesVariousSequenceAndGapLengthsMemberData() + public static IEnumerable ToAgentResponseCoalescesVariousSequenceAndGapLengthsMemberData() { foreach (bool useAsync in new[] { false, true }) { @@ -32,15 +32,15 @@ public static IEnumerable ToAgentRunResponseCoalescesVariousSequenceAn } [Fact] - public void ToAgentRunResponseWithInvalidArgsThrows() => - Assert.Throws("updates", () => ((List)null!).ToAgentRunResponse()); + public void ToAgentResponseWithInvalidArgsThrows() => + Assert.Throws("updates", () => ((List)null!).ToAgentResponse()); [Theory] [InlineData(false)] [InlineData(true)] - public async Task ToAgentRunResponseSuccessfullyCreatesResponseAsync(bool useAsync) + public async Task ToAgentResponseSuccessfullyCreatesResponseAsync(bool useAsync) { - AgentRunResponseUpdate[] updates = + AgentResponseUpdate[] updates = [ new(ChatRole.Assistant, "Hello") { ResponseId = "someResponse", MessageId = "12345", CreatedAt = new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), AgentId = "agentId" }, new(new("human"), ", ") { AuthorName = "Someone", AdditionalProperties = new() { ["a"] = "b" } }, @@ -50,9 +50,9 @@ public async Task ToAgentRunResponseSuccessfullyCreatesResponseAsync(bool useAsy new() { Contents = [new UsageContent(new() { InputTokenCount = 4, OutputTokenCount = 5 })] }, ]; - AgentRunResponse response = useAsync ? - updates.ToAgentRunResponse() : - await YieldAsync(updates).ToAgentRunResponseAsync(); + AgentResponse response = useAsync ? + updates.ToAgentResponse() : + await YieldAsync(updates).ToAgentResponseAsync(); Assert.NotNull(response); Assert.Equal("agentId", response.AgentId); @@ -90,10 +90,10 @@ public async Task ToAgentRunResponseSuccessfullyCreatesResponseAsync(bool useAsy } [Theory] - [MemberData(nameof(ToAgentRunResponseCoalescesVariousSequenceAndGapLengthsMemberData))] - public async Task ToAgentRunResponseCoalescesVariousSequenceAndGapLengthsAsync(bool useAsync, int numSequences, int sequenceLength, int gapLength, bool gapBeginningEnd) + [MemberData(nameof(ToAgentResponseCoalescesVariousSequenceAndGapLengthsMemberData))] + public async Task ToAgentResponseCoalescesVariousSequenceAndGapLengthsAsync(bool useAsync, int numSequences, int sequenceLength, int gapLength, bool gapBeginningEnd) { - List updates = []; + List updates = []; List expected = []; @@ -133,7 +133,7 @@ void AddGap() } } - AgentRunResponse response = useAsync ? await YieldAsync(updates).ToAgentRunResponseAsync() : updates.ToAgentRunResponse(); + AgentResponse response = useAsync ? await YieldAsync(updates).ToAgentResponseAsync() : updates.ToAgentResponse(); Assert.NotNull(response); ChatMessage message = response.Messages.Single(); @@ -152,9 +152,9 @@ void AddGap() [Theory] [InlineData(false)] [InlineData(true)] - public async Task ToAgentRunResponseCoalescesTextContentAndTextReasoningContentSeparatelyAsync(bool useAsync) + public async Task ToAgentResponseCoalescesTextContentAndTextReasoningContentSeparatelyAsync(bool useAsync) { - AgentRunResponseUpdate[] updates = + AgentResponseUpdate[] updates = [ new(null, "A"), new(null, "B"), @@ -174,7 +174,7 @@ public async Task ToAgentRunResponseCoalescesTextContentAndTextReasoningContentS new() { Contents = [new TextReasoningContent("P")] }, ]; - AgentRunResponse response = useAsync ? await YieldAsync(updates).ToAgentRunResponseAsync() : updates.ToAgentRunResponse(); + AgentResponse response = useAsync ? await YieldAsync(updates).ToAgentResponseAsync() : updates.ToAgentResponse(); ChatMessage message = Assert.Single(response.Messages); Assert.Equal(8, message.Contents.Count); Assert.Equal("ABC", Assert.IsType(message.Contents[0]).Text); @@ -188,16 +188,16 @@ public async Task ToAgentRunResponseCoalescesTextContentAndTextReasoningContentS } [Fact] - public async Task ToAgentRunResponseUsesContentExtractedFromContentsAsync() + public async Task ToAgentResponseUsesContentExtractedFromContentsAsync() { - AgentRunResponseUpdate[] updates = + AgentResponseUpdate[] updates = [ new(null, "Hello, "), new(null, "world!"), new() { Contents = [new UsageContent(new() { TotalTokenCount = 42 })] }, ]; - AgentRunResponse response = await YieldAsync(updates).ToAgentRunResponseAsync(); + AgentResponse response = await YieldAsync(updates).ToAgentResponseAsync(); Assert.NotNull(response); @@ -210,14 +210,14 @@ public async Task ToAgentRunResponseUsesContentExtractedFromContentsAsync() [Theory] [InlineData(false)] [InlineData(true)] - public async Task ToAgentRunResponse_AlternativeTimestampsAsync(bool useAsync) + public async Task ToAgentResponse_AlternativeTimestampsAsync(bool useAsync) { DateTimeOffset early = new(2024, 1, 1, 10, 0, 0, TimeSpan.Zero); DateTimeOffset middle = new(2024, 1, 1, 11, 0, 0, TimeSpan.Zero); DateTimeOffset late = new(2024, 1, 1, 12, 0, 0, TimeSpan.Zero); DateTimeOffset unixEpoch = new(1970, 1, 1, 0, 0, 0, TimeSpan.Zero); - AgentRunResponseUpdate[] updates = + AgentResponseUpdate[] updates = [ // Start with an early timestamp @@ -242,9 +242,9 @@ public async Task ToAgentRunResponse_AlternativeTimestampsAsync(bool useAsync) new(null, "g") { CreatedAt = null }, ]; - AgentRunResponse response = useAsync ? - updates.ToAgentRunResponse() : - await YieldAsync(updates).ToAgentRunResponseAsync(); + AgentResponse response = useAsync ? + updates.ToAgentResponse() : + await YieldAsync(updates).ToAgentResponseAsync(); Assert.Single(response.Messages); Assert.Equal("abcdefg", response.Messages[0].Text); @@ -253,7 +253,7 @@ public async Task ToAgentRunResponse_AlternativeTimestampsAsync(bool useAsync) Assert.Equal(late, response.CreatedAt); } - public static IEnumerable ToAgentRunResponse_TimestampFolding_MemberData() + public static IEnumerable ToAgentResponse_TimestampFolding_MemberData() { // Base test cases var testCases = new (string? timestamp1, string? timestamp2, string? expectedTimestamp)[] @@ -276,22 +276,22 @@ public async Task ToAgentRunResponse_AlternativeTimestampsAsync(bool useAsync) } [Theory] - [MemberData(nameof(ToAgentRunResponse_TimestampFolding_MemberData))] - public async Task ToAgentRunResponse_TimestampFoldingAsync(bool useAsync, string? timestamp1, string? timestamp2, string? expectedTimestamp) + [MemberData(nameof(ToAgentResponse_TimestampFolding_MemberData))] + public async Task ToAgentResponse_TimestampFoldingAsync(bool useAsync, string? timestamp1, string? timestamp2, string? expectedTimestamp) { DateTimeOffset? first = timestamp1 is not null ? DateTimeOffset.Parse(timestamp1) : null; DateTimeOffset? second = timestamp2 is not null ? DateTimeOffset.Parse(timestamp2) : null; DateTimeOffset? expected = expectedTimestamp is not null ? DateTimeOffset.Parse(expectedTimestamp) : null; - AgentRunResponseUpdate[] updates = + AgentResponseUpdate[] updates = [ new(ChatRole.Assistant, "a") { CreatedAt = first }, new(null, "b") { CreatedAt = second }, ]; - AgentRunResponse response = useAsync ? - updates.ToAgentRunResponse() : - await YieldAsync(updates).ToAgentRunResponseAsync(); + AgentResponse response = useAsync ? + updates.ToAgentResponse() : + await YieldAsync(updates).ToAgentResponseAsync(); Assert.Single(response.Messages); Assert.Equal("ab", response.Messages[0].Text); @@ -299,9 +299,9 @@ public async Task ToAgentRunResponse_TimestampFoldingAsync(bool useAsync, string Assert.Equal(expected, response.CreatedAt); } - private static async IAsyncEnumerable YieldAsync(IEnumerable updates) + private static async IAsyncEnumerable YieldAsync(IEnumerable updates) { - foreach (AgentRunResponseUpdate update in updates) + foreach (AgentResponseUpdate update in updates) { await Task.Yield(); yield return update; diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseUpdateTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentResponseUpdateTests.cs similarity index 94% rename from dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseUpdateTests.cs rename to dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentResponseUpdateTests.cs index 32b7acd673..7fda5f680b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseUpdateTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentResponseUpdateTests.cs @@ -7,12 +7,12 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; -public class AgentRunResponseUpdateTests +public class AgentResponseUpdateTests { [Fact] public void ConstructorPropsDefaulted() { - AgentRunResponseUpdate update = new(); + AgentResponseUpdate update = new(); Assert.Null(update.AuthorName); Assert.Null(update.Role); Assert.Empty(update.Text); @@ -45,7 +45,7 @@ public void ConstructorWithChatResponseUpdateRoundtrips() ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }), }; - AgentRunResponseUpdate response = new(chatResponseUpdate); + AgentResponseUpdate response = new(chatResponseUpdate); Assert.Same(chatResponseUpdate.AdditionalProperties, response.AdditionalProperties); Assert.Equal(chatResponseUpdate.AuthorName, response.AuthorName); Assert.Same(chatResponseUpdate.Contents, response.Contents); @@ -60,7 +60,7 @@ public void ConstructorWithChatResponseUpdateRoundtrips() [Fact] public void PropertiesRoundtrip() { - AgentRunResponseUpdate update = new(); + AgentResponseUpdate update = new(); Assert.Null(update.AuthorName); update.AuthorName = "author"; @@ -114,7 +114,7 @@ public void PropertiesRoundtrip() [Fact] public void TextGetUsesAllTextContent() { - AgentRunResponseUpdate update = new() + AgentResponseUpdate update = new() { Role = ChatRole.User, Contents = @@ -142,7 +142,7 @@ public void TextGetUsesAllTextContent() [Fact] public void JsonSerializationRoundtrips() { - AgentRunResponseUpdate original = new() + AgentResponseUpdate original = new() { AuthorName = "author", Role = ChatRole.Assistant, @@ -164,7 +164,7 @@ public void JsonSerializationRoundtrips() string json = JsonSerializer.Serialize(original, AgentAbstractionsJsonUtilities.DefaultOptions); - AgentRunResponseUpdate? result = JsonSerializer.Deserialize(json, AgentAbstractionsJsonUtilities.DefaultOptions); + AgentResponseUpdate? result = JsonSerializer.Deserialize(json, AgentAbstractionsJsonUtilities.DefaultOptions); Assert.NotNull(result); Assert.Equal(5, result.Contents.Count); diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs index 2a6cc7bb81..8055a95f3a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs @@ -17,8 +17,8 @@ public class DelegatingAIAgentTests { private readonly Mock _innerAgentMock; private readonly TestDelegatingAIAgent _delegatingAgent; - private readonly AgentRunResponse _testResponse; - private readonly List _testStreamingResponses; + private readonly AgentResponse _testResponse; + private readonly List _testStreamingResponses; private readonly AgentThread _testThread; /// @@ -27,19 +27,19 @@ public class DelegatingAIAgentTests public DelegatingAIAgentTests() { this._innerAgentMock = new Mock(); - this._testResponse = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "Test response")); - this._testStreamingResponses = [new AgentRunResponseUpdate(ChatRole.Assistant, "Test streaming response")]; + this._testResponse = new AgentResponse(new ChatMessage(ChatRole.Assistant, "Test response")); + this._testStreamingResponses = [new AgentResponseUpdate(ChatRole.Assistant, "Test streaming response")]; this._testThread = new TestAgentThread(); // Setup inner agent mock this._innerAgentMock.Protected().SetupGet("IdCore").Returns("test-agent-id"); this._innerAgentMock.Setup(x => x.Name).Returns("Test Agent"); this._innerAgentMock.Setup(x => x.Description).Returns("Test Description"); - this._innerAgentMock.Setup(x => x.GetNewThread()).Returns(this._testThread); + this._innerAgentMock.Setup(x => x.GetNewThreadAsync()).ReturnsAsync(this._testThread); this._innerAgentMock .Protected() - .Setup>("RunCoreAsync", + .Setup>("RunCoreAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), @@ -48,7 +48,7 @@ public DelegatingAIAgentTests() this._innerAgentMock .Protected() - .Setup>("RunCoreStreamingAsync", + .Setup>("RunCoreStreamingAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), @@ -132,17 +132,17 @@ public void Description_DelegatesToInnerAgent() #region Method Delegation Tests /// - /// Verify that GetNewThread delegates to inner agent. + /// Verify that GetNewThreadAsync delegates to inner agent. /// [Fact] - public void GetNewThread_DelegatesToInnerAgent() + public async Task GetNewThreadAsync_DelegatesToInnerAgentAsync() { // Act - var thread = this._delegatingAgent.GetNewThread(); + var thread = await this._delegatingAgent.GetNewThreadAsync(); // Assert Assert.Same(this._testThread, thread); - this._innerAgentMock.Verify(x => x.GetNewThread(), Times.Once); + this._innerAgentMock.Verify(x => x.GetNewThreadAsync(), Times.Once); } /// @@ -156,13 +156,13 @@ public async Task RunAsyncDefaultsToInnerAgentAsync() var expectedThread = new TestAgentThread(); var expectedOptions = new AgentRunOptions(); var expectedCancellationToken = new CancellationToken(); - var expectedResult = new TaskCompletionSource(); - var expectedResponse = new AgentRunResponse(); + var expectedResult = new TaskCompletionSource(); + var expectedResponse = new AgentResponse(); var innerAgentMock = new Mock(); innerAgentMock .Protected() - .Setup>("RunCoreAsync", + .Setup>("RunCoreAsync", ItExpr.Is>(m => m == expectedMessages), ItExpr.Is(t => t == expectedThread), ItExpr.Is(o => o == expectedOptions), @@ -192,7 +192,7 @@ public async Task RunStreamingAsyncDefaultsToInnerAgentAsync() var expectedThread = new TestAgentThread(); var expectedOptions = new AgentRunOptions(); var expectedCancellationToken = new CancellationToken(); - AgentRunResponseUpdate[] expectedResults = + AgentResponseUpdate[] expectedResults = [ new(ChatRole.Assistant, "Message 1"), new(ChatRole.Assistant, "Message 2") @@ -201,7 +201,7 @@ public async Task RunStreamingAsyncDefaultsToInnerAgentAsync() var innerAgentMock = new Mock(); innerAgentMock .Protected() - .Setup>("RunCoreStreamingAsync", + .Setup>("RunCoreStreamingAsync", ItExpr.Is>(m => m == expectedMessages), ItExpr.Is(t => t == expectedThread), ItExpr.Is(o => o == expectedOptions), diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs index ec343504ab..1f6f9bb578 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs @@ -11,8 +11,8 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, UseStringEnumConverter = true)] -[JsonSerializable(typeof(AgentRunResponse))] -[JsonSerializable(typeof(AgentRunResponseUpdate))] +[JsonSerializable(typeof(AgentResponse))] +[JsonSerializable(typeof(AgentResponseUpdate))] [JsonSerializable(typeof(AgentRunOptions))] [JsonSerializable(typeof(Animal))] [JsonSerializable(typeof(JsonElement))] diff --git a/dotnet/tests/Microsoft.Agents.AI.Anthropic.UnitTests/Extensions/AnthropicBetaServiceExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Anthropic.UnitTests/Extensions/AnthropicBetaServiceExtensionsTests.cs index 400bcf5456..91d2cb988a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Anthropic.UnitTests/Extensions/AnthropicBetaServiceExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Anthropic.UnitTests/Extensions/AnthropicBetaServiceExtensionsTests.cs @@ -34,7 +34,7 @@ public void CreateAIAgent_WithClientFactory_AppliesFactoryCorrectly() var testChatClient = new TestChatClient(chatClient.Beta.AsIChatClient()); // Act - var agent = chatClient.Beta.CreateAIAgent( + var agent = chatClient.Beta.AsAIAgent( model: "test-model", instructions: "Test instructions", name: "Test Agent", @@ -63,7 +63,7 @@ public void CreateAIAgent_WithClientFactoryUsingAsBuilder_AppliesFactoryCorrectl TestChatClient? testChatClient = null; // Act - var agent = chatClient.Beta.CreateAIAgent( + var agent = chatClient.Beta.AsAIAgent( model: "test-model", instructions: "Test instructions", clientFactory: (innerClient) => @@ -95,7 +95,7 @@ public void CreateAIAgent_WithOptionsAndClientFactory_AppliesFactoryCorrectly() }; // Act - var agent = chatClient.Beta.CreateAIAgent( + var agent = chatClient.Beta.AsAIAgent( options, clientFactory: (innerClient) => testChatClient); @@ -120,7 +120,7 @@ public void CreateAIAgent_WithoutClientFactory_WorksNormally() var chatClient = new TestAnthropicChatClient(); // Act - var agent = chatClient.Beta.CreateAIAgent( + var agent = chatClient.Beta.AsAIAgent( model: "test-model", instructions: "Test instructions", name: "Test Agent"); @@ -144,7 +144,7 @@ public void CreateAIAgent_WithNullClientFactory_WorksNormally() var chatClient = new TestAnthropicChatClient(); // Act - var agent = chatClient.Beta.CreateAIAgent( + var agent = chatClient.Beta.AsAIAgent( model: "test-model", instructions: "Test instructions", name: "Test Agent", @@ -167,7 +167,7 @@ public void CreateAIAgent_WithNullClient_ThrowsArgumentNullException() { // Act & Assert var exception = Assert.Throws(() => - ((IBetaService)null!).CreateAIAgent("test-model")); + ((IBetaService)null!).AsAIAgent("test-model")); Assert.Equal("betaService", exception.ParamName); } @@ -183,7 +183,7 @@ public void CreateAIAgent_WithNullOptions_ThrowsArgumentNullException() // Act & Assert var exception = Assert.Throws(() => - chatClient.Beta.CreateAIAgent((ChatClientAgentOptions)null!)); + chatClient.Beta.AsAIAgent((ChatClientAgentOptions)null!)); Assert.Equal("options", exception.ParamName); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Anthropic.UnitTests/Extensions/AnthropicClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Anthropic.UnitTests/Extensions/AnthropicClientExtensionsTests.cs index c8bf4d6a5e..90f20d15c3 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Anthropic.UnitTests/Extensions/AnthropicClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Anthropic.UnitTests/Extensions/AnthropicClientExtensionsTests.cs @@ -101,7 +101,7 @@ public void CreateAIAgent_WithClientFactory_AppliesFactoryCorrectly() var testChatClient = new TestChatClient(chatClient.AsIChatClient()); // Act - var agent = chatClient.CreateAIAgent( + var agent = chatClient.AsAIAgent( model: "test-model", instructions: "Test instructions", name: "Test Agent", @@ -130,7 +130,7 @@ public void CreateAIAgent_WithClientFactoryUsingAsBuilder_AppliesFactoryCorrectl TestChatClient? testChatClient = null; // Act - var agent = chatClient.CreateAIAgent( + var agent = chatClient.AsAIAgent( model: "test-model", instructions: "Test instructions", clientFactory: (innerClient) => @@ -162,7 +162,7 @@ public void CreateAIAgent_WithOptionsAndClientFactory_AppliesFactoryCorrectly() }; // Act - var agent = chatClient.CreateAIAgent( + var agent = chatClient.AsAIAgent( options, clientFactory: (innerClient) => testChatClient); @@ -187,7 +187,7 @@ public void CreateAIAgent_WithoutClientFactory_WorksNormally() var chatClient = new TestAnthropicChatClient(); // Act - var agent = chatClient.CreateAIAgent( + var agent = chatClient.AsAIAgent( model: "test-model", instructions: "Test instructions", name: "Test Agent"); @@ -211,7 +211,7 @@ public void CreateAIAgent_WithNullClientFactory_WorksNormally() var chatClient = new TestAnthropicChatClient(); // Act - var agent = chatClient.CreateAIAgent( + var agent = chatClient.AsAIAgent( model: "test-model", instructions: "Test instructions", name: "Test Agent", @@ -234,7 +234,7 @@ public void CreateAIAgent_WithNullClient_ThrowsArgumentNullException() { // Act & Assert var exception = Assert.Throws(() => - ((TestAnthropicChatClient)null!).CreateAIAgent("test-model")); + ((TestAnthropicChatClient)null!).AsAIAgent("test-model")); Assert.Equal("client", exception.ParamName); } @@ -250,7 +250,7 @@ public void CreateAIAgent_WithNullOptions_ThrowsArgumentNullException() // Act & Assert var exception = Assert.Throws(() => - chatClient.CreateAIAgent((ChatClientAgentOptions)null!)); + chatClient.AsAIAgent((ChatClientAgentOptions)null!)); Assert.Equal("options", exception.ParamName); } diff --git a/dotnet/tests/Microsoft.Agents.AI.AzureAI.Persistent.UnitTests/Extensions/PersistentAgentsClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.AzureAI.Persistent.UnitTests/Extensions/PersistentAgentsClientExtensionsTests.cs index b661a392be..a3d3be27fe 100644 --- a/dotnet/tests/Microsoft.Agents.AI.AzureAI.Persistent.UnitTests/Extensions/PersistentAgentsClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.AzureAI.Persistent.UnitTests/Extensions/PersistentAgentsClientExtensionsTests.cs @@ -42,7 +42,7 @@ public void GetAIAgent_WithNullOrWhitespaceAgentId_ThrowsArgumentException() // Act & Assert - null agentId var exception1 = Assert.Throws(() => - mockClient.Object.GetAIAgent((string)null!)); + mockClient.Object.GetAIAgent(null!)); Assert.Equal("agentId", exception1.ParamName); // Act & Assert - empty agentId @@ -314,7 +314,7 @@ public void GetAIAgent_WithResponseAndOptions_WorksCorrectly() }; // Act - var agent = client.GetAIAgent(response, options); + var agent = client.AsAIAgent(response, options); // Assert Assert.NotNull(agent); @@ -341,7 +341,7 @@ public void GetAIAgent_WithPersistentAgentAndOptions_WorksCorrectly() }; // Act - var agent = client.GetAIAgent(persistentAgent, options); + var agent = client.AsAIAgent(persistentAgent, options); // Assert Assert.NotNull(agent); @@ -363,7 +363,7 @@ public void GetAIAgent_WithPersistentAgentAndOptionsWithNullFields_FallsBackToAg var options = new ChatClientAgentOptions(); // Empty options // Act - var agent = client.GetAIAgent(persistentAgent, options); + var agent = client.AsAIAgent(persistentAgent, options); // Assert Assert.NotNull(agent); @@ -443,7 +443,7 @@ public void GetAIAgent_WithOptionsAndClientFactory_AppliesFactoryCorrectly() }; // Act - var agent = client.GetAIAgent( + var agent = client.AsAIAgent( persistentAgent, options, clientFactory: (innerClient) => testChatClient); @@ -470,7 +470,7 @@ public void GetAIAgent_WithNullResponse_ThrowsArgumentNullException() // Act & Assert var exception = Assert.Throws(() => - client.GetAIAgent((Response)null!, options)); + client.AsAIAgent(null!, options)); Assert.Equal("persistentAgentResponse", exception.ParamName); } @@ -487,7 +487,7 @@ public void GetAIAgent_WithNullPersistentAgent_ThrowsArgumentNullException() // Act & Assert var exception = Assert.Throws(() => - client.GetAIAgent((PersistentAgent)null!, options)); + client.AsAIAgent((PersistentAgent)null!, options)); Assert.Equal("persistentAgentMetadata", exception.ParamName); } @@ -504,7 +504,7 @@ public void GetAIAgent_WithNullOptions_ThrowsArgumentNullException() // Act & Assert var exception = Assert.Throws(() => - client.GetAIAgent(persistentAgent, (ChatClientAgentOptions)null!)); + client.AsAIAgent(persistentAgent, (ChatClientAgentOptions)null!)); Assert.Equal("options", exception.ParamName); } diff --git a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs index 4ca2b7f461..528dc323af 100644 --- a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs @@ -25,13 +25,13 @@ namespace Microsoft.Agents.AI.AzureAI.UnitTests; /// public sealed class AzureAIProjectChatClientExtensionsTests { - #region GetAIAgent(AIProjectClient, AgentRecord) Tests + #region AsAIAgent(AIProjectClient, AgentRecord) Tests /// - /// Verify that GetAIAgent throws ArgumentNullException when AIProjectClient is null. + /// Verify that AsAIAgent throws ArgumentNullException when AIProjectClient is null. /// [Fact] - public void GetAIAgent_WithAgentRecord_WithNullClient_ThrowsArgumentNullException() + public void AsAIAgent_WithAgentRecord_WithNullClient_ThrowsArgumentNullException() { // Arrange AIProjectClient? client = null; @@ -39,39 +39,39 @@ public void GetAIAgent_WithAgentRecord_WithNullClient_ThrowsArgumentNullExceptio // Act & Assert var exception = Assert.Throws(() => - client!.GetAIAgent(agentRecord)); + client!.AsAIAgent(agentRecord)); Assert.Equal("aiProjectClient", exception.ParamName); } /// - /// Verify that GetAIAgent throws ArgumentNullException when agentRecord is null. + /// Verify that AsAIAgent throws ArgumentNullException when agentRecord is null. /// [Fact] - public void GetAIAgent_WithAgentRecord_WithNullAgentRecord_ThrowsArgumentNullException() + public void AsAIAgent_WithAgentRecord_WithNullAgentRecord_ThrowsArgumentNullException() { // Arrange var mockClient = new Mock(); // Act & Assert var exception = Assert.Throws(() => - mockClient.Object.GetAIAgent((AgentRecord)null!)); + mockClient.Object.AsAIAgent((AgentRecord)null!)); Assert.Equal("agentRecord", exception.ParamName); } /// - /// Verify that GetAIAgent with AgentRecord creates a valid agent. + /// Verify that AsAIAgent with AgentRecord creates a valid agent. /// [Fact] - public void GetAIAgent_WithAgentRecord_CreatesValidAgent() + public void AsAIAgent_WithAgentRecord_CreatesValidAgent() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); AgentRecord agentRecord = this.CreateTestAgentRecord(); // Act - var agent = client.GetAIAgent(agentRecord); + var agent = client.AsAIAgent(agentRecord); // Assert Assert.NotNull(agent); @@ -79,10 +79,10 @@ public void GetAIAgent_WithAgentRecord_CreatesValidAgent() } /// - /// Verify that GetAIAgent with AgentRecord and clientFactory applies the factory. + /// Verify that AsAIAgent with AgentRecord and clientFactory applies the factory. /// [Fact] - public void GetAIAgent_WithAgentRecord_WithClientFactory_AppliesFactoryCorrectly() + public void AsAIAgent_WithAgentRecord_WithClientFactory_AppliesFactoryCorrectly() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); @@ -90,7 +90,7 @@ public void GetAIAgent_WithAgentRecord_WithClientFactory_AppliesFactoryCorrectly TestChatClient? testChatClient = null; // Act - var agent = client.GetAIAgent( + var agent = client.AsAIAgent( agentRecord, clientFactory: (innerClient) => testChatClient = new TestChatClient(innerClient)); @@ -103,13 +103,13 @@ public void GetAIAgent_WithAgentRecord_WithClientFactory_AppliesFactoryCorrectly #endregion - #region GetAIAgent(AIProjectClient, AgentVersion) Tests + #region AsAIAgent(AIProjectClient, AgentVersion) Tests /// - /// Verify that GetAIAgent throws ArgumentNullException when AIProjectClient is null. + /// Verify that AsAIAgent throws ArgumentNullException when AIProjectClient is null. /// [Fact] - public void GetAIAgent_WithAgentVersion_WithNullClient_ThrowsArgumentNullException() + public void AsAIAgent_WithAgentVersion_WithNullClient_ThrowsArgumentNullException() { // Arrange AIProjectClient? client = null; @@ -117,39 +117,39 @@ public void GetAIAgent_WithAgentVersion_WithNullClient_ThrowsArgumentNullExcepti // Act & Assert var exception = Assert.Throws(() => - client!.GetAIAgent(agentVersion)); + client!.AsAIAgent(agentVersion)); Assert.Equal("aiProjectClient", exception.ParamName); } /// - /// Verify that GetAIAgent throws ArgumentNullException when agentVersion is null. + /// Verify that AsAIAgent throws ArgumentNullException when agentVersion is null. /// [Fact] - public void GetAIAgent_WithAgentVersion_WithNullAgentVersion_ThrowsArgumentNullException() + public void AsAIAgent_WithAgentVersion_WithNullAgentVersion_ThrowsArgumentNullException() { // Arrange var mockClient = new Mock(); // Act & Assert var exception = Assert.Throws(() => - mockClient.Object.GetAIAgent((AgentVersion)null!)); + mockClient.Object.AsAIAgent((AgentVersion)null!)); Assert.Equal("agentVersion", exception.ParamName); } /// - /// Verify that GetAIAgent with AgentVersion creates a valid agent. + /// Verify that AsAIAgent with AgentVersion creates a valid agent. /// [Fact] - public void GetAIAgent_WithAgentVersion_CreatesValidAgent() + public void AsAIAgent_WithAgentVersion_CreatesValidAgent() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); AgentVersion agentVersion = this.CreateTestAgentVersion(); // Act - var agent = client.GetAIAgent(agentVersion); + var agent = client.AsAIAgent(agentVersion); // Assert Assert.NotNull(agent); @@ -157,10 +157,10 @@ public void GetAIAgent_WithAgentVersion_CreatesValidAgent() } /// - /// Verify that GetAIAgent with AgentVersion and clientFactory applies the factory. + /// Verify that AsAIAgent with AgentVersion and clientFactory applies the factory. /// [Fact] - public void GetAIAgent_WithAgentVersion_WithClientFactory_AppliesFactoryCorrectly() + public void AsAIAgent_WithAgentVersion_WithClientFactory_AppliesFactoryCorrectly() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); @@ -168,7 +168,7 @@ public void GetAIAgent_WithAgentVersion_WithClientFactory_AppliesFactoryCorrectl TestChatClient? testChatClient = null; // Act - var agent = client.GetAIAgent( + var agent = client.AsAIAgent( agentVersion, clientFactory: (innerClient) => testChatClient = new TestChatClient(innerClient)); @@ -183,7 +183,7 @@ public void GetAIAgent_WithAgentVersion_WithClientFactory_AppliesFactoryCorrectl /// Verify that GetAIAgent with requireInvocableTools=true enforces invocable tools. /// [Fact] - public void GetAIAgent_WithAgentVersion_WithRequireInvocableToolsTrue_EnforcesInvocableTools() + public void AsAIAgent_WithAgentVersion_WithRequireInvocableToolsTrue_EnforcesInvocableTools() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); @@ -194,7 +194,7 @@ public void GetAIAgent_WithAgentVersion_WithRequireInvocableToolsTrue_EnforcesIn }; // Act - var agent = client.GetAIAgent(agentVersion, tools: tools); + var agent = client.AsAIAgent(agentVersion, tools: tools); // Assert Assert.NotNull(agent); @@ -205,14 +205,14 @@ public void GetAIAgent_WithAgentVersion_WithRequireInvocableToolsTrue_EnforcesIn /// Verify that GetAIAgent with requireInvocableTools=false allows declarative functions. /// [Fact] - public void GetAIAgent_WithAgentVersion_WithRequireInvocableToolsFalse_AllowsDeclarativeFunctions() + public void AsAIAgent_WithAgentVersion_WithRequireInvocableToolsFalse_AllowsDeclarativeFunctions() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); AgentVersion agentVersion = this.CreateTestAgentVersion(); // Act - should not throw even without tools when requireInvocableTools is false - var agent = client.GetAIAgent(agentVersion); + var agent = client.AsAIAgent(agentVersion); // Assert Assert.NotNull(agent); @@ -374,7 +374,7 @@ public async Task GetAIAgentAsync_WithOptions_CreatesValidAgentAsync() #region GetAIAgent(AIProjectClient, string) Tests /// - /// Verify that GetAIAgent throws ArgumentNullException when AIProjectClient is null. + /// Verify that AsAIAgent throws ArgumentNullException when AIProjectClient is null. /// [Fact] public void GetAIAgent_ByName_WithNullClient_ThrowsArgumentNullException() @@ -390,7 +390,7 @@ public void GetAIAgent_ByName_WithNullClient_ThrowsArgumentNullException() } /// - /// Verify that GetAIAgent throws ArgumentNullException when name is null. + /// Verify that AsAIAgent throws ArgumentNullException when name is null. /// [Fact] public void GetAIAgent_ByName_WithNullName_ThrowsArgumentNullException() @@ -406,7 +406,7 @@ public void GetAIAgent_ByName_WithNullName_ThrowsArgumentNullException() } /// - /// Verify that GetAIAgent throws ArgumentException when name is empty. + /// Verify that AsAIAgent throws ArgumentException when name is empty. /// [Fact] public void GetAIAgent_ByName_WithEmptyName_ThrowsArgumentException() @@ -422,7 +422,7 @@ public void GetAIAgent_ByName_WithEmptyName_ThrowsArgumentException() } /// - /// Verify that GetAIAgent throws InvalidOperationException when agent is not found. + /// Verify that AsAIAgent throws InvalidOperationException when agent is not found. /// [Fact] public void GetAIAgent_ByName_WithNonExistentAgent_ThrowsInvalidOperationException() @@ -505,13 +505,13 @@ public async Task GetAIAgentAsync_ByName_WithNonExistentAgent_ThrowsInvalidOpera #endregion - #region GetAIAgent(AIProjectClient, AgentRecord) with tools Tests + #region AsAIAgent(AIProjectClient, AgentRecord) with tools Tests /// - /// Verify that GetAIAgent with additional tools when the definition has no tools does not throw and results in an agent with no tools. + /// Verify that AsAIAgent with additional tools when the definition has no tools does not throw and results in an agent with no tools. /// [Fact] - public void GetAIAgent_WithAgentRecordAndAdditionalTools_WhenDefinitionHasNoTools_ShouldNotThrow() + public void AsAIAgent_WithAgentRecordAndAdditionalTools_WhenDefinitionHasNoTools_ShouldNotThrow() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); @@ -522,7 +522,7 @@ public void GetAIAgent_WithAgentRecordAndAdditionalTools_WhenDefinitionHasNoTool }; // Act - var agent = client.GetAIAgent(agentRecord, tools: tools); + var agent = client.AsAIAgent(agentRecord, tools: tools); // Assert Assert.NotNull(agent); @@ -536,17 +536,17 @@ public void GetAIAgent_WithAgentRecordAndAdditionalTools_WhenDefinitionHasNoTool } /// - /// Verify that GetAIAgent with null tools works correctly. + /// Verify that AsAIAgent with null tools works correctly. /// [Fact] - public void GetAIAgent_WithAgentRecordAndNullTools_WorksCorrectly() + public void AsAIAgent_WithAgentRecordAndNullTools_WorksCorrectly() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); AgentRecord agentRecord = this.CreateTestAgentRecord(); // Act - var agent = client.GetAIAgent(agentRecord, tools: null); + var agent = client.AsAIAgent(agentRecord, tools: null); // Assert Assert.NotNull(agent); @@ -1104,7 +1104,7 @@ public void GetAIAgent_AdditionalAITools_WhenNotInTheDefinitionAreIgnored() var shouldBeIgnoredTool = AIFunctionFactory.Create(() => "test", "additional_tool", "An additional test function that should be ignored"); // Act & Assert - var agent = client.GetAIAgent(agentVersion, tools: [invocableInlineAITool, shouldBeIgnoredTool]); + var agent = client.AsAIAgent(agentVersion, tools: [invocableInlineAITool, shouldBeIgnoredTool]); Assert.NotNull(agent); var version = agent.GetService(); Assert.NotNull(version); @@ -1136,7 +1136,7 @@ public void GetAIAgent_WithParameterTools_AcceptsTools() }; // Act - var agent = client.GetAIAgent(agentRecord, tools: tools); + var agent = client.AsAIAgent(agentRecord, tools: tools); // Assert Assert.NotNull(agent); @@ -1632,7 +1632,7 @@ public void CreateAIAgent_WithOptionsAndTools_GeneratesCorrectOptions() #region AgentName Validation Tests /// - /// Verify that GetAIAgent throws ArgumentException when agent name is invalid. + /// Verify that AsAIAgent throws ArgumentException when agent name is invalid. /// [Theory] [MemberData(nameof(InvalidAgentNameTestData.GetInvalidAgentNames), MemberType = typeof(InvalidAgentNameTestData))] @@ -1846,7 +1846,7 @@ public void GetAIAgent_WithAgentReference_WithInvalidAgentName_ThrowsArgumentExc /// Verify that the underlying chat client created by extension methods can be wrapped with clientFactory. /// [Fact] - public void GetAIAgent_WithClientFactory_WrapsUnderlyingChatClient() + public void AsAIAgent_WithClientFactory_WrapsUnderlyingChatClient() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); @@ -1854,7 +1854,7 @@ public void GetAIAgent_WithClientFactory_WrapsUnderlyingChatClient() int factoryCallCount = 0; // Act - var agent = client.GetAIAgent( + var agent = client.AsAIAgent( agentRecord, clientFactory: (innerClient) => { @@ -1903,18 +1903,18 @@ public void CreateAIAgent_WithClientFactory_ReceivesCorrectUnderlyingClient() /// Verify that multiple clientFactory calls create independent wrapped clients. /// [Fact] - public void GetAIAgent_MultipleCallsWithClientFactory_CreatesIndependentClients() + public void AsAIAgent_MultipleCallsWithClientFactory_CreatesIndependentClients() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); AgentRecord agentRecord = this.CreateTestAgentRecord(); // Act - var agent1 = client.GetAIAgent( + var agent1 = client.AsAIAgent( agentRecord, clientFactory: (innerClient) => new TestChatClient(innerClient)); - var agent2 = client.GetAIAgent( + var agent2 = client.AsAIAgent( agentRecord, clientFactory: (innerClient) => new TestChatClient(innerClient)); @@ -2165,7 +2165,7 @@ public async Task GetAIAgent_UserAgentHeaderAddedToRequestsAsync() #region GetAIAgent(AIProjectClient, AgentReference) Tests /// - /// Verify that GetAIAgent throws ArgumentNullException when AIProjectClient is null. + /// Verify that AsAIAgent throws ArgumentNullException when AIProjectClient is null. /// [Fact] public void GetAIAgent_WithAgentReference_WithNullClient_ThrowsArgumentNullException() @@ -2182,7 +2182,7 @@ public void GetAIAgent_WithAgentReference_WithNullClient_ThrowsArgumentNullExcep } /// - /// Verify that GetAIAgent throws ArgumentNullException when agentReference is null. + /// Verify that AsAIAgent throws ArgumentNullException when agentReference is null. /// [Fact] public void GetAIAgent_WithAgentReference_WithNullAgentReference_ThrowsArgumentNullException() @@ -2297,7 +2297,7 @@ public void GetService_WithAgentRecord_ReturnsAgentRecord() AgentRecord agentRecord = this.CreateTestAgentRecord(); // Act - var agent = client.GetAIAgent(agentRecord); + var agent = client.AsAIAgent(agentRecord); var retrievedRecord = agent.GetService(); // Assert @@ -2338,7 +2338,7 @@ public void GetService_WithAgentVersion_ReturnsAgentVersion() AgentVersion agentVersion = this.CreateTestAgentVersion(); // Act - var agent = client.GetAIAgent(agentVersion); + var agent = client.AsAIAgent(agentVersion); var retrievedVersion = agent.GetService(); // Assert @@ -2379,7 +2379,7 @@ public void ChatClientMetadata_WithAgentRecord_IsPopulatedCorrectly() AgentRecord agentRecord = this.CreateTestAgentRecord(); // Act - var agent = client.GetAIAgent(agentRecord); + var agent = client.AsAIAgent(agentRecord); var metadata = agent.GetService(); // Assert @@ -2402,7 +2402,7 @@ public void ChatClientMetadata_WithPromptAgentDefinition_SetsDefaultModelIdFromM AgentRecord agentRecord = this.CreateTestAgentRecord(definition); // Act - var agent = client.GetAIAgent(agentRecord); + var agent = client.AsAIAgent(agentRecord); var metadata = agent.GetService(); // Assert @@ -2423,7 +2423,7 @@ public void ChatClientMetadata_WithAgentVersion_IsPopulatedCorrectly() AgentVersion agentVersion = this.CreateTestAgentVersion(); // Act - var agent = client.GetAIAgent(agentVersion); + var agent = client.AsAIAgent(agentVersion); var metadata = agent.GetService(); // Assert @@ -2467,7 +2467,7 @@ public void GetService_WithAgentRecord_ReturnsAlsoAgentReference() AgentRecord agentRecord = this.CreateTestAgentRecord(); // Act - var agent = client.GetAIAgent(agentRecord); + var agent = client.AsAIAgent(agentRecord); var retrievedReference = agent.GetService(); // Assert @@ -2486,7 +2486,7 @@ public void GetService_WithAgentVersion_ReturnsAlsoAgentReference() AgentVersion agentVersion = this.CreateTestAgentVersion(); // Act - var agent = client.GetAIAgent(agentVersion); + var agent = client.AsAIAgent(agentVersion); var retrievedReference = agent.GetService(); // Assert diff --git a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientTests.cs b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientTests.cs index eee9f520b6..0c93c72172 100644 --- a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientTests.cs @@ -53,7 +53,7 @@ public async Task ChatClient_UsesDefaultConversationIdAsync() }); // Act - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await agent.RunAsync("Hello", thread); Assert.True(requestTriggered); @@ -102,7 +102,7 @@ public async Task ChatClient_UsesPerRequestConversationId_WhenNoDefaultConversat }); // Act - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await agent.RunAsync("Hello", thread, options: new ChatClientAgentRunOptions() { ChatOptions = new() { ConversationId = "conv_12345" } }); Assert.True(requestTriggered); @@ -151,7 +151,7 @@ public async Task ChatClient_UsesPerRequestConversationId_EvenWhenDefaultConvers }); // Act - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await agent.RunAsync("Hello", thread, options: new ChatClientAgentRunOptions() { ChatOptions = new() { ConversationId = "conv_12345" } }); Assert.True(requestTriggered); @@ -200,7 +200,7 @@ public async Task ChatClient_UsesPreviousResponseId_WhenConversationIsNotPrefixe }); // Act - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); await agent.RunAsync("Hello", thread, options: new ChatClientAgentRunOptions() { ChatOptions = new() { ConversationId = "resp_0888a" } }); Assert.True(requestTriggered); diff --git a/dotnet/tests/Microsoft.Agents.AI.Declarative.UnitTests/AggregatorPromptAgentFactoryTests.cs b/dotnet/tests/Microsoft.Agents.AI.Declarative.UnitTests/AggregatorPromptAgentFactoryTests.cs index 09ee72504a..54bc8ebfed 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Declarative.UnitTests/AggregatorPromptAgentFactoryTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Declarative.UnitTests/AggregatorPromptAgentFactoryTests.cs @@ -66,22 +66,22 @@ public TestAgentFactory(AIAgent? agentToReturn = null) private sealed class TestAgent : AIAgent { - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } - public override AgentThread GetNewThread() + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) { throw new NotImplementedException(); } - protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } - protected override IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } diff --git a/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/AgentEntityTests.cs b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/AgentEntityTests.cs index b615bf1cd6..f0b5caf9bd 100644 --- a/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/AgentEntityTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/AgentEntityTests.cs @@ -41,7 +41,7 @@ public sealed class AgentEntityTests(ITestOutputHelper outputHelper) : IDisposab public async Task EntityNamePrefixAsync() { // Setup - AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).AsAIAgent( name: "TestAgent", instructions: "You are a helpful assistant that always responds with a friendly greeting." ); @@ -51,7 +51,7 @@ public async Task EntityNamePrefixAsync() // A proxy agent is needed to call the hosted test agent AIAgent simpleAgentProxy = simpleAgent.AsDurableAgentProxy(testHelper.Services); - AgentThread thread = simpleAgentProxy.GetNewThread(); + AgentThread thread = await simpleAgentProxy.GetNewThreadAsync(this.TestTimeoutToken); DurableTaskClient client = testHelper.GetClient(); @@ -88,7 +88,7 @@ await simpleAgentProxy.RunAsync( public async Task RunAgentMethodNamesAllWorkAsync(string runAgentMethodName) { // Setup - AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).AsAIAgent( name: "TestAgent", instructions: "You are a helpful assistant that always responds with a friendly greeting." ); @@ -98,7 +98,7 @@ public async Task RunAgentMethodNamesAllWorkAsync(string runAgentMethodName) // A proxy agent is needed to call the hosted test agent AIAgent simpleAgentProxy = simpleAgent.AsDurableAgentProxy(testHelper.Services); - AgentThread thread = simpleAgentProxy.GetNewThread(); + AgentThread thread = await simpleAgentProxy.GetNewThreadAsync(this.TestTimeoutToken); DurableTaskClient client = testHelper.GetClient(); @@ -143,7 +143,7 @@ await client.Entities.SignalEntityAsync( public async Task OrchestrationIdSetDuringOrchestrationAsync() { // Arrange - AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).AsAIAgent( name: "TestAgent", instructions: "You are a helpful assistant that always responds with a friendly greeting." ); @@ -184,7 +184,7 @@ private sealed class TestOrchestrator : TaskOrchestrator public override async Task RunAsync(TaskOrchestrationContext context, string input) { DurableAIAgent writer = context.GetAgent("TestAgent"); - AgentThread writerThread = writer.GetNewThread(); + AgentThread writerThread = await writer.GetNewThreadAsync(); await writer.RunAsync( message: context.GetInput()!, diff --git a/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/ExternalClientTests.cs b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/ExternalClientTests.cs index c43b86e330..9e266dde00 100644 --- a/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/ExternalClientTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/ExternalClientTests.cs @@ -41,7 +41,7 @@ public sealed class ExternalClientTests(ITestOutputHelper outputHelper) : IDispo public async Task SimplePromptAsync() { // Setup - AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).AsAIAgent( instructions: "You are a helpful assistant that always responds with a friendly greeting.", name: "TestAgent"); @@ -51,13 +51,13 @@ public async Task SimplePromptAsync() AIAgent simpleAgentProxy = simpleAgent.AsDurableAgentProxy(testHelper.Services); // Act: send a prompt to the agent and wait for a response - AgentThread thread = simpleAgentProxy.GetNewThread(); + AgentThread thread = await simpleAgentProxy.GetNewThreadAsync(this.TestTimeoutToken); await simpleAgentProxy.RunAsync( message: "Hello!", thread, cancellationToken: this.TestTimeoutToken); - AgentRunResponse response = await simpleAgentProxy.RunAsync( + AgentResponse response = await simpleAgentProxy.RunAsync( message: "Repeat what you just said but say it like a pirate", thread, cancellationToken: this.TestTimeoutToken); @@ -94,7 +94,7 @@ string SuggestPackingList(string weather, bool isSunny) return isSunny ? "Pack sunglasses and sunscreen." : "Pack a raincoat and umbrella."; } - AIAgent tripPlanningAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + AIAgent tripPlanningAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).AsAIAgent( instructions: "You are a trip planning assistant. Use the weather tool and packing list tool as needed.", name: "TripPlanningAgent", description: "An agent to help plan your day trips", @@ -105,7 +105,7 @@ string SuggestPackingList(string weather, bool isSunny) AIAgent tripPlanningAgentProxy = tripPlanningAgent.AsDurableAgentProxy(testHelper.Services); // Act: send a prompt to the agent - AgentRunResponse response = await tripPlanningAgentProxy.RunAsync( + AgentResponse response = await tripPlanningAgentProxy.RunAsync( message: "Help me figure out what to pack for my Seattle trip next Sunday", cancellationToken: this.TestTimeoutToken); @@ -156,13 +156,13 @@ async Task RunWorkflowAsync(TaskOrchestrationContext context, string nam { // 1. Get agent and create a session DurableAIAgent agent = context.GetAgent("SimpleAgent"); - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(this.TestTimeoutToken); // 2. Call an agent and tell it my name await agent.RunAsync($"My name is {name}.", thread); // 3. Call the agent again with the same thread (ask it to tell me my name) - AgentRunResponse response = await agent.RunAsync("What is my name?", thread); + AgentResponse response = await agent.RunAsync("What is my name?", thread); return response.Text; } @@ -174,7 +174,7 @@ async Task RunWorkflowAsync(TaskOrchestrationContext context, string nam // This is the agent that will be used to start the workflow agents.AddAIAgentFactory( "WorkflowAgent", - sp => TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + sp => TestHelper.GetAzureOpenAIChatClient(s_configuration).AsAIAgent( name: "WorkflowAgent", instructions: "You can start greeting workflows and check their status.", services: sp, @@ -184,7 +184,7 @@ async Task RunWorkflowAsync(TaskOrchestrationContext context, string nam ])); // This is the agent that will be called by the workflow - agents.AddAIAgent(TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + agents.AddAIAgent(TestHelper.GetAzureOpenAIChatClient(s_configuration).AsAIAgent( name: "SimpleAgent", instructions: "You are a simple assistant." )); @@ -194,14 +194,14 @@ async Task RunWorkflowAsync(TaskOrchestrationContext context, string nam AIAgent workflowManagerAgentProxy = testHelper.Services.GetDurableAgentProxy("WorkflowAgent"); // Act: send a prompt to the agent - AgentThread thread = workflowManagerAgentProxy.GetNewThread(); + AgentThread thread = await workflowManagerAgentProxy.GetNewThreadAsync(this.TestTimeoutToken); await workflowManagerAgentProxy.RunAsync( message: "Start a greeting workflow for \"John Doe\".", thread, cancellationToken: this.TestTimeoutToken); // Act: prompt it again to wait for the workflow to complete - AgentRunResponse response = await workflowManagerAgentProxy.RunAsync( + AgentResponse response = await workflowManagerAgentProxy.RunAsync( message: "Wait for the workflow to complete and tell me the result.", thread, cancellationToken: this.TestTimeoutToken); @@ -217,14 +217,14 @@ await workflowManagerAgentProxy.RunAsync( public void AsDurableAgentProxy_ThrowsWhenAgentNotRegistered() { // Setup: Register one agent but try to use a different one - AIAgent registeredAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + AIAgent registeredAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).AsAIAgent( instructions: "You are a helpful assistant.", name: "RegisteredAgent"); using TestHelper testHelper = TestHelper.Start([registeredAgent], this._outputHelper); // Create an agent with a different name that isn't registered - AIAgent unregisteredAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + AIAgent unregisteredAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).AsAIAgent( instructions: "You are a helpful assistant.", name: "UnregisteredAgent"); diff --git a/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/OrchestrationTests.cs b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/OrchestrationTests.cs index 0c702e6062..641cb57dc8 100644 --- a/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/OrchestrationTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/OrchestrationTests.cs @@ -57,7 +57,7 @@ static async Task TestOrchestrationAsync(TaskOrchestrationContext contex // Register a different agent, but not "NonExistentAgent" agents.AddAIAgentFactory( "OtherAgent", - sp => TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + sp => TestHelper.GetAzureOpenAIChatClient(s_configuration).AsAIAgent( name: "OtherAgent", instructions: "You are a test agent.")); }, diff --git a/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/TimeToLiveTests.cs b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/TimeToLiveTests.cs index 25d40a1c5a..5437b7cdfa 100644 --- a/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/TimeToLiveTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/TimeToLiveTests.cs @@ -40,7 +40,7 @@ public async Task EntityExpiresAfterTTLAsync() { // Arrange: Create agent with short TTL (10 seconds) TimeSpan ttl = TimeSpan.FromSeconds(10); - AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).AsAIAgent( name: "TTLTestAgent", instructions: "You are a helpful assistant." ); @@ -55,7 +55,7 @@ public async Task EntityExpiresAfterTTLAsync() }); AIAgent agentProxy = simpleAgent.AsDurableAgentProxy(testHelper.Services); - AgentThread thread = agentProxy.GetNewThread(); + AgentThread thread = await agentProxy.GetNewThreadAsync(this.TestTimeoutToken); DurableTaskClient client = testHelper.GetClient(); AgentSessionId sessionId = thread.GetService(); @@ -105,7 +105,7 @@ public async Task EntityTTLResetsOnInteractionAsync() { // Arrange: Create agent with short TTL TimeSpan ttl = TimeSpan.FromSeconds(6); - AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).AsAIAgent( name: "TTLResetTestAgent", instructions: "You are a helpful assistant." ); @@ -120,7 +120,7 @@ public async Task EntityTTLResetsOnInteractionAsync() }); AIAgent agentProxy = simpleAgent.AsDurableAgentProxy(testHelper.Services); - AgentThread thread = agentProxy.GetNewThread(); + AgentThread thread = await agentProxy.GetNewThreadAsync(this.TestTimeoutToken); DurableTaskClient client = testHelper.GetClient(); AgentSessionId sessionId = thread.GetService(); diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.A2A.UnitTests/A2AIntegrationTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.A2A.UnitTests/A2AIntegrationTests.cs index 48cb19789a..f8604c7eac 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.A2A.UnitTests/A2AIntegrationTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.A2A.UnitTests/A2AIntegrationTests.cs @@ -77,7 +77,9 @@ public async Task MapA2A_WithAgentCard_CardEndpointReturnsCardWithUrlAsync() Assert.NotNull(url); Assert.NotEmpty(url); Assert.StartsWith("http", url, StringComparison.OrdinalIgnoreCase); - Assert.Equal($"{testServer.BaseAddress.ToString().TrimEnd('/')}/a2a/test-agent/v1/card", url); + + // agentCard's URL matches the agent endpoint + Assert.Equal($"{testServer.BaseAddress.ToString().TrimEnd('/')}/a2a/test-agent", url); } finally { diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.A2A.UnitTests/AIAgentExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.A2A.UnitTests/AIAgentExtensionsTests.cs index 0d5b895974..271e80b966 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.A2A.UnitTests/AIAgentExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.A2A.UnitTests/AIAgentExtensionsTests.cs @@ -101,7 +101,7 @@ public async Task MapA2A_WhenResponseHasAdditionalProperties_ReturnsAgentMessage ["responseKey1"] = "responseValue1", ["responseKey2"] = 123 }; - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, "Test response")]) + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, "Test response")]) { AdditionalProperties = additionalProps }; @@ -130,7 +130,7 @@ public async Task MapA2A_WhenResponseHasAdditionalProperties_ReturnsAgentMessage public async Task MapA2A_WhenResponseHasNullAdditionalProperties_ReturnsAgentMessageWithNullMetadataAsync() { // Arrange - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, "Test response")]) + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, "Test response")]) { AdditionalProperties = null }; @@ -154,7 +154,7 @@ public async Task MapA2A_WhenResponseHasNullAdditionalProperties_ReturnsAgentMes public async Task MapA2A_WhenResponseHasEmptyAdditionalProperties_ReturnsAgentMessageWithNullMetadataAsync() { // Arrange - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, "Test response")]) + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, "Test response")]) { AdditionalProperties = [] }; @@ -175,29 +175,29 @@ private static Mock CreateAgentMock(Action optionsCal { Mock agentMock = new() { CallBase = true }; agentMock.SetupGet(x => x.Name).Returns("TestAgent"); - agentMock.Setup(x => x.GetNewThread()).Returns(new TestAgentThread()); + agentMock.Setup(x => x.GetNewThreadAsync()).ReturnsAsync(new TestAgentThread()); agentMock .Protected() - .Setup>("RunCoreAsync", + .Setup>("RunCoreAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), ItExpr.IsAny()) .Callback, AgentThread?, AgentRunOptions?, CancellationToken>( (_, _, options, _) => optionsCallback(options)) - .ReturnsAsync(new AgentRunResponse([new ChatMessage(ChatRole.Assistant, "Test response")])); + .ReturnsAsync(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Test response")])); return agentMock; } - private static Mock CreateAgentMockWithResponse(AgentRunResponse response) + private static Mock CreateAgentMockWithResponse(AgentResponse response) { Mock agentMock = new() { CallBase = true }; agentMock.SetupGet(x => x.Name).Returns("TestAgent"); - agentMock.Setup(x => x.GetNewThread()).Returns(new TestAgentThread()); + agentMock.Setup(x => x.GetNewThreadAsync()).ReturnsAsync(new TestAgentThread()); agentMock .Protected() - .Setup>("RunCoreAsync", + .Setup>("RunCoreAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs index dfabaca64e..12d0daffc7 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs @@ -30,14 +30,14 @@ public async Task ClientReceivesStreamedAssistantMessageAsync() // Arrange await this.SetupTestServerAsync(); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); - ChatClientAgentThread thread = (ChatClientAgentThread)agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); + ChatClientAgentThread thread = (ChatClientAgentThread)await agent.GetNewThreadAsync(); ChatMessage userMessage = new(ChatRole.User, "hello"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } @@ -49,7 +49,7 @@ public async Task ClientReceivesStreamedAssistantMessageAsync() updates.Should().AllSatisfy(u => u.Role.Should().Be(ChatRole.Assistant)); // Verify assistant response message - AgentRunResponse response = updates.ToAgentRunResponse(); + AgentResponse response = updates.ToAgentResponse(); response.Messages.Should().HaveCount(1); response.Messages[0].Role.Should().Be(ChatRole.Assistant); response.Messages[0].Text.Should().Be("Hello from fake agent!"); @@ -61,14 +61,14 @@ public async Task ClientReceivesRunLifecycleEventsAsync() // Arrange await this.SetupTestServerAsync(); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); - ChatClientAgentThread thread = (ChatClientAgentThread)agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); + ChatClientAgentThread thread = (ChatClientAgentThread)await agent.GetNewThreadAsync(); ChatMessage userMessage = new(ChatRole.User, "test"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } @@ -86,14 +86,14 @@ public async Task ClientReceivesRunLifecycleEventsAsync() updates.Should().Contain(u => !string.IsNullOrEmpty(u.Text)); // All text content updates should have the same message ID - List textUpdates = updates.Where(u => !string.IsNullOrEmpty(u.Text)).ToList(); + List textUpdates = updates.Where(u => !string.IsNullOrEmpty(u.Text)).ToList(); textUpdates.Should().NotBeEmpty(); string? firstMessageId = textUpdates.FirstOrDefault()?.MessageId; firstMessageId.Should().NotBeNullOrEmpty(); textUpdates.Should().AllSatisfy(u => u.MessageId.Should().Be(firstMessageId)); // RunFinished should be the last update - AgentRunResponseUpdate lastUpdate = updates[^1]; + AgentResponseUpdate lastUpdate = updates[^1]; lastUpdate.ResponseId.Should().Be(runId); ChatResponseUpdate lastChatUpdate = lastUpdate.AsChatResponseUpdate(); lastChatUpdate.ConversationId.Should().Be(threadId); @@ -105,12 +105,12 @@ public async Task RunAsyncAggregatesStreamingUpdatesAsync() // Arrange await this.SetupTestServerAsync(); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); - ChatClientAgentThread thread = (ChatClientAgentThread)agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); + ChatClientAgentThread thread = (ChatClientAgentThread)await agent.GetNewThreadAsync(); ChatMessage userMessage = new(ChatRole.User, "hello"); // Act - AgentRunResponse response = await agent.RunAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None); + AgentResponse response = await agent.RunAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None); // Assert response.Messages.Should().NotBeEmpty(); @@ -124,13 +124,13 @@ public async Task MultiTurnConversationPreservesAllMessagesInThreadAsync() // Arrange await this.SetupTestServerAsync(); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); - ChatClientAgentThread chatClientThread = (ChatClientAgentThread)agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); + ChatClientAgentThread chatClientThread = (ChatClientAgentThread)await agent.GetNewThreadAsync(); ChatMessage firstUserMessage = new(ChatRole.User, "First question"); // Act - First turn - List firstTurnUpdates = []; - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([firstUserMessage], chatClientThread, new AgentRunOptions(), CancellationToken.None)) + List firstTurnUpdates = []; + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([firstUserMessage], chatClientThread, new AgentRunOptions(), CancellationToken.None)) { firstTurnUpdates.Add(update); } @@ -140,8 +140,8 @@ public async Task MultiTurnConversationPreservesAllMessagesInThreadAsync() // Act - Second turn with another message ChatMessage secondUserMessage = new(ChatRole.User, "Second question"); - List secondTurnUpdates = []; - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([secondUserMessage], chatClientThread, new AgentRunOptions(), CancellationToken.None)) + List secondTurnUpdates = []; + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([secondUserMessage], chatClientThread, new AgentRunOptions(), CancellationToken.None)) { secondTurnUpdates.Add(update); } @@ -150,13 +150,13 @@ public async Task MultiTurnConversationPreservesAllMessagesInThreadAsync() secondTurnUpdates.Should().Contain(u => !string.IsNullOrEmpty(u.Text)); // Verify first turn assistant response - AgentRunResponse firstResponse = firstTurnUpdates.ToAgentRunResponse(); + AgentResponse firstResponse = firstTurnUpdates.ToAgentResponse(); firstResponse.Messages.Should().HaveCount(1); firstResponse.Messages[0].Role.Should().Be(ChatRole.Assistant); firstResponse.Messages[0].Text.Should().Be("Hello from fake agent!"); // Verify second turn assistant response - AgentRunResponse secondResponse = secondTurnUpdates.ToAgentRunResponse(); + AgentResponse secondResponse = secondTurnUpdates.ToAgentResponse(); secondResponse.Messages.Should().HaveCount(1); secondResponse.Messages[0].Role.Should().Be(ChatRole.Assistant); secondResponse.Messages[0].Text.Should().Be("Hello from fake agent!"); @@ -168,20 +168,20 @@ public async Task AgentSendsMultipleMessagesInOneTurnAsync() // Arrange await this.SetupTestServerAsync(useMultiMessageAgent: true); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); - ChatClientAgentThread chatClientThread = (ChatClientAgentThread)agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); + ChatClientAgentThread chatClientThread = (ChatClientAgentThread)await agent.GetNewThreadAsync(); ChatMessage userMessage = new(ChatRole.User, "Tell me a story"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage], chatClientThread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage], chatClientThread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } // Assert - Should have received text updates with different message IDs - List textUpdates = updates.Where(u => !string.IsNullOrEmpty(u.Text)).ToList(); + List textUpdates = updates.Where(u => !string.IsNullOrEmpty(u.Text)).ToList(); textUpdates.Should().NotBeEmpty(); // Extract unique message IDs @@ -189,7 +189,7 @@ public async Task AgentSendsMultipleMessagesInOneTurnAsync() messageIds.Should().HaveCountGreaterThan(1, "agent should send multiple messages"); // Verify assistant messages from updates - AgentRunResponse response = updates.ToAgentRunResponse(); + AgentResponse response = updates.ToAgentResponse(); response.Messages.Should().HaveCountGreaterThan(1); response.Messages.Should().AllSatisfy(m => m.Role.Should().Be(ChatRole.Assistant)); } @@ -200,8 +200,8 @@ public async Task UserSendsMultipleMessagesAtOnceAsync() // Arrange await this.SetupTestServerAsync(); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); - ChatClientAgentThread chatClientThread = (ChatClientAgentThread)agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); + ChatClientAgentThread chatClientThread = (ChatClientAgentThread)await agent.GetNewThreadAsync(); // Multiple user messages sent in one turn ChatMessage[] userMessages = @@ -211,10 +211,10 @@ public async Task UserSendsMultipleMessagesAtOnceAsync() new ChatMessage(ChatRole.User, "Third part of question") ]; - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(userMessages, chatClientThread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(userMessages, chatClientThread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } @@ -224,7 +224,7 @@ public async Task UserSendsMultipleMessagesAtOnceAsync() updates.Should().Contain(u => u.Role == ChatRole.Assistant); // Verify assistant response message - AgentRunResponse response = updates.ToAgentRunResponse(); + AgentResponse response = updates.ToAgentResponse(); response.Messages.Should().HaveCount(1); response.Messages[0].Role.Should().Be(ChatRole.Assistant); response.Messages[0].Text.Should().Be("Hello from fake agent!"); @@ -280,32 +280,28 @@ internal sealed class FakeChatClientAgent : AIAgent public override string? Description => "A fake agent for testing"; - public override AgentThread GetNewThread() - { - return new FakeInMemoryAgentThread(); - } + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) => + new(new FakeInMemoryAgentThread()); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - { - return new FakeInMemoryAgentThread(serializedThread, jsonSerializerOptions); - } + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => + new(new FakeInMemoryAgentThread(serializedThread, jsonSerializerOptions)); - protected override async Task RunCoreAsync( + protected override async Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { - List updates = []; - await foreach (AgentRunResponseUpdate update in this.RunStreamingAsync(messages, thread, options, cancellationToken).ConfigureAwait(false)) + List updates = []; + await foreach (AgentResponseUpdate update in this.RunStreamingAsync(messages, thread, options, cancellationToken).ConfigureAwait(false)) { updates.Add(update); } - return updates.ToAgentRunResponse(); + return updates.ToAgentResponse(); } - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -316,7 +312,7 @@ protected override async IAsyncEnumerable RunCoreStreami // Simulate streaming a deterministic response foreach (string chunk in new[] { "Hello", " ", "from", " ", "fake", " ", "agent", "!" }) { - yield return new AgentRunResponseUpdate + yield return new AgentResponseUpdate { MessageId = messageId, Role = ChatRole.Assistant, @@ -348,32 +344,28 @@ internal sealed class FakeMultiMessageAgent : AIAgent public override string? Description => "A fake agent that sends multiple messages for testing"; - public override AgentThread GetNewThread() - { - return new FakeInMemoryAgentThread(); - } + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) => + new(new FakeInMemoryAgentThread()); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - { - return new FakeInMemoryAgentThread(serializedThread, jsonSerializerOptions); - } + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => + new(new FakeInMemoryAgentThread(serializedThread, jsonSerializerOptions)); - protected override async Task RunCoreAsync( + protected override async Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { - List updates = []; - await foreach (AgentRunResponseUpdate update in this.RunStreamingAsync(messages, thread, options, cancellationToken).ConfigureAwait(false)) + List updates = []; + await foreach (AgentResponseUpdate update in this.RunStreamingAsync(messages, thread, options, cancellationToken).ConfigureAwait(false)) { updates.Add(update); } - return updates.ToAgentRunResponse(); + return updates.ToAgentResponse(); } - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -383,7 +375,7 @@ protected override async IAsyncEnumerable RunCoreStreami string messageId1 = Guid.NewGuid().ToString("N"); foreach (string chunk in new[] { "First", " ", "message" }) { - yield return new AgentRunResponseUpdate + yield return new AgentResponseUpdate { MessageId = messageId1, Role = ChatRole.Assistant, @@ -397,7 +389,7 @@ protected override async IAsyncEnumerable RunCoreStreami string messageId2 = Guid.NewGuid().ToString("N"); foreach (string chunk in new[] { "Second", " ", "message" }) { - yield return new AgentRunResponseUpdate + yield return new AgentResponseUpdate { MessageId = messageId2, Role = ChatRole.Assistant, @@ -411,7 +403,7 @@ protected override async IAsyncEnumerable RunCoreStreami string messageId3 = Guid.NewGuid().ToString("N"); foreach (string chunk in new[] { "Third", " ", "message" }) { - yield return new AgentRunResponseUpdate + yield return new AgentResponseUpdate { MessageId = messageId3, Role = ChatRole.Assistant, diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs index 1777ff456a..2009fdb91b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs @@ -303,12 +303,12 @@ public FakeForwardedPropsAgent() public JsonElement ReceivedForwardedProperties { get; private set; } - protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { - return this.RunCoreStreamingAsync(messages, thread, options, cancellationToken).ToAgentRunResponseAsync(cancellationToken); + return this.RunCoreStreamingAsync(messages, thread, options, cancellationToken).ToAgentResponseAsync(cancellationToken); } - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -324,7 +324,7 @@ protected override async IAsyncEnumerable RunCoreStreami // Always return a text response string messageId = Guid.NewGuid().ToString("N"); - yield return new AgentRunResponseUpdate + yield return new AgentResponseUpdate { MessageId = messageId, Role = ChatRole.Assistant, @@ -334,12 +334,11 @@ protected override async IAsyncEnumerable RunCoreStreami await Task.CompletedTask; } - public override AgentThread GetNewThread() => new FakeInMemoryAgentThread(); + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) => + new(new FakeInMemoryAgentThread()); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - { - return new FakeInMemoryAgentThread(serializedThread, jsonSerializerOptions); - } + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => + new(new FakeInMemoryAgentThread(serializedThread, jsonSerializerOptions)); private sealed class FakeInMemoryAgentThread : InMemoryAgentThread { diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs index df51d1cbc4..14675e4019 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs @@ -33,8 +33,8 @@ public async Task StateSnapshot_IsReturnedAsDataContent_WithCorrectMediaTypeAsyn await this.SetupTestServerAsync(fakeAgent); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); - ChatClientAgentThread thread = (ChatClientAgentThread)agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); + ChatClientAgentThread thread = (ChatClientAgentThread)await agent.GetNewThreadAsync(); string stateJson = JsonSerializer.Serialize(initialState); byte[] stateBytes = System.Text.Encoding.UTF8.GetBytes(stateJson); @@ -42,10 +42,10 @@ public async Task StateSnapshot_IsReturnedAsDataContent_WithCorrectMediaTypeAsyn ChatMessage stateMessage = new(ChatRole.System, [stateContent]); ChatMessage userMessage = new(ChatRole.User, "update state"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage, stateMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage, stateMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } @@ -54,7 +54,7 @@ public async Task StateSnapshot_IsReturnedAsDataContent_WithCorrectMediaTypeAsyn updates.Should().NotBeEmpty(); // Should receive state snapshot as DataContent with application/json media type - AgentRunResponseUpdate? stateUpdate = updates.FirstOrDefault(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json")); + AgentResponseUpdate? stateUpdate = updates.FirstOrDefault(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json")); stateUpdate.Should().NotBeNull("should receive state snapshot update"); DataContent? dataContent = stateUpdate!.Contents.OfType().FirstOrDefault(dc => dc.MediaType == "application/json"); @@ -76,8 +76,8 @@ public async Task StateSnapshot_HasCorrectAdditionalPropertiesAsync() await this.SetupTestServerAsync(fakeAgent); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); - ChatClientAgentThread thread = (ChatClientAgentThread)agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); + ChatClientAgentThread thread = (ChatClientAgentThread)await agent.GetNewThreadAsync(); string stateJson = JsonSerializer.Serialize(initialState); byte[] stateBytes = System.Text.Encoding.UTF8.GetBytes(stateJson); @@ -85,16 +85,16 @@ public async Task StateSnapshot_HasCorrectAdditionalPropertiesAsync() ChatMessage stateMessage = new(ChatRole.System, [stateContent]); ChatMessage userMessage = new(ChatRole.User, "process"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage, stateMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage, stateMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } // Assert - AgentRunResponseUpdate? stateUpdate = updates.FirstOrDefault(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json")); + AgentResponseUpdate? stateUpdate = updates.FirstOrDefault(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json")); stateUpdate.Should().NotBeNull(); ChatResponseUpdate chatUpdate = stateUpdate!.AsChatResponseUpdate(); @@ -118,8 +118,8 @@ public async Task ComplexState_WithNestedObjectsAndArrays_RoundTripsCorrectlyAsy await this.SetupTestServerAsync(fakeAgent); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); - ChatClientAgentThread thread = (ChatClientAgentThread)agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); + ChatClientAgentThread thread = (ChatClientAgentThread)await agent.GetNewThreadAsync(); string stateJson = JsonSerializer.Serialize(complexState); byte[] stateBytes = System.Text.Encoding.UTF8.GetBytes(stateJson); @@ -127,16 +127,16 @@ public async Task ComplexState_WithNestedObjectsAndArrays_RoundTripsCorrectlyAsy ChatMessage stateMessage = new(ChatRole.System, [stateContent]); ChatMessage userMessage = new(ChatRole.User, "process complex state"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage, stateMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage, stateMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } // Assert - AgentRunResponseUpdate? stateUpdate = updates.FirstOrDefault(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json")); + AgentResponseUpdate? stateUpdate = updates.FirstOrDefault(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json")); stateUpdate.Should().NotBeNull(); DataContent? dataContent = stateUpdate!.Contents.OfType().FirstOrDefault(dc => dc.MediaType == "application/json"); @@ -158,8 +158,8 @@ public async Task StateSnapshot_CanBeUsedInSubsequentRequest_ForStateRoundTripAs await this.SetupTestServerAsync(fakeAgent); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); - ChatClientAgentThread thread = (ChatClientAgentThread)agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); + ChatClientAgentThread thread = (ChatClientAgentThread)await agent.GetNewThreadAsync(); string stateJson = JsonSerializer.Serialize(initialState); byte[] stateBytes = System.Text.Encoding.UTF8.GetBytes(stateJson); @@ -167,16 +167,16 @@ public async Task StateSnapshot_CanBeUsedInSubsequentRequest_ForStateRoundTripAs ChatMessage stateMessage = new(ChatRole.System, [stateContent]); ChatMessage userMessage = new(ChatRole.User, "increment"); - List firstRoundUpdates = []; + List firstRoundUpdates = []; // Act - First round - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage, stateMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage, stateMessage], thread, new AgentRunOptions(), CancellationToken.None)) { firstRoundUpdates.Add(update); } // Extract state snapshot from first round - AgentRunResponseUpdate? firstStateUpdate = firstRoundUpdates.FirstOrDefault(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json")); + AgentResponseUpdate? firstStateUpdate = firstRoundUpdates.FirstOrDefault(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json")); firstStateUpdate.Should().NotBeNull(); DataContent? firstStateContent = firstStateUpdate!.Contents.OfType().FirstOrDefault(dc => dc.MediaType == "application/json"); @@ -184,14 +184,14 @@ public async Task StateSnapshot_CanBeUsedInSubsequentRequest_ForStateRoundTripAs ChatMessage secondStateMessage = new(ChatRole.System, [firstStateContent!]); ChatMessage secondUserMessage = new(ChatRole.User, "increment again"); - List secondRoundUpdates = []; - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([secondUserMessage, secondStateMessage], thread, new AgentRunOptions(), CancellationToken.None)) + List secondRoundUpdates = []; + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([secondUserMessage, secondStateMessage], thread, new AgentRunOptions(), CancellationToken.None)) { secondRoundUpdates.Add(update); } // Assert - Second round should have incremented counter again - AgentRunResponseUpdate? secondStateUpdate = secondRoundUpdates.FirstOrDefault(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json")); + AgentResponseUpdate? secondStateUpdate = secondRoundUpdates.FirstOrDefault(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json")); secondStateUpdate.Should().NotBeNull(); DataContent? secondStateContent = secondStateUpdate!.Contents.OfType().FirstOrDefault(dc => dc.MediaType == "application/json"); @@ -209,15 +209,15 @@ public async Task WithoutState_AgentBehavesNormally_NoStateSnapshotReturnedAsync await this.SetupTestServerAsync(fakeAgent); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); - ChatClientAgentThread thread = (ChatClientAgentThread)agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); + ChatClientAgentThread thread = (ChatClientAgentThread)await agent.GetNewThreadAsync(); ChatMessage userMessage = new(ChatRole.User, "hello"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } @@ -242,8 +242,8 @@ public async Task EmptyState_DoesNotTriggerStateHandlingAsync() await this.SetupTestServerAsync(fakeAgent); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); - ChatClientAgentThread thread = (ChatClientAgentThread)agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); + ChatClientAgentThread thread = (ChatClientAgentThread)await agent.GetNewThreadAsync(); string stateJson = JsonSerializer.Serialize(emptyState); byte[] stateBytes = System.Text.Encoding.UTF8.GetBytes(stateJson); @@ -251,10 +251,10 @@ public async Task EmptyState_DoesNotTriggerStateHandlingAsync() ChatMessage stateMessage = new(ChatRole.System, [stateContent]); ChatMessage userMessage = new(ChatRole.User, "hello"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage, stateMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage, stateMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } @@ -279,8 +279,8 @@ public async Task NonStreamingRunAsync_WithState_ReturnsStateInResponseAsync() await this.SetupTestServerAsync(fakeAgent); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); - ChatClientAgentThread thread = (ChatClientAgentThread)agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []); + ChatClientAgentThread thread = (ChatClientAgentThread)await agent.GetNewThreadAsync(); string stateJson = JsonSerializer.Serialize(initialState); byte[] stateBytes = System.Text.Encoding.UTF8.GetBytes(stateJson); @@ -289,7 +289,7 @@ public async Task NonStreamingRunAsync_WithState_ReturnsStateInResponseAsync() ChatMessage userMessage = new(ChatRole.User, "process"); // Act - AgentRunResponse response = await agent.RunAsync([userMessage, stateMessage], thread, new AgentRunOptions(), CancellationToken.None); + AgentResponse response = await agent.RunAsync([userMessage, stateMessage], thread, new AgentRunOptions(), CancellationToken.None); // Assert response.Should().NotBeNull(); @@ -342,12 +342,12 @@ internal sealed class FakeStateAgent : AIAgent { public override string? Description => "Agent for state testing"; - protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { - return this.RunCoreStreamingAsync(messages, thread, options, cancellationToken).ToAgentRunResponseAsync(cancellationToken); + return this.RunCoreStreamingAsync(messages, thread, options, cancellationToken).ToAgentResponseAsync(cancellationToken); } - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -396,7 +396,7 @@ stateObj is JsonElement state && byte[] modifiedStateBytes = System.Text.Encoding.UTF8.GetBytes(modifiedStateJson); DataContent modifiedStateContent = new(modifiedStateBytes, "application/json"); - yield return new AgentRunResponseUpdate + yield return new AgentResponseUpdate { MessageId = Guid.NewGuid().ToString("N"), Role = ChatRole.Assistant, @@ -407,7 +407,7 @@ stateObj is JsonElement state && // Always return a text response string messageId = Guid.NewGuid().ToString("N"); - yield return new AgentRunResponseUpdate + yield return new AgentResponseUpdate { MessageId = messageId, Role = ChatRole.Assistant, @@ -417,12 +417,11 @@ stateObj is JsonElement state && await Task.CompletedTask; } - public override AgentThread GetNewThread() => new FakeInMemoryAgentThread(); + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) => + new(new FakeInMemoryAgentThread()); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - { - return new FakeInMemoryAgentThread(serializedThread, jsonSerializerOptions); - } + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => + new(new FakeInMemoryAgentThread(serializedThread, jsonSerializerOptions)); private sealed class FakeInMemoryAgentThread : InMemoryAgentThread { diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ToolCallingTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ToolCallingTests.cs index 178ed20d73..5d5f145733 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ToolCallingTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ToolCallingTests.cs @@ -44,14 +44,14 @@ public async Task ServerTriggersSingleFunctionCallAsync() await this.SetupTestServerAsync(serverTools: [serverTool]); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: []); - AgentThread thread = agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: []); + AgentThread thread = await agent.GetNewThreadAsync(); ChatMessage userMessage = new(ChatRole.User, "Call the server function"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } @@ -92,14 +92,14 @@ public async Task ServerTriggersMultipleFunctionCallsAsync() await this.SetupTestServerAsync(serverTools: [getWeatherTool, getTimeTool]); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: []); - AgentThread thread = agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: []); + AgentThread thread = await agent.GetNewThreadAsync(); ChatMessage userMessage = new(ChatRole.User, "What's the weather and time?"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } @@ -133,14 +133,14 @@ public async Task ClientTriggersSingleFunctionCallAsync() await this.SetupTestServerAsync(); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: [clientTool]); - AgentThread thread = agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: [clientTool]); + AgentThread thread = await agent.GetNewThreadAsync(); ChatMessage userMessage = new(ChatRole.User, "Call the client function"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } @@ -181,14 +181,14 @@ public async Task ClientTriggersMultipleFunctionCallsAsync() await this.SetupTestServerAsync(); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: [calculateTool, formatTool]); - AgentThread thread = agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: [calculateTool, formatTool]); + AgentThread thread = await agent.GetNewThreadAsync(); ChatMessage userMessage = new(ChatRole.User, "Calculate 5 + 3 and format 'hello'"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } @@ -232,14 +232,14 @@ public async Task ServerAndClientTriggerFunctionCallsSimultaneouslyAsync() await this.SetupTestServerAsync(serverTools: [serverTool]); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: [clientTool]); - AgentThread thread = agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: [clientTool]); + AgentThread thread = await agent.GetNewThreadAsync(); ChatMessage userMessage = new(ChatRole.User, "Get both server and client data"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); this._output.WriteLine($"Update: {update.Contents.Count} contents"); @@ -297,14 +297,14 @@ public async Task FunctionCallsPreserveCallIdAndNameAsync() await this.SetupTestServerAsync(serverTools: [testTool]); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: []); - AgentThread thread = agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: []); + AgentThread thread = await agent.GetNewThreadAsync(); ChatMessage userMessage = new(ChatRole.User, "Call the test function"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } @@ -341,14 +341,14 @@ public async Task ParallelFunctionCallsFromServerAreHandledCorrectlyAsync() await this.SetupTestServerAsync(serverTools: [func1, func2], triggerParallelCalls: true); var chatClient = new AGUIChatClient(this._client!, "", null); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: []); - AgentThread thread = agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: []); + AgentThread thread = await agent.GetNewThreadAsync(); ChatMessage userMessage = new(ChatRole.User, "Call both functions in parallel"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } @@ -427,14 +427,14 @@ public async Task ServerToolCallWithCustomArgumentsAsync() await this.SetupTestServerAsync(serverTools: [serverTool], jsonSerializerOptions: ServerJsonContext.Default.Options); var chatClient = new AGUIChatClient(this._client!, "", null, ServerJsonContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: []); - AgentThread thread = agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: []); + AgentThread thread = await agent.GetNewThreadAsync(); ChatMessage userMessage = new(ChatRole.User, "Get server forecast for Seattle for 5 days"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } @@ -473,14 +473,14 @@ public async Task ClientToolCallWithCustomArgumentsAsync() await this.SetupTestServerAsync(); var chatClient = new AGUIChatClient(this._client!, "", null, ClientJsonContext.Default.Options); - AIAgent agent = chatClient.CreateAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: [clientTool]); - AgentThread thread = agent.GetNewThread(); + AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Test assistant", tools: [clientTool]); + AgentThread thread = await agent.GetNewThreadAsync(); ChatMessage userMessage = new(ChatRole.User, "Get client forecast for Portland with hourly data"); - List updates = []; + List updates = []; // Act - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage], thread, new AgentRunOptions(), CancellationToken.None)) { updates.Add(update); } @@ -518,7 +518,7 @@ private async Task SetupTestServerAsync( this._app = builder.Build(); // FakeChatClient will receive options.Tools containing both server and client tools (merged by framework) var fakeChatClient = new FakeToolCallingChatClient(triggerParallelCalls, this._output, jsonSerializerOptions: jsonSerializerOptions); - AIAgent baseAgent = fakeChatClient.CreateAIAgent(instructions: null, name: "base-agent", description: "A base agent for tool testing", tools: serverTools ?? []); + AIAgent baseAgent = fakeChatClient.AsAIAgent(instructions: null, name: "base-agent", description: "A base agent for tool testing", tools: serverTools ?? []); this._app.MapAGUI("/agent", baseAgent); await this._app.StartAsync(); diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs index 402451b061..a98d76d065 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs @@ -425,26 +425,27 @@ private sealed class MultiResponseAgent : AIAgent public override string? Description => "Agent that produces multiple text chunks"; - public override AgentThread GetNewThread() => new TestInMemoryAgentThread(); + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) => + new(new TestInMemoryAgentThread()); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) => - new TestInMemoryAgentThread(serializedThread, jsonSerializerOptions); + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => + new(new TestInMemoryAgentThread(serializedThread, jsonSerializerOptions)); - protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default) { await Task.CompletedTask; - yield return new AgentRunResponseUpdate(new ChatResponseUpdate(ChatRole.Assistant, "First")); - yield return new AgentRunResponseUpdate(new ChatResponseUpdate(ChatRole.Assistant, " part")); - yield return new AgentRunResponseUpdate(new ChatResponseUpdate(ChatRole.Assistant, " of response")); + yield return new AgentResponseUpdate(new ChatResponseUpdate(ChatRole.Assistant, "First")); + yield return new AgentResponseUpdate(new ChatResponseUpdate(ChatRole.Assistant, " part")); + yield return new AgentResponseUpdate(new ChatResponseUpdate(ChatRole.Assistant, " of response")); } } @@ -514,24 +515,25 @@ private sealed class TestAgent : AIAgent public override string? Description => "Test agent"; - public override AgentThread GetNewThread() => new TestInMemoryAgentThread(); + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) => + new(new TestInMemoryAgentThread()); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) => - new TestInMemoryAgentThread(serializedThread, jsonSerializerOptions); + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => + new(new TestInMemoryAgentThread(serializedThread, jsonSerializerOptions)); - protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default) { await Task.CompletedTask; - yield return new AgentRunResponseUpdate(new ChatResponseUpdate(ChatRole.Assistant, "Test response")); + yield return new AgentResponseUpdate(new ChatResponseUpdate(ChatRole.Assistant, "Test response")); } } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.IntegrationTests/SamplesValidation.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.IntegrationTests/SamplesValidation.cs index c80cd73941..0dccceaa1a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.IntegrationTests/SamplesValidation.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.IntegrationTests/SamplesValidation.cs @@ -286,7 +286,7 @@ await this.RunSampleTestAsync(samplePath, async (logs) => string startResponseText = await startResponse.Content.ReadAsStringAsync(); this._outputHelper.WriteLine($"Agent response: {startResponseText}"); - // The response should be deserializable as an AgentRunResponse object and have a valid thread ID + // The response should be deserializable as an AgentResponse object and have a valid thread ID startResponse.Headers.TryGetValues("x-ms-thread-id", out IEnumerable? agentIdValues); string? threadId = agentIdValues?.FirstOrDefault(); Assert.NotNull(threadId); diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.UnitTests/TestAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.UnitTests/TestAgent.cs index e6824a2dd4..2f8faa320a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.UnitTests/TestAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.UnitTests/TestAgent.cs @@ -11,19 +11,19 @@ internal sealed class TestAgent(string name, string description) : AIAgent public override string? Description => description; - public override AgentThread GetNewThread() => new DummyAgentThread(); + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) => new(new DummyAgentThread()); - public override AgentThread DeserializeThread( + public override ValueTask DeserializeThreadAsync( JsonElement serializedThread, - JsonSerializerOptions? jsonSerializerOptions = null) => new DummyAgentThread(); + JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => new(new DummyAgentThread()); - protected override Task RunCoreAsync( + protected override Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, - CancellationToken cancellationToken = default) => Task.FromResult(new AgentRunResponse([.. messages])); + CancellationToken cancellationToken = default) => Task.FromResult(new AgentResponse([.. messages])); - protected override IAsyncEnumerable RunCoreStreamingAsync( + protected override IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/HostedAgentBuilderToolsExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/HostedAgentBuilderToolsExtensionsTests.cs index a229c7e1f8..28b621714f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/HostedAgentBuilderToolsExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/HostedAgentBuilderToolsExtensionsTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -17,49 +18,40 @@ public sealed class HostedAgentBuilderToolsExtensionsTests [Fact] public void WithAITool_ThrowsWhenBuilderIsNull() { - // Arrange var tool = new DummyAITool(); - // Act & Assert Assert.Throws(() => HostedAgentBuilderExtensions.WithAITool(null!, tool)); } [Fact] public void WithAITool_ThrowsWhenToolIsNull() { - // Arrange var services = new ServiceCollection(); var builder = services.AddAIAgent("test-agent", "Test instructions"); - // Act & Assert - Assert.Throws(() => builder.WithAITool(null!)); + Assert.Throws(() => builder.WithAITool(tool: null!)); } [Fact] public void WithAITools_ThrowsWhenBuilderIsNull() { - // Arrange var tools = new[] { new DummyAITool() }; - // Act & Assert Assert.Throws(() => HostedAgentBuilderExtensions.WithAITools(null!, tools)); } [Fact] public void WithAITools_ThrowsWhenToolsArrayIsNull() { - // Arrange var services = new ServiceCollection(); var builder = services.AddAIAgent("test-agent", "Test instructions"); - // Act & Assert Assert.Throws(() => builder.WithAITools(null!)); } [Fact] public void RegisteredTools_ResolvesAllToolsForAgent() { - // Arrange var services = new ServiceCollection(); services.AddSingleton(new MockChatClient()); @@ -73,9 +65,13 @@ public void RegisteredTools_ResolvesAllToolsForAgent() var serviceProvider = services.BuildServiceProvider(); - var agent1Tools = ResolveAgentTools(serviceProvider, "test-agent"); + var agent1Tools = ResolveToolsFromAgent(serviceProvider, "test-agent"); Assert.Contains(tool1, agent1Tools); Assert.Contains(tool2, agent1Tools); + + var agent1ToolsDI = ResolveToolsFromDI(serviceProvider, "test-agent"); + Assert.Contains(tool1, agent1ToolsDI); + Assert.Contains(tool2, agent1ToolsDI); } [Fact] @@ -100,21 +96,160 @@ public void RegisteredTools_IsolatedPerAgent() var serviceProvider = services.BuildServiceProvider(); - var agent1Tools = ResolveAgentTools(serviceProvider, "agent1"); - var agent2Tools = ResolveAgentTools(serviceProvider, "agent2"); + var agent1Tools = ResolveToolsFromAgent(serviceProvider, "agent1"); + var agent2Tools = ResolveToolsFromAgent(serviceProvider, "agent2"); + + var agent1ToolsDI = ResolveToolsFromDI(serviceProvider, "agent1"); + var agent2ToolsDI = ResolveToolsFromDI(serviceProvider, "agent2"); Assert.Contains(tool1, agent1Tools); Assert.Contains(tool2, agent1Tools); + Assert.Contains(tool1, agent1ToolsDI); + Assert.Contains(tool2, agent1ToolsDI); + Assert.Contains(tool3, agent2Tools); + Assert.Contains(tool3, agent2ToolsDI); } - private static IList ResolveAgentTools(IServiceProvider serviceProvider, string name) + private static IList ResolveToolsFromAgent(IServiceProvider serviceProvider, string name) { var agent = serviceProvider.GetRequiredKeyedService(name) as ChatClientAgent; Assert.NotNull(agent?.ChatOptions?.Tools); return agent.ChatOptions.Tools; } + private static List ResolveToolsFromDI(IServiceProvider serviceProvider, string name) + { + var tools = serviceProvider.GetKeyedServices(name); + Assert.NotNull(tools); + return tools.ToList(); + } + + [Fact] + public void WithAIToolFactory_ThrowsWhenBuilderIsNull() + { + Assert.Throws(() => HostedAgentBuilderExtensions.WithAITool(null!, CreateTool)); + + static AITool CreateTool(IServiceProvider _) => new DummyAITool(); + } + + [Fact] + public void WithAIToolFactory_ThrowsWhenFactoryIsNull() + { + var services = new ServiceCollection(); + var builder = services.AddAIAgent("test-agent", "Test instructions"); + + Assert.Throws(() => builder.WithAITool(factory: null!)); + } + + [Fact] + public void WithAIToolFactory_RegistersToolFromFactory() + { + var services = new ServiceCollection(); + services.AddSingleton(new MockChatClient()); + + DummyAITool? createdTool = null; + var builder = services.AddAIAgent("test-agent", "Test instructions"); + builder.WithAITool(sp => + { + createdTool = new DummyAITool(); + return createdTool; + }); + + var serviceProvider = services.BuildServiceProvider(); + var tools = ResolveToolsFromDI(serviceProvider, "test-agent"); + + Assert.Single(tools); + Assert.Same(createdTool, tools[0]); + } + + [Fact] + public void WithAIToolFactory_CanAccessServicesFromFactory() + { + var services = new ServiceCollection(); + var mockChatClient = new MockChatClient(); + services.AddSingleton(mockChatClient); + + IChatClient? resolvedChatClient = null; + var builder = services.AddAIAgent("test-agent", "Test instructions"); + builder.WithAITool(sp => + { + resolvedChatClient = sp.GetService(); + return new DummyAITool(); + }); + + var serviceProvider = services.BuildServiceProvider(); + _ = ResolveToolsFromDI(serviceProvider, "test-agent"); + + Assert.Same(mockChatClient, resolvedChatClient); + } + + [Fact] + public void WithAIToolFactory_ToolsAreIsolatedPerAgent() + { + var services = new ServiceCollection(); + services.AddSingleton(new MockChatClient()); + + var tool1 = new DummyAITool(); + var tool2 = new DummyAITool(); + + var builder1 = services.AddAIAgent("agent1", "Agent 1 instructions"); + var builder2 = services.AddAIAgent("agent2", "Agent 2 instructions"); + + builder1.WithAITool(_ => tool1); + builder2.WithAITool(_ => tool2); + + var serviceProvider = services.BuildServiceProvider(); + var agent1Tools = ResolveToolsFromDI(serviceProvider, "agent1"); + var agent2Tools = ResolveToolsFromDI(serviceProvider, "agent2"); + + Assert.Single(agent1Tools); + Assert.Contains(tool1, agent1Tools); + Assert.DoesNotContain(tool2, agent1Tools); + + Assert.Single(agent2Tools); + Assert.Contains(tool2, agent2Tools); + Assert.DoesNotContain(tool1, agent2Tools); + } + + [Fact] + public void WithAIToolFactory_CanCombineWithDirectToolRegistration() + { + var services = new ServiceCollection(); + services.AddSingleton(new MockChatClient()); + + var directTool = new DummyAITool(); + var factoryTool = new DummyAITool(); + + var builder = services.AddAIAgent("test-agent", "Test instructions"); + builder + .WithAITool(directTool) + .WithAITool(_ => factoryTool); + + var serviceProvider = services.BuildServiceProvider(); + var tools = ResolveToolsFromDI(serviceProvider, "test-agent"); + + Assert.Equal(2, tools.Count); + Assert.Contains(directTool, tools); + Assert.Contains(factoryTool, tools); + } + + [Fact] + public void WithAIToolFactory_ToolsAvailableOnAgent() + { + var services = new ServiceCollection(); + services.AddSingleton(new MockChatClient()); + + var factoryTool = new DummyAITool(); + var builder = services.AddAIAgent("test-agent", "Test instructions"); + builder.WithAITool(_ => factoryTool); + + var serviceProvider = services.BuildServiceProvider(); + var agentTools = ResolveToolsFromAgent(serviceProvider, "test-agent"); + + Assert.Contains(factoryTool, agentTools); + } + /// /// Dummy AITool implementation for testing. /// diff --git a/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/AIAgentWithOpenAIExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/AIAgentWithOpenAIExtensionsTests.cs index 60c37c9b82..d29535eddb 100644 --- a/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/AIAgentWithOpenAIExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/AIAgentWithOpenAIExtensionsTests.cs @@ -78,12 +78,12 @@ public async Task RunAsync_CallsUnderlyingAgentAsync() mockAgent .Protected() - .Setup>("RunCoreAsync", + .Setup>("RunCoreAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync(new AgentRunResponse([responseMessage])); + .ReturnsAsync(new AgentResponse([responseMessage])); // Act var result = await mockAgent.Object.RunAsync(openAiMessages, mockThread.Object, options, cancellationToken); @@ -160,7 +160,7 @@ public async Task RunStreamingAsync_CallsUnderlyingAgentAsync() OpenAIChatMessage.CreateUserMessage(TestMessageText) }; - var responseUpdates = new List + var responseUpdates = new List { new(ChatRole.Assistant, ResponseText1), new(ChatRole.Assistant, ResponseText2) @@ -168,7 +168,7 @@ public async Task RunStreamingAsync_CallsUnderlyingAgentAsync() mockAgent .Protected() - .Setup>("RunCoreStreamingAsync", + .Setup>("RunCoreStreamingAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), @@ -199,9 +199,9 @@ public async Task RunStreamingAsync_CallsUnderlyingAgentAsync() } /// - /// Helper method to convert a list of AgentRunResponseUpdate to an async enumerable. + /// Helper method to convert a list of AgentResponseUpdate to an async enumerable. /// - private static async IAsyncEnumerable ToAsyncEnumerableAsync(IEnumerable updates) + private static async IAsyncEnumerable ToAsyncEnumerableAsync(IEnumerable updates) { foreach (var update in updates) { diff --git a/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIAssistantClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIAssistantClientExtensionsTests.cs index 3e9fe4d82a..8400adfbcc 100644 --- a/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIAssistantClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIAssistantClientExtensionsTests.cs @@ -210,10 +210,10 @@ public void CreateAIAgent_WithNullOptions_ThrowsArgumentNullException() } /// - /// Verify that GetAIAgent with ClientResult and options works correctly. + /// Verify that AsAIAgent with ClientResult and options works correctly. /// [Fact] - public void GetAIAgent_WithClientResultAndOptions_WorksCorrectly() + public void AsAIAgent_WithClientResultAndOptions_WorksCorrectly() { // Arrange var assistantClient = new TestAssistantClient(); @@ -228,7 +228,7 @@ public void GetAIAgent_WithClientResultAndOptions_WorksCorrectly() }; // Act - var agent = assistantClient.GetAIAgent(clientResult, options); + var agent = assistantClient.AsAIAgent(clientResult, options); // Assert Assert.NotNull(agent); @@ -238,10 +238,10 @@ public void GetAIAgent_WithClientResultAndOptions_WorksCorrectly() } /// - /// Verify that GetAIAgent with Assistant and options works correctly. + /// Verify that AsAIAgent with Assistant and options works correctly. /// [Fact] - public void GetAIAgent_WithAssistantAndOptions_WorksCorrectly() + public void AsAIAgent_WithAssistantAndOptions_WorksCorrectly() { // Arrange var assistantClient = new TestAssistantClient(); @@ -255,7 +255,7 @@ public void GetAIAgent_WithAssistantAndOptions_WorksCorrectly() }; // Act - var agent = assistantClient.GetAIAgent(assistant, options); + var agent = assistantClient.AsAIAgent(assistant, options); // Assert Assert.NotNull(agent); @@ -265,10 +265,10 @@ public void GetAIAgent_WithAssistantAndOptions_WorksCorrectly() } /// - /// Verify that GetAIAgent with Assistant and options falls back to assistant metadata when options are null. + /// Verify that AsAIAgent with Assistant and options falls back to assistant metadata when options are null. /// [Fact] - public void GetAIAgent_WithAssistantAndOptionsWithNullFields_FallsBackToAssistantMetadata() + public void AsAIAgent_WithAssistantAndOptionsWithNullFields_FallsBackToAssistantMetadata() { // Arrange var assistantClient = new TestAssistantClient(); @@ -277,7 +277,7 @@ public void GetAIAgent_WithAssistantAndOptionsWithNullFields_FallsBackToAssistan var options = new ChatClientAgentOptions(); // Empty options // Act - var agent = assistantClient.GetAIAgent(assistant, options); + var agent = assistantClient.AsAIAgent(assistant, options); // Assert Assert.NotNull(agent); @@ -341,10 +341,10 @@ public async Task GetAIAgentAsync_WithAgentIdAndOptions_WorksCorrectlyAsync() } /// - /// Verify that GetAIAgent with clientFactory parameter correctly applies the factory. + /// Verify that AsAIAgent with clientFactory parameter correctly applies the factory. /// [Fact] - public void GetAIAgent_WithClientFactory_AppliesFactoryCorrectly() + public void AsAIAgent_WithClientFactory_AppliesFactoryCorrectly() { // Arrange var assistantClient = new TestAssistantClient(); @@ -357,7 +357,7 @@ public void GetAIAgent_WithClientFactory_AppliesFactoryCorrectly() }; // Act - var agent = assistantClient.GetAIAgent( + var agent = assistantClient.AsAIAgent( assistant, options, clientFactory: (innerClient) => testChatClient); @@ -373,10 +373,10 @@ public void GetAIAgent_WithClientFactory_AppliesFactoryCorrectly() } /// - /// Verify that GetAIAgent throws ArgumentNullException when assistantClientResult is null. + /// Verify that AsAIAgent throws ArgumentNullException when assistantClientResult is null. /// [Fact] - public void GetAIAgent_WithNullClientResult_ThrowsArgumentNullException() + public void AsAIAgent_WithNullClientResult_ThrowsArgumentNullException() { // Arrange var assistantClient = new TestAssistantClient(); @@ -384,16 +384,16 @@ public void GetAIAgent_WithNullClientResult_ThrowsArgumentNullException() // Act & Assert var exception = Assert.Throws(() => - assistantClient.GetAIAgent((ClientResult)null!, options)); + assistantClient.AsAIAgent(null!, options)); Assert.Equal("assistantClientResult", exception.ParamName); } /// - /// Verify that GetAIAgent throws ArgumentNullException when assistant is null. + /// Verify that AsAIAgent throws ArgumentNullException when assistant is null. /// [Fact] - public void GetAIAgent_WithNullAssistant_ThrowsArgumentNullException() + public void AsAIAgent_WithNullAssistant_ThrowsArgumentNullException() { // Arrange var assistantClient = new TestAssistantClient(); @@ -401,16 +401,16 @@ public void GetAIAgent_WithNullAssistant_ThrowsArgumentNullException() // Act & Assert var exception = Assert.Throws(() => - assistantClient.GetAIAgent((Assistant)null!, options)); + assistantClient.AsAIAgent((Assistant)null!, options)); Assert.Equal("assistantMetadata", exception.ParamName); } /// - /// Verify that GetAIAgent throws ArgumentNullException when options is null. + /// Verify that AsAIAgent throws ArgumentNullException when options is null. /// [Fact] - public void GetAIAgent_WithNullOptions_ThrowsArgumentNullException() + public void AsAIAgent_WithNullOptions_ThrowsArgumentNullException() { // Arrange var assistantClient = new TestAssistantClient(); @@ -418,7 +418,7 @@ public void GetAIAgent_WithNullOptions_ThrowsArgumentNullException() // Act & Assert var exception = Assert.Throws(() => - assistantClient.GetAIAgent(assistant, (ChatClientAgentOptions)null!)); + assistantClient.AsAIAgent(assistant, (ChatClientAgentOptions)null!)); Assert.Equal("options", exception.ParamName); } @@ -518,10 +518,10 @@ public void CreateAIAgent_WithOptionsAndServices_PassesServicesToAgent() } /// - /// Verify that GetAIAgent with services parameter correctly passes it through to the ChatClientAgent. + /// Verify that AsAIAgent with services parameter correctly passes it through to the ChatClientAgent. /// [Fact] - public void GetAIAgent_WithServices_PassesServicesToAgent() + public void AsAIAgent_WithServices_PassesServicesToAgent() { // Arrange var assistantClient = new TestAssistantClient(); @@ -529,7 +529,7 @@ public void GetAIAgent_WithServices_PassesServicesToAgent() var assistant = ModelReaderWriter.Read(BinaryData.FromString("""{"id": "asst_abc123", "name": "Test Agent"}"""))!; // Act - var agent = assistantClient.GetAIAgent(assistant, services: serviceProvider); + var agent = assistantClient.AsAIAgent(assistant, services: serviceProvider); // Assert Assert.NotNull(agent); diff --git a/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIChatClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIChatClientExtensionsTests.cs index 09c36ef218..42ed26c6ef 100644 --- a/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIChatClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIChatClientExtensionsTests.cs @@ -76,7 +76,7 @@ public void CreateAIAgent_WithClientFactory_AppliesFactoryCorrectly() var testChatClient = new TestChatClient(chatClient.AsIChatClient()); // Act - var agent = chatClient.CreateAIAgent( + var agent = chatClient.AsAIAgent( instructions: "Test instructions", name: "Test Agent", description: "Test description", @@ -104,7 +104,7 @@ public void CreateAIAgent_WithClientFactoryUsingAsBuilder_AppliesFactoryCorrectl TestChatClient? testChatClient = null; // Act - var agent = chatClient.CreateAIAgent( + var agent = chatClient.AsAIAgent( instructions: "Test instructions", clientFactory: (innerClient) => innerClient.AsBuilder().Use((innerClient) => testChatClient = new TestChatClient(innerClient)).Build()); @@ -135,7 +135,7 @@ public void CreateAIAgent_WithOptionsAndClientFactory_AppliesFactoryCorrectly() }; // Act - var agent = chatClient.CreateAIAgent( + var agent = chatClient.AsAIAgent( options, clientFactory: (innerClient) => testChatClient); @@ -160,7 +160,7 @@ public void CreateAIAgent_WithoutClientFactory_WorksNormally() var chatClient = new TestOpenAIChatClient(); // Act - var agent = chatClient.CreateAIAgent( + var agent = chatClient.AsAIAgent( instructions: "Test instructions", name: "Test Agent"); @@ -183,7 +183,7 @@ public void CreateAIAgent_WithNullClientFactory_WorksNormally() var chatClient = new TestOpenAIChatClient(); // Act - var agent = chatClient.CreateAIAgent( + var agent = chatClient.AsAIAgent( instructions: "Test instructions", name: "Test Agent", clientFactory: null); @@ -205,7 +205,7 @@ public void CreateAIAgent_WithNullClient_ThrowsArgumentNullException() { // Act & Assert var exception = Assert.Throws(() => - ((OpenAIChatClient)null!).CreateAIAgent()); + ((OpenAIChatClient)null!).AsAIAgent()); Assert.Equal("client", exception.ParamName); } @@ -221,7 +221,7 @@ public void CreateAIAgent_WithNullOptions_ThrowsArgumentNullException() // Act & Assert var exception = Assert.Throws(() => - chatClient.CreateAIAgent((ChatClientAgentOptions)null!)); + chatClient.AsAIAgent((ChatClientAgentOptions)null!)); Assert.Equal("options", exception.ParamName); } diff --git a/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIResponseClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIResponseClientExtensionsTests.cs index 127fe1a58f..8723deeac9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIResponseClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIResponseClientExtensionsTests.cs @@ -75,7 +75,7 @@ public void CreateAIAgent_WithClientFactory_AppliesFactoryCorrectly() var testChatClient = new TestChatClient(responseClient.AsIChatClient()); // Act - var agent = responseClient.CreateAIAgent( + var agent = responseClient.AsAIAgent( instructions: "Test instructions", name: "Test Agent", description: "Test description", @@ -102,7 +102,7 @@ public void CreateAIAgent_WithoutClientFactory_WorksNormally() var responseClient = new TestOpenAIResponseClient(); // Act - var agent = responseClient.CreateAIAgent( + var agent = responseClient.AsAIAgent( instructions: "Test instructions", name: "Test Agent"); @@ -125,7 +125,7 @@ public void CreateAIAgent_WithNullClientFactory_WorksNormally() var responseClient = new TestOpenAIResponseClient(); // Act - var agent = responseClient.CreateAIAgent( + var agent = responseClient.AsAIAgent( instructions: "Test instructions", name: "Test Agent", clientFactory: null); @@ -147,7 +147,7 @@ public void CreateAIAgent_WithNullClient_ThrowsArgumentNullException() { // Act & Assert var exception = Assert.Throws(() => - ((ResponsesClient)null!).CreateAIAgent()); + ((ResponsesClient)null!).AsAIAgent()); Assert.Equal("client", exception.ParamName); } @@ -163,7 +163,7 @@ public void CreateAIAgent_WithNullOptions_ThrowsArgumentNullException() // Act & Assert var exception = Assert.Throws(() => - responseClient.CreateAIAgent((ChatClientAgentOptions)null!)); + responseClient.AsAIAgent((ChatClientAgentOptions)null!)); Assert.Equal("options", exception.ParamName); } @@ -179,7 +179,7 @@ public void CreateAIAgent_WithServices_PassesServicesToAgent() var serviceProvider = new TestServiceProvider(); // Act - var agent = responseClient.CreateAIAgent( + var agent = responseClient.AsAIAgent( instructions: "Test instructions", name: "Test Agent", services: serviceProvider); @@ -211,7 +211,7 @@ public void CreateAIAgent_WithOptionsAndServices_PassesServicesToAgent() }; // Act - var agent = responseClient.CreateAIAgent(options, services: serviceProvider); + var agent = responseClient.AsAIAgent(options, services: serviceProvider); // Assert Assert.NotNull(agent); @@ -237,7 +237,7 @@ public void CreateAIAgent_WithClientFactoryAndServices_AppliesBothCorrectly() var testChatClient = new TestChatClient(responseClient.AsIChatClient()); // Act - var agent = responseClient.CreateAIAgent( + var agent = responseClient.AsAIAgent( instructions: "Test instructions", name: "Test Agent", clientFactory: (innerClient) => testChatClient, diff --git a/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/PurviewWrapperTests.cs b/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/PurviewWrapperTests.cs index eafc67f7fc..ed012669d6 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/PurviewWrapperTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/PurviewWrapperTests.cs @@ -296,10 +296,10 @@ public async Task ProcessAgentContentAsync_WithAllowedPromptAndBlockedResponse_R new(ChatRole.User, "Test message") }; var mockAgent = new Mock(); - var innerResponse = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "Sensitive response")); + var innerResponse = new AgentResponse(new ChatMessage(ChatRole.Assistant, "Sensitive response")); mockAgent.Protected() - .Setup>("RunCoreAsync", + .Setup>("RunCoreAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), @@ -335,10 +335,10 @@ public async Task ProcessAgentContentAsync_WithAllowedPromptAndResponse_ReturnsI new(ChatRole.User, "Test message") }; var mockAgent = new Mock(); - var innerResponse = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "Safe response")); + var innerResponse = new AgentResponse(new ChatMessage(ChatRole.Assistant, "Safe response")); mockAgent.Protected() - .Setup>("RunCoreAsync", + .Setup>("RunCoreAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), @@ -378,10 +378,10 @@ public async Task ProcessAgentContentAsync_WithIgnoreExceptions_ContinuesOnError new(ChatRole.User, "Test message") }; - var expectedResponse = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "Response from inner agent")); + var expectedResponse = new AgentResponse(new ChatMessage(ChatRole.Assistant, "Response from inner agent")); var mockAgent = new Mock(); mockAgent.Protected() - .Setup>("RunCoreAsync", + .Setup>("RunCoreAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), @@ -445,10 +445,10 @@ public async Task ProcessAgentContentAsync_ExtractsThreadIdFromMessageAdditional } }; - var expectedResponse = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "Response")); + var expectedResponse = new AgentResponse(new ChatMessage(ChatRole.Assistant, "Response")); var mockAgent = new Mock(); mockAgent.Protected() - .Setup>("RunCoreAsync", + .Setup>("RunCoreAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), @@ -487,10 +487,10 @@ public async Task ProcessAgentContentAsync_GeneratesThreadId_WhenNotProvidedAsyn new(ChatRole.User, "Test message") }; - var expectedResponse = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "Response")); + var expectedResponse = new AgentResponse(new ChatMessage(ChatRole.Assistant, "Response")); var mockAgent = new Mock(); mockAgent.Protected() - .Setup>("RunCoreAsync", + .Setup>("RunCoreAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), @@ -527,10 +527,10 @@ public async Task ProcessAgentContentAsync_PassesResolvedUserId_ToResponseProces new(ChatRole.User, "Test message") }; var mockAgent = new Mock(); - var innerResponse = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "Response")); + var innerResponse = new AgentResponse(new ChatMessage(ChatRole.Assistant, "Response")); mockAgent.Protected() - .Setup>("RunCoreAsync", + .Setup>("RunCoreAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/AIAgentBuilderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/AIAgentBuilderTests.cs index 7f455327dc..48b2475a2b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/AIAgentBuilderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/AIAgentBuilderTests.cs @@ -376,7 +376,7 @@ public void Use_WithRunFuncOnly_CreatesAnonymousDelegatingAgent() var builder = new AIAgentBuilder(mockAgent.Object); // Act - var result = builder.Use((_, _, _, _, _) => Task.FromResult(new AgentRunResponse()), null).Build(); + var result = builder.Use((_, _, _, _, _) => Task.FromResult(new AgentResponse()), null).Build(); // Assert Assert.IsType(result); @@ -393,7 +393,7 @@ public void Use_WithStreamingFuncOnly_CreatesAnonymousDelegatingAgent() var builder = new AIAgentBuilder(mockAgent.Object); // Act - var result = builder.Use(null, (_, _, _, _, _) => AsyncEnumerable.Empty()).Build(); + var result = builder.Use(null, (_, _, _, _, _) => AsyncEnumerable.Empty()).Build(); // Assert Assert.IsType(result); @@ -411,8 +411,8 @@ public void Use_WithBothDelegates_CreatesAnonymousDelegatingAgent() // Act var result = builder.Use( - (_, _, _, _, _) => Task.FromResult(new AgentRunResponse()), - (_, _, _, _, _) => AsyncEnumerable.Empty()).Build(); + (_, _, _, _, _) => Task.FromResult(new AgentResponse()), + (_, _, _, _, _) => AsyncEnumerable.Empty()).Build(); // Assert Assert.IsType(result); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs index d039c95652..43039a7b76 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs @@ -121,7 +121,7 @@ public void CreateFromAgent_WithNullOptions_UsesAgentProperties() public async Task CreateFromAgent_WhenFunctionInvokedAsync_CallsAgentRunAsync() { // Arrange - var expectedResponse = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "Test response")); + var expectedResponse = new AgentResponse(new ChatMessage(ChatRole.Assistant, "Test response")); var testAgent = new TestAgent("TestAgent", "Test description", expectedResponse); var aiFunction = testAgent.AsAIFunction(); @@ -139,7 +139,7 @@ public async Task CreateFromAgent_WhenFunctionInvokedAsync_CallsAgentRunAsync() public async Task CreateFromAgent_WhenFunctionInvokedWithCancellationTokenAsync_PassesCancellationTokenAsync() { // Arrange - var expectedResponse = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "Test response")); + var expectedResponse = new AgentResponse(new ChatMessage(ChatRole.Assistant, "Test response")); var testAgent = new TestAgent("TestAgent", "Test description", expectedResponse); using var cancellationTokenSource = new CancellationTokenSource(); var cancellationToken = cancellationTokenSource.Token; @@ -257,7 +257,7 @@ public void CreateFromAgent_WithCustomOptionsOverridingNullAgentProperties_UsesC public async Task CreateFromAgent_InvokeWithComplexResponseFromAgentAsync_ReturnsCorrectResponseAsync() { // Arrange - var expectedResponse = new AgentRunResponse + var expectedResponse = new AgentResponse { AgentId = "agent-123", ResponseId = "response-456", @@ -307,10 +307,10 @@ public void CreateFromAgent_SanitizesAgentName(string agentName, string expected /// private sealed class TestAgent : AIAgent { - private readonly AgentRunResponse? _responseToReturn; + private readonly AgentResponse? _responseToReturn; private readonly Exception? _exceptionToThrow; - public TestAgent(string? name, string? description, AgentRunResponse responseToReturn) + public TestAgent(string? name, string? description, AgentResponse responseToReturn) { this.Name = name; this.Description = description; @@ -324,10 +324,10 @@ public TestAgent(string? name, string? description, Exception exceptionToThrow) this._exceptionToThrow = exceptionToThrow; } - public override AgentThread GetNewThread() + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); public override string? Name { get; } @@ -337,7 +337,7 @@ public override AgentThread DeserializeThread(JsonElement serializedThread, Json public CancellationToken LastCancellationToken { get; private set; } public int RunAsyncCallCount { get; private set; } - protected override Task RunCoreAsync( + protected override Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -355,14 +355,14 @@ protected override Task RunCoreAsync( return Task.FromResult(this._responseToReturn!); } - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var response = await this.RunAsync(messages, thread, options, cancellationToken); - foreach (var update in response.ToAgentRunResponseUpdates()) + foreach (var update in response.ToAgentResponseUpdates()) { yield return update; } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentJsonUtilitiesTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentJsonUtilitiesTests.cs index 65815607d9..c9cd1e827e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentJsonUtilitiesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentJsonUtilitiesTests.cs @@ -79,9 +79,9 @@ public void DefaultOptions_SerializesEnumsAsStrings() #endif [Fact] - public void DefaultOptions_UsesCamelCasePropertyNames_ForAgentRunResponse() + public void DefaultOptions_UsesCamelCasePropertyNames_ForAgentResponse() { - var response = new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "Hello")); + var response = new AgentResponse(new ChatMessage(ChatRole.Assistant, "Hello")); string json = JsonSerializer.Serialize(response, AgentJsonUtilities.DefaultOptions); Assert.Contains("\"messages\"", json); Assert.DoesNotContain("\"Messages\"", json); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/AnonymousDelegatingAIAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/AnonymousDelegatingAIAgentTests.cs index 4e91fc1430..43937a1148 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/AnonymousDelegatingAIAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/AnonymousDelegatingAIAgentTests.cs @@ -21,8 +21,8 @@ public class AnonymousDelegatingAIAgentTests private readonly List _testMessages; private readonly AgentThread _testThread; private readonly AgentRunOptions _testOptions; - private readonly AgentRunResponse _testResponse; - private readonly AgentRunResponseUpdate[] _testStreamingResponses; + private readonly AgentResponse _testResponse; + private readonly AgentResponseUpdate[] _testStreamingResponses; public AnonymousDelegatingAIAgentTests() { @@ -30,15 +30,15 @@ public AnonymousDelegatingAIAgentTests() this._testMessages = [new ChatMessage(ChatRole.User, "Test message")]; this._testThread = new Mock().Object; this._testOptions = new AgentRunOptions(); - this._testResponse = new AgentRunResponse([new ChatMessage(ChatRole.Assistant, "Test response")]); + this._testResponse = new AgentResponse([new ChatMessage(ChatRole.Assistant, "Test response")]); this._testStreamingResponses = [ - new AgentRunResponseUpdate(ChatRole.Assistant, "Response 1"), - new AgentRunResponseUpdate(ChatRole.Assistant, "Response 2") + new AgentResponseUpdate(ChatRole.Assistant, "Response 1"), + new AgentResponseUpdate(ChatRole.Assistant, "Response 2") ]; this._innerAgentMock .Protected() - .Setup>("RunCoreAsync", + .Setup>("RunCoreAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), @@ -47,7 +47,7 @@ public AnonymousDelegatingAIAgentTests() this._innerAgentMock .Protected() - .Setup>("RunCoreStreamingAsync", + .Setup>("RunCoreStreamingAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), @@ -191,7 +191,7 @@ public async Task RunAsync_WithSharedFunc_ContextPropagatedAsync() this._innerAgentMock .Protected() - .Verify>("RunCoreAsync", + .Verify>("RunCoreAsync", Times.Once(), ItExpr.Is>(m => m == this._testMessages), ItExpr.Is(t => t == this._testThread), @@ -441,7 +441,7 @@ public async Task SharedFunc_DoesNotCallInner_ThrowsInvalidOperationAsync() var exception = await Assert.ThrowsAsync( () => agent.RunAsync(this._testMessages, this._testThread, this._testOptions)); - Assert.Contains("without producing an AgentRunResponse", exception.Message); + Assert.Contains("without producing an AgentResponse", exception.Message); } #endregion @@ -468,7 +468,7 @@ public async Task AsyncLocalContext_MaintainedAcrossDelegatesAsync() this._innerAgentMock .Protected() - .Setup>("RunCoreAsync", + .Setup>("RunCoreAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), @@ -740,7 +740,7 @@ public async Task AIAgentBuilder_Use_MultipleMiddlewareWithSeparateDelegates_Exe var runExecutionOrder = new List(); var streamingExecutionOrder = new List(); - static async IAsyncEnumerable FirstStreamingMiddlewareAsync( + static async IAsyncEnumerable FirstStreamingMiddlewareAsync( IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, [EnumeratorCancellation] CancellationToken cancellationToken, List executionOrder) @@ -753,7 +753,7 @@ static async IAsyncEnumerable FirstStreamingMiddlewareAs executionOrder.Add("First-Streaming-Post"); } - static async IAsyncEnumerable SecondStreamingMiddlewareAsync( + static async IAsyncEnumerable SecondStreamingMiddlewareAsync( IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, [EnumeratorCancellation] CancellationToken cancellationToken, List executionOrder) @@ -891,7 +891,7 @@ public async Task AIAgentBuilder_Use_MiddlewareHandlesException_RecoveryWorksAsy { // Arrange var executionOrder = new List(); - var fallbackResponse = new AgentRunResponse([new ChatMessage(ChatRole.Assistant, "Fallback response")]); + var fallbackResponse = new AgentResponse([new ChatMessage(ChatRole.Assistant, "Fallback response")]); var agent = new AIAgentBuilder(this._innerAgentMock.Object) .Use( @@ -938,7 +938,7 @@ public async Task AIAgentBuilder_Use_CancellationTokenPropagation_WorksCorrectly // Setup mock to throw OperationCanceledException when cancelled token is used this._innerAgentMock .Protected() - .Setup>("RunCoreAsync", + .Setup>("RunCoreAsync", ItExpr.IsAny>(), ItExpr.IsAny(), ItExpr.IsAny(), @@ -973,7 +973,7 @@ await Assert.ThrowsAsync( public async Task AIAgentBuilder_Use_MiddlewareShortCircuits_InnerAgentNotCalledAsync() { // Arrange - var shortCircuitResponse = new AgentRunResponse([new ChatMessage(ChatRole.Assistant, "Short-circuited")]); + var shortCircuitResponse = new AgentResponse([new ChatMessage(ChatRole.Assistant, "Short-circuited")]); var executionOrder = new List(); var agent = new AIAgentBuilder(this._innerAgentMock.Object) @@ -1007,7 +1007,7 @@ public async Task AIAgentBuilder_Use_MiddlewareShortCircuits_InnerAgentNotCalled // Verify inner agent was never called this._innerAgentMock .Protected() - .Verify>("RunCoreAsync", + .Verify>("RunCoreAsync", Times.Never(), ItExpr.IsAny>(), ItExpr.IsAny(), diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs index 58cf5f718f..896a4ceba5 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; @@ -115,12 +117,11 @@ public void Clone_CreatesDeepCopyWithSameValues() const string Description = "Test description"; var tools = new List { AIFunctionFactory.Create(() => "test") }; - static ChatMessageStore ChatMessageStoreFactory( - ChatClientAgentOptions.ChatMessageStoreFactoryContext ctx) => new Mock().Object; + static ValueTask ChatMessageStoreFactoryAsync( + ChatClientAgentOptions.ChatMessageStoreFactoryContext ctx, CancellationToken ct) => new(new Mock().Object); - static AIContextProvider AIContextProviderFactory( - ChatClientAgentOptions.AIContextProviderFactoryContext ctx) => - new Mock().Object; + static ValueTask AIContextProviderFactoryAsync( + ChatClientAgentOptions.AIContextProviderFactoryContext ctx, CancellationToken ct) => new(new Mock().Object); var original = new ChatClientAgentOptions() { @@ -128,8 +129,8 @@ static AIContextProvider AIContextProviderFactory( Description = Description, ChatOptions = new() { Tools = tools }, Id = "test-id", - ChatMessageStoreFactory = ChatMessageStoreFactory, - AIContextProviderFactory = AIContextProviderFactory + ChatMessageStoreFactory = ChatMessageStoreFactoryAsync, + AIContextProviderFactory = AIContextProviderFactoryAsync }; // Act diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentRunOptionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentRunOptionsTests.cs index 9dd1fce4fb..1aa49dc328 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentRunOptionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentRunOptionsTests.cs @@ -143,7 +143,7 @@ IChatClient ClientFactory(IChatClient client) var options = new ChatClientAgentRunOptions { ChatClientFactory = ClientFactory }; // Act - var responseUpdates = new List(); + var responseUpdates = new List(); await foreach (var update in agent.RunStreamingAsync(messages, null, options, CancellationToken.None)) { responseUpdates.Add(update); @@ -215,7 +215,7 @@ public async Task RunStreamingAsync_WithoutChatClientFactory_UsesOriginalClientA var messages = new List { new(ChatRole.User, "Test message") }; // Act - No ChatClientFactory provided - var responseUpdates = new List(); + var responseUpdates = new List(); await foreach (var update in agent.RunStreamingAsync(messages, null, null, CancellationToken.None)) { responseUpdates.Add(update); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs index 29d3d3afee..546dc258cd 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs @@ -241,8 +241,8 @@ public async Task RunAsyncRetrievesMessagesFromThreadWhenThreadStoresMessagesThr ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" } }); - // Create a thread using the agent's GetNewThread method - var thread = agent.GetNewThread(); + // Create a thread using the agent's GetNewThreadAsync method + var thread = await agent.GetNewThreadAsync(); // Act await agent.RunAsync([new(ChatRole.User, "new message")], thread: thread); @@ -438,8 +438,8 @@ public async Task RunAsyncUsesChatMessageStoreWhenNoConversationIdReturnedByChat It.IsAny>(), It.IsAny(), It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); - Mock> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).Returns(new InMemoryChatMessageStore()); + Mock>> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(new InMemoryChatMessageStore()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, @@ -447,7 +447,7 @@ public async Task RunAsyncUsesChatMessageStoreWhenNoConversationIdReturnedByChat }); // Act - ChatClientAgentThread? thread = agent.GetNewThread() as ChatClientAgentThread; + ChatClientAgentThread? thread = await agent.GetNewThreadAsync() as ChatClientAgentThread; await agent.RunAsync([new(ChatRole.User, "test")], thread); // Assert @@ -455,7 +455,7 @@ public async Task RunAsyncUsesChatMessageStoreWhenNoConversationIdReturnedByChat Assert.Equal(2, messageStore.Count); Assert.Equal("test", messageStore[0].Text); Assert.Equal("response", messageStore[1].Text); - mockFactory.Verify(f => f(It.IsAny()), Times.Once); + mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); } /// @@ -477,7 +477,7 @@ public async Task RunAsyncUsesDefaultInMemoryChatMessageStoreWhenNoConversationI }); // Act - ChatClientAgentThread? thread = agent.GetNewThread() as ChatClientAgentThread; + ChatClientAgentThread? thread = await agent.GetNewThreadAsync() as ChatClientAgentThread; await agent.RunAsync([new(ChatRole.User, "test")], thread); // Assert @@ -509,8 +509,8 @@ public async Task RunAsyncUsesChatMessageStoreFactoryWhenProvidedAndNoConversati It.IsAny(), It.IsAny())).Returns(new ValueTask()); - Mock> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).Returns(mockChatMessageStore.Object); + Mock>> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(mockChatMessageStore.Object); ChatClientAgent agent = new(mockService.Object, options: new() { @@ -519,7 +519,7 @@ public async Task RunAsyncUsesChatMessageStoreFactoryWhenProvidedAndNoConversati }); // Act - ChatClientAgentThread? thread = agent.GetNewThread() as ChatClientAgentThread; + ChatClientAgentThread? thread = await agent.GetNewThreadAsync() as ChatClientAgentThread; await agent.RunAsync([new(ChatRole.User, "test")], thread); // Assert @@ -538,7 +538,7 @@ public async Task RunAsyncUsesChatMessageStoreFactoryWhenProvidedAndNoConversati It.Is(x => x.RequestMessages.Count() == 1 && x.ChatMessageStoreMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), It.IsAny()), Times.Once); - mockFactory.Verify(f => f(It.IsAny()), Times.Once); + mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); } /// @@ -557,8 +557,8 @@ public async Task RunAsyncNotifiesChatMessageStoreOnFailureAsync() Mock mockChatMessageStore = new(); - Mock> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).Returns(mockChatMessageStore.Object); + Mock>> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(mockChatMessageStore.Object); ChatClientAgent agent = new(mockService.Object, options: new() { @@ -567,7 +567,7 @@ public async Task RunAsyncNotifiesChatMessageStoreOnFailureAsync() }); // Act - ChatClientAgentThread? thread = agent.GetNewThread() as ChatClientAgentThread; + ChatClientAgentThread? thread = await agent.GetNewThreadAsync() as ChatClientAgentThread; await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], thread)); // Assert @@ -576,7 +576,7 @@ public async Task RunAsyncNotifiesChatMessageStoreOnFailureAsync() It.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages == null && x.InvokeException!.Message == "Test Error"), It.IsAny()), Times.Once); - mockFactory.Verify(f => f(It.IsAny()), Times.Once); + mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); } /// @@ -592,8 +592,8 @@ public async Task RunAsyncThrowsWhenChatMessageStoreFactoryProvidedAndConversati It.IsAny>(), It.IsAny(), It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]) { ConversationId = "ConvId" }); - Mock> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).Returns(new InMemoryChatMessageStore()); + Mock>> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(new InMemoryChatMessageStore()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, @@ -601,7 +601,7 @@ public async Task RunAsyncThrowsWhenChatMessageStoreFactoryProvidedAndConversati }); // Act & Assert - ChatClientAgentThread? thread = agent.GetNewThread() as ChatClientAgentThread; + ChatClientAgentThread? thread = await agent.GetNewThreadAsync() as ChatClientAgentThread; var exception = await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], thread)); Assert.Equal("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported.", exception.Message); } @@ -649,10 +649,10 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) .Returns(new ValueTask()); - ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = _ => mockProvider.Object, ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); // Act - var thread = agent.GetNewThread() as ChatClientAgentThread; + var thread = await agent.GetNewThreadAsync() as ChatClientAgentThread; await agent.RunAsync(requestMessages, thread); // Assert @@ -711,7 +711,7 @@ public async Task RunAsyncInvokesAIContextProviderWhenGetResponseFailsAsync() .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) .Returns(new ValueTask()); - ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = _ => mockProvider.Object, ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); // Act await Assert.ThrowsAsync(() => agent.RunAsync(requestMessages)); @@ -757,7 +757,7 @@ public async Task RunAsyncInvokesAIContextProviderAndSucceedsWithEmptyAIContextA .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) .ReturnsAsync(new AIContext()); - ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = _ => mockProvider.Object, ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); // Act await agent.RunAsync([new(ChatRole.User, "user message")]); @@ -801,15 +801,15 @@ public async Task RunAsyncWithTypeParameterInvokesChatClientMethodForStructuredO ChatClientAgent agent = new(mockService.Object, options: new()); // Act - AgentRunResponse agentRunResponse = await agent.RunAsync(messages: [new(ChatRole.User, "Hello")], serializerOptions: JsonContext2.Default.Options); + AgentResponse agentResponse = await agent.RunAsync(messages: [new(ChatRole.User, "Hello")], serializerOptions: JsonContext2.Default.Options); // Assert - Assert.Single(agentRunResponse.Messages); + Assert.Single(agentResponse.Messages); - Assert.NotNull(agentRunResponse.Result); - Assert.Equal(expectedSO.Id, agentRunResponse.Result.Id); - Assert.Equal(expectedSO.FullName, agentRunResponse.Result.FullName); - Assert.Equal(expectedSO.Species, agentRunResponse.Result.Species); + Assert.NotNull(agentResponse.Result); + Assert.Equal(expectedSO.Id, agentResponse.Result.Id); + Assert.Equal(expectedSO.FullName, agentResponse.Result.FullName); + Assert.Equal(expectedSO.Species, agentResponse.Result.Species); } #endregion @@ -1120,433 +1120,6 @@ public void ChatOptionsReturnsClonedCopyWhenAgentOptionsHaveChatOptions() #endregion - #region ChatOptions Merging Tests - - /// - /// Verify that ChatOptions merging works when agent has ChatOptions but request doesn't. - /// - [Fact] - public async Task ChatOptionsMergingUsesAgentOptionsWhenRequestHasNoneAsync() - { - // Arrange - var agentChatOptions = new ChatOptions { MaxOutputTokens = 100, Temperature = 0.7f, Instructions = "test instructions" }; - Mock mockService = new(); - ChatOptions? capturedChatOptions = null; - mockService.Setup( - s => s.GetResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => - capturedChatOptions = opts) - .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); - - ChatClientAgent agent = new(mockService.Object, options: new() - { - ChatOptions = agentChatOptions - }); - var messages = new List { new(ChatRole.User, "test") }; - - // Act - await agent.RunAsync(messages); - - // Assert - Assert.NotNull(capturedChatOptions); - Assert.Equal(100, capturedChatOptions.MaxOutputTokens); - Assert.Equal(0.7f, capturedChatOptions.Temperature); - Assert.Equal("test instructions", capturedChatOptions.Instructions); - } - - [Fact] - public async Task ChatOptionsMergingUsesAgentOptionsConstructorWhenRequestHasNoneAsync() - { - Mock mockService = new(); - ChatOptions? capturedChatOptions = null; - mockService.Setup( - s => s.GetResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => - capturedChatOptions = opts) - .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); - - ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" } }); - var messages = new List { new(ChatRole.User, "test") }; - - // Act - await agent.RunAsync(messages); - - // Assert - Assert.NotNull(capturedChatOptions); - Assert.Equal("test instructions", capturedChatOptions.Instructions); - } - - /// - /// Verify that ChatOptions merging works when request has ChatOptions but agent doesn't. - /// - [Fact] - public async Task ChatOptionsMergingUsesRequestOptionsWhenAgentHasNoneAsync() - { - // Arrange - var requestChatOptions = new ChatOptions { MaxOutputTokens = 200, Temperature = 0.3f, Instructions = "test instructions" }; - Mock mockService = new(); - ChatOptions? capturedChatOptions = null; - mockService.Setup( - s => s.GetResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => - capturedChatOptions = opts) - .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); - - ChatClientAgent agent = new(mockService.Object); - var messages = new List { new(ChatRole.User, "test") }; - - // Act - await agent.RunAsync(messages, options: new ChatClientAgentRunOptions(requestChatOptions)); - - // Assert - Assert.NotNull(capturedChatOptions); - Assert.Equivalent(requestChatOptions, capturedChatOptions); // Should be the same instance since no merging needed - Assert.Equal(200, capturedChatOptions.MaxOutputTokens); - Assert.Equal(0.3f, capturedChatOptions.Temperature); - Assert.Equal("test instructions", capturedChatOptions.Instructions); - } - - /// - /// Verify that ChatOptions merging prioritizes request options over agent options. - /// - [Fact] - public async Task ChatOptionsMergingPrioritizesRequestOptionsOverAgentOptionsAsync() - { - // Arrange - var agentChatOptions = new ChatOptions - { - Instructions = "test instructions", - MaxOutputTokens = 100, - Temperature = 0.7f, - TopP = 0.9f, - ModelId = "agent-model", - AdditionalProperties = new AdditionalPropertiesDictionary { ["key"] = "agent-value" } - }; - var requestChatOptions = new ChatOptions - { - // TopP and ModelId not set, should use agent values - MaxOutputTokens = 200, - Temperature = 0.3f, - AdditionalProperties = new AdditionalPropertiesDictionary { ["key"] = "request-value" }, - Instructions = "request instructions" - }; - var expectedChatOptionsMerge = new ChatOptions - { - MaxOutputTokens = 200, // Request value takes priority - Temperature = 0.3f, // Request value takes priority - AdditionalProperties = new AdditionalPropertiesDictionary { ["key"] = "request-value" }, // Request value takes priority - TopP = 0.9f, // Agent value used when request doesn't specify - ModelId = "agent-model", // Agent value used when request doesn't specify - Instructions = "test instructions\nrequest instructions" // Request is in addition to agent instructions - }; - - Mock mockService = new(); - ChatOptions? capturedChatOptions = null; - mockService.Setup( - s => s.GetResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => - capturedChatOptions = opts) - .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); - - ChatClientAgent agent = new(mockService.Object, options: new() - { - ChatOptions = agentChatOptions - }); - var messages = new List { new(ChatRole.User, "test") }; - - // Act - await agent.RunAsync(messages, options: new ChatClientAgentRunOptions(requestChatOptions)); - - // Assert - Assert.NotNull(capturedChatOptions); - Assert.Equivalent(expectedChatOptionsMerge, capturedChatOptions); // Should be the same instance (modified in place) - Assert.Equal(200, capturedChatOptions.MaxOutputTokens); // Request value takes priority - Assert.Equal(0.3f, capturedChatOptions.Temperature); // Request value takes priority - Assert.NotNull(capturedChatOptions.AdditionalProperties); - Assert.Equal("request-value", capturedChatOptions.AdditionalProperties["key"]); // Request value takes priority - Assert.Equal(0.9f, capturedChatOptions.TopP); // Agent value used when request doesn't specify - Assert.Equal("agent-model", capturedChatOptions.ModelId); // Agent value used when request doesn't specify - } - - /// - /// Verify that ChatOptions merging returns null when both agent and request have no ChatOptions. - /// - [Fact] - public async Task ChatOptionsMergingReturnsNullWhenBothAgentAndRequestHaveNoneAsync() - { - // Arrange - Mock mockService = new(); - ChatOptions? capturedChatOptions = null; - mockService.Setup( - s => s.GetResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => - capturedChatOptions = opts) - .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); - - ChatClientAgent agent = new(mockService.Object); - var messages = new List { new(ChatRole.User, "test") }; - - // Act - await agent.RunAsync(messages); - - // Assert - Assert.Null(capturedChatOptions); - } - - /// - /// Verify that ChatOptions merging concatenates Tools from agent and request. - /// - [Fact] - public async Task ChatOptionsMergingConcatenatesToolsFromAgentAndRequestAsync() - { - // Arrange - var agentTool = AIFunctionFactory.Create(() => "agent tool"); - var requestTool = AIFunctionFactory.Create(() => "request tool"); - - var agentChatOptions = new ChatOptions - { - Instructions = "test instructions", - Tools = [agentTool] - }; - var requestChatOptions = new ChatOptions - { - Tools = [requestTool] - }; - - Mock mockService = new(); - ChatOptions? capturedChatOptions = null; - mockService.Setup( - s => s.GetResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => - capturedChatOptions = opts) - .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); - - ChatClientAgent agent = new(mockService.Object, options: new() - { - ChatOptions = agentChatOptions - }); - var messages = new List { new(ChatRole.User, "test") }; - - // Act - await agent.RunAsync(messages, options: new ChatClientAgentRunOptions(requestChatOptions)); - - // Assert - Assert.NotNull(capturedChatOptions); - Assert.NotNull(capturedChatOptions.Tools); - Assert.Equal(2, capturedChatOptions.Tools.Count); - - // Request tools should come first, then agent tools - Assert.Contains(requestTool, capturedChatOptions.Tools); - Assert.Contains(agentTool, capturedChatOptions.Tools); - } - - /// - /// Verify that ChatOptions merging uses agent Tools when request has no Tools. - /// - [Fact] - public async Task ChatOptionsMergingUsesAgentToolsWhenRequestHasNoToolsAsync() - { - // Arrange - var agentTool = AIFunctionFactory.Create(() => "agent tool"); - - var agentChatOptions = new ChatOptions - { - Instructions = "test instructions", - Tools = [agentTool] - }; - var requestChatOptions = new ChatOptions - { - // No Tools specified - MaxOutputTokens = 100 - }; - - Mock mockService = new(); - ChatOptions? capturedChatOptions = null; - mockService.Setup( - s => s.GetResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => - capturedChatOptions = opts) - .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); - - ChatClientAgent agent = new(mockService.Object, options: new() - { - ChatOptions = agentChatOptions - }); - var messages = new List { new(ChatRole.User, "test") }; - - // Act - await agent.RunAsync(messages, options: new ChatClientAgentRunOptions(requestChatOptions)); - - // Assert - Assert.NotNull(capturedChatOptions); - Assert.NotNull(capturedChatOptions.Tools); - Assert.Single(capturedChatOptions.Tools); - Assert.Contains(agentTool, capturedChatOptions.Tools); // Should contain the agent's tool - } - - /// - /// Verify that ChatOptions merging uses RawRepresentationFactory from request first, with fallback to agent. - /// - [Theory] - [InlineData("MockAgentSetting", "MockRequestSetting", "MockRequestSetting")] - [InlineData("MockAgentSetting", null, "MockAgentSetting")] - [InlineData(null, "MockRequestSetting", "MockRequestSetting")] - public async Task ChatOptionsMergingUsesRawRepresentationFactoryWithFallbackAsync(string? agentSetting, string? requestSetting, string expectedSetting) - { - // Arrange - var agentChatOptions = new ChatOptions - { - Instructions = "test instructions", - RawRepresentationFactory = _ => agentSetting - }; - var requestChatOptions = new ChatOptions - { - RawRepresentationFactory = _ => requestSetting - }; - - Mock mockService = new(); - ChatOptions? capturedChatOptions = null; - mockService.Setup( - s => s.GetResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => - capturedChatOptions = opts) - .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); - - ChatClientAgent agent = new(mockService.Object, options: new() - { - ChatOptions = agentChatOptions - }); - var messages = new List { new(ChatRole.User, "test") }; - - // Act - await agent.RunAsync(messages, options: new ChatClientAgentRunOptions(requestChatOptions)); - - // Assert - Assert.NotNull(capturedChatOptions); - Assert.NotNull(capturedChatOptions.RawRepresentationFactory); - Assert.Equal(expectedSetting, capturedChatOptions.RawRepresentationFactory(null!)); - } - - /// - /// Verify that ChatOptions merging handles all scalar properties correctly. - /// - [Fact] - public async Task ChatOptionsMergingHandlesAllScalarPropertiesCorrectlyAsync() - { - // Arrange - var agentChatOptions = new ChatOptions - { - MaxOutputTokens = 100, - Temperature = 0.7f, - TopP = 0.9f, - TopK = 50, - PresencePenalty = 0.1f, - FrequencyPenalty = 0.2f, - Instructions = "agent instructions", - ModelId = "agent-model", - Seed = 12345, - ConversationId = "agent-conversation", - AllowMultipleToolCalls = true, - StopSequences = ["agent-stop"] - }; - var requestChatOptions = new ChatOptions - { - MaxOutputTokens = 200, - Temperature = 0.3f, - Instructions = "request instructions", - - // Other properties not set, should use agent values - StopSequences = ["request-stop"] - }; - - var expectedChatOptionsMerge = new ChatOptions - { - MaxOutputTokens = 200, - Temperature = 0.3f, - - // Agent value used when request doesn't specify - TopP = 0.9f, - TopK = 50, - PresencePenalty = 0.1f, - FrequencyPenalty = 0.2f, - Instructions = "agent instructions\nrequest instructions", - ModelId = "agent-model", - Seed = 12345, - ConversationId = "agent-conversation", - AllowMultipleToolCalls = true, - - // Merged StopSequences - StopSequences = ["request-stop", "agent-stop"] - }; - - Mock mockService = new(); - ChatOptions? capturedChatOptions = null; - mockService.Setup( - s => s.GetResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => - capturedChatOptions = opts) - .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); - - ChatClientAgent agent = new(mockService.Object, options: new() - { - ChatOptions = agentChatOptions - }); - var messages = new List { new(ChatRole.User, "test") }; - - // Act - await agent.RunAsync(messages, options: new ChatClientAgentRunOptions(requestChatOptions)); - - // Assert - Assert.NotNull(capturedChatOptions); - Assert.Equivalent(expectedChatOptionsMerge, capturedChatOptions); // Should be the equivalent instance (modified in place) - - // Request values should take priority - Assert.Equal(200, capturedChatOptions.MaxOutputTokens); - Assert.Equal(0.3f, capturedChatOptions.Temperature); - - // Merge StopSequences - Assert.Equal(["request-stop", "agent-stop"], capturedChatOptions.StopSequences); - - // Agent values should be used when request doesn't specify - Assert.Equal(0.9f, capturedChatOptions.TopP); - Assert.Equal(50, capturedChatOptions.TopK); - Assert.Equal(0.1f, capturedChatOptions.PresencePenalty); - Assert.Equal(0.2f, capturedChatOptions.FrequencyPenalty); - Assert.Equal("agent-model", capturedChatOptions.ModelId); - Assert.Equal(12345, capturedChatOptions.Seed); - Assert.Equal("agent-conversation", capturedChatOptions.ConversationId); - Assert.Equal(true, capturedChatOptions.AllowMultipleToolCalls); - } - - #endregion - #region GetService Method Tests /// @@ -1972,7 +1545,7 @@ public async Task VerifyChatClientAgentStreamingAsync() // Act var updates = agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")]); - List result = []; + List result = []; await foreach (var update in updates) { result.Add(update); @@ -2010,8 +1583,8 @@ public async Task RunStreamingAsyncUsesChatMessageStoreWhenNoConversationIdRetur It.IsAny>(), It.IsAny(), It.IsAny())).Returns(ToAsyncEnumerableAsync(returnUpdates)); - Mock> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).Returns(new InMemoryChatMessageStore()); + Mock>> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(new InMemoryChatMessageStore()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, @@ -2019,7 +1592,7 @@ public async Task RunStreamingAsyncUsesChatMessageStoreWhenNoConversationIdRetur }); // Act - ChatClientAgentThread? thread = agent.GetNewThread() as ChatClientAgentThread; + ChatClientAgentThread? thread = await agent.GetNewThreadAsync() as ChatClientAgentThread; await agent.RunStreamingAsync([new(ChatRole.User, "test")], thread).ToListAsync(); // Assert @@ -2027,7 +1600,7 @@ public async Task RunStreamingAsyncUsesChatMessageStoreWhenNoConversationIdRetur Assert.Equal(2, messageStore.Count); Assert.Equal("test", messageStore[0].Text); Assert.Equal("what?", messageStore[1].Text); - mockFactory.Verify(f => f(It.IsAny()), Times.Once); + mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); } /// @@ -2048,8 +1621,8 @@ public async Task RunStreamingAsyncThrowsWhenChatMessageStoreFactoryProvidedAndC It.IsAny>(), It.IsAny(), It.IsAny())).Returns(ToAsyncEnumerableAsync(returnUpdates)); - Mock> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).Returns(new InMemoryChatMessageStore()); + Mock>> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(new InMemoryChatMessageStore()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, @@ -2057,7 +1630,7 @@ public async Task RunStreamingAsyncThrowsWhenChatMessageStoreFactoryProvidedAndC }); // Act & Assert - ChatClientAgentThread? thread = agent.GetNewThread() as ChatClientAgentThread; + ChatClientAgentThread? thread = await agent.GetNewThreadAsync() as ChatClientAgentThread; var exception = await Assert.ThrowsAsync(async () => await agent.RunStreamingAsync([new(ChatRole.User, "test")], thread).ToListAsync()); Assert.Equal("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported.", exception.Message); } @@ -2105,12 +1678,18 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) .Returns(new ValueTask()); - ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] }, AIContextProviderFactory = _ => mockProvider.Object }); + ChatClientAgent agent = new( + mockService.Object, + options: new() + { + ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] }, + AIContextProviderFactory = (_, _) => new(mockProvider.Object) + }); // Act - var thread = agent.GetNewThread() as ChatClientAgentThread; + var thread = await agent.GetNewThreadAsync() as ChatClientAgentThread; var updates = agent.RunStreamingAsync(requestMessages, thread); - _ = await updates.ToAgentRunResponseAsync(); + _ = await updates.ToAgentResponseAsync(); // Assert // Should contain: base instructions, user message, context message, base function, context function @@ -2168,13 +1747,19 @@ public async Task RunStreamingAsyncInvokesAIContextProviderWhenGetResponseFailsA .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) .Returns(new ValueTask()); - ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] }, AIContextProviderFactory = _ => mockProvider.Object }); + ChatClientAgent agent = new( + mockService.Object, + options: new() + { + ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] }, + AIContextProviderFactory = (_, _) => new(mockProvider.Object) + }); // Act await Assert.ThrowsAsync(async () => { var updates = agent.RunStreamingAsync(requestMessages); - await updates.ToAgentRunResponseAsync(); + await updates.ToAgentResponseAsync(); }); // Assert diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentThreadTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentThreadTests.cs index 48caef1b3d..57af3b6449 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentThreadTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentThreadTests.cs @@ -93,7 +93,7 @@ public void SetChatMessageStoreThrowsWhenConversationIdIsSet() #region Deserialize Tests [Fact] - public async Task VerifyDeserializeConstructorWithMessagesAsync() + public async Task VerifyDeserializeWithMessagesAsync() { // Arrange var json = JsonSerializer.Deserialize(""" @@ -103,7 +103,7 @@ public async Task VerifyDeserializeConstructorWithMessagesAsync() """, TestJsonSerializerContext.Default.JsonElement); // Act. - var thread = new ChatClientAgentThread(json); + var thread = await ChatClientAgentThread.DeserializeAsync(json); // Assert Assert.Null(thread.ConversationId); @@ -115,7 +115,7 @@ public async Task VerifyDeserializeConstructorWithMessagesAsync() } [Fact] - public async Task VerifyDeserializeConstructorWithIdAsync() + public async Task VerifyDeserializeWithIdAsync() { // Arrange var json = JsonSerializer.Deserialize(""" @@ -125,7 +125,7 @@ public async Task VerifyDeserializeConstructorWithIdAsync() """, TestJsonSerializerContext.Default.JsonElement); // Act - var thread = new ChatClientAgentThread(json); + var thread = await ChatClientAgentThread.DeserializeAsync(json); // Assert Assert.Equal("TestConvId", thread.ConversationId); @@ -133,7 +133,7 @@ public async Task VerifyDeserializeConstructorWithIdAsync() } [Fact] - public async Task VerifyDeserializeConstructorWithAIContextProviderAsync() + public async Task VerifyDeserializeWithAIContextProviderAsync() { // Arrange var json = JsonSerializer.Deserialize(""" @@ -145,7 +145,7 @@ public async Task VerifyDeserializeConstructorWithAIContextProviderAsync() Mock mockProvider = new(); // Act - var thread = new ChatClientAgentThread(json, aiContextProviderFactory: (_, _) => mockProvider.Object); + var thread = await ChatClientAgentThread.DeserializeAsync(json, aiContextProviderFactory: (_, _, _) => new(mockProvider.Object)); // Assert Assert.Null(thread.MessageStore); @@ -153,14 +153,14 @@ public async Task VerifyDeserializeConstructorWithAIContextProviderAsync() } [Fact] - public async Task DeserializeContructorWithInvalidJsonThrowsAsync() + public async Task DeserializeWithInvalidJsonThrowsAsync() { // Arrange var invalidJson = JsonSerializer.Deserialize("[42]", TestJsonSerializerContext.Default.JsonElement); var thread = new ChatClientAgentThread(); // Act & Assert - Assert.Throws(() => new ChatClientAgentThread(invalidJson)); + await Assert.ThrowsAsync(() => ChatClientAgentThread.DeserializeAsync(invalidJson)); } #endregion Deserialize Tests diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs index cfccb7267a..79af3add1d 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs @@ -264,7 +264,7 @@ public async Task RunStreamingAsync_WhenContinuationTokenReceived_WrapsContinuat ChatClientAgentThread thread = new(); // Act - var actualUpdates = new List(); + var actualUpdates = new List(); await foreach (var u in agent.RunStreamingAsync([new(ChatRole.User, "hi")], thread, options: new ChatClientAgentRunOptions(new ChatOptions { AllowBackgroundResponses = true }))) { actualUpdates.Add(u); @@ -543,7 +543,7 @@ public async Task RunStreamingAsync_WhenInputMessagesPresentInContinuationToken_ }; // Act - var updates = new List(); + var updates = new List(); await foreach (var update in agent.RunStreamingAsync(thread, options: runOptions)) { updates.Add(update); @@ -591,7 +591,7 @@ public async Task RunStreamingAsync_WhenResponseUpdatesPresentInContinuationToke }; // Act - var updates = new List(); + var updates = new List(); await foreach (var update in agent.RunStreamingAsync(thread, options: runOptions)) { updates.Add(update); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatOptionsMergingTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatOptionsMergingTests.cs new file mode 100644 index 0000000000..6dda0f0278 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatOptionsMergingTests.cs @@ -0,0 +1,442 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Moq; + +namespace Microsoft.Agents.AI.UnitTests; + +/// +/// Contains tests for merging in . +/// +public class ChatClientAgent_ChatOptionsMergingTests +{ + /// + /// Verify that ChatOptions merging works when agent has ChatOptions but request doesn't. + /// + [Fact] + public async Task ChatOptionsMergingUsesAgentOptionsWhenRequestHasNoneAsync() + { + // Arrange + var agentChatOptions = new ChatOptions { MaxOutputTokens = 100, Temperature = 0.7f, Instructions = "test instructions" }; + Mock mockService = new(); + ChatOptions? capturedChatOptions = null; + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => + capturedChatOptions = opts) + .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + ChatClientAgent agent = new(mockService.Object, options: new() + { + ChatOptions = agentChatOptions + }); + var messages = new List { new(ChatRole.User, "test") }; + + // Act + await agent.RunAsync(messages); + + // Assert + Assert.NotNull(capturedChatOptions); + Assert.Equal(100, capturedChatOptions.MaxOutputTokens); + Assert.Equal(0.7f, capturedChatOptions.Temperature); + Assert.Equal("test instructions", capturedChatOptions.Instructions); + } + + [Fact] + public async Task ChatOptionsMergingUsesAgentOptionsConstructorWhenRequestHasNoneAsync() + { + Mock mockService = new(); + ChatOptions? capturedChatOptions = null; + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => + capturedChatOptions = opts) + .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" } }); + var messages = new List { new(ChatRole.User, "test") }; + + // Act + await agent.RunAsync(messages); + + // Assert + Assert.NotNull(capturedChatOptions); + Assert.Equal("test instructions", capturedChatOptions.Instructions); + } + + /// + /// Verify that ChatOptions merging works when request has ChatOptions but agent doesn't. + /// + [Fact] + public async Task ChatOptionsMergingUsesRequestOptionsWhenAgentHasNoneAsync() + { + // Arrange + var requestChatOptions = new ChatOptions { MaxOutputTokens = 200, Temperature = 0.3f, Instructions = "test instructions" }; + Mock mockService = new(); + ChatOptions? capturedChatOptions = null; + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => + capturedChatOptions = opts) + .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + ChatClientAgent agent = new(mockService.Object); + var messages = new List { new(ChatRole.User, "test") }; + + // Act + await agent.RunAsync(messages, options: new ChatClientAgentRunOptions(requestChatOptions)); + + // Assert + Assert.NotNull(capturedChatOptions); + Assert.Equivalent(requestChatOptions, capturedChatOptions); // Should be the same instance since no merging needed + Assert.Equal(200, capturedChatOptions.MaxOutputTokens); + Assert.Equal(0.3f, capturedChatOptions.Temperature); + Assert.Equal("test instructions", capturedChatOptions.Instructions); + } + + /// + /// Verify that merging prioritizes over request and that in turn over agent level . + /// + [Fact] + public async Task ChatOptionsMergingPrioritizesRequestOptionsOverAgentOptionsAsync() + { + // Arrange + var agentChatOptions = new ChatOptions + { + Instructions = "test instructions", + MaxOutputTokens = 100, + Temperature = 0.7f, + TopP = 0.9f, + ModelId = "agent-model", + AdditionalProperties = new AdditionalPropertiesDictionary { ["key1"] = "agent-value", ["key2"] = "agent-value", ["key3"] = "agent-value" } + }; + var requestChatOptions = new ChatOptions + { + // TopP and ModelId not set, should use agent values + MaxOutputTokens = 200, + Temperature = 0.3f, + AdditionalProperties = new AdditionalPropertiesDictionary { ["key2"] = "request-value", ["key3"] = "request-value" }, + Instructions = "request instructions" + }; + var agentRunOptionsAdditionalProperties = new AdditionalPropertiesDictionary { ["key3"] = "runoptions-value" }; + var expectedChatOptionsMerge = new ChatOptions + { + MaxOutputTokens = 200, // Request value takes priority + Temperature = 0.3f, // Request value takes priority + // Check that each level of precedence is respected in AdditionalProperties + AdditionalProperties = new AdditionalPropertiesDictionary { ["key1"] = "agent-value", ["key2"] = "request-value", ["key3"] = "runoptions-value" }, + TopP = 0.9f, // Agent value used when request doesn't specify + ModelId = "agent-model", // Agent value used when request doesn't specify + Instructions = "test instructions\nrequest instructions" // Request is in addition to agent instructions + }; + + Mock mockService = new(); + ChatOptions? capturedChatOptions = null; + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => + capturedChatOptions = opts) + .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + ChatClientAgent agent = new(mockService.Object, options: new() + { + ChatOptions = agentChatOptions + }); + var messages = new List { new(ChatRole.User, "test") }; + + // Act + await agent.RunAsync(messages, options: new ChatClientAgentRunOptions(requestChatOptions) { AdditionalProperties = agentRunOptionsAdditionalProperties }); + + // Assert + Assert.NotNull(capturedChatOptions); + Assert.Equivalent(expectedChatOptionsMerge, capturedChatOptions); // Should be the same instance (modified in place) + Assert.Equal(200, capturedChatOptions.MaxOutputTokens); // Request value takes priority + Assert.Equal(0.3f, capturedChatOptions.Temperature); // Request value takes priority + Assert.NotNull(capturedChatOptions.AdditionalProperties); + Assert.Equal("agent-value", capturedChatOptions.AdditionalProperties["key1"]); // Agent value used when request doesn't specify + Assert.Equal("request-value", capturedChatOptions.AdditionalProperties["key2"]); // Request ChatOptions value takes priority over agent ChatOptions value + Assert.Equal("runoptions-value", capturedChatOptions.AdditionalProperties["key3"]); // Run options value takes priority over request and agent ChatOptions values + Assert.Equal(0.9f, capturedChatOptions.TopP); // Agent value used when request doesn't specify + Assert.Equal("agent-model", capturedChatOptions.ModelId); // Agent value used when request doesn't specify + } + + /// + /// Verify that ChatOptions merging returns null when both agent and request have no ChatOptions. + /// + [Fact] + public async Task ChatOptionsMergingReturnsNullWhenBothAgentAndRequestHaveNoneAsync() + { + // Arrange + Mock mockService = new(); + ChatOptions? capturedChatOptions = null; + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => + capturedChatOptions = opts) + .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + ChatClientAgent agent = new(mockService.Object); + var messages = new List { new(ChatRole.User, "test") }; + + // Act + await agent.RunAsync(messages); + + // Assert + Assert.Null(capturedChatOptions); + } + + /// + /// Verify that ChatOptions merging concatenates Tools from agent and request. + /// + [Fact] + public async Task ChatOptionsMergingConcatenatesToolsFromAgentAndRequestAsync() + { + // Arrange + var agentTool = AIFunctionFactory.Create(() => "agent tool"); + var requestTool = AIFunctionFactory.Create(() => "request tool"); + + var agentChatOptions = new ChatOptions + { + Instructions = "test instructions", + Tools = [agentTool] + }; + var requestChatOptions = new ChatOptions + { + Tools = [requestTool] + }; + + Mock mockService = new(); + ChatOptions? capturedChatOptions = null; + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => + capturedChatOptions = opts) + .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + ChatClientAgent agent = new(mockService.Object, options: new() + { + ChatOptions = agentChatOptions + }); + var messages = new List { new(ChatRole.User, "test") }; + + // Act + await agent.RunAsync(messages, options: new ChatClientAgentRunOptions(requestChatOptions)); + + // Assert + Assert.NotNull(capturedChatOptions); + Assert.NotNull(capturedChatOptions.Tools); + Assert.Equal(2, capturedChatOptions.Tools.Count); + + // Request tools should come first, then agent tools + Assert.Contains(requestTool, capturedChatOptions.Tools); + Assert.Contains(agentTool, capturedChatOptions.Tools); + } + + /// + /// Verify that ChatOptions merging uses agent Tools when request has no Tools. + /// + [Fact] + public async Task ChatOptionsMergingUsesAgentToolsWhenRequestHasNoToolsAsync() + { + // Arrange + var agentTool = AIFunctionFactory.Create(() => "agent tool"); + + var agentChatOptions = new ChatOptions + { + Instructions = "test instructions", + Tools = [agentTool] + }; + var requestChatOptions = new ChatOptions + { + // No Tools specified + MaxOutputTokens = 100 + }; + + Mock mockService = new(); + ChatOptions? capturedChatOptions = null; + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => + capturedChatOptions = opts) + .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + ChatClientAgent agent = new(mockService.Object, options: new() + { + ChatOptions = agentChatOptions + }); + var messages = new List { new(ChatRole.User, "test") }; + + // Act + await agent.RunAsync(messages, options: new ChatClientAgentRunOptions(requestChatOptions)); + + // Assert + Assert.NotNull(capturedChatOptions); + Assert.NotNull(capturedChatOptions.Tools); + Assert.Single(capturedChatOptions.Tools); + Assert.Contains(agentTool, capturedChatOptions.Tools); // Should contain the agent's tool + } + + /// + /// Verify that ChatOptions merging uses RawRepresentationFactory from request first, with fallback to agent. + /// + [Theory] + [InlineData("MockAgentSetting", "MockRequestSetting", "MockRequestSetting")] + [InlineData("MockAgentSetting", null, "MockAgentSetting")] + [InlineData(null, "MockRequestSetting", "MockRequestSetting")] + public async Task ChatOptionsMergingUsesRawRepresentationFactoryWithFallbackAsync(string? agentSetting, string? requestSetting, string expectedSetting) + { + // Arrange + var agentChatOptions = new ChatOptions + { + Instructions = "test instructions", + RawRepresentationFactory = _ => agentSetting + }; + var requestChatOptions = new ChatOptions + { + RawRepresentationFactory = _ => requestSetting + }; + + Mock mockService = new(); + ChatOptions? capturedChatOptions = null; + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => + capturedChatOptions = opts) + .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + ChatClientAgent agent = new(mockService.Object, options: new() + { + ChatOptions = agentChatOptions + }); + var messages = new List { new(ChatRole.User, "test") }; + + // Act + await agent.RunAsync(messages, options: new ChatClientAgentRunOptions(requestChatOptions)); + + // Assert + Assert.NotNull(capturedChatOptions); + Assert.NotNull(capturedChatOptions.RawRepresentationFactory); + Assert.Equal(expectedSetting, capturedChatOptions.RawRepresentationFactory(null!)); + } + + /// + /// Verify that ChatOptions merging handles all scalar properties correctly. + /// + [Fact] + public async Task ChatOptionsMergingHandlesAllScalarPropertiesCorrectlyAsync() + { + // Arrange + var agentChatOptions = new ChatOptions + { + MaxOutputTokens = 100, + Temperature = 0.7f, + TopP = 0.9f, + TopK = 50, + PresencePenalty = 0.1f, + FrequencyPenalty = 0.2f, + Instructions = "agent instructions", + ModelId = "agent-model", + Seed = 12345, + ConversationId = "agent-conversation", + AllowMultipleToolCalls = true, + StopSequences = ["agent-stop"] + }; + var requestChatOptions = new ChatOptions + { + MaxOutputTokens = 200, + Temperature = 0.3f, + Instructions = "request instructions", + + // Other properties not set, should use agent values + StopSequences = ["request-stop"] + }; + + var expectedChatOptionsMerge = new ChatOptions + { + MaxOutputTokens = 200, + Temperature = 0.3f, + + // Agent value used when request doesn't specify + TopP = 0.9f, + TopK = 50, + PresencePenalty = 0.1f, + FrequencyPenalty = 0.2f, + Instructions = "agent instructions\nrequest instructions", + ModelId = "agent-model", + Seed = 12345, + ConversationId = "agent-conversation", + AllowMultipleToolCalls = true, + + // Merged StopSequences + StopSequences = ["request-stop", "agent-stop"] + }; + + Mock mockService = new(); + ChatOptions? capturedChatOptions = null; + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => + capturedChatOptions = opts) + .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + ChatClientAgent agent = new(mockService.Object, options: new() + { + ChatOptions = agentChatOptions + }); + var messages = new List { new(ChatRole.User, "test") }; + + // Act + await agent.RunAsync(messages, options: new ChatClientAgentRunOptions(requestChatOptions)); + + // Assert + Assert.NotNull(capturedChatOptions); + Assert.Equivalent(expectedChatOptionsMerge, capturedChatOptions); // Should be the equivalent instance (modified in place) + + // Request values should take priority + Assert.Equal(200, capturedChatOptions.MaxOutputTokens); + Assert.Equal(0.3f, capturedChatOptions.Temperature); + + // Merge StopSequences + Assert.Equal(["request-stop", "agent-stop"], capturedChatOptions.StopSequences); + + // Agent values should be used when request doesn't specify + Assert.Equal(0.9f, capturedChatOptions.TopP); + Assert.Equal(50, capturedChatOptions.TopK); + Assert.Equal(0.1f, capturedChatOptions.PresencePenalty); + Assert.Equal(0.2f, capturedChatOptions.FrequencyPenalty); + Assert.Equal("agent-model", capturedChatOptions.ModelId); + Assert.Equal(12345, capturedChatOptions.Seed); + Assert.Equal("agent-conversation", capturedChatOptions.ConversationId); + Assert.Equal(true, capturedChatOptions.AllowMultipleToolCalls); + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs index 04eabf36af..98e5b0ed1a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Text.Json; +using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; @@ -12,7 +13,7 @@ namespace Microsoft.Agents.AI.UnitTests.ChatClient; public class ChatClientAgent_DeserializeThreadTests { [Fact] - public void DeserializeThread_UsesAIContextProviderFactory_IfProvided() + public async Task DeserializeThread_UsesAIContextProviderFactory_IfProvidedAsync() { // Arrange var mockChatClient = new Mock(); @@ -21,10 +22,10 @@ public void DeserializeThread_UsesAIContextProviderFactory_IfProvided() var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions { ChatOptions = new() { Instructions = "Test instructions" }, - AIContextProviderFactory = _ => + AIContextProviderFactory = (_, _) => { factoryCalled = true; - return mockContextProvider.Object; + return new ValueTask(mockContextProvider.Object); } }); @@ -35,7 +36,7 @@ public void DeserializeThread_UsesAIContextProviderFactory_IfProvided() """, TestJsonSerializerContext.Default.JsonElement); // Act - var thread = agent.DeserializeThread(json); + var thread = await agent.DeserializeThreadAsync(json); // Assert Assert.True(factoryCalled, "AIContextProviderFactory was not called."); @@ -45,7 +46,7 @@ public void DeserializeThread_UsesAIContextProviderFactory_IfProvided() } [Fact] - public void DeserializeThread_UsesChatMessageStoreFactory_IfProvided() + public async Task DeserializeThread_UsesChatMessageStoreFactory_IfProvidedAsync() { // Arrange var mockChatClient = new Mock(); @@ -54,10 +55,10 @@ public void DeserializeThread_UsesChatMessageStoreFactory_IfProvided() var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions { ChatOptions = new() { Instructions = "Test instructions" }, - ChatMessageStoreFactory = _ => + ChatMessageStoreFactory = (_, _) => { factoryCalled = true; - return mockMessageStore.Object; + return new ValueTask(mockMessageStore.Object); } }); @@ -68,7 +69,7 @@ public void DeserializeThread_UsesChatMessageStoreFactory_IfProvided() """, TestJsonSerializerContext.Default.JsonElement); // Act - var thread = agent.DeserializeThread(json); + var thread = await agent.DeserializeThreadAsync(json); // Assert Assert.True(factoryCalled, "ChatMessageStoreFactory was not called."); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_GetNewThreadTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_GetNewThreadTests.cs index 628d738e72..e6cc7e90e9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_GetNewThreadTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_GetNewThreadTests.cs @@ -1,17 +1,18 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; namespace Microsoft.Agents.AI.UnitTests.ChatClient; /// -/// Contains unit tests for the ChatClientAgent.GetNewThread methods. +/// Contains unit tests for the ChatClientAgent.GetNewThreadAsync methods. /// public class ChatClientAgent_GetNewThreadTests { [Fact] - public void GetNewThread_UsesAIContextProviderFactory_IfProvided() + public async Task GetNewThread_UsesAIContextProviderFactory_IfProvidedAsync() { // Arrange var mockChatClient = new Mock(); @@ -20,15 +21,15 @@ public void GetNewThread_UsesAIContextProviderFactory_IfProvided() var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions { ChatOptions = new() { Instructions = "Test instructions" }, - AIContextProviderFactory = _ => + AIContextProviderFactory = (_, _) => { factoryCalled = true; - return mockContextProvider.Object; + return new ValueTask(mockContextProvider.Object); } }); // Act - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); // Assert Assert.True(factoryCalled, "AIContextProviderFactory was not called."); @@ -38,7 +39,7 @@ public void GetNewThread_UsesAIContextProviderFactory_IfProvided() } [Fact] - public void GetNewThread_UsesChatMessageStoreFactory_IfProvided() + public async Task GetNewThread_UsesChatMessageStoreFactory_IfProvidedAsync() { // Arrange var mockChatClient = new Mock(); @@ -47,15 +48,15 @@ public void GetNewThread_UsesChatMessageStoreFactory_IfProvided() var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions { ChatOptions = new() { Instructions = "Test instructions" }, - ChatMessageStoreFactory = _ => + ChatMessageStoreFactory = (_, _) => { factoryCalled = true; - return mockMessageStore.Object; + return new ValueTask(mockMessageStore.Object); } }); // Act - var thread = agent.GetNewThread(); + var thread = await agent.GetNewThreadAsync(); // Assert Assert.True(factoryCalled, "ChatMessageStoreFactory was not called."); @@ -65,7 +66,7 @@ public void GetNewThread_UsesChatMessageStoreFactory_IfProvided() } [Fact] - public void GetNewThread_UsesChatMessageStore_FromTypedOverload() + public async Task GetNewThread_UsesChatMessageStore_FromTypedOverloadAsync() { // Arrange var mockChatClient = new Mock(); @@ -73,7 +74,7 @@ public void GetNewThread_UsesChatMessageStore_FromTypedOverload() var agent = new ChatClientAgent(mockChatClient.Object); // Act - var thread = agent.GetNewThread(mockMessageStore.Object); + var thread = await agent.GetNewThreadAsync(mockMessageStore.Object); // Assert Assert.IsType(thread); @@ -82,7 +83,7 @@ public void GetNewThread_UsesChatMessageStore_FromTypedOverload() } [Fact] - public void GetNewThread_UsesConversationId_FromTypedOverload() + public async Task GetNewThread_UsesConversationId_FromTypedOverloadAsync() { // Arrange var mockChatClient = new Mock(); @@ -90,7 +91,7 @@ public void GetNewThread_UsesConversationId_FromTypedOverload() var agent = new ChatClientAgent(mockChatClient.Object); // Act - var thread = agent.GetNewThread(TestConversationId); + var thread = await agent.GetNewThreadAsync(TestConversationId); // Assert Assert.IsType(thread); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_RunWithCustomOptionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_RunWithCustomOptionsTests.cs index 85cb4cf0b4..4c85bcbb51 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_RunWithCustomOptionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_RunWithCustomOptionsTests.cs @@ -30,11 +30,11 @@ public async Task RunAsync_WithThreadAndOptions_CallsBaseMethodAsync() It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "Response")])); ChatClientAgent agent = new(mockChatClient.Object); - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(); ChatClientAgentRunOptions options = new(); // Act - AgentRunResponse result = await agent.RunAsync(thread, options); + AgentResponse result = await agent.RunAsync(thread, options); // Assert Assert.NotNull(result); @@ -59,11 +59,11 @@ public async Task RunAsync_WithStringMessageAndOptions_CallsBaseMethodAsync() It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "Response")])); ChatClientAgent agent = new(mockChatClient.Object); - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(); ChatClientAgentRunOptions options = new(); // Act - AgentRunResponse result = await agent.RunAsync("Test message", thread, options); + AgentResponse result = await agent.RunAsync("Test message", thread, options); // Assert Assert.NotNull(result); @@ -88,12 +88,12 @@ public async Task RunAsync_WithChatMessageAndOptions_CallsBaseMethodAsync() It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "Response")])); ChatClientAgent agent = new(mockChatClient.Object); - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(); ChatMessage message = new(ChatRole.User, "Test message"); ChatClientAgentRunOptions options = new(); // Act - AgentRunResponse result = await agent.RunAsync(message, thread, options); + AgentResponse result = await agent.RunAsync(message, thread, options); // Assert Assert.NotNull(result); @@ -118,12 +118,12 @@ public async Task RunAsync_WithMessagesCollectionAndOptions_CallsBaseMethodAsync It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "Response")])); ChatClientAgent agent = new(mockChatClient.Object); - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(); IEnumerable messages = [new(ChatRole.User, "Message 1"), new(ChatRole.User, "Message 2")]; ChatClientAgentRunOptions options = new(); // Act - AgentRunResponse result = await agent.RunAsync(messages, thread, options); + AgentResponse result = await agent.RunAsync(messages, thread, options); // Assert Assert.NotNull(result); @@ -151,7 +151,7 @@ public async Task RunAsync_WithChatOptionsInRunOptions_UsesChatOptionsAsync() ChatClientAgentRunOptions options = new(new ChatOptions { Temperature = 0.5f }); // Act - AgentRunResponse result = await agent.RunAsync("Test", null, options); + AgentResponse result = await agent.RunAsync("Test", null, options); // Assert Assert.NotNull(result); @@ -179,11 +179,11 @@ public async Task RunStreamingAsync_WithThreadAndOptions_CallsBaseMethodAsync() It.IsAny())).Returns(GetAsyncUpdatesAsync()); ChatClientAgent agent = new(mockChatClient.Object); - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(); ChatClientAgentRunOptions options = new(); // Act - var updates = new List(); + var updates = new List(); await foreach (var update in agent.RunStreamingAsync(thread, options)) { updates.Add(update); @@ -211,11 +211,11 @@ public async Task RunStreamingAsync_WithStringMessageAndOptions_CallsBaseMethodA It.IsAny())).Returns(GetAsyncUpdatesAsync()); ChatClientAgent agent = new(mockChatClient.Object); - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(); ChatClientAgentRunOptions options = new(); // Act - var updates = new List(); + var updates = new List(); await foreach (var update in agent.RunStreamingAsync("Test message", thread, options)) { updates.Add(update); @@ -243,12 +243,12 @@ public async Task RunStreamingAsync_WithChatMessageAndOptions_CallsBaseMethodAsy It.IsAny())).Returns(GetAsyncUpdatesAsync()); ChatClientAgent agent = new(mockChatClient.Object); - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(); ChatMessage message = new(ChatRole.User, "Test message"); ChatClientAgentRunOptions options = new(); // Act - var updates = new List(); + var updates = new List(); await foreach (var update in agent.RunStreamingAsync(message, thread, options)) { updates.Add(update); @@ -276,12 +276,12 @@ public async Task RunStreamingAsync_WithMessagesCollectionAndOptions_CallsBaseMe It.IsAny())).Returns(GetAsyncUpdatesAsync()); ChatClientAgent agent = new(mockChatClient.Object); - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(); IEnumerable messages = [new ChatMessage(ChatRole.User, "Message 1"), new ChatMessage(ChatRole.User, "Message 2")]; ChatClientAgentRunOptions options = new(); // Act - var updates = new List(); + var updates = new List(); await foreach (var update in agent.RunStreamingAsync(messages, thread, options)) { updates.Add(update); @@ -324,16 +324,16 @@ public async Task RunAsyncOfT_WithThreadAndOptions_CallsBaseMethodAsync() It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, """{"id":2, "fullName":"Tigger", "species":"Tiger"}""")])); ChatClientAgent agent = new(mockChatClient.Object); - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(); ChatClientAgentRunOptions options = new(); // Act - AgentRunResponse agentRunResponse = await agent.RunAsync(thread, JsonContext_WithCustomRunOptions.Default.Options, options); + AgentResponse agentResponse = await agent.RunAsync(thread, JsonContext_WithCustomRunOptions.Default.Options, options); // Assert - Assert.NotNull(agentRunResponse); - Assert.Single(agentRunResponse.Messages); - Assert.Equal("Tigger", agentRunResponse.Result.FullName); + Assert.NotNull(agentResponse); + Assert.Single(agentResponse.Messages); + Assert.Equal("Tigger", agentResponse.Result.FullName); mockChatClient.Verify( x => x.GetResponseAsync( It.IsAny>(), @@ -354,16 +354,16 @@ public async Task RunAsyncOfT_WithStringMessageAndOptions_CallsBaseMethodAsync() It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, """{"id":2, "fullName":"Tigger", "species":"Tiger"}""")])); ChatClientAgent agent = new(mockChatClient.Object); - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(); ChatClientAgentRunOptions options = new(); // Act - AgentRunResponse agentRunResponse = await agent.RunAsync("Test message", thread, JsonContext_WithCustomRunOptions.Default.Options, options); + AgentResponse agentResponse = await agent.RunAsync("Test message", thread, JsonContext_WithCustomRunOptions.Default.Options, options); // Assert - Assert.NotNull(agentRunResponse); - Assert.Single(agentRunResponse.Messages); - Assert.Equal("Tigger", agentRunResponse.Result.FullName); + Assert.NotNull(agentResponse); + Assert.Single(agentResponse.Messages); + Assert.Equal("Tigger", agentResponse.Result.FullName); mockChatClient.Verify( x => x.GetResponseAsync( It.Is>(msgs => msgs.Any(m => m.Text == "Test message")), @@ -384,17 +384,17 @@ public async Task RunAsyncOfT_WithChatMessageAndOptions_CallsBaseMethodAsync() It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, """{"id":2, "fullName":"Tigger", "species":"Tiger"}""")])); ChatClientAgent agent = new(mockChatClient.Object); - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(); ChatMessage message = new(ChatRole.User, "Test message"); ChatClientAgentRunOptions options = new(); // Act - AgentRunResponse agentRunResponse = await agent.RunAsync(message, thread, JsonContext_WithCustomRunOptions.Default.Options, options); + AgentResponse agentResponse = await agent.RunAsync(message, thread, JsonContext_WithCustomRunOptions.Default.Options, options); // Assert - Assert.NotNull(agentRunResponse); - Assert.Single(agentRunResponse.Messages); - Assert.Equal("Tigger", agentRunResponse.Result.FullName); + Assert.NotNull(agentResponse); + Assert.Single(agentResponse.Messages); + Assert.Equal("Tigger", agentResponse.Result.FullName); mockChatClient.Verify( x => x.GetResponseAsync( It.Is>(msgs => msgs.Contains(message)), @@ -415,17 +415,17 @@ public async Task RunAsyncOfT_WithMessagesCollectionAndOptions_CallsBaseMethodAs It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, """{"id":2, "fullName":"Tigger", "species":"Tiger"}""")])); ChatClientAgent agent = new(mockChatClient.Object); - AgentThread thread = agent.GetNewThread(); + AgentThread thread = await agent.GetNewThreadAsync(); IEnumerable messages = [new(ChatRole.User, "Message 1"), new(ChatRole.User, "Message 2")]; ChatClientAgentRunOptions options = new(); // Act - AgentRunResponse agentRunResponse = await agent.RunAsync(messages, thread, JsonContext_WithCustomRunOptions.Default.Options, options); + AgentResponse agentResponse = await agent.RunAsync(messages, thread, JsonContext_WithCustomRunOptions.Default.Options, options); // Assert - Assert.NotNull(agentRunResponse); - Assert.Single(agentRunResponse.Messages); - Assert.Equal("Tigger", agentRunResponse.Result.FullName); + Assert.NotNull(agentResponse); + Assert.Single(agentResponse.Messages); + Assert.Equal("Tigger", agentResponse.Result.FullName); mockChatClient.Verify( x => x.GetResponseAsync( It.IsAny>(), diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientExtensionsTests.cs index 51beb6aa2e..484b0a6fff 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientExtensionsTests.cs @@ -19,7 +19,7 @@ public void CreateAIAgent_WithBasicParameters_CreatesAgent() var chatClientMock = new Mock(); // Act - var agent = chatClientMock.Object.CreateAIAgent( + var agent = chatClientMock.Object.AsAIAgent( instructions: "Test instructions", name: "TestAgent", description: "Test description" @@ -40,7 +40,7 @@ public void CreateAIAgent_WithTools_SetsToolsInOptions() var tools = new List { new Mock().Object }; // Act - var agent = chatClientMock.Object.CreateAIAgent(tools: tools); + var agent = chatClientMock.Object.AsAIAgent(tools: tools); // Assert Assert.NotNull(agent); @@ -62,7 +62,7 @@ public void CreateAIAgent_WithOptions_CreatesAgentWithOptions() }; // Act - var agent = chatClientMock.Object.CreateAIAgent(options); + var agent = chatClientMock.Object.AsAIAgent(options); // Assert Assert.NotNull(agent); @@ -79,7 +79,7 @@ public void CreateAIAgent_WithNullClient_Throws() IChatClient chatClient = null!; // Act & Assert - Assert.Throws(() => chatClient.CreateAIAgent(instructions: "instructions")); + Assert.Throws(() => chatClient.AsAIAgent(instructions: "instructions")); } [Fact] @@ -89,6 +89,6 @@ public void CreateAIAgent_WithNullClientAndOptions_Throws() IChatClient chatClient = null!; // Act & Assert - Assert.Throws(() => chatClient.CreateAIAgent(options: new() { ChatOptions = new() { Instructions = "instructions" } })); + Assert.Throws(() => chatClient.AsAIAgent(options: new() { ChatOptions = new() { Instructions = "instructions" } })); } } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/FunctionInvocationDelegatingAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/FunctionInvocationDelegatingAgentTests.cs index 5866e610b6..50955234e5 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/FunctionInvocationDelegatingAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/FunctionInvocationDelegatingAgentTests.cs @@ -717,7 +717,7 @@ public async Task RunAsync_FunctionMiddlewareWithRunningMiddleware_BothExecuteAs var innerAgent = new ChatClientAgent(mockChatClient.Object); var messages = new List { new(ChatRole.User, "Test message") }; - async Task RunningMiddlewareCallbackAsync(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) + async Task RunningMiddlewareCallbackAsync(IEnumerable messages, AgentThread? thread, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken) { executionOrder.Add("Running-Pre"); var result = await innerAgent.RunAsync(messages, thread, options, cancellationToken); @@ -800,7 +800,7 @@ public async Task RunStreamingAsync_WithFunctionCall_InvokesMiddlewareAsync() // Act var options = new ChatClientAgentRunOptions(new ChatOptions { Tools = [testFunction] }); - var responseUpdates = new List(); + var responseUpdates = new List(); await foreach (var update in middleware.RunStreamingAsync(messages, null, options, CancellationToken.None)) { responseUpdates.Add(update); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/LoggingAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/LoggingAgentTests.cs index 58e9536491..b5e701cb38 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/LoggingAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/LoggingAgentTests.cs @@ -83,7 +83,7 @@ public async Task RunAsync_LogsAtDebugLevelAsync() RunAsyncFunc = async (messages, thread, options, cancellationToken) => { await Task.Yield(); - return new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "Test response")); + return new AgentResponse(new ChatMessage(ChatRole.Assistant, "Test response")); } }; @@ -126,7 +126,7 @@ public async Task RunAsync_LogsAtTraceLevel_IncludesSensitiveDataAsync() RunAsyncFunc = async (messages, thread, options, cancellationToken) => { await Task.Yield(); - return new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "Test response")); + return new AgentResponse(new ChatMessage(ChatRole.Assistant, "Test response")); } }; @@ -228,11 +228,11 @@ public async Task RunStreamingAsync_LogsAtDebugLevelAsync() RunStreamingAsyncFunc = CallbackAsync }; - static async IAsyncEnumerable CallbackAsync( + static async IAsyncEnumerable CallbackAsync( IEnumerable messages, AgentThread? thread, AgentRunOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) { await Task.Yield(); - yield return new AgentRunResponseUpdate(ChatRole.Assistant, "Test"); + yield return new AgentResponseUpdate(ChatRole.Assistant, "Test"); } var agent = new LoggingAgent(innerAgent, mockLogger.Object); @@ -277,12 +277,12 @@ public async Task RunStreamingAsync_LogsUpdatesAtTraceLevelAsync() RunStreamingAsyncFunc = CallbackAsync }; - static async IAsyncEnumerable CallbackAsync( + static async IAsyncEnumerable CallbackAsync( IEnumerable messages, AgentThread? thread, AgentRunOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) { await Task.Yield(); - yield return new AgentRunResponseUpdate(ChatRole.Assistant, "Update 1"); - yield return new AgentRunResponseUpdate(ChatRole.Assistant, "Update 2"); + yield return new AgentResponseUpdate(ChatRole.Assistant, "Update 1"); + yield return new AgentResponseUpdate(ChatRole.Assistant, "Update 2"); } var agent = new LoggingAgent(innerAgent, mockLogger.Object); @@ -317,7 +317,7 @@ public async Task RunStreamingAsync_OnCancellation_LogsCanceledAsync() RunStreamingAsyncFunc = CallbackAsync }; - static async IAsyncEnumerable CallbackAsync( + static async IAsyncEnumerable CallbackAsync( IEnumerable messages, AgentThread? thread, AgentRunOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) { await Task.Yield(); @@ -364,7 +364,7 @@ public async Task RunStreamingAsync_OnException_LogsFailedAsync() RunStreamingAsyncFunc = CallbackAsync }; - static async IAsyncEnumerable CallbackAsync( + static async IAsyncEnumerable CallbackAsync( IEnumerable messages, AgentThread? thread, AgentRunOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) { await Task.Yield(); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/OpenTelemetryAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/OpenTelemetryAgentTests.cs index 405832763c..84dc1d7a20 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/OpenTelemetryAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/OpenTelemetryAgentTests.cs @@ -82,7 +82,7 @@ public async Task WithoutChatOptions_ExpectedInformationLogged_Async(bool enable RunAsyncFunc = async (messages, thread, options, cancellationToken) => { await Task.Yield(); - return new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "The blue whale, I think.")) + return new AgentResponse(new ChatMessage(ChatRole.Assistant, "The blue whale, I think.")) { ResponseId = "id123", Usage = new UsageDetails @@ -107,7 +107,7 @@ public async Task WithoutChatOptions_ExpectedInformationLogged_Async(bool enable null, }; - async static IAsyncEnumerable CallbackAsync( + async static IAsyncEnumerable CallbackAsync( IEnumerable messages, AgentThread? thread, AgentRunOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) { await Task.Yield(); @@ -115,13 +115,13 @@ async static IAsyncEnumerable CallbackAsync( foreach (string text in new[] { "The ", "blue ", "whale,", " ", "", "I", " think." }) { await Task.Yield(); - yield return new AgentRunResponseUpdate(ChatRole.Assistant, text) + yield return new AgentResponseUpdate(ChatRole.Assistant, text) { ResponseId = "id123", }; } - yield return new AgentRunResponseUpdate + yield return new AgentResponseUpdate { Contents = [new UsageContent(new() { @@ -307,7 +307,7 @@ public async Task WithChatOptions_ExpectedInformationLogged_Async( RunAsyncFunc = async (messages, thread, options, cancellationToken) => { await Task.Yield(); - return new AgentRunResponse(new ChatMessage(ChatRole.Assistant, "The blue whale, I think.")) + return new AgentResponse(new ChatMessage(ChatRole.Assistant, "The blue whale, I think.")) { ResponseId = "id123", Usage = new UsageDetails @@ -332,7 +332,7 @@ public async Task WithChatOptions_ExpectedInformationLogged_Async( null, }; - async static IAsyncEnumerable CallbackAsync( + async static IAsyncEnumerable CallbackAsync( IEnumerable messages, AgentThread? thread, AgentRunOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) { await Task.Yield(); @@ -340,13 +340,13 @@ async static IAsyncEnumerable CallbackAsync( foreach (string text in new[] { "The ", "blue ", "whale,", " ", "", "I", " think." }) { await Task.Yield(); - yield return new AgentRunResponseUpdate(ChatRole.Assistant, text) + yield return new AgentResponseUpdate(ChatRole.Assistant, text) { ResponseId = "id123", }; } - yield return new AgentRunResponseUpdate + yield return new AgentResponseUpdate { Contents = [new UsageContent(new() { diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestAIAgent.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestAIAgent.cs index 3d2cdff868..473c01bb6b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestAIAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestAIAgent.cs @@ -16,24 +16,24 @@ internal sealed class TestAIAgent : AIAgent public Func DeserializeThreadFunc = delegate { throw new NotSupportedException(); }; public Func GetNewThreadFunc = delegate { throw new NotSupportedException(); }; - public Func, AgentThread?, AgentRunOptions?, CancellationToken, Task> RunAsyncFunc = delegate { throw new NotSupportedException(); }; - public Func, AgentThread?, AgentRunOptions?, CancellationToken, IAsyncEnumerable> RunStreamingAsyncFunc = delegate { throw new NotSupportedException(); }; + public Func, AgentThread?, AgentRunOptions?, CancellationToken, Task> RunAsyncFunc = delegate { throw new NotSupportedException(); }; + public Func, AgentThread?, AgentRunOptions?, CancellationToken, IAsyncEnumerable> RunStreamingAsyncFunc = delegate { throw new NotSupportedException(); }; public Func? GetServiceFunc; public override string? Name => this.NameFunc?.Invoke() ?? base.Name; public override string? Description => this.DescriptionFunc?.Invoke() ?? base.Description; - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) => - this.DeserializeThreadFunc(serializedThread, jsonSerializerOptions); + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => + new(this.DeserializeThreadFunc(serializedThread, jsonSerializerOptions)); - public override AgentThread GetNewThread() => - this.GetNewThreadFunc(); + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) => + new(this.GetNewThreadFunc()); - protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => this.RunAsyncFunc(messages, thread, options, cancellationToken); - protected override IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => + protected override IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => this.RunStreamingAsyncFunc(messages, thread, options, cancellationToken); public override object? GetService(Type serviceType, object? serviceKey = null) => diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/Framework/WorkflowEvents.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/Framework/WorkflowEvents.cs index a9f1789449..0cf044e02e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/Framework/WorkflowEvents.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/Framework/WorkflowEvents.cs @@ -18,7 +18,7 @@ public WorkflowEvents(IReadOnlyList workflowEvents) this.ExecutorInvokeEvents = workflowEvents.OfType().ToList(); this.ExecutorCompleteEvents = workflowEvents.OfType().ToList(); this.InputEvents = workflowEvents.OfType().ToList(); - this.AgentResponseEvents = workflowEvents.OfType().ToList(); + this.AgentResponseEvents = workflowEvents.OfType().ToList(); } public IReadOnlyList Events { get; } @@ -29,5 +29,5 @@ public WorkflowEvents(IReadOnlyList workflowEvents) public IReadOnlyList ExecutorInvokeEvents { get; } public IReadOnlyList ExecutorCompleteEvents { get; } public IReadOnlyList InputEvents { get; } - public IReadOnlyList AgentResponseEvents { get; } + public IReadOnlyList AgentResponseEvents { get; } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/Framework/WorkflowHarness.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/Framework/WorkflowHarness.cs index ed3e0367f7..afd9b18fb9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/Framework/WorkflowHarness.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/Framework/WorkflowHarness.cs @@ -142,7 +142,7 @@ private static async IAsyncEnumerable MonitorAndDisposeWorkflowRu Console.WriteLine($"ACTION: {actionInvokeEvent.ActionId} [{actionInvokeEvent.ActionType}]"); break; - case AgentRunResponseEvent responseEvent: + case AgentResponseEvent responseEvent: if (!string.IsNullOrEmpty(responseEvent.Response.Text)) { Console.WriteLine($"AGENT: {responseEvent.Response.AgentId}: {responseEvent.Response.Text}"); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/Framework/WorkflowTest.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/Framework/WorkflowTest.cs index 3649355182..20cc823553 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/Framework/WorkflowTest.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/Framework/WorkflowTest.cs @@ -124,7 +124,7 @@ public static void EventCounts(int actualCount, Testcase testcase, bool isComple } } - public static void Responses(IReadOnlyList responseEvents, Testcase testcase) + public static void Responses(IReadOnlyList responseEvents, Testcase testcase) { Assert.True(responseEvents.Count >= testcase.Validation.MinResponseCount, $"Response count less than expected: {testcase.Validation.MinResponseCount} (Actual: {responseEvents.Count})"); if (testcase.Validation.MaxResponseCount != -1) diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/MediaInputTest.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/MediaInputTest.cs index d5976a3174..ae3cdfd7a9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/MediaInputTest.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/MediaInputTest.cs @@ -84,7 +84,7 @@ private async Task ValidateFileAsync(AIContent fileContent) WorkflowEvents workflowEvents = await harness.RunWorkflowAsync(inputMessage).ConfigureAwait(false); ConversationUpdateEvent conversationEvent = Assert.Single(workflowEvents.ConversationEvents); this.Output.WriteLine("CONVERSATION: " + conversationEvent.ConversationId); - AgentRunResponseEvent agentResponseEvent = Assert.Single(workflowEvents.AgentResponseEvents); + AgentResponseEvent agentResponseEvent = Assert.Single(workflowEvents.AgentResponseEvents); this.Output.WriteLine("RESPONSE: " + agentResponseEvent.Response.Text); Assert.NotEmpty(agentResponseEvent.Response.Text); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/DeclarativeWorkflowTest.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/DeclarativeWorkflowTest.cs index d606770ff8..cef0c7aeea 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/DeclarativeWorkflowTest.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/DeclarativeWorkflowTest.cs @@ -352,7 +352,7 @@ private async Task RunWorkflowAsync(string workflowPath, TInput workflow this.Output.WriteLine($"ACTIVITY: {activityEvent.Message}"); break; - case AgentRunResponseEvent messageEvent: + case AgentResponseEvent messageEvent: this.Output.WriteLine($"MESSAGE: {messageEvent.Response.Messages[0].Text.Trim()}"); break; diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Events/ExternalInputRequestTest.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Events/ExternalInputRequestTest.cs index 49d06337fe..d1165d84d4 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Events/ExternalInputRequestTest.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Events/ExternalInputRequestTest.cs @@ -15,7 +15,7 @@ public sealed class ExternalInputRequestTest(ITestOutputHelper output) : EventTe public void VerifySerializationWithText() { // Arrange - ExternalInputRequest source = new(new AgentRunResponse(new ChatMessage(ChatRole.User, "Wassup?"))); + ExternalInputRequest source = new(new AgentResponse(new ChatMessage(ChatRole.User, "Wassup?"))); // Act ExternalInputRequest copy = VerifyEventSerialization(source); @@ -30,7 +30,7 @@ public void VerifySerializationWithRequests() { // Arrange ExternalInputRequest source = - new(new AgentRunResponse( + new(new AgentResponse( new ChatMessage( ChatRole.Assistant, [ diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/CancelWorkflow.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/CancelWorkflow.cs index 9cf72036fd..d407be3ae1 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/CancelWorkflow.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/CancelWorkflow.cs @@ -60,8 +60,8 @@ await context.FormatTemplateAsync( NEVER 1! """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); return default; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/Condition.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/Condition.cs index 68deb1f19e..3dfa7b4f68 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/Condition.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/Condition.cs @@ -1,4 +1,4 @@ -// ------------------------------------------------------------------------------ +// ------------------------------------------------------------------------------ // // This code was generated by a tool. // @@ -47,7 +47,7 @@ protected override async ValueTask ExecuteAsync(TInput message, IWorkflowContext await context.QueueStateUpdateAsync("TestValue", UnassignedValue.Instance, "Local").ConfigureAwait(false); } } - + /// /// Assigns an evaluated expression, other variable, or literal value to the "Local.TestValue" variable. /// @@ -58,11 +58,11 @@ internal sealed class SetvariableTestExecutor(FormulaSession session) : ActionEx { object? evaluatedValue = await context.EvaluateValueAsync("Value(System.LastMessageText)").ConfigureAwait(false); await context.QueueStateUpdateAsync(key: "TestValue", value: evaluatedValue, scopeName: "Local").ConfigureAwait(false); - + return default; } } - + /// /// Conditional branching similar to an if / elseif / elseif / else chain. /// @@ -76,17 +76,17 @@ internal sealed class ConditiongroupTestExecutor(FormulaSession session) : Actio { return "conditionItem_odd"; } - + bool condition1 = await context.EvaluateValueAsync("Mod(Local.TestValue, 2) = 0").ConfigureAwait(false); if (condition1) { return "conditionItem_even"; } - + return "conditionGroup_testElseActions"; } } - + /// /// Formats a message template and sends an activity event. /// @@ -101,13 +101,13 @@ await context.FormatTemplateAsync( ODD """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); - + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); + return default; } } - + /// /// Formats a message template and sends an activity event. /// @@ -122,13 +122,13 @@ await context.FormatTemplateAsync( EVEN """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); - + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); + return default; } } - + /// /// Formats a message template and sends an activity event. /// @@ -143,16 +143,16 @@ await context.FormatTemplateAsync( All done! """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); - + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); + return default; } } - + public static Workflow CreateWorkflow( DeclarativeWorkflowOptions options, - Func? inputTransform = null) + Func? inputTransform = null) where TInput : notnull { // Create root executor to initialize the workflow. @@ -198,4 +198,4 @@ public static Workflow CreateWorkflow( // Build the workflow return builder.Build(validateOrphans: false); } -} \ No newline at end of file +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/ConditionElse.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/ConditionElse.cs index 47e278bc59..2f64bdd3e5 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/ConditionElse.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/ConditionElse.cs @@ -1,4 +1,4 @@ -// ------------------------------------------------------------------------------ +// ------------------------------------------------------------------------------ // // This code was generated by a tool. // @@ -47,7 +47,7 @@ protected override async ValueTask ExecuteAsync(TInput message, IWorkflowContext await context.QueueStateUpdateAsync("TestValue", UnassignedValue.Instance, "Local").ConfigureAwait(false); } } - + /// /// Assigns an evaluated expression, other variable, or literal value to the "Local.TestValue" variable. /// @@ -58,11 +58,11 @@ internal sealed class SetvariableTestExecutor(FormulaSession session) : ActionEx { object? evaluatedValue = await context.EvaluateValueAsync("Value(System.LastMessageText)").ConfigureAwait(false); await context.QueueStateUpdateAsync(key: "TestValue", value: evaluatedValue, scopeName: "Local").ConfigureAwait(false); - + return default; } } - + /// /// Conditional branching similar to an if / elseif / elseif / else chain. /// @@ -76,11 +76,11 @@ internal sealed class ConditiongroupTestExecutor(FormulaSession session) : Actio { return "conditionItem_odd"; } - + return "conditionGroup_testElseActions"; } } - + /// /// Formats a message template and sends an activity event. /// @@ -95,13 +95,13 @@ await context.FormatTemplateAsync( ODD """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); - + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); + return default; } } - + /// /// Formats a message template and sends an activity event. /// @@ -116,13 +116,13 @@ await context.FormatTemplateAsync( EVEN """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); - + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); + return default; } } - + /// /// Formats a message template and sends an activity event. /// @@ -137,16 +137,16 @@ await context.FormatTemplateAsync( All done! """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); - + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); + return default; } } - + public static Workflow CreateWorkflow( DeclarativeWorkflowOptions options, - Func? inputTransform = null) + Func? inputTransform = null) where TInput : notnull { // Create root executor to initialize the workflow. @@ -190,4 +190,4 @@ public static Workflow CreateWorkflow( // Build the workflow return builder.Build(validateOrphans: false); } -} \ No newline at end of file +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/EndConversation.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/EndConversation.cs index 9cf72036fd..d407be3ae1 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/EndConversation.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/EndConversation.cs @@ -60,8 +60,8 @@ await context.FormatTemplateAsync( NEVER 1! """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); return default; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/EndWorkflow.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/EndWorkflow.cs index d17ddeb3a6..d407be3ae1 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/EndWorkflow.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/EndWorkflow.cs @@ -1,4 +1,4 @@ -// ------------------------------------------------------------------------------ +// ------------------------------------------------------------------------------ // // This code was generated by a tool. // @@ -60,8 +60,8 @@ await context.FormatTemplateAsync( NEVER 1! """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); return default; } @@ -91,4 +91,4 @@ public static Workflow CreateWorkflow( // Build the workflow return builder.Build(validateOrphans: false); } -} \ No newline at end of file +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/Goto.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/Goto.cs index d684232acf..d8413387c9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/Goto.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/Goto.cs @@ -60,8 +60,8 @@ await context.FormatTemplateAsync( NEVER 1! """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); return default; } @@ -81,8 +81,8 @@ await context.FormatTemplateAsync( NEVER 2! """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); return default; } @@ -102,8 +102,8 @@ await context.FormatTemplateAsync( NEVER 3! """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); return default; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/InvokeAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/InvokeAgent.cs index 08d324cbcd..08002d16fb 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/InvokeAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/InvokeAgent.cs @@ -70,7 +70,7 @@ internal sealed class InvokeAgentExecutor(FormulaSession session, WorkflowAgentP bool autoSend = true; IList? inputMessages = await context.EvaluateListAsync("[UserMessage(System.LastMessageText)]").ConfigureAwait(false); - AgentRunResponse agentResponse = + AgentResponse agentResponse = await InvokeAgentAsync( context, agentName, @@ -81,7 +81,7 @@ await InvokeAgentAsync( if (autoSend) { - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, agentResponse)).ConfigureAwait(false); + await context.AddEventAsync(new AgentResponseEvent(this.Id, agentResponse)).ConfigureAwait(false); } return default; diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/LoopBreak.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/LoopBreak.cs index f9bd0b6bd8..f4ee656e0c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/LoopBreak.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/LoopBreak.cs @@ -135,8 +135,8 @@ await context.FormatTemplateAsync( x{Local.Count} - {Local.LoopIndex}:{Local.LoopValue} """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); return default; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/LoopContinue.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/LoopContinue.cs index 507f99995c..474d69e7ed 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/LoopContinue.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/LoopContinue.cs @@ -135,8 +135,8 @@ await context.FormatTemplateAsync( x{Local.Count} - {Local.LoopIndex}:{Local.LoopValue} """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); return default; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/LoopEach.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/LoopEach.cs index 6d141cfe21..06137f222d 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/LoopEach.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/LoopEach.cs @@ -135,8 +135,8 @@ await context.FormatTemplateAsync( x{Local.Count} - {Local.LoopIndex}:{Local.LoopValue} """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); return default; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/SendActivity.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/SendActivity.cs index 69a99467ca..05cd29c574 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/SendActivity.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/Workflows/SendActivity.cs @@ -77,8 +77,8 @@ await context.FormatTemplateAsync( Input: "{Local.TestValue}" """ ); - AgentRunResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); - await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response)).ConfigureAwait(false); + AgentResponse response = new([new ChatMessage(ChatRole.Assistant, activityText)]); + await context.AddEventAsync(new AgentResponseEvent(this.Id, response)).ConfigureAwait(false); return default; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs index c45ef8726e..4ed540c34d 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs @@ -135,26 +135,26 @@ private class DoubleEchoAgent(string name) : AIAgent { public override string Name => name; - public override AgentThread GetNewThread() - => new DoubleEchoAgentThread(); + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) + => new(new DoubleEchoAgentThread()); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - => new DoubleEchoAgentThread(); + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + => new(new DoubleEchoAgentThread()); - protected override Task RunCoreAsync( + protected override Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { await Task.Yield(); var contents = messages.SelectMany(m => m.Contents).ToList(); string id = Guid.NewGuid().ToString("N"); - yield return new AgentRunResponseUpdate(ChatRole.Assistant, this.Name) { AuthorName = this.Name, MessageId = id }; - yield return new AgentRunResponseUpdate(ChatRole.Assistant, contents) { AuthorName = this.Name, MessageId = id }; - yield return new AgentRunResponseUpdate(ChatRole.Assistant, contents) { AuthorName = this.Name, MessageId = id }; + yield return new AgentResponseUpdate(ChatRole.Assistant, this.Name) { AuthorName = this.Name, MessageId = id }; + yield return new AgentResponseUpdate(ChatRole.Assistant, contents) { AuthorName = this.Name, MessageId = id }; + yield return new AgentResponseUpdate(ChatRole.Assistant, contents) { AuthorName = this.Name, MessageId = id }; } } @@ -393,7 +393,7 @@ public async Task BuildGroupChat_AgentsRunInOrderAsync(int maxIterations) WorkflowOutputEvent? output = null; await foreach (WorkflowEvent evt in run.WatchStreamAsync().ConfigureAwait(false)) { - if (evt is AgentRunUpdateEvent executorComplete) + if (evt is AgentResponseUpdateEvent executorComplete) { sb.Append(executorComplete.Data); } @@ -409,7 +409,7 @@ public async Task BuildGroupChat_AgentsRunInOrderAsync(int maxIterations) private sealed class DoubleEchoAgentWithBarrier(string name, StrongBox> barrier, StrongBox remaining) : DoubleEchoAgent(name) { - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { if (Interlocked.Decrement(ref remaining.Value) == 0) diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/ChatMessageBuilder.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/ChatMessageBuilder.cs index 659babe7e1..f3f7990c6c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/ChatMessageBuilder.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/ChatMessageBuilder.cs @@ -25,7 +25,7 @@ public static IEnumerable ToContentStream(this string? message) return splits.Select(text => (AIContent)new TextContent(text) { RawRepresentation = text }); } - public static AgentRunResponseUpdate ToResponseUpdate(this AIContent content, string? messageId = null, DateTimeOffset? createdAt = null, string? responseId = null, string? agentId = null, string? authorName = null) => + public static AgentResponseUpdate ToResponseUpdate(this AIContent content, string? messageId = null, DateTimeOffset? createdAt = null, string? responseId = null, string? agentId = null, string? authorName = null) => new() { Role = ChatRole.Assistant, @@ -37,7 +37,7 @@ public static AgentRunResponseUpdate ToResponseUpdate(this AIContent content, st Contents = [content], }; - public static IEnumerable ToAgentRunStream(this string message, DateTimeOffset? createdAt = null, string? messageId = null, string? responseId = null, string? agentId = null, string? authorName = null) + public static IEnumerable ToAgentRunStream(this string message, DateTimeOffset? createdAt = null, string? messageId = null, string? responseId = null, string? agentId = null, string? authorName = null) { messageId ??= Guid.NewGuid().ToString("N"); @@ -54,7 +54,7 @@ public static ChatMessage ToChatMessage(this IEnumerable contents, st RawRepresentation = rawRepresentation, }; - public static IEnumerable StreamMessage(this ChatMessage message, string? responseId = null, string? agentId = null) + public static IEnumerable StreamMessage(this ChatMessage message, string? responseId = null, string? agentId = null) { responseId ??= Guid.NewGuid().ToString("N"); string messageId = message.MessageId ?? Guid.NewGuid().ToString("N"); @@ -62,7 +62,7 @@ public static IEnumerable StreamMessage(this ChatMessage return message.Contents.Select(content => content.ToResponseUpdate(messageId, message.CreatedAt, responseId: responseId, agentId: agentId, authorName: message.AuthorName)); } - public static IEnumerable StreamMessages(this List messages, string? agentId = null) => + public static IEnumerable StreamMessages(this List messages, string? agentId = null) => messages.SelectMany(message => message.StreamMessage(agentId)); public static List ToChatMessages(this IEnumerable messages, string? authorName = null) diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs index b3e53da6f8..90b334ff01 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs @@ -39,7 +39,7 @@ public async Task RunAsyncShouldExecuteWorkflowAsync() run.OutgoingEvents.Should().NotBeEmpty("workflow should produce events during execution"); // Check that we have an agent execution event - var agentEvents = run.OutgoingEvents.OfType().ToList(); + var agentEvents = run.OutgoingEvents.OfType().ToList(); agentEvents.Should().NotBeEmpty("agent should have executed and produced update events"); // Check that we have output events @@ -79,7 +79,7 @@ public async Task StreamAsyncWithTurnTokenShouldExecuteWorkflowAsync() events.Should().NotBeEmpty("workflow should produce events during execution"); // Check that we have agent execution events - var agentEvents = events.OfType().ToList(); + var agentEvents = events.OfType().ToList(); agentEvents.Should().NotBeEmpty("agent should have executed and produced update events"); // Check that we have output events @@ -125,8 +125,8 @@ public async Task RunAsyncAndStreamAsyncShouldProduceSimilarResultsAsync() nonStreamingEvents.Should().NotBeEmpty("non-streaming version should also produce events"); // Both should have similar types of events - var streamingAgentEvents = streamingEvents.OfType().Count(); - var nonStreamingAgentEvents = nonStreamingEvents.OfType().Count(); + var streamingAgentEvents = streamingEvents.OfType().Count(); + var nonStreamingAgentEvents = nonStreamingEvents.OfType().Count(); nonStreamingAgentEvents.Should().Be(streamingAgentEvents, "both versions should produce the same number of agent events"); @@ -144,12 +144,12 @@ public SimpleTestAgent(string name) public override string Name { get; } - public override AgentThread GetNewThread() => new SimpleTestAgentThread(); + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) => new(new SimpleTestAgentThread()); - public override AgentThread DeserializeThread(System.Text.Json.JsonElement serializedThread, - System.Text.Json.JsonSerializerOptions? jsonSerializerOptions = null) => new SimpleTestAgentThread(); + public override ValueTask DeserializeThreadAsync(System.Text.Json.JsonElement serializedThread, + System.Text.Json.JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => new(new SimpleTestAgentThread()); - protected override Task RunCoreAsync( + protected override Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -157,10 +157,10 @@ protected override Task RunCoreAsync( { var lastMessage = messages.LastOrDefault(); var responseMessage = new ChatMessage(ChatRole.Assistant, $"Echo: {lastMessage?.Text ?? "no message"}"); - return Task.FromResult(new AgentRunResponse(responseMessage)); + return Task.FromResult(new AgentResponse(responseMessage)); } - protected override async IAsyncEnumerable RunCoreStreamingAsync( + protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -174,14 +174,14 @@ protected override async IAsyncEnumerable RunCoreStreami string messageId = Guid.NewGuid().ToString("N"); // Yield role first - yield return new AgentRunResponseUpdate(ChatRole.Assistant, this.Name) + yield return new AgentResponseUpdate(ChatRole.Assistant, this.Name) { AuthorName = this.Name, MessageId = messageId }; // Then yield content - yield return new AgentRunResponseUpdate(ChatRole.Assistant, responseText) + yield return new AgentResponseUpdate(ChatRole.Assistant, responseText) { AuthorName = this.Name, MessageId = messageId diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/MessageMergerTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/MessageMergerTests.cs index 793d01673e..93448aa327 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/MessageMergerTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/MessageMergerTests.cs @@ -23,12 +23,12 @@ public void Test_MessageMerger_AssemblesMessage() MessageMerger merger = new(); - foreach (AgentRunResponseUpdate update in "Hello Agent Framework Workflows!".ToAgentRunStream(authorName: TestAuthorName1, agentId: TestAgentId1, messageId: messageId, createdAt: creationTime, responseId: responseId)) + foreach (AgentResponseUpdate update in "Hello Agent Framework Workflows!".ToAgentRunStream(authorName: TestAuthorName1, agentId: TestAgentId1, messageId: messageId, createdAt: creationTime, responseId: responseId)) { merger.AddUpdate(update); } - AgentRunResponse response = merger.ComputeMerged(responseId); + AgentResponse response = merger.ComputeMerged(responseId); response.Messages.Should().HaveCount(1); response.Messages[0].Role.Should().Be(ChatRole.Assistant); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs index 5eb8696221..fab0c2bc3d 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs @@ -24,16 +24,16 @@ private sealed class TestExecutor() : Executor("TestExecutor") private sealed class TestAgent : AIAgent { - public override AgentThread GetNewThread() + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - protected override IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => + protected override IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs index a351c45b20..772d56bbcc 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs @@ -39,9 +39,9 @@ public static async ValueTask RunAsync(TextWriter writer, IWorkflowExecutionEnvi { Debug.WriteLine($"{executorCompleted.ExecutorId}: {executorCompleted.Data}"); } - else if (evt is AgentRunUpdateEvent update) + else if (evt is AgentResponseUpdateEvent update) { - AgentRunResponse response = update.AsResponse(); + AgentResponse response = update.AsResponse(); foreach (ChatMessage message in response.Messages) { @@ -60,23 +60,23 @@ internal sealed class HelloAgent(string id = nameof(HelloAgent)) : AIAgent protected override string? IdCore => id; public override string? Name => id; - public override AgentThread GetNewThread() - => new HelloAgentThread(); + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) + => new(new HelloAgentThread()); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - => new HelloAgentThread(); + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + => new(new HelloAgentThread()); - protected override async Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override async Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { - IEnumerable update = [ + IEnumerable update = [ await this.RunCoreStreamingAsync(messages, thread, options, cancellationToken) .SingleAsync(cancellationToken) .ConfigureAwait(false)]; - return update.ToAgentRunResponse(); + return update.ToAgentResponse(); } - protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { yield return new(ChatRole.Assistant, "Hello World!") { diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/07_GroupChat_Workflow_HostAsAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/07_GroupChat_Workflow_HostAsAgent.cs index 1739fe7a34..71844aff69 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/07_GroupChat_Workflow_HostAsAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/07_GroupChat_Workflow_HostAsAgent.cs @@ -19,8 +19,8 @@ public static async ValueTask RunAsync(TextWriter writer, IWorkflowExecutionEnvi for (int i = 0; i < numIterations; i++) { - AgentThread thread = agent.GetNewThread(); - await foreach (AgentRunResponseUpdate update in agent.RunStreamingAsync(thread).ConfigureAwait(false)) + AgentThread thread = await agent.GetNewThreadAsync(); + await foreach (AgentResponseUpdate update in agent.RunStreamingAsync(thread).ConfigureAwait(false)) { if (update.RawRepresentation is WorkflowEvent) { diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/10_Sequential_HostAsAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/10_Sequential_HostAsAgent.cs index fc23d44155..6b87aabd97 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/10_Sequential_HostAsAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/10_Sequential_HostAsAgent.cs @@ -21,10 +21,10 @@ public static async ValueTask RunAsync(TextWriter writer, IWorkflowExecutionEnvi { AIAgent hostAgent = WorkflowInstance.AsAgent("echo-workflow", "EchoW", executionEnvironment: executionEnvironment); - AgentThread thread = hostAgent.GetNewThread(); + AgentThread thread = await hostAgent.GetNewThreadAsync(); foreach (string input in inputs) { - AgentRunResponse response; + AgentResponse response; ResponseContinuationToken? continuationToken = null; do { diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/11_Concurrent_HostAsAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/11_Concurrent_HostAsAgent.cs index d47b90223c..dc939fda8b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/11_Concurrent_HostAsAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/11_Concurrent_HostAsAgent.cs @@ -33,10 +33,10 @@ public static async ValueTask RunAsync(TextWriter writer, IWorkflowExecutionEnvi { AIAgent hostAgent = WorkflowInstance.AsAgent("echo-workflow", "EchoW", executionEnvironment: executionEnvironment); - AgentThread thread = hostAgent.GetNewThread(); + AgentThread thread = await hostAgent.GetNewThreadAsync(); foreach (string input in inputs) { - AgentRunResponse response; + AgentResponse response; ResponseContinuationToken? continuationToken = null; do { diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs index c319a0ac32..5cf4e07120 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs @@ -69,10 +69,10 @@ public static async ValueTask RunAsync(TextWriter writer, IWorkflowExecutionEnvi { AIAgent hostAgent = WorkflowInstance.AsAgent("echo-workflow", "EchoW", executionEnvironment: executionEnvironment); - AgentThread thread = hostAgent.GetNewThread(); + AgentThread thread = await hostAgent.GetNewThreadAsync(); foreach (string input in inputs) { - AgentRunResponse response; + AgentResponse response; ResponseContinuationToken? continuationToken = null; do { diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs index ed9af701c6..ddd2b1fcdd 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs @@ -51,32 +51,32 @@ static ChatMessage ToMessage(string text) return result; } - public override AgentThread GetNewThread() - => new TestAgentThread(); + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) + => new(new TestAgentThread()); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - => new TestAgentThread(); + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + => new(new TestAgentThread()); public static TestAIAgent FromStrings(params string[] messages) => new(ToChatMessages(messages)); public List Messages { get; } = Validate(messages) ?? []; - protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => - Task.FromResult(new AgentRunResponse(this.Messages) + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => + Task.FromResult(new AgentResponse(this.Messages) { AgentId = this.Id, ResponseId = Guid.NewGuid().ToString("N") }); - protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { string responseId = Guid.NewGuid().ToString("N"); foreach (ChatMessage message in this.Messages) { foreach (AIContent content in message.Contents) { - yield return new AgentRunResponseUpdate() + yield return new AgentResponseUpdate() { AgentId = this.Id, MessageId = message.MessageId, diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs index 422d7a16ba..b971736b74 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs @@ -16,15 +16,13 @@ internal class TestEchoAgent(string? id = null, string? name = null, string? pre protected override string? IdCore => id; public override string? Name => name ?? base.Name; - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override async ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - return serializedThread.Deserialize(jsonSerializerOptions) ?? this.GetNewThread(); + return serializedThread.Deserialize(jsonSerializerOptions) ?? await this.GetNewThreadAsync(cancellationToken); } - public override AgentThread GetNewThread() - { - return new EchoAgentThread(); - } + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) => + new(new EchoAgentThread()); private static ChatMessage UpdateThread(ChatMessage message, InMemoryAgentThread? thread = null) { @@ -60,9 +58,9 @@ protected virtual IEnumerable GetEpilogueMessages(AgentRunOptions? return []; } - protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { - AgentRunResponse result = + AgentResponse result = new(this.EchoMessages(messages, thread, options).ToList()) { AgentId = this.Id, @@ -73,7 +71,7 @@ protected override Task RunCoreAsync(IEnumerable return Task.FromResult(result); } - protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { string responseId = Guid.NewGuid().ToString("N"); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs index be3c96d9f5..eed0d72dac 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs @@ -41,23 +41,23 @@ public Thread(JsonElement serializedThread, JsonSerializerOptions? jsonSerialize { } } - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - return new Thread(serializedThread, jsonSerializerOptions); + return new(new Thread(serializedThread, jsonSerializerOptions)); } - public override AgentThread GetNewThread() + public override ValueTask GetNewThreadAsync(CancellationToken cancellationToken = default) { - return new Thread(); + return new(new Thread()); } - protected override async Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + protected override async Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { return await this.RunStreamingAsync(messages, thread, options, cancellationToken) - .ToAgentRunResponseAsync(cancellationToken); + .ToAgentResponseAsync(cancellationToken); } - protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { const string ErrorMessage = "Simulated agent failure."; if (failByThrowing) @@ -65,7 +65,7 @@ protected override async IAsyncEnumerable RunCoreStreami throw new ExpectedException(ErrorMessage); } - yield return new AgentRunResponseUpdate(ChatRole.Assistant, [new ErrorContent(ErrorMessage)]); + yield return new AgentResponseUpdate(ChatRole.Assistant, [new ErrorContent(ErrorMessage)]); } } @@ -91,13 +91,13 @@ public async Task Test_AsAgent_ErrorContentStreamedOutAsync(bool includeExceptio Workflow workflow = CreateWorkflow(failByThrowing); // Act - List updates = await workflow.AsAgent("WorkflowAgent", includeExceptionDetails: includeExceptionDetails) + List updates = await workflow.AsAgent("WorkflowAgent", includeExceptionDetails: includeExceptionDetails) .RunStreamingAsync(new ChatMessage(ChatRole.User, "Hello")) .ToListAsync(); // Assert bool hadErrorContent = false; - foreach (AgentRunResponseUpdate update in updates) + foreach (AgentResponseUpdate update in updates) { if (update.Contents.Any()) { diff --git a/python/.cspell.json b/python/.cspell.json index 3fea304d38..73588b3b35 100644 --- a/python/.cspell.json +++ b/python/.cspell.json @@ -25,6 +25,7 @@ "words": [ "aeiou", "aiplatform", + "agui", "azuredocindex", "azuredocs", "azurefunctions", diff --git a/python/.vscode/settings.json b/python/.vscode/settings.json index 47da1de9e4..181b926ac0 100644 --- a/python/.vscode/settings.json +++ b/python/.vscode/settings.json @@ -18,9 +18,6 @@ }, "python.analysis.autoFormatStrings": true, "python.analysis.importFormat": "relative", - "python.analysis.exclude": [ - "samples/semantic-kernel-migration" - ], "python.analysis.packageIndexDepths": [ { "name": "agent_framework", diff --git a/python/CHANGELOG.md b/python/CHANGELOG.md index c1099e7ba7..a05b0a72a5 100644 --- a/python/CHANGELOG.md +++ b/python/CHANGELOG.md @@ -7,6 +7,34 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.0.0b260114] - 2026-01-14 + +### Added + +- **agent-framework-azure-ai**: Create/Get Agent API for Azure V2 ([#3059](https://github.com/microsoft/agent-framework/pull/3059)) by @moonbox3 +- **agent-framework-declarative**: Add declarative workflow runtime ([#2815](https://github.com/microsoft/agent-framework/pull/2815)) by @emattson +- **agent-framework-ag-ui**: Add dependencies param to ag-ui FastAPI endpoint ([#3191](https://github.com/microsoft/agent-framework/pull/3191)) by @emattson +- **agent-framework-ag-ui**: Add Pydantic request model and OpenAPI tags support to AG-UI FastAPI endpoint ([#2522](https://github.com/microsoft/agent-framework/pull/2522)) by @claude89757 +- **agent-framework-core**: Add tool call/result content types and update connectors and samples ([#2971](https://github.com/microsoft/agent-framework/pull/2971)) by @moonbox3 +- **agent-framework-core**: Add more specific exceptions to Workflow ([#3188](https://github.com/microsoft/agent-framework/pull/3188)) by @taochenms + +### Changed + +- **agent-framework-core**: [BREAKING] Refactor orchestrations ([#3023](https://github.com/microsoft/agent-framework/pull/3023)) by @taochenms +- **agent-framework-core**: [BREAKING] Introducing Options as TypedDict and Generic ([#3140](https://github.com/microsoft/agent-framework/pull/3140)) by @eavanvalkenburg +- **agent-framework-core**: [BREAKING] Removed display_name, renamed context_providers, middleware and AggregateContextProvider ([#3139](https://github.com/microsoft/agent-framework/pull/3139)) by @eavanvalkenburg +- **agent-framework-core**: MCP Improvements: improved connection loss behavior, pagination for loading and a param to control representation ([#3154](https://github.com/microsoft/agent-framework/pull/3154)) by @eavanvalkenburg +- **agent-framework-azure-ai**: Azure AI direct A2A endpoint support ([#3127](https://github.com/microsoft/agent-framework/pull/3127)) by @moonbox3 + +### Fixed + +- **agent-framework-anthropic**: Fix duplicate ToolCallStartEvent in streaming tool calls ([#3051](https://github.com/microsoft/agent-framework/pull/3051)) by @emattson +- **agent-framework-anthropic**: Fix Anthropic streaming response bugs ([#3141](https://github.com/microsoft/agent-framework/pull/3141)) by @eavanvalkenburg +- **agent-framework-ag-ui**: Execute tools with approval_mode, fix shared state, code cleanup ([#3079](https://github.com/microsoft/agent-framework/pull/3079)) by @emattson +- **agent-framework-azure-ai**: Fix AzureAIClient tool call bug for AG-UI use ([#3148](https://github.com/microsoft/agent-framework/pull/3148)) by @emattson +- **agent-framework-core**: Fix MCPStreamableHTTPTool to use new streamable_http_client API ([#3088](https://github.com/microsoft/agent-framework/pull/3088)) by @Copilot +- **agent-framework-core**: Multiple bug fixes ([#3150](https://github.com/microsoft/agent-framework/pull/3150)) by @eavanvalkenburg + ## [1.0.0b260107] - 2026-01-07 ### Added diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index cd85509a40..ba534436d6 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -26,8 +26,8 @@ from a2a.types import Part as A2APart from a2a.types import Role as A2ARole from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentThread, BaseAgent, ChatMessage, @@ -189,15 +189,15 @@ async def __aexit__( async def run( self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None, *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: """Get a response from the agent. This method returns the final result of the agent's execution - as a single AgentRunResponse object. The caller is blocked until + as a single AgentResponse object. The caller is blocked until the final result is available. Args: @@ -212,19 +212,19 @@ async def run( """ # Collect all updates and use framework to consolidate updates into response updates = [update async for update in self.run_stream(messages, thread=thread, **kwargs)] - return AgentRunResponse.from_agent_run_response_updates(updates) + return AgentResponse.from_agent_run_response_updates(updates) async def run_stream( self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None, *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: """Run the agent as a stream. This method will return the intermediate steps and final results of the - agent's execution as a stream of AgentRunResponseUpdate objects to the caller. + agent's execution as a stream of AgentResponseUpdate objects to the caller. Args: messages: The message(s) to send to the agent. @@ -245,7 +245,7 @@ async def run_stream( if isinstance(item, Message): # Process A2A Message contents = self._parse_contents_from_a2a(item.parts) - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=contents, role=Role.ASSISTANT if item.role == A2ARole.agent else Role.USER, response_id=str(getattr(item, "message_id", uuid.uuid4())), @@ -260,7 +260,7 @@ async def run_stream( for message in task_messages: # Use the artifact's ID from raw_representation as message_id for unique identification artifact_id = getattr(message.raw_representation, "artifact_id", None) - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=message.contents, role=message.role, response_id=task.id, @@ -269,7 +269,7 @@ async def run_stream( ) else: # Empty task - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=[], role=Role.ASSISTANT, response_id=task.id, diff --git a/python/packages/a2a/pyproject.toml b/python/packages/a2a/pyproject.toml index b9935fb9c3..1139d16e3b 100644 --- a/python/packages/a2a/pyproject.toml +++ b/python/packages/a2a/pyproject.toml @@ -4,7 +4,7 @@ description = "A2A integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py index 58ab18fee4..5d77345b20 100644 --- a/python/packages/a2a/tests/test_a2a_agent.py +++ b/python/packages/a2a/tests/test_a2a_agent.py @@ -21,8 +21,8 @@ ) from a2a.types import Role as A2ARole from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, ChatMessage, DataContent, ErrorContent, @@ -131,7 +131,7 @@ async def test_run_with_message_response(a2a_agent: A2AAgent, mock_a2a_client: M response = await a2a_agent.run("Hello agent") - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert len(response.messages) == 1 assert response.messages[0].role == Role.ASSISTANT assert response.messages[0].text == "Hello from agent!" @@ -146,7 +146,7 @@ async def test_run_with_task_response_single_artifact(a2a_agent: A2AAgent, mock_ response = await a2a_agent.run("Generate a report") - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert len(response.messages) == 1 assert response.messages[0].role == Role.ASSISTANT assert response.messages[0].text == "Generated report content" @@ -165,7 +165,7 @@ async def test_run_with_task_response_multiple_artifacts(a2a_agent: A2AAgent, mo response = await a2a_agent.run("Generate multiple outputs") - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert len(response.messages) == 3 assert response.messages[0].text == "First artifact content" @@ -185,7 +185,7 @@ async def test_run_with_task_response_no_artifacts(a2a_agent: A2AAgent, mock_a2a response = await a2a_agent.run("Do something with no output") - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert response.response_id == "task-empty" @@ -357,13 +357,13 @@ async def test_run_stream_with_message_response(a2a_agent: A2AAgent, mock_a2a_cl mock_a2a_client.add_message_response("msg-stream-123", "Streaming response from agent!", "agent") # Collect streaming updates - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in a2a_agent.run_stream("Hello agent"): updates.append(update) # Verify streaming response assert len(updates) == 1 - assert isinstance(updates[0], AgentRunResponseUpdate) + assert isinstance(updates[0], AgentResponseUpdate) assert updates[0].role == Role.ASSISTANT assert len(updates[0].contents) == 1 diff --git a/python/packages/ag-ui/README.md b/python/packages/ag-ui/README.md index 1e3d6b567f..ec5602cef9 100644 --- a/python/packages/ag-ui/README.md +++ b/python/packages/ag-ui/README.md @@ -82,6 +82,53 @@ This integration supports all 7 AG-UI features: 6. **Shared State**: Bidirectional state sync between client and server 7. **Predictive State Updates**: Stream tool arguments as optimistic state updates during execution +## Security: Authentication & Authorization + +The AG-UI endpoint does not enforce authentication by default. **For production deployments, you should add authentication** using FastAPI's dependency injection system via the `dependencies` parameter. + +### API Key Authentication Example + +```python +import os +from fastapi import Depends, FastAPI, HTTPException, Security +from fastapi.security import APIKeyHeader +from agent_framework import ChatAgent +from agent_framework.ag_ui import add_agent_framework_fastapi_endpoint + +# Configure API key authentication +API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False) +EXPECTED_API_KEY = os.environ.get("AG_UI_API_KEY") + +async def verify_api_key(api_key: str | None = Security(API_KEY_HEADER)) -> None: + """Verify the API key provided in the request header.""" + if not api_key or api_key != EXPECTED_API_KEY: + raise HTTPException(status_code=401, detail="Invalid or missing API key") + +# Create agent and app +agent = ChatAgent(name="my_agent", instructions="...", chat_client=...) +app = FastAPI() + +# Register endpoint WITH authentication +add_agent_framework_fastapi_endpoint( + app, + agent, + "/", + dependencies=[Depends(verify_api_key)], # Authentication enforced here +) +``` + +### Other Authentication Options + +The `dependencies` parameter accepts any FastAPI dependency, enabling integration with: + +- **OAuth 2.0 / OpenID Connect** - Use `fastapi.security.OAuth2PasswordBearer` +- **JWT Tokens** - Validate tokens with libraries like `python-jose` +- **Azure AD / Entra ID** - Use `azure-identity` for Microsoft identity platform +- **Rate Limiting** - Add request throttling dependencies +- **Custom Authentication** - Implement your organization's auth requirements + +For a complete authentication example, see [getting_started/server.py](getting_started/server.py). + ## Architecture The package uses a clean, orchestrator-based architecture: diff --git a/python/packages/ag-ui/agent_framework_ag_ui/__init__.py b/python/packages/ag-ui/agent_framework_ag_ui/__init__.py index 143f2499a0..c6dc575d36 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/__init__.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/__init__.py @@ -16,22 +16,32 @@ from ._endpoint import add_agent_framework_fastapi_endpoint from ._event_converters import AGUIEventConverter from ._http_service import AGUIHttpService +from ._types import AgentState, AGUIChatOptions, AGUIRequest, PredictStateConfig, RunMetadata try: __version__ = importlib.metadata.version(__name__) except importlib.metadata.PackageNotFoundError: __version__ = "0.0.0" +# Default OpenAPI tags for AG-UI endpoints +DEFAULT_TAGS = ["AG-UI"] + __all__ = [ "AgentFrameworkAgent", "add_agent_framework_fastapi_endpoint", "AGUIChatClient", + "AGUIChatOptions", "AGUIEventConverter", "AGUIHttpService", + "AGUIRequest", + "AgentState", "ConfirmationStrategy", "DefaultConfirmationStrategy", + "PredictStateConfig", + "RunMetadata", "TaskPlannerConfirmationStrategy", "RecipeConfirmationStrategy", "DocumentWriterConfirmationStrategy", + "DEFAULT_TAGS", "__version__", ] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py index 23860150be..806f5ab1bb 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py @@ -24,6 +24,7 @@ def __init__( self, state_schema: Any | None = None, predict_state_config: dict[str, dict[str, str]] | None = None, + use_service_thread: bool = False, require_confirmation: bool = True, ): """Initialize agent configuration. @@ -31,10 +32,12 @@ def __init__( Args: state_schema: Optional state schema for state management; accepts dict or Pydantic model/class predict_state_config: Configuration for predictive state updates + use_service_thread: Whether the agent thread is service-managed require_confirmation: Whether predictive updates require confirmation """ self.state_schema = self._normalize_state_schema(state_schema) self.predict_state_config = predict_state_config or {} + self.use_service_thread = use_service_thread self.require_confirmation = require_confirmation @staticmethod @@ -86,6 +89,7 @@ def __init__( predict_state_config: dict[str, dict[str, str]] | None = None, require_confirmation: bool = True, orchestrators: list[Orchestrator] | None = None, + use_service_thread: bool = False, confirmation_strategy: ConfirmationStrategy | None = None, ): """Initialize the AG-UI compatible agent wrapper. @@ -101,6 +105,7 @@ def __init__( Set to False for agentic generative UI that updates automatically. orchestrators: Custom orchestrators (auto-configured if None). Orchestrators are checked in order; first match handles the request. + use_service_thread: Whether the agent thread is service-managed. confirmation_strategy: Strategy for generating confirmation messages. Defaults to DefaultConfirmationStrategy if None. """ @@ -111,6 +116,7 @@ def __init__( self.config = AgentConfig( state_schema=state_schema, predict_state_config=predict_state_config, + use_service_thread=use_service_thread, require_confirmation=require_confirmation, ) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index db2f160a9d..e31036803c 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -4,17 +4,17 @@ import json import logging +import sys import uuid from collections.abc import AsyncIterable, MutableSequence from functools import wraps -from typing import Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, cast import httpx from agent_framework import ( AIFunction, BaseChatClient, ChatMessage, - ChatOptions, ChatResponse, ChatResponseUpdate, DataContent, @@ -30,6 +30,26 @@ from ._message_adapters import agent_framework_messages_to_agui from ._utils import convert_tools_to_agui_format +if TYPE_CHECKING: + from ._types import AGUIChatOptions + +from typing import TypedDict + +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover + +if sys.version_info >= (3, 11): + from typing import Self # pragma: no cover +else: + from typing_extensions import Self # pragma: no cover + logger: logging.Logger = logging.getLogger(__name__) @@ -55,7 +75,14 @@ def _unwrap_server_function_call_contents(contents: MutableSequence[Contents | d contents[idx] = content.function_call_content # type: ignore[assignment] -TBaseChatClient = TypeVar("TBaseChatClient", bound=type[BaseChatClient]) +TBaseChatClient = TypeVar("TBaseChatClient", bound=type[BaseChatClient[Any]]) + +TAGUIChatOptions = TypeVar( + "TAGUIChatOptions", + bound=TypedDict, # type: ignore[valid-type] + default="AGUIChatOptions", + covariant=True, +) def _apply_server_function_call_unwrap(chat_client: TBaseChatClient) -> TBaseChatClient: @@ -91,7 +118,7 @@ async def response_wrapper(self, *args: Any, **kwargs: Any) -> ChatResponse: @use_function_invocation @use_instrumentation @use_chat_middleware -class AGUIChatClient(BaseChatClient): +class AGUIChatClient(BaseChatClient[TAGUIChatOptions], Generic[TAGUIChatOptions]): """Chat client for communicating with AG-UI compliant servers. This client implements the BaseChatClient interface and automatically handles: @@ -168,6 +195,19 @@ class AGUIChatClient(BaseChatClient): async with AGUIChatClient(endpoint="http://localhost:8888/") as client: response = await client.get_response("Hello!") print(response.messages[0].text) + + Using custom ChatOptions with type safety: + + .. code-block:: python + + from typing import TypedDict + from agent_framework_ag_ui import AGUIChatClient, AGUIChatOptions + + class MyOptions(AGUIChatOptions, total=False): + my_custom_option: str + + client: AGUIChatClient[MyOptions] = AGUIChatClient(endpoint="http://localhost:8888/") + response = await client.get_response("Hello", options={"my_custom_option": "value"}) """ OTEL_PROVIDER_NAME = "agui" @@ -201,7 +241,7 @@ async def close(self) -> None: """Close the HTTP client.""" await self._http_service.close() - async def __aenter__(self) -> "AGUIChatClient": + async def __aenter__(self) -> Self: """Enter async context manager.""" return self @@ -280,36 +320,38 @@ def _convert_messages_to_agui_format(self, messages: list[ChatMessage]) -> list[ """ return agent_framework_messages_to_agui(messages) - def _get_thread_id(self, chat_options: ChatOptions) -> str: + def _get_thread_id(self, options: dict[str, Any]) -> str: """Get or generate thread ID from chat options. Args: - chat_options: Chat options containing metadata + options: Chat options containing metadata Returns: Thread ID string """ thread_id = None - if chat_options.metadata: - thread_id = chat_options.metadata.get("thread_id") + metadata = options.get("metadata") + if metadata: + thread_id = metadata.get("thread_id") if not thread_id: thread_id = f"thread_{uuid.uuid4().hex}" return thread_id + @override async def _inner_get_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> ChatResponse: """Internal method to get non-streaming response. Keyword Args: messages: List of chat messages - chat_options: Chat options for the request + options: Chat options for the request **kwargs: Additional keyword arguments Returns: @@ -318,23 +360,24 @@ async def _inner_get_response( return await ChatResponse.from_chat_response_generator( self._inner_get_streaming_response( messages=messages, - chat_options=chat_options, + options=options, **kwargs, ) ) + @override async def _inner_get_streaming_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: """Internal method to get streaming response. Keyword Args: messages: List of chat messages - chat_options: Chat options for the request + options: Chat options for the request **kwargs: Additional keyword arguments Yields: @@ -342,20 +385,21 @@ async def _inner_get_streaming_response( """ messages_to_send, state = self._extract_state_from_messages(messages) - thread_id = self._get_thread_id(chat_options) + thread_id = self._get_thread_id(options) run_id = f"run_{uuid.uuid4().hex}" agui_messages = self._convert_messages_to_agui_format(messages_to_send) # Send client tools to server so LLM knows about them # Client tools execute via ChatAgent's @use_function_invocation wrapper - agui_tools = convert_tools_to_agui_format(chat_options.tools) + agui_tools = convert_tools_to_agui_format(options.get("tools")) # Build set of client tool names (matches .NET clientToolSet) # Used to distinguish client vs server tools in response stream client_tool_set: set[str] = set() - if chat_options.tools: - for tool in chat_options.tools: + tools = options.get("tools") + if tools: + for tool in tools: if hasattr(tool, "name"): client_tool_set.add(tool.name) # type: ignore[arg-type] self._last_client_tool_set = client_tool_set # type: ignore[attr-defined] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py b/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py index eedf88db14..7948d4f935 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py @@ -4,14 +4,17 @@ import copy import logging +from collections.abc import Sequence from typing import Any from ag_ui.encoder import EventEncoder from agent_framework import AgentProtocol -from fastapi import FastAPI, Request +from fastapi import FastAPI +from fastapi.params import Depends from fastapi.responses import StreamingResponse from ._agent import AgentFrameworkAgent +from ._types import AGUIRequest logger = logging.getLogger(__name__) @@ -24,6 +27,8 @@ def add_agent_framework_fastapi_endpoint( predict_state_config: dict[str, dict[str, str]] | None = None, allow_origins: list[str] | None = None, default_state: dict[str, Any] | None = None, + tags: list[str] | None = None, + dependencies: Sequence[Depends] | None = None, ) -> None: """Add an AG-UI endpoint to a FastAPI app. @@ -36,6 +41,11 @@ def add_agent_framework_fastapi_endpoint( Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}} allow_origins: CORS origins (not yet implemented) default_state: Optional initial state to seed when the client does not provide state keys + tags: OpenAPI tags for endpoint categorization (defaults to ["AG-UI"]) + dependencies: Optional FastAPI dependencies for authentication/authorization. + These dependencies run before the endpoint handler. Use this to add + authentication checks, rate limiting, or other middleware-like behavior. + Example: `dependencies=[Depends(verify_api_key)]` """ if isinstance(agent, AgentProtocol): wrapped_agent = AgentFrameworkAgent( @@ -46,15 +56,15 @@ def add_agent_framework_fastapi_endpoint( else: wrapped_agent = agent - @app.post(path) - async def agent_endpoint(request: Request): # type: ignore[misc] + @app.post(path, tags=tags or ["AG-UI"], dependencies=dependencies) # type: ignore[arg-type] + async def agent_endpoint(request_body: AGUIRequest): # type: ignore[misc] """Handle AG-UI agent requests. Note: Function is accessed via FastAPI's decorator registration, despite appearing unused to static analysis. """ try: - input_data = await request.json() + input_data = request_body.model_dump(exclude_none=True) if default_state: state = input_data.setdefault("state", {}) for key, value in default_state.items(): diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py index 812c99064d..ddf3ebba01 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_events.py @@ -24,7 +24,7 @@ ToolCallStartEvent, ) from agent_framework import ( - AgentRunResponseUpdate, + AgentResponseUpdate, FunctionApprovalRequestContent, FunctionCallContent, FunctionResultContent, @@ -81,9 +81,9 @@ def __init__( self.should_stop_after_confirm: bool = False # Flag to stop run after confirm_changes self.suppressed_summary: str = "" # Store LLM summary to show after confirmation - async def from_agent_run_update(self, update: AgentRunResponseUpdate) -> list[BaseEvent]: + async def from_agent_run_update(self, update: AgentResponseUpdate) -> list[BaseEvent]: """ - Convert an AgentRunResponseUpdate to AG-UI events. + Convert an AgentResponseUpdate to AG-UI events. Args: update: The agent run update to convert. diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py index 977c276627..0f86516448 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py @@ -3,57 +3,85 @@ """Tool handling helpers.""" import logging -from typing import Any +from typing import TYPE_CHECKING, Any -from agent_framework import BaseChatClient, ChatAgent +from agent_framework import BaseChatClient + +if TYPE_CHECKING: + from agent_framework import AgentProtocol logger = logging.getLogger(__name__) -def collect_server_tools(agent: Any) -> list[Any]: - """Collect server tools from ChatAgent or duck-typed agent.""" - if isinstance(agent, ChatAgent): - tools_from_agent = agent.chat_options.tools - server_tools = list(tools_from_agent) if tools_from_agent else [] - logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools") - for tool in server_tools: - tool_name = getattr(tool, "name", "unknown") - approval_mode = getattr(tool, "approval_mode", None) - logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}") - return server_tools - - try: - chat_options_attr = getattr(agent, "chat_options", None) - if chat_options_attr is not None: - return getattr(chat_options_attr, "tools", None) or [] - except AttributeError: +def _collect_mcp_tool_functions(mcp_tools: list[Any]) -> list[Any]: + """Extract functions from connected MCP tools. + + Args: + mcp_tools: List of MCP tool instances. + + Returns: + List of functions from connected MCP tools. + """ + functions: list[Any] = [] + for mcp_tool in mcp_tools: + if getattr(mcp_tool, "is_connected", False) and hasattr(mcp_tool, "functions"): + functions.extend(mcp_tool.functions) + return functions + + +def collect_server_tools(agent: "AgentProtocol") -> list[Any]: + """Collect server tools from an agent. + + This includes both regular tools from default_options and MCP tools. + MCP tools are stored separately for lifecycle management but their + functions need to be included for tool execution during approval flows. + + Args: + agent: Agent instance to collect tools from. Works with ChatAgent + or any agent with default_options and optional mcp_tools attributes. + + Returns: + List of tools including both regular tools and connected MCP tool functions. + """ + # Get tools from default_options + default_options = getattr(agent, "default_options", None) + if default_options is None: return [] - return [] + tools_from_agent = default_options.get("tools") if isinstance(default_options, dict) else None + server_tools = list(tools_from_agent) if tools_from_agent else [] + + # Include functions from connected MCP tools (only available on ChatAgent) + mcp_tools = getattr(agent, "mcp_tools", None) + if mcp_tools: + server_tools.extend(_collect_mcp_tool_functions(mcp_tools)) + + logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools") + for tool in server_tools: + tool_name = getattr(tool, "name", "unknown") + approval_mode = getattr(tool, "approval_mode", None) + logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}") + return server_tools -def register_additional_client_tools(agent: Any, client_tools: list[Any] | None) -> None: - """Register client tools as additional declaration-only tools to avoid server execution.""" + +def register_additional_client_tools(agent: "AgentProtocol", client_tools: list[Any] | None) -> None: + """Register client tools as additional declaration-only tools to avoid server execution. + + Args: + agent: Agent instance to register tools on. Works with ChatAgent + or any agent with a chat_client attribute. + client_tools: List of client tools to register. + """ if not client_tools: return - if isinstance(agent, ChatAgent): - chat_client = agent.chat_client - if isinstance(chat_client, BaseChatClient) and chat_client.function_invocation_configuration is not None: - chat_client.function_invocation_configuration.additional_tools = client_tools - logger.debug(f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)") + chat_client = getattr(agent, "chat_client", None) + if chat_client is None: return - try: - chat_client_attr = getattr(agent, "chat_client", None) - if chat_client_attr is not None: - fic = getattr(chat_client_attr, "function_invocation_configuration", None) - if fic is not None: - fic.additional_tools = client_tools # type: ignore[attr-defined] - logger.debug( - f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)" - ) - except AttributeError: - return + if isinstance(chat_client, BaseChatClient) and chat_client.function_invocation_configuration is not None: + chat_client.function_invocation_configuration.additional_tools = client_tools + logger.debug(f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)") def merge_tools(server_tools: list[Any], client_tools: list[Any] | None) -> list[Any] | None: diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 3067e3e4a7..b5566f0aec 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -6,7 +6,7 @@ import logging import uuid from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Sequence from typing import TYPE_CHECKING, Any from ag_ui.core import ( @@ -53,11 +53,18 @@ merge_tools, register_additional_client_tools, ) -from ._utils import convert_agui_tools_to_agent_framework, generate_event_id, get_role_value +from ._utils import ( + convert_agui_tools_to_agent_framework, + generate_event_id, + get_conversation_id_from_update, + get_role_value, +) if TYPE_CHECKING: from ._agent import AgentConfig from ._confirmation_strategies import ConfirmationStrategy + from ._events import AgentFrameworkEventBridge + from ._orchestration._state_manager import StateManager logger = logging.getLogger(__name__) @@ -92,6 +99,8 @@ def __init__( self._last_message = None self._run_id: str | None = None self._thread_id: str | None = None + self._supplied_run_id: str | None = None + self._supplied_thread_id: str | None = None @property def messages(self): @@ -125,26 +134,66 @@ def last_message(self): self._last_message = self.messages[-1] return self._last_message + @property + def supplied_run_id(self) -> str | None: + """Get the supplied run ID, if any.""" + if self._supplied_run_id is None: + self._supplied_run_id = self.input_data.get("run_id") or self.input_data.get("runId") + return self._supplied_run_id + @property def run_id(self) -> str: - """Get or generate run ID.""" + """Get supplied run ID or generate a new run ID.""" + if self._run_id: + return self._run_id + + if self.supplied_run_id: + self._run_id = self.supplied_run_id + if self._run_id is None: - self._run_id = self.input_data.get("run_id") or self.input_data.get("runId") or str(uuid.uuid4()) - # This should never be None after the if block above, but satisfy type checkers - if self._run_id is None: # pragma: no cover - raise RuntimeError("Failed to initialize run_id") + self._run_id = str(uuid.uuid4()) + return self._run_id + @property + def supplied_thread_id(self) -> str | None: + """Get the supplied thread ID, if any.""" + if self._supplied_thread_id is None: + self._supplied_thread_id = self.input_data.get("thread_id") or self.input_data.get("threadId") + return self._supplied_thread_id + @property def thread_id(self) -> str: - """Get or generate thread ID.""" + """Get supplied thread ID or generate a new thread ID.""" + if self._thread_id: + return self._thread_id + + if self.supplied_thread_id: + self._thread_id = self.supplied_thread_id + if self._thread_id is None: - self._thread_id = self.input_data.get("thread_id") or self.input_data.get("threadId") or str(uuid.uuid4()) - # This should never be None after the if block above, but satisfy type checkers - if self._thread_id is None: # pragma: no cover - raise RuntimeError("Failed to initialize thread_id") + self._thread_id = str(uuid.uuid4()) + return self._thread_id + def update_run_id(self, new_run_id: str) -> None: + """Update the run ID in the context. + + Args: + new_run_id: The new run ID to set + """ + self._supplied_run_id = new_run_id + self._run_id = new_run_id + + def update_thread_id(self, new_thread_id: str) -> None: + """Update the thread ID in the context. + + Args: + new_thread_id: The new thread ID to set + """ + self._supplied_thread_id = new_thread_id + self._thread_id = new_thread_id + class Orchestrator(ABC): """Base orchestrator for agent execution flows.""" @@ -297,6 +346,28 @@ def can_handle(self, context: ExecutionContext) -> bool: """ return True + def _create_initial_events( + self, event_bridge: "AgentFrameworkEventBridge", state_manager: "StateManager" + ) -> Sequence[BaseEvent]: + """Generate initial events for the run. + + Args: + event_bridge: Event bridge for creating events + Returns: + Initial AG-UI events + """ + events: list[BaseEvent] = [event_bridge.create_run_started_event()] + + predict_event = state_manager.predict_state_event() + if predict_event: + events.append(predict_event) + + snapshot_event = state_manager.initial_snapshot_event(event_bridge) + if snapshot_event: + events.append(snapshot_event) + + return events + async def run( self, context: ExecutionContext, @@ -319,7 +390,7 @@ async def run( response_format = None if isinstance(context.agent, ChatAgent): - response_format = context.agent.chat_options.response_format + response_format = context.agent.default_options.get("response_format") skip_text_content = response_format is not None client_tools = convert_agui_tools_to_agent_framework(context.input_data.get("tools")) @@ -342,17 +413,11 @@ async def run( approval_tool_name=approval_tool_name, ) - yield event_bridge.create_run_started_event() - - predict_event = state_manager.predict_state_event() - if predict_event: - yield predict_event - - snapshot_event = state_manager.initial_snapshot_event(event_bridge) - if snapshot_event: - yield snapshot_event + if context.config.use_service_thread: + thread = AgentThread(service_thread_id=context.supplied_thread_id) + else: + thread = AgentThread() - thread = AgentThread() thread.metadata = { # type: ignore[attr-defined] "ag_ui_thread_id": context.thread_id, "ag_ui_run_id": context.run_id, @@ -363,6 +428,8 @@ async def run( provider_messages = context.messages or [] snapshot_messages = context.snapshot_messages if not provider_messages: + for event in self._create_initial_events(event_bridge, state_manager): + yield event logger.warning("No messages provided in AG-UI input") yield event_bridge.create_run_finished_event() return @@ -434,10 +501,10 @@ async def run( run_kwargs: dict[str, Any] = { "thread": thread, "tools": tools_param, - "metadata": safe_metadata, + "options": {"metadata": safe_metadata}, } if safe_metadata: - run_kwargs["store"] = True + run_kwargs["options"]["store"] = True async def _resolve_approval_responses( messages: list[Any], @@ -554,13 +621,41 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap confirmation_message = strategy.on_state_rejected() message_id = generate_event_id() + for event in self._create_initial_events(event_bridge, state_manager): + yield event yield TextMessageStartEvent(message_id=message_id, role="assistant") yield TextMessageContentEvent(message_id=message_id, delta=confirmation_message) yield TextMessageEndEvent(message_id=message_id) yield event_bridge.create_run_finished_event() return + should_recreate_event_bridge = False async for update in context.agent.run_stream(messages_to_run, **run_kwargs): + conv_id = get_conversation_id_from_update(update) + if conv_id and conv_id != context.thread_id: + context.update_thread_id(conv_id) + should_recreate_event_bridge = True + + if update.response_id and update.response_id != context.run_id: + context.update_run_id(update.response_id) + should_recreate_event_bridge = True + + if should_recreate_event_bridge: + event_bridge = AgentFrameworkEventBridge( + run_id=context.run_id, + thread_id=context.thread_id, + predict_state_config=context.config.predict_state_config, + current_state=current_state, + skip_text_content=skip_text_content, + require_confirmation=context.config.require_confirmation, + approval_tool_name=approval_tool_name, + ) + should_recreate_event_bridge = False + + if update_count == 0: + for event in self._create_initial_events(event_bridge, state_manager): + yield event + update_count += 1 logger.info(f"[STREAM] Received update #{update_count} from agent") if all_updates is not None: @@ -646,11 +741,11 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap yield end_event if response_format and all_updates: - from agent_framework import AgentRunResponse + from agent_framework import AgentResponse from pydantic import BaseModel logger.info(f"Processing structured output, update count: {len(all_updates)}") - final_response = AgentRunResponse.from_agent_run_response_updates( + final_response = AgentResponse.from_agent_run_response_updates( all_updates, output_format_type=response_format ) @@ -672,6 +767,11 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap yield TextMessageEndEvent(message_id=message_id) logger.info(f"Emitted conversational message with length={len(response_dict['message'])}") + if all_updates is not None and len(all_updates) == 0: + logger.info("No updates received from agent - emitting initial events") + for event in self._create_initial_events(event_bridge, state_manager): + yield event + logger.info(f"[FINALIZE] Checking for unclosed message. current_message_id={event_bridge.current_message_id}") if event_bridge.current_message_id: logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_types.py b/python/packages/ag-ui/agent_framework_ag_ui/_types.py index da7d80ea66..226abae692 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_types.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_types.py @@ -2,8 +2,25 @@ """Type definitions for AG-UI integration.""" +import sys from typing import Any, TypedDict +from agent_framework import ChatOptions + +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + +__all__ = [ + "AGUIChatOptions", + "AgentState", + "PredictStateConfig", + "RunMetadata", +] + +from pydantic import BaseModel, Field + class PredictStateConfig(TypedDict): """Configuration for predictive state updates.""" @@ -25,3 +42,97 @@ class AgentState(TypedDict): """Base state for AG-UI agents.""" messages: list[Any] | None + + +class AGUIRequest(BaseModel): + """Request model for AG-UI endpoints.""" + + messages: list[dict[str, Any]] = Field( + ..., + description="AG-UI format messages array", + ) + run_id: str | None = Field( + None, + description="Optional run identifier for tracking", + ) + thread_id: str | None = Field( + None, + description="Optional thread identifier for conversation context", + ) + state: dict[str, Any] | None = Field( + None, + description="Optional shared state for agentic generative UI", + ) + + +# region AG-UI Chat Options TypedDict + + +class AGUIChatOptions(ChatOptions, total=False): + """AG-UI protocol-specific chat options dict. + + Extends base ChatOptions for the AG-UI (Agent-UI) protocol. + AG-UI is a streaming protocol for connecting AI agents to user interfaces. + Options are forwarded to the remote AG-UI server. + + See: https://github.com/ag-ui/ag-ui-protocol + + Keys: + # Inherited from ChatOptions (forwarded to remote server): + model_id: The model identifier (forwarded as-is to server). + temperature: Sampling temperature. + top_p: Nucleus sampling parameter. + max_tokens: Maximum tokens to generate. + stop: Stop sequences. + tools: List of tools - sent to server so LLM knows about client tools. + Server executes its own tools; client tools execute locally via + @use_function_invocation middleware. + tool_choice: How the model should use tools. + metadata: Metadata dict containing thread_id for conversation continuity. + + # Options with limited support (depends on remote server): + frequency_penalty: Forwarded if remote server supports it. + presence_penalty: Forwarded if remote server supports it. + seed: Forwarded if remote server supports it. + response_format: Forwarded if remote server supports it. + logit_bias: Forwarded if remote server supports it. + user: Forwarded if remote server supports it. + + # Options not typically used in AG-UI: + store: Not applicable for AG-UI protocol. + allow_multiple_tool_calls: Handled by underlying server. + + # AG-UI-specific options: + forward_props: Additional properties to forward to the AG-UI server. + Useful for passing custom parameters to specific server implementations. + context: Shared context/state to send to the server. + + Note: + AG-UI is a protocol bridge - actual option support depends on the + remote server implementation. The client sends all options to the + server, which decides how to handle them. + + Thread ID management: + - Pass ``thread_id`` in ``metadata`` to maintain conversation continuity + - If not provided, a new thread ID is auto-generated + """ + + # AG-UI-specific options + forward_props: dict[str, Any] + """Additional properties to forward to the AG-UI server.""" + + context: dict[str, Any] + """Shared context/state to send to the server.""" + + # ChatOptions fields not applicable for AG-UI + store: None # type: ignore[misc] + """Not applicable for AG-UI protocol.""" + + +AGUI_OPTION_TRANSLATIONS: dict[str, str] = {} +"""Maps ChatOptions keys to AG-UI parameter names (protocol uses standard names).""" + +TAGUIChatOptions = TypeVar("TAGUIChatOptions", bound=TypedDict, default="AGUIChatOptions", covariant=True) # type: ignore[valid-type] + + +# endregion diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py index c0da986308..9f42e24770 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py @@ -10,7 +10,7 @@ from datetime import date, datetime from typing import Any -from agent_framework import AIFunction, Role, ToolProtocol +from agent_framework import AgentResponseUpdate, AIFunction, ChatResponseUpdate, Role, ToolProtocol # Role mapping constants AGUI_TO_FRAMEWORK_ROLE: dict[str, Role] = { @@ -259,3 +259,17 @@ def convert_tools_to_agui_format( continue return results if results else None + + +def get_conversation_id_from_update(update: AgentResponseUpdate) -> str | None: + """Extract conversation ID from AgentResponseUpdate metadata. + + Args: + update: AgentRunResponseUpdate instance + Returns: + Conversation ID if present, else None + + """ + if isinstance(update.raw_representation, ChatResponseUpdate): + return update.raw_representation.conversation_id + return None diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/README.md b/python/packages/ag-ui/agent_framework_ag_ui_examples/README.md index 620f18dbbf..e9d6d4ed17 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/README.md +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/README.md @@ -169,7 +169,7 @@ The package uses a clean, orchestrator-based architecture: - **AgentFrameworkAgent**: Lightweight wrapper that delegates to orchestrators - **Orchestrators**: Handle different execution flows (default, human-in-the-loop, etc.) - **Confirmation Strategies**: Domain-specific confirmation messages (extensible) -- **AgentFrameworkEventBridge**: Converts AgentRunResponseUpdate to AG-UI events +- **AgentFrameworkEventBridge**: Converts AgentResponseUpdate to AG-UI events - **Message Adapters**: Bidirectional conversion between AG-UI and Agent Framework message formats - **FastAPI Endpoint**: Streaming HTTP endpoint with Server-Sent Events (SSE) @@ -198,10 +198,10 @@ def my_tool(param: str) -> str: def my_custom_agent(chat_client: ChatClientProtocol) -> AgentFrameworkAgent: """Create a custom agent with the specified chat client. - + Args: chat_client: The chat client to use for the agent - + Returns: A configured AgentFrameworkAgent instance """ @@ -211,7 +211,7 @@ def my_custom_agent(chat_client: ChatClientProtocol) -> AgentFrameworkAgent: chat_client=chat_client, tools=[my_tool], ) - + return AgentFrameworkAgent( agent=agent, name="MyCustomAgent", @@ -302,13 +302,13 @@ from agent_framework.ag_ui import AgentFrameworkAgent, ConfirmationStrategy class CustomConfirmationStrategy(ConfirmationStrategy): def on_approval_accepted(self, steps: list[dict[str, Any]]) -> str: return "Your custom approval message!" - + def on_approval_rejected(self, steps: list[dict[str, Any]]) -> str: return "Your custom rejection message!" - + def on_state_confirmed(self) -> str: return "State changes confirmed!" - + def on_state_rejected(self) -> str: return "State changes rejected!" @@ -349,7 +349,7 @@ class MyCustomOrchestrator(Orchestrator): def can_handle(self, context: ExecutionContext) -> bool: # Return True if this orchestrator should handle the request return context.input_data.get("custom_mode") == True - + async def run(self, context: ExecutionContext): # Custom execution logic yield RunStartedEvent(...) diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/human_in_the_loop_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/human_in_the_loop_agent.py index ab7a3533cd..dbfdab5272 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/human_in_the_loop_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/human_in_the_loop_agent.py @@ -3,6 +3,7 @@ """Human-in-the-loop agent demonstrating step customization (Feature 5).""" from enum import Enum +from typing import Any from agent_framework import ChatAgent, ChatClientProtocol, ai_function from pydantic import BaseModel, Field @@ -42,7 +43,7 @@ def generate_task_steps(steps: list[TaskStep]) -> str: return f"Generated {len(steps)} execution steps for the task." -def human_in_the_loop_agent(chat_client: ChatClientProtocol) -> ChatAgent: +def human_in_the_loop_agent(chat_client: ChatClientProtocol[Any]) -> ChatAgent[Any]: """Create a human-in-the-loop agent using tool-based approach for predictive state. Args: diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/recipe_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/recipe_agent.py index 051937f2a9..05c42efb30 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/recipe_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/recipe_agent.py @@ -3,6 +3,7 @@ """Recipe agent example demonstrating shared state management (Feature 3).""" from enum import Enum +from typing import Any from agent_framework import ChatAgent, ChatClientProtocol, ai_function from agent_framework.ag_ui import AgentFrameworkAgent, RecipeConfirmationStrategy @@ -101,7 +102,7 @@ def update_recipe(recipe: Recipe) -> str: """ -def recipe_agent(chat_client: ChatClientProtocol) -> AgentFrameworkAgent: +def recipe_agent(chat_client: ChatClientProtocol[Any]) -> AgentFrameworkAgent: """Create a recipe agent with streaming state updates. Args: diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/research_assistant_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/research_assistant_agent.py index ad5c4f425c..52515bc0a4 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/research_assistant_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/research_assistant_agent.py @@ -3,6 +3,7 @@ """Example agent demonstrating agentic generative UI with custom events during execution.""" import asyncio +from typing import Any from agent_framework import ChatAgent, ChatClientProtocol, ai_function from agent_framework.ag_ui import AgentFrameworkAgent @@ -87,8 +88,8 @@ async def analyze_data(dataset: str) -> str: ) -def research_assistant_agent(chat_client: ChatClientProtocol) -> AgentFrameworkAgent: - """Create a research assistant agent with progress events. +def research_assistant_agent(chat_client: ChatClientProtocol[Any]) -> AgentFrameworkAgent: + """Create a research assistant agent. Args: chat_client: The chat client to use for the agent diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/simple_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/simple_agent.py index e4bffaea0d..3e72fd3a11 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/simple_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/simple_agent.py @@ -2,10 +2,12 @@ """Simple agentic chat example (Feature 1: Agentic Chat).""" +from typing import Any + from agent_framework import ChatAgent, ChatClientProtocol -def simple_agent(chat_client: ChatClientProtocol) -> ChatAgent: +def simple_agent(chat_client: ChatClientProtocol[Any]) -> ChatAgent[Any]: """Create a simple chat agent. Args: @@ -14,7 +16,7 @@ def simple_agent(chat_client: ChatClientProtocol) -> ChatAgent: Returns: A configured ChatAgent instance """ - return ChatAgent( + return ChatAgent[Any]( name="simple_chat_agent", instructions="You are a helpful assistant. Be concise and friendly.", chat_client=chat_client, diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_planner_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_planner_agent.py index 6609f06aa6..c79c36f511 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_planner_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_planner_agent.py @@ -2,6 +2,8 @@ """Example agent demonstrating human-in-the-loop with function approvals.""" +from typing import Any + from agent_framework import ChatAgent, ChatClientProtocol, ai_function from agent_framework.ag_ui import AgentFrameworkAgent, TaskPlannerConfirmationStrategy @@ -59,7 +61,7 @@ def book_meeting_room(room_name: str, date: str, start_time: str, end_time: str) ) -def task_planner_agent(chat_client: ChatClientProtocol) -> AgentFrameworkAgent: +def task_planner_agent(chat_client: ChatClientProtocol[Any]) -> AgentFrameworkAgent: """Create a task planner agent with user approval for actions. Args: diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py index 567dd348b4..572df2720b 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py @@ -52,7 +52,7 @@ def generate_task_steps(steps: list[TaskStep]) -> str: return "Steps generated." -def _create_task_steps_agent(chat_client: ChatClientProtocol) -> AgentFrameworkAgent: +def _create_task_steps_agent(chat_client: ChatClientProtocol[Any]) -> AgentFrameworkAgent: """Create the task steps agent using tool-based approach for streaming. Args: @@ -61,7 +61,7 @@ def _create_task_steps_agent(chat_client: ChatClientProtocol) -> AgentFrameworkA Returns: A configured AgentFrameworkAgent instance """ - agent = ChatAgent( + agent = ChatAgent[Any]( name="task_steps_agent", instructions="""You are a helpful assistant that breaks down tasks into actionable steps. @@ -331,7 +331,7 @@ async def run_agent(self, input_data: dict[str, Any]) -> AsyncGenerator[Any, Non yield run_finished_event -def task_steps_agent_wrapped(chat_client: ChatClientProtocol) -> TaskStepsAgentWithExecution: +def task_steps_agent_wrapped(chat_client: ChatClientProtocol[Any]) -> TaskStepsAgentWithExecution: """Create a task steps agent with execution simulation. Args: diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py index 0a99e6f1a1..db1788fd25 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py @@ -2,11 +2,17 @@ """Example agent demonstrating Tool-based Generative UI (Feature 5).""" -from typing import Any +import sys +from typing import Any, TypedDict -from agent_framework import AIFunction, ChatAgent, ChatClientProtocol +from agent_framework import AIFunction, ChatAgent, ChatClientProtocol, ChatOptions from agent_framework.ag_ui import AgentFrameworkAgent +if sys.version_info >= (3, 13): + from typing import TypeVar # type: ignore # pragma: no cover +else: + from typing_extensions import TypeVar # type: ignore # pragma: no cover + # Declaration-only tools (func=None) - actual rendering happens on the client side generate_haiku = AIFunction[Any, str]( name="generate_haiku", @@ -150,15 +156,17 @@ For other requests, use the appropriate tool (create_chart, display_timeline, show_comparison_table). """ +TOptions = TypeVar("TOptions", bound=TypedDict, default="ChatOptions") # type: ignore[valid-type] + -def ui_generator_agent(chat_client: ChatClientProtocol) -> AgentFrameworkAgent: - """Create a UI generator agent with frontend rendering tools. +def ui_generator_agent(chat_client: ChatClientProtocol[TOptions]) -> AgentFrameworkAgent: + """Create a UI generator agent with custom React component rendering. Args: chat_client: The chat client to use for the agent Returns: - A configured AgentFrameworkAgent instance with UI generation tools + A configured AgentFrameworkAgent instance with UI generation capabilities """ agent = ChatAgent( name="ui_generator", @@ -166,7 +174,7 @@ def ui_generator_agent(chat_client: ChatClientProtocol) -> AgentFrameworkAgent: chat_client=chat_client, tools=[generate_haiku, create_chart, display_timeline, show_comparison_table], # Force tool usage - the LLM MUST call a tool, cannot respond with plain text - chat_options={"tool_choice": "required"}, + default_options={"tool_choice": "required"}, # type: ignore ) return AgentFrameworkAgent( diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/weather_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/weather_agent.py index 6edaa02616..32324d72eb 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/weather_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/weather_agent.py @@ -57,7 +57,7 @@ def get_forecast(location: str, days: int = 3) -> str: return f"{days}-day forecast for {location}:\n" + "\n".join(forecast) -def weather_agent(chat_client: ChatClientProtocol) -> ChatAgent: +def weather_agent(chat_client: ChatClientProtocol[Any]) -> ChatAgent[Any]: """Create a weather agent with get_weather and get_forecast tools. Args: @@ -66,7 +66,7 @@ def weather_agent(chat_client: ChatClientProtocol) -> ChatAgent: Returns: A configured ChatAgent instance with weather tools """ - return ChatAgent( + return ChatAgent[Any]( name="weather_agent", instructions=( "You are a helpful weather assistant. " diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py index ebfc42ea19..e71abe7507 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py @@ -4,6 +4,7 @@ import logging import os +from typing import TYPE_CHECKING import uvicorn from agent_framework.ag_ui import add_agent_framework_fastapi_endpoint @@ -19,6 +20,10 @@ from ..agents.ui_generator_agent import ui_generator_agent from ..agents.weather_agent import weather_agent +if TYPE_CHECKING: + from agent_framework import ChatOptions + from agent_framework._clients import BaseChatClient + # Configure logging to file and console (disabled by default - set ENABLE_DEBUG_LOGGING=1 to enable) if os.getenv("ENABLE_DEBUG_LOGGING"): log_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "ag_ui_server.log") @@ -60,7 +65,7 @@ # Create a shared chat client for all agents # You can use different chat clients for different agents if needed -chat_client = AzureOpenAIChatClient() +chat_client: BaseChatClient[ChatOptions] = AzureOpenAIChatClient() # Agentic Chat - basic chat agent add_agent_framework_fastapi_endpoint( diff --git a/python/packages/ag-ui/getting_started/server.py b/python/packages/ag-ui/getting_started/server.py index e4ed669516..c8889126e9 100644 --- a/python/packages/ag-ui/getting_started/server.py +++ b/python/packages/ag-ui/getting_started/server.py @@ -9,7 +9,8 @@ from agent_framework.ag_ui import add_agent_framework_fastapi_endpoint from agent_framework.azure import AzureOpenAIChatClient from dotenv import load_dotenv -from fastapi import FastAPI +from fastapi import Depends, FastAPI, HTTPException, Security +from fastapi.security import APIKeyHeader load_dotenv() @@ -31,6 +32,60 @@ raise ValueError("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME environment variable is required") +# ============================================================================ +# AUTHENTICATION EXAMPLE +# ============================================================================ +# This demonstrates how to secure the AG-UI endpoint with API key authentication. +# In production, you should use a more robust authentication mechanism such as: +# - OAuth 2.0 / OpenID Connect +# - JWT tokens with proper validation +# - Azure AD / Entra ID integration +# - Your organization's identity provider +# +# The API key should be stored securely (e.g., Azure Key Vault, environment variables) +# and rotated regularly. +# ============================================================================ + +# API key header configuration +API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False) + +# Get the expected API key from environment variable +# In production, use a secrets manager like Azure Key Vault +EXPECTED_API_KEY = os.environ.get("AG_UI_API_KEY") + + +async def verify_api_key(api_key: str | None = Security(API_KEY_HEADER)) -> None: + """Verify the API key provided in the request header. + + Args: + api_key: The API key from the X-API-Key header + + Raises: + HTTPException: If the API key is missing or invalid + """ + if not EXPECTED_API_KEY: + # If no API key is configured, log a warning but allow the request + # This maintains backward compatibility but warns about the security risk + logger.warning( + "AG_UI_API_KEY environment variable not set. " + "The endpoint is accessible without authentication. " + "Set AG_UI_API_KEY to enable API key authentication." + ) + return + + if not api_key: + raise HTTPException( + status_code=401, + detail="Missing API key. Provide X-API-Key header.", + ) + + if api_key != EXPECTED_API_KEY: + raise HTTPException( + status_code=403, + detail="Invalid API key.", + ) + + # Server-side tool (executes on server) @ai_function(description="Get the time zone for a location.") def get_time_zone(location: str) -> str: @@ -72,8 +127,14 @@ def get_time_zone(location: str) -> str: # Create FastAPI app app = FastAPI(title="AG-UI Server") -# Register the AG-UI endpoint -add_agent_framework_fastapi_endpoint(app, agent, "/") +# Register the AG-UI endpoint with authentication +# The dependencies parameter accepts FastAPI Depends() objects that run before the handler +add_agent_framework_fastapi_endpoint( + app, + agent, + "/", + dependencies=[Depends(verify_api_key)], +) if __name__ == "__main__": import uvicorn diff --git a/python/packages/ag-ui/pyproject.toml b/python/packages/ag-ui/pyproject.toml index 10f3e19e40..21fe4f234b 100644 --- a/python/packages/ag-ui/pyproject.toml +++ b/python/packages/ag-ui/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "agent-framework-ag-ui" -version = "1.0.0b260107" +version = "1.0.0b260114" description = "AG-UI protocol integration for Agent Framework" readme = "README.md" license-files = ["LICENSE"] diff --git a/python/packages/ag-ui/tests/test_ag_ui_client.py b/python/packages/ag-ui/tests/test_ag_ui_client.py index 09570c1be4..bc1cc6d711 100644 --- a/python/packages/ag-ui/tests/test_ag_ui_client.py +++ b/python/packages/ag-ui/tests/test_ag_ui_client.py @@ -9,13 +9,13 @@ from agent_framework import ( ChatMessage, ChatOptions, + ChatResponse, ChatResponseUpdate, FunctionCallContent, Role, TextContent, ai_function, ) -from agent_framework._types import ChatResponse from pytest import MonkeyPatch from agent_framework_ag_ui._client import AGUIChatClient, ServerFunctionCallContent @@ -40,22 +40,22 @@ def convert_messages_to_agui_format(self, messages: list[ChatMessage]) -> list[d """Expose message conversion helper.""" return self._convert_messages_to_agui_format(messages) - def get_thread_id(self, chat_options: ChatOptions) -> str: + def get_thread_id(self, options: dict[str, Any]) -> str: """Expose thread id helper.""" - return self._get_thread_id(chat_options) + return self._get_thread_id(options) async def inner_get_streaming_response( - self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions + self, *, messages: MutableSequence[ChatMessage], options: dict[str, Any] ) -> AsyncIterable[ChatResponseUpdate]: """Proxy to protected streaming call.""" - async for update in self._inner_get_streaming_response(messages=messages, chat_options=chat_options): + async for update in self._inner_get_streaming_response(messages=messages, options=options): yield update async def inner_get_response( - self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions + self, *, messages: MutableSequence[ChatMessage], options: dict[str, Any] ) -> ChatResponse: """Proxy to protected response call.""" - return await self._inner_get_response(messages=messages, chat_options=chat_options) + return await self._inner_get_response(messages=messages, options=options) class TestAGUIChatClient: @@ -191,7 +191,7 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str chat_options = ChatOptions() updates: list[ChatResponseUpdate] = [] - async for update in client.inner_get_streaming_response(messages=messages, chat_options=chat_options): + async for update in client.inner_get_streaming_response(messages=messages, options=chat_options): updates.append(update) assert len(updates) == 4 @@ -221,9 +221,9 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [ChatMessage(role="user", text="Test message")] - chat_options = ChatOptions() + chat_options = {} - response = await client.inner_get_response(messages=messages, chat_options=chat_options) + response = await client.inner_get_response(messages=messages, options=chat_options) assert response is not None assert len(response.messages) > 0 @@ -266,7 +266,7 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str messages = [ChatMessage(role="user", text="Test with tools")] chat_options = ChatOptions(tools=[test_tool]) - response = await client.inner_get_response(messages=messages, chat_options=chat_options) + response = await client.inner_get_response(messages=messages, options=chat_options) assert response is not None @@ -288,10 +288,9 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [ChatMessage(role="user", text="Test server tool execution")] - chat_options = ChatOptions() updates: list[ChatResponseUpdate] = [] - async for update in client.get_streaming_response(messages, chat_options=chat_options): + async for update in client.get_streaming_response(messages): updates.append(update) function_calls = [ @@ -332,9 +331,8 @@ async def fake_auto_invoke(*args: object, **kwargs: Any) -> None: monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [ChatMessage(role="user", text="Test server tool execution")] - chat_options = ChatOptions(tool_choice="auto", tools=[client_tool]) - async for _ in client.get_streaming_response(messages, chat_options=chat_options): + async for _ in client.get_streaming_response(messages, options={"tool_choice": "auto", "tools": [client_tool]}): pass async def test_state_transmission(self, monkeypatch: MonkeyPatch) -> None: @@ -370,6 +368,6 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str chat_options = ChatOptions() - response = await client.inner_get_response(messages=messages, chat_options=chat_options) + response = await client.inner_get_response(messages=messages, options=chat_options) assert response is not None diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index 281b81c968..f919c00a56 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -9,12 +9,11 @@ from typing import Any import pytest -from agent_framework import ChatAgent, ChatMessage, ChatOptions, TextContent -from agent_framework._types import ChatResponseUpdate +from agent_framework import ChatAgent, ChatMessage, ChatOptions, ChatResponseUpdate, TextContent from pydantic import BaseModel sys.path.insert(0, str(Path(__file__).parent)) -from test_helpers_ag_ui import StreamingChatClientStub +from utils_test_ag_ui import StreamingChatClientStub async def test_agent_initialization_basic(): @@ -22,11 +21,15 @@ async def test_agent_initialization_basic(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) + agent = ChatAgent[ChatOptions]( + chat_client=StreamingChatClientStub(stream_fn), + name="test_agent", + instructions="Test", + ) wrapper = AgentFrameworkAgent(agent=agent) assert wrapper.name == "test_agent" @@ -40,7 +43,7 @@ async def test_agent_initialization_with_state_schema(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) @@ -56,7 +59,7 @@ async def test_agent_initialization_with_predict_state_config(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) @@ -72,7 +75,7 @@ async def test_agent_initialization_with_pydantic_state_schema(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) @@ -95,7 +98,7 @@ async def test_run_started_event_emission(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) @@ -119,7 +122,7 @@ async def test_predict_state_custom_event_emission(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) @@ -151,7 +154,7 @@ async def test_initial_state_snapshot_with_schema(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) @@ -181,7 +184,7 @@ async def test_state_initialization_object_type(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) @@ -208,7 +211,7 @@ async def test_state_initialization_array_type(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) @@ -235,7 +238,7 @@ async def test_run_finished_event_emission(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) @@ -257,7 +260,7 @@ async def test_tool_result_confirm_changes_accepted(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="Document updated")]) @@ -304,7 +307,7 @@ async def test_tool_result_confirm_changes_rejected(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="OK")]) @@ -338,7 +341,7 @@ async def test_tool_result_function_approval_accepted(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="OK")]) @@ -384,7 +387,7 @@ async def test_tool_result_function_approval_rejected(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="OK")]) @@ -423,10 +426,11 @@ async def test_thread_metadata_tracking(): thread_metadata: dict[str, Any] = {} async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - if chat_options.metadata: - thread_metadata.update(chat_options.metadata) + metadata = options.get("metadata") + if metadata: + thread_metadata.update(metadata) yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) @@ -448,15 +452,16 @@ async def stream_fn( async def test_state_context_injection(): """Test that current state is injected into thread metadata.""" - from agent_framework.ag_ui import AgentFrameworkAgent + from agent_framework_ag_ui import AgentFrameworkAgent thread_metadata: dict[str, Any] = {} async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - if chat_options.metadata: - thread_metadata.update(chat_options.metadata) + metadata = options.get("metadata") + if metadata: + thread_metadata.update(metadata) yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) @@ -485,7 +490,7 @@ async def test_no_messages_provided(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) @@ -509,7 +514,7 @@ async def test_message_end_event_emission(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="Hello world")]) @@ -537,7 +542,7 @@ async def test_error_handling_with_exception(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: if False: yield ChatResponseUpdate(contents=[]) @@ -558,7 +563,7 @@ async def test_json_decode_error_in_tool_result(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: if False: yield ChatResponseUpdate(contents=[]) @@ -595,7 +600,7 @@ async def test_suppressed_summary_with_document_state(): from agent_framework.ag_ui import AgentFrameworkAgent, DocumentWriterConfirmationStrategy async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="Response")]) @@ -632,6 +637,60 @@ async def stream_fn( assert "written" in full_text.lower() or "document" in full_text.lower() +async def test_agent_with_use_service_thread_is_false(): + """Test that when use_service_thread is False, the AgentThread used to run the agent is NOT set to the service thread ID.""" + from agent_framework.ag_ui import AgentFrameworkAgent + + request_service_thread_id: str | None = None + + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + nonlocal request_service_thread_id + thread = kwargs.get("thread") + request_service_thread_id = thread.service_thread_id if thread else None + yield ChatResponseUpdate( + contents=[TextContent(text="Response")], response_id="resp_67890", conversation_id="conv_12345" + ) + + agent = ChatAgent(chat_client=StreamingChatClientStub(stream_fn)) + wrapper = AgentFrameworkAgent(agent=agent, use_service_thread=False) + + input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"} + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + assert request_service_thread_id is None # type: ignore[attr-defined] (service_thread_id should be set) + + +async def test_agent_with_use_service_thread_is_true(): + """Test that when use_service_thread is True, the AgentThread used to run the agent is set to the service thread ID.""" + from agent_framework.ag_ui import AgentFrameworkAgent + + request_service_thread_id: str | None = None + + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + nonlocal request_service_thread_id + thread = kwargs.get("thread") + request_service_thread_id = thread.service_thread_id if thread else None + yield ChatResponseUpdate( + contents=[TextContent(text="Response")], response_id="resp_67890", conversation_id="conv_12345" + ) + + agent = ChatAgent(chat_client=StreamingChatClientStub(stream_fn)) + wrapper = AgentFrameworkAgent(agent=agent, use_service_thread=True) + + input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"} + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + assert request_service_thread_id == "conv_123456" # type: ignore[attr-defined] (service_thread_id should be set) + + async def test_function_approval_mode_executes_tool(): """Test that function approval with approval_mode='always_require' sends the correct messages.""" from agent_framework import FunctionResultContent, ai_function @@ -648,7 +707,7 @@ def get_datetime() -> str: return "2025/12/01 12:00:00" async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: # Capture the messages received by the chat client messages_received.clear() @@ -656,9 +715,9 @@ async def stream_fn( yield ChatResponseUpdate(contents=[TextContent(text="Processing completed")]) agent = ChatAgent( + chat_client=StreamingChatClientStub(stream_fn), name="test_agent", instructions="Test", - chat_client=StreamingChatClientStub(stream_fn), tools=[get_datetime], ) wrapper = AgentFrameworkAgent(agent=agent) @@ -739,7 +798,7 @@ def delete_all_data() -> str: return "All data deleted" async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: # Capture the messages received by the chat client messages_received.clear() diff --git a/python/packages/ag-ui/tests/test_backend_tool_rendering.py b/python/packages/ag-ui/tests/test_backend_tool_rendering.py index 97654182cf..446da23ff2 100644 --- a/python/packages/ag-ui/tests/test_backend_tool_rendering.py +++ b/python/packages/ag-ui/tests/test_backend_tool_rendering.py @@ -12,7 +12,7 @@ ToolCallResultEvent, ToolCallStartEvent, ) -from agent_framework import AgentRunResponseUpdate, FunctionCallContent, FunctionResultContent, TextContent +from agent_framework import AgentResponseUpdate, FunctionCallContent, FunctionResultContent, TextContent from agent_framework_ag_ui._events import AgentFrameworkEventBridge @@ -28,7 +28,7 @@ async def test_tool_call_flow(): arguments={"location": "Seattle"}, ) - update1 = AgentRunResponseUpdate(contents=[tool_call]) + update1 = AgentResponseUpdate(contents=[tool_call]) events1 = await bridge.from_agent_run_update(update1) # Should have: ToolCallStartEvent, ToolCallArgsEvent @@ -49,7 +49,7 @@ async def test_tool_call_flow(): result="Weather in Seattle: Rainy, 52°F", ) - update2 = AgentRunResponseUpdate(contents=[tool_result]) + update2 = AgentResponseUpdate(contents=[tool_result]) events2 = await bridge.from_agent_run_update(update2) # Should have: ToolCallEndEvent, ToolCallResultEvent @@ -78,7 +78,7 @@ async def test_text_with_tool_call(): arguments={"location": "San Francisco", "days": 3}, ) - update = AgentRunResponseUpdate(contents=[text_content, tool_call]) + update = AgentResponseUpdate(contents=[text_content, tool_call]) events = await bridge.from_agent_run_update(update) # Should have: TextMessageStart, TextMessageContent, ToolCallStart, ToolCallArgs @@ -107,7 +107,7 @@ async def test_multiple_tool_results(): FunctionResultContent(call_id="tool-3", result="Result 3"), ] - update = AgentRunResponseUpdate(contents=results) + update = AgentResponseUpdate(contents=results) events = await bridge.from_agent_run_update(update) # Should have 3 pairs of ToolCallEndEvent + ToolCallResultEvent = 6 events diff --git a/python/packages/ag-ui/tests/test_document_writer_flow.py b/python/packages/ag-ui/tests/test_document_writer_flow.py index 1ea164beef..2e5cec9f95 100644 --- a/python/packages/ag-ui/tests/test_document_writer_flow.py +++ b/python/packages/ag-ui/tests/test_document_writer_flow.py @@ -3,8 +3,7 @@ """Tests for document writer predictive state flow with confirm_changes.""" from ag_ui.core import EventType, StateDeltaEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallStartEvent -from agent_framework import FunctionCallContent, FunctionResultContent, TextContent -from agent_framework._types import AgentRunResponseUpdate +from agent_framework import AgentResponseUpdate, FunctionCallContent, FunctionResultContent, TextContent from agent_framework_ag_ui._events import AgentFrameworkEventBridge @@ -27,7 +26,7 @@ async def test_streaming_document_with_state_deltas(): name="write_document_local", arguments='{"document":"Once', ) - update1 = AgentRunResponseUpdate(contents=[tool_call_start]) + update1 = AgentResponseUpdate(contents=[tool_call_start]) events1 = await bridge.from_agent_run_update(update1) # Should have ToolCallStartEvent and ToolCallArgsEvent @@ -36,7 +35,7 @@ async def test_streaming_document_with_state_deltas(): # Second chunk - incomplete JSON, should try partial extraction tool_call_chunk2 = FunctionCallContent(call_id="call_123", name="write_document_local", arguments=" upon a time") - update2 = AgentRunResponseUpdate(contents=[tool_call_chunk2]) + update2 = AgentResponseUpdate(contents=[tool_call_chunk2]) events2 = await bridge.from_agent_run_update(update2) # Should emit StateDeltaEvent with partial document @@ -77,7 +76,7 @@ async def test_confirm_changes_emission(): result="Document written.", ) - update = AgentRunResponseUpdate(contents=[tool_result]) + update = AgentResponseUpdate(contents=[tool_result]) events = await bridge.from_agent_run_update(update) # Should have: ToolCallEndEvent, ToolCallResultEvent, StateSnapshotEvent, confirm_changes sequence @@ -117,7 +116,7 @@ async def test_text_suppression_before_confirm(): # Text content that should be suppressed text = TextContent(text="I have written a story about pirates.") - update = AgentRunResponseUpdate(contents=[text]) + update = AgentResponseUpdate(contents=[text]) events = await bridge.from_agent_run_update(update) @@ -152,7 +151,7 @@ async def test_no_confirm_for_non_predictive_tools(): result="Sunny, 72°F", ) - update = AgentRunResponseUpdate(contents=[tool_result]) + update = AgentResponseUpdate(contents=[tool_result]) events = await bridge.from_agent_run_update(update) # Should NOT have confirm_changes @@ -181,7 +180,7 @@ async def test_state_delta_deduplication(): name="write_document_local", arguments='{"document":"Same text"}', ) - update1 = AgentRunResponseUpdate(contents=[tool_call1]) + update1 = AgentResponseUpdate(contents=[tool_call1]) events1 = await bridge.from_agent_run_update(update1) # Count state deltas @@ -195,7 +194,7 @@ async def test_state_delta_deduplication(): name="write_document_local", arguments='{"document":"Same text"}', # Identical content ) - update2 = AgentRunResponseUpdate(contents=[tool_call2]) + update2 = AgentResponseUpdate(contents=[tool_call2]) events2 = await bridge.from_agent_run_update(update2) # Should NOT emit state delta (same value) @@ -222,7 +221,7 @@ async def test_predict_state_config_multiple_fields(): name="create_post", arguments='{"title":"My Post","body":"Post content"}', ) - update = AgentRunResponseUpdate(contents=[tool_call]) + update = AgentResponseUpdate(contents=[tool_call]) events = await bridge.from_agent_run_update(update) # Should emit StateDeltaEvent for both fields diff --git a/python/packages/ag-ui/tests/test_endpoint.py b/python/packages/ag-ui/tests/test_endpoint.py index 36c9e3bc32..59cb884c5c 100644 --- a/python/packages/ag-ui/tests/test_endpoint.py +++ b/python/packages/ag-ui/tests/test_endpoint.py @@ -6,16 +6,16 @@ import sys from pathlib import Path -from agent_framework import ChatAgent, TextContent -from agent_framework._types import ChatResponseUpdate -from fastapi import FastAPI +from agent_framework import ChatAgent, ChatResponseUpdate, TextContent +from fastapi import FastAPI, Header, HTTPException +from fastapi.params import Depends from fastapi.testclient import TestClient +from agent_framework_ag_ui import add_agent_framework_fastapi_endpoint from agent_framework_ag_ui._agent import AgentFrameworkAgent -from agent_framework_ag_ui._endpoint import add_agent_framework_fastapi_endpoint sys.path.insert(0, str(Path(__file__).parent)) -from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates +from utils_test_ag_ui import StreamingChatClientStub, stream_from_updates def build_chat_client(response_text: str = "Test response") -> StreamingChatClientStub: @@ -176,11 +176,8 @@ async def test_endpoint_error_handling(): # Send invalid JSON to trigger parsing error before streaming response = client.post("/failing", data=b"invalid json", headers={"content-type": "application/json"}) # type: ignore - # The exception handler catches it and returns JSON error - assert response.status_code == 200 - content = json.loads(response.content) - assert "error" in content - assert content["error"] == "An internal error has occurred." + # Pydantic validation now returns 422 for invalid request body + assert response.status_code == 422 async def test_endpoint_multiple_paths(): @@ -266,3 +263,206 @@ async def test_endpoint_complex_input(): ) assert response.status_code == 200 + + +async def test_endpoint_openapi_schema(): + """Test that endpoint generates proper OpenAPI schema with request model.""" + app = FastAPI() + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) + + add_agent_framework_fastapi_endpoint(app, agent, path="/schema-test") + + client = TestClient(app) + response = client.get("/openapi.json") + + assert response.status_code == 200 + openapi_spec = response.json() + + # Verify the endpoint exists in the schema + assert "/schema-test" in openapi_spec["paths"] + endpoint_spec = openapi_spec["paths"]["/schema-test"]["post"] + + # Verify request body schema is defined + assert "requestBody" in endpoint_spec + request_body = endpoint_spec["requestBody"] + assert "content" in request_body + assert "application/json" in request_body["content"] + + # Verify schema references AGUIRequest model + schema_ref = request_body["content"]["application/json"]["schema"] + assert "$ref" in schema_ref + assert "AGUIRequest" in schema_ref["$ref"] + + # Verify AGUIRequest model is in components + assert "components" in openapi_spec + assert "schemas" in openapi_spec["components"] + assert "AGUIRequest" in openapi_spec["components"]["schemas"] + + # Verify AGUIRequest has required fields + agui_request_schema = openapi_spec["components"]["schemas"]["AGUIRequest"] + assert "properties" in agui_request_schema + assert "messages" in agui_request_schema["properties"] + assert "run_id" in agui_request_schema["properties"] + assert "thread_id" in agui_request_schema["properties"] + assert "state" in agui_request_schema["properties"] + assert "required" in agui_request_schema + assert "messages" in agui_request_schema["required"] + + +async def test_endpoint_default_tags(): + """Test that endpoint uses default 'AG-UI' tag.""" + app = FastAPI() + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) + + add_agent_framework_fastapi_endpoint(app, agent, path="/default-tags") + + client = TestClient(app) + response = client.get("/openapi.json") + + assert response.status_code == 200 + openapi_spec = response.json() + + endpoint_spec = openapi_spec["paths"]["/default-tags"]["post"] + assert "tags" in endpoint_spec + assert endpoint_spec["tags"] == ["AG-UI"] + + +async def test_endpoint_custom_tags(): + """Test that endpoint accepts custom tags.""" + app = FastAPI() + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) + + add_agent_framework_fastapi_endpoint(app, agent, path="/custom-tags", tags=["Custom", "Agent"]) + + client = TestClient(app) + response = client.get("/openapi.json") + + assert response.status_code == 200 + openapi_spec = response.json() + + endpoint_spec = openapi_spec["paths"]["/custom-tags"]["post"] + assert "tags" in endpoint_spec + assert endpoint_spec["tags"] == ["Custom", "Agent"] + + +async def test_endpoint_missing_required_field(): + """Test that endpoint validates required fields with Pydantic.""" + app = FastAPI() + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) + + add_agent_framework_fastapi_endpoint(app, agent, path="/validation") + + client = TestClient(app) + + # Missing required 'messages' field should trigger validation error + response = client.post("/validation", json={"run_id": "test-123"}) + + assert response.status_code == 422 + error_detail = response.json() + assert "detail" in error_detail + + +async def test_endpoint_internal_error_handling(): + """Test endpoint error handling when an exception occurs before streaming starts.""" + from unittest.mock import patch + + app = FastAPI() + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) + + # Use default_state to trigger the code path that can raise an exception + add_agent_framework_fastapi_endpoint(app, agent, path="/error-test", default_state={"key": "value"}) + + client = TestClient(app) + + # Mock copy.deepcopy to raise an exception during default_state processing + with patch("agent_framework_ag_ui._endpoint.copy.deepcopy") as mock_deepcopy: + mock_deepcopy.side_effect = Exception("Simulated internal error") + response = client.post("/error-test", json={"messages": [{"role": "user", "content": "Hello"}]}) + + assert response.status_code == 200 + assert response.json() == {"error": "An internal error has occurred."} + + +async def test_endpoint_with_dependencies_blocks_unauthorized(): + """Test that endpoint blocks requests when authentication dependency fails.""" + app = FastAPI() + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) + + async def require_api_key(x_api_key: str | None = Header(None)): + if x_api_key != "secret-key": + raise HTTPException(status_code=401, detail="Unauthorized") + + add_agent_framework_fastapi_endpoint(app, agent, path="/protected", dependencies=[Depends(require_api_key)]) + + client = TestClient(app) + + # Request without API key should be rejected + response = client.post("/protected", json={"messages": [{"role": "user", "content": "Hello"}]}) + assert response.status_code == 401 + assert response.json()["detail"] == "Unauthorized" + + +async def test_endpoint_with_dependencies_allows_authorized(): + """Test that endpoint allows requests when authentication dependency passes.""" + app = FastAPI() + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) + + async def require_api_key(x_api_key: str | None = Header(None)): + if x_api_key != "secret-key": + raise HTTPException(status_code=401, detail="Unauthorized") + + add_agent_framework_fastapi_endpoint(app, agent, path="/protected", dependencies=[Depends(require_api_key)]) + + client = TestClient(app) + + # Request with valid API key should succeed + response = client.post( + "/protected", + json={"messages": [{"role": "user", "content": "Hello"}]}, + headers={"x-api-key": "secret-key"}, + ) + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + +async def test_endpoint_with_multiple_dependencies(): + """Test that endpoint supports multiple dependencies.""" + app = FastAPI() + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) + + execution_order: list[str] = [] + + async def first_dependency(): + execution_order.append("first") + + async def second_dependency(): + execution_order.append("second") + + add_agent_framework_fastapi_endpoint( + app, + agent, + path="/multi-deps", + dependencies=[Depends(first_dependency), Depends(second_dependency)], + ) + + client = TestClient(app) + response = client.post("/multi-deps", json={"messages": [{"role": "user", "content": "Hello"}]}) + + assert response.status_code == 200 + assert "first" in execution_order + assert "second" in execution_order + + +async def test_endpoint_without_dependencies_is_accessible(): + """Test that endpoint without dependencies remains accessible (backward compatibility).""" + app = FastAPI() + agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client()) + + # No dependencies parameter - should be accessible without auth + add_agent_framework_fastapi_endpoint(app, agent, path="/open") + + client = TestClient(app) + response = client.post("/open", json={"messages": [{"role": "user", "content": "Hello"}]}) + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" diff --git a/python/packages/ag-ui/tests/test_events_comprehensive.py b/python/packages/ag-ui/tests/test_events_comprehensive.py index cfd45ea5c8..295ba00372 100644 --- a/python/packages/ag-ui/tests/test_events_comprehensive.py +++ b/python/packages/ag-ui/tests/test_events_comprehensive.py @@ -5,7 +5,7 @@ import json from agent_framework import ( - AgentRunResponseUpdate, + AgentResponseUpdate, FunctionApprovalRequestContent, FunctionCallContent, FunctionResultContent, @@ -19,7 +19,7 @@ async def test_basic_text_message_conversion(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update = AgentRunResponseUpdate(contents=[TextContent(text="Hello")]) + update = AgentResponseUpdate(contents=[TextContent(text="Hello")]) events = await bridge.from_agent_run_update(update) assert len(events) == 2 @@ -35,8 +35,8 @@ async def test_text_message_streaming(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update1 = AgentRunResponseUpdate(contents=[TextContent(text="Hello ")]) - update2 = AgentRunResponseUpdate(contents=[TextContent(text="world")]) + update1 = AgentResponseUpdate(contents=[TextContent(text="Hello ")]) + update2 = AgentResponseUpdate(contents=[TextContent(text="world")]) events1 = await bridge.from_agent_run_update(update1) events2 = await bridge.from_agent_run_update(update2) @@ -61,7 +61,7 @@ async def test_skip_text_content_for_structured_outputs(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread", skip_text_content=True) - update = AgentRunResponseUpdate(contents=[TextContent(text='{"result": "data"}')]) + update = AgentResponseUpdate(contents=[TextContent(text='{"result": "data"}')]) events = await bridge.from_agent_run_update(update) # No events should be emitted @@ -74,9 +74,9 @@ async def test_skip_text_content_for_empty_text(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update1 = AgentRunResponseUpdate(contents=[TextContent(text="Hello ")]) - update2 = AgentRunResponseUpdate(contents=[TextContent(text="")]) # Empty chunk - update3 = AgentRunResponseUpdate(contents=[TextContent(text="world")]) + update1 = AgentResponseUpdate(contents=[TextContent(text="Hello ")]) + update2 = AgentResponseUpdate(contents=[TextContent(text="")]) # Empty chunk + update3 = AgentResponseUpdate(contents=[TextContent(text="world")]) events1 = await bridge.from_agent_run_update(update1) events2 = await bridge.from_agent_run_update(update2) @@ -105,7 +105,7 @@ async def test_tool_call_with_name(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update = AgentRunResponseUpdate(contents=[FunctionCallContent(name="search_web", call_id="call_123")]) + update = AgentResponseUpdate(contents=[FunctionCallContent(name="search_web", call_id="call_123")]) events = await bridge.from_agent_run_update(update) assert len(events) == 1 @@ -121,17 +121,15 @@ async def test_tool_call_streaming_args(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") # First chunk: name only - update1 = AgentRunResponseUpdate(contents=[FunctionCallContent(name="search_web", call_id="call_123")]) + update1 = AgentResponseUpdate(contents=[FunctionCallContent(name="search_web", call_id="call_123")]) events1 = await bridge.from_agent_run_update(update1) # Second chunk: arguments chunk 1 (name can be empty string for continuation) - update2 = AgentRunResponseUpdate( - contents=[FunctionCallContent(name="", call_id="call_123", arguments='{"query": "')] - ) + update2 = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_123", arguments='{"query": "')]) events2 = await bridge.from_agent_run_update(update2) # Third chunk: arguments chunk 2 - update3 = AgentRunResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_123", arguments='AI"}')]) + update3 = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_123", arguments='AI"}')]) events3 = await bridge.from_agent_run_update(update3) # First update: ToolCallStartEvent @@ -152,6 +150,42 @@ async def test_tool_call_streaming_args(): assert events1[0].tool_call_id == events2[0].tool_call_id == events3[0].tool_call_id +async def test_streaming_tool_call_no_duplicate_start_events(): + """Test that streaming tool calls emit exactly one ToolCallStartEvent. + + This is a regression test for the Anthropic streaming fix where input_json_delta + events were incorrectly passing the tool name, causing duplicate ToolCallStartEvents. + + The correct behavior is: + - Initial FunctionCallContent with name -> emits ToolCallStartEvent + - Subsequent FunctionCallContent with name="" -> emits only ToolCallArgsEvent + + See: https://github.com/microsoft/agent-framework/pull/3051 + """ + from agent_framework_ag_ui._events import AgentFrameworkEventBridge + + bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") + + # Simulate streaming tool call: first chunk has name, subsequent chunks have name="" + update1 = AgentResponseUpdate(contents=[FunctionCallContent(name="get_weather", call_id="call_789")]) + update2 = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_789", arguments='{"loc":')]) + update3 = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_789", arguments='"SF"}')]) + + events1 = await bridge.from_agent_run_update(update1) + events2 = await bridge.from_agent_run_update(update2) + events3 = await bridge.from_agent_run_update(update3) + + # Count all ToolCallStartEvents - should be exactly 1 + all_events = events1 + events2 + events3 + tool_call_start_count = sum(1 for e in all_events if e.type == "TOOL_CALL_START") + assert tool_call_start_count == 1, f"Expected 1 ToolCallStartEvent, got {tool_call_start_count}" + + # Verify event types + assert events1[0].type == "TOOL_CALL_START" + assert events2[0].type == "TOOL_CALL_ARGS" + assert events3[0].type == "TOOL_CALL_ARGS" + + async def test_tool_result_with_dict(): """Test FunctionResultContent with dict result.""" from agent_framework_ag_ui._events import AgentFrameworkEventBridge @@ -159,7 +193,7 @@ async def test_tool_result_with_dict(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") result_data = {"status": "success", "count": 42} - update = AgentRunResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result=result_data)]) + update = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result=result_data)]) events = await bridge.from_agent_run_update(update) # Should emit ToolCallEndEvent + ToolCallResultEvent @@ -180,7 +214,7 @@ async def test_tool_result_with_string(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update = AgentRunResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result="Search complete")]) + update = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result="Search complete")]) events = await bridge.from_agent_run_update(update) assert len(events) == 2 @@ -195,7 +229,7 @@ async def test_tool_result_with_none(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update = AgentRunResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result=None)]) + update = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result=None)]) events = await bridge.from_agent_run_update(update) assert len(events) == 2 @@ -211,7 +245,7 @@ async def test_multiple_tool_results_in_sequence(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[ FunctionResultContent(call_id="call_1", result="Result 1"), FunctionResultContent(call_id="call_2", result="Result 2"), @@ -248,7 +282,7 @@ async def test_function_approval_request_basic(): function_call=func_call, ) - update = AgentRunResponseUpdate(contents=[approval]) + update = AgentResponseUpdate(contents=[approval]) events = await bridge.from_agent_run_update(update) # Should emit: ToolCallEndEvent + CustomEvent @@ -276,7 +310,7 @@ async def test_empty_predict_state_config(): ) # Tool call with arguments - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[ FunctionCallContent(name="write_doc", call_id="call_1", arguments='{"content": "test"}'), FunctionResultContent(call_id="call_1", result="Done"), @@ -311,7 +345,7 @@ async def test_tool_not_in_predict_state_config(): ) # Different tool name - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[ FunctionCallContent(name="search_web", call_id="call_1", arguments='{"query": "AI"}'), FunctionResultContent(call_id="call_1", result="Results"), @@ -340,7 +374,7 @@ async def test_state_management_tracking(): ) # Streaming tool call - update1 = AgentRunResponseUpdate( + update1 = AgentResponseUpdate( contents=[ FunctionCallContent(name="write_doc", call_id="call_1"), FunctionCallContent(name="", call_id="call_1", arguments='{"content": "Hello"}'), @@ -353,7 +387,7 @@ async def test_state_management_tracking(): assert bridge.pending_state_updates["document"] == "Hello" # Tool result should update current_state - update2 = AgentRunResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")]) + update2 = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")]) await bridge.from_agent_run_update(update2) # current_state should be updated @@ -377,7 +411,7 @@ async def test_wildcard_tool_argument(): ) # Complete tool call with dict arguments - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[ FunctionCallContent( name="create_recipe", @@ -467,7 +501,7 @@ async def test_state_snapshot_after_tool_result(): ) # Tool call with streaming args - update1 = AgentRunResponseUpdate( + update1 = AgentResponseUpdate( contents=[ FunctionCallContent(name="write_doc", call_id="call_1"), FunctionCallContent(name="", call_id="call_1", arguments='{"content": "Test"}'), @@ -476,7 +510,7 @@ async def test_state_snapshot_after_tool_result(): await bridge.from_agent_run_update(update1) # Tool result should trigger StateSnapshotEvent - update2 = AgentRunResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")]) + update2 = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")]) events = await bridge.from_agent_run_update(update2) # Should have: ToolCallEnd, ToolCallResult, StateSnapshot, ToolCallStart (confirm_changes), ToolCallArgs, ToolCallEnd @@ -492,12 +526,12 @@ async def test_message_id_persistence_across_chunks(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") # First chunk - update1 = AgentRunResponseUpdate(contents=[TextContent(text="Hello ")]) + update1 = AgentResponseUpdate(contents=[TextContent(text="Hello ")]) events1 = await bridge.from_agent_run_update(update1) message_id = events1[0].message_id # Second chunk - update2 = AgentRunResponseUpdate(contents=[TextContent(text="world")]) + update2 = AgentResponseUpdate(contents=[TextContent(text="world")]) events2 = await bridge.from_agent_run_update(update2) # Should use same message_id @@ -512,14 +546,14 @@ async def test_tool_call_id_tracking(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") # First chunk with name - update1 = AgentRunResponseUpdate(contents=[FunctionCallContent(name="search", call_id="call_1")]) + update1 = AgentResponseUpdate(contents=[FunctionCallContent(name="search", call_id="call_1")]) await bridge.from_agent_run_update(update1) assert bridge.current_tool_call_id == "call_1" assert bridge.current_tool_call_name == "search" # Second chunk with args but no name - update2 = AgentRunResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_1", arguments='{"q":"AI"}')]) + update2 = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_1", arguments='{"q":"AI"}')]) events2 = await bridge.from_agent_run_update(update2) # Should still track same tool call @@ -540,7 +574,7 @@ async def test_tool_name_reset_after_result(): ) # Tool call - update1 = AgentRunResponseUpdate( + update1 = AgentResponseUpdate( contents=[ FunctionCallContent(name="write_doc", call_id="call_1"), FunctionCallContent(name="", call_id="call_1", arguments='{"content": "Test"}'), @@ -551,7 +585,7 @@ async def test_tool_name_reset_after_result(): assert bridge.current_tool_call_name == "write_doc" # Tool result with predictive state (should trigger confirm_changes and reset) - update2 = AgentRunResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")]) + update2 = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")]) await bridge.from_agent_run_update(update2) # Tool name should be reset @@ -577,7 +611,7 @@ async def test_function_approval_with_wildcard_argument(): ), ) - update = AgentRunResponseUpdate(contents=[approval_content]) + update = AgentResponseUpdate(contents=[approval_content]) events = await bridge.from_agent_run_update(update) # Should emit StateSnapshotEvent with entire parsed args as value @@ -603,7 +637,7 @@ async def test_function_approval_missing_argument(): function_call=FunctionCallContent(name="process", call_id="call_1", arguments='{"other_field": "value"}'), ) - update = AgentRunResponseUpdate(contents=[approval_content]) + update = AgentResponseUpdate(contents=[approval_content]) events = await bridge.from_agent_run_update(update) # Should not emit StateSnapshotEvent since argument not found @@ -618,7 +652,7 @@ async def test_empty_predict_state_config_no_deltas(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread", predict_state_config={}) # Tool call with arguments - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[ FunctionCallContent(name="search", call_id="call_1"), FunctionCallContent(name="", call_id="call_1", arguments='{"query": "test"}'), @@ -642,7 +676,7 @@ async def test_tool_with_no_matching_config(): ) # Tool call for different tool - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[ FunctionCallContent(name="search_web", call_id="call_1"), FunctionCallContent(name="", call_id="call_1", arguments='{"query": "test"}'), @@ -662,7 +696,7 @@ async def test_tool_call_without_name_or_id(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") # This should not crash but log an error - update = AgentRunResponseUpdate(contents=[FunctionCallContent(name="", call_id="", arguments='{"arg": "val"}')]) + update = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="", arguments='{"arg": "val"}')]) events = await bridge.from_agent_run_update(update) # Should emit ToolCallArgsEvent with generated ID @@ -681,7 +715,7 @@ async def test_state_delta_count_logging(): # Emit multiple state deltas with different content each time for i in range(15): - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[ FunctionCallContent(name="", call_id="call_1", arguments=f'{{"text": "Content variation {i}"}}'), ] @@ -703,7 +737,7 @@ async def test_tool_result_with_empty_list(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update = AgentRunResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result=[])]) + update = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result=[])]) events = await bridge.from_agent_run_update(update) assert len(events) == 2 @@ -725,7 +759,7 @@ class MockTextContent: bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[FunctionResultContent(call_id="call_123", result=[MockTextContent("Hello from MCP tool!")])] ) events = await bridge.from_agent_run_update(update) @@ -749,7 +783,7 @@ class MockTextContent: bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[ FunctionResultContent( call_id="call_123", @@ -777,7 +811,7 @@ class MockModel(BaseModel): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[FunctionResultContent(call_id="call_123", result=[MockModel(value=1), MockModel(value=2)])] ) events = await bridge.from_agent_run_update(update) diff --git a/python/packages/ag-ui/tests/test_human_in_the_loop.py b/python/packages/ag-ui/tests/test_human_in_the_loop.py index 55a2869c91..00e64472b6 100644 --- a/python/packages/ag-ui/tests/test_human_in_the_loop.py +++ b/python/packages/ag-ui/tests/test_human_in_the_loop.py @@ -2,8 +2,7 @@ """Tests for human in the loop (function approval requests).""" -from agent_framework import FunctionApprovalRequestContent, FunctionCallContent -from agent_framework._types import AgentRunResponseUpdate +from agent_framework import AgentResponseUpdate, FunctionApprovalRequestContent, FunctionCallContent from agent_framework_ag_ui._events import AgentFrameworkEventBridge @@ -28,7 +27,7 @@ async def test_function_approval_request_emission(): function_call=func_call, ) - update = AgentRunResponseUpdate(contents=[approval_request]) + update = AgentResponseUpdate(contents=[approval_request]) events = await bridge.from_agent_run_update(update) # Should emit ToolCallEndEvent + CustomEvent for approval request @@ -67,7 +66,7 @@ async def test_function_approval_request_with_confirm_changes(): function_call=func_call, ) - update = AgentRunResponseUpdate(contents=[approval_request]) + update = AgentResponseUpdate(contents=[approval_request]) events = await bridge.from_agent_run_update(update) # Should emit: ToolCallEndEvent, CustomEvent, and confirm_changes (Start, Args, End) = 5 events @@ -130,7 +129,7 @@ async def test_multiple_approval_requests(): function_call=func_call_2, ) - update = AgentRunResponseUpdate(contents=[approval_1, approval_2]) + update = AgentResponseUpdate(contents=[approval_1, approval_2]) events = await bridge.from_agent_run_update(update) # Should emit ToolCallEndEvent + CustomEvent for each approval (4 events total) @@ -175,7 +174,7 @@ async def test_function_approval_request_sets_stop_flag(): function_call=func_call, ) - update = AgentRunResponseUpdate(contents=[approval_request]) + update = AgentResponseUpdate(contents=[approval_request]) await bridge.from_agent_run_update(update) assert bridge.should_stop_after_confirm is True diff --git a/python/packages/ag-ui/tests/test_orchestrators.py b/python/packages/ag-ui/tests/test_orchestrators.py index 8c00602538..279ddedc82 100644 --- a/python/packages/ag-ui/tests/test_orchestrators.py +++ b/python/packages/ag-ui/tests/test_orchestrators.py @@ -3,11 +3,20 @@ """Tests for AG-UI orchestrators.""" from collections.abc import AsyncGenerator -from types import SimpleNamespace from typing import Any - -from agent_framework import AgentRunResponseUpdate, TextContent, ai_function -from agent_framework._tools import FunctionInvocationConfiguration +from unittest.mock import MagicMock + +from ag_ui.core import BaseEvent, RunFinishedEvent +from agent_framework import ( + AgentResponseUpdate, + AgentThread, + BaseChatClient, + ChatAgent, + ChatResponseUpdate, + FunctionInvocationConfiguration, + TextContent, + ai_function, +) from agent_framework_ag_ui._agent import AgentConfig from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, ExecutionContext @@ -19,56 +28,77 @@ def server_tool() -> str: return "server" -class DummyAgent: - """Minimal agent stub to capture run_stream parameters.""" - - def __init__(self) -> None: - self.chat_options = SimpleNamespace(tools=[server_tool], response_format=None) - self.tools = [server_tool] - self.chat_client = SimpleNamespace( - function_invocation_configuration=FunctionInvocationConfiguration(), - ) - self.seen_tools: list[Any] | None = None +def _create_mock_chat_agent( + tools: list[Any] | None = None, + response_format: Any = None, + capture_tools: list[Any] | None = None, + capture_messages: list[Any] | None = None, +) -> ChatAgent: + """Create a ChatAgent with mocked chat client for testing. + + Args: + tools: Tools to configure on the agent. + response_format: Response format to configure. + capture_tools: If provided, tools passed to run_stream will be appended here. + capture_messages: If provided, messages passed to run_stream will be appended here. + """ + mock_chat_client = MagicMock(spec=BaseChatClient) + mock_chat_client.function_invocation_configuration = FunctionInvocationConfiguration() + + agent = ChatAgent( + chat_client=mock_chat_client, + tools=tools or [server_tool], + response_format=response_format, + ) - async def run_stream( - self, + # Create a mock run_stream that captures parameters and yields a simple response + async def mock_run_stream( messages: list[Any], *, - thread: Any, + # thread: AgentThread, + # tools: list[Any] | None = None, + # **kwargs: Any, + # ) -> AsyncGenerator[AgentRunResponseUpdate, None]: + # self.seen_tools = tools + # yield AgentRunResponseUpdate( + # contents=[TextContent(text="ok")], + # role="assistant", + # response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) + # raw_representation=ChatResponseUpdate( + # contents=[TextContent(text="ok")], + # conversation_id=thread.metadata.get("ag_ui_thread_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) + # response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) + # ), + # ) + thread: AgentThread, tools: list[Any] | None = None, **kwargs: Any, - ) -> AsyncGenerator[AgentRunResponseUpdate, None]: - self.seen_tools = tools - yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant") - - -class RecordingAgent: - """Agent stub that captures messages passed to run_stream.""" - - def __init__(self) -> None: - self.chat_options = SimpleNamespace(tools=[], response_format=None) - self.tools: list[Any] = [] - self.chat_client = SimpleNamespace( - function_invocation_configuration=FunctionInvocationConfiguration(), + ) -> AsyncGenerator[AgentResponseUpdate, None]: + if capture_tools is not None and tools is not None: + capture_tools.extend(tools) + if capture_messages is not None: + capture_messages.extend(messages) + yield AgentResponseUpdate( + contents=[TextContent(text="ok")], + role="assistant", + response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) + raw_representation=ChatResponseUpdate( + contents=[TextContent(text="ok")], + conversation_id=thread.metadata.get("ag_ui_thread_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) + response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) + ), ) - self.seen_messages: list[Any] | None = None - async def run_stream( - self, - messages: list[Any], - *, - thread: Any, - tools: list[Any] | None = None, - **kwargs: Any, - ) -> AsyncGenerator[AgentRunResponseUpdate, None]: - self.seen_messages = messages - yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant") + # Patch the run_stream method + agent.run_stream = mock_run_stream # type: ignore[method-assign] + + return agent async def test_default_orchestrator_merges_client_tools() -> None: """Client tool declarations are merged with server tools before running agent.""" - - agent = DummyAgent() + captured_tools: list[Any] = [] + agent = _create_mock_chat_agent(tools=[server_tool], capture_tools=captured_tools) orchestrator = DefaultOrchestrator() input_data = { @@ -101,8 +131,8 @@ async def test_default_orchestrator_merges_client_tools() -> None: async for event in orchestrator.run(context): events.append(event) - assert agent.seen_tools is not None - tool_names = [getattr(tool, "name", "?") for tool in agent.seen_tools] + assert len(captured_tools) > 0 + tool_names = [getattr(tool, "name", "?") for tool in captured_tools] assert "server_tool" in tool_names assert "get_weather" in tool_names assert agent.chat_client.function_invocation_configuration.additional_tools @@ -110,8 +140,7 @@ async def test_default_orchestrator_merges_client_tools() -> None: async def test_default_orchestrator_with_camel_case_ids() -> None: """Client tool is able to extract camelCase IDs.""" - - agent = DummyAgent() + agent = _create_mock_chat_agent() orchestrator = DefaultOrchestrator() input_data = { @@ -137,6 +166,7 @@ async def test_default_orchestrator_with_camel_case_ids() -> None: events.append(event) # assert the last event has the expected run_id and thread_id + assert isinstance(events[-1], RunFinishedEvent) last_event = events[-1] assert last_event.run_id == "test-camelcase-runid" assert last_event.thread_id == "test-camelcase-threadid" @@ -144,8 +174,7 @@ async def test_default_orchestrator_with_camel_case_ids() -> None: async def test_default_orchestrator_with_snake_case_ids() -> None: """Client tool is able to extract snake_case IDs.""" - - agent = DummyAgent() + agent = _create_mock_chat_agent() orchestrator = DefaultOrchestrator() input_data = { @@ -166,11 +195,12 @@ async def test_default_orchestrator_with_snake_case_ids() -> None: config=AgentConfig(), ) - events = [] + events: list[BaseEvent] = [] async for event in orchestrator.run(context): events.append(event) # assert the last event has the expected run_id and thread_id + assert isinstance(events[-1], RunFinishedEvent) last_event = events[-1] assert last_event.run_id == "test-snakecase-runid" assert last_event.thread_id == "test-snakecase-threadid" @@ -178,8 +208,8 @@ async def test_default_orchestrator_with_snake_case_ids() -> None: async def test_state_context_injected_when_tool_call_state_mismatch() -> None: """State context should be injected when current state differs from tool call args.""" - - agent = RecordingAgent() + captured_messages: list[Any] = [] + agent = _create_mock_chat_agent(tools=[], capture_messages=captured_messages) orchestrator = DefaultOrchestrator() tool_recipe = {"title": "Salad", "special_preferences": []} @@ -216,9 +246,9 @@ async def test_state_context_injected_when_tool_call_state_mismatch() -> None: async for _event in orchestrator.run(context): pass - assert agent.seen_messages is not None + assert len(captured_messages) > 0 state_messages = [] - for msg in agent.seen_messages: + for msg in captured_messages: role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) if role_value != "system": continue @@ -231,8 +261,8 @@ async def test_state_context_injected_when_tool_call_state_mismatch() -> None: async def test_state_context_not_injected_when_tool_call_matches_state() -> None: """State context should be skipped when tool call args match current state.""" - - agent = RecordingAgent() + captured_messages: list[Any] = [] + agent = _create_mock_chat_agent(tools=[], capture_messages=captured_messages) orchestrator = DefaultOrchestrator() input_data = { @@ -265,9 +295,9 @@ async def test_state_context_not_injected_when_tool_call_matches_state() -> None async for _event in orchestrator.run(context): pass - assert agent.seen_messages is not None + assert len(captured_messages) > 0 state_messages = [] - for msg in agent.seen_messages: + for msg in captured_messages: role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) if role_value != "system": continue diff --git a/python/packages/ag-ui/tests/test_orchestrators_coverage.py b/python/packages/ag-ui/tests/test_orchestrators_coverage.py index 041e25c3d2..6c311d593a 100644 --- a/python/packages/ag-ui/tests/test_orchestrators_coverage.py +++ b/python/packages/ag-ui/tests/test_orchestrators_coverage.py @@ -9,7 +9,7 @@ from typing import Any from agent_framework import ( - AgentRunResponseUpdate, + AgentResponseUpdate, ChatMessage, TextContent, ai_function, @@ -20,7 +20,7 @@ from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, HumanInTheLoopOrchestrator sys.path.insert(0, str(Path(__file__).parent)) -from test_helpers_ag_ui import StubAgent, TestExecutionContext +from utils_test_ag_ui import StubAgent, TestExecutionContext @ai_function(approval_mode="always_require") @@ -29,7 +29,7 @@ def approval_tool(param: str) -> str: return f"executed: {param}" -DEFAULT_CHAT_OPTIONS = SimpleNamespace(tools=[approval_tool], response_format=None) +DEFAULT_OPTIONS: dict[str, Any] = {"tools": [approval_tool], "response_format": None} async def test_human_in_the_loop_json_decode_error() -> None: @@ -54,8 +54,8 @@ async def test_human_in_the_loop_json_decode_error() -> None: ] agent = StubAgent( - chat_options=SimpleNamespace(tools=[approval_tool], response_format=None), - updates=[AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")], + default_options={"tools": [approval_tool], "response_format": None}, + updates=[AgentResponseUpdate(contents=[TextContent(text="response")], role="assistant")], ) context = TestExecutionContext( input_data=input_data, @@ -106,7 +106,7 @@ async def test_sanitize_tool_history_confirm_changes() -> None: input_data: dict[str, Any] = {"messages": []} agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) context = TestExecutionContext( input_data=input_data, @@ -151,7 +151,7 @@ async def test_sanitize_tool_history_orphaned_tool_result() -> None: orchestrator = DefaultOrchestrator() input_data: dict[str, Any] = {"messages": []} agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) context = TestExecutionContext( input_data=input_data, @@ -191,7 +191,7 @@ async def test_orphaned_tool_result_sanitization() -> None: } agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) context = TestExecutionContext( input_data=input_data, @@ -234,7 +234,7 @@ async def test_deduplicate_messages_empty_tool_results() -> None: orchestrator = DefaultOrchestrator() input_data: dict[str, Any] = {"messages": []} agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) context = TestExecutionContext( input_data=input_data, @@ -279,7 +279,7 @@ async def test_deduplicate_messages_duplicate_assistant_tool_calls() -> None: orchestrator = DefaultOrchestrator() input_data: dict[str, Any] = {"messages": []} agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) context = TestExecutionContext( input_data=input_data, @@ -323,7 +323,7 @@ async def test_deduplicate_messages_duplicate_system_messages() -> None: orchestrator = DefaultOrchestrator() input_data: dict[str, Any] = {"messages": []} agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) context = TestExecutionContext( input_data=input_data, @@ -362,7 +362,7 @@ async def test_state_context_injection() -> None: } agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) context = TestExecutionContext( input_data=input_data, @@ -407,7 +407,7 @@ async def test_state_context_injection_with_tool_calls_and_input_state() -> None orchestrator = DefaultOrchestrator() input_data: dict[str, Any] = {"messages": [], "state": {"weather": "sunny"}} agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) context = TestExecutionContext( input_data=input_data, @@ -449,15 +449,15 @@ class RecipeState(BaseModel): # Agent with structured output agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, updates=[ - AgentRunResponseUpdate( + AgentResponseUpdate( contents=[TextContent(text='{"ingredients": ["tomato"], "message": "Added tomato"}')], role="assistant", ) ], ) - agent.chat_options.response_format = RecipeState + agent.default_options["response_format"] = RecipeState context = TestExecutionContext( input_data=input_data, @@ -510,9 +510,9 @@ def get_weather(location: str) -> str: } agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) - agent.chat_options.tools = [get_weather] + agent.default_options["tools"] = [get_weather] context = TestExecutionContext( input_data=input_data, @@ -559,9 +559,9 @@ def server_tool() -> str: } agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) - agent.chat_options.tools = [server_tool] + agent.default_options["tools"] = [server_tool] context = TestExecutionContext( input_data=input_data, @@ -587,7 +587,7 @@ async def test_empty_messages_handling() -> None: input_data: dict[str, Any] = {"messages": []} agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) context = TestExecutionContext( input_data=input_data, @@ -621,7 +621,7 @@ async def test_all_messages_filtered_handling() -> None: } agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) context = TestExecutionContext( input_data=input_data, @@ -663,7 +663,7 @@ async def test_confirm_changes_with_invalid_json_fallback() -> None: orchestrator = DefaultOrchestrator() input_data: dict[str, Any] = {"messages": []} agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) context = TestExecutionContext( input_data=input_data, @@ -691,7 +691,7 @@ async def test_confirm_changes_closes_active_message_before_finish() -> None: from agent_framework import FunctionCallContent, FunctionResultContent updates = [ - AgentRunResponseUpdate( + AgentResponseUpdate( contents=[ FunctionCallContent( name="write_document_local", @@ -700,13 +700,13 @@ async def test_confirm_changes_closes_active_message_before_finish() -> None: ) ] ), - AgentRunResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")]), + AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")]), ] orchestrator = DefaultOrchestrator() input_data: dict[str, Any] = {"messages": [{"role": "user", "content": "Start"}]} agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, updates=updates, ) context = TestExecutionContext( @@ -751,7 +751,7 @@ async def test_tool_result_kept_when_call_id_matches() -> None: orchestrator = DefaultOrchestrator() input_data: dict[str, Any] = {"messages": []} agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) context = TestExecutionContext( input_data=input_data, @@ -781,7 +781,7 @@ class CustomAgent: """Custom agent without ChatAgent type.""" def __init__(self) -> None: - self.chat_options = SimpleNamespace(tools=[], response_format=None) + self.default_options: dict[str, Any] = {"tools": [], "response_format": None} self.chat_client = SimpleNamespace(function_invocation_configuration=SimpleNamespace()) self.messages_received: list[Any] = [] @@ -792,9 +792,9 @@ async def run_stream( thread: Any = None, tools: list[Any] | None = None, **kwargs: Any, - ) -> AsyncGenerator[AgentRunResponseUpdate, None]: + ) -> AsyncGenerator[AgentResponseUpdate, None]: self.messages_received = messages - yield AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant") + yield AgentResponseUpdate(contents=[TextContent(text="response")], role="assistant") from agent_framework import ChatMessage, TextContent @@ -827,7 +827,7 @@ async def test_initial_state_snapshot_with_array_schema() -> None: orchestrator = DefaultOrchestrator() input_data: dict[str, Any] = {"messages": [], "state": {}} agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) context = TestExecutionContext( input_data=input_data, @@ -859,9 +859,9 @@ class OutputModel(BaseModel): input_data: dict[str, Any] = {"messages": []} agent = StubAgent( - chat_options=DEFAULT_CHAT_OPTIONS, + default_options=DEFAULT_OPTIONS, ) - agent.chat_options.response_format = OutputModel + agent.default_options["response_format"] = OutputModel context = TestExecutionContext( input_data=input_data, diff --git a/python/packages/ag-ui/tests/test_service_thread_id.py b/python/packages/ag-ui/tests/test_service_thread_id.py new file mode 100644 index 0000000000..8c00f7b67c --- /dev/null +++ b/python/packages/ag-ui/tests/test_service_thread_id.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for service-managed thread IDs, and service-generated response ids.""" + +import sys +from pathlib import Path +from typing import Any + +from ag_ui.core import RunFinishedEvent, RunStartedEvent +from agent_framework import TextContent +from agent_framework._types import AgentResponseUpdate, ChatResponseUpdate + +sys.path.insert(0, str(Path(__file__).parent)) +from utils_test_ag_ui import StubAgent + + +async def test_service_thread_id_when_there_are_updates(): + """Test that service-managed thread IDs (conversation_id) are correctly set as the thread_id in events.""" + from agent_framework.ag_ui import AgentFrameworkAgent + + updates: list[AgentResponseUpdate] = [ + AgentResponseUpdate( + contents=[TextContent(text="Hello, user!")], + response_id="resp_67890", + raw_representation=ChatResponseUpdate( + contents=[TextContent(text="Hello, user!")], + conversation_id="conv_12345", + response_id="resp_67890", + ), + ) + ] + agent = StubAgent(updates=updates) + wrapper = AgentFrameworkAgent(agent=agent) + + input_data = { + "messages": [{"role": "user", "content": "Hi"}], + } + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + assert isinstance(events[0], RunStartedEvent) + assert events[0].run_id == "resp_67890" + assert events[0].thread_id == "conv_12345" + assert isinstance(events[-1], RunFinishedEvent) + + +async def test_service_thread_id_when_no_user_message(): + """Test when user submits no messages, emitted events still have with a thread_id""" + from agent_framework.ag_ui import AgentFrameworkAgent + + updates: list[AgentResponseUpdate] = [] + agent = StubAgent(updates=updates) + wrapper = AgentFrameworkAgent(agent=agent) + + input_data: dict[str, list[dict[str, str]]] = { + "messages": [], + } + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + assert len(events) == 2 + assert isinstance(events[0], RunStartedEvent) + assert events[0].thread_id + assert isinstance(events[-1], RunFinishedEvent) + + +async def test_service_thread_id_when_user_supplied_thread_id(): + """Test that user-supplied thread IDs are preserved in emitted events.""" + from agent_framework.ag_ui import AgentFrameworkAgent + + updates: list[AgentResponseUpdate] = [] + agent = StubAgent(updates=updates) + wrapper = AgentFrameworkAgent(agent=agent) + + input_data: dict[str, Any] = {"messages": [{"role": "user", "content": "Hi"}], "threadId": "conv_12345"} + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + assert isinstance(events[0], RunStartedEvent) + assert events[0].thread_id == "conv_12345" + assert isinstance(events[-1], RunFinishedEvent) diff --git a/python/packages/ag-ui/tests/test_shared_state.py b/python/packages/ag-ui/tests/test_shared_state.py index 36f80b9d47..469f5f5ad8 100644 --- a/python/packages/ag-ui/tests/test_shared_state.py +++ b/python/packages/ag-ui/tests/test_shared_state.py @@ -8,14 +8,13 @@ import pytest from ag_ui.core import StateSnapshotEvent -from agent_framework import ChatAgent, TextContent -from agent_framework._types import ChatResponseUpdate +from agent_framework import ChatAgent, ChatResponseUpdate, TextContent from agent_framework_ag_ui._agent import AgentFrameworkAgent from agent_framework_ag_ui._events import AgentFrameworkEventBridge sys.path.insert(0, str(Path(__file__).parent)) -from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates +from utils_test_ag_ui import StreamingChatClientStub, stream_from_updates @pytest.fixture diff --git a/python/packages/ag-ui/tests/test_structured_output.py b/python/packages/ag-ui/tests/test_structured_output.py index c5f9719938..b9a04353be 100644 --- a/python/packages/ag-ui/tests/test_structured_output.py +++ b/python/packages/ag-ui/tests/test_structured_output.py @@ -8,12 +8,11 @@ from pathlib import Path from typing import Any -from agent_framework import ChatAgent, ChatMessage, ChatOptions, TextContent -from agent_framework._types import ChatResponseUpdate +from agent_framework import ChatAgent, ChatMessage, ChatOptions, ChatResponseUpdate, TextContent from pydantic import BaseModel sys.path.insert(0, str(Path(__file__).parent)) -from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates +from utils_test_ag_ui import StreamingChatClientStub, stream_from_updates class RecipeOutput(BaseModel): @@ -41,14 +40,14 @@ async def test_structured_output_with_recipe(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate( contents=[TextContent(text='{"recipe": {"name": "Pasta"}, "message": "Here is your recipe"}')] ) agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) - agent.chat_options = ChatOptions(response_format=RecipeOutput) + agent.default_options = ChatOptions(response_format=RecipeOutput) wrapper = AgentFrameworkAgent( agent=agent, @@ -79,7 +78,7 @@ async def test_structured_output_with_steps(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: steps_data = { "steps": [ @@ -90,7 +89,7 @@ async def stream_fn( yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(steps_data))]) agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) - agent.chat_options = ChatOptions(response_format=StepsOutput) + agent.default_options = ChatOptions(response_format=StepsOutput) wrapper = AgentFrameworkAgent( agent=agent, @@ -125,7 +124,7 @@ async def test_structured_output_with_no_schema_match(): agent = ChatAgent( name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_from_updates(updates)) ) - agent.chat_options = ChatOptions(response_format=GenericOutput) + agent.default_options = ChatOptions(response_format=GenericOutput) wrapper = AgentFrameworkAgent( agent=agent, @@ -155,12 +154,12 @@ class DataOutput(BaseModel): info: str async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}, "info": "processed"}')]) agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) - agent.chat_options = ChatOptions(response_format=DataOutput) + agent.default_options = ChatOptions(response_format=DataOutput) wrapper = AgentFrameworkAgent( agent=agent, @@ -214,13 +213,13 @@ async def test_structured_output_with_message_field(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: output_data = {"recipe": {"name": "Salad"}, "message": "Fresh salad recipe ready"} yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(output_data))]) agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) - agent.chat_options = ChatOptions(response_format=RecipeOutput) + agent.default_options = ChatOptions(response_format=RecipeOutput) wrapper = AgentFrameworkAgent( agent=agent, @@ -249,13 +248,13 @@ async def test_empty_updates_no_structured_processing(): from agent_framework.ag_ui import AgentFrameworkAgent async def stream_fn( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: if False: yield ChatResponseUpdate(contents=[]) agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) - agent.chat_options = ChatOptions(response_format=RecipeOutput) + agent.default_options = ChatOptions(response_format=RecipeOutput) wrapper = AgentFrameworkAgent(agent=agent) diff --git a/python/packages/ag-ui/tests/test_tooling.py b/python/packages/ag-ui/tests/test_tooling.py index b802d654c6..23d82dda90 100644 --- a/python/packages/ag-ui/tests/test_tooling.py +++ b/python/packages/ag-ui/tests/test_tooling.py @@ -1,8 +1,14 @@ # Copyright (c) Microsoft. All rights reserved. -from types import SimpleNamespace +from unittest.mock import MagicMock -from agent_framework_ag_ui._orchestration._tooling import merge_tools, register_additional_client_tools +from agent_framework import ChatAgent, ai_function + +from agent_framework_ag_ui._orchestration._tooling import ( + collect_server_tools, + merge_tools, + register_additional_client_tools, +) class DummyTool: @@ -11,6 +17,30 @@ def __init__(self, name: str) -> None: self.declaration_only = True +class MockMCPTool: + """Mock MCP tool that simulates connected MCP tool with functions.""" + + def __init__(self, functions: list[DummyTool], is_connected: bool = True) -> None: + self.functions = functions + self.is_connected = is_connected + + +@ai_function +def regular_tool() -> str: + """Regular tool for testing.""" + return "result" + + +def _create_chat_agent_with_tool(tool_name: str = "regular_tool") -> ChatAgent: + """Create a ChatAgent with a mocked chat client and a simple tool. + + Note: tool_name parameter is kept for API compatibility but the tool + will always be named 'regular_tool' since ai_function uses the function name. + """ + mock_chat_client = MagicMock() + return ChatAgent(chat_client=mock_chat_client, tools=[regular_tool]) + + def test_merge_tools_filters_duplicates() -> None: server = [DummyTool("a"), DummyTool("b")] client = [DummyTool("b"), DummyTool("c")] @@ -23,14 +53,79 @@ def test_merge_tools_filters_duplicates() -> None: def test_register_additional_client_tools_assigns_when_configured() -> None: - class Fic: - def __init__(self) -> None: - self.additional_tools = None + """register_additional_client_tools should set additional_tools on the chat client.""" + from agent_framework import BaseChatClient, FunctionInvocationConfiguration + + mock_chat_client = MagicMock(spec=BaseChatClient) + mock_chat_client.function_invocation_configuration = FunctionInvocationConfiguration() - holder = SimpleNamespace(function_invocation_configuration=Fic()) - agent = SimpleNamespace(chat_client=holder) + agent = ChatAgent(chat_client=mock_chat_client) tools = [DummyTool("x")] register_additional_client_tools(agent, tools) - assert holder.function_invocation_configuration.additional_tools == tools + assert mock_chat_client.function_invocation_configuration.additional_tools == tools + + +def test_collect_server_tools_includes_mcp_tools_when_connected() -> None: + """MCP tool functions should be included when the MCP tool is connected.""" + mcp_function1 = DummyTool("mcp_function_1") + mcp_function2 = DummyTool("mcp_function_2") + mock_mcp = MockMCPTool([mcp_function1, mcp_function2], is_connected=True) + + agent = _create_chat_agent_with_tool("regular_tool") + agent.mcp_tools = [mock_mcp] + + tools = collect_server_tools(agent) + + names = [getattr(t, "name", None) for t in tools] + assert "regular_tool" in names + assert "mcp_function_1" in names + assert "mcp_function_2" in names + assert len(tools) == 3 + + +def test_collect_server_tools_excludes_mcp_tools_when_not_connected() -> None: + """MCP tool functions should be excluded when the MCP tool is not connected.""" + mcp_function = DummyTool("mcp_function") + mock_mcp = MockMCPTool([mcp_function], is_connected=False) + + agent = _create_chat_agent_with_tool("regular_tool") + agent.mcp_tools = [mock_mcp] + + tools = collect_server_tools(agent) + + names = [getattr(t, "name", None) for t in tools] + assert "regular_tool" in names + assert "mcp_function" not in names + assert len(tools) == 1 + + +def test_collect_server_tools_works_with_no_mcp_tools() -> None: + """collect_server_tools should work when there are no MCP tools.""" + agent = _create_chat_agent_with_tool("regular_tool") + + tools = collect_server_tools(agent) + + names = [getattr(t, "name", None) for t in tools] + assert "regular_tool" in names + assert len(tools) == 1 + + +def test_collect_server_tools_with_mcp_tools_via_public_property() -> None: + """collect_server_tools should access MCP tools via the public mcp_tools property.""" + mcp_function = DummyTool("mcp_function") + mock_mcp = MockMCPTool([mcp_function], is_connected=True) + + agent = _create_chat_agent_with_tool("regular_tool") + agent.mcp_tools = [mock_mcp] + + # Verify the public property works + assert agent.mcp_tools == [mock_mcp] + + tools = collect_server_tools(agent) + + names = [getattr(t, "name", None) for t in tools] + assert "regular_tool" in names + assert "mcp_function" in names + assert len(tools) == 2 diff --git a/python/packages/ag-ui/tests/test_helpers_ag_ui.py b/python/packages/ag-ui/tests/utils_test_ag_ui.py similarity index 65% rename from python/packages/ag-ui/tests/test_helpers_ag_ui.py rename to python/packages/ag-ui/tests/utils_test_ag_ui.py index fc82b11510..c3fa590cd1 100644 --- a/python/packages/ag-ui/tests/test_helpers_ag_ui.py +++ b/python/packages/ag-ui/tests/utils_test_ag_ui.py @@ -2,30 +2,37 @@ """Shared test stubs for AG-UI tests.""" +import sys from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, MutableSequence from types import SimpleNamespace -from typing import Any +from typing import Any, Generic from agent_framework import ( AgentProtocol, - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentThread, + BaseChatClient, ChatMessage, - ChatOptions, + ChatResponse, + ChatResponseUpdate, TextContent, ) -from agent_framework._clients import BaseChatClient -from agent_framework._types import ChatResponse, ChatResponseUpdate +from agent_framework._clients import TOptions_co from agent_framework_ag_ui._message_adapters import _deduplicate_messages, _sanitize_tool_history from agent_framework_ag_ui._orchestrators import ExecutionContext +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover + StreamFn = Callable[..., AsyncIterator[ChatResponseUpdate]] ResponseFn = Callable[..., Awaitable[ChatResponse]] -class StreamingChatClientStub(BaseChatClient): +class StreamingChatClientStub(BaseChatClient[TOptions_co], Generic[TOptions_co]): """Typed streaming stub that satisfies ChatClientProtocol.""" def __init__(self, stream_fn: StreamFn, response_fn: ResponseFn | None = None) -> None: @@ -33,20 +40,22 @@ def __init__(self, stream_fn: StreamFn, response_fn: ResponseFn | None = None) - self._stream_fn = stream_fn self._response_fn = response_fn + @override async def _inner_get_streaming_response( - self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + self, *, messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - async for update in self._stream_fn(messages, chat_options, **kwargs): + async for update in self._stream_fn(messages, options, **kwargs): yield update + @override async def _inner_get_response( - self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + self, *, messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> ChatResponse: if self._response_fn is not None: - return await self._response_fn(messages, chat_options, **kwargs) + return await self._response_fn(messages, options, **kwargs) contents: list[Any] = [] - async for update in self._stream_fn(messages, chat_options, **kwargs): + async for update in self._stream_fn(messages, options, **kwargs): contents.extend(update.contents) return ChatResponse( @@ -59,7 +68,7 @@ def stream_from_updates(updates: list[ChatResponseUpdate]) -> StreamFn: """Create a stream function that yields from a static list of updates.""" async def _stream( - messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: for update in updates: yield update @@ -72,46 +81,32 @@ class StubAgent(AgentProtocol): def __init__( self, - updates: list[AgentRunResponseUpdate] | None = None, + updates: list[AgentResponseUpdate] | None = None, *, agent_id: str = "stub-agent", agent_name: str | None = "stub-agent", - chat_options: Any | None = None, + default_options: Any | None = None, chat_client: Any | None = None, ) -> None: - self._id = agent_id - self._name = agent_name - self._description = "stub agent" - self.updates = updates or [AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")] - self.chat_options = chat_options or SimpleNamespace(tools=None, response_format=None) + self.id = agent_id + self.name = agent_name + self.description = "stub agent" + self.updates = updates or [AgentResponseUpdate(contents=[TextContent(text="response")], role="assistant")] + self.default_options: dict[str, Any] = ( + default_options if isinstance(default_options, dict) else {"tools": None, "response_format": None} + ) self.chat_client = chat_client or SimpleNamespace(function_invocation_configuration=None) self.messages_received: list[Any] = [] self.tools_received: list[Any] | None = None - @property - def id(self) -> str: - return self._id - - @property - def name(self) -> str | None: - return self._name - - @property - def display_name(self) -> str: - return self._name or self._id - - @property - def description(self) -> str | None: - return self._description - async def run( self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: - return AgentRunResponse(messages=[], response_id="stub-response") + ) -> AgentResponse: + return AgentResponse(messages=[], response_id="stub-response") def run_stream( self, @@ -119,8 +114,8 @@ def run_stream( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: - async def _stream() -> AsyncIterator[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: + async def _stream() -> AsyncIterator[AgentResponseUpdate]: self.messages_received = [] if messages is None else list(messages) # type: ignore[arg-type] self.tools_received = kwargs.get("tools") for update in self.updates: diff --git a/python/packages/anthropic/agent_framework_anthropic/__init__.py b/python/packages/anthropic/agent_framework_anthropic/__init__.py index e81064b213..706740a127 100644 --- a/python/packages/anthropic/agent_framework_anthropic/__init__.py +++ b/python/packages/anthropic/agent_framework_anthropic/__init__.py @@ -2,7 +2,7 @@ import importlib.metadata -from ._chat_client import AnthropicClient +from ._chat_client import AnthropicChatOptions, AnthropicClient try: __version__ = importlib.metadata.version(__name__) @@ -10,6 +10,7 @@ __version__ = "0.0.0" # Fallback for development mode __all__ = [ + "AnthropicChatOptions", "AnthropicClient", "__version__", ] diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index b29b13fbd3..c9223e614b 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. + +import sys from collections.abc import AsyncIterable, MutableMapping, MutableSequence, Sequence -from typing import Any, ClassVar, Final, TypeVar +from typing import Any, ClassVar, Final, Generic, Literal, TypedDict from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, @@ -49,15 +51,132 @@ BetaTextBlock, BetaUsage, ) -from anthropic.types.beta.beta_bash_code_execution_tool_result_error import BetaBashCodeExecutionToolResultError -from anthropic.types.beta.beta_code_execution_tool_result_error import BetaCodeExecutionToolResultError +from anthropic.types.beta.beta_bash_code_execution_tool_result_error import ( + BetaBashCodeExecutionToolResultError, +) +from anthropic.types.beta.beta_code_execution_tool_result_error import ( + BetaCodeExecutionToolResultError, +) from pydantic import SecretStr, ValidationError +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover + +__all__ = [ + "AnthropicChatOptions", + "AnthropicClient", + "ThinkingConfig", +] + logger = get_logger("agent_framework.anthropic") ANTHROPIC_DEFAULT_MAX_TOKENS: Final[int] = 1024 BETA_FLAGS: Final[list[str]] = ["mcp-client-2025-04-04", "code-execution-2025-08-25"] + +# region Anthropic Chat Options TypedDict + + +class ThinkingConfig(TypedDict, total=False): + """Configuration for enabling Claude's extended thinking. + + When enabled, responses include ``thinking`` content blocks showing Claude's + thinking process before the final answer. Requires a minimum budget of 1,024 + tokens and counts towards your ``max_tokens`` limit. + + See https://docs.claude.com/en/docs/build-with-claude/extended-thinking for details. + + Keys: + type: "enabled" to enable extended thinking, "disabled" to disable. + budget_tokens: The token budget for thinking (minimum 1024, required when type="enabled"). + """ + + type: Literal["enabled", "disabled"] + budget_tokens: int + + +class AnthropicChatOptions(ChatOptions, total=False): + """Anthropic-specific chat options. + + Extends ChatOptions with options specific to Anthropic's Messages API. + Options that Anthropic doesn't support are typed as None to indicate they're unavailable. + + Note: + Anthropic REQUIRES max_tokens to be specified. If not provided, + a default of 1024 will be used. + + Keys: + model_id: The model to use for the request, + translates to ``model`` in Anthropic API. + temperature: Sampling temperature between 0 and 1. + top_p: Nucleus sampling parameter. + max_tokens: Maximum number of tokens to generate (REQUIRED). + stop: Stop sequences, + translates to ``stop_sequences`` in Anthropic API. + tools: List of tools (functions) available to the model. + tool_choice: How the model should use tools. + response_format: Structured output schema. + metadata: Request metadata with user_id for tracking. + user: User identifier, translates to ``metadata.user_id`` in Anthropic API. + instructions: System instructions for the model, + translates to ``system`` in Anthropic API. + top_k: Number of top tokens to consider for sampling. + service_tier: Service tier ("auto" or "standard_only"). + thinking: Extended thinking configuration for Claude models. + When enabled, responses include ``thinking`` content blocks showing Claude's + thinking process before the final answer. Requires a minimum budget of 1,024 + tokens and counts towards your ``max_tokens`` limit. + See https://docs.claude.com/en/docs/build-with-claude/extended-thinking for details. + container: Container configuration for skills. + additional_beta_flags: Additional beta flags to enable on the request. + """ + + # Anthropic-specific generation parameters (supported by all models) + top_k: int + service_tier: Literal["auto", "standard_only"] + + # Extended thinking (Claude models) + thinking: ThinkingConfig + + # Skills + container: dict[str, Any] + + # Beta features + additional_beta_flags: list[str] + + # Unsupported base options (override with None to indicate not supported) + logit_bias: None # type: ignore[misc] + seed: None # type: ignore[misc] + frequency_penalty: None # type: ignore[misc] + presence_penalty: None # type: ignore[misc] + store: None # type: ignore[misc] + + +TAnthropicOptions = TypeVar( + "TAnthropicOptions", + bound=TypedDict, # type: ignore[valid-type] + default="AnthropicChatOptions", + covariant=True, +) + +# Translation between framework options keys and Anthropic Messages API +OPTION_TRANSLATIONS: dict[str, str] = { + "model_id": "model", + "stop": "stop_sequences", + "instructions": "system", +} + + +# region Role and Finish Reason Maps + + ROLE_MAP: dict[Role, str] = { Role.USER: "user", Role.ASSISTANT: "assistant", @@ -111,13 +230,10 @@ class AnthropicSettings(AFBaseSettings): chat_model_id: str | None = None -TAnthropicClient = TypeVar("TAnthropicClient", bound="AnthropicClient") - - @use_function_invocation @use_instrumentation @use_chat_middleware -class AnthropicClient(BaseChatClient): +class AnthropicClient(BaseChatClient[TAnthropicOptions], Generic[TAnthropicOptions]): """Anthropic Chat client.""" OTEL_PROVIDER_NAME: ClassVar[str] = "anthropic" # type: ignore[reportIncompatibleVariableOverride, misc] @@ -177,6 +293,18 @@ def __init__( anthropic_client=anthropic_client, ) + # Using custom ChatOptions with type safety: + from typing import TypedDict + from agent_framework.anthropic import AnthropicChatOptions + + + class MyOptions(AnthropicChatOptions, total=False): + my_custom_option: str + + + client: AnthropicClient[MyOptions] = AnthropicClient(model_id="claude-sonnet-4-5-20250929") + response = await client.get_response("Hello", options={"my_custom_option": "value"}) + """ try: anthropic_settings = AnthropicSettings( @@ -212,29 +340,31 @@ def __init__( # region Get response methods + @override async def _inner_get_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> ChatResponse: # prepare - run_options = self._prepare_options(messages, chat_options, **kwargs) + run_options = self._prepare_options(messages, options, **kwargs) # execute message = await self.anthropic_client.beta.messages.create(**run_options, stream=False) # process return self._process_message(message) + @override async def _inner_get_streaming_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: # prepare - run_options = self._prepare_options(messages, chat_options, **kwargs) + run_options = self._prepare_options(messages, options, **kwargs) # execute and process async for chunk in await self.anthropic_client.beta.messages.create(**run_options, stream=True): parsed_chunk = self._process_stream_event(chunk) @@ -246,35 +376,31 @@ async def _inner_get_streaming_response( def _prepare_options( self, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> dict[str, Any]: - """Create run options for the Anthropic client based on messages and chat options. + """Create run options for the Anthropic client based on messages and options. Args: messages: The list of chat messages. - chat_options: The chat options. + options: The options dict. kwargs: Additional keyword arguments. Returns: A dictionary of run options for the Anthropic client. """ - run_options: dict[str, Any] = chat_options.to_dict( - exclude={ - "type", - "instructions", # handled via system message - "tool_choice", # handled separately - "allow_multiple_tool_calls", # handled via tool_choice - "additional_properties", # handled separately - } - ) + # Prepend instructions from options if they exist + instructions = options.get("instructions") + if instructions: + from agent_framework._types import prepend_instructions_to_messages - # translations between ChatOptions and Anthropic API - translations = { - "model_id": "model", - "stop": "stop_sequences", - } - for old_key, new_key in translations.items(): + messages = prepend_instructions_to_messages(list(messages), instructions, role="system") + + # Start with a copy of options + run_options: dict[str, Any] = {k: v for k, v in options.items() if v is not None and k not in {"instructions"}} + + # Translation between options keys and Anthropic Messages API + for old_key, new_key in OPTION_TRANSLATIONS.items(): if old_key in run_options and old_key != new_key: run_options[new_key] = run_options.pop(old_key) @@ -296,31 +422,30 @@ def _prepare_options( run_options["system"] = messages[0].text # betas - run_options["betas"] = self._prepare_betas(chat_options) + run_options["betas"] = self._prepare_betas(options) # extra headers run_options["extra_headers"] = {"User-Agent": AGENT_FRAMEWORK_USER_AGENT} + # Handle user option -> metadata.user_id (Anthropic uses metadata.user_id instead of user) + if user := run_options.pop("user", None): + metadata = run_options.get("metadata", {}) + if "user_id" not in metadata: + metadata["user_id"] = user + run_options["metadata"] = metadata + # tools, mcp servers and tool choice - if tools_config := self._prepare_tools_for_anthropic(chat_options): + if tools_config := self._prepare_tools_for_anthropic(options): run_options.update(tools_config) - # additional properties - additional_options = { - key: value - for key, value in chat_options.additional_properties.items() - if value is not None and key != "additional_beta_flags" - } - if additional_options: - run_options.update(additional_options) run_options.update(kwargs) return run_options - def _prepare_betas(self, chat_options: ChatOptions) -> set[str]: + def _prepare_betas(self, options: dict[str, Any]) -> set[str]: """Prepare the beta flags for the Anthropic API request. Args: - chat_options: The chat options that may contain additional beta flags. + options: The options dict that may contain additional beta flags. Returns: A set of beta flag strings to include in the request. @@ -328,7 +453,7 @@ def _prepare_betas(self, chat_options: ChatOptions) -> set[str]: return { *BETA_FLAGS, *self.additional_beta_flags, - *chat_options.additional_properties.get("additional_beta_flags", []), + *options.get("additional_beta_flags", []), } def _prepare_messages_for_anthropic(self, messages: MutableSequence[ChatMessage]) -> list[dict[str, Any]]: @@ -370,7 +495,10 @@ def _prepare_message_for_anthropic(self, message: ChatMessage) -> dict[str, Any] logger.debug(f"Ignoring unsupported data content media type: {content.media_type} for now") case "uri": if content.has_top_level_media_type("image"): - a_content.append({"type": "image", "source": {"type": "url", "url": content.uri}}) + a_content.append({ + "type": "image", + "source": {"type": "url", "url": content.uri}, + }) else: logger.debug(f"Ignoring unsupported data content media type: {content.media_type} for now") case "function_call": @@ -397,22 +525,25 @@ def _prepare_message_for_anthropic(self, message: ChatMessage) -> dict[str, Any] "content": a_content, } - def _prepare_tools_for_anthropic(self, chat_options: ChatOptions) -> dict[str, Any] | None: + def _prepare_tools_for_anthropic(self, options: dict[str, Any]) -> dict[str, Any] | None: """Prepare tools and tool choice configuration for the Anthropic API request. Args: - chat_options: The chat options containing tools and tool choice settings. + options: The options dict containing tools and tool choice settings. Returns: A dictionary with tools, mcp_servers, and tool_choice configuration, or None if empty. """ + from agent_framework._types import validate_tool_mode + result: dict[str, Any] = {} + tools = options.get("tools") # Process tools - if chat_options.tools: + if tools: tool_list: list[MutableMapping[str, Any]] = [] mcp_server_list: list[MutableMapping[str, Any]] = [] - for tool in chat_options.tools: + for tool in tools: match tool: case MutableMapping(): tool_list.append(tool) @@ -457,34 +588,31 @@ def _prepare_tools_for_anthropic(self, chat_options: ChatOptions) -> dict[str, A result["mcp_servers"] = mcp_server_list # Process tool choice - if chat_options.tool_choice is not None: - tool_choice_mode = ( - chat_options.tool_choice if isinstance(chat_options.tool_choice, str) else chat_options.tool_choice.mode - ) - match tool_choice_mode: - case "auto": - tool_choice: dict[str, Any] = {"type": "auto"} - if chat_options.allow_multiple_tool_calls is not None: - tool_choice["disable_parallel_tool_use"] = not chat_options.allow_multiple_tool_calls - result["tool_choice"] = tool_choice - case "required": - if ( - not isinstance(chat_options.tool_choice, str) - and chat_options.tool_choice.required_function_name - ): - tool_choice = { - "type": "tool", - "name": chat_options.tool_choice.required_function_name, - } - else: - tool_choice = {"type": "any"} - if chat_options.allow_multiple_tool_calls is not None: - tool_choice["disable_parallel_tool_use"] = not chat_options.allow_multiple_tool_calls - result["tool_choice"] = tool_choice - case "none": - result["tool_choice"] = {"type": "none"} - case _: - logger.debug(f"Ignoring unsupported tool choice mode: {tool_choice_mode} for now") + if options.get("tool_choice") is None: + return result or None + tool_mode = validate_tool_mode(options.get("tool_choice")) + allow_multiple = options.get("allow_multiple_tool_calls") + match tool_mode.get("mode"): + case "auto": + tool_choice: dict[str, Any] = {"type": "auto"} + if allow_multiple is not None: + tool_choice["disable_parallel_tool_use"] = not allow_multiple + result["tool_choice"] = tool_choice + case "required": + if "required_function_name" in tool_mode: + tool_choice = { + "type": "tool", + "name": tool_mode["required_function_name"], + } + else: + tool_choice = {"type": "any"} + if allow_multiple is not None: + tool_choice["disable_parallel_tool_use"] = not allow_multiple + result["tool_choice"] = tool_choice + case "none": + result["tool_choice"] = {"type": "none"} + case _: + logger.debug(f"Ignoring unsupported tool choice mode: {tool_mode} for now") return result or None @@ -531,7 +659,10 @@ def _process_stream_event(self, event: BetaRawMessageStreamEvent) -> ChatRespons return ChatResponseUpdate( response_id=event.message.id, - contents=[*self._parse_contents_from_anthropic(event.message.content), *usage_details], + contents=[ + *self._parse_contents_from_anthropic(event.message.content), + *usage_details, + ], model_id=event.message.model, finish_reason=FINISH_REASON_MAP.get(event.message.stop_reason) if event.message.stop_reason @@ -579,7 +710,8 @@ def _parse_usage_from_anthropic(self, usage: BetaUsage | BetaMessageDeltaUsage | return usage_details def _parse_contents_from_anthropic( - self, content: Sequence[BetaContentBlock | BetaRawContentBlockDelta | BetaTextBlock] + self, + content: Sequence[BetaContentBlock | BetaRawContentBlockDelta | BetaTextBlock], ) -> list[Contents]: """Parse contents from the Anthropic message.""" contents: list[Contents] = [] @@ -609,7 +741,12 @@ def _parse_contents_from_anthropic( contents.append( CodeInterpreterToolCallContent( call_id=content_block.id, - inputs=[TextContent(text=str(content_block.input), raw_representation=content_block)], + inputs=[ + TextContent( + text=str(content_block.input), + raw_representation=content_block, + ) + ], raw_representation=content_block, ) ) @@ -630,7 +767,10 @@ def _parse_contents_from_anthropic( parsed_output = self._parse_contents_from_anthropic(content_block.content) elif isinstance(content_block.content, (str, bytes)): parsed_output = [ - TextContent(text=str(content_block.content), raw_representation=content_block) + TextContent( + text=str(content_block.content), + raw_representation=content_block, + ) ] else: parsed_output = self._parse_contents_from_anthropic([content_block.content]) @@ -679,7 +819,8 @@ def _parse_contents_from_anthropic( for code_file_content in content_block.content.content: code_outputs.append( HostedFileContent( - file_id=code_file_content.file_id, raw_representation=code_file_content + file_id=code_file_content.file_id, + raw_representation=code_file_content, ) ) contents.append( @@ -720,7 +861,8 @@ def _parse_contents_from_anthropic( for bash_file_content in content_block.content.content: contents.append( HostedFileContent( - file_id=bash_file_content.file_id, raw_representation=bash_file_content + file_id=bash_file_content.file_id, + raw_representation=bash_file_content, ) ) contents.append( @@ -832,17 +974,27 @@ def _parse_contents_from_anthropic( ) ) case "input_json_delta": - call_id, name = self._last_call_id_name if self._last_call_id_name else ("", "") + # For streaming argument deltas, only pass call_id and arguments. + # Pass empty string for name - it causes ag-ui to emit duplicate ToolCallStartEvents + # since it triggers on `if content.name:`. The initial tool_use event already + # provides the name, so deltas should only carry incremental arguments. + # This matches OpenAI's behavior where streaming chunks have name="". + call_id, _ = self._last_call_id_name if self._last_call_id_name else ("", "") contents.append( FunctionCallContent( call_id=call_id, - name=name, + name="", arguments=content_block.partial_json, raw_representation=content_block, ) ) case "thinking" | "thinking_delta": - contents.append(TextReasoningContent(text=content_block.thinking, raw_representation=content_block)) + contents.append( + TextReasoningContent( + text=content_block.thinking, + raw_representation=content_block, + ) + ) case _: logger.debug(f"Ignoring unsupported content type: {content_block.type} for now") return contents @@ -865,7 +1017,10 @@ def _parse_citations_from_anthropic( if not cit.annotated_regions: cit.annotated_regions = [] cit.annotated_regions.append( - TextSpanRegion(start_index=citation.start_char_index, end_index=citation.end_char_index) + TextSpanRegion( + start_index=citation.start_char_index, + end_index=citation.end_char_index, + ) ) case "page_location": cit.title = citation.document_title @@ -888,7 +1043,10 @@ def _parse_citations_from_anthropic( if not cit.annotated_regions: cit.annotated_regions = [] cit.annotated_regions.append( - TextSpanRegion(start_index=citation.start_block_index, end_index=citation.end_block_index) + TextSpanRegion( + start_index=citation.start_block_index, + end_index=citation.end_block_index, + ) ) case "web_search_result_location": cit.title = citation.title @@ -901,7 +1059,10 @@ def _parse_citations_from_anthropic( if not cit.annotated_regions: cit.annotated_regions = [] cit.annotated_regions.append( - TextSpanRegion(start_index=citation.start_block_index, end_index=citation.end_block_index) + TextSpanRegion( + start_index=citation.start_block_index, + end_index=citation.end_block_index, + ) ) case _: logger.debug(f"Unknown citation type encountered: {citation.type}") diff --git a/python/packages/anthropic/pyproject.toml b/python/packages/anthropic/pyproject.toml index 73528ccfec..53eb699cd4 100644 --- a/python/packages/anthropic/pyproject.toml +++ b/python/packages/anthropic/pyproject.toml @@ -4,7 +4,7 @@ description = "Anthropic integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/anthropic/tests/test_anthropic_client.py b/python/packages/anthropic/tests/test_anthropic_client.py index e8a3ac9cb0..828d9916c2 100644 --- a/python/packages/anthropic/tests/test_anthropic_client.py +++ b/python/packages/anthropic/tests/test_anthropic_client.py @@ -595,6 +595,53 @@ def test_parse_contents_from_anthropic_tool_use(mock_anthropic_client: MagicMock assert result[0].name == "get_weather" +def test_parse_contents_from_anthropic_input_json_delta_no_duplicate_name(mock_anthropic_client: MagicMock) -> None: + """Test that input_json_delta events have empty name to prevent duplicate ToolCallStartEvents. + + When streaming tool calls, the initial tool_use event provides the name, + and subsequent input_json_delta events should have name="" to prevent + ag-ui from emitting duplicate ToolCallStartEvents. + """ + chat_client = create_test_anthropic_client(mock_anthropic_client) + + # First, simulate a tool_use event that sets _last_call_id_name + tool_use_content = MagicMock() + tool_use_content.type = "tool_use" + tool_use_content.id = "call_123" + tool_use_content.name = "get_weather" + tool_use_content.input = {} + + result = chat_client._parse_contents_from_anthropic([tool_use_content]) + assert len(result) == 1 + assert isinstance(result[0], FunctionCallContent) + assert result[0].call_id == "call_123" + assert result[0].name == "get_weather" # Initial event has name + + # Now simulate input_json_delta events (argument streaming) + delta_content_1 = MagicMock() + delta_content_1.type = "input_json_delta" + delta_content_1.partial_json = '{"location":' + + result = chat_client._parse_contents_from_anthropic([delta_content_1]) + assert len(result) == 1 + assert isinstance(result[0], FunctionCallContent) + assert result[0].call_id == "call_123" + assert result[0].name == "" # Delta events should have empty name + assert result[0].arguments == '{"location":' + + # Another delta + delta_content_2 = MagicMock() + delta_content_2.type = "input_json_delta" + delta_content_2.partial_json = '"San Francisco"}' + + result = chat_client._parse_contents_from_anthropic([delta_content_2]) + assert len(result) == 1 + assert isinstance(result[0], FunctionCallContent) + assert result[0].call_id == "call_123" + assert result[0].name == "" # Still empty name for subsequent deltas + assert result[0].arguments == '"San Francisco"}' + + # Stream Processing Tests @@ -630,7 +677,7 @@ async def test_inner_get_response(mock_anthropic_client: MagicMock) -> None: chat_options = ChatOptions(max_tokens=10) response = await chat_client._inner_get_response( # type: ignore[attr-defined] - messages=messages, chat_options=chat_options + messages=messages, options=chat_options ) assert response is not None @@ -655,7 +702,7 @@ async def mock_stream(): chunks: list[ChatResponseUpdate] = [] async for chunk in chat_client._inner_get_streaming_response( # type: ignore[attr-defined] - messages=messages, chat_options=chat_options + messages=messages, options=chat_options ): if chunk: chunks.append(chunk) @@ -683,7 +730,7 @@ async def test_anthropic_client_integration_basic_chat() -> None: messages = [ChatMessage(role=Role.USER, text="Say 'Hello, World!' and nothing else.")] - response = await client.get_response(messages=messages, chat_options=ChatOptions(max_tokens=50)) + response = await client.get_response(messages=messages, options={"max_tokens": 50}) assert response is not None assert len(response.messages) > 0 @@ -701,7 +748,7 @@ async def test_anthropic_client_integration_streaming_chat() -> None: messages = [ChatMessage(role=Role.USER, text="Count from 1 to 5.")] chunks = [] - async for chunk in client.get_streaming_response(messages=messages, chat_options=ChatOptions(max_tokens=50)): + async for chunk in client.get_streaming_response(messages=messages, options={"max_tokens": 50}): chunks.append(chunk) assert len(chunks) > 0 @@ -719,7 +766,7 @@ async def test_anthropic_client_integration_function_calling() -> None: response = await client.get_response( messages=messages, - chat_options=ChatOptions(tools=tools, max_tokens=100), + options={"tools": tools, "max_tokens": 100}, ) assert response is not None @@ -749,7 +796,7 @@ async def test_anthropic_client_integration_hosted_tools() -> None: response = await client.get_response( messages=messages, - chat_options=ChatOptions(tools=tools, max_tokens=100), + options={"tools": tools, "max_tokens": 100}, ) assert response is not None @@ -767,7 +814,7 @@ async def test_anthropic_client_integration_with_system_message() -> None: ChatMessage(role=Role.USER, text="Hello!"), ] - response = await client.get_response(messages=messages, chat_options=ChatOptions(max_tokens=50)) + response = await client.get_response(messages=messages, options={"max_tokens": 50}) assert response is not None assert len(response.messages) > 0 @@ -783,7 +830,7 @@ async def test_anthropic_client_integration_temperature_control() -> None: response = await client.get_response( messages=messages, - chat_options=ChatOptions(max_tokens=20, temperature=0.0), + options={"max_tokens": 20, "temperature": 0.0}, ) assert response is not None diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_search_provider.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_search_provider.py index a63ad1deb2..ac81a3c50b 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_search_provider.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_search_provider.py @@ -91,16 +91,16 @@ except ImportError: _agentic_retrieval_available = False -if sys.version_info >= (3, 11): - from typing import Self # pragma: no cover -else: - from typing_extensions import Self # pragma: no cover - if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: from typing_extensions import override # type: ignore[import] # pragma: no cover +if sys.version_info >= (3, 11): + from typing import Self # pragma: no cover +else: + from typing_extensions import Self # pragma: no cover + """Azure AI Search Context Provider for Agent Framework. This module provides context providers for Azure AI Search integration with two modes: diff --git a/python/packages/azure-ai-search/pyproject.toml b/python/packages/azure-ai-search/pyproject.toml index 3c782022a6..e43f010023 100644 --- a/python/packages/azure-ai-search/pyproject.toml +++ b/python/packages/azure-ai-search/pyproject.toml @@ -4,7 +4,7 @@ description = "Azure AI Search integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/azure-ai/agent_framework_azure_ai/__init__.py b/python/packages/azure-ai/agent_framework_azure_ai/__init__.py index cf2423693d..c0cd4d249c 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/__init__.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/__init__.py @@ -2,8 +2,10 @@ import importlib.metadata -from ._chat_client import AzureAIAgentClient +from ._agent_provider import AzureAIAgentsProvider +from ._chat_client import AzureAIAgentClient, AzureAIAgentOptions from ._client import AzureAIClient +from ._project_provider import AzureAIProjectAgentProvider from ._shared import AzureAISettings try: @@ -13,7 +15,10 @@ __all__ = [ "AzureAIAgentClient", + "AzureAIAgentOptions", + "AzureAIAgentsProvider", "AzureAIClient", + "AzureAIProjectAgentProvider", "AzureAISettings", "__version__", ] diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_agent_provider.py b/python/packages/azure-ai/agent_framework_azure_ai/_agent_provider.py new file mode 100644 index 0000000000..6ed8853977 --- /dev/null +++ b/python/packages/azure-ai/agent_framework_azure_ai/_agent_provider.py @@ -0,0 +1,519 @@ +# Copyright (c) Microsoft. All rights reserved. + +import sys +from collections.abc import Callable, MutableMapping, Sequence +from typing import TYPE_CHECKING, Any, Generic, TypedDict, cast + +from agent_framework import ( + AGENT_FRAMEWORK_USER_AGENT, + AIFunction, + ChatAgent, + ContextProvider, + Middleware, + ToolProtocol, + normalize_tools, +) +from agent_framework._mcp import MCPTool +from agent_framework.exceptions import ServiceInitializationError +from azure.ai.agents.aio import AgentsClient +from azure.ai.agents.models import Agent, ResponseFormatJsonSchema, ResponseFormatJsonSchemaType +from azure.core.credentials_async import AsyncTokenCredential +from pydantic import BaseModel, ValidationError + +from ._chat_client import AzureAIAgentClient +from ._shared import AzureAISettings, from_azure_ai_agent_tools, to_azure_ai_agent_tools + +if TYPE_CHECKING: + from ._chat_client import AzureAIAgentOptions + +if sys.version_info >= (3, 13): + from typing import Self, TypeVar # pragma: no cover +else: + from typing_extensions import Self, TypeVar # pragma: no cover + + +# Type variable for options - allows typed ChatAgent[TOptions] returns +# Default matches AzureAIAgentClient's default options type +TOptions_co = TypeVar( + "TOptions_co", + bound=TypedDict, # type: ignore[valid-type] + default="AzureAIAgentOptions", + covariant=True, +) + + +class AzureAIAgentsProvider(Generic[TOptions_co]): + """Provider for Azure AI Agent Service V1 (Persistent Agents API). + + This provider enables creating, retrieving, and wrapping Azure AI agents as ChatAgent + instances. It manages the underlying AgentsClient lifecycle and provides a high-level + interface for agent operations. + + The provider can be initialized with either: + - An existing AgentsClient instance + - Azure credentials and endpoint for automatic client creation + + Examples: + Using credentials (auto-creates client): + + .. code-block:: python + + from agent_framework.azure import AzureAIAgentsProvider + from azure.identity.aio import AzureCliCredential + + async with ( + AzureCliCredential() as credential, + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( + name="MyAgent", + instructions="You are a helpful assistant.", + ) + result = await agent.run("Hello!") + + Using existing AgentsClient: + + .. code-block:: python + + from agent_framework.azure import AzureAIAgentsProvider + from azure.ai.agents.aio import AgentsClient + + async with AgentsClient(endpoint=endpoint, credential=credential) as client: + provider = AzureAIAgentsProvider(agents_client=client) + agent = await provider.create_agent(name="MyAgent", instructions="...") + """ + + def __init__( + self, + agents_client: AgentsClient | None = None, + *, + project_endpoint: str | None = None, + credential: AsyncTokenCredential | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + ) -> None: + """Initialize the Azure AI Agents Provider. + + Args: + agents_client: An existing AgentsClient to use. If provided, the provider + will not manage its lifecycle. + + Keyword Args: + project_endpoint: The Azure AI Project endpoint URL. + Can also be set via AZURE_AI_PROJECT_ENDPOINT environment variable. + credential: Azure async credential for authentication. + Required if agents_client is not provided. + env_file_path: Path to .env file for loading settings. + env_file_encoding: Encoding of the .env file. + + Raises: + ServiceInitializationError: If required parameters are missing or invalid. + """ + try: + self._settings = AzureAISettings( + project_endpoint=project_endpoint, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) + except ValidationError as ex: + raise ServiceInitializationError("Failed to create Azure AI settings.", ex) from ex + + self._should_close_client = False + + if agents_client is not None: + self._agents_client = agents_client + else: + if not self._settings.project_endpoint: + raise ServiceInitializationError( + "Azure AI project endpoint is required. Provide 'project_endpoint' parameter " + "or set 'AZURE_AI_PROJECT_ENDPOINT' environment variable." + ) + if not credential: + raise ServiceInitializationError("Azure credential is required when agents_client is not provided.") + self._agents_client = AgentsClient( + endpoint=self._settings.project_endpoint, + credential=credential, + user_agent=AGENT_FRAMEWORK_USER_AGENT, + ) + self._should_close_client = True + + async def __aenter__(self) -> "Self": + """Async context manager entry.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit.""" + await self.close() + + async def close(self) -> None: + """Close the provider and release resources. + + Only closes the AgentsClient if it was created by this provider. + """ + if self._should_close_client: + await self._agents_client.close() + + async def create_agent( + self, + name: str, + *, + model: str | None = None, + instructions: str | None = None, + description: str | None = None, + tools: ToolProtocol + | Callable[..., Any] + | MutableMapping[str, Any] + | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] + | None = None, + default_options: TOptions_co | None = None, + middleware: Sequence[Middleware] | None = None, + context_provider: ContextProvider | None = None, + ) -> "ChatAgent[TOptions_co]": + """Create a new agent on the Azure AI service and return a ChatAgent. + + This method creates a persistent agent on the Azure AI service with the specified + configuration and returns a local ChatAgent instance for interaction. + + Args: + name: The name for the agent. + + Keyword Args: + model: The model deployment name to use. Falls back to + AZURE_AI_MODEL_DEPLOYMENT_NAME environment variable if not provided. + instructions: Instructions for the agent's behavior. + description: A description of the agent's purpose. + tools: Tools to make available to the agent. + default_options: A TypedDict containing default chat options for the agent. + These options are applied to every run unless overridden. + middleware: List of middleware to intercept agent and function invocations. + context_provider: Context provider to include during agent invocation. + + Returns: + ChatAgent: A ChatAgent instance configured with the created agent. + + Raises: + ServiceInitializationError: If model deployment name is not available. + + Examples: + .. code-block:: python + + agent = await provider.create_agent( + name="WeatherAgent", + instructions="You are a helpful weather assistant.", + tools=get_weather, + ) + """ + resolved_model = model or self._settings.model_deployment_name + if not resolved_model: + raise ServiceInitializationError( + "Model deployment name is required. Provide 'model' parameter " + "or set 'AZURE_AI_MODEL_DEPLOYMENT_NAME' environment variable." + ) + + # Extract response_format from default_options if present + opts = dict(default_options) if default_options else {} + response_format = opts.get("response_format") + + args: dict[str, Any] = { + "model": resolved_model, + "name": name, + } + + if description: + args["description"] = description + if instructions: + args["instructions"] = instructions + + # Handle response format + if response_format and isinstance(response_format, type) and issubclass(response_format, BaseModel): + args["response_format"] = self._create_response_format_config(response_format) + + # Normalize and convert tools + # Local MCP tools (MCPTool) are handled by ChatAgent at runtime, not stored on the Azure agent + normalized_tools = normalize_tools(tools) + if normalized_tools: + # Only convert non-MCP tools to Azure AI format + non_mcp_tools = [t for t in normalized_tools if not isinstance(t, MCPTool)] + if non_mcp_tools: + # Pass run_options to capture tool_resources (e.g., for file search vector stores) + run_options: dict[str, Any] = {} + args["tools"] = to_azure_ai_agent_tools(non_mcp_tools, run_options) + if "tool_resources" in run_options: + args["tool_resources"] = run_options["tool_resources"] + + # Create the agent on the service + created_agent = await self._agents_client.create_agent(**args) + + # Create ChatAgent wrapper + return self._to_chat_agent_from_agent( + created_agent, + normalized_tools, + default_options=default_options, + middleware=middleware, + context_provider=context_provider, + ) + + async def get_agent( + self, + id: str, + *, + tools: ToolProtocol + | Callable[..., Any] + | MutableMapping[str, Any] + | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] + | None = None, + default_options: TOptions_co | None = None, + middleware: Sequence[Middleware] | None = None, + context_provider: ContextProvider | None = None, + ) -> "ChatAgent[TOptions_co]": + """Retrieve an existing agent from the service and return a ChatAgent. + + This method fetches an agent by ID from the Azure AI service + and returns a local ChatAgent instance for interaction. + + Args: + id: The ID of the agent to retrieve from the service. + + Keyword Args: + tools: Tools to make available to the agent. Required if the agent + has function tools that need implementations. + default_options: A TypedDict containing default chat options for the agent. + These options are applied to every run unless overridden. + middleware: List of middleware to intercept agent and function invocations. + context_provider: Context provider to include during agent invocation. + + Returns: + ChatAgent: A ChatAgent instance configured with the retrieved agent. + + Raises: + ServiceInitializationError: If required function tools are not provided. + + Examples: + .. code-block:: python + + agent = await provider.get_agent("agent-123") + + # With function tools + agent = await provider.get_agent("agent-123", tools=my_function) + """ + agent = await self._agents_client.get_agent(id) + + # Validate function tools + normalized_tools = normalize_tools(tools) + self._validate_function_tools(agent.tools, normalized_tools) + + return self._to_chat_agent_from_agent( + agent, + normalized_tools, + default_options=default_options, + middleware=middleware, + context_provider=context_provider, + ) + + def as_agent( + self, + agent: Agent, + tools: ToolProtocol + | Callable[..., Any] + | MutableMapping[str, Any] + | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] + | None = None, + default_options: TOptions_co | None = None, + middleware: Sequence[Middleware] | None = None, + context_provider: ContextProvider | None = None, + ) -> "ChatAgent[TOptions_co]": + """Wrap an existing Agent SDK object as a ChatAgent without making HTTP calls. + + Use this method when you already have an Agent object from a previous + SDK operation and want to use it with the Agent Framework. + + Args: + agent: The Agent object to wrap. + tools: Tools to make available to the agent. Required if the agent + has function tools that need implementations. + default_options: A TypedDict containing default chat options for the agent. + These options are applied to every run unless overridden. + middleware: List of middleware to intercept agent and function invocations. + context_provider: Context provider to include during agent invocation. + + Returns: + ChatAgent: A ChatAgent instance configured with the agent. + + Raises: + ServiceInitializationError: If required function tools are not provided. + + Examples: + .. code-block:: python + + # Create agent directly with SDK + sdk_agent = await agents_client.create_agent( + model="gpt-4", + name="MyAgent", + instructions="...", + ) + + # Wrap as ChatAgent + chat_agent = provider.as_agent(sdk_agent) + """ + # Validate function tools + normalized_tools = normalize_tools(tools) + self._validate_function_tools(agent.tools, normalized_tools) + + return self._to_chat_agent_from_agent( + agent, + normalized_tools, + default_options=default_options, + middleware=middleware, + context_provider=context_provider, + ) + + def _to_chat_agent_from_agent( + self, + agent: Agent, + provided_tools: Sequence[ToolProtocol | MutableMapping[str, Any]] | None = None, + default_options: TOptions_co | None = None, + middleware: Sequence[Middleware] | None = None, + context_provider: ContextProvider | None = None, + ) -> "ChatAgent[TOptions_co]": + """Create a ChatAgent from an Agent SDK object. + + Args: + agent: The Agent SDK object. + provided_tools: User-provided tools (including function implementations). + default_options: A TypedDict containing default chat options for the agent. + These options are applied to every run unless overridden. + middleware: List of middleware to intercept agent and function invocations. + context_provider: Context provider to include during agent invocation. + """ + # Create the underlying client + client = AzureAIAgentClient( + agents_client=self._agents_client, + agent_id=agent.id, + agent_name=agent.name, + agent_description=agent.description, + should_cleanup_agent=False, # Provider manages agent lifecycle + ) + + # Merge tools: convert agent's hosted tools + user-provided function tools + merged_tools = self._merge_tools(agent.tools, provided_tools) + + return ChatAgent( # type: ignore[return-value] + chat_client=client, + id=agent.id, + name=agent.name, + description=agent.description, + instructions=agent.instructions, + model_id=agent.model, + tools=merged_tools, + default_options=default_options, # type: ignore[arg-type] + middleware=middleware, + context_provider=context_provider, + ) + + def _merge_tools( + self, + agent_tools: Sequence[Any] | None, + provided_tools: Sequence[ToolProtocol | MutableMapping[str, Any]] | None, + ) -> list[ToolProtocol | dict[str, Any]]: + """Merge hosted tools from agent with user-provided function tools. + + Args: + agent_tools: Tools from the agent definition (Azure AI format). + provided_tools: User-provided tools (Agent Framework format). + + Returns: + Combined list of tools for the ChatAgent. + """ + merged: list[ToolProtocol | dict[str, Any]] = [] + + # Convert hosted tools from agent definition + hosted_tools = from_azure_ai_agent_tools(agent_tools) + for hosted_tool in hosted_tools: + # Skip function tool dicts - they don't have implementations + # Skip OpenAPI tool dicts - they're defined on the agent, not needed at runtime + if isinstance(hosted_tool, dict): + tool_type = hosted_tool.get("type") + if tool_type == "function" or tool_type == "openapi": + continue + merged.append(hosted_tool) + + # Add user-provided function tools and MCP tools + if provided_tools: + for provided_tool in provided_tools: + # AIFunction - has implementation for function calling + # MCPTool - ChatAgent handles MCP connection and tool discovery at runtime + if isinstance(provided_tool, (AIFunction, MCPTool)): + merged.append(provided_tool) # type: ignore[reportUnknownArgumentType] + + return merged + + def _validate_function_tools( + self, + agent_tools: Sequence[Any] | None, + provided_tools: Sequence[ToolProtocol | MutableMapping[str, Any]] | None, + ) -> None: + """Validate that required function tools are provided. + + Raises: + ServiceInitializationError: If agent has function tools but user + didn't provide implementations. + """ + if not agent_tools: + return + + # Get function tool names from agent definition + function_tool_names: set[str] = set() + for tool in agent_tools: + if isinstance(tool, dict): + tool_dict = cast(dict[str, Any], tool) + if tool_dict.get("type") == "function": + func_def = cast(dict[str, Any], tool_dict.get("function", {})) + name = func_def.get("name") + if isinstance(name, str): + function_tool_names.add(name) + elif hasattr(tool, "type") and tool.type == "function": + func_attr = getattr(tool, "function", None) + if func_attr and hasattr(func_attr, "name"): + function_tool_names.add(str(func_attr.name)) + + if not function_tool_names: + return + + # Get provided function names + provided_names: set[str] = set() + if provided_tools: + for tool in provided_tools: + if isinstance(tool, AIFunction): + provided_names.add(tool.name) + + # Check for missing implementations + missing = function_tool_names - provided_names + if missing: + raise ServiceInitializationError( + f"Agent has function tools that require implementations: {missing}. " + "Provide these functions via the 'tools' parameter." + ) + + def _create_response_format_config( + self, + response_format: type[BaseModel], + ) -> ResponseFormatJsonSchemaType: + """Create response format configuration for Azure AI. + + Args: + response_format: Pydantic model for structured output. + + Returns: + Azure AI response format configuration. + """ + return ResponseFormatJsonSchemaType( + json_schema=ResponseFormatJsonSchema( + name=response_format.__name__, + schema=response_format.model_json_schema(), + ) + ) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index 50d18bbdc1..931f57500e 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -2,15 +2,13 @@ import ast import json -import os import re import sys -from collections.abc import AsyncIterable, MutableMapping, MutableSequence, Sequence -from typing import Any, ClassVar, TypeVar +from collections.abc import AsyncIterable, Mapping, MutableMapping, MutableSequence, Sequence +from typing import Any, ClassVar, Generic, TypedDict from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, - AIFunction, BaseChatClient, ChatMessage, ChatOptions, @@ -23,16 +21,11 @@ FunctionApprovalResponseContent, FunctionCallContent, FunctionResultContent, - HostedCodeInterpreterTool, HostedFileContent, - HostedFileSearchTool, HostedMCPTool, - HostedVectorStoreContent, - HostedWebSearchTool, Role, TextContent, TextSpanRegion, - ToolMode, ToolProtocol, UriContent, UsageContent, @@ -42,7 +35,7 @@ use_chat_middleware, use_function_invocation, ) -from agent_framework.exceptions import ServiceInitializationError, ServiceResponseException +from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidRequestError, ServiceResponseException from agent_framework.observability import use_instrumentation from azure.ai.agents.aio import AgentsClient from azure.ai.agents.models import ( @@ -53,14 +46,9 @@ AgentStreamEvent, AsyncAgentEventHandler, AsyncAgentRunStream, - BingCustomSearchTool, - BingGroundingTool, - CodeInterpreterToolDefinition, - FileSearchTool, FunctionName, FunctionToolDefinition, ListSortOrder, - McpTool, MessageDeltaChunk, MessageDeltaTextContent, MessageDeltaTextFileCitationAnnotation, @@ -90,10 +78,18 @@ ToolOutput, ) from azure.core.credentials_async import AsyncTokenCredential -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError -from ._shared import AzureAISettings +from ._shared import AzureAISettings, to_azure_ai_agent_tools +if sys.version_info >= (3, 13): + from typing import TypeVar # type: ignore # pragma: no cover +else: + from typing_extensions import TypeVar # type: ignore # pragma: no cover +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover if sys.version_info >= (3, 11): from typing import Self # pragma: no cover else: @@ -102,14 +98,106 @@ logger = get_logger("agent_framework.azure") +__all__ = ["AzureAIAgentClient", "AzureAIAgentOptions"] + + +# region Azure AI Agent Options TypedDict + + +class AzureAIAgentOptions(ChatOptions, total=False): + """Azure AI Foundry Agent Service-specific options dict. + + Extends base ChatOptions with Azure AI Agent Service parameters. + Azure AI Agents provides a managed agent runtime with built-in + tools for code interpreter, file search, and web search. + + See: https://learn.microsoft.com/azure/ai-services/agents/ + + Keys: + # Inherited from ChatOptions: + model_id: The model deployment name, + translates to ``model`` in Azure AI API. + temperature: Sampling temperature between 0 and 2. + top_p: Nucleus sampling parameter. + max_tokens: Maximum number of tokens to generate, + translates to ``max_completion_tokens`` in Azure AI API. + tools: List of tools available to the agent. + tool_choice: How the model should use tools. + allow_multiple_tool_calls: Whether to allow parallel tool calls, + translates to ``parallel_tool_calls`` in Azure AI API. + response_format: Structured output schema. + metadata: Request metadata for tracking. + instructions: System instructions for the agent. + + # Options not supported in Azure AI Agent Service: + stop: Not supported. + seed: Not supported. + frequency_penalty: Not supported. + presence_penalty: Not supported. + user: Not supported. + store: Not supported. + logit_bias: Not supported. + + # Azure AI Agent-specific options: + conversation_id: Thread ID to continue conversation in. + tool_resources: Resources for tools (file IDs, vector stores). + """ + + # Azure AI Agent-specific options + conversation_id: str # type: ignore[misc] + """Thread ID to continue a conversation in an existing thread.""" + + tool_resources: dict[str, Any] + """Tool-specific resources for code_interpreter and file_search. + For code_interpreter: {"file_ids": ["file-abc123"]} + For file_search: {"vector_store_ids": ["vs-abc123"]} + """ + + # ChatOptions fields not supported in Azure AI Agent Service + stop: None # type: ignore[misc] + """Not supported in Azure AI Agent Service.""" + + seed: None # type: ignore[misc] + """Not supported in Azure AI Agent Service.""" -TAzureAIAgentClient = TypeVar("TAzureAIAgentClient", bound="AzureAIAgentClient") + frequency_penalty: None # type: ignore[misc] + """Not supported in Azure AI Agent Service.""" + + presence_penalty: None # type: ignore[misc] + """Not supported in Azure AI Agent Service.""" + + user: None # type: ignore[misc] + """Not supported in Azure AI Agent Service.""" + + store: None # type: ignore[misc] + """Not supported in Azure AI Agent Service.""" + + logit_bias: None # type: ignore[misc] + """Not supported in Azure AI Agent Service.""" + + +AZURE_AI_AGENT_OPTION_TRANSLATIONS: dict[str, str] = { + "model_id": "model", + "max_tokens": "max_completion_tokens", + "allow_multiple_tool_calls": "parallel_tool_calls", +} +"""Maps ChatOptions keys to Azure AI Agents API parameter names.""" + +TAzureAIAgentOptions = TypeVar( + "TAzureAIAgentOptions", + bound=TypedDict, # type: ignore[valid-type] + default="AzureAIAgentOptions", + covariant=True, +) + + +# endregion @use_function_invocation @use_instrumentation @use_chat_middleware -class AzureAIAgentClient(BaseChatClient): +class AzureAIAgentClient(BaseChatClient[TAzureAIAgentOptions], Generic[TAzureAIAgentOptions]): """Azure AI Agent Chat client.""" OTEL_PROVIDER_NAME: ClassVar[str] = "azure.ai" # type: ignore[reportIncompatibleVariableOverride, misc] @@ -162,19 +250,31 @@ def __init__( # Using environment variables # Set AZURE_AI_PROJECT_ENDPOINT=https://your-project.cognitiveservices.azure.com - # Set AZURE_AI_MODEL_DEPLOYMENT_NAME=gpt-4 + # Set AZURE_AI_MODEL_DEPLOYMENT_NAME= credential = DefaultAzureCredential() client = AzureAIAgentClient(credential=credential) # Or passing parameters directly client = AzureAIAgentClient( project_endpoint="https://your-project.cognitiveservices.azure.com", - model_deployment_name="gpt-4", + model_deployment_name="", credential=credential, ) # Or loading from a .env file client = AzureAIAgentClient(credential=credential, env_file_path="path/to/.env") + + # Using custom ChatOptions with type safety: + from typing import TypedDict + from agent_framework_azure_ai import AzureAIAgentOptions + + + class MyOptions(AzureAIAgentOptions, total=False): + my_custom_option: str + + + client: AzureAIAgentClient[MyOptions] = AzureAIAgentClient(credential=credential) + response = await client.get_response("Hello", options={"my_custom_option": "value"}) """ try: azure_ai_settings = AzureAISettings( @@ -240,46 +340,29 @@ async def close(self) -> None: await self._cleanup_agent_if_needed() await self._close_client_if_needed() - @classmethod - def from_settings(cls: type[TAzureAIAgentClient], settings: dict[str, Any]) -> TAzureAIAgentClient: - """Initialize a AzureAIAgentClient from a dictionary of settings. - - Args: - settings: A dictionary of settings for the service. - """ - return cls( - agents_client=settings.get("agents_client"), - agent_id=settings.get("agent_id"), - thread_id=settings.get("thread_id"), - project_endpoint=settings.get("project_endpoint"), - model_deployment_name=settings.get("model_deployment_name"), - agent_name=settings.get("agent_name"), - credential=settings.get("credential"), - env_file_path=settings.get("env_file_path"), - should_cleanup_agent=settings.get("should_cleanup_agent", True), - ) - + @override async def _inner_get_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> ChatResponse: return await ChatResponse.from_chat_response_generator( - updates=self._inner_get_streaming_response(messages=messages, chat_options=chat_options, **kwargs), - output_format_type=chat_options.response_format, + updates=self._inner_get_streaming_response(messages=messages, options=options, **kwargs), + output_format_type=options.get("response_format"), ) + @override async def _inner_get_streaming_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: Mapping[str, Any], **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: # prepare - run_options, required_action_results = await self._prepare_options(messages, chat_options, **kwargs) + run_options, required_action_results = await self._prepare_options(messages, options, **kwargs) agent_id = await self._get_agent_id_or_create(run_options) # execute and process @@ -783,46 +866,31 @@ async def _load_agent_definition_if_needed(self) -> Agent | None: self._agent_definition = await self.agents_client.get_agent(self.agent_id) return self._agent_definition - def _prepare_tool_choice(self, chat_options: ChatOptions) -> None: - """Prepare the tools and tool choice for the chat options. - - Args: - chat_options: The chat options to prepare. - """ - chat_tool_mode = chat_options.tool_choice - if chat_tool_mode is None or chat_tool_mode == ToolMode.NONE or chat_tool_mode == "none": - chat_options.tools = None - chat_options.tool_choice = ToolMode.NONE - return - - chat_options.tool_choice = chat_tool_mode - async def _prepare_options( self, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: Mapping[str, Any], **kwargs: Any, ) -> tuple[dict[str, Any], list[FunctionResultContent | FunctionApprovalResponseContent] | None]: agent_definition = await self._load_agent_definition_if_needed() - # Use to_dict with exclusions for properties handled separately - run_options: dict[str, Any] = chat_options.to_dict( - exclude={ - "type", - "instructions", # handled via messages - "tools", # handled separately - "tool_choice", # handled separately - "response_format", # handled separately - "additional_properties", # handled separately - "frequency_penalty", # not supported - "presence_penalty", # not supported - "user", # not supported - "stop", # not supported - "logit_bias", # not supported - "seed", # not supported - "store", # not supported - } - ) + # Build run_options from options dict, excluding specific keys + exclude_keys = { + "type", + "instructions", # handled via messages + "tools", # handled separately + "tool_choice", # handled separately + "response_format", # handled separately + "additional_properties", # handled separately + "frequency_penalty", # not supported + "presence_penalty", # not supported + "user", # not supported + "stop", # not supported + "logit_bias", # not supported + "seed", # not supported + "store", # not supported + } + run_options: dict[str, Any] = {k: v for k, v in options.items() if k not in exclude_keys and v is not None} # Translation between ChatOptions and Azure AI Agents API translations = { @@ -840,21 +908,31 @@ async def _prepare_options( # tools and tool_choice if tool_definitions := await self._prepare_tool_definitions_and_resources( - chat_options, agent_definition, run_options + options, agent_definition, run_options ): run_options["tools"] = tool_definitions - if tool_choice := self._prepare_tool_choice_mode(chat_options): + if tool_choice := self._prepare_tool_choice_mode(options): run_options["tool_choice"] = tool_choice # response format - if chat_options.response_format is not None: - run_options["response_format"] = ResponseFormatJsonSchemaType( - json_schema=ResponseFormatJsonSchema( - name=chat_options.response_format.__name__, - schema=chat_options.response_format.model_json_schema(), + response_format = options.get("response_format") + if response_format is not None: + if isinstance(response_format, type) and issubclass(response_format, BaseModel): + # Pydantic model - convert to Azure format + run_options["response_format"] = ResponseFormatJsonSchemaType( + json_schema=ResponseFormatJsonSchema( + name=response_format.__name__, + schema=response_format.model_json_schema(), + ) + ) + elif isinstance(response_format, Mapping): + # Runtime JSON schema dict - pass through as-is + run_options["response_format"] = response_format + else: + raise ServiceInvalidRequestError( + "response_format must be a Pydantic BaseModel class or a dict with runtime JSON schema." ) - ) # messages additional_messages, instructions, required_action_results = self._prepare_messages(messages) @@ -873,41 +951,40 @@ async def _prepare_options( run_options["instructions"] = "\n".join(instructions) # thread_id resolution (conversation_id takes precedence, then kwargs, then instance default) - run_options["thread_id"] = chat_options.conversation_id or kwargs.get("conversation_id") or self.thread_id + run_options["thread_id"] = options.get("conversation_id") or kwargs.get("conversation_id") or self.thread_id return run_options, required_action_results def _prepare_tool_choice_mode( - self, chat_options: ChatOptions + self, options: Mapping[str, Any] ) -> AgentsToolChoiceOptionMode | AgentsNamedToolChoice | None: """Prepare the tool choice mode for Azure AI Agents API.""" - if chat_options.tool_choice is None: + tool_choice = options.get("tool_choice") + if tool_choice is None: return None - if chat_options.tool_choice == "none": + if tool_choice == "none": return AgentsToolChoiceOptionMode.NONE - if chat_options.tool_choice == "auto": + if tool_choice == "auto": return AgentsToolChoiceOptionMode.AUTO - if ( - isinstance(chat_options.tool_choice, ToolMode) - and chat_options.tool_choice == "required" - and chat_options.tool_choice.required_function_name is not None - ): - return AgentsNamedToolChoice( - type=AgentsNamedToolChoiceType.FUNCTION, - function=FunctionName(name=chat_options.tool_choice.required_function_name), - ) + if isinstance(tool_choice, Mapping) and tool_choice.get("mode") == "required": + req_fn = tool_choice.get("required_function_name") + if req_fn: + return AgentsNamedToolChoice( + type=AgentsNamedToolChoiceType.FUNCTION, + function=FunctionName(name=str(req_fn)), + ) return None async def _prepare_tool_definitions_and_resources( self, - chat_options: ChatOptions, + options: Mapping[str, Any], agent_definition: Agent | None, run_options: dict[str, Any], ) -> list[ToolDefinition | dict[str, Any]]: """Prepare tool definitions and resources for the run options.""" tool_definitions: list[ToolDefinition | dict[str, Any]] = [] - # Add tools from existing agent (exclude function tools - passed via chat_options.tools) + # Add tools from existing agent (exclude function tools - passed via options.get("tools")) if agent_definition is not None: agent_tools = [tool for tool in agent_definition.tools if not isinstance(tool, FunctionToolDefinition)] if agent_tools: @@ -916,11 +993,13 @@ async def _prepare_tool_definitions_and_resources( run_options["tool_resources"] = agent_definition.tool_resources # Add run tools if tool_choice allows - if chat_options.tool_choice is not None and chat_options.tool_choice != "none" and chat_options.tools: - tool_definitions.extend(await self._prepare_tools_for_azure_ai(chat_options.tools, run_options)) + tool_choice = options.get("tool_choice") + tools = options.get("tools") + if tool_choice is not None and tool_choice != "none" and tools: + tool_definitions.extend(to_azure_ai_agent_tools(tools, run_options)) # Handle MCP tool resources - mcp_resources = self._prepare_mcp_resources(chat_options.tools) + mcp_resources = self._prepare_mcp_resources(tools) if mcp_resources: if "tool_resources" not in run_options: run_options["tool_resources"] = {} @@ -1016,82 +1095,6 @@ def _prepare_messages( return additional_messages, instructions, required_action_results - async def _prepare_tools_for_azure_ai( - self, tools: Sequence["ToolProtocol | MutableMapping[str, Any]"], run_options: dict[str, Any] | None = None - ) -> list[ToolDefinition | dict[str, Any]]: - """Prepare tool definitions for the Azure AI Agents API.""" - tool_definitions: list[ToolDefinition | dict[str, Any]] = [] - for tool in tools: - match tool: - case AIFunction(): - tool_definitions.append(tool.to_json_schema_spec()) # type: ignore[reportUnknownArgumentType] - case HostedWebSearchTool(): - additional_props = tool.additional_properties or {} - config_args: dict[str, Any] = {} - if count := additional_props.get("count"): - config_args["count"] = count - if freshness := additional_props.get("freshness"): - config_args["freshness"] = freshness - if market := additional_props.get("market"): - config_args["market"] = market - if set_lang := additional_props.get("set_lang"): - config_args["set_lang"] = set_lang - # Bing Grounding - connection_id = additional_props.get("connection_id") or os.getenv("BING_CONNECTION_ID") - # Custom Bing Search - custom_connection_id = additional_props.get("custom_connection_id") or os.getenv( - "BING_CUSTOM_CONNECTION_ID" - ) - custom_instance_name = additional_props.get("custom_instance_name") or os.getenv( - "BING_CUSTOM_INSTANCE_NAME" - ) - bing_search: BingGroundingTool | BingCustomSearchTool | None = None - if (connection_id) and not custom_connection_id and not custom_instance_name: - if connection_id: - conn_id = connection_id - else: - raise ServiceInitializationError("Parameter connection_id is not provided.") - bing_search = BingGroundingTool(connection_id=conn_id, **config_args) - if custom_connection_id and custom_instance_name: - bing_search = BingCustomSearchTool( - connection_id=custom_connection_id, - instance_name=custom_instance_name, - **config_args, - ) - if not bing_search: - raise ServiceInitializationError( - "Bing search tool requires either 'connection_id' for Bing Grounding " - "or both 'custom_connection_id' and 'custom_instance_name' for Custom Bing Search. " - "These can be provided via additional_properties or environment variables: " - "'BING_CONNECTION_ID', 'BING_CUSTOM_CONNECTION_ID', " - "'BING_CUSTOM_INSTANCE_NAME'" - ) - tool_definitions.extend(bing_search.definitions) - case HostedCodeInterpreterTool(): - tool_definitions.append(CodeInterpreterToolDefinition()) - case HostedMCPTool(): - mcp_tool = McpTool( - server_label=tool.name.replace(" ", "_"), - server_url=str(tool.url), - allowed_tools=list(tool.allowed_tools) if tool.allowed_tools else [], - ) - tool_definitions.extend(mcp_tool.definitions) - case HostedFileSearchTool(): - vector_stores = [inp for inp in tool.inputs or [] if isinstance(inp, HostedVectorStoreContent)] - if vector_stores: - file_search = FileSearchTool(vector_store_ids=[vs.vector_store_id for vs in vector_stores]) - tool_definitions.extend(file_search.definitions) - # Set tool_resources for file search to work properly with Azure AI - if run_options is not None and "tool_resources" not in run_options: - run_options["tool_resources"] = file_search.resources - case ToolDefinition(): - tool_definitions.append(tool) - case dict(): - tool_definitions.append(tool) - case _: - raise ServiceInitializationError(f"Unsupported tool type: {type(tool)}") - return tool_definitions - def _prepare_tool_outputs_for_azure_ai( self, required_action_results: list[FunctionResultContent | FunctionApprovalResponseContent] | None, diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index 2e8edc7e47..c735cce049 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -2,19 +2,18 @@ import sys from collections.abc import Mapping, MutableSequence -from typing import Any, ClassVar, TypeVar, cast +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypedDict, TypeVar, cast from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, ChatMessage, - ChatOptions, HostedMCPTool, TextContent, get_logger, use_chat_middleware, use_function_invocation, ) -from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidRequestError +from agent_framework.exceptions import ServiceInitializationError from agent_framework.observability import use_instrumentation from agent_framework.openai._responses_client import OpenAIBaseResponsesClient from azure.ai.projects.aio import AIProjectClient @@ -22,37 +21,44 @@ MCPTool, PromptAgentDefinition, PromptAgentDefinitionText, - ResponseTextFormatConfigurationJsonObject, - ResponseTextFormatConfigurationJsonSchema, - ResponseTextFormatConfigurationText, ) from azure.core.credentials_async import AsyncTokenCredential from azure.core.exceptions import ResourceNotFoundError -from pydantic import BaseModel, ValidationError +from pydantic import ValidationError -from ._shared import AzureAISettings +from ._shared import AzureAISettings, create_text_format_config -if sys.version_info >= (3, 11): - from typing import Self # pragma: no cover -else: - from typing_extensions import Self # pragma: no cover +if TYPE_CHECKING: + from agent_framework.openai import OpenAIResponsesOptions +if sys.version_info >= (3, 13): + from typing import TypeVar # type: ignore # pragma: no cover +else: + from typing_extensions import TypeVar # type: ignore # pragma: no cover if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: from typing_extensions import override # type: ignore[import] # pragma: no cover +if sys.version_info >= (3, 11): + from typing import Self # pragma: no cover +else: + from typing_extensions import Self # pragma: no cover logger = get_logger("agent_framework.azure") - -TAzureAIClient = TypeVar("TAzureAIClient", bound="AzureAIClient") +TAzureAIClientOptions = TypeVar( + "TAzureAIClientOptions", + bound=TypedDict, # type: ignore[valid-type] + default="OpenAIResponsesOptions", + covariant=True, +) @use_function_invocation @use_instrumentation @use_chat_middleware -class AzureAIClient(OpenAIBaseResponsesClient): +class AzureAIClient(OpenAIBaseResponsesClient[TAzureAIClientOptions], Generic[TAzureAIClientOptions]): """Azure AI Agent client.""" OTEL_PROVIDER_NAME: ClassVar[str] = "azure.ai" # type: ignore[reportIncompatibleVariableOverride, misc] @@ -115,6 +121,18 @@ def __init__( # Or loading from a .env file client = AzureAIClient(credential=credential, env_file_path="path/to/.env") + + # Using custom ChatOptions with type safety: + from typing import TypedDict + from agent_framework import ChatOptions + + + class MyOptions(ChatOptions, total=False): + my_custom_option: str + + + client: AzureAIClient[MyOptions] = AzureAIClient(credential=credential) + response = await client.get_response("Hello", options={"my_custom_option": "value"}) """ try: azure_ai_settings = AzureAISettings( @@ -265,45 +283,11 @@ async def close(self) -> None: """Close the project_client.""" await self._close_client_if_needed() - def _create_text_format_config( - self, response_format: Any - ) -> ( - ResponseTextFormatConfigurationJsonSchema - | ResponseTextFormatConfigurationJsonObject - | ResponseTextFormatConfigurationText - ): - """Convert response_format into Azure text format configuration.""" - if isinstance(response_format, type) and issubclass(response_format, BaseModel): - return ResponseTextFormatConfigurationJsonSchema( - name=response_format.__name__, - schema=response_format.model_json_schema(), - ) - - if isinstance(response_format, Mapping): - format_config = self._convert_response_format(response_format) - format_type = format_config.get("type") - if format_type == "json_schema": - config_kwargs: dict[str, Any] = { - "name": format_config.get("name") or "response", - "schema": format_config["schema"], - } - if "strict" in format_config: - config_kwargs["strict"] = format_config["strict"] - if "description" in format_config: - config_kwargs["description"] = format_config["description"] - return ResponseTextFormatConfigurationJsonSchema(**config_kwargs) - if format_type == "json_object": - return ResponseTextFormatConfigurationJsonObject() - if format_type == "text": - return ResponseTextFormatConfigurationText() - - raise ServiceInvalidRequestError("response_format must be a Pydantic model or mapping.") - async def _get_agent_reference_or_create( self, run_options: dict[str, Any], messages_instructions: str | None, - chat_options: ChatOptions | None = None, + chat_options: Mapping[str, Any] | None = None, ) -> dict[str, str]: """Determine which agent to use and create if needed. @@ -315,11 +299,6 @@ async def _get_agent_reference_or_create( Returns: dict[str, str]: The agent reference to use. """ - # chat_options is needed separately because the base class excludes response_format - # from run_options (transforming it to text/text_format for OpenAI). Azure's agent - # creation API requires the original response_format to build its own config format. - if chat_options is None: - chat_options = ChatOptions() # Agent name must be explicitly provided by the user. if self.agent_name is None: raise ServiceInitializationError( @@ -356,13 +335,8 @@ async def _get_agent_reference_or_create( # response_format is accessed from chat_options or additional_properties # since the base class excludes it from run_options - response_format: Any = ( - chat_options.response_format - if chat_options.response_format is not None - else chat_options.additional_properties.get("response_format") - ) - if response_format: - args["text"] = PromptAgentDefinitionText(format=self._create_text_format_config(response_format)) + if chat_options and (response_format := chat_options.get("response_format")): + args["text"] = PromptAgentDefinitionText(format=create_text_format_config(response_format)) # Combine instructions from messages and options combined_instructions = [ @@ -392,12 +366,12 @@ async def _close_client_if_needed(self) -> None: async def _prepare_options( self, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> dict[str, Any]: """Take ChatOptions and create the specific options for Azure AI.""" prepared_messages, instructions = self._prepare_messages_for_azure_ai(messages) - run_options = await super()._prepare_options(prepared_messages, chat_options, **kwargs) + run_options = await super()._prepare_options(prepared_messages, options, **kwargs) # WORKAROUND: Azure AI Projects 'create responses' API has schema divergence from OpenAI's # Responses API. Azure requires 'type' at item level and 'annotations' in content items. @@ -409,18 +383,33 @@ async def _prepare_options( if not self._is_application_endpoint: # Application-scoped response APIs do not support "agent" property. - agent_reference = await self._get_agent_reference_or_create(run_options, instructions, chat_options) + agent_reference = await self._get_agent_reference_or_create(run_options, instructions, options) run_options["extra_body"] = {"agent": agent_reference} # Remove properties that are not supported on request level # but were configured on agent level - exclude = ["model", "tools", "response_format", "temperature", "top_p", "text", "text_format"] + exclude = [ + "model", + "tools", + "response_format", + "temperature", + "top_p", + "text", + "text_format", + ] for property in exclude: run_options.pop(property, None) return run_options + @override + def _check_model_presence(self, run_options: dict[str, Any]) -> None: + if not run_options.get("model"): + if not self.model_id: + raise ValueError("model_deployment_name must be a non-empty string") + run_options["model"] = self.model_id + def _transform_input_for_azure_ai(self, input_items: list[dict[str, Any]]) -> list[dict[str, Any]]: """Transform input items to match Azure AI Projects expected schema. @@ -460,9 +449,9 @@ def _transform_input_for_azure_ai(self, input_items: list[dict[str, Any]]) -> li return transformed @override - def _get_current_conversation_id(self, chat_options: ChatOptions, **kwargs: Any) -> str | None: + def _get_current_conversation_id(self, options: dict[str, Any], **kwargs: Any) -> str | None: """Get the current conversation ID from chat options or kwargs.""" - return chat_options.conversation_id or kwargs.get("conversation_id") or self.conversation_id + return options.get("conversation_id") or kwargs.get("conversation_id") or self.conversation_id def _prepare_messages_for_azure_ai( self, messages: MutableSequence[ChatMessage] diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py b/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py new file mode 100644 index 0000000000..edad03f5b4 --- /dev/null +++ b/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py @@ -0,0 +1,454 @@ +# Copyright (c) Microsoft. All rights reserved. + +import sys +from collections.abc import Callable, MutableMapping, Sequence +from typing import TYPE_CHECKING, Any, Generic, TypedDict + +from agent_framework import ( + AGENT_FRAMEWORK_USER_AGENT, + AIFunction, + ChatAgent, + ContextProvider, + Middleware, + ToolProtocol, + get_logger, + normalize_tools, +) +from agent_framework.exceptions import ServiceInitializationError +from azure.ai.projects.aio import AIProjectClient +from azure.ai.projects.models import ( + AgentReference, + AgentVersionDetails, + FunctionTool, + PromptAgentDefinition, + PromptAgentDefinitionText, +) +from azure.core.credentials_async import AsyncTokenCredential +from pydantic import ValidationError + +from ._client import AzureAIClient +from ._shared import AzureAISettings, create_text_format_config, from_azure_ai_tools, to_azure_ai_tools + +if TYPE_CHECKING: + from agent_framework.openai import OpenAIResponsesOptions + +if sys.version_info >= (3, 13): + from typing import Self, TypeVar # pragma: no cover +else: + from typing_extensions import Self, TypeVar # pragma: no cover + + +logger = get_logger("agent_framework.azure") + + +# Type variable for options - allows typed ChatAgent[TOptions] returns +# Default matches AzureAIClient's default options type +TOptions_co = TypeVar( + "TOptions_co", + bound=TypedDict, # type: ignore[valid-type] + default="OpenAIResponsesOptions", + covariant=True, +) + + +class AzureAIProjectAgentProvider(Generic[TOptions_co]): + """Provider for Azure AI Agent Service (Responses API). + + This provider allows you to create, retrieve, and manage Azure AI agents + using the AIProjectClient from the Azure AI Projects SDK. + + Examples: + Using with explicit AIProjectClient: + + .. code-block:: python + + from agent_framework.azure import AzureAIProjectAgentProvider + from azure.ai.projects.aio import AIProjectClient + from azure.identity.aio import DefaultAzureCredential + + async with AIProjectClient(endpoint, credential) as client: + provider = AzureAIProjectAgentProvider(client) + agent = await provider.create_agent( + name="MyAgent", + model="gpt-4", + instructions="You are a helpful assistant.", + ) + response = await agent.run("Hello!") + + Using with credential and endpoint (auto-creates client): + + .. code-block:: python + + from agent_framework.azure import AzureAIProjectAgentProvider + from azure.identity.aio import DefaultAzureCredential + + async with AzureAIProjectAgentProvider(credential=credential) as provider: + agent = await provider.create_agent( + name="MyAgent", + model="gpt-4", + instructions="You are a helpful assistant.", + ) + response = await agent.run("Hello!") + """ + + def __init__( + self, + project_client: AIProjectClient | None = None, + *, + project_endpoint: str | None = None, + model: str | None = None, + credential: AsyncTokenCredential | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + ) -> None: + """Initialize an Azure AI Project Agent Provider. + + Args: + project_client: An existing AIProjectClient to use. If not provided, one will be created. + project_endpoint: The Azure AI Project endpoint URL. + Can also be set via environment variable AZURE_AI_PROJECT_ENDPOINT. + Ignored when a project_client is passed. + model: The default model deployment name to use for agent creation. + Can also be set via environment variable AZURE_AI_MODEL_DEPLOYMENT_NAME. + credential: Azure async credential to use for authentication. + Required when project_client is not provided. + env_file_path: Path to environment file for loading settings. + env_file_encoding: Encoding of the environment file. + + Raises: + ServiceInitializationError: If required parameters are missing or invalid. + """ + try: + self._settings = AzureAISettings( + project_endpoint=project_endpoint, + model_deployment_name=model, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) + except ValidationError as ex: + raise ServiceInitializationError("Failed to create Azure AI settings.", ex) from ex + + # Track whether we should close client connection + self._should_close_client = False + + if project_client is None: + if not self._settings.project_endpoint: + raise ServiceInitializationError( + "Azure AI project endpoint is required. Set via 'project_endpoint' parameter " + "or 'AZURE_AI_PROJECT_ENDPOINT' environment variable." + ) + + if not credential: + raise ServiceInitializationError("Azure credential is required when project_client is not provided.") + + project_client = AIProjectClient( + endpoint=self._settings.project_endpoint, + credential=credential, + user_agent=AGENT_FRAMEWORK_USER_AGENT, + ) + self._should_close_client = True + + self._project_client = project_client + + async def create_agent( + self, + name: str, + model: str | None = None, + instructions: str | None = None, + description: str | None = None, + tools: ToolProtocol + | Callable[..., Any] + | MutableMapping[str, Any] + | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] + | None = None, + default_options: TOptions_co | None = None, + middleware: Sequence[Middleware] | None = None, + context_provider: ContextProvider | None = None, + ) -> "ChatAgent[TOptions_co]": + """Create a new agent on the Azure AI service and return a local ChatAgent wrapper. + + Args: + name: The name of the agent to create. + model: The model deployment name to use. Falls back to AZURE_AI_MODEL_DEPLOYMENT_NAME + environment variable if not provided. + instructions: Instructions for the agent. + description: A description of the agent. + tools: Tools to make available to the agent. + default_options: A TypedDict containing default chat options for the agent. + These options are applied to every run unless overridden. + middleware: List of middleware to intercept agent and function invocations. + context_provider: Context provider to include during agent invocation. + + Returns: + ChatAgent: A ChatAgent instance configured with the created agent. + + Raises: + ServiceInitializationError: If required parameters are missing. + """ + # Resolve model from parameter or environment variable + resolved_model = model or self._settings.model_deployment_name + if not resolved_model: + raise ServiceInitializationError( + "Model deployment name is required. Provide 'model' parameter " + "or set 'AZURE_AI_MODEL_DEPLOYMENT_NAME' environment variable." + ) + + # Extract response_format from default_options if present + opts = dict(default_options) if default_options else {} + response_format = opts.get("response_format") + + args: dict[str, Any] = {"model": resolved_model} + + if instructions: + args["instructions"] = instructions + if response_format and isinstance(response_format, (type, dict)): + args["text"] = PromptAgentDefinitionText( + format=create_text_format_config(response_format) # type: ignore[arg-type] + ) + + # Normalize tools once and reuse for both Azure AI API and ChatAgent + normalized_tools = normalize_tools(tools) + if normalized_tools: + args["tools"] = to_azure_ai_tools(normalized_tools) + + created_agent = await self._project_client.agents.create_version( + agent_name=name, + definition=PromptAgentDefinition(**args), + description=description, + ) + + return self._to_chat_agent_from_details( + created_agent, + normalized_tools, + default_options=default_options, + middleware=middleware, + context_provider=context_provider, + ) + + async def get_agent( + self, + *, + name: str | None = None, + reference: AgentReference | None = None, + tools: ToolProtocol + | Callable[..., Any] + | MutableMapping[str, Any] + | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] + | None = None, + default_options: TOptions_co | None = None, + middleware: Sequence[Middleware] | None = None, + context_provider: ContextProvider | None = None, + ) -> "ChatAgent[TOptions_co]": + """Retrieve an existing agent from the Azure AI service and return a local ChatAgent wrapper. + + You must provide either name or reference. Use `as_agent()` if you already have + AgentVersionDetails and want to avoid an async call. + + Args: + name: The name of the agent to retrieve (fetches latest version). + reference: Reference containing the agent's name and optionally a specific version. + tools: Tools to make available to the agent. Required if the agent has function tools. + default_options: A TypedDict containing default chat options for the agent. + These options are applied to every run unless overridden. + middleware: List of middleware to intercept agent and function invocations. + context_provider: Context provider to include during agent invocation. + + Returns: + ChatAgent: A ChatAgent instance configured with the retrieved agent. + + Raises: + ValueError: If no identifier is provided or required tools are missing. + """ + existing_agent: AgentVersionDetails + + if reference and reference.version: + # Fetch specific version + existing_agent = await self._project_client.agents.get_version( + agent_name=reference.name, agent_version=reference.version + ) + elif agent_name := (reference.name if reference else name): + # Fetch latest version + details = await self._project_client.agents.get(agent_name=agent_name) + existing_agent = details.versions.latest + else: + raise ValueError("Either name or reference must be provided to get an agent.") + + if not isinstance(existing_agent.definition, PromptAgentDefinition): + raise ValueError("Agent definition must be PromptAgentDefinition to get a ChatAgent.") + + # Validate that required function tools are provided + self._validate_function_tools(existing_agent.definition.tools, tools) + + return self._to_chat_agent_from_details( + existing_agent, + normalize_tools(tools), + default_options=default_options, + middleware=middleware, + context_provider=context_provider, + ) + + def as_agent( + self, + details: AgentVersionDetails, + tools: ToolProtocol + | Callable[..., Any] + | MutableMapping[str, Any] + | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] + | None = None, + default_options: TOptions_co | None = None, + middleware: Sequence[Middleware] | None = None, + context_provider: ContextProvider | None = None, + ) -> "ChatAgent[TOptions_co]": + """Wrap an SDK agent version object into a ChatAgent without making HTTP calls. + + Use this when you already have an AgentVersionDetails from a previous API call. + + Args: + details: The AgentVersionDetails to wrap. + tools: Tools to make available to the agent. Required if the agent has function tools. + default_options: A TypedDict containing default chat options for the agent. + These options are applied to every run unless overridden. + middleware: List of middleware to intercept agent and function invocations. + context_provider: Context provider to include during agent invocation. + + Returns: + ChatAgent: A ChatAgent instance configured with the agent version. + + Raises: + ValueError: If the agent definition is not a PromptAgentDefinition or required tools are missing. + """ + if not isinstance(details.definition, PromptAgentDefinition): + raise ValueError("Agent definition must be PromptAgentDefinition to create a ChatAgent.") + + # Validate that required function tools are provided + self._validate_function_tools(details.definition.tools, tools) + + return self._to_chat_agent_from_details( + details, + normalize_tools(tools), + default_options=default_options, + middleware=middleware, + context_provider=context_provider, + ) + + def _to_chat_agent_from_details( + self, + details: AgentVersionDetails, + provided_tools: Sequence[ToolProtocol | MutableMapping[str, Any]] | None = None, + default_options: TOptions_co | None = None, + middleware: Sequence[Middleware] | None = None, + context_provider: ContextProvider | None = None, + ) -> "ChatAgent[TOptions_co]": + """Create a ChatAgent from an AgentVersionDetails. + + Args: + details: The AgentVersionDetails containing the agent definition. + provided_tools: User-provided tools (including function implementations). + These are merged with hosted tools from the definition. + default_options: A TypedDict containing default chat options for the agent. + These options are applied to every run unless overridden. + middleware: List of middleware to intercept agent and function invocations. + context_provider: Context provider to include during agent invocation. + """ + if not isinstance(details.definition, PromptAgentDefinition): + raise ValueError("Agent definition must be PromptAgentDefinition to get a ChatAgent.") + + client = AzureAIClient( + project_client=self._project_client, + agent_name=details.name, + agent_version=details.version, + agent_description=details.description, + model_deployment_name=details.definition.model, + ) + + # Merge tools: hosted tools from definition + user-provided function tools + # from_azure_ai_tools converts hosted tools (MCP, code interpreter, file search, web search) + # but function tools need the actual implementations from provided_tools + merged_tools = self._merge_tools(details.definition.tools, provided_tools) + + return ChatAgent( # type: ignore[return-value] + chat_client=client, + id=details.id, + name=details.name, + description=details.description, + instructions=details.definition.instructions, + model_id=details.definition.model, + tools=merged_tools, + default_options=default_options, # type: ignore[arg-type] + middleware=middleware, + context_provider=context_provider, + ) + + def _merge_tools( + self, + definition_tools: Sequence[Any] | None, + provided_tools: Sequence[ToolProtocol | MutableMapping[str, Any]] | None, + ) -> list[ToolProtocol | dict[str, Any]]: + """Merge hosted tools from definition with user-provided function tools. + + Args: + definition_tools: Tools from the agent definition (Azure AI format). + provided_tools: User-provided tools (Agent Framework format), including function implementations. + + Returns: + Combined list of tools for the ChatAgent. + """ + merged: list[ToolProtocol | dict[str, Any]] = [] + + # Convert hosted tools from definition (MCP, code interpreter, file search, web search) + # Function tools from the definition are skipped - we use user-provided implementations instead + hosted_tools = from_azure_ai_tools(definition_tools) + for hosted_tool in hosted_tools: + # Skip function tool dicts - they don't have implementations + if isinstance(hosted_tool, dict) and hosted_tool.get("type") == "function": + continue + merged.append(hosted_tool) + + # Add user-provided function tools (these have the actual implementations) + if provided_tools: + for provided_tool in provided_tools: + if isinstance(provided_tool, AIFunction): + merged.append(provided_tool) # type: ignore[reportUnknownArgumentType] + + return merged + + def _validate_function_tools( + self, + agent_tools: Sequence[Any] | None, + provided_tools: ToolProtocol + | Callable[..., Any] + | MutableMapping[str, Any] + | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] + | None, + ) -> None: + """Validate that required function tools are provided.""" + # Normalize and validate function tools + normalized_tools = normalize_tools(provided_tools) + tool_names = {tool.name for tool in normalized_tools if isinstance(tool, AIFunction)} + + # If function tools exist in agent definition but were not provided, + # we need to raise an error, as it won't be possible to invoke the function. + missing_tools = [ + tool.name for tool in (agent_tools or []) if isinstance(tool, FunctionTool) and tool.name not in tool_names + ] + + if missing_tools: + raise ValueError( + f"The following prompt agent definition required tools were not provided: {', '.join(missing_tools)}" + ) + + async def __aenter__(self) -> Self: + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + """Async context manager exit.""" + await self.close() + + async def close(self) -> None: + """Close the provider and release resources. + + Only closes the underlying AIProjectClient if it was created by this provider. + """ + if self._should_close_client: + await self._project_client.close() diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_shared.py b/python/packages/azure-ai/agent_framework_azure_ai/_shared.py index a120e9f92e..b99a3b5f66 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_shared.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_shared.py @@ -1,8 +1,49 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import ClassVar +import os +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any, ClassVar, Literal, cast +from agent_framework import ( + AIFunction, + Contents, + HostedCodeInterpreterTool, + HostedFileContent, + HostedFileSearchTool, + HostedMCPTool, + HostedVectorStoreContent, + HostedWebSearchTool, + ToolProtocol, + get_logger, +) from agent_framework._pydantic import AFBaseSettings +from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidRequestError +from azure.ai.agents.models import ( + BingCustomSearchTool, + BingGroundingTool, + CodeInterpreterToolDefinition, + McpTool, + ToolDefinition, +) +from azure.ai.agents.models import FileSearchTool as AgentsFileSearchTool +from azure.ai.projects.models import ( + ApproximateLocation, + CodeInterpreterTool, + CodeInterpreterToolAuto, + FunctionTool, + MCPTool, + ResponseTextFormatConfigurationJsonObject, + ResponseTextFormatConfigurationJsonSchema, + ResponseTextFormatConfigurationText, + Tool, + WebSearchPreviewTool, +) +from azure.ai.projects.models import ( + FileSearchTool as ProjectsFileSearchTool, +) +from pydantic import BaseModel + +logger = get_logger("agent_framework.azure") class AzureAISettings(AFBaseSettings): @@ -44,3 +85,481 @@ class AzureAISettings(AFBaseSettings): project_endpoint: str | None = None model_deployment_name: str | None = None + + +def to_azure_ai_agent_tools( + tools: Sequence[ToolProtocol | MutableMapping[str, Any]] | None, + run_options: dict[str, Any] | None = None, +) -> list[ToolDefinition | dict[str, Any]]: + """Convert Agent Framework tools to Azure AI V1 SDK tool definitions. + + Args: + tools: Sequence of Agent Framework tools to convert. + run_options: Optional dict with run options. + + Returns: + List of Azure AI V1 SDK tool definitions. + + Raises: + ServiceInitializationError: If tool configuration is invalid. + """ + if not tools: + return [] + + tool_definitions: list[ToolDefinition | dict[str, Any]] = [] + for tool in tools: + match tool: + case AIFunction(): + tool_definitions.append(tool.to_json_schema_spec()) # type: ignore[reportUnknownArgumentType] + case HostedWebSearchTool(): + additional_props = tool.additional_properties or {} + config_args: dict[str, Any] = {} + if count := additional_props.get("count"): + config_args["count"] = count + if freshness := additional_props.get("freshness"): + config_args["freshness"] = freshness + if market := additional_props.get("market"): + config_args["market"] = market + if set_lang := additional_props.get("set_lang"): + config_args["set_lang"] = set_lang + # Bing Grounding + connection_id = additional_props.get("connection_id") or os.getenv("BING_CONNECTION_ID") + # Custom Bing Search + custom_connection_id = additional_props.get("custom_connection_id") or os.getenv( + "BING_CUSTOM_CONNECTION_ID" + ) + custom_instance_name = additional_props.get("custom_instance_name") or os.getenv( + "BING_CUSTOM_INSTANCE_NAME" + ) + bing_search: BingGroundingTool | BingCustomSearchTool | None = None + if connection_id and not custom_connection_id and not custom_instance_name: + bing_search = BingGroundingTool(connection_id=connection_id, **config_args) + if custom_connection_id and custom_instance_name: + bing_search = BingCustomSearchTool( + connection_id=custom_connection_id, + instance_name=custom_instance_name, + **config_args, + ) + if not bing_search: + raise ServiceInitializationError( + "Bing search tool requires either 'connection_id' for Bing Grounding " + "or both 'custom_connection_id' and 'custom_instance_name' for Custom Bing Search. " + "These can be provided via additional_properties or environment variables: " + "'BING_CONNECTION_ID', 'BING_CUSTOM_CONNECTION_ID', 'BING_CUSTOM_INSTANCE_NAME'" + ) + tool_definitions.extend(bing_search.definitions) + case HostedCodeInterpreterTool(): + tool_definitions.append(CodeInterpreterToolDefinition()) + case HostedMCPTool(): + mcp_tool = McpTool( + server_label=tool.name.replace(" ", "_"), + server_url=str(tool.url), + allowed_tools=list(tool.allowed_tools) if tool.allowed_tools else [], + ) + tool_definitions.extend(mcp_tool.definitions) + case HostedFileSearchTool(): + vector_stores = [inp for inp in tool.inputs or [] if isinstance(inp, HostedVectorStoreContent)] + if vector_stores: + file_search = AgentsFileSearchTool(vector_store_ids=[vs.vector_store_id for vs in vector_stores]) + tool_definitions.extend(file_search.definitions) + # Set tool_resources for file search to work properly with Azure AI + if run_options is not None and "tool_resources" not in run_options: + run_options["tool_resources"] = file_search.resources + case ToolDefinition(): + tool_definitions.append(tool) + case dict(): + tool_definitions.append(tool) + case _: + raise ServiceInitializationError(f"Unsupported tool type: {type(tool)}") + return tool_definitions + + +def from_azure_ai_agent_tools( + tools: Sequence[ToolDefinition | dict[str, Any]] | None, +) -> list[ToolProtocol | dict[str, Any]]: + """Convert Azure AI V1 SDK tool definitions to Agent Framework tools. + + Args: + tools: Sequence of Azure AI V1 SDK tool definitions. + + Returns: + List of Agent Framework tools. + """ + if not tools: + return [] + + result: list[ToolProtocol | dict[str, Any]] = [] + for tool in tools: + # Handle SDK objects + if isinstance(tool, CodeInterpreterToolDefinition): + result.append(HostedCodeInterpreterTool()) + elif isinstance(tool, dict): + # Handle dict format + converted = _convert_dict_tool(tool) + if converted is not None: + result.append(converted) + elif hasattr(tool, "type"): + # Handle other SDK objects by type + converted = _convert_sdk_tool(tool) + if converted is not None: + result.append(converted) + return result + + +def _convert_dict_tool(tool: dict[str, Any]) -> ToolProtocol | dict[str, Any] | None: + """Convert a dict-format Azure AI tool to Agent Framework tool.""" + tool_type = tool.get("type") + + if tool_type == "code_interpreter": + return HostedCodeInterpreterTool() + + if tool_type == "file_search": + file_search_config = tool.get("file_search", {}) + vector_store_ids = file_search_config.get("vector_store_ids", []) + inputs = [HostedVectorStoreContent(vector_store_id=vs_id) for vs_id in vector_store_ids] + return HostedFileSearchTool(inputs=inputs if inputs else None) # type: ignore + + if tool_type == "bing_grounding": + bing_config = tool.get("bing_grounding", {}) + connection_id = bing_config.get("connection_id") + return HostedWebSearchTool(additional_properties={"connection_id": connection_id} if connection_id else None) + + if tool_type == "bing_custom_search": + bing_config = tool.get("bing_custom_search", {}) + return HostedWebSearchTool( + additional_properties={ + "custom_connection_id": bing_config.get("connection_id"), + "custom_instance_name": bing_config.get("instance_name"), + } + ) + + if tool_type == "mcp": + # Hosted MCP tools are defined on the Azure agent, no local handling needed + # Azure may not return full server_url, so skip conversion + return None + + if tool_type == "function": + # Function tools are returned as dicts - users must provide implementations + return tool + + # Unknown tool type - pass through + return tool + + +def _convert_sdk_tool(tool: ToolDefinition) -> ToolProtocol | dict[str, Any] | None: + """Convert an SDK-object Azure AI tool to Agent Framework tool.""" + tool_type = getattr(tool, "type", None) + + if tool_type == "code_interpreter": + return HostedCodeInterpreterTool() + + if tool_type == "file_search": + file_search_config = getattr(tool, "file_search", None) + vector_store_ids = getattr(file_search_config, "vector_store_ids", []) if file_search_config else [] + inputs = [HostedVectorStoreContent(vector_store_id=vs_id) for vs_id in vector_store_ids] + return HostedFileSearchTool(inputs=inputs if inputs else None) # type: ignore + + if tool_type == "bing_grounding": + bing_config = getattr(tool, "bing_grounding", None) + connection_id = getattr(bing_config, "connection_id", None) if bing_config else None + return HostedWebSearchTool(additional_properties={"connection_id": connection_id} if connection_id else None) + + if tool_type == "bing_custom_search": + bing_config = getattr(tool, "bing_custom_search", None) + return HostedWebSearchTool( + additional_properties={ + "custom_connection_id": getattr(bing_config, "connection_id", None) if bing_config else None, + "custom_instance_name": getattr(bing_config, "instance_name", None) if bing_config else None, + } + ) + + if tool_type == "mcp": + # Hosted MCP tools are defined on the Azure agent, no local handling needed + # Azure may not return full server_url, so skip conversion + return None + + if tool_type == "function": + # Function tools from SDK don't have implementations - skip + return None + + # Unknown tool type - convert to dict if possible + if hasattr(tool, "as_dict"): + return tool.as_dict() # type: ignore[union-attr] + return {"type": tool_type} if tool_type else {} + + +def from_azure_ai_tools(tools: Sequence[Tool | dict[str, Any]] | None) -> list[ToolProtocol | dict[str, Any]]: + """Parses and converts a sequence of Azure AI tools into Agent Framework compatible tools. + + Args: + tools: A sequence of tool objects or dictionaries + defining the tools to be parsed. Can be None. + + Returns: + list[ToolProtocol | dict[str, Any]]: A list of converted tools compatible with the + Agent Framework. + """ + agent_tools: list[ToolProtocol | dict[str, Any]] = [] + if not tools: + return agent_tools + for tool in tools: + # Handle raw dictionary tools + tool_dict = tool if isinstance(tool, dict) else dict(tool) + tool_type = tool_dict.get("type") + + if tool_type == "mcp": + mcp_tool = cast(MCPTool, tool_dict) + approval_mode: Literal["always_require", "never_require"] | dict[str, set[str]] | None = None + if require_approval := mcp_tool.get("require_approval"): + if require_approval == "always": + approval_mode = "always_require" + elif require_approval == "never": + approval_mode = "never_require" + elif isinstance(require_approval, dict): + approval_mode = {} + if "always" in require_approval: + approval_mode["always_require_approval"] = set(require_approval["always"].get("tool_names", [])) # type: ignore + if "never" in require_approval: + approval_mode["never_require_approval"] = set(require_approval["never"].get("tool_names", [])) # type: ignore + + agent_tools.append( + HostedMCPTool( + name=mcp_tool.get("server_label", "").replace("_", " "), + url=mcp_tool.get("server_url", ""), + description=mcp_tool.get("server_description"), + headers=mcp_tool.get("headers"), + allowed_tools=mcp_tool.get("allowed_tools"), + approval_mode=approval_mode, # type: ignore + ) + ) + elif tool_type == "code_interpreter": + ci_tool = cast(CodeInterpreterTool, tool_dict) + container = ci_tool.get("container", {}) + ci_inputs: list[Contents] = [] + if "file_ids" in container: + for file_id in container["file_ids"]: + ci_inputs.append(HostedFileContent(file_id=file_id)) + + agent_tools.append(HostedCodeInterpreterTool(inputs=ci_inputs if ci_inputs else None)) # type: ignore + elif tool_type == "file_search": + fs_tool = cast(ProjectsFileSearchTool, tool_dict) + fs_inputs: list[Contents] = [] + if "vector_store_ids" in fs_tool: + for vs_id in fs_tool["vector_store_ids"]: + fs_inputs.append(HostedVectorStoreContent(vector_store_id=vs_id)) + + agent_tools.append( + HostedFileSearchTool( + inputs=fs_inputs if fs_inputs else None, # type: ignore + max_results=fs_tool.get("max_num_results"), + ) + ) + elif tool_type == "web_search_preview": + ws_tool = cast(WebSearchPreviewTool, tool_dict) + additional_properties: dict[str, Any] = {} + if user_location := ws_tool.get("user_location"): + additional_properties["user_location"] = { + "city": user_location.get("city"), + "country": user_location.get("country"), + "region": user_location.get("region"), + "timezone": user_location.get("timezone"), + } + + agent_tools.append(HostedWebSearchTool(additional_properties=additional_properties)) + else: + agent_tools.append(tool_dict) + return agent_tools + + +def to_azure_ai_tools( + tools: Sequence[ToolProtocol | MutableMapping[str, Any]] | None, +) -> list[Tool | dict[str, Any]]: + """Converts Agent Framework tools into Azure AI compatible tools. + + Args: + tools: A sequence of Agent Framework tool objects or dictionaries + defining the tools to be converted. Can be None. + + Returns: + list[Tool | dict[str, Any]]: A list of converted tools compatible with Azure AI. + """ + azure_tools: list[Tool | dict[str, Any]] = [] + if not tools: + return azure_tools + + for tool in tools: + if isinstance(tool, ToolProtocol): + match tool: + case HostedMCPTool(): + azure_tools.append(_prepare_mcp_tool_for_azure_ai(tool)) + case HostedCodeInterpreterTool(): + file_ids: list[str] = [] + if tool.inputs: + for tool_input in tool.inputs: + if isinstance(tool_input, HostedFileContent): + file_ids.append(tool_input.file_id) + container = CodeInterpreterToolAuto(file_ids=file_ids if file_ids else None) + ci_tool: CodeInterpreterTool = CodeInterpreterTool(container=container) + azure_tools.append(ci_tool) + case AIFunction(): + params = tool.parameters() + params["additionalProperties"] = False + azure_tools.append( + FunctionTool( + name=tool.name, + parameters=params, + strict=False, + description=tool.description, + ) + ) + case HostedFileSearchTool(): + if not tool.inputs: + raise ValueError("HostedFileSearchTool requires inputs to be specified.") + vector_store_ids: list[str] = [ + inp.vector_store_id for inp in tool.inputs if isinstance(inp, HostedVectorStoreContent) + ] + if not vector_store_ids: + raise ValueError( + "HostedFileSearchTool requires inputs to be of type `HostedVectorStoreContent`." + ) + fs_tool: ProjectsFileSearchTool = ProjectsFileSearchTool(vector_store_ids=vector_store_ids) + if tool.max_results: + fs_tool["max_num_results"] = tool.max_results + azure_tools.append(fs_tool) + case HostedWebSearchTool(): + ws_tool: WebSearchPreviewTool = WebSearchPreviewTool() + if tool.additional_properties: + location: dict[str, str] | None = ( + tool.additional_properties.get("user_location", None) + if tool.additional_properties + else None + ) + if location: + ws_tool.user_location = ApproximateLocation( + city=location.get("city"), + country=location.get("country"), + region=location.get("region"), + timezone=location.get("timezone"), + ) + azure_tools.append(ws_tool) + case _: + logger.debug("Unsupported tool passed (type: %s)", type(tool)) + else: + # Handle raw dictionary tools + tool_dict = tool if isinstance(tool, dict) else dict(tool) + azure_tools.append(tool_dict) + + return azure_tools + + +def _prepare_mcp_tool_for_azure_ai(tool: HostedMCPTool) -> MCPTool: + """Convert HostedMCPTool to Azure AI MCPTool format. + + Args: + tool: The HostedMCPTool to convert. + + Returns: + MCPTool: The converted Azure AI MCPTool. + """ + mcp: MCPTool = MCPTool(server_label=tool.name.replace(" ", "_"), server_url=str(tool.url)) + + if tool.description: + mcp["server_description"] = tool.description + + if tool.headers: + mcp["headers"] = tool.headers + + if tool.allowed_tools: + mcp["allowed_tools"] = list(tool.allowed_tools) + + if tool.approval_mode: + match tool.approval_mode: + case str(): + mcp["require_approval"] = "always" if tool.approval_mode == "always_require" else "never" + case _: + if always_require_approvals := tool.approval_mode.get("always_require_approval"): + mcp["require_approval"] = {"always": {"tool_names": list(always_require_approvals)}} + if never_require_approvals := tool.approval_mode.get("never_require_approval"): + mcp["require_approval"] = {"never": {"tool_names": list(never_require_approvals)}} + + return mcp + + +def create_text_format_config( + response_format: type[BaseModel] | Mapping[str, Any], +) -> ( + ResponseTextFormatConfigurationJsonSchema + | ResponseTextFormatConfigurationJsonObject + | ResponseTextFormatConfigurationText +): + """Convert response_format into Azure text format configuration.""" + if isinstance(response_format, type) and issubclass(response_format, BaseModel): + schema = response_format.model_json_schema() + # Ensure additionalProperties is explicitly false to satisfy Azure validation + if isinstance(schema, dict): + schema.setdefault("additionalProperties", False) + return ResponseTextFormatConfigurationJsonSchema( + name=response_format.__name__, + schema=schema, + strict=True, + ) + + if isinstance(response_format, Mapping): + format_config = _convert_response_format(response_format) + format_type = format_config.get("type") + if format_type == "json_schema": + # Ensure schema includes additionalProperties=False to satisfy Azure validation + schema = dict(format_config.get("schema", {})) # type: ignore[assignment] + schema.setdefault("additionalProperties", False) + config_kwargs: dict[str, Any] = { + "name": format_config.get("name") or "response", + "schema": schema, + } + if "strict" in format_config: + config_kwargs["strict"] = format_config["strict"] + if "description" in format_config: + config_kwargs["description"] = format_config["description"] + return ResponseTextFormatConfigurationJsonSchema(**config_kwargs) + if format_type == "json_object": + return ResponseTextFormatConfigurationJsonObject() + if format_type == "text": + return ResponseTextFormatConfigurationText() + + raise ServiceInvalidRequestError("response_format must be a Pydantic model or mapping.") + + +def _convert_response_format(response_format: Mapping[str, Any]) -> dict[str, Any]: + """Convert Chat style response_format into Responses text format config.""" + if "format" in response_format and isinstance(response_format["format"], Mapping): + return dict(cast("Mapping[str, Any]", response_format["format"])) + + format_type = response_format.get("type") + if format_type == "json_schema": + schema_section = response_format.get("json_schema", response_format) + if not isinstance(schema_section, Mapping): + raise ServiceInvalidRequestError("json_schema response_format must be a mapping.") + schema_section_typed = cast("Mapping[str, Any]", schema_section) + schema: Any = schema_section_typed.get("schema") + if schema is None: + raise ServiceInvalidRequestError("json_schema response_format requires a schema.") + name: str = str( + schema_section_typed.get("name") + or schema_section_typed.get("title") + or (cast("Mapping[str, Any]", schema).get("title") if isinstance(schema, Mapping) else None) + or "response" + ) + format_config: dict[str, Any] = { + "type": "json_schema", + "name": name, + "schema": schema, + } + if "strict" in schema_section: + format_config["strict"] = schema_section["strict"] + if "description" in schema_section and schema_section["description"] is not None: + format_config["description"] = schema_section["description"] + return format_config + + if format_type in {"json_object", "text"}: + return {"type": format_type} + + raise ServiceInvalidRequestError("Unsupported response_format provided for Azure AI client.") diff --git a/python/packages/azure-ai/pyproject.toml b/python/packages/azure-ai/pyproject.toml index 37491b42e5..65c2b2c0c9 100644 --- a/python/packages/azure-ai/pyproject.toml +++ b/python/packages/azure-ai/pyproject.toml @@ -4,7 +4,7 @@ description = "Azure AI Foundry integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/azure-ai/tests/test_agent_provider.py b/python/packages/azure-ai/tests/test_agent_provider.py new file mode 100644 index 0000000000..3df8d318ec --- /dev/null +++ b/python/packages/azure-ai/tests/test_agent_provider.py @@ -0,0 +1,803 @@ +# Copyright (c) Microsoft. All rights reserved. + +import os +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from agent_framework import ( + ChatAgent, + HostedCodeInterpreterTool, + HostedFileSearchTool, + HostedMCPTool, + HostedVectorStoreContent, + HostedWebSearchTool, + ai_function, +) +from agent_framework.exceptions import ServiceInitializationError +from azure.ai.agents.models import ( + Agent, + CodeInterpreterToolDefinition, +) +from pydantic import BaseModel + +from agent_framework_azure_ai import ( + AzureAIAgentsProvider, + AzureAISettings, +) +from agent_framework_azure_ai._shared import ( + from_azure_ai_agent_tools, + to_azure_ai_agent_tools, +) + +skip_if_azure_ai_integration_tests_disabled = pytest.mark.skipif( + os.getenv("RUN_INTEGRATION_TESTS", "false").lower() != "true" + or os.getenv("AZURE_AI_PROJECT_ENDPOINT", "") in ("", "https://test-project.cognitiveservices.azure.com/"), + reason="No real AZURE_AI_PROJECT_ENDPOINT provided; skipping integration tests." + if os.getenv("RUN_INTEGRATION_TESTS", "false").lower() == "true" + else "Integration tests are disabled.", +) + + +# region Provider Initialization Tests + + +def test_provider_init_with_agents_client(mock_agents_client: MagicMock) -> None: + """Test AzureAIAgentsProvider initialization with existing AgentsClient.""" + provider = AzureAIAgentsProvider(agents_client=mock_agents_client) + + assert provider._agents_client is mock_agents_client # type: ignore + assert provider._should_close_client is False # type: ignore + + +def test_provider_init_with_credential( + azure_ai_unit_test_env: dict[str, str], + mock_azure_credential: MagicMock, +) -> None: + """Test AzureAIAgentsProvider initialization with credential.""" + with patch("agent_framework_azure_ai._agent_provider.AgentsClient") as mock_client_class: + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + + provider = AzureAIAgentsProvider(credential=mock_azure_credential) + + mock_client_class.assert_called_once() + assert provider._agents_client is mock_client_instance # type: ignore + assert provider._should_close_client is True # type: ignore + + +def test_provider_init_with_explicit_endpoint(mock_azure_credential: MagicMock) -> None: + """Test AzureAIAgentsProvider initialization with explicit endpoint.""" + with patch("agent_framework_azure_ai._agent_provider.AgentsClient") as mock_client_class: + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + + provider = AzureAIAgentsProvider( + project_endpoint="https://custom-endpoint.com/", + credential=mock_azure_credential, + ) + + mock_client_class.assert_called_once() + call_kwargs = mock_client_class.call_args.kwargs + assert call_kwargs["endpoint"] == "https://custom-endpoint.com/" + assert provider._should_close_client is True # type: ignore + + +def test_provider_init_missing_endpoint_raises( + mock_azure_credential: MagicMock, +) -> None: + """Test AzureAIAgentsProvider raises error when endpoint is missing.""" + # Mock AzureAISettings to return None for project_endpoint + with patch("agent_framework_azure_ai._agent_provider.AzureAISettings") as mock_settings_class: + mock_settings = MagicMock() + mock_settings.project_endpoint = None + mock_settings.model_deployment_name = "test-model" + mock_settings_class.return_value = mock_settings + + with pytest.raises(ServiceInitializationError) as exc_info: + AzureAIAgentsProvider(credential=mock_azure_credential) + + assert "project endpoint is required" in str(exc_info.value).lower() + + +def test_provider_init_missing_credential_raises(azure_ai_unit_test_env: dict[str, str]) -> None: + """Test AzureAIAgentsProvider raises error when credential is missing.""" + with pytest.raises(ServiceInitializationError) as exc_info: + AzureAIAgentsProvider() + + assert "credential is required" in str(exc_info.value).lower() + + +# endregion + + +# region Context Manager Tests + + +async def test_provider_context_manager_closes_client(mock_agents_client: MagicMock) -> None: + """Test that context manager closes client when it was created by provider.""" + with patch("agent_framework_azure_ai._agent_provider.AgentsClient") as mock_client_class: + mock_client_instance = AsyncMock() + mock_client_class.return_value = mock_client_instance + + with patch.object(AzureAIAgentsProvider, "__init__", lambda self: None): # type: ignore + provider = AzureAIAgentsProvider.__new__(AzureAIAgentsProvider) + provider._agents_client = mock_client_instance # type: ignore + provider._should_close_client = True # type: ignore + provider._settings = AzureAISettings(project_endpoint="https://test.com") # type: ignore + + async with provider: + pass + + mock_client_instance.close.assert_called_once() + + +async def test_provider_context_manager_does_not_close_external_client(mock_agents_client: MagicMock) -> None: + """Test that context manager does not close externally provided client.""" + mock_agents_client.close = AsyncMock() + + provider = AzureAIAgentsProvider(agents_client=mock_agents_client) + + async with provider: + pass + + mock_agents_client.close.assert_not_called() + + +# endregion + + +# region create_agent Tests + + +async def test_create_agent_basic( + azure_ai_unit_test_env: dict[str, str], + mock_agents_client: MagicMock, +) -> None: + """Test creating a basic agent.""" + mock_agent = MagicMock(spec=Agent) + mock_agent.id = "test-agent-id" + mock_agent.name = "TestAgent" + mock_agent.description = "A test agent" + mock_agent.instructions = "Be helpful" + mock_agent.model = "gpt-4" + mock_agent.temperature = 0.7 + mock_agent.top_p = 0.9 + mock_agent.tools = [] + mock_agents_client.create_agent = AsyncMock(return_value=mock_agent) + + provider = AzureAIAgentsProvider(agents_client=mock_agents_client) + + agent = await provider.create_agent( + name="TestAgent", + instructions="Be helpful", + description="A test agent", + ) + + assert isinstance(agent, ChatAgent) + assert agent.name == "TestAgent" + assert agent.id == "test-agent-id" + mock_agents_client.create_agent.assert_called_once() + + +async def test_create_agent_with_model( + azure_ai_unit_test_env: dict[str, str], + mock_agents_client: MagicMock, +) -> None: + """Test creating an agent with explicit model.""" + mock_agent = MagicMock(spec=Agent) + mock_agent.id = "test-agent-id" + mock_agent.name = "TestAgent" + mock_agent.description = None + mock_agent.instructions = None + mock_agent.model = "custom-model" + mock_agent.temperature = None + mock_agent.top_p = None + mock_agent.tools = [] + mock_agents_client.create_agent = AsyncMock(return_value=mock_agent) + + provider = AzureAIAgentsProvider(agents_client=mock_agents_client) + + await provider.create_agent(name="TestAgent", model="custom-model") + + call_kwargs = mock_agents_client.create_agent.call_args.kwargs + assert call_kwargs["model"] == "custom-model" + + +async def test_create_agent_with_tools( + azure_ai_unit_test_env: dict[str, str], + mock_agents_client: MagicMock, +) -> None: + """Test creating an agent with tools.""" + mock_agent = MagicMock(spec=Agent) + mock_agent.id = "test-agent-id" + mock_agent.name = "TestAgent" + mock_agent.description = None + mock_agent.instructions = None + mock_agent.model = "gpt-4" + mock_agent.temperature = None + mock_agent.top_p = None + mock_agent.tools = [] + mock_agents_client.create_agent = AsyncMock(return_value=mock_agent) + + provider = AzureAIAgentsProvider(agents_client=mock_agents_client) + + @ai_function + def get_weather(city: str) -> str: + """Get weather for a city.""" + return f"Weather in {city}" + + await provider.create_agent(name="TestAgent", tools=get_weather) + + call_kwargs = mock_agents_client.create_agent.call_args.kwargs + assert "tools" in call_kwargs + assert len(call_kwargs["tools"]) > 0 + + +async def test_create_agent_with_response_format( + azure_ai_unit_test_env: dict[str, str], + mock_agents_client: MagicMock, +) -> None: + """Test creating an agent with structured response format via default_options.""" + + class WeatherResponse(BaseModel): + temperature: float + description: str + + mock_agent = MagicMock(spec=Agent) + mock_agent.id = "test-agent-id" + mock_agent.name = "TestAgent" + mock_agent.description = None + mock_agent.instructions = None + mock_agent.model = "gpt-4" + mock_agent.temperature = None + mock_agent.top_p = None + mock_agent.tools = [] + mock_agents_client.create_agent = AsyncMock(return_value=mock_agent) + + provider = AzureAIAgentsProvider(agents_client=mock_agents_client) + + await provider.create_agent( + name="TestAgent", + default_options={"response_format": WeatherResponse}, + ) + + call_kwargs = mock_agents_client.create_agent.call_args.kwargs + assert "response_format" in call_kwargs + + +async def test_create_agent_missing_model_raises( + mock_agents_client: MagicMock, +) -> None: + """Test that create_agent raises error when model is not specified.""" + # Create provider with mocked settings that has no model + with patch("agent_framework_azure_ai._agent_provider.AzureAISettings") as mock_settings_class: + mock_settings = MagicMock() + mock_settings.project_endpoint = "https://test.com" + mock_settings.model_deployment_name = None # No model configured + mock_settings_class.return_value = mock_settings + + provider = AzureAIAgentsProvider(agents_client=mock_agents_client) + + with pytest.raises(ServiceInitializationError) as exc_info: + await provider.create_agent(name="TestAgent") + + assert "model deployment name is required" in str(exc_info.value).lower() + + +# endregion + + +# region get_agent Tests + + +async def test_get_agent_by_id( + azure_ai_unit_test_env: dict[str, str], + mock_agents_client: MagicMock, +) -> None: + """Test getting an agent by ID.""" + mock_agent = MagicMock(spec=Agent) + mock_agent.id = "existing-agent-id" + mock_agent.name = "ExistingAgent" + mock_agent.description = "An existing agent" + mock_agent.instructions = "Be helpful" + mock_agent.model = "gpt-4" + mock_agent.temperature = 0.7 + mock_agent.top_p = 0.9 + mock_agent.tools = [] + mock_agents_client.get_agent = AsyncMock(return_value=mock_agent) + + provider = AzureAIAgentsProvider(agents_client=mock_agents_client) + + agent = await provider.get_agent("existing-agent-id") + + assert isinstance(agent, ChatAgent) + assert agent.id == "existing-agent-id" + mock_agents_client.get_agent.assert_called_once_with("existing-agent-id") + + +async def test_get_agent_with_function_tools( + azure_ai_unit_test_env: dict[str, str], + mock_agents_client: MagicMock, +) -> None: + """Test getting an agent that has function tools requires tool implementations.""" + mock_function_tool = MagicMock() + mock_function_tool.type = "function" + mock_function_tool.function = MagicMock() + mock_function_tool.function.name = "get_weather" + + mock_agent = MagicMock(spec=Agent) + mock_agent.id = "agent-with-tools" + mock_agent.name = "AgentWithTools" + mock_agent.description = None + mock_agent.instructions = None + mock_agent.model = "gpt-4" + mock_agent.temperature = None + mock_agent.top_p = None + mock_agent.tools = [mock_function_tool] + mock_agents_client.get_agent = AsyncMock(return_value=mock_agent) + + provider = AzureAIAgentsProvider(agents_client=mock_agents_client) + + with pytest.raises(ServiceInitializationError) as exc_info: + await provider.get_agent("agent-with-tools") + + assert "get_weather" in str(exc_info.value) + + +async def test_get_agent_with_provided_function_tools( + azure_ai_unit_test_env: dict[str, str], + mock_agents_client: MagicMock, +) -> None: + """Test getting an agent with function tools when implementations are provided.""" + mock_function_tool = MagicMock() + mock_function_tool.type = "function" + mock_function_tool.function = MagicMock() + mock_function_tool.function.name = "get_weather" + + mock_agent = MagicMock(spec=Agent) + mock_agent.id = "agent-with-tools" + mock_agent.name = "AgentWithTools" + mock_agent.description = None + mock_agent.instructions = None + mock_agent.model = "gpt-4" + mock_agent.temperature = None + mock_agent.top_p = None + mock_agent.tools = [mock_function_tool] + mock_agents_client.get_agent = AsyncMock(return_value=mock_agent) + + @ai_function + def get_weather(city: str) -> str: + """Get weather for a city.""" + return f"Weather in {city}" + + provider = AzureAIAgentsProvider(agents_client=mock_agents_client) + + agent = await provider.get_agent("agent-with-tools", tools=get_weather) + + assert isinstance(agent, ChatAgent) + assert agent.id == "agent-with-tools" + + +# endregion + + +# region as_agent Tests + + +def test_as_agent_wraps_without_http( + azure_ai_unit_test_env: dict[str, str], + mock_agents_client: MagicMock, +) -> None: + """Test as_agent wraps Agent object without making HTTP calls.""" + mock_agent = MagicMock(spec=Agent) + mock_agent.id = "wrap-agent-id" + mock_agent.name = "WrapAgent" + mock_agent.description = "Wrapped agent" + mock_agent.instructions = "Be helpful" + mock_agent.model = "gpt-4" + mock_agent.temperature = 0.5 + mock_agent.top_p = 0.8 + mock_agent.tools = [] + + provider = AzureAIAgentsProvider(agents_client=mock_agents_client) + + agent = provider.as_agent(mock_agent) + + assert isinstance(agent, ChatAgent) + assert agent.id == "wrap-agent-id" + assert agent.name == "WrapAgent" + # Ensure no HTTP calls were made + mock_agents_client.get_agent.assert_not_called() + mock_agents_client.create_agent.assert_not_called() + + +def test_as_agent_with_function_tools_validates( + azure_ai_unit_test_env: dict[str, str], + mock_agents_client: MagicMock, +) -> None: + """Test as_agent validates that function tool implementations are provided.""" + mock_function_tool = MagicMock() + mock_function_tool.type = "function" + mock_function_tool.function = MagicMock() + mock_function_tool.function.name = "my_function" + + mock_agent = MagicMock(spec=Agent) + mock_agent.id = "agent-id" + mock_agent.name = "Agent" + mock_agent.description = None + mock_agent.instructions = None + mock_agent.model = "gpt-4" + mock_agent.temperature = None + mock_agent.top_p = None + mock_agent.tools = [mock_function_tool] + + provider = AzureAIAgentsProvider(agents_client=mock_agents_client) + + with pytest.raises(ServiceInitializationError) as exc_info: + provider.as_agent(mock_agent) + + assert "my_function" in str(exc_info.value) + + +def test_as_agent_with_hosted_tools( + azure_ai_unit_test_env: dict[str, str], + mock_agents_client: MagicMock, +) -> None: + """Test as_agent handles hosted tools correctly.""" + mock_code_interpreter = MagicMock() + mock_code_interpreter.type = "code_interpreter" + + mock_agent = MagicMock(spec=Agent) + mock_agent.id = "agent-id" + mock_agent.name = "Agent" + mock_agent.description = None + mock_agent.instructions = None + mock_agent.model = "gpt-4" + mock_agent.temperature = None + mock_agent.top_p = None + mock_agent.tools = [mock_code_interpreter] + + provider = AzureAIAgentsProvider(agents_client=mock_agents_client) + + agent = provider.as_agent(mock_agent) + + assert isinstance(agent, ChatAgent) + # Should have HostedCodeInterpreterTool in the default_options tools + assert any(isinstance(t, HostedCodeInterpreterTool) for t in (agent.default_options.get("tools") or [])) + + +# endregion + + +# region Tool Conversion Tests - to_azure_ai_agent_tools + + +def test_to_azure_ai_agent_tools_empty() -> None: + """Test converting empty tools list.""" + result = to_azure_ai_agent_tools(None) + assert result == [] + + result = to_azure_ai_agent_tools([]) + assert result == [] + + +def test_to_azure_ai_agent_tools_function() -> None: + """Test converting AIFunction to Azure tool definition.""" + + @ai_function + def get_weather(city: str) -> str: + """Get weather for a city.""" + return f"Weather in {city}" + + result = to_azure_ai_agent_tools([get_weather]) + + assert len(result) == 1 + assert result[0]["type"] == "function" + assert result[0]["function"]["name"] == "get_weather" + + +def test_to_azure_ai_agent_tools_code_interpreter() -> None: + """Test converting HostedCodeInterpreterTool.""" + tool = HostedCodeInterpreterTool() + + result = to_azure_ai_agent_tools([tool]) + + assert len(result) == 1 + assert isinstance(result[0], CodeInterpreterToolDefinition) + + +def test_to_azure_ai_agent_tools_file_search() -> None: + """Test converting HostedFileSearchTool with vector stores.""" + tool = HostedFileSearchTool(inputs=[HostedVectorStoreContent(vector_store_id="vs-123")]) + run_options: dict[str, Any] = {} + + result = to_azure_ai_agent_tools([tool], run_options) + + assert len(result) == 1 + assert "tool_resources" in run_options + + +def test_to_azure_ai_agent_tools_web_search_bing_grounding(monkeypatch: Any) -> None: + """Test converting HostedWebSearchTool for Bing Grounding.""" + # Use a properly formatted connection ID as required by Azure SDK + valid_conn_id = ( + "/subscriptions/test-sub/resourceGroups/test-rg/" + "providers/Microsoft.CognitiveServices/accounts/test-account/" + "projects/test-project/connections/test-connection" + ) + monkeypatch.setenv("BING_CONNECTION_ID", valid_conn_id) + tool = HostedWebSearchTool() + + result = to_azure_ai_agent_tools([tool]) + + assert len(result) > 0 + + +def test_to_azure_ai_agent_tools_web_search_custom(monkeypatch: Any) -> None: + """Test converting HostedWebSearchTool for Custom Bing Search.""" + monkeypatch.setenv("BING_CUSTOM_CONNECTION_ID", "custom-conn-id") + monkeypatch.setenv("BING_CUSTOM_INSTANCE_NAME", "my-instance") + tool = HostedWebSearchTool() + + result = to_azure_ai_agent_tools([tool]) + + assert len(result) > 0 + + +def test_to_azure_ai_agent_tools_web_search_missing_config(monkeypatch: Any) -> None: + """Test converting HostedWebSearchTool raises error when config is missing.""" + monkeypatch.delenv("BING_CONNECTION_ID", raising=False) + monkeypatch.delenv("BING_CUSTOM_CONNECTION_ID", raising=False) + monkeypatch.delenv("BING_CUSTOM_INSTANCE_NAME", raising=False) + tool = HostedWebSearchTool() + + with pytest.raises(ServiceInitializationError): + to_azure_ai_agent_tools([tool]) + + +def test_to_azure_ai_agent_tools_mcp() -> None: + """Test converting HostedMCPTool.""" + tool = HostedMCPTool( + name="my mcp server", + url="https://mcp.example.com", + allowed_tools=["tool1", "tool2"], + ) + + result = to_azure_ai_agent_tools([tool]) + + assert len(result) > 0 + + +def test_to_azure_ai_agent_tools_dict_passthrough() -> None: + """Test that dict tools are passed through.""" + tool = {"type": "custom_tool", "config": {"key": "value"}} + + result = to_azure_ai_agent_tools([tool]) + + assert len(result) == 1 + assert result[0] == tool + + +def test_to_azure_ai_agent_tools_unsupported_type() -> None: + """Test that unsupported tool types raise error.""" + + class UnsupportedTool: + pass + + with pytest.raises(ServiceInitializationError): + to_azure_ai_agent_tools([UnsupportedTool()]) # type: ignore + + +# endregion + + +# region Tool Conversion Tests - from_azure_ai_agent_tools + + +def test_from_azure_ai_agent_tools_empty() -> None: + """Test converting empty tools list.""" + result = from_azure_ai_agent_tools(None) + assert result == [] + + result = from_azure_ai_agent_tools([]) + assert result == [] + + +def test_from_azure_ai_agent_tools_code_interpreter() -> None: + """Test converting CodeInterpreterToolDefinition.""" + tool = CodeInterpreterToolDefinition() + + result = from_azure_ai_agent_tools([tool]) + + assert len(result) == 1 + assert isinstance(result[0], HostedCodeInterpreterTool) + + +def test_from_azure_ai_agent_tools_code_interpreter_dict() -> None: + """Test converting code_interpreter dict.""" + tool = {"type": "code_interpreter"} + + result = from_azure_ai_agent_tools([tool]) + + assert len(result) == 1 + assert isinstance(result[0], HostedCodeInterpreterTool) + + +def test_from_azure_ai_agent_tools_file_search_dict() -> None: + """Test converting file_search dict with vector store IDs.""" + tool = { + "type": "file_search", + "file_search": {"vector_store_ids": ["vs-123", "vs-456"]}, + } + + result = from_azure_ai_agent_tools([tool]) + + assert len(result) == 1 + assert isinstance(result[0], HostedFileSearchTool) + assert len(result[0].inputs or []) == 2 + + +def test_from_azure_ai_agent_tools_bing_grounding_dict() -> None: + """Test converting bing_grounding dict.""" + tool = { + "type": "bing_grounding", + "bing_grounding": {"connection_id": "conn-123"}, + } + + result = from_azure_ai_agent_tools([tool]) + + assert len(result) == 1 + assert isinstance(result[0], HostedWebSearchTool) + + additional_properties = result[0].additional_properties + + assert additional_properties + assert additional_properties.get("connection_id") == "conn-123" + + +def test_from_azure_ai_agent_tools_bing_custom_search_dict() -> None: + """Test converting bing_custom_search dict.""" + tool = { + "type": "bing_custom_search", + "bing_custom_search": { + "connection_id": "custom-conn", + "instance_name": "my-instance", + }, + } + + result = from_azure_ai_agent_tools([tool]) + + assert len(result) == 1 + assert isinstance(result[0], HostedWebSearchTool) + additional_properties = result[0].additional_properties + + assert additional_properties + assert additional_properties.get("custom_connection_id") == "custom-conn" + + +def test_from_azure_ai_agent_tools_mcp_dict() -> None: + """Test that mcp dict is skipped (hosted on Azure, no local handling needed).""" + tool = { + "type": "mcp", + "mcp": { + "server_label": "my_server", + "server_url": "https://mcp.example.com", + "allowed_tools": ["tool1"], + }, + } + + result = from_azure_ai_agent_tools([tool]) + + # MCP tools are hosted on Azure agent, skipped in conversion + assert len(result) == 0 + + +def test_from_azure_ai_agent_tools_function_dict() -> None: + """Test converting function tool dict (returned as-is).""" + tool: dict[str, Any] = { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": {}, + }, + } + + result = from_azure_ai_agent_tools([tool]) + + assert len(result) == 1 + assert result[0] == tool + + +def test_from_azure_ai_agent_tools_unknown_dict() -> None: + """Test converting unknown tool type dict.""" + tool = {"type": "unknown_tool", "config": "value"} + + result = from_azure_ai_agent_tools([tool]) + + assert len(result) == 1 + assert result[0] == tool + + +# endregion + + +# region Integration Tests + + +@skip_if_azure_ai_integration_tests_disabled +async def test_integration_create_agent() -> None: + """Integration test: Create an agent using the provider.""" + from azure.identity.aio import AzureCliCredential + + async with ( + AzureCliCredential() as credential, + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( + name="IntegrationTestAgent", + instructions="You are a helpful assistant for testing.", + ) + + try: + assert isinstance(agent, ChatAgent) + assert agent.name == "IntegrationTestAgent" + assert agent.id is not None + finally: + # Cleanup: delete the agent + if agent.id: + await provider._agents_client.delete_agent(agent.id) # type: ignore + + +@skip_if_azure_ai_integration_tests_disabled +async def test_integration_get_agent() -> None: + """Integration test: Get an existing agent using the provider.""" + from azure.identity.aio import AzureCliCredential + + async with ( + AzureCliCredential() as credential, + AzureAIAgentsProvider(credential=credential) as provider, + ): + # First create an agent + created = await provider._agents_client.create_agent( # type: ignore + model=os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME", "gpt-4o"), + name="GetAgentTest", + instructions="Test agent", + ) + + try: + # Then get it using the provider + agent = await provider.get_agent(created.id) + + assert isinstance(agent, ChatAgent) + assert agent.id == created.id + finally: + await provider._agents_client.delete_agent(created.id) # type: ignore + + +@skip_if_azure_ai_integration_tests_disabled +async def test_integration_create_and_run() -> None: + """Integration test: Create an agent and run a conversation.""" + from azure.identity.aio import AzureCliCredential + + async with ( + AzureCliCredential() as credential, + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( + name="RunTestAgent", + instructions="You are a helpful assistant. Always respond with 'Hello!' to any greeting.", + ) + + try: + result = await agent.run("Hi there!") + + assert result is not None + assert len(result.messages) > 0 + finally: + if agent.id: + await provider._agents_client.delete_agent(agent.id) # type: ignore + + +# endregion diff --git a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py index 134a3586b0..21bedbf710 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py @@ -8,10 +8,9 @@ import pytest from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentThread, - AIFunction, ChatAgent, ChatClientProtocol, ChatMessage, @@ -28,10 +27,8 @@ HostedFileSearchTool, HostedMCPTool, HostedVectorStoreContent, - HostedWebSearchTool, Role, TextContent, - ToolMode, UriContent, ) from agent_framework._serialization import SerializationMixin @@ -39,7 +36,6 @@ from azure.ai.agents.models import ( AgentsNamedToolChoice, AgentsNamedToolChoiceType, - CodeInterpreterToolDefinition, FileInfo, MessageDeltaChunk, MessageDeltaTextContent, @@ -197,34 +193,6 @@ def test_azure_ai_chat_client_init_missing_model_deployment_for_agent_creation() ) -def test_azure_ai_chat_client_from_dict(mock_agents_client: MagicMock) -> None: - """Test AzureAIAgentClient.from_dict method.""" - settings = { - "agents_client": mock_agents_client, - "agent_id": "test-agent-id", - "thread_id": "test-thread-id", - "project_endpoint": "https://test-endpoint.com/", - "model_deployment_name": "test-model", - "agent_name": "TestAgent", - } - - azure_ai_settings = AzureAISettings( - project_endpoint=settings["project_endpoint"], - model_deployment_name=settings["model_deployment_name"], - ) - - chat_client: AzureAIAgentClient = create_test_azure_ai_chat_client( - mock_agents_client, - agent_id=settings["agent_id"], # type: ignore - thread_id=settings["thread_id"], # type: ignore - azure_ai_settings=azure_ai_settings, - ) - - assert chat_client.agents_client is mock_agents_client - assert chat_client.agent_id == "test-agent-id" - assert chat_client.thread_id == "test-thread-id" - - def test_azure_ai_chat_client_init_missing_credential(azure_ai_unit_test_env: dict[str, str]) -> None: """Test AzureAIAgentClient.__init__ when credential is missing and no agents_client provided.""" with pytest.raises( @@ -253,7 +221,7 @@ def test_azure_ai_chat_client_init_validation_error(mock_azure_credential: Magic ) -def test_azure_ai_chat_client_from_settings() -> None: +def test_azure_ai_chat_client_from_dict() -> None: """Test from_settings class method.""" mock_agents_client = MagicMock() settings = { @@ -265,7 +233,7 @@ def test_azure_ai_chat_client_from_settings() -> None: "agent_name": "TestAgent", } - client = AzureAIAgentClient.from_settings(settings) + client = AzureAIAgentClient.from_dict(settings) assert client.agents_client is mock_agents_client assert client.agent_id == "test-agent" @@ -372,7 +340,7 @@ async def test_azure_ai_chat_client_prepare_options_basic(mock_agents_client: Ma chat_client = create_test_azure_ai_chat_client(mock_agents_client) messages = [ChatMessage(role=Role.USER, text="Hello")] - chat_options = ChatOptions(max_tokens=100, temperature=0.7) + chat_options: ChatOptions = {"max_tokens": 100, "temperature": 0.7} run_options, tool_results = await chat_client._prepare_options(messages, chat_options) # type: ignore @@ -386,7 +354,7 @@ async def test_azure_ai_chat_client_prepare_options_no_chat_options(mock_agents_ messages = [ChatMessage(role=Role.USER, text="Hello")] - run_options, tool_results = await chat_client._prepare_options(messages, ChatOptions()) # type: ignore + run_options, tool_results = await chat_client._prepare_options(messages, {}) # type: ignore assert run_options is not None assert tool_results is None @@ -403,7 +371,7 @@ async def test_azure_ai_chat_client_prepare_options_with_image_content(mock_agen image_content = UriContent(uri="https://example.com/image.jpg", media_type="image/jpeg") messages = [ChatMessage(role=Role.USER, contents=[image_content])] - run_options, _ = await chat_client._prepare_options(messages, ChatOptions()) # type: ignore + run_options, _ = await chat_client._prepare_options(messages, {}) # type: ignore assert "additional_messages" in run_options assert len(run_options["additional_messages"]) == 1 @@ -494,7 +462,7 @@ async def test_azure_ai_chat_client_prepare_options_with_messages(mock_agents_cl ChatMessage(role=Role.USER, text="Hello"), ] - run_options, _ = await chat_client._prepare_options(messages, ChatOptions()) # type: ignore + run_options, _ = await chat_client._prepare_options(messages, {}) # type: ignore assert "instructions" in run_options assert "You are a helpful assistant" in run_options["instructions"] @@ -506,7 +474,7 @@ async def test_azure_ai_chat_client_inner_get_response(mock_agents_client: Magic """Test _inner_get_response method.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") messages = [ChatMessage(role=Role.USER, text="Hello")] - chat_options = ChatOptions() + chat_options: ChatOptions = {} async def mock_streaming_response(): yield ChatResponseUpdate(role=Role.ASSISTANT, text="Hello back") @@ -518,7 +486,7 @@ async def mock_streaming_response(): mock_response = ChatResponse(role=Role.ASSISTANT, text="Hello back") mock_from_generator.return_value = mock_response - result = await chat_client._inner_get_response(messages=messages, chat_options=chat_options) # type: ignore + result = await chat_client._inner_get_response(messages=messages, options=chat_options) # type: ignore assert result is mock_response mock_from_generator.assert_called_once() @@ -627,8 +595,7 @@ async def test_azure_ai_chat_client_prepare_options_with_none_tool_choice( """Test _prepare_options with tool_choice set to 'none'.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) - chat_options = ChatOptions() - chat_options.tool_choice = "none" + chat_options: ChatOptions = {"tool_choice": "none"} run_options, _ = await chat_client._prepare_options([], chat_options) # type: ignore @@ -643,8 +610,7 @@ async def test_azure_ai_chat_client_prepare_options_with_auto_tool_choice( """Test _prepare_options with tool_choice set to 'auto'.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) - chat_options = ChatOptions() - chat_options.tool_choice = "auto" + chat_options = {"tool_choice": "auto"} run_options, _ = await chat_client._prepare_options([], chat_options) # type: ignore @@ -653,35 +619,17 @@ async def test_azure_ai_chat_client_prepare_options_with_auto_tool_choice( assert run_options["tool_choice"] == AgentsToolChoiceOptionMode.AUTO -async def test_azure_ai_chat_client_prepare_tool_choice_none_string( - mock_agents_client: MagicMock, -) -> None: - """Test _prepare_tool_choice when tool_choice is string 'none'.""" - chat_client = create_test_azure_ai_chat_client(mock_agents_client) - - # Create a mock tool for testing - mock_tool = MagicMock() - chat_options = ChatOptions(tools=[mock_tool], tool_choice="none") - - # Call the method - chat_client._prepare_tool_choice(chat_options) # type: ignore - - # Verify tools are cleared and tool_choice is set to NONE mode - assert chat_options.tools is None - assert chat_options.tool_choice == ToolMode.NONE.mode - - async def test_azure_ai_chat_client_prepare_options_tool_choice_required_specific_function( mock_agents_client: MagicMock, ) -> None: - """Test _prepare_options with ToolMode.REQUIRED specifying a specific function name.""" + """Test _prepare_options with required tool_choice specifying a specific function name.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) - required_tool_mode = ToolMode.REQUIRED("specific_function_name") + required_tool_mode = {"mode": "required", "required_function_name": "specific_function_name"} dict_tool = {"type": "function", "function": {"name": "test_function"}} - chat_options = ChatOptions(tools=[dict_tool], tool_choice=required_tool_mode) + chat_options = {"tools": [dict_tool], "tool_choice": required_tool_mode} messages = [ChatMessage(role=Role.USER, text="Hello")] run_options, _ = await chat_client._prepare_options(messages, chat_options) # type: ignore @@ -703,8 +651,7 @@ async def test_azure_ai_chat_client_prepare_options_with_response_format( class TestResponseModel(BaseModel): name: str = Field(description="Test name") - chat_options = ChatOptions() - chat_options.response_format = TestResponseModel + chat_options: ChatOptions = {"response_format": TestResponseModel} run_options, _ = await chat_client._prepare_options([], chat_options) # type: ignore @@ -722,60 +669,6 @@ def test_azure_ai_chat_client_service_url_method(mock_agents_client: MagicMock) assert url == "https://test-endpoint.com/" -async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_ai_function(mock_agents_client: MagicMock) -> None: - """Test _prepare_tools_for_azure_ai with AIFunction tool.""" - - chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") - - # Create a mock AIFunction - mock_ai_function = MagicMock(spec=AIFunction) - mock_ai_function.to_json_schema_spec.return_value = {"type": "function", "function": {"name": "test_function"}} - - result = await chat_client._prepare_tools_for_azure_ai([mock_ai_function]) # type: ignore - - assert len(result) == 1 - assert result[0] == {"type": "function", "function": {"name": "test_function"}} - mock_ai_function.to_json_schema_spec.assert_called_once() - - -async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_code_interpreter(mock_agents_client: MagicMock) -> None: - """Test _prepare_tools_for_azure_ai with HostedCodeInterpreterTool.""" - - chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") - - code_interpreter_tool = HostedCodeInterpreterTool() - - result = await chat_client._prepare_tools_for_azure_ai([code_interpreter_tool]) # type: ignore - - assert len(result) == 1 - assert isinstance(result[0], CodeInterpreterToolDefinition) - - -async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_mcp_tool(mock_agents_client: MagicMock) -> None: - """Test _prepare_tools_for_azure_ai with HostedMCPTool.""" - - chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") - - mcp_tool = HostedMCPTool(name="Test MCP Tool", url="https://example.com/mcp", allowed_tools=["tool1", "tool2"]) - - # Mock McpTool to have a definitions attribute - with patch("agent_framework_azure_ai._chat_client.McpTool") as mock_mcp_tool_class: - mock_mcp_tool = MagicMock() - mock_mcp_tool.definitions = [{"type": "mcp", "name": "test_mcp"}] - mock_mcp_tool_class.return_value = mock_mcp_tool - - result = await chat_client._prepare_tools_for_azure_ai([mcp_tool]) # type: ignore - - assert len(result) == 1 - assert result[0] == {"type": "mcp", "name": "test_mcp"} - # Check that the call was made (order of allowed_tools may vary) - mock_mcp_tool_class.assert_called_once() - call_args = mock_mcp_tool_class.call_args[1] - assert call_args["server_label"] == "Test_MCP_Tool" - assert call_args["server_url"] == "https://example.com/mcp" - assert set(call_args["allowed_tools"]) == {"tool1", "tool2"} - - async def test_azure_ai_chat_client_prepare_options_mcp_never_require(mock_agents_client: MagicMock) -> None: """Test _prepare_options with HostedMCPTool having never_require approval mode.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) @@ -783,10 +676,9 @@ async def test_azure_ai_chat_client_prepare_options_mcp_never_require(mock_agent mcp_tool = HostedMCPTool(name="Test MCP Tool", url="https://example.com/mcp", approval_mode="never_require") messages = [ChatMessage(role=Role.USER, text="Hello")] - chat_options = ChatOptions(tools=[mcp_tool], tool_choice="auto") + chat_options: ChatOptions = {"tools": [mcp_tool], "tool_choice": "auto"} - with patch("agent_framework_azure_ai._chat_client.McpTool") as mock_mcp_tool_class: - # Mock _prepare_tools_for_azure_ai to avoid actual tool preparation + with patch("agent_framework_azure_ai._shared.McpTool") as mock_mcp_tool_class: mock_mcp_tool_instance = MagicMock() mock_mcp_tool_instance.definitions = [{"type": "mcp", "name": "test_mcp"}] mock_mcp_tool_class.return_value = mock_mcp_tool_instance @@ -816,10 +708,9 @@ async def test_azure_ai_chat_client_prepare_options_mcp_with_headers(mock_agents ) messages = [ChatMessage(role=Role.USER, text="Hello")] - chat_options = ChatOptions(tools=[mcp_tool], tool_choice="auto") + chat_options: ChatOptions = {"tools": [mcp_tool], "tool_choice": "auto"} - with patch("agent_framework_azure_ai._chat_client.McpTool") as mock_mcp_tool_class: - # Mock _prepare_tools_for_azure_ai to avoid actual tool preparation + with patch("agent_framework_azure_ai._shared.McpTool") as mock_mcp_tool_class: mock_mcp_tool_instance = MagicMock() mock_mcp_tool_instance.definitions = [{"type": "mcp", "name": "test_mcp"}] mock_mcp_tool_class.return_value = mock_mcp_tool_instance @@ -837,121 +728,6 @@ async def test_azure_ai_chat_client_prepare_options_mcp_with_headers(mock_agents assert mcp_resource["headers"] == headers -async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_web_search_bing_grounding( - mock_agents_client: MagicMock, -) -> None: - """Test _prepare_tools_for_azure_ai with HostedWebSearchTool using Bing Grounding.""" - - chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") - - web_search_tool = HostedWebSearchTool( - additional_properties={ - "connection_id": "test-connection-id", - "count": 5, - "freshness": "Day", - "market": "en-US", - "set_lang": "en", - } - ) - - # Mock BingGroundingTool - with patch("agent_framework_azure_ai._chat_client.BingGroundingTool") as mock_bing_grounding: - mock_bing_tool = MagicMock() - mock_bing_tool.definitions = [{"type": "bing_grounding"}] - mock_bing_grounding.return_value = mock_bing_tool - - result = await chat_client._prepare_tools_for_azure_ai([web_search_tool]) # type: ignore - - assert len(result) == 1 - assert result[0] == {"type": "bing_grounding"} - call_args = mock_bing_grounding.call_args[1] - assert call_args["count"] == 5 - assert call_args["freshness"] == "Day" - assert call_args["market"] == "en-US" - assert call_args["set_lang"] == "en" - assert "connection_id" in call_args - - -async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_web_search_bing_grounding_with_connection_id( - mock_agents_client: MagicMock, -) -> None: - """Test _prepare_tools_... with HostedWebSearchTool using Bing Grounding with connection_id (no HTTP call).""" - - chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") - - web_search_tool = HostedWebSearchTool( - additional_properties={ - "connection_id": "direct-connection-id", - "count": 3, - } - ) - - # Mock BingGroundingTool - with patch("agent_framework_azure_ai._chat_client.BingGroundingTool") as mock_bing_grounding: - mock_bing_tool = MagicMock() - mock_bing_tool.definitions = [{"type": "bing_grounding"}] - mock_bing_grounding.return_value = mock_bing_tool - - result = await chat_client._prepare_tools_for_azure_ai([web_search_tool]) # type: ignore - - assert len(result) == 1 - assert result[0] == {"type": "bing_grounding"} - mock_bing_grounding.assert_called_once_with(connection_id="direct-connection-id", count=3) - - -async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_web_search_custom_bing( - mock_agents_client: MagicMock, -) -> None: - """Test _prepare_tools_for_azure_ai with HostedWebSearchTool using Custom Bing Search.""" - - chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") - - web_search_tool = HostedWebSearchTool( - additional_properties={ - "custom_connection_id": "custom-connection-id", - "custom_instance_name": "custom-instance", - "count": 10, - } - ) - - # Mock BingCustomSearchTool - with patch("agent_framework_azure_ai._chat_client.BingCustomSearchTool") as mock_custom_bing: - mock_custom_tool = MagicMock() - mock_custom_tool.definitions = [{"type": "bing_custom_search"}] - mock_custom_bing.return_value = mock_custom_tool - - result = await chat_client._prepare_tools_for_azure_ai([web_search_tool]) # type: ignore - - assert len(result) == 1 - assert result[0] == {"type": "bing_custom_search"} - - -async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_file_search_with_vector_stores( - mock_agents_client: MagicMock, -) -> None: - """Test _prepare_tools_for_azure_ai with HostedFileSearchTool using vector stores.""" - - chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") - - vector_store_input = HostedVectorStoreContent(vector_store_id="vs-123") - file_search_tool = HostedFileSearchTool(inputs=[vector_store_input]) - - # Mock FileSearchTool - with patch("agent_framework_azure_ai._chat_client.FileSearchTool") as mock_file_search: - mock_file_tool = MagicMock() - mock_file_tool.definitions = [{"type": "file_search"}] - mock_file_tool.resources = {"vector_store_ids": ["vs-123"]} - mock_file_search.return_value = mock_file_tool - - run_options = {} - result = await chat_client._prepare_tools_for_azure_ai([file_search_tool], run_options) # type: ignore - - assert len(result) == 1 - assert result[0] == {"type": "file_search"} - assert run_options["tool_resources"] == {"vector_store_ids": ["vs-123"]} - mock_file_search.assert_called_once_with(vector_store_ids=["vs-123"]) - - async def test_azure_ai_chat_client_create_agent_stream_submit_tool_approvals( mock_agents_client: MagicMock, ) -> None: @@ -993,28 +769,6 @@ async def test_azure_ai_chat_client_create_agent_stream_submit_tool_approvals( assert call_args["tool_approvals"][0].approve is True -async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_dict_tool(mock_agents_client: MagicMock) -> None: - """Test _prepare_tools_for_azure_ai with dictionary tool definition.""" - chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") - - dict_tool = {"type": "custom_tool", "config": {"param": "value"}} - - result = await chat_client._prepare_tools_for_azure_ai([dict_tool]) # type: ignore - - assert len(result) == 1 - assert result[0] == dict_tool - - -async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_unsupported_tool(mock_agents_client: MagicMock) -> None: - """Test _prepare_tools_for_azure_ai with unsupported tool type.""" - chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") - - unsupported_tool = "not_a_tool" - - with pytest.raises(ServiceInitializationError, match="Unsupported tool type: "): - await chat_client._prepare_tools_for_azure_ai([unsupported_tool]) # type: ignore - - async def test_azure_ai_chat_client_get_active_thread_run_with_active_run(mock_agents_client: MagicMock) -> None: """Test _get_active_thread_run when there's an active run.""" @@ -1518,8 +1272,7 @@ async def test_azure_ai_chat_client_get_response_tools() -> None: # Test that the agents_client can be used to get a response response = await azure_ai_chat_client.get_response( messages=messages, - tools=[get_weather], - tool_choice="auto", + options={"tools": [get_weather], "tool_choice": "auto"}, ) assert response is not None @@ -1571,8 +1324,7 @@ async def test_azure_ai_chat_client_streaming_tools() -> None: # Test that the agents_client can be used to get a response response = azure_ai_chat_client.get_streaming_response( messages=messages, - tools=[get_weather], - tool_choice="auto", + options={"tools": [get_weather], "tool_choice": "auto"}, ) full_message: str = "" async for chunk in response: @@ -1596,7 +1348,7 @@ async def test_azure_ai_chat_client_agent_basic_run() -> None: response = await agent.run("Hello! Please respond with 'Hello World' exactly.") # Validate response - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert response.text is not None assert len(response.text) > 0 assert "Hello World" in response.text @@ -1613,7 +1365,7 @@ async def test_azure_ai_chat_client_agent_basic_run_streaming() -> None: full_message: str = "" async for chunk in agent.run_stream("Please respond with exactly: 'This is a streaming response test.'"): assert chunk is not None - assert isinstance(chunk, AgentRunResponseUpdate) + assert isinstance(chunk, AgentResponseUpdate) if chunk.text: full_message += chunk.text @@ -1637,14 +1389,14 @@ async def test_azure_ai_chat_client_agent_thread_persistence() -> None: first_response = await agent.run( "Remember this number: 42. What number did I just tell you to remember?", thread=thread ) - assert isinstance(first_response, AgentRunResponse) + assert isinstance(first_response, AgentResponse) assert "42" in first_response.text # Second message - test conversation memory second_response = await agent.run( "What number did I tell you to remember in my previous message?", thread=thread ) - assert isinstance(second_response, AgentRunResponse) + assert isinstance(second_response, AgentResponse) assert "42" in second_response.text @@ -1661,7 +1413,7 @@ async def test_azure_ai_chat_client_agent_existing_thread_id() -> None: first_response = await first_agent.run("My name is Alice. Remember this.", thread=thread) # Validate first response - assert isinstance(first_response, AgentRunResponse) + assert isinstance(first_response, AgentResponse) assert first_response.text is not None # The thread ID is set after the first response @@ -1680,7 +1432,7 @@ async def test_azure_ai_chat_client_agent_existing_thread_id() -> None: response2 = await second_agent.run("What is my name?", thread=thread) # Validate that the agent remembers the previous conversation - assert isinstance(response2, AgentRunResponse) + assert isinstance(response2, AgentResponse) assert response2.text is not None # Should reference Alice from the previous conversation assert "alice" in response2.text.lower() @@ -1700,7 +1452,7 @@ async def test_azure_ai_chat_client_agent_code_interpreter(): response = await agent.run("Write Python code to calculate the factorial of 5 and show the result.") # Validate response - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert response.text is not None # Factorial of 5 is 120 assert "120" in response.text or "factorial" in response.text.lower() @@ -1735,7 +1487,7 @@ async def test_azure_ai_chat_client_agent_file_search(): response = await agent.run("Who is the youngest employee in the files?") # Validate response - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert response.text is not None # Should find information about Alice Johnson (age 24) being the youngest assert any(term in response.text.lower() for term in ["alice", "johnson", "24"]) @@ -1772,10 +1524,10 @@ async def test_azure_ai_chat_client_agent_hosted_mcp_tool() -> None: ) as agent: response = await agent.run( "How to create an Azure storage account using az cli?", - max_tokens=200, + options={"max_tokens": 200}, ) - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert response.text is not None assert len(response.text) > 0 @@ -1800,7 +1552,7 @@ async def test_azure_ai_chat_client_agent_level_tool_persistence(): # First run - agent-level tool should be available first_response = await agent.run("What's the weather like in Chicago?") - assert isinstance(first_response, AgentRunResponse) + assert isinstance(first_response, AgentResponse) assert first_response.text is not None # Should use the agent-level weather tool assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "25"]) @@ -1808,7 +1560,7 @@ async def test_azure_ai_chat_client_agent_level_tool_persistence(): # Second run - agent-level tool should still be available (persistence test) second_response = await agent.run("What's the weather in Miami?") - assert isinstance(second_response, AgentRunResponse) + assert isinstance(second_response, AgentResponse) assert second_response.text is not None # Should use the agent-level weather tool again assert any(term in second_response.text.lower() for term in ["miami", "sunny", "25"]) @@ -1823,23 +1575,17 @@ async def test_azure_ai_chat_client_agent_chat_options_run_level() -> None: ) as agent: response = await agent.run( "Provide a brief, helpful response.", - max_tokens=100, - temperature=0.7, - top_p=0.9, - seed=123, - user="comprehensive-test-user", tools=[get_weather], - tool_choice="auto", - frequency_penalty=0.1, - presence_penalty=0.1, - stop=["END"], - store=True, - logit_bias={"test": 1}, - metadata={"test": "value"}, - additional_properties={"custom_param": "test_value"}, + options={ + "max_tokens": 100, + "temperature": 0.7, + "top_p": 0.9, + "tool_choice": "auto", + "metadata": {"test": "value"}, + }, ) - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert response.text is not None assert len(response.text) > 0 @@ -1850,26 +1596,20 @@ async def test_azure_ai_chat_client_agent_chat_options_agent_level() -> None: async with ChatAgent( chat_client=AzureAIAgentClient(credential=AzureCliCredential()), instructions="You are a helpful assistant.", - max_tokens=100, - temperature=0.7, - top_p=0.9, - seed=123, - user="comprehensive-test-user", tools=[get_weather], - tool_choice="auto", - frequency_penalty=0.1, - presence_penalty=0.1, - stop=["END"], - store=True, - logit_bias={"test": 1}, - metadata={"test": "value"}, - request_kwargs={"custom_param": "test_value"}, + default_options={ + "max_tokens": 100, + "temperature": 0.7, + "top_p": 0.9, + "tool_choice": "auto", + "metadata": {"test": "value"}, + }, ) as agent: response = await agent.run( "Provide a brief, helpful response.", ) - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert response.text is not None assert len(response.text) > 0 diff --git a/python/packages/azure-ai/tests/test_azure_ai_client.py b/python/packages/azure-ai/tests/test_azure_ai_client.py index 3b1b500ede..dad8f049fe 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_client.py @@ -1,33 +1,49 @@ # Copyright (c) Microsoft. All rights reserved. +import json import os -from collections.abc import AsyncIterator +from collections.abc import AsyncGenerator, AsyncIterator from contextlib import asynccontextmanager -from typing import Annotated +from typing import Annotated, Any from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 import pytest from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, ChatAgent, ChatClientProtocol, ChatMessage, ChatOptions, + ChatResponse, + HostedCodeInterpreterTool, + HostedFileContent, + HostedFileSearchTool, + HostedMCPTool, + HostedVectorStoreContent, + HostedWebSearchTool, Role, TextContent, ) from agent_framework.exceptions import ServiceInitializationError from azure.ai.projects.aio import AIProjectClient from azure.ai.projects.models import ( + ApproximateLocation, + CodeInterpreterTool, + CodeInterpreterToolAuto, + FileSearchTool, + MCPTool, ResponseTextFormatConfigurationJsonSchema, + WebSearchPreviewTool, ) from azure.identity.aio import AzureCliCredential from openai.types.responses.parsed_response import ParsedResponse from openai.types.responses.response import Response as OpenAIResponse from pydantic import BaseModel, ConfigDict, Field, ValidationError +from pytest import fixture, param from agent_framework_azure_ai import AzureAIClient, AzureAISettings +from agent_framework_azure_ai._shared import from_azure_ai_tools skip_if_azure_ai_integration_tests_disabled = pytest.mark.skipif( os.getenv("RUN_INTEGRATION_TESTS", "false").lower() != "true" @@ -41,6 +57,32 @@ ) +@pytest.fixture +def mock_project_client() -> MagicMock: + """Fixture that provides a mock AIProjectClient.""" + mock_client = MagicMock() + + # Mock agents property + mock_client.agents = MagicMock() + mock_client.agents.create_version = AsyncMock() + + # Mock conversations property + mock_client.conversations = MagicMock() + mock_client.conversations.create = AsyncMock() + + # Mock telemetry property + mock_client.telemetry = MagicMock() + mock_client.telemetry.get_application_insights_connection_string = AsyncMock() + + # Mock get_openai_client method + mock_client.get_openai_client = AsyncMock() + + # Mock close method + mock_client.close = AsyncMock() + + return mock_client + + @asynccontextmanager async def temporary_chat_client(agent_name: str) -> AsyncIterator[AzureAIClient]: """Async context manager that creates an Azure AI agent and yields an `AzureAIClient`. @@ -121,7 +163,7 @@ def test_azure_ai_settings_init_with_explicit_values() -> None: assert settings.model_deployment_name == "custom-model" -def test_azure_ai_client_init_with_project_client(mock_project_client: MagicMock) -> None: +def test_init_with_project_client(mock_project_client: MagicMock) -> None: """Test AzureAIClient initialization with existing project_client.""" with patch("agent_framework_azure_ai._client.AzureAISettings") as mock_settings: mock_settings.return_value.project_endpoint = None @@ -140,7 +182,7 @@ def test_azure_ai_client_init_with_project_client(mock_project_client: MagicMock assert isinstance(client, ChatClientProtocol) -def test_azure_ai_client_init_auto_create_client( +def test_init_auto_create_client( azure_ai_unit_test_env: dict[str, str], mock_azure_credential: MagicMock, ) -> None: @@ -164,7 +206,7 @@ def test_azure_ai_client_init_auto_create_client( mock_ai_project_client.assert_called_once() -def test_azure_ai_client_init_missing_project_endpoint() -> None: +def test_init_missing_project_endpoint() -> None: """Test AzureAIClient initialization when project_endpoint is missing and no project_client provided.""" with patch("agent_framework_azure_ai._client.AzureAISettings") as mock_settings: mock_settings.return_value.project_endpoint = None @@ -174,7 +216,7 @@ def test_azure_ai_client_init_missing_project_endpoint() -> None: AzureAIClient(credential=MagicMock()) -def test_azure_ai_client_init_missing_credential(azure_ai_unit_test_env: dict[str, str]) -> None: +def test_init_missing_credential(azure_ai_unit_test_env: dict[str, str]) -> None: """Test AzureAIClient.__init__ when credential is missing and no project_client provided.""" with pytest.raises( ServiceInitializationError, match="Azure credential is required when project_client is not provided" @@ -185,7 +227,7 @@ def test_azure_ai_client_init_missing_credential(azure_ai_unit_test_env: dict[st ) -def test_azure_ai_client_init_validation_error(mock_azure_credential: MagicMock) -> None: +def test_init_validation_error(mock_azure_credential: MagicMock) -> None: """Test that ValidationError in AzureAISettings is properly handled.""" with patch("agent_framework_azure_ai._client.AzureAISettings") as mock_settings: mock_settings.side_effect = ValidationError.from_exception_data("test", []) @@ -194,7 +236,7 @@ def test_azure_ai_client_init_validation_error(mock_azure_credential: MagicMock) AzureAIClient(credential=mock_azure_credential) -async def test_azure_ai_client_get_agent_reference_or_create_existing_version( +async def test_get_agent_reference_or_create_existing_version( mock_project_client: MagicMock, ) -> None: """Test _get_agent_reference_or_create when agent_version is already provided.""" @@ -205,7 +247,7 @@ async def test_azure_ai_client_get_agent_reference_or_create_existing_version( assert agent_ref == {"name": "existing-agent", "version": "1.0", "type": "agent_reference"} -async def test_azure_ai_client_get_agent_reference_or_create_missing_agent_name( +async def test_get_agent_reference_or_create_missing_agent_name( mock_project_client: MagicMock, ) -> None: """Test _get_agent_reference_or_create raises when agent_name is missing.""" @@ -215,7 +257,7 @@ async def test_azure_ai_client_get_agent_reference_or_create_missing_agent_name( await client._get_agent_reference_or_create({}, None) # type: ignore -async def test_azure_ai_client_get_agent_reference_or_create_new_agent( +async def test_get_agent_reference_or_create_new_agent( mock_project_client: MagicMock, azure_ai_unit_test_env: dict[str, str], ) -> None: @@ -239,7 +281,7 @@ async def test_azure_ai_client_get_agent_reference_or_create_new_agent( assert client.agent_version == "1.0" -async def test_azure_ai_client_get_agent_reference_missing_model( +async def test_get_agent_reference_missing_model( mock_project_client: MagicMock, ) -> None: """Test _get_agent_reference_or_create when model is missing for agent creation.""" @@ -249,7 +291,7 @@ async def test_azure_ai_client_get_agent_reference_missing_model( await client._get_agent_reference_or_create({}, None) # type: ignore -async def test_azure_ai_client_prepare_messages_for_azure_ai_with_system_messages( +async def test_prepare_messages_for_azure_ai_with_system_messages( mock_project_client: MagicMock, ) -> None: """Test _prepare_messages_for_azure_ai converts system/developer messages to instructions.""" @@ -269,7 +311,7 @@ async def test_azure_ai_client_prepare_messages_for_azure_ai_with_system_message assert instructions == "You are a helpful assistant." -async def test_azure_ai_client_prepare_messages_for_azure_ai_no_system_messages( +async def test_prepare_messages_for_azure_ai_no_system_messages( mock_project_client: MagicMock, ) -> None: """Test _prepare_messages_for_azure_ai with no system/developer messages.""" @@ -286,7 +328,7 @@ async def test_azure_ai_client_prepare_messages_for_azure_ai_no_system_messages( assert instructions is None -def test_azure_ai_client_transform_input_for_azure_ai(mock_project_client: MagicMock) -> None: +def test_transform_input_for_azure_ai(mock_project_client: MagicMock) -> None: """Test _transform_input_for_azure_ai adds required fields for Azure AI schema. WORKAROUND TEST: Azure AI Projects API requires 'type' at item level and @@ -331,7 +373,7 @@ def test_azure_ai_client_transform_input_for_azure_ai(mock_project_client: Magic assert result[1]["content"][0]["text"] == "Hi there!" -def test_azure_ai_client_transform_input_preserves_existing_fields(mock_project_client: MagicMock) -> None: +def test_transform_input_preserves_existing_fields(mock_project_client: MagicMock) -> None: """Test _transform_input_for_azure_ai preserves existing type and annotations.""" client = create_test_azure_ai_client(mock_project_client) @@ -353,7 +395,7 @@ def test_azure_ai_client_transform_input_preserves_existing_fields(mock_project_ assert result[0]["content"][0]["annotations"] == [{"some": "annotation"}] -def test_azure_ai_client_transform_input_handles_non_dict_content(mock_project_client: MagicMock) -> None: +def test_transform_input_handles_non_dict_content(mock_project_client: MagicMock) -> None: """Test _transform_input_for_azure_ai handles non-dict content items.""" client = create_test_azure_ai_client(mock_project_client) @@ -373,12 +415,11 @@ def test_azure_ai_client_transform_input_handles_non_dict_content(mock_project_c assert result[0]["content"] == ["plain string content"] -async def test_azure_ai_client_prepare_options_basic(mock_project_client: MagicMock) -> None: +async def test_prepare_options_basic(mock_project_client: MagicMock) -> None: """Test prepare_options basic functionality.""" client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0") messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])] - chat_options = ChatOptions() with ( patch.object(client.__class__.__bases__[0], "_prepare_options", return_value={"model": "test-model"}), @@ -388,7 +429,7 @@ async def test_azure_ai_client_prepare_options_basic(mock_project_client: MagicM return_value={"name": "test-agent", "version": "1.0", "type": "agent_reference"}, ), ): - run_options = await client._prepare_options(messages, chat_options) + run_options = await client._prepare_options(messages, {}) assert "extra_body" in run_options assert run_options["extra_body"]["agent"]["name"] == "test-agent" @@ -401,7 +442,7 @@ async def test_azure_ai_client_prepare_options_basic(mock_project_client: MagicM ("https://example.com/api/projects/my-project", True), ], ) -async def test_azure_ai_client_prepare_options_with_application_endpoint( +async def test_prepare_options_with_application_endpoint( mock_azure_credential: MagicMock, endpoint: str, expects_agent: bool ) -> None: client = AzureAIClient( @@ -413,7 +454,6 @@ async def test_azure_ai_client_prepare_options_with_application_endpoint( ) messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])] - chat_options = ChatOptions() with ( patch.object(client.__class__.__bases__[0], "_prepare_options", return_value={"model": "test-model"}), @@ -423,7 +463,7 @@ async def test_azure_ai_client_prepare_options_with_application_endpoint( return_value={"name": "test-agent", "version": "1", "type": "agent_reference"}, ), ): - run_options = await client._prepare_options(messages, chat_options) + run_options = await client._prepare_options(messages, {}) if expects_agent: assert "extra_body" in run_options @@ -439,7 +479,7 @@ async def test_azure_ai_client_prepare_options_with_application_endpoint( ("https://example.com/api/projects/my-project", True), ], ) -async def test_azure_ai_client_prepare_options_with_application_project_client( +async def test_prepare_options_with_application_project_client( mock_project_client: MagicMock, endpoint: str, expects_agent: bool ) -> None: mock_project_client._config = MagicMock() @@ -453,7 +493,6 @@ async def test_azure_ai_client_prepare_options_with_application_project_client( ) messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])] - chat_options = ChatOptions() with ( patch.object(client.__class__.__bases__[0], "_prepare_options", return_value={"model": "test-model"}), @@ -463,7 +502,7 @@ async def test_azure_ai_client_prepare_options_with_application_project_client( return_value={"name": "test-agent", "version": "1", "type": "agent_reference"}, ), ): - run_options = await client._prepare_options(messages, chat_options) + run_options = await client._prepare_options(messages, {}) if expects_agent: assert "extra_body" in run_options @@ -472,7 +511,7 @@ async def test_azure_ai_client_prepare_options_with_application_project_client( assert "extra_body" not in run_options -async def test_azure_ai_client_initialize_client(mock_project_client: MagicMock) -> None: +async def test_initialize_client(mock_project_client: MagicMock) -> None: """Test _initialize_client method.""" client = create_test_azure_ai_client(mock_project_client) @@ -485,7 +524,7 @@ async def test_azure_ai_client_initialize_client(mock_project_client: MagicMock) mock_project_client.get_openai_client.assert_called_once() -def test_azure_ai_client_update_agent_name_and_description(mock_project_client: MagicMock) -> None: +def test_update_agent_name_and_description(mock_project_client: MagicMock) -> None: """Test _update_agent_name_and_description method.""" client = create_test_azure_ai_client(mock_project_client) @@ -506,7 +545,7 @@ def test_azure_ai_client_update_agent_name_and_description(mock_project_client: mock_update.assert_called_once_with(None) -async def test_azure_ai_client_async_context_manager(mock_project_client: MagicMock) -> None: +async def test_async_context_manager(mock_project_client: MagicMock) -> None: """Test async context manager functionality.""" client = create_test_azure_ai_client(mock_project_client, should_close_client=True) @@ -519,7 +558,7 @@ async def test_azure_ai_client_async_context_manager(mock_project_client: MagicM mock_project_client.close.assert_called_once() -async def test_azure_ai_client_close_method(mock_project_client: MagicMock) -> None: +async def test_close_method(mock_project_client: MagicMock) -> None: """Test close method.""" client = create_test_azure_ai_client(mock_project_client, should_close_client=True) @@ -530,7 +569,7 @@ async def test_azure_ai_client_close_method(mock_project_client: MagicMock) -> N mock_project_client.close.assert_called_once() -async def test_azure_ai_client_close_client_when_should_close_false(mock_project_client: MagicMock) -> None: +async def test_close_client_when_should_close_false(mock_project_client: MagicMock) -> None: """Test _close_client_if_needed when should_close_client is False.""" client = create_test_azure_ai_client(mock_project_client, should_close_client=False) @@ -542,7 +581,7 @@ async def test_azure_ai_client_close_client_when_should_close_false(mock_project mock_project_client.close.assert_not_called() -async def test_azure_ai_client_agent_creation_with_instructions( +async def test_agent_creation_with_instructions( mock_project_client: MagicMock, ) -> None: """Test agent creation with combined instructions.""" @@ -564,7 +603,7 @@ async def test_azure_ai_client_agent_creation_with_instructions( assert call_args[1]["definition"].instructions == "Message instructions. Option instructions. " -async def test_azure_ai_client_agent_creation_with_additional_args( +async def test_agent_creation_with_additional_args( mock_project_client: MagicMock, ) -> None: """Test agent creation with additional arguments.""" @@ -588,7 +627,7 @@ async def test_azure_ai_client_agent_creation_with_additional_args( assert definition.top_p == 0.8 -async def test_azure_ai_client_agent_creation_with_tools( +async def test_agent_creation_with_tools( mock_project_client: MagicMock, ) -> None: """Test agent creation with tools.""" @@ -610,7 +649,7 @@ async def test_azure_ai_client_agent_creation_with_tools( assert call_args[1]["definition"].tools == test_tools -async def test_azure_ai_client_use_latest_version_existing_agent( +async def test_use_latest_version_existing_agent( mock_project_client: MagicMock, ) -> None: """Test _get_agent_reference_or_create when use_latest_version=True and agent exists.""" @@ -634,7 +673,7 @@ async def test_azure_ai_client_use_latest_version_existing_agent( assert client.agent_version == "2.5" -async def test_azure_ai_client_use_latest_version_agent_not_found( +async def test_use_latest_version_agent_not_found( mock_project_client: MagicMock, ) -> None: """Test _get_agent_reference_or_create when use_latest_version=True but agent doesn't exist.""" @@ -663,7 +702,7 @@ async def test_azure_ai_client_use_latest_version_agent_not_found( assert client.agent_version == "1.0" -async def test_azure_ai_client_use_latest_version_false( +async def test_use_latest_version_false( mock_project_client: MagicMock, ) -> None: """Test _get_agent_reference_or_create when use_latest_version=False (default behavior).""" @@ -685,7 +724,7 @@ async def test_azure_ai_client_use_latest_version_false( assert agent_ref == {"name": "test-agent", "version": "1.0", "type": "agent_reference"} -async def test_azure_ai_client_use_latest_version_with_existing_agent_version( +async def test_use_latest_version_with_existing_agent_version( mock_project_client: MagicMock, ) -> None: """Test that use_latest_version is ignored when agent_version is already provided.""" @@ -711,7 +750,7 @@ class ResponseFormatModel(BaseModel): model_config = ConfigDict(extra="forbid") -async def test_azure_ai_client_agent_creation_with_response_format( +async def test_agent_creation_with_response_format( mock_project_client: MagicMock, ) -> None: """Test agent creation with response_format configuration.""" @@ -724,7 +763,7 @@ async def test_azure_ai_client_agent_creation_with_response_format( mock_project_client.agents.create_version = AsyncMock(return_value=mock_agent) run_options = {"model": "test-model"} - chat_options = ChatOptions(response_format=ResponseFormatModel) + chat_options = {"response_format": ResponseFormatModel} await client._get_agent_reference_or_create(run_options, None, chat_options) # type: ignore @@ -751,9 +790,10 @@ async def test_azure_ai_client_agent_creation_with_response_format( assert "name" in schema["properties"] assert "value" in schema["properties"] assert "description" in schema["properties"] + assert "additionalProperties" in schema -async def test_azure_ai_client_agent_creation_with_mapping_response_format( +async def test_agent_creation_with_mapping_response_format( mock_project_client: MagicMock, ) -> None: """Test agent creation when response_format is provided as a mapping.""" @@ -786,9 +826,9 @@ async def test_azure_ai_client_agent_creation_with_mapping_response_format( "schema": runtime_schema, }, } - chat_options = ChatOptions(response_format=response_format_mapping) # type: ignore + chat_options = {"response_format": response_format_mapping} - await client._get_agent_reference_or_create(run_options, None, chat_options) # type: ignore + await client._get_agent_reference_or_create(run_options, None, chat_options) call_args = mock_project_client.agents.create_version.call_args created_definition = call_args[1]["definition"] @@ -802,14 +842,14 @@ async def test_azure_ai_client_agent_creation_with_mapping_response_format( assert format_config.strict is True -async def test_azure_ai_client_prepare_options_excludes_response_format( +async def test_prepare_options_excludes_response_format( mock_project_client: MagicMock, ) -> None: """Test that prepare_options excludes response_format, text, and text_format from final run options.""" client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0") messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])] - chat_options = ChatOptions() + chat_options: ChatOptions = {} with ( patch.object( @@ -932,30 +972,59 @@ def test_get_conversation_id_with_parsed_response_no_conversation() -> None: assert result == "resp_parsed_12345" -@pytest.fixture -def mock_project_client() -> MagicMock: - """Fixture that provides a mock AIProjectClient.""" - mock_client = MagicMock() - - # Mock agents property - mock_client.agents = MagicMock() - mock_client.agents.create_version = AsyncMock() - - # Mock conversations property - mock_client.conversations = MagicMock() - mock_client.conversations.create = AsyncMock() +def test_from_azure_ai_tools() -> None: + """Test from_azure_ai_tools.""" + # Test MCP tool + mcp_tool = MCPTool(server_label="test_server", server_url="http://localhost:8080") + parsed_tools = from_azure_ai_tools([mcp_tool]) + assert len(parsed_tools) == 1 + assert isinstance(parsed_tools[0], HostedMCPTool) + assert parsed_tools[0].name == "test server" + assert str(parsed_tools[0].url).rstrip("/") == "http://localhost:8080" + + # Test Code Interpreter tool + ci_tool = CodeInterpreterTool(container=CodeInterpreterToolAuto(file_ids=["file-1"])) + parsed_tools = from_azure_ai_tools([ci_tool]) + assert len(parsed_tools) == 1 + assert isinstance(parsed_tools[0], HostedCodeInterpreterTool) + assert parsed_tools[0].inputs is not None + assert len(parsed_tools[0].inputs) == 1 + + tool_input = parsed_tools[0].inputs[0] + + assert tool_input and isinstance(tool_input, HostedFileContent) and tool_input.file_id == "file-1" + + # Test File Search tool + fs_tool = FileSearchTool(vector_store_ids=["vs-1"], max_num_results=5) + parsed_tools = from_azure_ai_tools([fs_tool]) + assert len(parsed_tools) == 1 + assert isinstance(parsed_tools[0], HostedFileSearchTool) + assert parsed_tools[0].inputs is not None + assert len(parsed_tools[0].inputs) == 1 + + tool_input = parsed_tools[0].inputs[0] + + assert tool_input and isinstance(tool_input, HostedVectorStoreContent) and tool_input.vector_store_id == "vs-1" + assert parsed_tools[0].max_results == 5 + + # Test Web Search tool + ws_tool = WebSearchPreviewTool( + user_location=ApproximateLocation(city="Seattle", country="US", region="WA", timezone="PST") + ) + parsed_tools = from_azure_ai_tools([ws_tool]) + assert len(parsed_tools) == 1 + assert isinstance(parsed_tools[0], HostedWebSearchTool) + assert parsed_tools[0].additional_properties - # Mock telemetry property - mock_client.telemetry = MagicMock() - mock_client.telemetry.get_application_insights_connection_string = AsyncMock() + user_location = parsed_tools[0].additional_properties["user_location"] - # Mock get_openai_client method - mock_client.get_openai_client = AsyncMock() + assert user_location["city"] == "Seattle" + assert user_location["country"] == "US" + assert user_location["region"] == "WA" + assert user_location["timezone"] == "PST" - # Mock close method - mock_client.close = AsyncMock() - return mock_client +# region Integration Tests def get_weather( @@ -965,143 +1034,356 @@ def get_weather( return f"The weather in {location} is sunny with a high of 25°C." -@pytest.mark.flaky -@skip_if_azure_ai_integration_tests_disabled -async def test_azure_ai_chat_client_agent_basic_run() -> None: - """Test ChatAgent basic run functionality with AzureAIClient.""" +class OutputStruct(BaseModel): + """A structured output for testing purposes.""" + + location: str + weather: str + + +@fixture +async def client() -> AsyncGenerator[AzureAIClient, None]: + """Create a client to test with.""" + agent_name = f"test-agent-{uuid4()}" + endpoint = os.environ["AZURE_AI_PROJECT_ENDPOINT"] async with ( - temporary_chat_client(agent_name="BasicRunAgent") as chat_client, - ChatAgent(chat_client=chat_client) as agent, + AzureCliCredential() as credential, + AIProjectClient(endpoint=endpoint, credential=credential) as project_client, ): - response = await agent.run("Hello! Please respond with 'Hello World' exactly.") - - # Validate response - assert isinstance(response, AgentRunResponse) - assert response.text is not None - assert len(response.text) > 0 - assert "Hello World" in response.text + client = AzureAIClient( + project_client=project_client, + agent_name=agent_name, + ) + try: + assert client.function_invocation_configuration + client.function_invocation_configuration.max_iterations = 1 + yield client + finally: + await project_client.agents.delete(agent_name=agent_name) @pytest.mark.flaky @skip_if_azure_ai_integration_tests_disabled -async def test_azure_ai_chat_client_agent_basic_run_streaming() -> None: - """Test ChatAgent basic streaming functionality with AzureAIClient.""" - async with ( - temporary_chat_client(agent_name="BasicRunStreamingAgent") as chat_client, - ChatAgent(chat_client=chat_client) as agent, - ): - full_message: str = "" - async for chunk in agent.run_stream("Please respond with exactly: 'This is a streaming response test.'"): - assert chunk is not None - assert isinstance(chunk, AgentRunResponseUpdate) - if chunk.text: - full_message += chunk.text +@pytest.mark.parametrize( + "option_name,option_value,needs_validation", + [ + # Simple ChatOptions - just verify they don't fail + param("top_p", 0.9, False, id="top_p"), + param("max_tokens", 500, False, id="max_tokens"), + param("seed", 123, False, id="seed"), + param("user", "test-user-id", False, id="user"), + param("metadata", {"test_key": "test_value"}, False, id="metadata"), + param("frequency_penalty", 0.5, False, id="frequency_penalty"), + param("presence_penalty", 0.3, False, id="presence_penalty"), + param("stop", ["END"], False, id="stop"), + param("allow_multiple_tool_calls", True, False, id="allow_multiple_tool_calls"), + param("tool_choice", "none", True, id="tool_choice_none"), + param("tool_choice", "auto", True, id="tool_choice_auto"), + param("tool_choice", "required", True, id="tool_choice_required_any"), + param( + "tool_choice", + {"mode": "required", "required_function_name": "get_weather"}, + True, + id="tool_choice_required", + ), + # OpenAIResponsesOptions - just verify they don't fail + param("safety_identifier", "user-hash-abc123", False, id="safety_identifier"), + param("truncation", "auto", False, id="truncation"), + param("top_logprobs", 5, False, id="top_logprobs"), + param("prompt_cache_key", "test-cache-key", False, id="prompt_cache_key"), + param("max_tool_calls", 3, False, id="max_tool_calls"), + ], +) +async def test_integration_options( + option_name: str, + option_value: Any, + needs_validation: bool, + client: AzureAIClient, +) -> None: + """Parametrized test covering options that can be set at runtime for a Foundry Agent. - # Validate streaming response - assert len(full_message) > 0 - assert "streaming response test" in full_message.lower() + Tests both streaming and non-streaming modes for each option to ensure + they don't cause failures. Options marked with needs_validation also + check that the feature actually works correctly. + + This test reuses a single agent. + """ + # Prepare test message + if option_name.startswith("tool_choice"): + # Use weather-related prompt for tool tests + messages = [ChatMessage(role="user", text="What is the weather in Seattle?")] + else: + # Generic prompt for simple options + messages = [ChatMessage(role="user", text="Say 'Hello World' briefly.")] + + # Build options dict + options: dict[str, Any] = {option_name: option_value, "tools": [get_weather]} + + for streaming in [False, True]: + if streaming: + # Test streaming mode + response_gen = client.get_streaming_response( + messages=messages, + options=options, + ) + + output_format = option_value if option_name == "response_format" else None + response = await ChatResponse.from_chat_response_generator(response_gen, output_format_type=output_format) + else: + # Test non-streaming mode + response = await client.get_response( + messages=messages, + options=options, + ) + + assert response is not None + assert isinstance(response, ChatResponse) + assert response.text is not None, f"No text in response for option '{option_name}'" + assert len(response.text) > 0, f"Empty response for option '{option_name}'" + + # Validate based on option type + if needs_validation: + if option_name.startswith("tool_choice"): + # Should have called the weather function + text = response.text.lower() + assert "sunny" in text or "seattle" in text, f"Tool not invoked for {option_name}" + elif option_name == "response_format": + if option_value == OutputStruct: + # Should have structured output + assert response.value is not None, "No structured output" + assert isinstance(response.value, OutputStruct) + assert "seattle" in response.value.location.lower() + else: + # Runtime JSON schema + assert response.value is None, "No structured output, can't parse any json." + response_value = json.loads(response.text) + assert isinstance(response_value, dict) + assert "location" in response_value + assert "seattle" in response_value["location"].lower() @pytest.mark.flaky @skip_if_azure_ai_integration_tests_disabled -async def test_azure_ai_chat_client_agent_with_tools() -> None: - """Test ChatAgent tools with AzureAIClient.""" - async with ( - temporary_chat_client(agent_name="RunToolsAgent") as chat_client, - ChatAgent(chat_client=chat_client, tools=[get_weather]) as agent, - ): - response = await agent.run("What's the weather like in Seattle?") +@pytest.mark.parametrize( + "option_name,option_value,needs_validation", + [ + param("temperature", 0.7, False, id="temperature"), + # Complex options requiring output validation + param("response_format", OutputStruct, True, id="response_format_pydantic"), + param( + "response_format", + { + "type": "json_schema", + "json_schema": { + "name": "WeatherDigest", + "strict": True, + "schema": { + "title": "WeatherDigest", + "type": "object", + "properties": { + "location": {"type": "string"}, + "conditions": {"type": "string"}, + "temperature_c": {"type": "number"}, + "advisory": {"type": "string"}, + }, + "required": ["location", "conditions", "temperature_c", "advisory"], + "additionalProperties": False, + }, + }, + }, + True, + id="response_format_runtime_json_schema", + ), + ], +) +async def test_integration_agent_options( + option_name: str, + option_value: Any, + needs_validation: bool, +) -> None: + """Test Foundry agent level options in both streaming and non-streaming modes. - # Validate response - assert isinstance(response, AgentRunResponse) - assert response.text is not None - assert len(response.text) > 0 - assert any(word in response.text.lower() for word in ["sunny", "25"]) + Tests both streaming and non-streaming modes for each option to ensure + they don't cause failures. Options marked with needs_validation also + check that the feature actually works correctly. + This test create a new client and uses it for both streaming and non-streaming tests. + """ + async with temporary_chat_client(agent_name=f"test-agent-{option_name.replace('_', '-')}-{uuid4()}") as client: + for streaming in [False, True]: + # Prepare test message + if option_name.startswith("response_format"): + # Use prompt that works well with structured output + messages = [ChatMessage(role="user", text="The weather in Seattle is sunny")] + messages.append(ChatMessage(role="user", text="What is the weather in Seattle?")) + else: + # Generic prompt for simple options + messages = [ChatMessage(role="user", text="Say 'Hello World' briefly.")] + + # Build options dict + options = {option_name: option_value} + + if streaming: + # Test streaming mode + response_gen = client.get_streaming_response( + messages=messages, + options=options, + ) + + output_format = option_value if option_name.startswith("response_format") else None + response = await ChatResponse.from_chat_response_generator( + response_gen, output_format_type=output_format + ) + else: + # Test non-streaming mode + response = await client.get_response( + messages=messages, + options=options, + ) + + assert response is not None + assert isinstance(response, ChatResponse) + assert response.text is not None, f"No text in response for option '{option_name}'" + assert len(response.text) > 0, f"Empty response for option '{option_name}'" + + # Validate based on option type + if needs_validation and option_name.startswith("response_format"): + if option_value == OutputStruct: + # Should have structured output + assert response.value is not None, "No structured output" + assert isinstance(response.value, OutputStruct) + assert "seattle" in response.value.location.lower() + else: + # Runtime JSON schema + assert response.value is None, "No structured output, can't parse any json." + response_value = json.loads(response.text) + assert isinstance(response_value, dict) + assert "location" in response_value + assert "seattle" in response_value["location"].lower() -class ReleaseBrief(BaseModel): - """Structured output model for release brief.""" - title: str = Field(description="A short title for the release.") - summary: str = Field(description="A brief summary of what was released.") - highlights: list[str] = Field(description="Key highlights from the release.") - model_config = ConfigDict(extra="forbid") +@pytest.mark.flaky +@skip_if_azure_ai_integration_tests_disabled +async def test_integration_web_search() -> None: + async with temporary_chat_client(agent_name="af-int-test-web-search") as client: + for streaming in [False, True]: + content = { + "messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.", + "options": { + "tool_choice": "auto", + "tools": [HostedWebSearchTool()], + }, + } + if streaming: + response = await ChatResponse.from_chat_response_generator(client.get_streaming_response(**content)) + else: + response = await client.get_response(**content) + + assert response is not None + assert isinstance(response, ChatResponse) + assert "Rumi" in response.text + assert "Mira" in response.text + assert "Zoey" in response.text + + # Test that the client will use the web search tool with location + additional_properties = { + "user_location": { + "country": "US", + "city": "Seattle", + } + } + content = { + "messages": "What is the current weather? Do not ask for my current location.", + "options": { + "tool_choice": "auto", + "tools": [HostedWebSearchTool(additional_properties=additional_properties)], + }, + } + if streaming: + response = await ChatResponse.from_chat_response_generator(client.get_streaming_response(**content)) + else: + response = await client.get_response(**content) + assert response.text is not None @pytest.mark.flaky @skip_if_azure_ai_integration_tests_disabled -async def test_azure_ai_chat_client_agent_with_response_format() -> None: - """Test ChatAgent with response_format (structured output) using AzureAIClient.""" - async with ( - temporary_chat_client(agent_name="ResponseFormatAgent") as chat_client, - ChatAgent(chat_client=chat_client) as agent, - ): - response = await agent.run( - "Summarize the following release notes into a ReleaseBrief:\n\n" - "Version 2.0 Release Notes:\n" - "- Added new streaming API for real-time responses\n" - "- Improved error handling with detailed messages\n" - "- Performance boost of 50% in batch processing\n" - "- Fixed memory leak in connection pooling", - response_format=ReleaseBrief, +async def test_integration_agent_hosted_mcp_tool() -> None: + """Integration test for HostedMCPTool with Azure Response Agent using Microsoft Learn MCP.""" + async with temporary_chat_client(agent_name="af-int-test-mcp") as client: + response = await client.get_response( + "How to create an Azure storage account using az cli?", + options={ + # this needs to be high enough to handle the full MCP tool response. + "max_tokens": 5000, + "tools": HostedMCPTool( + name="Microsoft Learn MCP", + url="https://learn.microsoft.com/api/mcp", + description="A Microsoft Learn MCP server for documentation questions", + approval_mode="never_require", + ), + }, ) - - # Validate response - assert isinstance(response, AgentRunResponse) - assert response.value is not None - assert isinstance(response.value, ReleaseBrief) - - # Validate structured output fields - brief = response.value - assert len(brief.title) > 0 - assert len(brief.summary) > 0 - assert len(brief.highlights) > 0 + assert isinstance(response, ChatResponse) + assert response.text + # Should contain Azure-related content since it's asking about Azure CLI + assert any(term in response.text.lower() for term in ["azure", "storage", "account", "cli"]) @pytest.mark.flaky @skip_if_azure_ai_integration_tests_disabled -async def test_azure_ai_chat_client_agent_with_runtime_json_schema() -> None: - """Test ChatAgent with runtime JSON schema (structured output) using AzureAIClient.""" - runtime_schema = { - "title": "WeatherDigest", - "type": "object", - "properties": { - "location": {"type": "string"}, - "conditions": {"type": "string"}, - "temperature_c": {"type": "number"}, - "advisory": {"type": "string"}, - }, - "required": ["location", "conditions", "temperature_c", "advisory"], - "additionalProperties": False, - } - - async with ( - temporary_chat_client(agent_name="RuntimeSchemaAgent") as chat_client, - ChatAgent(chat_client=chat_client) as agent, - ): - response = await agent.run( - "Give a brief weather digest for Seattle.", - additional_chat_options={ - "response_format": { - "type": "json_schema", - "json_schema": { - "name": runtime_schema["title"], - "strict": True, - "schema": runtime_schema, - }, - }, +async def test_integration_agent_hosted_code_interpreter_tool(): + """Test Azure Responses Client agent with HostedCodeInterpreterTool through AzureAIClient.""" + async with temporary_chat_client(agent_name="af-int-test-code-interpreter") as client: + response = await client.get_response( + "Calculate the sum of numbers from 1 to 10 using Python code.", + options={ + "tools": [HostedCodeInterpreterTool()], }, ) + # Should contain calculation result (sum of 1-10 = 55) or code execution content + contains_relevant_content = any( + term in response.text.lower() for term in ["55", "sum", "code", "python", "calculate", "10"] + ) + assert contains_relevant_content or len(response.text.strip()) > 10 - # Validate response - assert isinstance(response, AgentRunResponse) - assert response.text is not None - # Parse JSON and validate structure - import json +@pytest.mark.flaky +@skip_if_azure_ai_integration_tests_disabled +async def test_integration_agent_existing_thread(): + """Test Azure Responses Client agent with existing thread to continue conversations across agent instances.""" + # First conversation - capture the thread + preserved_thread = None - parsed = json.loads(response.text) - assert "location" in parsed - assert "conditions" in parsed - assert "temperature_c" in parsed - assert "advisory" in parsed + async with ( + temporary_chat_client(agent_name="af-int-test-existing-thread") as client, + ChatAgent( + chat_client=client, + instructions="You are a helpful assistant with good memory.", + ) as first_agent, + ): + # Start a conversation and capture the thread + thread = first_agent.get_new_thread() + first_response = await first_agent.run("My hobby is photography. Remember this.", thread=thread, store=True) + + assert isinstance(first_response, AgentResponse) + assert first_response.text is not None + + # Preserve the thread for reuse + preserved_thread = thread + + # Second conversation - reuse the thread in a new agent instance + if preserved_thread: + async with ( + temporary_chat_client(agent_name="af-int-test-existing-thread-2") as client, + ChatAgent( + chat_client=client, + instructions="You are a helpful assistant with good memory.", + ) as second_agent, + ): + # Reuse the preserved thread + second_response = await second_agent.run("What is my hobby?", thread=preserved_thread) + + assert isinstance(second_response, AgentResponse) + assert second_response.text is not None + assert "photography" in second_response.text.lower() diff --git a/python/packages/azure-ai/tests/test_provider.py b/python/packages/azure-ai/tests/test_provider.py new file mode 100644 index 0000000000..e3dfa0995a --- /dev/null +++ b/python/packages/azure-ai/tests/test_provider.py @@ -0,0 +1,436 @@ +# Copyright (c) Microsoft. All rights reserved. + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from agent_framework import ChatAgent +from agent_framework.exceptions import ServiceInitializationError +from azure.ai.projects.aio import AIProjectClient +from azure.ai.projects.models import ( + AgentReference, + AgentVersionDetails, + FunctionTool, + PromptAgentDefinition, +) +from azure.identity.aio import AzureCliCredential + +from agent_framework_azure_ai import AzureAIProjectAgentProvider + +skip_if_azure_ai_integration_tests_disabled = pytest.mark.skipif( + os.getenv("RUN_INTEGRATION_TESTS", "false").lower() != "true" + or os.getenv("AZURE_AI_PROJECT_ENDPOINT", "") in ("", "https://test-project.cognitiveservices.azure.com/") + or os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME", "") == "", + reason=( + "No real AZURE_AI_PROJECT_ENDPOINT or AZURE_AI_MODEL_DEPLOYMENT_NAME provided; skipping integration tests." + if os.getenv("RUN_INTEGRATION_TESTS", "false").lower() == "true" + else "Integration tests are disabled." + ), +) + + +@pytest.fixture +def mock_project_client() -> MagicMock: + """Fixture that provides a mock AIProjectClient.""" + mock_client = MagicMock() + + # Mock agents property + mock_client.agents = MagicMock() + mock_client.agents.create_version = AsyncMock() + + # Mock conversations property + mock_client.conversations = MagicMock() + mock_client.conversations.create = AsyncMock() + + # Mock telemetry property + mock_client.telemetry = MagicMock() + mock_client.telemetry.get_application_insights_connection_string = AsyncMock() + + # Mock get_openai_client method + mock_client.get_openai_client = AsyncMock() + + # Mock close method + mock_client.close = AsyncMock() + + return mock_client + + +@pytest.fixture +def mock_azure_credential() -> MagicMock: + """Fixture that provides a mock Azure credential.""" + return MagicMock() + + +@pytest.fixture +def azure_ai_unit_test_env(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]: + """Fixture that sets up Azure AI environment variables for unit testing.""" + env_vars = { + "AZURE_AI_PROJECT_ENDPOINT": "https://test-project.cognitiveservices.azure.com/", + "AZURE_AI_MODEL_DEPLOYMENT_NAME": "test-model-deployment", + } + for key, value in env_vars.items(): + monkeypatch.setenv(key, value) + return env_vars + + +def test_provider_init_with_project_client(mock_project_client: MagicMock) -> None: + """Test AzureAIProjectAgentProvider initialization with existing project_client.""" + provider = AzureAIProjectAgentProvider(project_client=mock_project_client) + + assert provider._project_client is mock_project_client # type: ignore + assert not provider._should_close_client # type: ignore + + +def test_provider_init_with_credential_and_endpoint( + azure_ai_unit_test_env: dict[str, str], + mock_azure_credential: MagicMock, +) -> None: + """Test AzureAIProjectAgentProvider initialization with credential and endpoint.""" + with patch("agent_framework_azure_ai._project_provider.AIProjectClient") as mock_ai_project_client: + mock_client = MagicMock() + mock_ai_project_client.return_value = mock_client + + provider = AzureAIProjectAgentProvider( + project_endpoint=azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"], + credential=mock_azure_credential, + ) + + assert provider._project_client is mock_client # type: ignore + assert provider._should_close_client # type: ignore + + # Verify AIProjectClient was called with correct parameters + mock_ai_project_client.assert_called_once() + + +def test_provider_init_missing_endpoint() -> None: + """Test AzureAIProjectAgentProvider initialization when endpoint is missing.""" + with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings: + mock_settings.return_value.project_endpoint = None + mock_settings.return_value.model_deployment_name = "test-model" + + with pytest.raises(ServiceInitializationError, match="Azure AI project endpoint is required"): + AzureAIProjectAgentProvider(credential=MagicMock()) + + +def test_provider_init_missing_credential(azure_ai_unit_test_env: dict[str, str]) -> None: + """Test AzureAIProjectAgentProvider initialization when credential is missing.""" + with pytest.raises( + ServiceInitializationError, match="Azure credential is required when project_client is not provided" + ): + AzureAIProjectAgentProvider( + project_endpoint=azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"], + ) + + +async def test_provider_create_agent( + mock_project_client: MagicMock, + azure_ai_unit_test_env: dict[str, str], +) -> None: + """Test AzureAIProjectAgentProvider.create_agent method.""" + with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings: + mock_settings.return_value.project_endpoint = azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"] + mock_settings.return_value.model_deployment_name = azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"] + + provider = AzureAIProjectAgentProvider(project_client=mock_project_client) + + # Mock agent creation response + mock_agent_version = MagicMock(spec=AgentVersionDetails) + mock_agent_version.id = "agent-id" + mock_agent_version.name = "test-agent" + mock_agent_version.version = "1.0" + mock_agent_version.description = "Test Agent" + mock_agent_version.definition = MagicMock(spec=PromptAgentDefinition) + mock_agent_version.definition.model = "gpt-4" + mock_agent_version.definition.instructions = "Test instructions" + mock_agent_version.definition.temperature = 0.7 + mock_agent_version.definition.top_p = 0.9 + mock_agent_version.definition.tools = [] + + mock_project_client.agents.create_version = AsyncMock(return_value=mock_agent_version) + + agent = await provider.create_agent( + name="test-agent", + model="gpt-4", + instructions="Test instructions", + description="Test Agent", + ) + + assert isinstance(agent, ChatAgent) + assert agent.name == "test-agent" + mock_project_client.agents.create_version.assert_called_once() + + +async def test_provider_create_agent_with_env_model( + mock_project_client: MagicMock, + azure_ai_unit_test_env: dict[str, str], +) -> None: + """Test AzureAIProjectAgentProvider.create_agent uses model from env var.""" + with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings: + mock_settings.return_value.project_endpoint = azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"] + mock_settings.return_value.model_deployment_name = azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"] + + provider = AzureAIProjectAgentProvider(project_client=mock_project_client) + + # Mock agent creation response + mock_agent_version = MagicMock(spec=AgentVersionDetails) + mock_agent_version.id = "agent-id" + mock_agent_version.name = "test-agent" + mock_agent_version.version = "1.0" + mock_agent_version.description = None + mock_agent_version.definition = MagicMock(spec=PromptAgentDefinition) + mock_agent_version.definition.model = azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"] + mock_agent_version.definition.instructions = None + mock_agent_version.definition.temperature = None + mock_agent_version.definition.top_p = None + mock_agent_version.definition.tools = [] + + mock_project_client.agents.create_version = AsyncMock(return_value=mock_agent_version) + + # Call without model parameter - should use env var + agent = await provider.create_agent(name="test-agent") + + assert isinstance(agent, ChatAgent) + # Verify the model from env var was used + call_args = mock_project_client.agents.create_version.call_args + assert call_args[1]["definition"].model == azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"] + + +async def test_provider_create_agent_missing_model(mock_project_client: MagicMock) -> None: + """Test AzureAIProjectAgentProvider.create_agent raises when model is missing.""" + with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings: + mock_settings.return_value.project_endpoint = "https://test.com" + mock_settings.return_value.model_deployment_name = None + + provider = AzureAIProjectAgentProvider(project_client=mock_project_client) + + with pytest.raises(ServiceInitializationError, match="Model deployment name is required"): + await provider.create_agent(name="test-agent") + + +async def test_provider_get_agent_with_name(mock_project_client: MagicMock) -> None: + """Test AzureAIProjectAgentProvider.get_agent with name parameter.""" + provider = AzureAIProjectAgentProvider(project_client=mock_project_client) + + # Mock agent response + mock_agent_version = MagicMock(spec=AgentVersionDetails) + mock_agent_version.id = "agent-id" + mock_agent_version.name = "test-agent" + mock_agent_version.version = "1.0" + mock_agent_version.description = "Test Agent" + mock_agent_version.definition = MagicMock(spec=PromptAgentDefinition) + mock_agent_version.definition.model = "gpt-4" + mock_agent_version.definition.instructions = "Test instructions" + mock_agent_version.definition.temperature = None + mock_agent_version.definition.top_p = None + mock_agent_version.definition.tools = [] + + mock_agent_object = MagicMock() + mock_agent_object.versions.latest = mock_agent_version + + mock_project_client.agents = AsyncMock() + mock_project_client.agents.get.return_value = mock_agent_object + + agent = await provider.get_agent(name="test-agent") + + assert isinstance(agent, ChatAgent) + assert agent.name == "test-agent" + mock_project_client.agents.get.assert_called_with(agent_name="test-agent") + + +async def test_provider_get_agent_with_reference(mock_project_client: MagicMock) -> None: + """Test AzureAIProjectAgentProvider.get_agent with reference parameter.""" + provider = AzureAIProjectAgentProvider(project_client=mock_project_client) + + # Mock agent response + mock_agent_version = MagicMock(spec=AgentVersionDetails) + mock_agent_version.id = "agent-id" + mock_agent_version.name = "test-agent" + mock_agent_version.version = "1.0" + mock_agent_version.description = "Test Agent" + mock_agent_version.definition = MagicMock(spec=PromptAgentDefinition) + mock_agent_version.definition.model = "gpt-4" + mock_agent_version.definition.instructions = "Test instructions" + mock_agent_version.definition.temperature = None + mock_agent_version.definition.top_p = None + mock_agent_version.definition.tools = [] + + mock_project_client.agents = AsyncMock() + mock_project_client.agents.get_version.return_value = mock_agent_version + + agent_reference = AgentReference(name="test-agent", version="1.0") + agent = await provider.get_agent(reference=agent_reference) + + assert isinstance(agent, ChatAgent) + assert agent.name == "test-agent" + mock_project_client.agents.get_version.assert_called_with(agent_name="test-agent", agent_version="1.0") + + +async def test_provider_get_agent_missing_parameters(mock_project_client: MagicMock) -> None: + """Test AzureAIProjectAgentProvider.get_agent raises when no identifier provided.""" + provider = AzureAIProjectAgentProvider(project_client=mock_project_client) + + with pytest.raises(ValueError, match="Either name or reference must be provided"): + await provider.get_agent() + + +async def test_provider_get_agent_missing_function_tools(mock_project_client: MagicMock) -> None: + """Test AzureAIProjectAgentProvider.get_agent raises when required tools are missing.""" + provider = AzureAIProjectAgentProvider(project_client=mock_project_client) + + # Mock agent with function tools + mock_agent_version = MagicMock(spec=AgentVersionDetails) + mock_agent_version.id = "agent-id" + mock_agent_version.name = "test-agent" + mock_agent_version.version = "1.0" + mock_agent_version.description = None + mock_agent_version.definition = MagicMock(spec=PromptAgentDefinition) + mock_agent_version.definition.tools = [ + FunctionTool(name="test_tool", parameters=[], strict=True, description="Test tool") + ] + + mock_agent_object = MagicMock() + mock_agent_object.versions.latest = mock_agent_version + + mock_project_client.agents = AsyncMock() + mock_project_client.agents.get.return_value = mock_agent_object + + with pytest.raises( + ValueError, match="The following prompt agent definition required tools were not provided: test_tool" + ): + await provider.get_agent(name="test-agent") + + +def test_provider_as_agent(mock_project_client: MagicMock) -> None: + """Test AzureAIProjectAgentProvider.as_agent method.""" + provider = AzureAIProjectAgentProvider(project_client=mock_project_client) + + # Create mock agent version + mock_agent_version = MagicMock(spec=AgentVersionDetails) + mock_agent_version.id = "agent-id" + mock_agent_version.name = "test-agent" + mock_agent_version.version = "1.0" + mock_agent_version.description = "Test Agent" + mock_agent_version.definition = MagicMock(spec=PromptAgentDefinition) + mock_agent_version.definition.model = "gpt-4" + mock_agent_version.definition.instructions = "Test instructions" + mock_agent_version.definition.temperature = 0.7 + mock_agent_version.definition.top_p = 0.9 + mock_agent_version.definition.tools = [] + + with patch("agent_framework_azure_ai._project_provider.AzureAIClient") as mock_azure_ai_client: + agent = provider.as_agent(mock_agent_version) + + assert isinstance(agent, ChatAgent) + assert agent.name == "test-agent" + assert agent.description == "Test Agent" + + # Verify AzureAIClient was called with correct parameters + mock_azure_ai_client.assert_called_once() + call_kwargs = mock_azure_ai_client.call_args[1] + assert call_kwargs["project_client"] is mock_project_client + assert call_kwargs["agent_name"] == "test-agent" + assert call_kwargs["agent_version"] == "1.0" + assert call_kwargs["agent_description"] == "Test Agent" + assert call_kwargs["model_deployment_name"] == "gpt-4" + + +async def test_provider_context_manager(mock_project_client: MagicMock) -> None: + """Test AzureAIProjectAgentProvider async context manager.""" + with patch("agent_framework_azure_ai._project_provider.AIProjectClient") as mock_ai_project_client: + mock_client = MagicMock() + mock_client.close = AsyncMock() + mock_ai_project_client.return_value = mock_client + + with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings: + mock_settings.return_value.project_endpoint = "https://test.com" + mock_settings.return_value.model_deployment_name = "test-model" + + async with AzureAIProjectAgentProvider(credential=MagicMock()) as provider: + assert provider._project_client is mock_client # type: ignore + + # Should call close after exiting context + mock_client.close.assert_called_once() + + +async def test_provider_context_manager_with_provided_client(mock_project_client: MagicMock) -> None: + """Test AzureAIProjectAgentProvider context manager doesn't close provided client.""" + mock_project_client.close = AsyncMock() + + async with AzureAIProjectAgentProvider(project_client=mock_project_client) as provider: + assert provider._project_client is mock_project_client # type: ignore + + # Should NOT call close when client was provided + mock_project_client.close.assert_not_called() + + +async def test_provider_close_method(mock_project_client: MagicMock) -> None: + """Test AzureAIProjectAgentProvider.close method.""" + with patch("agent_framework_azure_ai._project_provider.AIProjectClient") as mock_ai_project_client: + mock_client = MagicMock() + mock_client.close = AsyncMock() + mock_ai_project_client.return_value = mock_client + + with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings: + mock_settings.return_value.project_endpoint = "https://test.com" + mock_settings.return_value.model_deployment_name = "test-model" + + provider = AzureAIProjectAgentProvider(credential=MagicMock()) + await provider.close() + + mock_client.close.assert_called_once() + + +def test_create_text_format_config_sets_strict_for_pydantic_models() -> None: + """Test that create_text_format_config sets strict=True for Pydantic models.""" + from pydantic import BaseModel + + from agent_framework_azure_ai._shared import create_text_format_config + + class TestSchema(BaseModel): + subject: str + summary: str + + result = create_text_format_config(TestSchema) + + # Verify strict=True is set + assert result["strict"] is True + assert result["name"] == "TestSchema" + assert "schema" in result + + +@pytest.mark.flaky +@skip_if_azure_ai_integration_tests_disabled +async def test_provider_create_and_get_agent_integration() -> None: + """Integration test for provider create_agent and get_agent.""" + endpoint = os.environ["AZURE_AI_PROJECT_ENDPOINT"] + model = os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"] + + async with ( + AzureCliCredential() as credential, + AIProjectClient(endpoint=endpoint, credential=credential) as project_client, + ): + provider = AzureAIProjectAgentProvider(project_client=project_client) + + try: + # Create agent + agent = await provider.create_agent( + name="ProviderTestAgent", + model=model, + instructions="You are a helpful assistant. Always respond with 'Hello from provider!'", + ) + + assert isinstance(agent, ChatAgent) + assert agent.name == "ProviderTestAgent" + + # Run the agent + response = await agent.run("Hi!") + assert response.text is not None + assert len(response.text) > 0 + + # Get the same agent + retrieved_agent = await provider.get_agent(name="ProviderTestAgent") + assert retrieved_agent.name == "ProviderTestAgent" + + finally: + # Cleanup + await project_client.agents.delete(agent_name="ProviderTestAgent") diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py index 11a35f7f92..60ec15e21c 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py @@ -63,10 +63,10 @@ def __init__(self, result: Any): class AgentTask(_TypedCompoundTask): - """A custom Task that wraps entity calls and provides typed AgentRunResponse results. + """A custom Task that wraps entity calls and provides typed AgentResponse results. This task wraps the underlying entity call task and intercepts its completion - to convert the raw result into a typed AgentRunResponse object. + to convert the raw result into a typed AgentResponse object. """ def __init__( @@ -97,7 +97,7 @@ def __init__( self.id = entity_task.id def try_set_value(self, child: TaskBase) -> None: - """Transition the AgentTask to a terminal state and set its value to `AgentRunResponse`. + """Transition the AgentTask to a terminal state and set its value to `AgentResponse`. Parameters ---------- @@ -124,7 +124,7 @@ def try_set_value(self, child: TaskBase) -> None: response, ) - # Set the typed AgentRunResponse as this task's result + # Set the typed AgentResponse as this task's result self.set_value(is_error=False, value=response) except Exception as e: logger.exception( @@ -151,17 +151,16 @@ def generate_unique_id(self) -> str: def get_run_request( self, message: str, - response_format: type[BaseModel] | None, - enable_tool_calls: bool, - wait_for_response: bool = True, + *, + options: dict[str, Any] | None = None, ) -> RunRequest: """Get the current run request from the orchestration context. Args: message: The message to send to the agent - response_format: Optional Pydantic model for response parsing - enable_tool_calls: Whether to enable tool calls - wait_for_response: Must be True for orchestration contexts + options: Optional options dictionary. Supported keys include + ``response_format``, ``enable_tool_calls``, and ``wait_for_response``. + Additional keys are forwarded to the agent execution. Returns: RunRequest: The current run request @@ -169,12 +168,9 @@ def get_run_request( Raises: ValueError: If wait_for_response=False (not supported in orchestrations) """ - request = super().get_run_request( - message, - response_format, - enable_tool_calls, - wait_for_response, - ) + # Create a copy to avoid modifying the caller's dict + + request = super().get_run_request(message, options=options) request.orchestration_id = self.context.instance_id return request diff --git a/python/packages/azurefunctions/pyproject.toml b/python/packages/azurefunctions/pyproject.toml index 4a05986469..c6a8ecbfe6 100644 --- a/python/packages/azurefunctions/pyproject.toml +++ b/python/packages/azurefunctions/pyproject.toml @@ -4,7 +4,7 @@ description = "Azure Functions integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/azurefunctions/tests/integration_tests/conftest.py b/python/packages/azurefunctions/tests/integration_tests/conftest.py index e2f19d6037..ee81028b80 100644 --- a/python/packages/azurefunctions/tests/integration_tests/conftest.py +++ b/python/packages/azurefunctions/tests/integration_tests/conftest.py @@ -6,13 +6,18 @@ """ import subprocess +import sys from collections.abc import Iterator, Mapping +from pathlib import Path from typing import Any import pytest import requests -from .testutils import ( +# Add the integration_tests directory to the path so testutils can be imported +sys.path.insert(0, str(Path(__file__).parent)) + +from testutils import ( FunctionAppStartupError, build_base_url, cleanup_function_app, diff --git a/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py b/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py index d98e23824e..7af3a3b653 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py @@ -16,8 +16,7 @@ import pytest from agent_framework_durabletask import THREAD_ID_HEADER - -from .testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled +from testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file pytestmark = [ diff --git a/python/packages/azurefunctions/tests/integration_tests/test_02_multi_agent.py b/python/packages/azurefunctions/tests/integration_tests/test_02_multi_agent.py index f473a2be11..7a4adfd8dd 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_02_multi_agent.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_02_multi_agent.py @@ -15,8 +15,7 @@ """ import pytest - -from .testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled +from testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file pytestmark = [ diff --git a/python/packages/azurefunctions/tests/integration_tests/test_03_reliable_streaming.py b/python/packages/azurefunctions/tests/integration_tests/test_03_reliable_streaming.py index 44fb8efb2f..032935ee29 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_03_reliable_streaming.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_03_reliable_streaming.py @@ -19,8 +19,7 @@ import pytest import requests - -from .testutils import ( +from testutils import ( SampleTestHelper, skip_if_azure_functions_integration_tests_disabled, ) diff --git a/python/packages/azurefunctions/tests/integration_tests/test_04_single_agent_orchestration_chaining.py b/python/packages/azurefunctions/tests/integration_tests/test_04_single_agent_orchestration_chaining.py index e4bb1cd930..fff06c9d8d 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_04_single_agent_orchestration_chaining.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_04_single_agent_orchestration_chaining.py @@ -19,8 +19,7 @@ """ import pytest - -from .testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled +from testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file pytestmark = [ diff --git a/python/packages/azurefunctions/tests/integration_tests/test_05_multi_agent_orchestration_concurrency.py b/python/packages/azurefunctions/tests/integration_tests/test_05_multi_agent_orchestration_concurrency.py index aac8f361c6..d2d9cbbed8 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_05_multi_agent_orchestration_concurrency.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_05_multi_agent_orchestration_concurrency.py @@ -19,8 +19,7 @@ """ import pytest - -from .testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled +from testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file pytestmark = [ diff --git a/python/packages/azurefunctions/tests/integration_tests/test_06_multi_agent_orchestration_conditionals.py b/python/packages/azurefunctions/tests/integration_tests/test_06_multi_agent_orchestration_conditionals.py index d7f13777bb..0b2a9f7073 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_06_multi_agent_orchestration_conditionals.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_06_multi_agent_orchestration_conditionals.py @@ -19,8 +19,7 @@ """ import pytest - -from .testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled +from testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file pytestmark = [ diff --git a/python/packages/azurefunctions/tests/integration_tests/test_07_single_agent_orchestration_hitl.py b/python/packages/azurefunctions/tests/integration_tests/test_07_single_agent_orchestration_hitl.py index ade46033bc..f21410ebf5 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_07_single_agent_orchestration_hitl.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_07_single_agent_orchestration_hitl.py @@ -21,8 +21,7 @@ import time import pytest - -from .testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled +from testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file pytestmark = [ diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index 466cb8ea85..4fb1617a73 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -12,7 +12,7 @@ import azure.durable_functions as df import azure.functions as func import pytest -from agent_framework import AgentRunResponse, ChatMessage, ErrorContent +from agent_framework import AgentResponse, ChatMessage, ErrorContent from agent_framework_durabletask import ( MIMETYPE_APPLICATION_JSON, MIMETYPE_TEXT_PLAIN, @@ -356,7 +356,7 @@ async def test_entity_run_agent_operation(self) -> None: """Test that entity can run agent operation.""" mock_agent = Mock() mock_agent.run = AsyncMock( - return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Test response")]) + return_value=AgentResponse(messages=[ChatMessage(role="assistant", text="Test response")]) ) entity = AgentEntity(mock_agent, state_provider=_InMemoryStateProvider(thread_id="test-conv-123")) @@ -366,7 +366,7 @@ async def test_entity_run_agent_operation(self) -> None: "correlationId": "corr-app-entity-1", }) - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert result.text == "Test response" assert entity.state.message_count == 2 @@ -374,7 +374,7 @@ async def test_entity_stores_conversation_history(self) -> None: """Test that the entity stores conversation history.""" mock_agent = Mock() mock_agent.run = AsyncMock( - return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response 1")]) + return_value=AgentResponse(messages=[ChatMessage(role="assistant", text="Response 1")]) ) entity = AgentEntity(mock_agent, state_provider=_InMemoryStateProvider(thread_id="conv-1")) @@ -408,7 +408,7 @@ async def test_entity_increments_message_count(self) -> None: """Test that the entity increments the message count.""" mock_agent = Mock() mock_agent.run = AsyncMock( - return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response")]) + return_value=AgentResponse(messages=[ChatMessage(role="assistant", text="Response")]) ) entity = AgentEntity(mock_agent, state_provider=_InMemoryStateProvider(thread_id="conv-1")) @@ -449,7 +449,7 @@ def test_entity_function_handles_run_operation(self) -> None: """Test that the entity function handles the run operation.""" mock_agent = Mock() mock_agent.run = AsyncMock( - return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response")]) + return_value=AgentResponse(messages=[ChatMessage(role="assistant", text="Response")]) ) entity_function = create_agent_entity(mock_agent) @@ -476,7 +476,7 @@ def test_entity_function_handles_run_agent_operation(self) -> None: """Test that the entity function handles the deprecated run_agent operation for backward compatibility.""" mock_agent = Mock() mock_agent.run = AsyncMock( - return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response")]) + return_value=AgentResponse(messages=[ChatMessage(role="assistant", text="Response")]) ) entity_function = create_agent_entity(mock_agent) @@ -633,7 +633,7 @@ async def test_entity_handles_agent_error(self) -> None: "correlationId": "corr-app-error-1", }) - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert len(result.messages) == 1 content = result.messages[0].contents[0] assert isinstance(content, ErrorContent) diff --git a/python/packages/azurefunctions/tests/test_entities.py b/python/packages/azurefunctions/tests/test_entities.py index 7ae845ed2b..555b588887 100644 --- a/python/packages/azurefunctions/tests/test_entities.py +++ b/python/packages/azurefunctions/tests/test_entities.py @@ -10,19 +10,21 @@ from unittest.mock import AsyncMock, Mock import pytest -from agent_framework import AgentRunResponse, ChatMessage +from agent_framework import AgentResponse, ChatMessage, Role from agent_framework_azurefunctions._entities import create_agent_entity TFunc = TypeVar("TFunc", bound=Callable[..., Any]) -def _agent_response(text: str | None) -> AgentRunResponse: - """Create an AgentRunResponse with a single assistant message.""" +def _agent_response(text: str | None) -> AgentResponse: + """Create an AgentResponse with a single assistant message.""" message = ( - ChatMessage(role="assistant", text=text) if text is not None else ChatMessage(role="assistant", contents=[]) + ChatMessage(role=Role.ASSISTANT, text=text) + if text is not None + else ChatMessage(role=Role.ASSISTANT, contents=[]) ) - return AgentRunResponse(messages=[message]) + return AgentResponse(messages=[message]) class TestCreateAgentEntity: diff --git a/python/packages/azurefunctions/tests/test_orchestration.py b/python/packages/azurefunctions/tests/test_orchestration.py index c19d99177f..2b9a4126d4 100644 --- a/python/packages/azurefunctions/tests/test_orchestration.py +++ b/python/packages/azurefunctions/tests/test_orchestration.py @@ -6,7 +6,7 @@ from unittest.mock import Mock import pytest -from agent_framework import AgentRunResponse, ChatMessage, Role +from agent_framework import AgentResponse, ChatMessage, Role from agent_framework_durabletask import DurableAIAgent from azure.durable_functions.models.Task import TaskBase, TaskState @@ -136,7 +136,7 @@ def test_try_set_value_success(self) -> None: # Simulate successful entity task completion entity_task.state = TaskState.SUCCEEDED - entity_task.result = AgentRunResponse(messages=[ChatMessage(role="assistant", text="Test response")]).to_dict() + entity_task.result = AgentResponse(messages=[ChatMessage(role="assistant", text="Test response")]).to_dict() # Clear pending_tasks to simulate that parent has processed the child task.pending_tasks.clear() @@ -144,9 +144,9 @@ def test_try_set_value_success(self) -> None: # Call try_set_value task.try_set_value(entity_task) - # Verify task completed successfully with AgentRunResponse + # Verify task completed successfully with AgentResponse assert task.state == TaskState.SUCCEEDED - assert isinstance(task.result, AgentRunResponse) + assert isinstance(task.result, AgentResponse) assert task.result.text == "Test response" def test_try_set_value_failure(self) -> None: @@ -178,9 +178,7 @@ class TestSchema(BaseModel): # Simulate successful entity task with JSON response entity_task.state = TaskState.SUCCEEDED - entity_task.result = AgentRunResponse( - messages=[ChatMessage(role="assistant", text='{"answer": "42"}')] - ).to_dict() + entity_task.result = AgentResponse(messages=[ChatMessage(role="assistant", text='{"answer": "42"}')]).to_dict() # Clear pending_tasks to simulate that parent has processed the child task.pending_tasks.clear() @@ -190,7 +188,7 @@ class TestSchema(BaseModel): # Verify task completed and value was parsed assert task.state == TaskState.SUCCEEDED - assert isinstance(task.result, AgentRunResponse) + assert isinstance(task.result, AgentResponse) assert isinstance(task.result.value, TestSchema) assert task.result.value.answer == "42" @@ -219,7 +217,7 @@ def test_fire_and_forget_calls_signal_entity(self, executor_with_uuid: tuple[Any thread = agent.get_new_thread() # Run with wait_for_response=False - result = agent.run("Test message", thread=thread, wait_for_response=False) + result = agent.run("Test message", thread=thread, options={"wait_for_response": False}) # Verify signal_entity was called and call_entity was not assert context.signal_entity.call_count == 1 @@ -236,7 +234,7 @@ def test_fire_and_forget_returns_completed_task(self, executor_with_uuid: tuple[ agent = DurableAIAgent(executor, "TestAgent") thread = agent.get_new_thread() - result = agent.run("Test message", thread=thread, wait_for_response=False) + result = agent.run("Test message", thread=thread, options={"wait_for_response": False}) # Task should be immediately complete assert isinstance(result, AgentTask) @@ -250,11 +248,11 @@ def test_fire_and_forget_returns_acceptance_response(self, executor_with_uuid: t agent = DurableAIAgent(executor, "TestAgent") thread = agent.get_new_thread() - result = agent.run("Test message", thread=thread, wait_for_response=False) + result = agent.run("Test message", thread=thread, options={"wait_for_response": False}) # Get the result response = result.result - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert len(response.messages) == 1 assert response.messages[0].role == Role.SYSTEM # Check message contains key information @@ -271,7 +269,7 @@ def test_blocking_mode_still_works(self, executor_with_uuid: tuple[Any, Mock, st agent = DurableAIAgent(executor, "TestAgent") thread = agent.get_new_thread() - result = agent.run("Test message", thread=thread, wait_for_response=True) + result = agent.run("Test message", thread=thread, options={"wait_for_response": True}) # Verify call_entity was called and signal_entity was not assert context.call_entity.call_count == 1 diff --git a/python/packages/bedrock/agent_framework_bedrock/__init__.py b/python/packages/bedrock/agent_framework_bedrock/__init__.py index 84f3e5946c..c33badcb35 100644 --- a/python/packages/bedrock/agent_framework_bedrock/__init__.py +++ b/python/packages/bedrock/agent_framework_bedrock/__init__.py @@ -2,7 +2,7 @@ import importlib.metadata -from ._chat_client import BedrockChatClient +from ._chat_client import BedrockChatClient, BedrockChatOptions, BedrockGuardrailConfig, BedrockSettings try: __version__ = importlib.metadata.version(__name__) @@ -11,5 +11,8 @@ __all__ = [ "BedrockChatClient", + "BedrockChatOptions", + "BedrockGuardrailConfig", + "BedrockSettings", "__version__", ] diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index c1e404834f..e9e1eeff96 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -2,9 +2,10 @@ import asyncio import json +import sys from collections import deque from collections.abc import AsyncIterable, MutableMapping, MutableSequence, Sequence -from typing import Any, ClassVar +from typing import Any, ClassVar, Generic, Literal, TypedDict from uuid import uuid4 from agent_framework import ( @@ -28,6 +29,7 @@ prepare_function_call_results, use_chat_middleware, use_function_invocation, + validate_tool_mode, ) from agent_framework._pydantic import AFBaseSettings from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidResponseError @@ -37,11 +39,151 @@ from botocore.config import Config as BotoConfig from pydantic import SecretStr, ValidationError +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover + logger = get_logger("agent_framework.bedrock") + +__all__ = [ + "BedrockChatClient", + "BedrockChatOptions", + "BedrockGuardrailConfig", + "BedrockSettings", +] + + +# region Bedrock Chat Options TypedDict + + DEFAULT_REGION = "us-east-1" DEFAULT_MAX_TOKENS = 1024 + +class BedrockGuardrailConfig(TypedDict, total=False): + """Amazon Bedrock Guardrails configuration. + + See: https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails.html + """ + + guardrailIdentifier: str + """The identifier of the guardrail to apply.""" + + guardrailVersion: str + """The version of the guardrail to use.""" + + trace: Literal["enabled", "disabled"] + """Whether to include guardrail trace information in the response.""" + + streamProcessingMode: Literal["sync", "async"] + """How to process guardrails during streaming (sync blocks, async does not).""" + + +class BedrockChatOptions(ChatOptions, total=False): + """Amazon Bedrock Converse API-specific chat options dict. + + Extends base ChatOptions with Bedrock-specific parameters. + Bedrock uses a unified Converse API that works across multiple + foundation models (Claude, Titan, Llama, etc.). + + See: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + + Keys: + # Inherited from ChatOptions (mapped to Bedrock): + model_id: The Bedrock model identifier, + translates to ``modelId`` in Bedrock API. + temperature: Sampling temperature, + translates to ``inferenceConfig.temperature``. + top_p: Nucleus sampling parameter, + translates to ``inferenceConfig.topP``. + max_tokens: Maximum number of tokens to generate, + translates to ``inferenceConfig.maxTokens``. + stop: Stop sequences, + translates to ``inferenceConfig.stopSequences``. + tools: List of tools available to the model, + translates to ``toolConfig.tools``. + tool_choice: How the model should use tools, + translates to ``toolConfig.toolChoice``. + + # Options not supported in Bedrock Converse API: + seed: Not supported. + frequency_penalty: Not supported. + presence_penalty: Not supported. + allow_multiple_tool_calls: Not supported (models handle parallel calls automatically). + response_format: Not directly supported (use model-specific prompting). + user: Not supported. + store: Not supported. + logit_bias: Not supported. + metadata: Not supported (use additional_properties for additionalModelRequestFields). + + # Bedrock-specific options: + guardrailConfig: Guardrails configuration for content filtering. + performanceConfig: Performance optimization settings. + requestMetadata: Key-value metadata for the request. + promptVariables: Variables for prompt management (if using managed prompts). + """ + + # Bedrock-specific options + guardrailConfig: BedrockGuardrailConfig + """Guardrails configuration for content filtering and safety.""" + + performanceConfig: dict[str, Any] + """Performance optimization settings (e.g., latency optimization). + See: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-performance.html""" + + requestMetadata: dict[str, str] + """Key-value metadata for the request (max 2048 characters total).""" + + promptVariables: dict[str, dict[str, str]] + """Variables for prompt management when using managed prompts.""" + + # ChatOptions fields not supported in Bedrock + seed: None # type: ignore[misc] + """Not supported in Bedrock Converse API.""" + + frequency_penalty: None # type: ignore[misc] + """Not supported in Bedrock Converse API.""" + + presence_penalty: None # type: ignore[misc] + """Not supported in Bedrock Converse API.""" + + allow_multiple_tool_calls: None # type: ignore[misc] + """Not supported. Bedrock models handle parallel tool calls automatically.""" + + response_format: None # type: ignore[misc] + """Not directly supported. Use model-specific prompting for JSON output.""" + + user: None # type: ignore[misc] + """Not supported in Bedrock Converse API.""" + + store: None # type: ignore[misc] + """Not supported in Bedrock Converse API.""" + + logit_bias: None # type: ignore[misc] + """Not supported in Bedrock Converse API.""" + + +BEDROCK_OPTION_TRANSLATIONS: dict[str, str] = { + "model_id": "modelId", + "max_tokens": "maxTokens", + "top_p": "topP", + "stop": "stopSequences", +} +"""Maps ChatOptions keys to Bedrock Converse API parameter names.""" + +TBedrockChatOptions = TypeVar("TBedrockChatOptions", bound=TypedDict, default="BedrockChatOptions", covariant=True) # type: ignore[valid-type] + + +# endregion + + ROLE_MAP: dict[Role, str] = { Role.USER: "user", Role.ASSISTANT: "assistant", @@ -74,7 +216,7 @@ class BedrockSettings(AFBaseSettings): @use_function_invocation @use_instrumentation @use_chat_middleware -class BedrockChatClient(BaseChatClient): +class BedrockChatClient(BaseChatClient[TBedrockChatOptions], Generic[TBedrockChatOptions]): """Async chat client for Amazon Bedrock's Converse API.""" OTEL_PROVIDER_NAME: ClassVar[str] = "aws.bedrock" # type: ignore[reportIncompatibleVariableOverride, misc] @@ -106,6 +248,26 @@ def __init__( env_file_path: Optional .env file path used by ``BedrockSettings`` to load defaults. env_file_encoding: Encoding for the optional .env file. kwargs: Additional arguments forwarded to ``BaseChatClient``. + + Examples: + .. code-block:: python + + from agent_framework.bedrock import BedrockChatClient + + # Basic usage with default credentials + client = BedrockChatClient(model_id="") + + # Using custom ChatOptions with type safety: + from typing import TypedDict + from agent_framework_bedrock import BedrockChatOptions + + + class MyOptions(BedrockChatOptions, total=False): + my_custom_option: str + + + client = BedrockChatClient[MyOptions](model_id="") + response = await client.get_response("Hello", options={"my_custom_option": "value"}) """ try: settings = BedrockSettings( @@ -143,25 +305,27 @@ def _create_session(settings: BedrockSettings) -> Boto3Session: session_kwargs["aws_session_token"] = settings.session_token.get_secret_value() return Boto3Session(**session_kwargs) + @override async def _inner_get_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> ChatResponse: - request = self._build_converse_request(messages, chat_options, **kwargs) + request = self._prepare_options(messages, options, **kwargs) raw_response = await asyncio.to_thread(self._bedrock_client.converse, **request) return self._process_converse_response(raw_response) + @override async def _inner_get_streaming_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: - response = await self._inner_get_response(messages=messages, chat_options=chat_options, **kwargs) + response = await self._inner_get_response(messages=messages, options=options, **kwargs) contents = list(response.messages[0].contents if response.messages else []) if response.usage_details: contents.append(UsageContent(details=response.usage_details)) @@ -173,13 +337,13 @@ async def _inner_get_streaming_response( raw_representation=response.raw_representation, ) - def _build_converse_request( + def _prepare_options( self, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> dict[str, Any]: - model_id = chat_options.model_id or self.model_id + model_id = options.get("model_id") or self.model_id if not model_id: raise ServiceInitializationError( "Bedrock model_id is required. Set via chat options or BEDROCK_CHAT_MODEL_ID environment variable." @@ -188,40 +352,42 @@ def _build_converse_request( system_prompts, conversation = self._prepare_bedrock_messages(messages) if not conversation: raise ServiceInitializationError("At least one non-system message is required for Bedrock requests.") + # Prepend instructions from options if they exist + if instructions := options.get("instructions"): + system_prompts = [{"text": instructions}, *system_prompts] - payload: dict[str, Any] = { + run_options: dict[str, Any] = { "modelId": model_id, "messages": conversation, + "inferenceConfig": {"maxTokens": options.get("max_tokens", DEFAULT_MAX_TOKENS)}, } if system_prompts: - payload["system"] = system_prompts - - inference_config: dict[str, Any] = {} - inference_config["maxTokens"] = ( - chat_options.max_tokens if chat_options.max_tokens is not None else DEFAULT_MAX_TOKENS - ) - if chat_options.temperature is not None: - inference_config["temperature"] = chat_options.temperature - if chat_options.top_p is not None: - inference_config["topP"] = chat_options.top_p - if chat_options.stop is not None: - inference_config["stopSequences"] = chat_options.stop - if inference_config: - payload["inferenceConfig"] = inference_config - - tool_config = self._convert_tools_to_bedrock_config(chat_options.tools) - if tool_choice := self._convert_tool_choice(chat_options.tool_choice): - if tool_config is None: - tool_config = {} - tool_config["toolChoice"] = tool_choice + run_options["system"] = system_prompts + + if (temperature := options.get("temperature")) is not None: + run_options["inferenceConfig"]["temperature"] = temperature + if (top_p := options.get("top_p")) is not None: + run_options["inferenceConfig"]["topP"] = top_p + if (stop := options.get("stop")) is not None: + run_options["inferenceConfig"]["stopSequences"] = stop + + tool_config = self._prepare_tools(options.get("tools")) + if tool_mode := validate_tool_mode(options.get("tool_choice")): + tool_config = tool_config or {} + match tool_mode.get("mode"): + case "auto" | "none": + tool_config["toolChoice"] = {tool_mode.get("mode"): {}} + case "required": + if required_name := tool_mode.get("required_function_name"): + tool_config["toolChoice"] = {"tool": {"name": required_name}} + else: + tool_config["toolChoice"] = {"any": {}} + case _: + raise ServiceInitializationError(f"Unsupported tool mode for Bedrock: {tool_mode.get('mode')}") if tool_config: - payload["toolConfig"] = tool_config + run_options["toolConfig"] = tool_config - if chat_options.additional_properties: - payload.update(chat_options.additional_properties) - if kwargs: - payload.update(kwargs) - return payload + return run_options def _prepare_bedrock_messages( self, messages: Sequence[ChatMessage] @@ -374,12 +540,10 @@ def _normalize_tool_result_value(self, value: Any) -> dict[str, Any]: return {"text": str(value)} return {"text": str(value)} - def _convert_tools_to_bedrock_config( - self, tools: list[ToolProtocol | MutableMapping[str, Any]] | None - ) -> dict[str, Any] | None: + def _prepare_tools(self, tools: list[ToolProtocol | MutableMapping[str, Any]] | None) -> dict[str, Any] | None: + converted: list[dict[str, Any]] = [] if not tools: return None - converted: list[dict[str, Any]] = [] for tool in tools: if isinstance(tool, MutableMapping): converted.append(dict(tool)) @@ -396,24 +560,6 @@ def _convert_tools_to_bedrock_config( logger.debug("Ignoring unsupported tool type for Bedrock: %s", type(tool)) return {"tools": converted} if converted else None - def _convert_tool_choice(self, tool_choice: Any) -> dict[str, Any] | None: - if not tool_choice: - return None - mode = tool_choice.mode if hasattr(tool_choice, "mode") else str(tool_choice) - required_name = getattr(tool_choice, "required_function_name", None) - match mode: - case "auto": - return {"auto": {}} - case "none": - return {"none": {}} - case "required": - if required_name: - return {"tool": {"name": required_name}} - return {"any": {}} - case _: - logger.debug("Unsupported tool choice mode for Bedrock: %s", mode) - return None - @staticmethod def _generate_tool_call_id() -> str: return f"tool-call-{uuid4().hex}" diff --git a/python/packages/bedrock/pyproject.toml b/python/packages/bedrock/pyproject.toml index 2e60e11288..d94035dd53 100644 --- a/python/packages/bedrock/pyproject.toml +++ b/python/packages/bedrock/pyproject.toml @@ -4,7 +4,7 @@ description = "Amazon Bedrock integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/bedrock/samples/bedrock_sample.py b/python/packages/bedrock/samples/bedrock_sample.py index 9e14b5a385..57901f49e8 100644 --- a/python/packages/bedrock/samples/bedrock_sample.py +++ b/python/packages/bedrock/samples/bedrock_sample.py @@ -5,13 +5,12 @@ from collections.abc import Sequence from agent_framework import ( - AgentRunResponse, + AgentResponse, ChatAgent, FunctionCallContent, FunctionResultContent, Role, TextContent, - ToolMode, ai_function, ) @@ -31,7 +30,7 @@ async def main() -> None: chat_client=BedrockChatClient(), instructions="You are a concise travel assistant.", name="BedrockWeatherAgent", - tool_choice=ToolMode.AUTO, + tool_choice="auto", tools=[get_weather], ) @@ -40,7 +39,7 @@ async def main() -> None: _log_response(response) -def _log_response(response: AgentRunResponse) -> None: +def _log_response(response: AgentResponse) -> None: logging.info("\nConversation transcript:") for idx, message in enumerate(response.messages, start=1): tag = f"{idx}. {message.role.value if isinstance(message.role, Role) else message.role}" diff --git a/python/packages/bedrock/tests/test_bedrock_client.py b/python/packages/bedrock/tests/test_bedrock_client.py index 4086dfa429..5842426483 100644 --- a/python/packages/bedrock/tests/test_bedrock_client.py +++ b/python/packages/bedrock/tests/test_bedrock_client.py @@ -6,7 +6,7 @@ from typing import Any import pytest -from agent_framework import ChatMessage, ChatOptions, Role, TextContent +from agent_framework import ChatMessage, Role, TextContent from agent_framework.exceptions import ServiceInitializationError from agent_framework_bedrock import BedrockChatClient @@ -46,7 +46,7 @@ def test_get_response_invokes_bedrock_runtime() -> None: ChatMessage(role=Role.USER, contents=[TextContent(text="hello")]), ] - response = asyncio.run(client.get_response(messages=messages, chat_options=ChatOptions(max_tokens=32))) + response = asyncio.run(client.get_response(messages=messages, options={"max_tokens": 32})) assert stub.calls, "Expected the runtime client to be called" payload = stub.calls[0] @@ -66,4 +66,4 @@ def test_build_request_requires_non_system_messages() -> None: messages = [ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="Only system text")])] with pytest.raises(ServiceInitializationError): - client._build_converse_request(messages, ChatOptions()) + client._prepare_options(messages, {}) diff --git a/python/packages/bedrock/tests/test_bedrock_settings.py b/python/packages/bedrock/tests/test_bedrock_settings.py index a3b0894d28..1924c750c6 100644 --- a/python/packages/bedrock/tests/test_bedrock_settings.py +++ b/python/packages/bedrock/tests/test_bedrock_settings.py @@ -13,7 +13,6 @@ FunctionResultContent, Role, TextContent, - ToolMode, ) from pydantic import BaseModel @@ -46,10 +45,13 @@ def test_build_request_includes_tool_config() -> None: client = _build_client() tool = AIFunction(name="get_weather", description="desc", func=_dummy_weather, input_model=_WeatherArgs) - options = ChatOptions(tools=[tool], tool_choice=ToolMode.REQUIRED("get_weather")) + options = { + "tools": [tool], + "tool_choice": {"mode": "required", "required_function_name": "get_weather"}, + } messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="hi")])] - request = client._build_converse_request(messages, options) + request = client._prepare_options(messages, options) assert request["toolConfig"]["tools"][0]["toolSpec"]["name"] == "get_weather" assert request["toolConfig"]["toolChoice"] == {"tool": {"name": "get_weather"}} @@ -57,7 +59,7 @@ def test_build_request_includes_tool_config() -> None: def test_build_request_serializes_tool_history() -> None: client = _build_client() - options = ChatOptions() + options: ChatOptions = {} messages = [ ChatMessage(role=Role.USER, contents=[TextContent(text="how's weather?")]), ChatMessage( @@ -70,7 +72,7 @@ def test_build_request_serializes_tool_history() -> None: ), ] - request = client._build_converse_request(messages, options) + request = client._prepare_options(messages, options) assistant_block = request["messages"][1]["content"][0]["toolUse"] result_block = request["messages"][2]["content"][0]["toolResult"] diff --git a/python/packages/chatkit/README.md b/python/packages/chatkit/README.md index afdb6f237f..cd4464d7de 100644 --- a/python/packages/chatkit/README.md +++ b/python/packages/chatkit/README.md @@ -4,7 +4,7 @@ This package provides an integration layer between Microsoft Agent Framework and [OpenAI ChatKit (Python)](https://github.com/openai/chatkit-python/). Specifically, it mirrors the [Agent SDK integration](https://github.com/openai/chatkit-python/blob/main/docs/server.md#agents-sdk-integration), and provides the following helpers: -- `stream_agent_response`: A helper to convert a streamed `AgentRunResponseUpdate` +- `stream_agent_response`: A helper to convert a streamed `AgentResponseUpdate` from a Microsoft Agent Framework agent that implements `AgentProtocol` to ChatKit events. - `ThreadItemConverter`: A extendable helper class to convert ChatKit thread items to `ChatMessage` objects that can be consumed by an Agent Framework agent. diff --git a/python/packages/chatkit/agent_framework_chatkit/_converter.py b/python/packages/chatkit/agent_framework_chatkit/_converter.py index 1070d83926..252ac8a753 100644 --- a/python/packages/chatkit/agent_framework_chatkit/_converter.py +++ b/python/packages/chatkit/agent_framework_chatkit/_converter.py @@ -6,11 +6,6 @@ import sys from collections.abc import Awaitable, Callable, Sequence -if sys.version_info >= (3, 11): - from typing import assert_never -else: - from typing_extensions import assert_never - from agent_framework import ( ChatMessage, DataContent, @@ -38,6 +33,11 @@ WorkflowItem, ) +if sys.version_info >= (3, 11): + from typing import assert_never +else: + from typing_extensions import assert_never + logger = logging.getLogger(__name__) diff --git a/python/packages/chatkit/agent_framework_chatkit/_streaming.py b/python/packages/chatkit/agent_framework_chatkit/_streaming.py index daeaa0b4ab..b0273c5944 100644 --- a/python/packages/chatkit/agent_framework_chatkit/_streaming.py +++ b/python/packages/chatkit/agent_framework_chatkit/_streaming.py @@ -6,7 +6,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Callable from datetime import datetime -from agent_framework import AgentRunResponseUpdate, TextContent +from agent_framework import AgentResponseUpdate, TextContent from chatkit.types import ( AssistantMessageContent, AssistantMessageContentPartTextDelta, @@ -19,13 +19,13 @@ async def stream_agent_response( - response_stream: AsyncIterable[AgentRunResponseUpdate], + response_stream: AsyncIterable[AgentResponseUpdate], thread_id: str, generate_id: Callable[[str], str] | None = None, ) -> AsyncIterator[ThreadStreamEvent]: - """Convert a streamed AgentRunResponseUpdate from Agent Framework to ChatKit events. + """Convert a streamed AgentResponseUpdate from Agent Framework to ChatKit events. - This helper function takes a stream of AgentRunResponseUpdate objects from + This helper function takes a stream of AgentResponseUpdate objects from a Microsoft Agent Framework agent and converts them to ChatKit ThreadStreamEvent objects that can be consumed by the ChatKit UI. @@ -34,7 +34,7 @@ async def stream_agent_response( text chunk as it arrives from the agent. Args: - response_stream: An async iterable of AgentRunResponseUpdate objects + response_stream: An async iterable of AgentResponseUpdate objects from an Agent Framework agent. thread_id: The ChatKit thread ID for the conversation. generate_id: Optional function to generate IDs for ChatKit items. diff --git a/python/packages/chatkit/pyproject.toml b/python/packages/chatkit/pyproject.toml index 3fa92669f2..8621411503 100644 --- a/python/packages/chatkit/pyproject.toml +++ b/python/packages/chatkit/pyproject.toml @@ -4,7 +4,7 @@ description = "OpenAI ChatKit integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/chatkit/tests/test_streaming.py b/python/packages/chatkit/tests/test_streaming.py index 2e5041613a..ead7c5f33e 100644 --- a/python/packages/chatkit/tests/test_streaming.py +++ b/python/packages/chatkit/tests/test_streaming.py @@ -4,7 +4,7 @@ from unittest.mock import Mock -from agent_framework import AgentRunResponseUpdate, Role, TextContent +from agent_framework import AgentResponseUpdate, Role, TextContent from chatkit.types import ( ThreadItemAddedEvent, ThreadItemDoneEvent, @@ -34,7 +34,7 @@ async def test_stream_single_text_update(self): """Test streaming single text update.""" async def single_update_stream(): - yield AgentRunResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="Hello world")]) + yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="Hello world")]) events = [] async for event in stream_agent_response(single_update_stream(), thread_id="test_thread"): @@ -59,8 +59,8 @@ async def test_stream_multiple_text_updates(self): """Test streaming multiple text updates.""" async def multiple_updates_stream(): - yield AgentRunResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="Hello ")]) - yield AgentRunResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="world!")]) + yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="Hello ")]) + yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="world!")]) events = [] async for event in stream_agent_response(multiple_updates_stream(), thread_id="test_thread"): @@ -91,7 +91,7 @@ def custom_id_generator(item_type: str) -> str: return f"custom_{item_type}_123" async def single_update_stream(): - yield AgentRunResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="Test")]) + yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="Test")]) events = [] async for event in stream_agent_response( @@ -107,8 +107,8 @@ async def test_stream_empty_content_updates(self): """Test streaming updates with empty content.""" async def empty_content_stream(): - yield AgentRunResponseUpdate(role=Role.ASSISTANT, contents=[]) - yield AgentRunResponseUpdate(role=Role.ASSISTANT, contents=None) + yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[]) + yield AgentResponseUpdate(role=Role.ASSISTANT, contents=None) events = [] async for event in stream_agent_response(empty_content_stream(), thread_id="test_thread"): @@ -130,7 +130,7 @@ async def test_stream_non_text_content(self): del non_text_content.text async def non_text_stream(): - yield AgentRunResponseUpdate(role=Role.ASSISTANT, contents=[non_text_content]) + yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[non_text_content]) events = [] async for event in stream_agent_response(non_text_stream(), thread_id="test_thread"): diff --git a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py index 35abd6dec7..606e1e83b6 100644 --- a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py +++ b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py @@ -4,11 +4,10 @@ from typing import Any, ClassVar from agent_framework import ( - AgentMiddlewares, - AgentRunResponse, - AgentRunResponseUpdate, + AgentMiddlewareTypes, + AgentResponse, + AgentResponseUpdate, AgentThread, - AggregateContextProvider, BaseAgent, ChatMessage, ContextProvider, @@ -79,8 +78,8 @@ def __init__( id: str | None = None, name: str | None = None, description: str | None = None, - context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None, - middleware: AgentMiddlewares | list[AgentMiddlewares] | None = None, + context_provider: ContextProvider | None = None, + middleware: list[AgentMiddlewareTypes] | None = None, environment_id: str | None = None, agent_identifier: str | None = None, client_id: str | None = None, @@ -107,8 +106,8 @@ def __init__( id: id of the CopilotAgent name: Name of the CopilotAgent description: Description of the CopilotAgent - context_providers: Context Providers, to be used by the copilot agent. - middleware: Agent middlewares used by the agent. + context_provider: Context Provider, to be used by the copilot agent. + middleware: Agent middleware used by the agent, should be a list of AgentMiddlewareTypes. environment_id: Environment ID of the Power Platform environment containing the Copilot Studio app. Can also be set via COPILOTSTUDIOAGENT__ENVIRONMENTID environment variable. @@ -138,7 +137,7 @@ def __init__( id=id, name=name, description=description, - context_providers=context_providers, + context_provider=context_provider, middleware=middleware, ) if not client: @@ -211,15 +210,15 @@ async def run( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: """Get a response from the agent. This method returns the final result of the agent's execution - as a single AgentRunResponse object. The caller is blocked until + as a single AgentResponse object. The caller is blocked until the final result is available. Note: For streaming responses, use the run_stream method, which returns - intermediate steps and the final result as a stream of AgentRunResponseUpdate + intermediate steps and the final result as a stream of AgentResponseUpdate objects. Streaming only the final result is not feasible because the timing of the final result's availability is unknown, and blocking the caller until then is undesirable in streaming scenarios. @@ -249,7 +248,7 @@ async def run( response_messages = [message async for message in self._process_activities(activities, streaming=False)] response_id = response_messages[0].message_id if response_messages else None - return AgentRunResponse(messages=response_messages, response_id=response_id) + return AgentResponse(messages=response_messages, response_id=response_id) async def run_stream( self, @@ -257,13 +256,13 @@ async def run_stream( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: """Run the agent as a stream. This method will return the intermediate steps and final results of the - agent's execution as a stream of AgentRunResponseUpdate objects to the caller. + agent's execution as a stream of AgentResponseUpdate objects to the caller. - Note: An AgentRunResponseUpdate object contains a chunk of a message. + Note: An AgentResponseUpdate object contains a chunk of a message. Args: messages: The message(s) to send to the agent. @@ -286,7 +285,7 @@ async def run_stream( activities = self.client.ask_question(question, thread.service_thread_id) async for message in self._process_activities(activities, streaming=True): - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( role=message.role, contents=message.contents, author_name=message.author_name, diff --git a/python/packages/copilotstudio/pyproject.toml b/python/packages/copilotstudio/pyproject.toml index 0ead8b437d..83e1202016 100644 --- a/python/packages/copilotstudio/pyproject.toml +++ b/python/packages/copilotstudio/pyproject.toml @@ -4,7 +4,7 @@ description = "Copilot Studio integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/copilotstudio/tests/test_copilot_agent.py b/python/packages/copilotstudio/tests/test_copilot_agent.py index 2a2e36263a..4777557d32 100644 --- a/python/packages/copilotstudio/tests/test_copilot_agent.py +++ b/python/packages/copilotstudio/tests/test_copilot_agent.py @@ -5,8 +5,8 @@ import pytest from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentThread, ChatMessage, Role, @@ -133,7 +133,7 @@ async def test_run_with_string_message(self, mock_copilot_client: MagicMock, moc response = await agent.run("test message") - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert len(response.messages) == 1 content = response.messages[0].contents[0] assert isinstance(content, TextContent) @@ -153,7 +153,7 @@ async def test_run_with_chat_message(self, mock_copilot_client: MagicMock, mock_ chat_message = ChatMessage(role=Role.USER, contents=[TextContent("test message")]) response = await agent.run(chat_message) - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert len(response.messages) == 1 content = response.messages[0].contents[0] assert isinstance(content, TextContent) @@ -173,7 +173,7 @@ async def test_run_with_thread(self, mock_copilot_client: MagicMock, mock_activi response = await agent.run("test message", thread=thread) - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert len(response.messages) == 1 assert thread.service_thread_id == "test-conversation-id" @@ -204,7 +204,7 @@ async def test_run_stream_with_string_message(self, mock_copilot_client: MagicMo response_count = 0 async for response in agent.run_stream("test message"): - assert isinstance(response, AgentRunResponseUpdate) + assert isinstance(response, AgentResponseUpdate) content = response.contents[0] assert isinstance(content, TextContent) assert content.text == "Streaming response" @@ -231,7 +231,7 @@ async def test_run_stream_with_thread(self, mock_copilot_client: MagicMock) -> N response_count = 0 async for response in agent.run_stream("test message", thread=thread): - assert isinstance(response, AgentRunResponseUpdate) + assert isinstance(response, AgentResponseUpdate) content = response.contents[0] assert isinstance(content, TextContent) assert content.text == "Streaming response" @@ -285,7 +285,7 @@ async def test_run_multiple_activities(self, mock_copilot_client: MagicMock) -> response = await agent.run("test message") - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert len(response.messages) == 2 async def test_run_list_of_messages(self, mock_copilot_client: MagicMock, mock_activity: MagicMock) -> None: @@ -301,7 +301,7 @@ async def test_run_list_of_messages(self, mock_copilot_client: MagicMock, mock_a messages = ["Hello", "How are you?"] response = await agent.run(messages) - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert len(response.messages) == 1 async def test_run_stream_start_conversation_failure(self, mock_copilot_client: MagicMock) -> None: diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index aadd1be40a..628ac7fb17 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -7,7 +7,16 @@ from contextlib import AbstractAsyncContextManager, AsyncExitStack from copy import deepcopy from itertools import chain -from typing import Any, ClassVar, Literal, Protocol, TypeVar, cast, runtime_checkable +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + Protocol, + TypedDict, + cast, + runtime_checkable, +) from uuid import uuid4 from mcp import types @@ -18,36 +27,82 @@ from ._clients import BaseChatClient, ChatClientProtocol from ._logging import get_logger from ._mcp import LOG_LEVEL_MAPPING, MCPTool -from ._memory import AggregateContextProvider, Context, ContextProvider +from ._memory import Context, ContextProvider from ._middleware import Middleware, use_agent_middleware from ._serialization import SerializationMixin from ._threads import AgentThread, ChatMessageStoreProtocol from ._tools import FUNCTION_INVOKING_CHAT_CLIENT_MARKER, AIFunction, ToolProtocol from ._types import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, ChatMessage, - ChatOptions, ChatResponse, ChatResponseUpdate, Role, - ToolMode, ) from .exceptions import AgentExecutionException, AgentInitializationError from .observability import use_agent_instrumentation +if TYPE_CHECKING: + from ._types import ChatOptions + + +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: from typing_extensions import override # type: ignore[import] # pragma: no cover + if sys.version_info >= (3, 11): from typing import Self # pragma: no cover else: from typing_extensions import Self # pragma: no cover + logger = get_logger("agent_framework") TThreadType = TypeVar("TThreadType", bound="AgentThread") +TOptions_co = TypeVar( + "TOptions_co", + bound=TypedDict, # type: ignore[valid-type] + default="ChatOptions", + covariant=True, +) + + +def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: + """Merge two options dicts, with override values taking precedence. + + Args: + base: The base options dict. + override: The override options dict (values take precedence). + + Returns: + A new merged options dict. + """ + result = dict(base) + for key, value in override.items(): + if value is None: + continue + if key == "tools" and result.get("tools"): + # Combine tool lists + result["tools"] = list(result["tools"]) + list(value) + elif key == "logit_bias" and result.get("logit_bias"): + # Merge logit_bias dicts + result["logit_bias"] = {**result["logit_bias"], **value} + elif key == "metadata" and result.get("metadata"): + # Merge metadata dicts + result["metadata"] = {**result["metadata"], **value} + elif key == "instructions" and result.get("instructions"): + # Concatenate instructions + result["instructions"] = f"{result['instructions']}\n{value}" + else: + result[key] = value + return result def _sanitize_agent_name(agent_name: str | None) -> str | None: @@ -116,37 +171,22 @@ class AgentProtocol(Protocol): # No need to inherit from AgentProtocol or use any framework classes class CustomAgent: def __init__(self): - self._id = "custom-agent-001" - self._name = "Custom Agent" - - @property - def id(self) -> str: - return self._id - - @property - def name(self) -> str | None: - return self._name - - @property - def display_name(self) -> str: - return self.name or self.id - - @property - def description(self) -> str | None: - return "A fully custom agent implementation" + self.id = "custom-agent-001" + self.name = "Custom Agent" + self.description = "A fully custom agent implementation" async def run(self, messages=None, *, thread=None, **kwargs): # Your custom implementation - from agent_framework import AgentRunResponse + from agent_framework import AgentResponse - return AgentRunResponse(messages=[], response_id="custom-response") + return AgentResponse(messages=[], response_id="custom-response") def run_stream(self, messages=None, *, thread=None, **kwargs): # Your custom streaming implementation async def _stream(): - from agent_framework import AgentRunResponseUpdate + from agent_framework import AgentResponseUpdate - yield AgentRunResponseUpdate() + yield AgentResponseUpdate() return _stream() @@ -160,41 +200,25 @@ def get_new_thread(self, **kwargs): assert isinstance(instance, AgentProtocol) """ - @property - def id(self) -> str: - """Returns the ID of the agent.""" - ... - - @property - def name(self) -> str | None: - """Returns the name of the agent.""" - ... - - @property - def display_name(self) -> str: - """Returns the display name of the agent.""" - ... - - @property - def description(self) -> str | None: - """Returns the description of the agent.""" - ... + id: str + name: str | None + description: str | None async def run( self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None, *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: """Get a response from the agent. This method returns the final result of the agent's execution - as a single AgentRunResponse object. The caller is blocked until + as a single AgentResponse object. The caller is blocked until the final result is available. Note: For streaming responses, use the run_stream method, which returns - intermediate steps and the final result as a stream of AgentRunResponseUpdate + intermediate steps and the final result as a stream of AgentResponseUpdate objects. Streaming only the final result is not feasible because the timing of the final result's availability is unknown, and blocking the caller until then is undesirable in streaming scenarios. @@ -213,17 +237,17 @@ async def run( def run_stream( self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None, *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: """Run the agent as a stream. This method will return the intermediate steps and final results of the - agent's execution as a stream of AgentRunResponseUpdate objects to the caller. + agent's execution as a stream of AgentResponseUpdate objects to the caller. - Note: An AgentRunResponseUpdate object contains a chunk of a message. + Note: An AgentResponseUpdate object contains a chunk of a message. Args: messages: The message(s) to send to the agent. @@ -259,19 +283,19 @@ class BaseAgent(SerializationMixin): Examples: .. code-block:: python - from agent_framework import BaseAgent, AgentThread, AgentRunResponse + from agent_framework import BaseAgent, AgentThread, AgentResponse # Create a concrete subclass that implements the protocol class SimpleAgent(BaseAgent): async def run(self, messages=None, *, thread=None, **kwargs): # Custom implementation - return AgentRunResponse(messages=[], response_id="simple-response") + return AgentResponse(messages=[], response_id="simple-response") def run_stream(self, messages=None, *, thread=None, **kwargs): async def _stream(): # Custom streaming implementation - yield AgentRunResponseUpdate() + yield AgentResponseUpdate() return _stream() @@ -289,7 +313,6 @@ async def _stream(): # Access agent properties print(agent.id) # Custom or auto-generated UUID - print(agent.display_name) # Returns name or id """ DEFAULT_EXCLUDE: ClassVar[set[str]] = {"additional_properties"} @@ -300,8 +323,8 @@ def __init__( id: str | None = None, name: str | None = None, description: str | None = None, - context_providers: ContextProvider | Sequence[ContextProvider] | None = None, - middleware: Middleware | Sequence[Middleware] | None = None, + context_provider: ContextProvider | None = None, + middleware: Sequence[Middleware] | None = None, additional_properties: MutableMapping[str, Any] | None = None, **kwargs: Any, ) -> None: @@ -312,8 +335,8 @@ def __init__( a new UUID will be generated. name: The name of the agent, can be None. description: The description of the agent. - context_providers: The collection of multiple context providers to include during agent invocation. - middleware: List of middleware to intercept agent and function invocations. + context_provider: The context provider to include during agent invocation. + middleware: List of middleware. additional_properties: Additional properties set on the agent. kwargs: Additional keyword arguments (merged into additional_properties). """ @@ -322,11 +345,10 @@ def __init__( self.id = id self.name = name self.description = description - self.context_provider = self._prepare_context_providers(context_providers) - if middleware is None or isinstance(middleware, Sequence): - self.middleware: list[Middleware] | None = cast(list[Middleware], middleware) if middleware else None - else: - self.middleware = [middleware] + self.context_provider = context_provider + self.middleware: list[Middleware] | None = ( + cast(list[Middleware], middleware) if middleware is not None else None + ) # Merge kwargs into additional_properties self.additional_properties: dict[str, Any] = cast(dict[str, Any], additional_properties or {}) @@ -356,14 +378,6 @@ async def _notify_thread_of_new_messages( if thread.context_provider: await thread.context_provider.invoked(input_messages, response_messages, **kwargs) - @property - def display_name(self) -> str: - """Returns the display name of the agent. - - This is the name if present, otherwise the id. - """ - return self.name or self.id - def get_new_thread(self, **kwargs: Any) -> AgentThread: """Return a new AgentThread instance that is compatible with the agent. @@ -398,8 +412,8 @@ def as_tool( description: str | None = None, arg_name: str = "task", arg_description: str | None = None, - stream_callback: Callable[[AgentRunResponseUpdate], None] - | Callable[[AgentRunResponseUpdate], Awaitable[None]] + stream_callback: Callable[[AgentResponseUpdate], None] + | Callable[[AgentResponseUpdate], Awaitable[None]] | None = None, ) -> AIFunction[BaseModel, str]: """Create an AIFunction tool that wraps this agent. @@ -464,7 +478,7 @@ async def agent_wrapper(**kwargs: Any) -> str: return (await self.run(input_text, **forwarded_kwargs)).text # Use streaming mode - accumulate updates and create final response - response_updates: list[AgentRunResponseUpdate] = [] + response_updates: list[AgentResponseUpdate] = [] async for update in self.run_stream(input_text, **forwarded_kwargs): response_updates.append(update) if is_async_callback: @@ -473,7 +487,7 @@ async def agent_wrapper(**kwargs: Any) -> str: stream_callback(update) # Create final text from accumulated updates - return AgentRunResponse.from_agent_run_response_updates(response_updates).text + return AgentResponse.from_agent_run_response_updates(response_updates).text agent_tool: AIFunction[BaseModel, str] = AIFunction( name=tool_name, @@ -486,7 +500,7 @@ async def agent_wrapper(**kwargs: Any) -> str: def _normalize_messages( self, - messages: str | ChatMessage | Sequence[str] | Sequence[ChatMessage] | None = None, + messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None, ) -> list[ChatMessage]: if messages is None: return [] @@ -499,38 +513,29 @@ def _normalize_messages( return [ChatMessage(role=Role.USER, text=msg) if isinstance(msg, str) else msg for msg in messages] - def _prepare_context_providers( - self, - context_providers: ContextProvider | Sequence[ContextProvider] | None = None, - ) -> AggregateContextProvider | None: - if not context_providers: - return None - - if isinstance(context_providers, AggregateContextProvider): - return context_providers - - return AggregateContextProvider(context_providers) - # region ChatAgent @use_agent_middleware @use_agent_instrumentation(capture_usage=False) # type: ignore[arg-type,misc] -class ChatAgent(BaseAgent): # type: ignore[misc] +class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc] """A Chat Client Agent. This is the primary agent implementation that uses a chat client to interact with language models. It supports tools, context providers, middleware, and both streaming and non-streaming responses. + The generic type parameter TOptions specifies which options TypedDict this agent + accepts. This enables IDE autocomplete and type checking for provider-specific options. + Examples: Basic usage: .. code-block:: python from agent_framework import ChatAgent - from agent_framework.clients import OpenAIChatClient + from agent_framework.openai import OpenAIChatClient # Create a basic chat agent client = OpenAIChatClient(model_id="gpt-4") @@ -562,72 +567,55 @@ def get_weather(location: str) -> str: async for update in agent.run_stream("What's the weather in Paris?"): print(update.text, end="") - With additional provider specific options: + With typed options for IDE autocomplete: .. code-block:: python - agent = ChatAgent( + from agent_framework import ChatAgent + from agent_framework.openai import OpenAIChatClient, OpenAIChatOptions + + client = OpenAIChatClient(model_id="gpt-4o") + agent: ChatAgent[OpenAIChatOptions] = ChatAgent( chat_client=client, name="reasoning-agent", instructions="You are a reasoning assistant.", - model_id="gpt-5", - temperature=0.7, - max_tokens=500, - additional_chat_options={ - "reasoning": {"effort": "high", "summary": "concise"} - }, # OpenAI Responses specific. + options={ + "temperature": 0.7, + "max_tokens": 500, + "reasoning_effort": "high", # OpenAI-specific, IDE will autocomplete! + }, ) - # Use streaming responses - async for update in agent.run_stream("How do you prove the pythagorean theorem?"): - print(update.text, end="") + # Or pass options at runtime + response = await agent.run( + "What is 25 * 47?", + options={"temperature": 0.0, "logprobs": True}, + ) """ AGENT_PROVIDER_NAME: ClassVar[str] = "microsoft.agent_framework" def __init__( self, - chat_client: ChatClientProtocol, + chat_client: ChatClientProtocol[TOptions_co], instructions: str | None = None, *, id: str | None = None, name: str | None = None, description: str | None = None, - chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None, - context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None, - middleware: Middleware | list[Middleware] | None = None, - # chat options - allow_multiple_tool_calls: bool | None = None, - conversation_id: str | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str | int, float] | None = None, - max_tokens: int | None = None, - metadata: dict[str, Any] | None = None, - model_id: str | None = None, - presence_penalty: float | None = None, - response_format: type[BaseModel] | None = None, - seed: int | None = None, - stop: str | Sequence[str] | None = None, - store: bool | None = None, - temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", tools: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None = None, - top_p: float | None = None, - user: str | None = None, - additional_chat_options: dict[str, Any] | None = None, + default_options: TOptions_co | None = None, + chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None, + context_provider: ContextProvider | None = None, + middleware: Sequence[Middleware] | None = None, **kwargs: Any, ) -> None: """Initialize a ChatAgent instance. - Note: - The set of parameters from frequency_penalty to request_kwargs are used to - call the chat client. They can also be passed to both run methods. - When both are set, the ones passed to the run methods take precedence. - Args: chat_client: The chat client to use for the agent. instructions: Optional instructions for the agent. @@ -639,35 +627,24 @@ def __init__( description: A brief description of the agent's purpose. chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol. If not provided, the default in-memory store will be used. - context_providers: The collection of multiple context providers to include during agent invocation. + context_provider: The context providers to include during agent invocation. middleware: List of middleware to intercept agent and function invocations. - allow_multiple_tool_calls: Whether to allow multiple tool calls in a single response. - conversation_id: The conversation ID for service-managed threads. - Cannot be used together with chat_message_store_factory. - frequency_penalty: The frequency penalty to use. - logit_bias: The logit bias to use. - max_tokens: The maximum number of tokens to generate. - metadata: Additional metadata to include in the request. - model_id: The model_id to use for the agent. - This overrides the model_id set in the chat client if it contains one. - presence_penalty: The presence penalty to use. - response_format: The format of the response. - seed: The random seed to use. - stop: The stop sequence(s) for the request. - store: Whether to store the response. - temperature: The sampling temperature to use. - tool_choice: The tool choice for the request. + default_options: A TypedDict containing chat options. When using a typed agent like + ``ChatAgent[OpenAIChatOptions]``, this enables IDE autocomplete for + provider-specific options including temperature, max_tokens, model_id, + tool_choice, and provider-specific options like reasoning_effort. + You can also create your own TypedDict for custom chat clients. + These can be overridden at runtime via the ``options`` parameter of ``run()`` and ``run_stream()``. tools: The tools to use for the request. - top_p: The nucleus sampling probability to use. - user: The user to associate with the request. - additional_chat_options: A dictionary of other values that will be passed through - to the chat_client ``get_response`` and ``get_streaming_response`` methods. - This can be used to pass provider specific parameters. kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``. Raises: AgentInitializationError: If both conversation_id and chat_message_store_factory are provided. """ + # Extract conversation_id from options for validation + opts = dict(default_options) if default_options else {} + conversation_id = opts.get("conversation_id") + if conversation_id is not None and chat_message_store_factory is not None: raise AgentInitializationError( "Cannot specify both conversation_id and chat_message_store_factory. " @@ -683,41 +660,51 @@ def __init__( id=id, name=name, description=description, - context_providers=context_providers, + context_provider=context_provider, middleware=middleware, **kwargs, ) - self.chat_client = chat_client + self.chat_client: ChatClientProtocol[TOptions_co] = chat_client self.chat_message_store_factory = chat_message_store_factory + # Get tools from options or named parameter (named param takes precedence) + tools_ = tools if tools is not None else opts.pop("tools", None) + + # Handle instructions - named parameter takes precedence over options + instructions_ = instructions if instructions is not None else opts.pop("instructions", None) + # We ignore the MCP Servers here and store them separately, # we add their functions to the tools list at runtime normalized_tools: list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] = ( # type:ignore[reportUnknownVariableType] - [] if tools is None else tools if isinstance(tools, list) else [tools] # type: ignore[list-item] + [] if tools_ is None else tools_ if isinstance(tools_, list) else [tools_] # type: ignore[list-item] ) - self._local_mcp_tools = [tool for tool in normalized_tools if isinstance(tool, MCPTool)] + self.mcp_tools: list[MCPTool] = [tool for tool in normalized_tools if isinstance(tool, MCPTool)] agent_tools = [tool for tool in normalized_tools if not isinstance(tool, MCPTool)] - self.chat_options = ChatOptions( - model_id=model_id or (str(chat_client.model_id) if hasattr(chat_client, "model_id") else None), - allow_multiple_tool_calls=allow_multiple_tool_calls, - conversation_id=conversation_id, - frequency_penalty=frequency_penalty, - instructions=instructions, - logit_bias=logit_bias, - max_tokens=max_tokens, - metadata=metadata, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - store=store, - temperature=temperature, - tool_choice=tool_choice, - tools=agent_tools, - top_p=top_p, - user=user, - additional_properties=additional_chat_options or {}, # type: ignore - ) + + # Build chat options dict + self.default_options: dict[str, Any] = { + "model_id": opts.pop("model_id", None) or (getattr(self.chat_client, "model_id", None)), + "allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None), + "conversation_id": conversation_id, + "frequency_penalty": opts.pop("frequency_penalty", None), + "instructions": instructions_, + "logit_bias": opts.pop("logit_bias", None), + "max_tokens": opts.pop("max_tokens", None), + "metadata": opts.pop("metadata", None), + "presence_penalty": opts.pop("presence_penalty", None), + "response_format": opts.pop("response_format", None), + "seed": opts.pop("seed", None), + "stop": opts.pop("stop", None), + "store": opts.pop("store", None), + "temperature": opts.pop("temperature", None), + "tool_choice": opts.pop("tool_choice", "auto"), + "tools": agent_tools, + "top_p": opts.pop("top_p", None), + "user": opts.pop("user", None), + **opts, # Remaining options are provider-specific + } + # Remove None values from chat_options + self.default_options = {k: v for k, v in self.default_options.items() if v is not None} self._async_exit_stack = AsyncExitStack() self._update_agent_name_and_description() @@ -733,7 +720,7 @@ async def __aenter__(self) -> "Self": Returns: The ChatAgent instance. """ - for context_manager in chain([self.chat_client], self._local_mcp_tools): + for context_manager in chain([self.chat_client], self.mcp_tools): if isinstance(context_manager, AbstractAsyncContextManager): await self._async_exit_stack.enter_async_context(context_manager) return self @@ -769,32 +756,17 @@ def _update_agent_name_and_description(self) -> None: async def run( self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None, *, thread: AgentThread | None = None, - allow_multiple_tool_calls: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str | int, float] | None = None, - max_tokens: int | None = None, - metadata: dict[str, Any] | None = None, - model_id: str | None = None, - presence_penalty: float | None = None, - response_format: type[BaseModel] | None = None, - seed: int | None = None, - stop: str | Sequence[str] | None = None, - store: bool | None = None, - temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None, tools: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] | list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None = None, - top_p: float | None = None, - user: str | None = None, - additional_chat_options: dict[str, Any] | None = None, + options: TOptions_co | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: """Run the agent with the given messages and options. Note: @@ -808,36 +780,29 @@ async def run( Keyword Args: thread: The thread to use for the agent. - allow_multiple_tool_calls: Whether to allow multiple tool calls in a single response. - frequency_penalty: The frequency penalty to use. - logit_bias: The logit bias to use. - max_tokens: The maximum number of tokens to generate. - metadata: Additional metadata to include in the request. - model_id: The model_id to use for the agent. - presence_penalty: The presence penalty to use. - response_format: The format of the response. - seed: The random seed to use. - stop: The stop sequence(s) for the request. - store: Whether to store the response. - temperature: The sampling temperature to use. - tool_choice: The tool choice for the request. - tools: The tools to use for the request. - top_p: The nucleus sampling probability to use. - user: The user to associate with the request. - additional_chat_options: Additional properties to include in the request. - Use this field for provider-specific parameters. + tools: The tools to use for this specific run (merged with default tools). + options: A TypedDict containing chat options. When using a typed agent like + ``ChatAgent[OpenAIChatOptions]``, this enables IDE autocomplete for + provider-specific options including temperature, max_tokens, model_id, + tool_choice, and provider-specific options like reasoning_effort. kwargs: Additional keyword arguments for the agent. Will only be passed to functions that are called. Returns: - An AgentRunResponse containing the agent's response. + An AgentResponse containing the agent's response. """ + # Build options dict from provided options + opts = dict(options) if options else {} + + # Get tools from options or named parameter (named param takes precedence) + tools_ = tools if tools is not None else opts.pop("tools", None) + input_messages = self._normalize_messages(messages) thread, run_chat_options, thread_messages = await self._prepare_thread_and_messages( - thread=thread, input_messages=input_messages + thread=thread, input_messages=input_messages, **kwargs ) normalized_tools: list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] = ( # type:ignore[reportUnknownVariableType] - [] if tools is None else tools if isinstance(tools, list) else [tools] + [] if tools_ is None else tools_ if isinstance(tools_, list) else [tools_] ) agent_name = self._get_agent_name() @@ -852,32 +817,35 @@ async def run( else: final_tools.append(tool) # type: ignore - for mcp_server in self._local_mcp_tools: + for mcp_server in self.mcp_tools: if not mcp_server.is_connected: await self._async_exit_stack.enter_async_context(mcp_server) final_tools.extend(mcp_server.functions) - merged_additional_options = additional_chat_options or {} - co = run_chat_options & ChatOptions( - model_id=model_id, - conversation_id=thread.service_thread_id, - allow_multiple_tool_calls=allow_multiple_tool_calls, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - max_tokens=max_tokens, - metadata=metadata, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - store=store, - temperature=temperature, - tool_choice=tool_choice, - tools=final_tools, - top_p=top_p, - user=user, - additional_properties=merged_additional_options, # type: ignore[arg-type] - ) + # Build options dict from run() options merged with provided options + run_opts: dict[str, Any] = { + "model_id": opts.pop("model_id", None), + "conversation_id": thread.service_thread_id, + "allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None), + "frequency_penalty": opts.pop("frequency_penalty", None), + "logit_bias": opts.pop("logit_bias", None), + "max_tokens": opts.pop("max_tokens", None), + "metadata": opts.pop("metadata", None), + "presence_penalty": opts.pop("presence_penalty", None), + "response_format": opts.pop("response_format", None), + "seed": opts.pop("seed", None), + "stop": opts.pop("stop", None), + "store": opts.pop("store", None), + "temperature": opts.pop("temperature", None), + "tool_choice": opts.pop("tool_choice", None), + "tools": final_tools, + "top_p": opts.pop("top_p", None), + "user": opts.pop("user", None), + **opts, # Remaining options are provider-specific + } + # Remove None values and merge with chat_options + run_opts = {k: v for k, v in run_opts.items() if v is not None} + co = _merge_options(run_chat_options, run_opts) # Ensure thread is forwarded in kwargs for tool invocation kwargs["thread"] = thread @@ -885,7 +853,7 @@ async def run( filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"} response = await self.chat_client.get_response( messages=thread_messages, - chat_options=co, + options=co, # type: ignore[arg-type] **filtered_kwargs, ) @@ -904,7 +872,7 @@ async def run( response.messages, **{k: v for k, v in kwargs.items() if k != "thread"}, ) - return AgentRunResponse( + return AgentResponse( messages=response.messages, response_id=response.response_id, created_at=response.created_at, @@ -916,32 +884,17 @@ async def run( async def run_stream( self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None, *, thread: AgentThread | None = None, - allow_multiple_tool_calls: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str | int, float] | None = None, - max_tokens: int | None = None, - metadata: dict[str, Any] | None = None, - model_id: str | None = None, - presence_penalty: float | None = None, - response_format: type[BaseModel] | None = None, - seed: int | None = None, - stop: str | Sequence[str] | None = None, - store: bool | None = None, - temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None, tools: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] | list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None = None, - top_p: float | None = None, - user: str | None = None, - additional_chat_options: dict[str, Any] | None = None, + options: TOptions_co | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: """Stream the agent with the given messages and options. Note: @@ -955,30 +908,23 @@ async def run_stream( Keyword Args: thread: The thread to use for the agent. - allow_multiple_tool_calls: Whether to allow multiple tool calls in a single response. - frequency_penalty: The frequency penalty to use. - logit_bias: The logit bias to use. - max_tokens: The maximum number of tokens to generate. - metadata: Additional metadata to include in the request. - model_id: The model_id to use for the agent. - presence_penalty: The presence penalty to use. - response_format: The format of the response. - seed: The random seed to use. - stop: The stop sequence(s) for the request. - store: Whether to store the response. - temperature: The sampling temperature to use. - tool_choice: The tool choice for the request. - tools: The tools to use for the request. - top_p: The nucleus sampling probability to use. - user: The user to associate with the request. - additional_chat_options: Additional properties to include in the request. - Use this field for provider-specific parameters. - kwargs: Any additional keyword arguments. + tools: The tools to use for this specific run (merged with agent-level tools). + options: A TypedDict containing chat options. When using a typed agent like + ``ChatAgent[OpenAIChatOptions]``, this enables IDE autocomplete for + provider-specific options including temperature, max_tokens, model_id, + tool_choice, and provider-specific options like reasoning_effort. + kwargs: Additional keyword arguments for the agent. Will only be passed to functions that are called. Yields: - AgentRunResponseUpdate objects containing chunks of the agent's response. + AgentResponseUpdate objects containing chunks of the agent's response. """ + # Build options dict from provided options + opts = dict(options) if options else {} + + # Get tools from options or named parameter (named param takes precedence) + tools_ = tools if tools is not None else opts.pop("tools", None) + input_messages = self._normalize_messages(messages) thread, run_chat_options, thread_messages = await self._prepare_thread_and_messages( thread=thread, input_messages=input_messages, **kwargs @@ -987,7 +933,7 @@ async def run_stream( # Resolve final tool list (runtime provided tools + local MCP server tools) final_tools: list[ToolProtocol | MutableMapping[str, Any] | Callable[..., Any]] = [] normalized_tools: list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] = ( # type: ignore[reportUnknownVariableType] - [] if tools is None else tools if isinstance(tools, list) else [tools] + [] if tools_ is None else tools_ if isinstance(tools_, list) else [tools_] ) # Normalize tools argument to a list without mutating the original parameter for tool in normalized_tools: @@ -998,32 +944,35 @@ async def run_stream( else: final_tools.append(tool) - for mcp_server in self._local_mcp_tools: + for mcp_server in self.mcp_tools: if not mcp_server.is_connected: await self._async_exit_stack.enter_async_context(mcp_server) final_tools.extend(mcp_server.functions) - merged_additional_options = additional_chat_options or {} - co = run_chat_options & ChatOptions( - conversation_id=thread.service_thread_id, - allow_multiple_tool_calls=allow_multiple_tool_calls, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - max_tokens=max_tokens, - metadata=metadata, - model_id=model_id, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - store=store, - temperature=temperature, - tool_choice=tool_choice, - tools=final_tools, - top_p=top_p, - user=user, - additional_properties=merged_additional_options, # type: ignore[arg-type] - ) + # Build options dict from run_stream() options merged with provided options + run_opts: dict[str, Any] = { + "model_id": opts.pop("model_id", None), + "conversation_id": thread.service_thread_id, + "allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None), + "frequency_penalty": opts.pop("frequency_penalty", None), + "logit_bias": opts.pop("logit_bias", None), + "max_tokens": opts.pop("max_tokens", None), + "metadata": opts.pop("metadata", None), + "presence_penalty": opts.pop("presence_penalty", None), + "response_format": opts.pop("response_format", None), + "seed": opts.pop("seed", None), + "stop": opts.pop("stop", None), + "store": opts.pop("store", None), + "temperature": opts.pop("temperature", None), + "tool_choice": opts.pop("tool_choice", None), + "tools": final_tools, + "top_p": opts.pop("top_p", None), + "user": opts.pop("user", None), + **opts, # Remaining options are provider-specific + } + # Remove None values and merge with chat_options + run_opts = {k: v for k, v in run_opts.items() if v is not None} + co = _merge_options(run_chat_options, run_opts) # Ensure thread is forwarded in kwargs for tool invocation kwargs["thread"] = thread @@ -1032,7 +981,7 @@ async def run_stream( response_updates: list[ChatResponseUpdate] = [] async for update in self.chat_client.get_streaming_response( messages=thread_messages, - chat_options=co, + options=co, # type: ignore[arg-type] **filtered_kwargs, ): response_updates.append(update) @@ -1040,7 +989,7 @@ async def run_stream( if update.author_name is None: update.author_name = agent_name - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=update.contents, role=update.role, author_name=update.author_name, @@ -1051,7 +1000,9 @@ async def run_stream( raw_representation=update, ) - response = ChatResponse.from_chat_response_updates(response_updates, output_format_type=co.response_format) + response = ChatResponse.from_chat_response_updates( + response_updates, output_format_type=co.get("response_format") + ) await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id) await self._notify_thread_of_new_messages( @@ -1096,9 +1047,9 @@ def get_new_thread( service_thread_id=service_thread_id, context_provider=self.context_provider, ) - if self.chat_options.conversation_id is not None: + if self.default_options.get("conversation_id") is not None: return AgentThread( - service_thread_id=self.chat_options.conversation_id, + service_thread_id=self.default_options["conversation_id"], context_provider=self.context_provider, ) if self.chat_message_store_factory is not None: @@ -1255,7 +1206,7 @@ async def _prepare_thread_and_messages( thread: AgentThread | None, input_messages: list[ChatMessage] | None = None, **kwargs: Any, - ) -> tuple[AgentThread, ChatOptions, list[ChatMessage]]: + ) -> tuple[AgentThread, dict[str, Any], list[ChatMessage]]: """Prepare the thread and messages for agent execution. This method prepares the conversation thread, merges context provider data, @@ -1275,7 +1226,7 @@ async def _prepare_thread_and_messages( Raises: AgentExecutionException: If the conversation IDs on the thread and agent don't match. """ - chat_options = deepcopy(self.chat_options) if self.chat_options else ChatOptions() + chat_options = deepcopy(self.default_options) if self.default_options else {} thread = thread or self.get_new_thread() if thread.service_thread_id and thread.context_provider: await thread.context_provider.thread_created(thread.service_thread_id) @@ -1292,21 +1243,21 @@ async def _prepare_thread_and_messages( if context.messages: thread_messages.extend(context.messages) if context.tools: - if chat_options.tools is not None: - chat_options.tools.extend(context.tools) + if chat_options.get("tools") is not None: + chat_options["tools"].extend(context.tools) else: - chat_options.tools = list(context.tools) + chat_options["tools"] = list(context.tools) if context.instructions: - chat_options.instructions = ( + chat_options["instructions"] = ( context.instructions - if not chat_options.instructions - else f"{chat_options.instructions}\n{context.instructions}" + if not chat_options.get("instructions") + else f"{chat_options['instructions']}\n{context.instructions}" ) thread_messages.extend(input_messages or []) if ( thread.service_thread_id - and chat_options.conversation_id - and thread.service_thread_id != chat_options.conversation_id + and chat_options.get("conversation_id") + and thread.service_thread_id != chat_options["conversation_id"] ): raise AgentExecutionException( "The conversation_id set on the agent is different from the one set on the thread, " diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 6743902475..f48e8af86a 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -1,14 +1,27 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import sys from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, Callable, MutableMapping, MutableSequence, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol, TypeVar, runtime_checkable - -from pydantic import BaseModel +from collections.abc import ( + AsyncIterable, + Callable, + MutableMapping, + MutableSequence, + Sequence, +) +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + Protocol, + TypedDict, + runtime_checkable, +) from ._logging import get_logger -from ._memory import AggregateContextProvider, ContextProvider +from ._memory import ContextProvider from ._middleware import ( ChatMiddleware, ChatMiddlewareCallable, @@ -18,11 +31,27 @@ ) from ._serialization import SerializationMixin from ._threads import ChatMessageStoreProtocol -from ._tools import FUNCTION_INVOKING_CHAT_CLIENT_MARKER, FunctionInvocationConfiguration, ToolProtocol -from ._types import ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, ToolMode, prepare_messages +from ._tools import ( + FUNCTION_INVOKING_CHAT_CLIENT_MARKER, + FunctionInvocationConfiguration, + ToolProtocol, +) +from ._types import ( + ChatMessage, + ChatResponse, + ChatResponseUpdate, + prepare_messages, + validate_chat_options, +) + +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar if TYPE_CHECKING: from ._agents import ChatAgent + from ._types import ChatOptions TInput = TypeVar("TInput", contravariant=True) @@ -39,14 +68,26 @@ # region ChatClientProtocol Protocol +# Contravariant for the Protocol +TOptions_contra = TypeVar( + "TOptions_contra", + bound=TypedDict, # type: ignore[valid-type] + default="ChatOptions", + contravariant=True, +) + @runtime_checkable -class ChatClientProtocol(Protocol): +class ChatClientProtocol(Protocol[TOptions_contra]): # """A protocol for a chat client that can generate responses. This protocol defines the interface that all chat clients must implement, including methods for generating both streaming and non-streaming responses. + The generic type parameter TOptions specifies which options TypedDict this + client accepts, enabling IDE autocomplete and type checking for provider-specific + options. + Note: Protocols use structural subtyping (duck typing). Classes don't need to explicitly inherit from this protocol to be considered compatible. @@ -59,10 +100,6 @@ class ChatClientProtocol(Protocol): # Any class implementing the required methods is compatible class CustomChatClient: - @property - def additional_properties(self) -> dict[str, Any]: - return {} - async def get_response(self, messages, **kwargs): # Your custom implementation return ChatResponse(messages=[], response_id="custom") @@ -81,61 +118,21 @@ async def _stream(): assert isinstance(client, ChatClientProtocol) """ - @property - def additional_properties(self) -> dict[str, Any]: - """Get additional properties associated with the client.""" - ... + additional_properties: dict[str, Any] async def get_response( self, - messages: str | ChatMessage | list[str] | list[ChatMessage], + messages: str | ChatMessage | Sequence[str | ChatMessage], *, - frequency_penalty: float | None = None, - logit_bias: dict[str | int, float] | None = None, - max_tokens: int | None = None, - metadata: dict[str, Any] | None = None, - model_id: str | None = None, - presence_penalty: float | None = None, - response_format: type[BaseModel] | None = None, - seed: int | None = None, - stop: str | Sequence[str] | None = None, - store: bool | None = None, - temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None, - tools: ToolProtocol - | Callable[..., Any] - | MutableMapping[str, Any] - | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] - | None = None, - top_p: float | None = None, - user: str | None = None, - additional_properties: dict[str, Any] | None = None, + options: TOptions_contra | None = None, **kwargs: Any, ) -> ChatResponse: """Send input and return the response. Args: messages: The sequence of input messages to send. - - Keyword Args: - frequency_penalty: The frequency penalty to use. - logit_bias: The logit bias to use. - max_tokens: The maximum number of tokens to generate. - metadata: Additional metadata to include in the request. - model_id: The model_id to use for the agent. - presence_penalty: The presence penalty to use. - response_format: The format of the response. - seed: The random seed to use. - stop: The stop sequence(s) for the request. - store: Whether to store the response. - temperature: The sampling temperature to use. - tool_choice: The tool choice for the request. - tools: The tools to use for the request. - top_p: The nucleus sampling probability to use. - user: The user to associate with the request. - additional_properties: Additional properties to include in the request. - kwargs: Any additional keyword arguments. - Will only be passed to functions that are called. + options: Chat options as a TypedDict. + **kwargs: Additional chat options. Returns: The response messages generated by the client. @@ -147,155 +144,48 @@ async def get_response( def get_streaming_response( self, - messages: str | ChatMessage | list[str] | list[ChatMessage], + messages: str | ChatMessage | Sequence[str | ChatMessage], *, - frequency_penalty: float | None = None, - logit_bias: dict[str | int, float] | None = None, - max_tokens: int | None = None, - metadata: dict[str, Any] | None = None, - model_id: str | None = None, - presence_penalty: float | None = None, - response_format: type[BaseModel] | None = None, - seed: int | None = None, - stop: str | Sequence[str] | None = None, - store: bool | None = None, - temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None, - tools: ToolProtocol - | Callable[..., Any] - | MutableMapping[str, Any] - | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] - | None = None, - top_p: float | None = None, - user: str | None = None, - additional_properties: dict[str, Any] | None = None, + options: TOptions_contra | None = None, **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: """Send input messages and stream the response. Args: messages: The sequence of input messages to send. - - Keyword Args: - frequency_penalty: The frequency penalty to use. - logit_bias: The logit bias to use. - max_tokens: The maximum number of tokens to generate. - metadata: Additional metadata to include in the request. - model_id: The model_id to use for the agent. - presence_penalty: The presence penalty to use. - response_format: The format of the response. - seed: The random seed to use. - stop: The stop sequence(s) for the request. - store: Whether to store the response. - temperature: The sampling temperature to use. - tool_choice: The tool choice for the request. - tools: The tools to use for the request. - top_p: The nucleus sampling probability to use. - user: The user to associate with the request. - additional_properties: Additional properties to include in the request. - kwargs: Any additional keyword arguments. - Will only be passed to functions that are called. + options: Chat options as a TypedDict. + **kwargs: Additional chat options. Yields: - ChatResponseUpdate: An async iterable of chat response updates containing - the content of the response messages generated by the client. - - Raises: - ValueError: If the input message sequence is ``None``. + ChatResponseUpdate: Partial response updates as they're generated. """ ... +# endregion + + # region ChatClientBase +# Covariant for the BaseChatClient +TOptions_co = TypeVar( + "TOptions_co", + bound=TypedDict, # type: ignore[valid-type] + default="ChatOptions", + covariant=True, +) -def _merge_chat_options( - *, - base_chat_options: ChatOptions | Any | None, - model_id: str | None = None, - allow_multiple_tool_calls: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str | int, float] | None = None, - max_tokens: int | None = None, - metadata: dict[str, Any] | None = None, - presence_penalty: float | None = None, - response_format: type[BaseModel] | None = None, - seed: int | None = None, - stop: str | Sequence[str] | None = None, - store: bool | None = None, - temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None, - tools: list[ToolProtocol | dict[str, Any] | Callable[..., Any]] | None = None, - top_p: float | None = None, - user: str | None = None, - additional_properties: dict[str, Any] | None = None, -) -> ChatOptions: - """Merge base chat options with direct parameters to create a new ChatOptions instance. - - When both base_chat_options and individual parameters are provided, the individual - parameters take precedence and override the corresponding values in base_chat_options. - Tools from both sources are combined into a single list. - - Keyword Args: - base_chat_options: Optional base ChatOptions to merge with direct parameters. - model_id: The model_id to use for the agent. - allow_multiple_tool_calls: Whether to allow multiple tool calls in a single response. - frequency_penalty: The frequency penalty to use. - logit_bias: The logit bias to use. - max_tokens: The maximum number of tokens to generate. - metadata: Additional metadata to include in the request. - presence_penalty: The presence penalty to use. - response_format: The format of the response. - seed: The random seed to use. - stop: The stop sequence(s) for the request. - store: Whether to store the response. - temperature: The sampling temperature to use. - tool_choice: The tool choice for the request. - tools: The normalized tools to use for the request. - top_p: The nucleus sampling probability to use. - user: The user to associate with the request. - additional_properties: Additional properties to include in the request. - - Returns: - A new ChatOptions instance with merged values. - - Raises: - TypeError: If base_chat_options is not None and not an instance of ChatOptions. - """ - # Validate base_chat_options type if provided - if base_chat_options is not None and not isinstance(base_chat_options, ChatOptions): - raise TypeError("chat_options must be an instance of ChatOptions") - - if base_chat_options is None: - base_chat_options = ChatOptions() - - return base_chat_options & ChatOptions( - model_id=model_id, - allow_multiple_tool_calls=allow_multiple_tool_calls, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - max_tokens=max_tokens, - metadata=metadata, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - store=store, - temperature=temperature, - top_p=top_p, - tool_choice=tool_choice, - tools=tools, - user=user, - additional_properties=additional_properties, - ) - - -class BaseChatClient(SerializationMixin, ABC): + +class BaseChatClient(SerializationMixin, ABC, Generic[TOptions_co]): """Base class for chat clients. This abstract base class provides core functionality for chat client implementations, including middleware support, message preparation, and tool normalization. + The generic type parameter TOptions specifies which options TypedDict this client + accepts. This enables IDE autocomplete and type checking for provider-specific options + when using the typed overloads of get_response and get_streaming_response. + Note: BaseChatClient cannot be instantiated directly as it's an abstract base class. Subclasses must implement ``_inner_get_response()`` and ``_inner_get_streaming_response()``. @@ -308,13 +198,13 @@ class BaseChatClient(SerializationMixin, ABC): class CustomChatClient(BaseChatClient): - async def _inner_get_response(self, *, messages, chat_options, **kwargs): + async def _inner_get_response(self, *, messages, options, **kwargs): # Your custom implementation return ChatResponse( messages=[ChatMessage(role="assistant", text="Hello!")], response_id="custom-response" ) - async def _inner_get_streaming_response(self, *, messages, chat_options, **kwargs): + async def _inner_get_streaming_response(self, *, messages, options, **kwargs): # Your custom streaming implementation from agent_framework import ChatResponseUpdate @@ -336,12 +226,7 @@ def __init__( self, *, middleware: ( - ChatMiddleware - | ChatMiddlewareCallable - | FunctionMiddleware - | FunctionMiddlewareCallable - | list[ChatMiddleware | ChatMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable] - | None + Sequence[ChatMiddleware | ChatMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable] | None ) = None, additional_properties: dict[str, Any] | None = None, **kwargs: Any, @@ -384,57 +269,6 @@ def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) return result - def _filter_internal_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: - """Filter out internal framework parameters that shouldn't be passed to chat client implementations. - - Keyword Args: - kwargs: The original kwargs dictionary. - - Returns: - A filtered kwargs dictionary without internal parameters. - """ - return {k: v for k, v in kwargs.items() if not k.startswith("_")} - - @staticmethod - async def _normalize_tools( - tools: ToolProtocol - | MutableMapping[str, Any] - | Callable[..., Any] - | Sequence[ToolProtocol | MutableMapping[str, Any] | Callable[..., Any]] - | None = None, - ) -> list[ToolProtocol | dict[str, Any] | Callable[..., Any]]: - """Normalize tools input to a consistent list format. - - Expands MCP tools to their constituent functions, connecting them if needed. - - Args: - tools: The tools in various supported formats. - - Returns: - A normalized list of tools. - """ - from typing import cast - - final_tools: list[ToolProtocol | dict[str, Any] | Callable[..., Any]] = [] - if not tools: - return final_tools - # Use cast when a sequence is passed (likely already a list) - tools_list = ( - cast(list[ToolProtocol | MutableMapping[str, Any] | Callable[..., Any]], tools) - if isinstance(tools, Sequence) and not isinstance(tools, (str, bytes)) - else [tools] - ) - for tool in tools_list: # type: ignore[reportUnknownType] - from ._mcp import MCPTool - - if isinstance(tool, MCPTool): - if not tool.is_connected: - await tool.connect() - final_tools.extend(tool.functions) # type: ignore - continue - final_tools.append(tool) # type: ignore - return final_tools - # region Internal methods to be implemented by the derived classes @abstractmethod @@ -442,14 +276,14 @@ async def _inner_get_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> ChatResponse: """Send a chat request to the AI service. Keyword Args: messages: The chat messages to send. - chat_options: The options for the request. + options: The options dict for the request. kwargs: Any additional keyword arguments. Returns: @@ -461,14 +295,14 @@ async def _inner_get_streaming_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: """Send a streaming chat request to the AI service. Keyword Args: messages: The chat messages to send. - chat_options: The chat_options for the request. + options: The options dict for the request. kwargs: Any additional keyword arguments. Yields: @@ -487,222 +321,51 @@ async def _inner_get_streaming_response( async def get_response( self, - messages: str | ChatMessage | list[str] | list[ChatMessage], + messages: str | ChatMessage | Sequence[str | ChatMessage], *, - allow_multiple_tool_calls: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str | int, float] | None = None, - max_tokens: int | None = None, - metadata: dict[str, Any] | None = None, - model_id: str | None = None, - presence_penalty: float | None = None, - response_format: type[BaseModel] | None = None, - seed: int | None = None, - stop: str | Sequence[str] | None = None, - store: bool | None = None, - temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None, - tools: ToolProtocol - | Callable[..., Any] - | MutableMapping[str, Any] - | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] - | None = None, - top_p: float | None = None, - user: str | None = None, - additional_properties: dict[str, Any] | None = None, + options: TOptions_co | None = None, **kwargs: Any, ) -> ChatResponse: """Get a response from a chat client. - When both ``chat_options`` (in kwargs) and individual parameters are provided, - the individual parameters take precedence and override the corresponding values - in ``chat_options``. Tools from both sources are combined into a single list. - Args: messages: The message or messages to send to the model. - - Keyword Args: - allow_multiple_tool_calls: Whether to allow multiple tool calls in a single response. - frequency_penalty: The frequency penalty to use. - logit_bias: The logit bias to use. - max_tokens: The maximum number of tokens to generate. - metadata: Additional metadata to include in the request. - model_id: The model_id to use for the agent. - presence_penalty: The presence penalty to use. - response_format: The format of the response. - seed: The random seed to use. - stop: The stop sequence(s) for the request. - store: Whether to store the response. - temperature: The sampling temperature to use. - tool_choice: The tool choice for the request. - Default is `auto`. - tools: The tools to use for the request. - top_p: The nucleus sampling probability to use. - user: The user to associate with the request. - additional_properties: Additional properties to include in the request. - Can be used for provider-specific parameters. - kwargs: Any additional keyword arguments. - May include ``chat_options`` which provides base values that can be overridden by direct parameters. + options: Chat options as a TypedDict. + **kwargs: Other keyword arguments, can be used to pass function specific parameters. Returns: - A chat response from the model_id. + A chat response from the model. """ - # Normalize tools and merge with base chat_options - normalized_tools = await self._normalize_tools(tools) - chat_options = _merge_chat_options( - base_chat_options=kwargs.pop("chat_options", None), - model_id=model_id, - allow_multiple_tool_calls=allow_multiple_tool_calls, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - max_tokens=max_tokens, - metadata=metadata, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - store=store, - temperature=temperature, - tool_choice=tool_choice, - tools=normalized_tools, - top_p=top_p, - user=user, - additional_properties=additional_properties, + return await self._inner_get_response( + messages=prepare_messages(messages), + options=await validate_chat_options(dict(options) if options else {}), + **kwargs, ) - if chat_options.instructions: - system_msg = ChatMessage(role="system", text=chat_options.instructions) - prepped_messages = [system_msg, *prepare_messages(messages)] - else: - prepped_messages = prepare_messages(messages) - self._prepare_tool_choice(chat_options=chat_options) - - filtered_kwargs = self._filter_internal_kwargs(kwargs) - return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **filtered_kwargs) - async def get_streaming_response( self, - messages: str | ChatMessage | list[str] | list[ChatMessage], + messages: str | ChatMessage | Sequence[str | ChatMessage], *, - allow_multiple_tool_calls: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str | int, float] | None = None, - max_tokens: int | None = None, - metadata: dict[str, Any] | None = None, - model_id: str | None = None, - presence_penalty: float | None = None, - response_format: type[BaseModel] | None = None, - seed: int | None = None, - stop: str | Sequence[str] | None = None, - store: bool | None = None, - temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None, - tools: ToolProtocol - | Callable[..., Any] - | MutableMapping[str, Any] - | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] - | None = None, - top_p: float | None = None, - user: str | None = None, - additional_properties: dict[str, Any] | None = None, + options: TOptions_co | None = None, **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: """Get a streaming response from a chat client. - When both ``chat_options`` (in kwargs) and individual parameters are provided, - the individual parameters take precedence and override the corresponding values - in ``chat_options``. Tools from both sources are combined into a single list. - Args: messages: The message or messages to send to the model. - - Keyword Args: - allow_multiple_tool_calls: Whether to allow multiple tool calls in a single response. - frequency_penalty: The frequency penalty to use. - logit_bias: The logit bias to use. - max_tokens: The maximum number of tokens to generate. - metadata: Additional metadata to include in the request. - model_id: The model_id to use for the agent. - presence_penalty: The presence penalty to use. - response_format: The format of the response. - seed: The random seed to use. - stop: The stop sequence(s) for the request. - store: Whether to store the response. - temperature: The sampling temperature to use. - tool_choice: The tool choice for the request. - Default is `auto`. - tools: The tools to use for the request. - top_p: The nucleus sampling probability to use. - user: The user to associate with the request. - additional_properties: Additional properties to include in the request. - Can be used for provider-specific parameters. - kwargs: Any additional keyword arguments. - May include ``chat_options`` which provides base values that can be overridden by direct parameters. + options: Chat options as a TypedDict. + **kwargs: Other keyword arguments, can be used to pass function specific parameters. Yields: ChatResponseUpdate: A stream representing the response(s) from the LLM. """ - # Normalize tools and merge with base chat_options - normalized_tools = await self._normalize_tools(tools) - chat_options = _merge_chat_options( - base_chat_options=kwargs.pop("chat_options", None), - model_id=model_id, - allow_multiple_tool_calls=allow_multiple_tool_calls, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - max_tokens=max_tokens, - metadata=metadata, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - store=store, - temperature=temperature, - tool_choice=tool_choice, - tools=normalized_tools, - top_p=top_p, - user=user, - additional_properties=additional_properties, - ) - - if chat_options.instructions: - system_msg = ChatMessage(role="system", text=chat_options.instructions) - prepped_messages = [system_msg, *prepare_messages(messages)] - else: - prepped_messages = prepare_messages(messages) - self._prepare_tool_choice(chat_options=chat_options) - - filtered_kwargs = self._filter_internal_kwargs(kwargs) async for update in self._inner_get_streaming_response( - messages=prepped_messages, chat_options=chat_options, **filtered_kwargs + messages=prepare_messages(messages), + options=await validate_chat_options(dict(options) if options else {}), + **kwargs, ): yield update - def _prepare_tool_choice(self, chat_options: ChatOptions) -> None: - """Prepare the tools and tool choice for the chat options. - - This function should be overridden by subclasses to customize tool handling, - as it currently parses only AIFunctions. - - Args: - chat_options: The chat options to prepare. - """ - chat_tool_mode = chat_options.tool_choice - # Explicitly disabled: clear tools and set to NONE - if chat_tool_mode == ToolMode.NONE or chat_tool_mode == "none": - chat_options.tools = None - chat_options.tool_choice = ToolMode.NONE - return - # No tools available: set to NONE regardless of requested mode - if not chat_options.tools: - chat_options.tool_choice = ToolMode.NONE - # Tools available but no explicit mode: default to AUTO - elif chat_tool_mode is None: - chat_options.tool_choice = ToolMode.AUTO - # Tools available with explicit mode: preserve the mode - else: - chat_options.tool_choice = chat_tool_mode - def service_url(self) -> str: """Get the URL of the service. @@ -721,33 +384,17 @@ def create_agent( name: str | None = None, description: str | None = None, instructions: str | None = None, - chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None, - context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None, - middleware: Middleware | list[Middleware] | None = None, - allow_multiple_tool_calls: bool | None = None, - conversation_id: str | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str | int, float] | None = None, - max_tokens: int | None = None, - metadata: dict[str, Any] | None = None, - model_id: str | None = None, - presence_penalty: float | None = None, - response_format: type[BaseModel] | None = None, - seed: int | None = None, - stop: str | Sequence[str] | None = None, - store: bool | None = None, - temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", tools: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None = None, - top_p: float | None = None, - user: str | None = None, - additional_chat_options: dict[str, Any] | None = None, + default_options: TOptions_co | None = None, + chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None, + context_provider: ContextProvider | None = None, + middleware: Sequence[Middleware] | None = None, **kwargs: Any, - ) -> "ChatAgent": + ) -> "ChatAgent[TOptions_co]": """Create a ChatAgent with this client. This is a convenience method that creates a ChatAgent instance with this @@ -759,30 +406,14 @@ def create_agent( description: A brief description of the agent's purpose. instructions: Optional instructions for the agent. These will be put into the messages sent to the chat client service as a system message. + tools: The tools to use for the request. + default_options: A TypedDict containing chat options. When using a typed client like + ``OpenAIChatClient``, this enables IDE autocomplete for provider-specific options + including temperature, max_tokens, model_id, tool_choice, and more. chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol. If not provided, the default in-memory store will be used. - context_providers: Context providers to include during agent invocation. + context_provider: Context providers to include during agent invocation. middleware: List of middleware to intercept agent and function invocations. - allow_multiple_tool_calls: Whether to allow multiple tool calls per agent turn. - conversation_id: The conversation ID to associate with the agent's messages. - frequency_penalty: The frequency penalty to use. - logit_bias: The logit bias to use. - max_tokens: The maximum number of tokens to generate. - metadata: Additional metadata to include in the request. - model_id: The model_id to use for the agent. - presence_penalty: The presence penalty to use. - response_format: The format of the response. - seed: The random seed to use. - stop: The stop sequence(s) for the request. - store: Whether to store the response. - temperature: The sampling temperature to use. - tool_choice: The tool choice for the request. - tools: The tools to use for the request. - top_p: The nucleus sampling probability to use. - user: The user to associate with the request. - additional_chat_options: A dictionary of other values that will be passed through - to the chat_client ``get_response`` and ``get_streaming_response`` methods. - This can be used to pass provider specific parameters. kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``. Returns: @@ -791,14 +422,16 @@ def create_agent( Examples: .. code-block:: python - from agent_framework.clients import OpenAIChatClient + from agent_framework.openai import OpenAIChatClient # Create a client client = OpenAIChatClient(model_id="gpt-4") # Create an agent using the convenience method agent = client.create_agent( - name="assistant", instructions="You are a helpful assistant.", temperature=0.7 + name="assistant", + instructions="You are a helpful assistant.", + default_options={"temperature": 0.7, "max_tokens": 500}, ) # Run the agent @@ -812,26 +445,10 @@ def create_agent( name=name, description=description, instructions=instructions, + tools=tools, + default_options=default_options, chat_message_store_factory=chat_message_store_factory, - context_providers=context_providers, + context_provider=context_provider, middleware=middleware, - allow_multiple_tool_calls=allow_multiple_tool_calls, - conversation_id=conversation_id, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - max_tokens=max_tokens, - metadata=metadata, - model_id=model_id, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - store=store, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_p=top_p, - user=user, - additional_chat_options=additional_chat_options, **kwargs, ) diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index a25f359a59..3a6d5b818c 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -4,23 +4,29 @@ import re import sys from abc import abstractmethod -from collections.abc import Collection, Sequence +from collections.abc import Callable, Collection, Sequence from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore from datetime import timedelta from functools import partial from typing import TYPE_CHECKING, Any, Literal +import httpx +from anyio import ClosedResourceError from mcp import types from mcp.client.session import ClientSession from mcp.client.stdio import StdioServerParameters, stdio_client -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client from mcp.client.websocket import websocket_client from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.session import RequestResponder from pydantic import BaseModel, create_model -from ._tools import AIFunction, HostedMCPSpecificApproval, _build_pydantic_model_from_json_schema +from ._tools import ( + AIFunction, + HostedMCPSpecificApproval, + _build_pydantic_model_from_json_schema, +) from ._types import ( ChatMessage, Contents, @@ -328,7 +334,9 @@ def __init__( approval_mode: (Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None) = None, allowed_tools: Collection[str] | None = None, load_tools: bool = True, + parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True, load_prompts: bool = True, + parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True, session: ClientSession | None = None, request_timeout: int | None = None, chat_client: "ChatClientProtocol | None" = None, @@ -346,7 +354,9 @@ def __init__( self.allowed_tools = allowed_tools self.additional_properties = additional_properties self.load_tools_flag = load_tools + self.parse_tool_results = parse_tool_results self.load_prompts_flag = load_prompts + self.parse_prompt_results = parse_prompt_results self._exit_stack = AsyncExitStack() self.session = session self.request_timeout = request_timeout @@ -366,15 +376,23 @@ def functions(self) -> list[AIFunction[Any, Any]]: return self._functions return [func for func in self._functions if func.name in self.allowed_tools] - async def connect(self) -> None: + async def connect(self, *, reset: bool = False) -> None: """Connect to the MCP server. Establishes a connection to the MCP server, initializes the session, and loads tools and prompts if configured to do so. + Keyword Args: + reset: If True, forces a reconnection even if already connected. + Raises: ToolException: If connection or session initialization fails. """ + if reset: + await self._exit_stack.aclose() + self.session = None + self.is_connected = False + self._exit_stack = AsyncExitStack() if not self.session: try: transport = await self._exit_stack.enter_async_context(self.get_mcp_client()) @@ -564,86 +582,88 @@ async def load_prompts(self) -> None: """Load prompts from the MCP server. Retrieves available prompts from the connected MCP server and converts - them into AIFunction instances. + them into AIFunction instances. Handles pagination automatically. Raises: ToolExecutionException: If the MCP server is not connected. """ - if not self.session: - raise ToolExecutionException("MCP server not connected, please call connect() before using this method.") - try: - prompt_list = await self.session.list_prompts() - except Exception as exc: - logger.info( - "Prompt could not be loaded, you can exclude trying to load, by setting: load_prompts=False", - exc_info=exc, - ) - prompt_list = None - # Track existing function names to prevent duplicates existing_names = {func.name for func in self._functions} - for prompt in prompt_list.prompts if prompt_list else []: - local_name = _normalize_mcp_name(prompt.name) - - # Skip if already loaded - if local_name in existing_names: - continue - - input_model = _get_input_model_from_mcp_prompt(prompt) - approval_mode = self._determine_approval_mode(local_name) - func: AIFunction[BaseModel, list[ChatMessage]] = AIFunction( - func=partial(self.get_prompt, prompt.name), - name=local_name, - description=prompt.description or "", - approval_mode=approval_mode, - input_model=input_model, - ) - self._functions.append(func) - existing_names.add(local_name) + params: types.PaginatedRequestParams | None = None + while True: + # Ensure connection is still valid before each page request + await self._ensure_connected() + + prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr] + + for prompt in prompt_list.prompts: + local_name = _normalize_mcp_name(prompt.name) + + # Skip if already loaded + if local_name in existing_names: + continue + + input_model = _get_input_model_from_mcp_prompt(prompt) + approval_mode = self._determine_approval_mode(local_name) + func: AIFunction[BaseModel, list[ChatMessage] | Any | types.GetPromptResult] = AIFunction( + func=partial(self.get_prompt, prompt.name), + name=local_name, + description=prompt.description or "", + approval_mode=approval_mode, + input_model=input_model, + ) + self._functions.append(func) + existing_names.add(local_name) + + # Check if there are more pages + if not prompt_list or not prompt_list.nextCursor: + break + params = types.PaginatedRequestParams(cursor=prompt_list.nextCursor) async def load_tools(self) -> None: """Load tools from the MCP server. Retrieves available tools from the connected MCP server and converts - them into AIFunction instances. + them into AIFunction instances. Handles pagination automatically. Raises: ToolExecutionException: If the MCP server is not connected. """ - if not self.session: - raise ToolExecutionException("MCP server not connected, please call connect() before using this method.") - try: - tool_list = await self.session.list_tools() - except Exception as exc: - logger.info( - "Tools could not be loaded, you can exclude trying to load, by setting: load_tools=False", - exc_info=exc, - ) - tool_list = None - # Track existing function names to prevent duplicates existing_names = {func.name for func in self._functions} - for tool in tool_list.tools if tool_list else []: - local_name = _normalize_mcp_name(tool.name) - - # Skip if already loaded - if local_name in existing_names: - continue - - input_model = _get_input_model_from_mcp_tool(tool) - approval_mode = self._determine_approval_mode(local_name) - # Create AIFunctions out of each tool - func: AIFunction[BaseModel, list[Contents]] = AIFunction( - func=partial(self.call_tool, tool.name), - name=local_name, - description=tool.description or "", - approval_mode=approval_mode, - input_model=input_model, - ) - self._functions.append(func) - existing_names.add(local_name) + params: types.PaginatedRequestParams | None = None + while True: + # Ensure connection is still valid before each page request + await self._ensure_connected() + + tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr] + + for tool in tool_list.tools: + local_name = _normalize_mcp_name(tool.name) + + # Skip if already loaded + if local_name in existing_names: + continue + + input_model = _get_input_model_from_mcp_tool(tool) + approval_mode = self._determine_approval_mode(local_name) + # Create AIFunctions out of each tool + func: AIFunction[BaseModel, list[Contents] | Any | types.CallToolResult] = AIFunction( + func=partial(self.call_tool, tool.name), + name=local_name, + description=tool.description or "", + approval_mode=approval_mode, + input_model=input_model, + ) + self._functions.append(func) + existing_names.add(local_name) + + # Check if there are more pages + if not tool_list or not tool_list.nextCursor: + break + params = types.PaginatedRequestParams(cursor=tool_list.nextCursor) async def close(self) -> None: """Disconnect from the MCP server. @@ -663,7 +683,28 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: """ pass - async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents]: + async def _ensure_connected(self) -> None: + """Ensure the connection is valid, reconnecting if necessary. + + This method proactively checks if the connection is valid and + reconnects if it's not, avoiding the need to catch ClosedResourceError. + + Raises: + ToolExecutionException: If reconnection fails. + """ + try: + await self.session.send_ping() # type: ignore[union-attr] + except Exception: + logger.info("MCP connection invalid or closed. Reconnecting...") + try: + await self.connect(reset=True) + except Exception as ex: + raise ToolExecutionException( + "Failed to establish MCP connection.", + inner_exception=ex, + ) from ex + + async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents] | Any | types.CallToolResult: """Call a tool with the given arguments. Args: @@ -679,8 +720,6 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents]: ToolExecutionException: If the MCP server is not connected, tools are not loaded, or the tool call fails. """ - if not self.session: - raise ToolExecutionException("MCP server not connected, please call connect() before using this method.") if not self.load_tools_flag: raise ToolExecutionException( "Tools are not loaded for this server, please set load_tools=True in the constructor." @@ -691,16 +730,44 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents]: filtered_kwargs = { k: v for k, v in kwargs.items() if k not in {"chat_options", "tools", "tool_choice", "thread"} } - try: - return _parse_contents_from_mcp_tool_result( - await self.session.call_tool(tool_name, arguments=filtered_kwargs) - ) - except McpError as mcp_exc: - raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc - except Exception as ex: - raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex - async def get_prompt(self, prompt_name: str, **kwargs: Any) -> list[ChatMessage]: + # Try the operation, reconnecting once if the connection is closed + for attempt in range(2): + try: + result = await self.session.call_tool(tool_name, arguments=filtered_kwargs) # type: ignore + if self.parse_tool_results is None: + return result + if self.parse_tool_results is True: + return _parse_contents_from_mcp_tool_result(result) + if callable(self.parse_tool_results): + return self.parse_tool_results(result) + return result + except ClosedResourceError as cl_ex: + if attempt == 0: + # First attempt failed, try reconnecting + logger.info("MCP connection closed unexpectedly. Reconnecting...") + try: + await self.connect(reset=True) + continue # Retry the operation + except Exception as reconn_ex: + raise ToolExecutionException( + "Failed to reconnect to MCP server.", + inner_exception=reconn_ex, + ) from reconn_ex + else: + # Second attempt also failed, give up + logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}") + raise ToolExecutionException( + f"Failed to call tool '{tool_name}' - connection lost.", + inner_exception=cl_ex, + ) from cl_ex + except McpError as mcp_exc: + raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc + except Exception as ex: + raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex + raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.") + + async def get_prompt(self, prompt_name: str, **kwargs: Any) -> list[ChatMessage] | Any | types.GetPromptResult: """Call a prompt with the given arguments. Args: @@ -716,19 +783,46 @@ async def get_prompt(self, prompt_name: str, **kwargs: Any) -> list[ChatMessage] ToolExecutionException: If the MCP server is not connected, prompts are not loaded, or the prompt call fails. """ - if not self.session: - raise ToolExecutionException("MCP server not connected, please call connect() before using this method.") if not self.load_prompts_flag: raise ToolExecutionException( "Prompts are not loaded for this server, please set load_prompts=True in the constructor." ) - try: - prompt_result = await self.session.get_prompt(prompt_name, arguments=kwargs) - return [_parse_message_from_mcp(message) for message in prompt_result.messages] - except McpError as mcp_exc: - raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc - except Exception as ex: - raise ToolExecutionException(f"Failed to call prompt '{prompt_name}'.", inner_exception=ex) from ex + + # Try the operation, reconnecting once if the connection is closed + for attempt in range(2): + try: + prompt_result = await self.session.get_prompt(prompt_name, arguments=kwargs) # type: ignore + if self.parse_prompt_results is None: + return prompt_result + if self.parse_prompt_results is True: + return [_parse_message_from_mcp(message) for message in prompt_result.messages] + if callable(self.parse_prompt_results): + return self.parse_prompt_results(prompt_result) + return prompt_result + except ClosedResourceError as cl_ex: + if attempt == 0: + # First attempt failed, try reconnecting + logger.info("MCP connection closed unexpectedly. Reconnecting...") + try: + await self.connect(reset=True) + continue # Retry the operation + except Exception as reconn_ex: + raise ToolExecutionException( + "Failed to reconnect to MCP server.", + inner_exception=reconn_ex, + ) from reconn_ex + else: + # Second attempt also failed, give up + logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}") + raise ToolExecutionException( + f"Failed to call prompt '{prompt_name}' - connection lost.", + inner_exception=cl_ex, + ) from cl_ex + except McpError as mcp_exc: + raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc + except Exception as ex: + raise ToolExecutionException(f"Failed to call prompt '{prompt_name}'.", inner_exception=ex) from ex + raise ToolExecutionException(f"Failed to get prompt '{prompt_name}' after retries.") async def __aenter__(self) -> Self: """Enter the async context manager. @@ -803,7 +897,9 @@ def __init__( command: str, *, load_tools: bool = True, + parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True, load_prompts: bool = True, + parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True, request_timeout: int | None = None, session: ClientSession | None = None, description: str | None = None, @@ -829,7 +925,15 @@ def __init__( Keyword Args: load_tools: Whether to load tools from the MCP server. + parse_tool_results: How to parse tool results from the MCP server. + Set to True, to use the default parser that converts to Agent Framework types. + Set to a callable to use a custom parser function. + Set to None to return the raw MCP tool result. load_prompts: Whether to load prompts from the MCP server. + parse_prompt_results: How to parse prompt results from the MCP server. + Set to True, to use the default parser that converts to Agent Framework types. + Set to a callable to use a custom parser function. + Set to None to return the raw MCP prompt result. request_timeout: The default timeout in seconds for all requests. session: The session to use for the MCP connection. description: The description of the tool. @@ -856,7 +960,9 @@ def __init__( session=session, chat_client=chat_client, load_tools=load_tools, + parse_tool_results=parse_tool_results, load_prompts=load_prompts, + parse_prompt_results=parse_prompt_results, request_timeout=request_timeout, ) self.command = command @@ -897,7 +1003,6 @@ class MCPStreamableHTTPTool(MCPTool): mcp_tool = MCPStreamableHTTPTool( name="web-api", url="https://api.example.com/mcp", - headers={"Authorization": "Bearer token"}, description="Web API operations", ) @@ -913,27 +1018,27 @@ def __init__( url: str, *, load_tools: bool = True, + parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True, load_prompts: bool = True, + parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True, request_timeout: int | None = None, session: ClientSession | None = None, description: str | None = None, approval_mode: (Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None) = None, allowed_tools: Collection[str] | None = None, - headers: dict[str, Any] | None = None, - timeout: float | None = None, - sse_read_timeout: float | None = None, terminate_on_close: bool | None = None, chat_client: "ChatClientProtocol | None" = None, additional_properties: dict[str, Any] | None = None, + http_client: httpx.AsyncClient | None = None, **kwargs: Any, ) -> None: """Initialize the MCP streamable HTTP tool. Note: - The arguments are used to create a streamable HTTP client. - See ``mcp.client.streamable_http.streamablehttp_client`` for more details. - Any extra arguments passed to the constructor will be passed to the - streamable HTTP client constructor. + The arguments are used to create a streamable HTTP client using the + new ``mcp.client.streamable_http.streamable_http_client`` API. + If an httpx.AsyncClient is provided via ``http_client``, it will be used directly. + Otherwise, the ``streamable_http_client`` API will create and manage a default client. Args: name: The name of the tool. @@ -941,7 +1046,15 @@ def __init__( Keyword Args: load_tools: Whether to load tools from the MCP server. + parse_tool_results: How to parse tool results from the MCP server. + Set to True, to use the default parser that converts to Agent Framework types. + Set to a callable to use a custom parser function. + Set to None to return the raw MCP tool result. load_prompts: Whether to load prompts from the MCP server. + parse_prompt_results: How to parse prompt results from the MCP server. + Set to True, to use the default parser that converts to Agent Framework types. + Set to a callable to use a custom parser function. + Set to None to return the raw MCP prompt result. request_timeout: The default timeout in seconds for all requests. session: The session to use for the MCP connection. description: The description of the tool. @@ -953,12 +1066,13 @@ def __init__( A tool should not be listed in both, if so, it will require approval. allowed_tools: A list of tools that are allowed to use this tool. additional_properties: Additional properties. - headers: The headers to send with the request. - timeout: The timeout for the request. - sse_read_timeout: The timeout for reading from the SSE stream. terminate_on_close: Close the transport when the MCP client is terminated. chat_client: The chat client to use for sampling. - kwargs: Any extra arguments to pass to the SSE client. + http_client: Optional httpx.AsyncClient to use. If not provided, the + ``streamable_http_client`` API will create and manage a default client. + To configure headers, timeouts, or other HTTP client settings, create + and pass your own ``httpx.AsyncClient`` instance. + kwargs: Additional keyword arguments (accepted for backward compatibility but not used). """ super().__init__( name=name, @@ -969,15 +1083,14 @@ def __init__( session=session, chat_client=chat_client, load_tools=load_tools, + parse_tool_results=parse_tool_results, load_prompts=load_prompts, + parse_prompt_results=parse_prompt_results, request_timeout=request_timeout, ) self.url = url - self.headers = headers or {} - self.timeout = timeout - self.sse_read_timeout = sse_read_timeout self.terminate_on_close = terminate_on_close - self._client_kwargs = kwargs + self._httpx_client: httpx.AsyncClient | None = http_client def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: """Get an MCP streamable HTTP client. @@ -985,20 +1098,12 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: Returns: An async context manager for the streamable HTTP client transport. """ - args: dict[str, Any] = { - "url": self.url, - } - if self.headers: - args["headers"] = self.headers - if self.timeout is not None: - args["timeout"] = self.timeout - if self.sse_read_timeout is not None: - args["sse_read_timeout"] = self.sse_read_timeout - if self.terminate_on_close is not None: - args["terminate_on_close"] = self.terminate_on_close - if self._client_kwargs: - args.update(self._client_kwargs) - return streamablehttp_client(**args) + # Pass the http_client (which may be None) to streamable_http_client + return streamable_http_client( + url=self.url, + http_client=self._httpx_client, + terminate_on_close=self.terminate_on_close if self.terminate_on_close is not None else True, + ) class MCPWebsocketTool(MCPTool): @@ -1028,7 +1133,9 @@ def __init__( url: str, *, load_tools: bool = True, + parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True, load_prompts: bool = True, + parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True, request_timeout: int | None = None, session: ClientSession | None = None, description: str | None = None, @@ -1052,7 +1159,15 @@ def __init__( Keyword Args: load_tools: Whether to load tools from the MCP server. + parse_tool_results: How to parse tool results from the MCP server. + Set to True, to use the default parser that converts to Agent Framework types. + Set to a callable to use a custom parser function. + Set to None to return the raw MCP tool result. load_prompts: Whether to load prompts from the MCP server. + parse_prompt_results: How to parse prompt results from the MCP server. + Set to True, to use the default parser that converts to Agent Framework types. + Set to a callable to use a custom parser function. + Set to None to return the raw MCP prompt result. request_timeout: The default timeout in seconds for all requests. session: The session to use for the MCP connection. description: The description of the tool. @@ -1076,7 +1191,9 @@ def __init__( session=session, chat_client=chat_client, load_tools=load_tools, + parse_tool_results=parse_tool_results, load_prompts=load_prompts, + parse_prompt_results=parse_prompt_results, request_timeout=request_timeout, ) self.url = url diff --git a/python/packages/core/agent_framework/_memory.py b/python/packages/core/agent_framework/_memory.py index a5b53fc39f..5e46b1749d 100644 --- a/python/packages/core/agent_framework/_memory.py +++ b/python/packages/core/agent_framework/_memory.py @@ -1,22 +1,16 @@ # Copyright (c) Microsoft. All rights reserved. -import asyncio import sys from abc import ABC, abstractmethod from collections.abc import MutableSequence, Sequence -from contextlib import AsyncExitStack from types import TracebackType -from typing import TYPE_CHECKING, Any, Final, cast +from typing import TYPE_CHECKING, Any, Final from ._types import ChatMessage if TYPE_CHECKING: from ._tools import ToolProtocol -if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover -else: - from typing_extensions import override # type: ignore[import] # pragma: no cover if sys.version_info >= (3, 11): from typing import Self # pragma: no cover else: @@ -24,7 +18,7 @@ # region Context -__all__ = ["AggregateContextProvider", "Context", "ContextProvider"] +__all__ = ["Context", "ContextProvider"] class Context: @@ -100,7 +94,7 @@ async def invoking(self, messages, **kwargs): # Use with a chat agent async with CustomContextProvider() as provider: - agent = ChatAgent(chat_client=client, name="assistant", context_providers=provider) + agent = ChatAgent(chat_client=client, name="assistant", context_provider=provider) """ # Default prompt to be used by all context providers when assembling memories/instructions @@ -183,130 +177,3 @@ async def __aexit__( exc_tb: The exception traceback if an exception occurred, None otherwise. """ pass - - -# region AggregateContextProvider - - -class AggregateContextProvider(ContextProvider): - """A ContextProvider that contains multiple context providers. - - It delegates events to multiple context providers and aggregates responses from those - events before returning. This allows you to combine multiple context providers into a - single provider. - - Note: - An AggregateContextProvider is created automatically when you pass a single context - provider or a sequence of context providers to the agent constructor. - - Examples: - .. code-block:: python - - from agent_framework import AggregateContextProvider, ChatAgent - - # Create multiple context providers - provider1 = CustomContextProvider1() - provider2 = CustomContextProvider2() - provider3 = CustomContextProvider3() - - # Pass them to the agent - AggregateContextProvider is created automatically - agent = ChatAgent(chat_client=client, name="assistant", context_providers=[provider1, provider2, provider3]) - - # Verify that an AggregateContextProvider was created - assert isinstance(agent.context_providers, AggregateContextProvider) - - # Add additional providers to the agent - provider4 = CustomContextProvider4() - agent.context_providers.add(provider4) - """ - - def __init__(self, context_providers: ContextProvider | Sequence[ContextProvider] | None = None) -> None: - """Initialize the AggregateContextProvider with context providers. - - Args: - context_providers: The context provider(s) to add. - """ - if isinstance(context_providers, ContextProvider): - self.providers = [context_providers] - else: - self.providers = cast(list[ContextProvider], context_providers) or [] - self._exit_stack: AsyncExitStack | None = None - - def add(self, context_provider: ContextProvider) -> None: - """Add a new context provider. - - Args: - context_provider: The context provider to add. - """ - self.providers.append(context_provider) - - @override - async def thread_created(self, thread_id: str | None = None) -> None: - await asyncio.gather(*[x.thread_created(thread_id) for x in self.providers]) - - @override - async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context: - contexts = await asyncio.gather(*[provider.invoking(messages, **kwargs) for provider in self.providers]) - instructions: str = "" - return_messages: list[ChatMessage] = [] - tools: list["ToolProtocol"] = [] - for ctx in contexts: - if ctx.instructions: - instructions += ctx.instructions - if ctx.messages: - return_messages.extend(ctx.messages) - if ctx.tools: - tools.extend(ctx.tools) - return Context(instructions=instructions, messages=return_messages, tools=tools) - - @override - async def invoked( - self, - request_messages: ChatMessage | Sequence[ChatMessage], - response_messages: ChatMessage | Sequence[ChatMessage] | None = None, - invoke_exception: Exception | None = None, - **kwargs: Any, - ) -> None: - await asyncio.gather(*[ - x.invoked( - request_messages=request_messages, - response_messages=response_messages, - invoke_exception=invoke_exception, - **kwargs, - ) - for x in self.providers - ]) - - @override - async def __aenter__(self) -> "Self": - """Enter the async context manager and set up all providers. - - Returns: - The AggregateContextProvider instance for chaining. - """ - self._exit_stack = AsyncExitStack() - await self._exit_stack.__aenter__() - - # Enter all context providers - for provider in self.providers: - await self._exit_stack.enter_async_context(provider) - - return self - - @override - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Exit the async context manager and clean up all providers. - - Args: - exc_type: The exception type if an exception occurred, None otherwise. - exc_val: The exception value if an exception occurred, None otherwise. - exc_tb: The exception traceback if an exception occurred, None otherwise. - """ - if self._exit_stack is not None: - await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) - self._exit_stack = None diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 4e36cb764a..0e26565c5a 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -2,13 +2,13 @@ import inspect from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence +from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableSequence, Sequence from enum import Enum from functools import update_wrapper -from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypedDict, TypeVar from ._serialization import SerializationMixin -from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, prepare_messages +from ._types import AgentResponse, AgentResponseUpdate, ChatMessage, prepare_messages from .exceptions import MiddlewareException if TYPE_CHECKING: @@ -18,12 +18,12 @@ from ._clients import ChatClientProtocol from ._threads import AgentThread from ._tools import AIFunction - from ._types import ChatOptions, ChatResponse, ChatResponseUpdate + from ._types import ChatResponse, ChatResponseUpdate __all__ = [ "AgentMiddleware", - "AgentMiddlewares", + "AgentMiddlewareTypes", "AgentRunContext", "ChatContext", "ChatMiddleware", @@ -38,7 +38,7 @@ ] TAgent = TypeVar("TAgent", bound="AgentProtocol") -TChatClient = TypeVar("TChatClient", bound="ChatClientProtocol") +TChatClient = TypeVar("TChatClient", bound="ChatClientProtocol[Any]") TContext = TypeVar("TContext") @@ -67,8 +67,8 @@ class AgentRunContext(SerializationMixin): metadata: Metadata dictionary for sharing data between agent middleware. result: Agent execution result. Can be observed after calling ``next()`` to see the actual execution result or can be set to override the execution result. - For non-streaming: should be AgentRunResponse. - For streaming: should be AsyncIterable[AgentRunResponseUpdate]. + For non-streaming: should be AgentResponse. + For streaming: should be AsyncIterable[AgentResponseUpdate]. terminate: A flag indicating whether to terminate execution after current middleware. When set to True, execution will stop as soon as control returns to framework. kwargs: Additional keyword arguments passed to the agent run method. @@ -105,7 +105,7 @@ def __init__( thread: "AgentThread | None" = None, is_streaming: bool = False, metadata: dict[str, Any] | None = None, - result: AgentRunResponse | AsyncIterable[AgentRunResponseUpdate] | None = None, + result: AgentResponse | AsyncIterable[AgentResponseUpdate] | None = None, terminate: bool = False, kwargs: dict[str, Any] | None = None, ) -> None: @@ -206,7 +206,7 @@ class ChatContext(SerializationMixin): Attributes: chat_client: The chat client being invoked. messages: The messages being sent to the chat client. - chat_options: The options for the chat request. + options: The options for the chat request as a dict. is_streaming: Whether this is a streaming invocation. metadata: Metadata dictionary for sharing data between chat middleware. result: Chat execution result. Can be observed after calling ``next()`` @@ -227,7 +227,7 @@ class TokenCounterMiddleware(ChatMiddleware): async def process(self, context: ChatContext, next): print(f"Chat client: {context.chat_client.__class__.__name__}") print(f"Messages: {len(context.messages)}") - print(f"Model: {context.chat_options.model_id}") + print(f"Model: {context.options.get('model_id')}") # Store metadata context.metadata["input_tokens"] = self.count_tokens(context.messages) @@ -246,7 +246,7 @@ def __init__( self, chat_client: "ChatClientProtocol", messages: "MutableSequence[ChatMessage]", - chat_options: "ChatOptions", + options: Mapping[str, Any] | None, is_streaming: bool = False, metadata: dict[str, Any] | None = None, result: "ChatResponse | AsyncIterable[ChatResponseUpdate] | None" = None, @@ -258,7 +258,7 @@ def __init__( Args: chat_client: The chat client being invoked. messages: The messages being sent to the chat client. - chat_options: The options for the chat request. + options: The options for the chat request as a dict. is_streaming: Whether this is a streaming invocation. metadata: Metadata dictionary for sharing data between chat middleware. result: Chat execution result. @@ -267,7 +267,7 @@ def __init__( """ self.chat_client = chat_client self.messages = messages - self.chat_options = chat_options + self.options = options self.is_streaming = is_streaming self.metadata = metadata if metadata is not None else {} self.result = result @@ -305,7 +305,7 @@ async def process(self, context: AgentRunContext, next): # Use with an agent - agent = ChatAgent(chat_client=client, name="assistant", middleware=RetryMiddleware()) + agent = ChatAgent(chat_client=client, name="assistant", middleware=[RetryMiddleware()]) """ @abstractmethod @@ -321,8 +321,8 @@ async def process( Use context.is_streaming to determine if this is a streaming call. Middleware can set context.result to override execution, or observe the actual execution result after calling next(). - For non-streaming: AgentRunResponse - For streaming: AsyncIterable[AgentRunResponseUpdate] + For non-streaming: AgentResponse + For streaming: AsyncIterable[AgentResponseUpdate] next: Function to call the next middleware or final agent execution. Does not return anything - all data flows through the context. @@ -373,7 +373,7 @@ async def process(self, context: FunctionInvocationContext, next): # Use with an agent - agent = ChatAgent(chat_client=client, name="assistant", middleware=CachingMiddleware()) + agent = ChatAgent(chat_client=client, name="assistant", middleware=[CachingMiddleware()]) """ @abstractmethod @@ -432,7 +432,9 @@ async def process(self, context: ChatContext, next): # Use with an agent agent = ChatAgent( - chat_client=client, name="assistant", middleware=SystemPromptMiddleware("You are a helpful assistant.") + chat_client=client, + name="assistant", + middleware=[SystemPromptMiddleware("You are a helpful assistant.")], ) """ @@ -480,7 +482,7 @@ async def process( | ChatMiddleware | ChatMiddlewareCallable ) -AgentMiddlewares: TypeAlias = AgentMiddleware | AgentMiddlewareCallable +AgentMiddlewareTypes: TypeAlias = AgentMiddleware | AgentMiddlewareCallable # region Middleware type markers for decorators @@ -511,7 +513,7 @@ async def logging_middleware(context: AgentRunContext, next): # Use with an agent - agent = ChatAgent(chat_client=client, name="assistant", middleware=logging_middleware) + agent = ChatAgent(chat_client=client, name="assistant", middleware=[logging_middleware]) """ # Add marker attribute to identify this as agent middleware func._middleware_type: MiddlewareType = MiddlewareType.AGENT # type: ignore @@ -544,7 +546,7 @@ async def logging_middleware(context: FunctionInvocationContext, next): # Use with an agent - agent = ChatAgent(chat_client=client, name="assistant", middleware=logging_middleware) + agent = ChatAgent(chat_client=client, name="assistant", middleware=[logging_middleware]) """ # Add marker attribute to identify this as function middleware func._middleware_type: MiddlewareType = MiddlewareType.FUNCTION # type: ignore @@ -577,7 +579,7 @@ async def logging_middleware(context: ChatContext, next): # Use with an agent - agent = ChatAgent(chat_client=client, name="assistant", middleware=logging_middleware) + agent = ChatAgent(chat_client=client, name="assistant", middleware=[logging_middleware]) """ # Add marker attribute to identify this as chat middleware func._middleware_type: MiddlewareType = MiddlewareType.CHAT # type: ignore @@ -609,7 +611,7 @@ class BaseMiddlewarePipeline(ABC): def __init__(self) -> None: """Initialize the base middleware pipeline.""" - self._middlewares: list[Any] = [] + self._middleware: list[Any] = [] @abstractmethod def _register_middleware(self, middleware: Any) -> None: @@ -624,12 +626,12 @@ def _register_middleware(self, middleware: Any) -> None: @property def has_middlewares(self) -> bool: - """Check if there are any middlewares registered. + """Check if there are any middleware registered. Returns: - True if middlewares are registered, False otherwise. + True if middleware are registered, False otherwise. """ - return bool(self._middlewares) + return bool(self._middleware) def _register_middleware_with_wrapper( self, @@ -645,9 +647,9 @@ def _register_middleware_with_wrapper( expected_type: The expected middleware base class type. """ if isinstance(middleware, expected_type): - self._middlewares.append(middleware) + self._middleware.append(middleware) elif callable(middleware): - self._middlewares.append(MiddlewareWrapper(middleware)) # type: ignore[arg-type] + self._middleware.append(MiddlewareWrapper(middleware)) # type: ignore[arg-type] def _create_handler_chain( self, @@ -667,7 +669,7 @@ def _create_handler_chain( """ def create_next_handler(index: int) -> Callable[[Any], Awaitable[None]]: - if index >= len(self._middlewares): + if index >= len(self._middleware): async def final_wrapper(c: Any) -> None: # Execute actual handler and populate context for observability @@ -677,7 +679,7 @@ async def final_wrapper(c: Any) -> None: return final_wrapper - middleware = self._middlewares[index] + middleware = self._middleware[index] next_handler = create_next_handler(index + 1) async def current_handler(c: Any) -> None: @@ -705,7 +707,7 @@ def _create_streaming_handler_chain( """ def create_next_handler(index: int) -> Callable[[Any], Awaitable[None]]: - if index >= len(self._middlewares): + if index >= len(self._middleware): async def final_wrapper(c: Any) -> None: # If terminate was set, skip execution @@ -724,7 +726,7 @@ async def final_wrapper(c: Any) -> None: return final_wrapper - middleware = self._middlewares[index] + middleware = self._middleware[index] next_handler = create_next_handler(index + 1) async def current_handler(c: Any) -> None: @@ -745,20 +747,20 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline): to process the agent invocation and pass control to the next middleware in the chain. """ - def __init__(self, middlewares: list[AgentMiddleware | AgentMiddlewareCallable] | None = None): + def __init__(self, middleware: Sequence[AgentMiddlewareTypes] | None = None): """Initialize the agent middleware pipeline. Args: - middlewares: The list of agent middleware to include in the pipeline. + middleware: The list of agent middleware to include in the pipeline. """ super().__init__() - self._middlewares: list[AgentMiddleware] = [] + self._middleware: list[AgentMiddleware] = [] - if middlewares: - for middleware in middlewares: - self._register_middleware(middleware) + if middleware: + for mdlware in middleware: + self._register_middleware(mdlware) - def _register_middleware(self, middleware: AgentMiddleware | AgentMiddlewareCallable) -> None: + def _register_middleware(self, middleware: AgentMiddlewareTypes) -> None: """Register an agent middleware item. Args: @@ -771,8 +773,8 @@ async def execute( agent: "AgentProtocol", messages: list[ChatMessage], context: AgentRunContext, - final_handler: Callable[[AgentRunContext], Awaitable[AgentRunResponse]], - ) -> AgentRunResponse | None: + final_handler: Callable[[AgentRunContext], Awaitable[AgentResponse]], + ) -> AgentResponse | None: """Execute the agent middleware pipeline for non-streaming. Args: @@ -789,19 +791,19 @@ async def execute( context.messages = messages context.is_streaming = False - if not self._middlewares: + if not self._middleware: return await final_handler(context) # Store the final result - result_container: dict[str, AgentRunResponse | None] = {"result": None} + result_container: dict[str, AgentResponse | None] = {"result": None} # Custom final handler that handles termination and result override - async def agent_final_handler(c: AgentRunContext) -> AgentRunResponse: + async def agent_final_handler(c: AgentRunContext) -> AgentResponse: # If terminate was set, return the result (which might be None) if c.terminate: - if c.result is not None and isinstance(c.result, AgentRunResponse): + if c.result is not None and isinstance(c.result, AgentResponse): return c.result - return AgentRunResponse() + return AgentResponse() # Execute actual handler and populate context for observability return await final_handler(c) @@ -809,13 +811,13 @@ async def agent_final_handler(c: AgentRunContext) -> AgentRunResponse: await first_handler(context) # Return the result from result container or overridden result - if context.result is not None and isinstance(context.result, AgentRunResponse): + if context.result is not None and isinstance(context.result, AgentResponse): return context.result - # If no result was set (next() not called), return empty AgentRunResponse + # If no result was set (next() not called), return empty AgentResponse response = result_container.get("result") if response is None: - return AgentRunResponse() + return AgentResponse() return response async def execute_stream( @@ -823,8 +825,8 @@ async def execute_stream( agent: "AgentProtocol", messages: list[ChatMessage], context: AgentRunContext, - final_handler: Callable[[AgentRunContext], AsyncIterable[AgentRunResponseUpdate]], - ) -> AsyncIterable[AgentRunResponseUpdate]: + final_handler: Callable[[AgentRunContext], AsyncIterable[AgentResponseUpdate]], + ) -> AsyncIterable[AgentResponseUpdate]: """Execute the agent middleware pipeline for streaming. Args: @@ -841,13 +843,13 @@ async def execute_stream( context.messages = messages context.is_streaming = True - if not self._middlewares: + if not self._middleware: async for update in final_handler(context): yield update return # Store the final result - result_container: dict[str, AsyncIterable[AgentRunResponseUpdate] | None] = {"result_stream": None} + result_container: dict[str, AsyncIterable[AgentResponseUpdate] | None] = {"result_stream": None} first_handler = self._create_streaming_handler_chain(final_handler, result_container, "result_stream") await first_handler(context) @@ -874,18 +876,18 @@ class FunctionMiddlewarePipeline(BaseMiddlewarePipeline): to process the function invocation and pass control to the next middleware in the chain. """ - def __init__(self, middlewares: list[FunctionMiddleware | FunctionMiddlewareCallable] | None = None): + def __init__(self, middleware: Sequence[FunctionMiddleware | FunctionMiddlewareCallable] | None = None): """Initialize the function middleware pipeline. Args: - middlewares: The list of function middleware to include in the pipeline. + middleware: The list of function middleware to include in the pipeline. """ super().__init__() - self._middlewares: list[FunctionMiddleware] = [] + self._middleware: list[FunctionMiddleware] = [] - if middlewares: - for middleware in middlewares: - self._register_middleware(middleware) + if middleware: + for mdlware in middleware: + self._register_middleware(mdlware) def _register_middleware(self, middleware: FunctionMiddleware | FunctionMiddlewareCallable) -> None: """Register a function middleware item. @@ -917,7 +919,7 @@ async def execute( context.function = function context.arguments = arguments - if not self._middlewares: + if not self._middleware: return await final_handler(context) # Store the final result @@ -947,18 +949,18 @@ class ChatMiddlewarePipeline(BaseMiddlewarePipeline): to process the chat request and pass control to the next middleware in the chain. """ - def __init__(self, middlewares: list[ChatMiddleware | ChatMiddlewareCallable] | None = None): + def __init__(self, middleware: Sequence[ChatMiddleware | ChatMiddlewareCallable] | None = None): """Initialize the chat middleware pipeline. Args: - middlewares: The list of chat middleware to include in the pipeline. + middleware: The list of chat middleware to include in the pipeline. """ super().__init__() - self._middlewares: list[ChatMiddleware] = [] + self._middleware: list[ChatMiddleware] = [] - if middlewares: - for middleware in middlewares: - self._register_middleware(middleware) + if middleware: + for mdlware in middleware: + self._register_middleware(mdlware) def _register_middleware(self, middleware: ChatMiddleware | ChatMiddlewareCallable) -> None: """Register a chat middleware item. @@ -972,7 +974,7 @@ async def execute( self, chat_client: "ChatClientProtocol", messages: "MutableSequence[ChatMessage]", - chat_options: "ChatOptions", + options: Mapping[str, Any] | None, context: ChatContext, final_handler: Callable[[ChatContext], Awaitable["ChatResponse"]], **kwargs: Any, @@ -982,7 +984,7 @@ async def execute( Args: chat_client: The chat client being invoked. messages: The messages being sent to the chat client. - chat_options: The options for the chat request. + options: The options for the chat request as a dict. context: The chat invocation context. final_handler: The final handler that performs the actual chat execution. **kwargs: Additional keyword arguments. @@ -993,9 +995,10 @@ async def execute( # Update context with chat client, messages, and options context.chat_client = chat_client context.messages = messages - context.chat_options = chat_options + if options: + context.options = options - if not self._middlewares: + if not self._middleware: return await final_handler(context) # Store the final result @@ -1021,7 +1024,7 @@ async def execute_stream( self, chat_client: "ChatClientProtocol", messages: "MutableSequence[ChatMessage]", - chat_options: "ChatOptions", + options: Mapping[str, Any] | None, context: ChatContext, final_handler: Callable[[ChatContext], AsyncIterable["ChatResponseUpdate"]], **kwargs: Any, @@ -1031,7 +1034,7 @@ async def execute_stream( Args: chat_client: The chat client being invoked. messages: The messages being sent to the chat client. - chat_options: The options for the chat request. + options: The options for the chat request as a dict. context: The chat invocation context. final_handler: The final handler that performs the actual streaming chat execution. **kwargs: Additional keyword arguments. @@ -1042,10 +1045,11 @@ async def execute_stream( # Update context with chat client, messages, and options context.chat_client = chat_client context.messages = messages - context.chat_options = chat_options + if options: + context.options = options context.is_streaming = True - if not self._middlewares: + if not self._middleware: async for update in final_handler(context): yield update return @@ -1182,8 +1186,8 @@ async def run_stream(self, messages, **kwargs): original_run_stream = agent_class.run_stream # type: ignore[attr-defined] def _build_middleware_pipelines( - agent_level_middlewares: Middleware | list[Middleware] | None, - run_level_middlewares: Middleware | list[Middleware] | None = None, + agent_level_middlewares: Sequence[Middleware] | None, + run_level_middlewares: Sequence[Middleware] | None = None, ) -> tuple[AgentMiddlewarePipeline, FunctionMiddlewarePipeline, list[ChatMiddleware | ChatMiddlewareCallable]]: """Build fresh agent and function middleware pipelines from the provided middleware lists. @@ -1191,7 +1195,7 @@ def _build_middleware_pipelines( agent_level_middlewares: Agent-level middleware (executed first) run_level_middlewares: Run-level middleware (executed after agent middleware) """ - middleware = categorize_middleware(agent_level_middlewares, run_level_middlewares) + middleware = categorize_middleware(*(agent_level_middlewares or ()), *(run_level_middlewares or ())) return ( AgentMiddlewarePipeline(middleware["agent"]), # type: ignore[arg-type] @@ -1204,9 +1208,9 @@ async def middleware_enabled_run( messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: Any = None, - middleware: Middleware | list[Middleware] | None = None, + middleware: Sequence[Middleware] | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: """Middleware-enabled run method.""" # Build fresh middleware pipelines from current middleware collection and run-level middleware agent_middleware = getattr(self, "middleware", None) @@ -1233,7 +1237,7 @@ async def middleware_enabled_run( kwargs=kwargs, ) - async def _execute_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def _execute_handler(ctx: AgentRunContext) -> AgentResponse: return await original_run(self, ctx.messages, thread=thread, **ctx.kwargs) # type: ignore result = await agent_pipeline.execute( @@ -1243,7 +1247,7 @@ async def _execute_handler(ctx: AgentRunContext) -> AgentRunResponse: _execute_handler, ) - return result if result else AgentRunResponse() + return result if result else AgentResponse() # No middleware, execute directly return await original_run(self, normalized_messages, thread=thread, **kwargs) # type: ignore[return-value] @@ -1253,9 +1257,9 @@ def middleware_enabled_run_stream( messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: Any = None, - middleware: Middleware | list[Middleware] | None = None, + middleware: Sequence[Middleware] | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: """Middleware-enabled run_stream method.""" # Build fresh middleware pipelines from current middleware collection and run-level middleware agent_middleware = getattr(self, "middleware", None) @@ -1281,11 +1285,11 @@ def middleware_enabled_run_stream( kwargs=kwargs, ) - async def _execute_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + async def _execute_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: async for update in original_run_stream(self, ctx.messages, thread=thread, **ctx.kwargs): # type: ignore[misc] yield update - async def _stream_generator() -> AsyncIterable[AgentRunResponseUpdate]: + async def _stream_generator() -> AsyncIterable[AgentResponseUpdate]: async for update in agent_pipeline.execute_stream( self, # type: ignore[arg-type] normalized_messages, @@ -1344,6 +1348,8 @@ async def get_streaming_response(self, messages, **kwargs): async def middleware_enabled_get_response( self: Any, messages: Any, + *, + options: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Any: """Middleware-enabled get_response method.""" @@ -1364,30 +1370,35 @@ async def middleware_enabled_get_response( # If no chat middleware, use original method if not chat_middleware_list: - return await original_get_response(self, messages, **kwargs) + return await original_get_response( + self, + messages, + options=options, # type: ignore[arg-type] + **kwargs, + ) # Create pipeline and execute with middleware - from ._types import ChatOptions - - # Extract chat_options or create default - chat_options = kwargs.pop("chat_options", ChatOptions()) - pipeline = ChatMiddlewarePipeline(chat_middleware_list) # type: ignore[arg-type] context = ChatContext( chat_client=self, messages=prepare_messages(messages), - chat_options=chat_options, + options=options, is_streaming=False, kwargs=kwargs, ) async def final_handler(ctx: ChatContext) -> Any: - return await original_get_response(self, list(ctx.messages), chat_options=ctx.chat_options, **ctx.kwargs) + return await original_get_response( + self, + list(ctx.messages), + options=ctx.options, # type: ignore[arg-type] + **ctx.kwargs, + ) return await pipeline.execute( chat_client=self, messages=context.messages, - chat_options=context.chat_options, + options=options, context=context, final_handler=final_handler, **kwargs, @@ -1396,6 +1407,8 @@ async def final_handler(ctx: ChatContext) -> Any: def middleware_enabled_get_streaming_response( self: Any, messages: Any, + *, + options: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: """Middleware-enabled get_streaming_response method.""" @@ -1416,34 +1429,37 @@ async def _stream_generator() -> Any: # If no chat middleware, use original method if not chat_middleware_list: - async for update in original_get_streaming_response(self, messages, **kwargs): + async for update in original_get_streaming_response( + self, + messages, + options=options, # type: ignore[arg-type] + **kwargs, + ): yield update return # Create pipeline and execute with middleware - from ._types import ChatOptions - - # Extract chat_options or create default - chat_options = kwargs.pop("chat_options", ChatOptions()) - pipeline = ChatMiddlewarePipeline(chat_middleware_list) # type: ignore[arg-type] context = ChatContext( chat_client=self, messages=prepare_messages(messages), - chat_options=chat_options, + options=options or {}, is_streaming=True, kwargs=kwargs, ) def final_handler(ctx: ChatContext) -> Any: return original_get_streaming_response( - self, list(ctx.messages), chat_options=ctx.chat_options, **ctx.kwargs + self, + list(ctx.messages), + options=ctx.options, # type: ignore[arg-type] + **ctx.kwargs, ) async for update in pipeline.execute_stream( chat_client=self, messages=context.messages, - chat_options=context.chat_options, + options=options or {}, context=context, final_handler=final_handler, **kwargs, @@ -1461,9 +1477,15 @@ def final_handler(ctx: ChatContext) -> Any: return chat_client_class +class MiddlewareDict(TypedDict): + agent: list[AgentMiddleware | AgentMiddlewareCallable] + function: list[FunctionMiddleware | FunctionMiddlewareCallable] + chat: list[ChatMiddleware | ChatMiddlewareCallable] + + def categorize_middleware( - *middleware_sources: Any | list[Any] | None, -) -> dict[str, list[Any]]: + *middleware_sources: Middleware | None, +) -> MiddlewareDict: """Categorize middleware from multiple sources into agent, function, and chat types. Args: @@ -1472,7 +1494,7 @@ def categorize_middleware( Returns: Dict with keys "agent", "function", "chat" containing lists of categorized middleware. """ - result: dict[str, list[Any]] = {"agent": [], "function": [], "chat": []} + result: MiddlewareDict = {"agent": [], "function": [], "chat": []} # Merge all middleware sources into a single list all_middleware: list[Any] = [] @@ -1495,11 +1517,11 @@ def categorize_middleware( # Always call _determine_middleware_type to ensure proper validation middleware_type = _determine_middleware_type(middleware) if middleware_type == MiddlewareType.AGENT: - result["agent"].append(middleware) + result["agent"].append(middleware) # type: ignore elif middleware_type == MiddlewareType.FUNCTION: - result["function"].append(middleware) + result["function"].append(middleware) # type: ignore elif middleware_type == MiddlewareType.CHAT: - result["chat"].append(middleware) + result["chat"].append(middleware) # type: ignore else: # Fallback to agent middleware for unknown types result["agent"].append(middleware) @@ -1508,7 +1530,7 @@ def categorize_middleware( def create_function_middleware_pipeline( - *middleware_sources: list[Middleware] | None, + *middleware_sources: Middleware, ) -> FunctionMiddlewarePipeline | None: """Create a function middleware pipeline from multiple middleware sources. @@ -1518,28 +1540,10 @@ def create_function_middleware_pipeline( Returns: A FunctionMiddlewarePipeline if function middleware is found, None otherwise. """ - middleware = categorize_middleware(*middleware_sources) - function_middlewares = middleware["function"] + function_middlewares = categorize_middleware(*middleware_sources)["function"] return FunctionMiddlewarePipeline(function_middlewares) if function_middlewares else None # type: ignore[arg-type] -def _merge_and_filter_chat_middleware( - instance_middleware: Any | list[Any] | None, - call_middleware: Any | list[Any] | None, -) -> list[ChatMiddleware | ChatMiddlewareCallable]: - """Merge instance-level and call-level middleware, filtering for chat middleware only. - - Args: - instance_middleware: Middleware defined at the instance level. - call_middleware: Middleware provided at the call level. - - Returns: - A merged list of chat middleware only. - """ - middleware = categorize_middleware(instance_middleware, call_middleware) - return middleware["chat"] # type: ignore[return-value] - - def extract_and_merge_function_middleware( chat_client: Any, kwargs: dict[str, Any] ) -> "FunctionMiddlewarePipeline | None": @@ -1556,7 +1560,7 @@ def extract_and_merge_function_middleware( existing_pipeline: FunctionMiddlewarePipeline | None = kwargs.get("_function_middleware_pipeline") # Get middleware sources - client_middleware = getattr(chat_client, "middleware", None) if hasattr(chat_client, "middleware") else None + client_middleware = getattr(chat_client, "middleware", None) run_level_middleware = kwargs.get("middleware") # If we have an existing pipeline but no additional middleware sources, return it directly @@ -1564,15 +1568,15 @@ def extract_and_merge_function_middleware( return existing_pipeline # If we have an existing pipeline with additional middleware, we need to merge - # Extract existing pipeline middlewares if present - cast to list[Middleware] for type compatibility - existing_middlewares: list[Middleware] | None = list(existing_pipeline._middlewares) if existing_pipeline else None + # Extract existing pipeline middleware if present - cast to list[Middleware] for type compatibility + existing_middleware: list[Middleware] | None = list(existing_pipeline._middleware) if existing_pipeline else None # Create combined pipeline from all sources using existing helper combined_pipeline = create_function_middleware_pipeline( - client_middleware, run_level_middleware, existing_middlewares + *(client_middleware or ()), *(run_level_middleware or ()), *(existing_middleware or ()) ) - # If we have an existing pipeline but combined is None (no new middlewares), return existing + # If we have an existing pipeline but combined is None (no new middleware), return existing if existing_pipeline and combined_pipeline is None: return existing_pipeline diff --git a/python/packages/core/agent_framework/_serialization.py b/python/packages/core/agent_framework/_serialization.py index cf28df2f4f..8aa9b6adcf 100644 --- a/python/packages/core/agent_framework/_serialization.py +++ b/python/packages/core/agent_framework/_serialization.py @@ -259,7 +259,7 @@ def __init__(self, **kwargs): agent = CustomAgent( - context_providers=[...], + context_provider=[...], middleware=[...] ) diff --git a/python/packages/core/agent_framework/_threads.py b/python/packages/core/agent_framework/_threads.py index 92469a78d5..e44c362324 100644 --- a/python/packages/core/agent_framework/_threads.py +++ b/python/packages/core/agent_framework/_threads.py @@ -3,7 +3,7 @@ from collections.abc import MutableMapping, Sequence from typing import Any, Protocol, TypeVar -from ._memory import AggregateContextProvider +from ._memory import ContextProvider from ._serialization import SerializationMixin from ._types import ChatMessage from .exceptions import AgentThreadException @@ -327,7 +327,7 @@ def __init__( *, service_thread_id: str | None = None, message_store: ChatMessageStoreProtocol | None = None, - context_provider: AggregateContextProvider | None = None, + context_provider: ContextProvider | None = None, ) -> None: """Initialize an AgentThread, do not use this method manually, always use: ``agent.get_new_thread()``. diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 24481c3b3b..a0d0a13dc2 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -4,7 +4,15 @@ import inspect import json import sys -from collections.abc import AsyncIterable, Awaitable, Callable, Collection, Mapping, MutableMapping, Sequence +from collections.abc import ( + AsyncIterable, + Awaitable, + Callable, + Collection, + Mapping, + MutableMapping, + Sequence, +) from functools import wraps from time import perf_counter, time_ns from typing import ( @@ -18,6 +26,7 @@ Protocol, TypedDict, TypeVar, + Union, cast, get_args, get_origin, @@ -50,21 +59,12 @@ FunctionCallContent, ) -if sys.version_info >= (3, 12): - from typing import ( - TypedDict, # pragma: no cover - override, # type: ignore # pragma: no cover - ) -else: - from typing_extensions import ( - TypedDict, # pragma: no cover - override, # type: ignore[import] # pragma: no cover - ) +from typing import overload -if sys.version_info >= (3, 11): - from typing import overload # pragma: no cover +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover else: - from typing_extensions import overload # pragma: no cover + from typing_extensions import override # type: ignore[import] # pragma: no cover logger = get_logger() @@ -88,7 +88,7 @@ FUNCTION_INVOKING_CHAT_CLIENT_MARKER: Final[str] = "__function_invoking_chat_client__" DEFAULT_MAX_ITERATIONS: Final[int] = 40 DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3 -TChatClient = TypeVar("TChatClient", bound="ChatClientProtocol") +TChatClient = TypeVar("TChatClient", bound="ChatClientProtocol[Any]") # region Helpers ArgsT = TypeVar("ArgsT", bound=BaseModel) @@ -121,7 +121,13 @@ def _parse_inputs( if inputs is None: return [] - from ._types import BaseContent, DataContent, HostedFileContent, HostedVectorStoreContent, UriContent + from ._types import ( + BaseContent, + DataContent, + HostedFileContent, + HostedVectorStoreContent, + UriContent, + ) parsed_inputs: list["Contents"] = [] if not isinstance(inputs, list): @@ -820,8 +826,10 @@ async def invoke( else: logger.info(f"Function {self.name} succeeded.") if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined] + from ._types import prepare_function_call_results + try: - json_result = json.dumps(result) + json_result = prepare_function_call_results(result) except (TypeError, OverflowError): span.set_attribute(OtelAttr.TOOL_RESULT, "") logger.debug("Function result: ") @@ -1008,6 +1016,27 @@ def _build_pydantic_model_from_json_schema( if not properties: return create_model(f"{model_name}_input") + def _resolve_literal_type(prop_details: dict[str, Any]) -> type | None: + """Check if property should be a Literal type (const or enum). + + Args: + prop_details: The JSON Schema property details + + Returns: + Literal type if const or enum is present, None otherwise + """ + # const → Literal["value"] + if "const" in prop_details: + return Literal[prop_details["const"]] # type: ignore + + # enum → Literal["a", "b", ...] + if "enum" in prop_details and isinstance(prop_details["enum"], list): + enum_values = prop_details["enum"] + if enum_values: + return Literal[tuple(enum_values)] # type: ignore + + return None + def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: """Resolve JSON Schema type to Python type, handling $ref, nested objects, and typed arrays. @@ -1018,6 +1047,31 @@ def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: Returns: Python type annotation (could be int, str, list[str], or a nested Pydantic model) """ + # Handle oneOf + discriminator (polymorphic objects) + if "oneOf" in prop_details and "discriminator" in prop_details: + discriminator = prop_details["discriminator"] + disc_field = discriminator.get("propertyName") + + variants = [] + for variant in prop_details["oneOf"]: + if "$ref" in variant: + ref = variant["$ref"] + if ref.startswith("#/$defs/"): + def_name = ref.split("/")[-1] + resolved = definitions.get(def_name) + if resolved: + variant_model = _resolve_type( + resolved, + parent_name=f"{parent_name}_{def_name}", + ) + variants.append(variant_model) + + if variants and disc_field: + return Annotated[ + Union[tuple(variants)], # type: ignore + Field(discriminator=disc_field), + ] + # Handle $ref by resolving the reference if "$ref" in prop_details: ref = prop_details["$ref"] @@ -1068,9 +1122,15 @@ def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: else nested_prop_details ) - nested_python_type = _resolve_type( - nested_prop_details, f"{nested_model_name}_{nested_prop_name}" - ) + # Check for Literal types first (const/enum) + literal_type = _resolve_literal_type(nested_prop_details) + if literal_type is not None: + nested_python_type = literal_type + else: + nested_python_type = _resolve_type( + nested_prop_details, + f"{nested_model_name}_{nested_prop_name}", + ) nested_description = nested_prop_details.get("description", "") # Build field kwargs for nested property @@ -1107,7 +1167,12 @@ def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: for prop_name, prop_details in properties.items(): prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details - python_type = _resolve_type(prop_details, f"{model_name}_{prop_name}") + # Check for Literal types first (const/enum) + literal_type = _resolve_literal_type(prop_details) + if literal_type is not None: + python_type = literal_type + else: + python_type = _resolve_type(prop_details, f"{model_name}_{prop_name}") description = prop_details.get("description", "") # Build field kwargs (description, etc.) @@ -1690,19 +1755,19 @@ def _update_conversation_id(kwargs: dict[str, Any], conversation_id: str | None) kwargs["conversation_id"] = conversation_id -def _extract_tools(kwargs: dict[str, Any]) -> Any: - """Extract tools from kwargs or chat_options. +def _extract_tools(options: dict[str, Any] | None) -> Any: + """Extract tools from options dict. + + Args: + options: The options dict containing chat options. Returns: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None """ - from ._types import ChatOptions - - tools = kwargs.get("tools") - if not tools and (chat_options := kwargs.get("chat_options")) and isinstance(chat_options, ChatOptions): - tools = chat_options.tools - return tools + if options and isinstance(options, dict): + return options.get("tools") + return None def _collect_approval_responses( @@ -1795,6 +1860,8 @@ def decorator( async def function_invocation_wrapper( self: "ChatClientProtocol", messages: "str | ChatMessage | list[str] | list[ChatMessage]", + *, + options: dict[str, Any] | None = None, **kwargs: Any, ) -> "ChatResponse": from ._middleware import extract_and_merge_function_middleware @@ -1823,7 +1890,7 @@ async def function_invocation_wrapper( for attempt_idx in range(config.max_iterations if config.enabled else 0): fcc_todo = _collect_approval_responses(prepped_messages) if fcc_todo: - tools = _extract_tools(kwargs) + tools = _extract_tools(options) # Only execute APPROVED function calls, not rejected ones approved_responses = [resp for resp in fcc_todo.values() if resp.approved] approved_function_results: list[Contents] = [] @@ -1855,8 +1922,9 @@ async def function_invocation_wrapper( _replace_approval_contents_with_results(prepped_messages, fcc_todo, approved_function_results) # Filter out internal framework kwargs before passing to clients. - filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"} - response = await func(self, messages=prepped_messages, **filtered_kwargs) + # Also exclude tools and tool_choice since they are now in options dict. + filtered_kwargs = {k: v for k, v in kwargs.items() if k not in ("thread", "tools", "tool_choice")} + response = await func(self, messages=prepped_messages, options=options, **filtered_kwargs) # if there are function calls, we will handle them first function_results = { it.call_id for it in response.messages[0].contents if isinstance(it, FunctionResultContent) @@ -1872,7 +1940,7 @@ async def function_invocation_wrapper( prepped_messages = [] # we load the tools here, since middleware might have changed them compared to before calling func. - tools = _extract_tools(kwargs) + tools = _extract_tools(options) if function_calls and tools: # Use the stored middleware pipeline instead of extracting from kwargs # because kwargs may have been modified by the underlying function @@ -1955,11 +2023,13 @@ async def function_invocation_wrapper( return response # Failsafe: give up on tools, ask model for plain answer - kwargs["tool_choice"] = "none" + if options is None: + options = {} + options["tool_choice"] = "none" # Filter out internal framework kwargs before passing to clients. filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"} - response = await func(self, messages=prepped_messages, **filtered_kwargs) + response = await func(self, messages=prepped_messages, options=options, **filtered_kwargs) if fcc_messages: for msg in reversed(fcc_messages): response.messages.insert(0, msg) @@ -1991,6 +2061,8 @@ def decorator( async def streaming_function_invocation_wrapper( self: "ChatClientProtocol", messages: "str | ChatMessage | list[str] | list[ChatMessage]", + *, + options: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable["ChatResponseUpdate"]: """Wrap the inner get streaming response method to handle tool calls.""" @@ -2019,7 +2091,7 @@ async def streaming_function_invocation_wrapper( for attempt_idx in range(config.max_iterations if config.enabled else 0): fcc_todo = _collect_approval_responses(prepped_messages) if fcc_todo: - tools = _extract_tools(kwargs) + tools = _extract_tools(options) # Only execute APPROVED function calls, not rejected ones approved_responses = [resp for resp in fcc_todo.values() if resp.approved] approved_function_results: list[Contents] = [] @@ -2045,7 +2117,7 @@ async def streaming_function_invocation_wrapper( all_updates: list["ChatResponseUpdate"] = [] # Filter out internal framework kwargs before passing to clients. filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"} - async for update in func(self, messages=prepped_messages, **filtered_kwargs): + async for update in func(self, messages=prepped_messages, options=options, **filtered_kwargs): all_updates.append(update) yield update @@ -2083,7 +2155,7 @@ async def streaming_function_invocation_wrapper( prepped_messages = [] # we load the tools here, since middleware might have changed them compared to before calling func. - tools = _extract_tools(kwargs) + tools = _extract_tools(options) if function_calls and tools: # Use the stored middleware pipeline instead of extracting from kwargs # because kwargs may have been modified by the underlying function @@ -2162,10 +2234,12 @@ async def streaming_function_invocation_wrapper( return # Failsafe: give up on tools, ask model for plain answer - kwargs["tool_choice"] = "none" + if options is None: + options = {} + options["tool_choice"] = "none" # Filter out internal framework kwargs before passing to clients. filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"} - async for update in func(self, messages=prepped_messages, **filtered_kwargs): + async for update in func(self, messages=prepped_messages, options=options, **filtered_kwargs): yield update return streaming_function_invocation_wrapper diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index ebe3d23e6f..4ee640304e 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -13,7 +13,7 @@ Sequence, ) from copy import deepcopy -from typing import Any, ClassVar, Literal, TypeVar, cast, overload +from typing import Any, ClassVar, Literal, TypedDict, TypeVar, cast, overload from pydantic import BaseModel, ValidationError @@ -29,13 +29,14 @@ __all__ = [ - "AgentRunResponse", - "AgentRunResponseUpdate", + "AgentResponse", + "AgentResponseUpdate", "AnnotatedRegions", "Annotations", "BaseAnnotation", "BaseContent", "ChatMessage", + "ChatOptions", # Backward compatibility alias "ChatOptions", "ChatResponse", "ChatResponseUpdate", @@ -64,7 +65,13 @@ "UriContent", "UsageContent", "UsageDetails", + "merge_chat_options", + "normalize_tools", "prepare_function_call_results", + "prepend_instructions_to_messages", + "validate_chat_options", + "validate_tool_mode", + "validate_tools", ] logger = get_logger("agent_framework") @@ -182,7 +189,7 @@ def _parse_content_list(contents_data: Sequence[Any]) -> list["Contents"]: TEmbedding = TypeVar("TEmbedding") TChatResponse = TypeVar("TChatResponse", bound="ChatResponse") TToolMode = TypeVar("TToolMode", bound="ToolMode") -TAgentRunResponse = TypeVar("TAgentRunResponse", bound="AgentRunResponse") +TAgentRunResponse = TypeVar("TAgentRunResponse", bound="AgentResponse") CreatedAtT = str # Use a datetimeoffset type? Or a more specific type like datetime.datetime? @@ -2457,7 +2464,7 @@ def text(self) -> str: def prepare_messages( - messages: str | ChatMessage | list[str] | list[ChatMessage], system_instructions: str | list[str] | None = None + messages: str | ChatMessage | Sequence[str | ChatMessage], system_instructions: str | Sequence[str] | None = None ) -> list[ChatMessage]: """Convert various message input formats into a list of ChatMessage objects. @@ -2488,11 +2495,54 @@ def prepare_messages( return return_messages +def prepend_instructions_to_messages( + messages: list[ChatMessage], + instructions: str | Sequence[str] | None, + role: Role | Literal["system", "user", "assistant"] = "system", +) -> list[ChatMessage]: + """Prepend instructions to a list of messages with a specified role. + + This is a helper method for chat clients that need to add instructions + from options as messages. Different providers support different roles for + instructions (e.g., OpenAI uses "system", some providers might use "user"). + + Args: + messages: The existing list of ChatMessage objects. + instructions: The instructions to prepend. Can be a single string or a sequence of strings. + role: The role to use for the instruction messages. Defaults to "system". + + Returns: + A new list with instruction messages prepended. + + Examples: + .. code-block:: python + + from agent_framework import prepend_instructions_to_messages, ChatMessage + + messages = [ChatMessage(role="user", text="Hello")] + instructions = "You are a helpful assistant" + + # Prepend as system message (default) + messages_with_instructions = prepend_instructions_to_messages(messages, instructions) + + # Or use a different role + messages_with_user_instructions = prepend_instructions_to_messages(messages, instructions, role="user") + """ + if instructions is None: + return messages + + if isinstance(instructions, str): + instructions = [instructions] + + instruction_messages = [ChatMessage(role=role, text=instr) for instr in instructions] + return [*instruction_messages, *messages] + + # region ChatResponse def _process_update( - response: "ChatResponse | AgentRunResponse", update: "ChatResponseUpdate | AgentRunResponseUpdate" + response: "ChatResponse | AgentResponse", update: "ChatResponseUpdate | AgentResponseUpdate" ) -> None: """Processes a single update and modifies the response in place.""" is_new_message = False @@ -2596,7 +2646,7 @@ def _coalesce_text_content( contents.extend(coalesced_contents) -def _finalize_response(response: "ChatResponse | AgentRunResponse") -> None: +def _finalize_response(response: "ChatResponse | AgentResponse") -> None: """Finalizes the response by performing any necessary post-processing.""" for msg in response.messages: _coalesce_text_content(msg.contents, TextContent) @@ -2845,7 +2895,7 @@ async def from_chat_response_generator( cls: type[TChatResponse], updates: AsyncIterable["ChatResponseUpdate"], *, - output_format_type: type[BaseModel] | None = None, + output_format_type: type[BaseModel] | Mapping[str, Any] | None = None, ) -> TChatResponse: """Joins multiple updates into a single ChatResponse. @@ -2870,7 +2920,7 @@ async def from_chat_response_generator( async for update in updates: _process_update(msg, update) _finalize_response(msg) - if output_format_type: + if output_format_type and isinstance(output_format_type, type) and issubclass(output_format_type, BaseModel): msg.try_parse_value(output_format_type) return msg @@ -2884,7 +2934,7 @@ def __str__(self) -> str: def try_parse_value(self, output_format_type: type[BaseModel]) -> None: """If there is a value, does nothing, otherwise tries to parse the text into the value.""" - if self.value is None: + if self.value is None and isinstance(output_format_type, type) and issubclass(output_format_type, BaseModel): try: self.value = output_format_type.model_validate_json(self.text) # type: ignore[reportUnknownMemberType] except ValidationError as ex: @@ -3018,10 +3068,10 @@ def __str__(self) -> str: return self.text -# region AgentRunResponse +# region AgentResponse -class AgentRunResponse(SerializationMixin): +class AgentResponse(SerializationMixin): """Represents the response to an Agent run request. Provides one or more response messages and metadata about the response. @@ -3031,11 +3081,11 @@ class AgentRunResponse(SerializationMixin): Examples: .. code-block:: python - from agent_framework import AgentRunResponse, ChatMessage + from agent_framework import AgentResponse, ChatMessage # Create agent response msg = ChatMessage(role="assistant", text="Task completed successfully.") - response = AgentRunResponse(messages=[msg], response_id="run_123") + response = AgentResponse(messages=[msg], response_id="run_123") print(response.text) # "Task completed successfully." # Access user input requests @@ -3043,20 +3093,20 @@ class AgentRunResponse(SerializationMixin): print(len(user_requests)) # 0 # Combine streaming updates - updates = [...] # List of AgentRunResponseUpdate objects - response = AgentRunResponse.from_agent_run_response_updates(updates) + updates = [...] # List of AgentResponseUpdate objects + response = AgentResponse.from_agent_run_response_updates(updates) # Serialization - to_dict and from_dict response_dict = response.to_dict() - # {'type': 'agent_run_response', 'messages': [...], 'response_id': 'run_123', + # {'type': 'agent_response', 'messages': [...], 'response_id': 'run_123', # 'additional_properties': {}} - restored_response = AgentRunResponse.from_dict(response_dict) + restored_response = AgentResponse.from_dict(response_dict) print(restored_response.response_id) # "run_123" # Serialization - to_json and from_json response_json = response.to_json() - # '{"type": "agent_run_response", "messages": [...], "response_id": "run_123", ...}' - restored_from_json = AgentRunResponse.from_json(response_json) + # '{"type": "agent_response", "messages": [...], "response_id": "run_123", ...}' + restored_from_json = AgentResponse.from_json(response_json) print(restored_from_json.text) # "Task completed successfully." """ @@ -3078,7 +3128,7 @@ def __init__( additional_properties: dict[str, Any] | None = None, **kwargs: Any, ) -> None: - """Initialize an AgentRunResponse. + """Initialize an AgentResponse. Keyword Args: messages: The list of chat messages in the response. @@ -3136,14 +3186,14 @@ def user_input_requests(self) -> list[UserInputRequestContents]: @classmethod def from_agent_run_response_updates( cls: type[TAgentRunResponse], - updates: Sequence["AgentRunResponseUpdate"], + updates: Sequence["AgentResponseUpdate"], *, output_format_type: type[BaseModel] | None = None, ) -> TAgentRunResponse: - """Joins multiple updates into a single AgentRunResponse. + """Joins multiple updates into a single AgentResponse. Args: - updates: A sequence of AgentRunResponseUpdate objects to combine. + updates: A sequence of AgentResponseUpdate objects to combine. Keyword Args: output_format_type: Optional Pydantic model type to parse the response text into structured data. @@ -3159,14 +3209,14 @@ def from_agent_run_response_updates( @classmethod async def from_agent_response_generator( cls: type[TAgentRunResponse], - updates: AsyncIterable["AgentRunResponseUpdate"], + updates: AsyncIterable["AgentResponseUpdate"], *, output_format_type: type[BaseModel] | None = None, ) -> TAgentRunResponse: - """Joins multiple updates into a single AgentRunResponse. + """Joins multiple updates into a single AgentResponse. Args: - updates: An async iterable of AgentRunResponseUpdate objects to combine. + updates: An async iterable of AgentResponseUpdate objects to combine. Keyword Args: output_format_type: Optional Pydantic model type to parse the response text into structured data @@ -3191,19 +3241,19 @@ def try_parse_value(self, output_format_type: type[BaseModel]) -> None: logger.debug("Failed to parse value from agent run response text: %s", ex) -# region AgentRunResponseUpdate +# region AgentResponseUpdate -class AgentRunResponseUpdate(SerializationMixin): +class AgentResponseUpdate(SerializationMixin): """Represents a single streaming response chunk from an Agent. Examples: .. code-block:: python - from agent_framework import AgentRunResponseUpdate, TextContent + from agent_framework import AgentResponseUpdate, TextContent # Create an agent run update - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[TextContent(text="Processing...")], role="assistant", response_id="run_123", @@ -3215,15 +3265,15 @@ class AgentRunResponseUpdate(SerializationMixin): # Serialization - to_dict and from_dict update_dict = update.to_dict() - # {'type': 'agent_run_response_update', 'contents': [{'type': 'text', 'text': 'Processing...'}], + # {'type': 'agent_response_update', 'contents': [{'type': 'text', 'text': 'Processing...'}], # 'role': {'type': 'role', 'value': 'assistant'}, 'response_id': 'run_123'} - restored_update = AgentRunResponseUpdate.from_dict(update_dict) + restored_update = AgentResponseUpdate.from_dict(update_dict) print(restored_update.response_id) # "run_123" # Serialization - to_json and from_json update_json = update.to_json() - # '{"type": "agent_run_response_update", "contents": [{"type": "text", "text": "Processing..."}], ...}' - restored_from_json = AgentRunResponseUpdate.from_json(update_json) + # '{"type": "agent_response_update", "contents": [{"type": "text", "text": "Processing..."}], ...}' + restored_from_json = AgentResponseUpdate.from_json(update_json) print(restored_from_json.text) # "Processing..." """ @@ -3243,7 +3293,7 @@ def __init__( raw_representation: Any | None = None, **kwargs: Any, ) -> None: - """Initialize an AgentRunResponseUpdate. + """Initialize an AgentResponseUpdate. Keyword Args: contents: Optional list of BaseContent items or dicts to include in the update. @@ -3301,372 +3351,369 @@ def __str__(self) -> str: # region ChatOptions -class ToolMode(SerializationMixin, metaclass=EnumLike): - """Defines if and how tools are used in a chat request. +class ToolMode(TypedDict, total=False): + """Tool choice mode for the chat options. + + Fields: + mode: One of "auto", "required", or "none". + required_function_name: Optional function name when `mode == "required"`. + """ + + mode: Literal["auto", "required", "none"] + required_function_name: str + + +# region TypedDict-based Chat Options + + +class ChatOptions(TypedDict, total=False): + """Common request settings for AI services as a TypedDict. + + All fields are optional (total=False) to allow partial specification. + Provider-specific TypedDicts extend this with additional options. + + These options represent the common denominator across chat providers. + Individual implementations may raise errors for unsupported options. Examples: .. code-block:: python - from agent_framework import ToolMode - - # Use predefined tool modes - auto_mode = ToolMode.AUTO # Model decides when to use tools - required_mode = ToolMode.REQUIRED_ANY # Model must use a tool - none_mode = ToolMode.NONE # No tools allowed + from agent_framework import ChatOptions, ToolMode - # Require a specific function - specific_mode = ToolMode.REQUIRED(function_name="get_weather") - print(specific_mode.required_function_name) # "get_weather" + # Type-safe options + options: ChatOptions = { + "temperature": 0.7, + "max_tokens": 1000, + "model_id": "gpt-4", + } - # Compare modes - print(auto_mode == "auto") # True + # With tools + options_with_tools: ChatOptions = { + "model_id": "gpt-4", + "tool_choice": "auto", + "temperature": 0.7, + } + + # Used with Unpack for function signatures + # async def get_response(self, **options: Unpack[ChatOptions]) -> ChatResponse: """ - # Constants configuration for EnumLike metaclass - _constants: ClassVar[dict[str, tuple[str, ...]]] = { - "AUTO": ("auto",), - "REQUIRED_ANY": ("required",), - "NONE": ("none",), - } + # Model selection + model_id: str - # Type annotations for constants - AUTO: "ToolMode" - REQUIRED_ANY: "ToolMode" - NONE: "ToolMode" + # Generation parameters + temperature: float + top_p: float + max_tokens: int + stop: str | Sequence[str] + seed: int + logit_bias: dict[str | int, float] - def __init__( - self, - mode: Literal["auto", "required", "none"] = "none", - *, - required_function_name: str | None = None, - ) -> None: - """Initialize ToolMode. + # Penalty parameters + frequency_penalty: float + presence_penalty: float - Args: - mode: The tool mode - "auto", "required", or "none". + # Tool configuration (forward reference to avoid circular import) + tools: "ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None" # noqa: E501 + tool_choice: ToolMode | Literal["auto", "required", "none"] + allow_multiple_tool_calls: bool - Keyword Args: - required_function_name: Optional function name for required mode. - """ - self.mode = mode - self.required_function_name = required_function_name + # Response configuration + response_format: type[BaseModel] | dict[str, Any] - @classmethod - def REQUIRED(cls, function_name: str | None = None) -> "ToolMode": - """Returns a ToolMode that requires the specified function to be called.""" - return cls(mode="required", required_function_name=function_name) + # Metadata + metadata: dict[str, Any] + user: str + store: bool + conversation_id: str - def __eq__(self, other: object) -> bool: - """Checks equality with another ToolMode or string.""" - if isinstance(other, str): - return self.mode == other - if isinstance(other, ToolMode): - return self.mode == other.mode and self.required_function_name == other.required_function_name - return False + # System/instructions + instructions: str - def __hash__(self) -> int: - """Return hash of the ToolMode for use in sets and dicts.""" - return hash((self.mode, self.required_function_name)) - def serialize_model(self) -> str: - """Serializes the ToolMode to just the mode string.""" - return self.mode +# region Chat Options Utility Functions - def __str__(self) -> str: - """Returns the string representation of the mode.""" - return self.mode - def __repr__(self) -> str: - """Returns the string representation of the ToolMode.""" - if self.required_function_name: - return f"ToolMode(mode={self.mode!r}, required_function_name={self.required_function_name!r})" - return f"ToolMode(mode={self.mode!r})" +async def validate_chat_options(options: dict[str, Any]) -> dict[str, Any]: + """Validate and normalize chat options dictionary. + Validates numeric constraints and converts types as needed. -class ChatOptions(SerializationMixin): - """Common request settings for AI services. + Args: + options: The options dictionary to validate. + + Returns: + The validated and normalized options dictionary. + + Raises: + ValueError: If any option value is invalid. Examples: .. code-block:: python - from agent_framework import ChatOptions, ai_function + from agent_framework import validate_chat_options - # Create basic chat options - options = ChatOptions( - model_id="gpt-4", - temperature=0.7, - max_tokens=1000, - ) + options = await validate_chat_options({ + "temperature": 0.7, + "max_tokens": 1000, + }) + """ + result = dict(options) # Make a copy + # Validate numeric constraints + if (freq_pen := result.get("frequency_penalty")) is not None: + if not (-2.0 <= freq_pen <= 2.0): + raise ValueError("frequency_penalty must be between -2.0 and 2.0") + result["frequency_penalty"] = float(freq_pen) - # With tools - @ai_function - def get_weather(location: str) -> str: - '''Get weather for a location.''' - return f"Weather in {location}" + if (pres_pen := result.get("presence_penalty")) is not None: + if not (-2.0 <= pres_pen <= 2.0): + raise ValueError("presence_penalty must be between -2.0 and 2.0") + result["presence_penalty"] = float(pres_pen) + if (temp := result.get("temperature")) is not None: + if not (0.0 <= temp <= 2.0): + raise ValueError("temperature must be between 0.0 and 2.0") + result["temperature"] = float(temp) - options = ChatOptions( - model_id="gpt-4", - tools=get_weather, - tool_choice="auto", - ) + if (top_p := result.get("top_p")) is not None: + if not (0.0 <= top_p <= 1.0): + raise ValueError("top_p must be between 0.0 and 1.0") + result["top_p"] = float(top_p) - # Require a specific tool to be called - options_required = ChatOptions( - model_id="gpt-4", - tools=get_weather, - tool_choice=ToolMode.REQUIRED(function_name="get_weather"), - ) + if (max_tokens := result.get("max_tokens")) is not None and max_tokens <= 0: + raise ValueError("max_tokens must be greater than 0") - # Combine options - base_options = ChatOptions(temperature=0.5) - extended_options = ChatOptions(max_tokens=500, tools=get_weather) - combined = base_options & extended_options - """ + # Validate and normalize tools + if "tools" in result: + result["tools"] = await validate_tools(result["tools"]) - DEFAULT_EXCLUDE: ClassVar[set[str]] = {"_tools"} # Internal field, use .tools property + return result - def __init__( - self, - *, - model_id: str | None = None, - allow_multiple_tool_calls: bool | None = None, - conversation_id: str | None = None, - frequency_penalty: float | None = None, - instructions: str | None = None, - logit_bias: MutableMapping[str | int, float] | None = None, - max_tokens: int | None = None, - metadata: MutableMapping[str, str] | None = None, - presence_penalty: float | None = None, - response_format: type[BaseModel] | None = None, - seed: int | None = None, - stop: str | Sequence[str] | None = None, - store: bool | None = None, - temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | Mapping[str, Any] | None = None, - tools: ToolProtocol + +def normalize_tools( + tools: ( + ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] - | None = None, - top_p: float | None = None, - user: str | None = None, - additional_properties: MutableMapping[str, Any] | None = None, - **kwargs: Any, - ): - """Initialize ChatOptions. + | None + ), +) -> list[ToolProtocol | MutableMapping[str, Any]]: + """Normalize tools into a list. - Keyword Args: - model_id: The AI model ID to use. - allow_multiple_tool_calls: Whether to allow multiple tool calls. - conversation_id: The conversation ID. - frequency_penalty: The frequency penalty (must be between -2.0 and 2.0). - instructions: the instructions, will be turned into a system or equivalent message. - logit_bias: The logit bias mapping. - max_tokens: The maximum number of tokens (must be > 0). - metadata: Metadata mapping. - presence_penalty: The presence penalty (must be between -2.0 and 2.0). - response_format: Structured output response format schema. Must be a valid Pydantic model. - seed: Random seed for reproducibility. - stop: Stop sequences. - store: Whether to store the conversation. - temperature: The temperature (must be between 0.0 and 2.0). - tool_choice: The tool choice mode. - tools: List of available tools. - top_p: The top-p value (must be between 0.0 and 1.0). - user: The user ID. - additional_properties: Provider-specific additional properties, can also be passed as kwargs. - **kwargs: Additional properties to include in additional_properties. - """ - # Validate numeric constraints and convert types as needed - if frequency_penalty is not None: - if not (-2.0 <= frequency_penalty <= 2.0): - raise ValueError("frequency_penalty must be between -2.0 and 2.0") - frequency_penalty = float(frequency_penalty) - if presence_penalty is not None: - if not (-2.0 <= presence_penalty <= 2.0): - raise ValueError("presence_penalty must be between -2.0 and 2.0") - presence_penalty = float(presence_penalty) - if temperature is not None: - if not (0.0 <= temperature <= 2.0): - raise ValueError("temperature must be between 0.0 and 2.0") - temperature = float(temperature) - if top_p is not None: - if not (0.0 <= top_p <= 1.0): - raise ValueError("top_p must be between 0.0 and 1.0") - top_p = float(top_p) - if max_tokens is not None and max_tokens <= 0: - raise ValueError("max_tokens must be greater than 0") - - if additional_properties is None: - additional_properties = {} - if kwargs: - additional_properties.update(kwargs) - - self.additional_properties = cast(dict[str, Any], additional_properties) - self.model_id = model_id - self.allow_multiple_tool_calls = allow_multiple_tool_calls - self.conversation_id = conversation_id - self.frequency_penalty = frequency_penalty - self.instructions = instructions - self.logit_bias = logit_bias - self.max_tokens = max_tokens - self.metadata = metadata - self.presence_penalty = presence_penalty - self.response_format = response_format - self.seed = seed - self.stop = stop - self.store = store - self.temperature = temperature - self.tool_choice = self._validate_tool_mode(tool_choice) - self._tools = self._validate_tools(tools) - self.top_p = top_p - self.user = user - - def __deepcopy__(self, memo: dict[int, Any]) -> "ChatOptions": - """Create a runtime-safe copy without deep-copying tool instances.""" - clone = type(self).__new__(type(self)) - memo[id(self)] = clone - for key, value in self.__dict__.items(): - if key == "_tools": - setattr(clone, key, list(value) if value is not None else None) - continue - if key in {"logit_bias", "metadata", "additional_properties"}: - setattr(clone, key, self._safe_deepcopy_mapping(value, memo)) - continue - setattr(clone, key, self._safe_deepcopy_value(value, memo)) - return clone + Converts callables to AIFunction objects and ensures all tools are either + ToolProtocol instances or MutableMappings. - @staticmethod - def _safe_deepcopy_mapping( - value: MutableMapping[str, Any] | None, memo: dict[int, Any] - ) -> MutableMapping[str, Any] | None: - """Deep copy helper that falls back to a shallow copy for problematic mappings.""" - if value is None: - return None - try: - return deepcopy(value, memo) # type: ignore[arg-type] - except Exception: - return dict(value) + Args: + tools: Tools to normalize - can be a single tool, callable, or sequence. - @staticmethod - def _safe_deepcopy_value(value: Any, memo: dict[int, Any]) -> Any: - """Deep copy helper that avoids failing on non-copyable instances.""" - try: - return deepcopy(value, memo) - except Exception: - return value + Returns: + Normalized list of tools. - @property - def tools(self) -> list[ToolProtocol | MutableMapping[str, Any]] | None: - """Return the tools that are specified.""" - return self._tools + Examples: + .. code-block:: python - @tools.setter - def tools( - self, - new_tools: ToolProtocol + from agent_framework import normalize_tools, ai_function + + + @ai_function + def my_tool(x: int) -> int: + return x * 2 + + + # Single tool + tools = normalize_tools(my_tool) + + # List of tools + tools = normalize_tools([my_tool, another_tool]) + """ + final_tools: list[ToolProtocol | MutableMapping[str, Any]] = [] + if not tools: + return final_tools + if not isinstance(tools, Sequence) or isinstance(tools, (str, MutableMapping)): + # Single tool (not a sequence, or is a mapping which shouldn't be treated as sequence) + if not isinstance(tools, (ToolProtocol, MutableMapping)): + return [ai_function(tools)] + return [tools] + for tool in tools: + if isinstance(tool, (ToolProtocol, MutableMapping)): + final_tools.append(tool) + else: + # Convert callable to AIFunction + final_tools.append(ai_function(tool)) + return final_tools + + +async def validate_tools( + tools: ( + ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] - | None, - ) -> None: - """Set the tools.""" - self._tools = self._validate_tools(new_tools) + | None + ), +) -> list[ToolProtocol | MutableMapping[str, Any]]: + """Validate and normalize tools into a list. - @classmethod - def _validate_tools( - cls, - tools: ( - ToolProtocol - | Callable[..., Any] - | MutableMapping[str, Any] - | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] - | None - ), - ) -> list[ToolProtocol | MutableMapping[str, Any]] | None: - """Parse the tools field.""" - if not tools: - return None - if not isinstance(tools, Sequence): - if not isinstance(tools, (ToolProtocol, MutableMapping)): - return [ai_function(tools)] - return [tools] - return [tool if isinstance(tool, (ToolProtocol, MutableMapping)) else ai_function(tool) for tool in tools] + Converts callables to AIFunction objects, expands MCP tools to their constituent + functions (connecting them if needed), and ensures all tools are either ToolProtocol + instances or MutableMappings. - @classmethod - def _validate_tool_mode( - cls, tool_choice: ToolMode | Literal["auto", "required", "none"] | Mapping[str, Any] | None - ) -> ToolMode | None: - """Validates the tool_choice field to ensure it is a valid ToolMode.""" - if not tool_choice: - return None - if isinstance(tool_choice, str): - match tool_choice: - case "auto": - return ToolMode.AUTO - case "required": - return ToolMode.REQUIRED_ANY - case "none": - return ToolMode.NONE - case _: - raise ContentError(f"Invalid tool choice: {tool_choice}") - if isinstance(tool_choice, (dict, Mapping)): - return ToolMode.from_dict(tool_choice) # type: ignore - return tool_choice - - def __and__(self, other: object) -> "ChatOptions": - """Combines two ChatOptions instances. - - The values from the other ChatOptions take precedence. - List and dicts are combined. - """ - if not isinstance(other, ChatOptions): - return self - other_tools = other.tools - # tool_choice has a specialized serialize method. Save it here so we can fix it later. - tool_choice = other.tool_choice or self.tool_choice - # response_format is a class type that can't be serialized. Save it here so we can restore it later. - response_format = self.response_format - # Start with a shallow copy of self that preserves tool objects - combined = ChatOptions.from_dict(self.to_dict()) - combined.tool_choice = self.tool_choice - combined.tools = list(self.tools) if self.tools else None - combined.logit_bias = dict(self.logit_bias) if self.logit_bias else None - combined.metadata = dict(self.metadata) if self.metadata else None - combined.response_format = response_format - - # Apply scalar and mapping updates from the other options - updated_data = other.to_dict(exclude_none=True, exclude={"tools"}) - logit_bias = updated_data.pop("logit_bias", {}) - metadata = updated_data.pop("metadata", {}) - additional_properties: dict[str, Any] = updated_data.pop("additional_properties", {}) - - for key, value in updated_data.items(): - setattr(combined, key, value) - - combined.tool_choice = tool_choice - # Preserve response_format from other if it exists, otherwise keep self's - if other.response_format is not None: - combined.response_format = other.response_format - if other.instructions: - combined.instructions = "\n".join([combined.instructions or "", other.instructions or ""]) - - combined.logit_bias = ( - {**(combined.logit_bias or {}), **logit_bias} if logit_bias or combined.logit_bias else None - ) - combined.metadata = {**(combined.metadata or {}), **metadata} if metadata or combined.metadata else None - if combined.additional_properties and additional_properties: - combined.additional_properties.update(additional_properties) + Args: + tools: Tools to validate - can be a single tool, callable, or sequence. + + Returns: + Normalized list of tools, or None if no tools provided. + + Examples: + .. code-block:: python + + from agent_framework import validate_tools, ai_function + + + @ai_function + def my_tool(x: int) -> int: + return x * 2 + + + # Single tool + tools = await validate_tools(my_tool) + + # List of tools + tools = await validate_tools([my_tool, another_tool]) + """ + # Use normalize_tools for common sync logic (converts callables to AIFunction) + normalized = normalize_tools(tools) + + # Handle MCP tool expansion (async-only) + final_tools: list[ToolProtocol | MutableMapping[str, Any]] = [] + for tool in normalized: + # Import MCPTool here to avoid circular imports + from ._mcp import MCPTool + + if isinstance(tool, MCPTool): + # Expand MCP tools to their constituent functions + if not tool.is_connected: + await tool.connect() + final_tools.extend(tool.functions) # type: ignore else: - if additional_properties: - combined.additional_properties = additional_properties - if other_tools: - if combined.tools is None: - combined.tools = list(other_tools) + final_tools.append(tool) + + return final_tools + + +def validate_tool_mode( + tool_choice: ToolMode | Literal["auto", "required", "none"] | None, +) -> ToolMode: + """Validate and normalize tool_choice to a ToolMode dict. + + Args: + tool_choice: The tool choice value to validate. + + Returns: + A ToolMode dict (contains keys: "mode", and optionally "required_function_name"). + + Raises: + ContentError: If the tool_choice string is invalid. + """ + if not tool_choice: + return {"mode": "none"} + if isinstance(tool_choice, str): + if tool_choice not in ("auto", "required", "none"): + raise ContentError(f"Invalid tool choice: {tool_choice}") + return {"mode": tool_choice} + if "mode" not in tool_choice: + raise ContentError("tool_choice dict must contain 'mode' key") + if tool_choice["mode"] not in ("auto", "required", "none"): + raise ContentError(f"Invalid tool choice: {tool_choice['mode']}") + if tool_choice["mode"] != "required" and "required_function_name" in tool_choice: + raise ContentError("tool_choice with mode other than 'required' cannot have 'required_function_name'") + return tool_choice + + +def merge_chat_options( + base: dict[str, Any] | None, + override: dict[str, Any] | None, +) -> dict[str, Any]: + """Merge two chat options dictionaries. + + Values from override take precedence over base. + Lists and dicts are combined (not replaced). + Instructions are concatenated with newlines. + + Args: + base: The base options dictionary. + override: The override options dictionary. + + Returns: + A new merged options dictionary. + + Examples: + .. code-block:: python + + from agent_framework import merge_chat_options + + base = {"temperature": 0.5, "model_id": "gpt-4"} + override = {"temperature": 0.7, "max_tokens": 1000} + merged = merge_chat_options(base, override) + # {"temperature": 0.7, "model_id": "gpt-4", "max_tokens": 1000} + """ + if not base: + return dict(override) if override else {} + if not override: + return dict(base) + + result: dict[str, Any] = {} + + # Copy base values (shallow copy for simple values, dict copy for dicts) + for key, value in base.items(): + if isinstance(value, dict): + result[key] = dict(value) + elif isinstance(value, list): + result[key] = list(value) + else: + result[key] = value + + # Apply overrides + for key, value in override.items(): + if value is None: + continue + + if key == "instructions": + # Concatenate instructions + base_instructions = result.get("instructions") + if base_instructions: + result["instructions"] = f"{base_instructions}\n{value}" else: - for tool in other_tools: - if tool not in combined.tools: - combined.tools.append(tool) - return combined + result["instructions"] = value + elif key == "tools": + # Merge tools lists + base_tools = result.get("tools") + if base_tools and value: + # Add tools that aren't already present + merged_tools = list(base_tools) + for tool in value if isinstance(value, list) else [value]: + if tool not in merged_tools: + merged_tools.append(tool) + result["tools"] = merged_tools + elif value: + result["tools"] = list(value) if isinstance(value, list) else [value] + elif key in ("logit_bias", "metadata", "additional_properties"): + # Merge dicts + base_dict = result.get(key) + if base_dict and isinstance(value, dict): + result[key] = {**base_dict, **value} + elif value: + result[key] = dict(value) if isinstance(value, dict) else value + elif key == "tool_choice": + # tool_choice from override takes precedence + result["tool_choice"] = value if value else result.get("tool_choice") + elif key == "response_format": + # response_format from override takes precedence if set + result["response_format"] = value + else: + # Simple override + result[key] = value + + return result diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index 42e48c50cf..70706ff827 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -6,6 +6,13 @@ AgentExecutorRequest, AgentExecutorResponse, ) +from ._agent_utils import resolve_agent_id +from ._base_group_chat_orchestrator import ( + BaseGroupChatOrchestrator, + GroupChatRequestMessage, + GroupChatRequestSentEvent, + GroupChatResponseReceivedEvent, +) from ._checkpoint import ( CheckpointStorage, FileCheckpointStorage, @@ -21,6 +28,7 @@ Case, Default, Edge, + EdgeCondition, FanInEdgeGroup, FanOutEdgeGroup, SingleEdgeGroup, @@ -49,43 +57,42 @@ WorkflowStartedEvent, WorkflowStatusEvent, ) +from ._exceptions import ( + WorkflowCheckpointException, + WorkflowConvergenceException, + WorkflowException, + WorkflowRunnerException, +) from ._executor import ( Executor, handler, ) from ._function_executor import FunctionExecutor, executor from ._group_chat import ( - DEFAULT_MANAGER_INSTRUCTIONS, - DEFAULT_MANAGER_STRUCTURED_OUTPUT_PROMPT, + AgentBasedGroupChatOrchestrator, GroupChatBuilder, - GroupChatDirective, - GroupChatStateSnapshot, - ManagerDirectiveModel, - ManagerSelectionRequest, - ManagerSelectionResponse, + GroupChatState, ) -from ._handoff import HandoffBuilder, HandoffUserInputRequest +from ._handoff import HandoffAgentUserRequest, HandoffBuilder, HandoffSentEvent from ._magentic import ( - MAGENTIC_EVENT_TYPE_AGENT_DELTA, - MAGENTIC_EVENT_TYPE_ORCHESTRATOR, ORCH_MSG_KIND_INSTRUCTION, ORCH_MSG_KIND_NOTICE, ORCH_MSG_KIND_TASK_LEDGER, ORCH_MSG_KIND_USER_TASK, MagenticBuilder, MagenticContext, - MagenticHumanInputRequest, - MagenticHumanInterventionDecision, - MagenticHumanInterventionKind, - MagenticHumanInterventionReply, - MagenticHumanInterventionRequest, MagenticManagerBase, - MagenticStallInterventionDecision, - MagenticStallInterventionReply, - MagenticStallInterventionRequest, + MagenticOrchestrator, + MagenticOrchestratorEvent, + MagenticOrchestratorEventType, + MagenticPlanReviewRequest, + MagenticPlanReviewResponse, + MagenticProgressLedger, + MagenticProgressLedgerItem, + MagenticResetSignal, StandardMagenticManager, ) -from ._orchestration_request_info import AgentInputRequest, AgentResponseReviewRequest, RequestInfoInterceptor +from ._orchestration_request_info import AgentRequestInfoResponse from ._orchestration_state import OrchestrationState from ._request_info_mixin import response_handler from ._runner import Runner @@ -108,30 +115,32 @@ from ._workflow import Workflow, WorkflowRunResult from ._workflow_builder import WorkflowBuilder from ._workflow_context import WorkflowContext -from ._workflow_executor import SubWorkflowRequestMessage, SubWorkflowResponseMessage, WorkflowExecutor +from ._workflow_executor import ( + SubWorkflowRequestMessage, + SubWorkflowResponseMessage, + WorkflowExecutor, +) __all__ = [ - "DEFAULT_MANAGER_INSTRUCTIONS", - "DEFAULT_MANAGER_STRUCTURED_OUTPUT_PROMPT", "DEFAULT_MAX_ITERATIONS", - "MAGENTIC_EVENT_TYPE_AGENT_DELTA", - "MAGENTIC_EVENT_TYPE_ORCHESTRATOR", "ORCH_MSG_KIND_INSTRUCTION", "ORCH_MSG_KIND_NOTICE", "ORCH_MSG_KIND_TASK_LEDGER", "ORCH_MSG_KIND_USER_TASK", + "AgentBasedGroupChatOrchestrator", "AgentExecutor", "AgentExecutorRequest", "AgentExecutorResponse", - "AgentInputRequest", - "AgentResponseReviewRequest", + "AgentRequestInfoResponse", "AgentRunEvent", "AgentRunUpdateEvent", + "BaseGroupChatOrchestrator", "Case", "CheckpointStorage", "ConcurrentBuilder", "Default", "Edge", + "EdgeCondition", "EdgeDuplicationError", "Executor", "ExecutorCompletedEvent", @@ -144,30 +153,29 @@ "FunctionExecutor", "GraphConnectivityError", "GroupChatBuilder", - "GroupChatDirective", - "GroupChatStateSnapshot", + "GroupChatRequestMessage", + "GroupChatRequestSentEvent", + "GroupChatResponseReceivedEvent", + "GroupChatState", + "HandoffAgentUserRequest", "HandoffBuilder", - "HandoffUserInputRequest", + "HandoffSentEvent", "InMemoryCheckpointStorage", "InProcRunnerContext", "MagenticBuilder", "MagenticContext", - "MagenticHumanInputRequest", - "MagenticHumanInterventionDecision", - "MagenticHumanInterventionKind", - "MagenticHumanInterventionReply", - "MagenticHumanInterventionRequest", "MagenticManagerBase", - "MagenticStallInterventionDecision", - "MagenticStallInterventionReply", - "MagenticStallInterventionRequest", - "ManagerDirectiveModel", - "ManagerSelectionRequest", - "ManagerSelectionResponse", + "MagenticOrchestrator", + "MagenticOrchestratorEvent", + "MagenticOrchestratorEventType", + "MagenticPlanReviewRequest", + "MagenticPlanReviewResponse", + "MagenticProgressLedger", + "MagenticProgressLedgerItem", + "MagenticResetSignal", "Message", "OrchestrationState", "RequestInfoEvent", - "RequestInfoInterceptor", "Runner", "RunnerContext", "SequentialBuilder", @@ -187,17 +195,21 @@ "WorkflowAgent", "WorkflowBuilder", "WorkflowCheckpoint", + "WorkflowCheckpointException", "WorkflowCheckpointSummary", "WorkflowContext", + "WorkflowConvergenceException", "WorkflowErrorDetails", "WorkflowEvent", "WorkflowEventSource", + "WorkflowException", "WorkflowExecutor", "WorkflowFailedEvent", "WorkflowLifecycleEvent", "WorkflowOutputEvent", "WorkflowRunResult", "WorkflowRunState", + "WorkflowRunnerException", "WorkflowStartedEvent", "WorkflowStatusEvent", "WorkflowValidationError", @@ -206,6 +218,7 @@ "executor", "get_checkpoint_summary", "handler", + "resolve_agent_id", "response_handler", "validate_workflow_graph", ] diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index 7eec2472f0..cd768ffc4d 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -9,8 +9,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentThread, BaseAgent, BaseContent, @@ -126,7 +126,7 @@ async def run( checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: """Get a response from the workflow agent (non-streaming). This method collects all streaming updates and merges them into a single response. @@ -146,10 +146,10 @@ async def run( and ai_function tools. Returns: - The final workflow response as an AgentRunResponse. + The final workflow response as an AgentResponse. """ # Collect all streaming updates - response_updates: list[AgentRunResponseUpdate] = [] + response_updates: list[AgentResponseUpdate] = [] input_messages = normalize_messages_input(messages) thread = thread or self.get_new_thread() response_id = str(uuid.uuid4()) @@ -175,7 +175,7 @@ async def run_stream( checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: """Stream response updates from the workflow agent. Args: @@ -193,11 +193,11 @@ async def run_stream( and ai_function tools. Yields: - AgentRunResponseUpdate objects representing the workflow execution progress. + AgentResponseUpdate objects representing the workflow execution progress. """ input_messages = normalize_messages_input(messages) thread = thread or self.get_new_thread() - response_updates: list[AgentRunResponseUpdate] = [] + response_updates: list[AgentResponseUpdate] = [] response_id = str(uuid.uuid4()) async for update in self._run_stream_impl( @@ -220,7 +220,7 @@ async def _run_stream_impl( checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: """Internal implementation of streaming execution. Args: @@ -233,7 +233,7 @@ async def _run_stream_impl( workflow and ai_function tools. Yields: - AgentRunResponseUpdate objects representing the workflow execution progress. + AgentResponseUpdate objects representing the workflow execution progress. """ # Determine the event stream based on whether we have function responses if bool(self.pending_requests): @@ -289,8 +289,8 @@ def _convert_workflow_event_to_agent_update( self, response_id: str, event: WorkflowEvent, - ) -> AgentRunResponseUpdate | None: - """Convert a workflow event to an AgentRunResponseUpdate. + ) -> AgentResponseUpdate | None: + """Convert a workflow event to an AgentResponseUpdate. AgentRunUpdateEvent, RequestInfoEvent, and WorkflowOutputEvent are processed. Other workflow events are ignored as they are workflow-internal. @@ -315,24 +315,24 @@ def _convert_workflow_event_to_agent_update( return update return None - case WorkflowOutputEvent(data=data, source_executor_id=source_executor_id): + case WorkflowOutputEvent(data=data, executor_id=executor_id): # Convert workflow output to an agent response update. # Handle different data types appropriately. - # Skip AgentRunResponse from AgentExecutor with output_response=True + # Skip AgentResponse from AgentExecutor with output_response=True # since streaming events already surfaced the content. - if isinstance(data, AgentRunResponse): - executor = self.workflow.executors.get(source_executor_id) + if isinstance(data, AgentResponse): + executor = self.workflow.executors.get(executor_id) if isinstance(executor, AgentExecutor) and executor.output_response: return None - if isinstance(data, AgentRunResponseUpdate): + if isinstance(data, AgentResponseUpdate): return data if isinstance(data, ChatMessage): - return AgentRunResponseUpdate( + return AgentResponseUpdate( contents=list(data.contents), role=data.role, - author_name=data.author_name or source_executor_id, + author_name=data.author_name or executor_id, response_id=response_id, message_id=str(uuid.uuid4()), created_at=datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"), @@ -341,10 +341,10 @@ def _convert_workflow_event_to_agent_update( contents = self._extract_contents(data) if not contents: return None - return AgentRunResponseUpdate( + return AgentResponseUpdate( contents=contents, role=Role.ASSISTANT, - author_name=source_executor_id, + author_name=executor_id, response_id=response_id, message_id=str(uuid.uuid4()), created_at=datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"), @@ -367,7 +367,7 @@ def _convert_workflow_event_to_agent_update( function_call=function_call, additional_properties={"request_id": request_id}, ) - return AgentRunResponseUpdate( + return AgentResponseUpdate( contents=[function_call, approval_request], role=Role.ASSISTANT, author_name=self.name, @@ -441,34 +441,52 @@ def _extract_contents(self, data: Any) -> list[Contents]: class _ResponseState(TypedDict): """State for grouping response updates by message_id.""" - by_msg: dict[str, list[AgentRunResponseUpdate]] - dangling: list[AgentRunResponseUpdate] + by_msg: dict[str, list[AgentResponseUpdate]] + dangling: list[AgentResponseUpdate] @staticmethod - def merge_updates(updates: list[AgentRunResponseUpdate], response_id: str) -> AgentRunResponse: - """Merge streaming updates into a single AgentRunResponse. + def merge_updates(updates: list[AgentResponseUpdate], response_id: str) -> AgentResponse: + """Merge streaming updates into a single AgentResponse. Behavior: - Group updates by response_id; within each response_id, group by message_id and keep a dangling bucket for updates without message_id. - - Convert each group (per message and dangling) into an intermediate AgentRunResponse via - AgentRunResponse.from_agent_run_response_updates, then sort by created_at and merge. + - Convert each group (per message and dangling) into an intermediate AgentResponse via + AgentResponse.from_agent_run_response_updates, then sort by created_at and merge. - Append messages from updates without any response_id at the end (global dangling), while aggregating metadata. Args: - updates: The list of AgentRunResponseUpdate objects to merge. - response_id: The response identifier to set on the returned AgentRunResponse. + updates: The list of AgentResponseUpdate objects to merge. + response_id: The response identifier to set on the returned AgentResponse. Returns: - An AgentRunResponse with messages in processing order and aggregated metadata. + An AgentResponse with messages in processing order and aggregated metadata. """ # PHASE 1: GROUP UPDATES BY RESPONSE_ID AND MESSAGE_ID + # First pass: build call_id -> response_id map from FunctionCallContent updates + call_id_to_response_id: dict[str, str] = {} + for u in updates: + if u.response_id: + for content in u.contents: + if isinstance(content, FunctionCallContent) and content.call_id: + call_id_to_response_id[content.call_id] = u.response_id + + # Second pass: group updates, associating FunctionResultContent with their calls states: dict[str, WorkflowAgent._ResponseState] = {} - global_dangling: list[AgentRunResponseUpdate] = [] + global_dangling: list[AgentResponseUpdate] = [] for u in updates: - if u.response_id: - state = states.setdefault(u.response_id, {"by_msg": {}, "dangling": []}) + effective_response_id = u.response_id + # If no response_id, check if this is a FunctionResultContent that matches a call + if not effective_response_id: + for content in u.contents: + if isinstance(content, FunctionResultContent) and content.call_id: + effective_response_id = call_id_to_response_id.get(content.call_id) + if effective_response_id: + break + + if effective_response_id: + state = states.setdefault(effective_response_id, {"by_msg": {}, "dangling": []}) by_msg = state["by_msg"] dangling = state["dangling"] if u.message_id: @@ -497,7 +515,7 @@ def _sum_usage(a: UsageDetails | None, b: UsageDetails | None) -> UsageDetails | return a return a + b - def _merge_responses(current: AgentRunResponse | None, incoming: AgentRunResponse) -> AgentRunResponse: + def _merge_responses(current: AgentResponse | None, incoming: AgentResponse) -> AgentResponse: if current is None: return incoming raw_list: list[object] = [] @@ -512,7 +530,7 @@ def _add_raw(value: object) -> None: _add_raw(current.raw_representation) if incoming.raw_representation is not None: _add_raw(incoming.raw_representation) - return AgentRunResponse( + return AgentResponse( messages=(current.messages or []) + (incoming.messages or []), response_id=current.response_id or incoming.response_id, created_at=incoming.created_at or current.created_at, @@ -533,16 +551,16 @@ def _add_raw(value: object) -> None: by_msg = state["by_msg"] dangling = state["dangling"] - per_message_responses: list[AgentRunResponse] = [] + per_message_responses: list[AgentResponse] = [] for _, msg_updates in by_msg.items(): if msg_updates: - per_message_responses.append(AgentRunResponse.from_agent_run_response_updates(msg_updates)) + per_message_responses.append(AgentResponse.from_agent_run_response_updates(msg_updates)) if dangling: - per_message_responses.append(AgentRunResponse.from_agent_run_response_updates(dangling)) + per_message_responses.append(AgentResponse.from_agent_run_response_updates(dangling)) per_message_responses.sort(key=lambda r: _parse_dt(r.created_at)) - aggregated: AgentRunResponse | None = None + aggregated: AgentResponse | None = None for resp in per_message_responses: if resp.response_id and grouped_response_id and resp.response_id != grouped_response_id: resp.response_id = grouped_response_id @@ -569,8 +587,10 @@ def _add_raw(value: object) -> None: raw_representations.append(cast_value) # PHASE 3: HANDLE GLOBAL DANGLING UPDATES (NO RESPONSE_ID) + # These are updates that couldn't be associated with any response_id + # (e.g., orphan FunctionResultContent with no matching FunctionCallContent) if global_dangling: - flattened = AgentRunResponse.from_agent_run_response_updates(global_dangling) + flattened = AgentResponse.from_agent_run_response_updates(global_dangling) final_messages.extend(flattened.messages) if flattened.usage_details: merged_usage = _sum_usage(merged_usage, flattened.usage_details) @@ -591,7 +611,7 @@ def _add_raw(value: object) -> None: raw_representations.append(cast_flat) # PHASE 4: CONSTRUCT FINAL RESPONSE WITH INPUT RESPONSE_ID - return AgentRunResponse( + return AgentResponse( messages=final_messages, response_id=response_id, created_at=latest_created_at, diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 4e0d2058ad..bcd47caca2 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -9,7 +9,8 @@ from .._agents import AgentProtocol, ChatAgent from .._threads import AgentThread -from .._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage +from .._types import AgentResponse, AgentResponseUpdate, ChatMessage +from ._agent_utils import resolve_agent_id from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value from ._const import WORKFLOW_RUN_KWARGS_KEY from ._conversation_state import encode_chat_messages @@ -50,14 +51,14 @@ class AgentExecutorResponse: Attributes: executor_id: The ID of the executor that generated the response. - agent_run_response: The underlying agent run response (unaltered from client). + agent_response: The underlying agent run response (unaltered from client). full_conversation: The full conversation context (prior inputs + all assistant/tool outputs) that should be used when chaining to another AgentExecutor. This prevents downstream agents losing user prompts while keeping the emitted AgentRunEvent text faithful to the raw agent output. """ executor_id: str - agent_run_response: AgentRunResponse + agent_response: AgentResponse full_conversation: list[ChatMessage] | None = None @@ -84,31 +85,35 @@ def __init__( Args: agent: The agent to be wrapped by this executor. agent_thread: The thread to use for running the agent. If None, a new thread will be created. - output_response: Whether to yield an AgentRunResponse as a workflow output when the agent completes. + output_response: Whether to yield an AgentResponse as a workflow output when the agent completes. id: A unique identifier for the executor. If None, the agent's name will be used if available. """ # Prefer provided id; else use agent.name if present; else generate deterministic prefix - exec_id = id or agent.name + exec_id = id or resolve_agent_id(agent) if not exec_id: - raise ValueError("Agent must have a name or an explicit id must be provided.") + raise ValueError("Agent must have a non-empty name or id or an explicit id must be provided.") super().__init__(exec_id) self._agent = agent self._agent_thread = agent_thread or self._agent.get_new_thread() self._pending_agent_requests: dict[str, FunctionApprovalRequestContent] = {} self._pending_responses_to_agent: list[FunctionApprovalResponseContent] = [] self._output_response = output_response + + # AgentExecutor maintains an internal cache of messages in between runs self._cache: list[ChatMessage] = [] + # This tracks the full conversation after each run + self._full_conversation: list[ChatMessage] = [] @property def output_response(self) -> bool: - """Whether this executor yields AgentRunResponse as workflow output when complete.""" + """Whether this executor yields AgentResponse as workflow output when complete.""" return self._output_response @property def workflow_output_types(self) -> list[type[Any]]: - # Override to declare AgentRunResponse as a possible output type only if enabled. + # Override to declare AgentResponse as a possible output type only if enabled. if self._output_response: - return [AgentRunResponse] + return [AgentResponse] return [] @property @@ -118,7 +123,7 @@ def description(self) -> str | None: @handler async def run( - self, request: AgentExecutorRequest, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse] + self, request: AgentExecutorRequest, ctx: WorkflowContext[AgentExecutorResponse, AgentResponse] ) -> None: """Handle an AgentExecutorRequest (canonical input). @@ -131,22 +136,22 @@ async def run( @handler async def from_response( - self, prior: AgentExecutorResponse, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse] + self, prior: AgentExecutorResponse, ctx: WorkflowContext[AgentExecutorResponse, AgentResponse] ) -> None: """Enable seamless chaining: accept a prior AgentExecutorResponse as input. Strategy: treat the prior response's messages as the conversation state and immediately run the agent to produce a new response. """ - # Replace cache with full conversation if available, else fall back to agent_run_response messages. + # Replace cache with full conversation if available, else fall back to agent_response messages. if prior.full_conversation is not None: self._cache = list(prior.full_conversation) else: - self._cache = list(prior.agent_run_response.messages) + self._cache = list(prior.agent_response.messages) await self._run_agent_and_emit(ctx) @handler - async def from_str(self, text: str, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]) -> None: + async def from_str(self, text: str, ctx: WorkflowContext[AgentExecutorResponse, AgentResponse]) -> None: """Accept a raw user prompt string and run the agent (one-shot).""" self._cache = normalize_messages_input(text) await self._run_agent_and_emit(ctx) @@ -155,7 +160,7 @@ async def from_str(self, text: str, ctx: WorkflowContext[AgentExecutorResponse, async def from_message( self, message: ChatMessage, - ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse], + ctx: WorkflowContext[AgentExecutorResponse, AgentResponse], ) -> None: """Accept a single ChatMessage as input.""" self._cache = normalize_messages_input(message) @@ -165,7 +170,7 @@ async def from_message( async def from_messages( self, messages: list[str | ChatMessage], - ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse], + ctx: WorkflowContext[AgentExecutorResponse, AgentResponse], ) -> None: """Accept a list of chat inputs (strings or ChatMessage) as conversation context.""" self._cache = normalize_messages_input(messages) @@ -176,7 +181,7 @@ async def handle_user_input_response( self, original_request: FunctionApprovalRequestContent, response: FunctionApprovalResponseContent, - ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse], + ctx: WorkflowContext[AgentExecutorResponse, AgentResponse], ) -> None: """Handle user input responses for function approvals during agent execution. @@ -227,6 +232,7 @@ async def on_checkpoint_save(self) -> dict[str, Any]: return { "cache": encode_chat_messages(self._cache), + "full_conversation": encode_chat_messages(self._full_conversation), "agent_thread": serialized_thread, "pending_agent_requests": encode_checkpoint_value(self._pending_agent_requests), "pending_responses_to_agent": encode_checkpoint_value(self._pending_responses_to_agent), @@ -251,6 +257,16 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: else: self._cache = [] + full_conversation_payload = state.get("full_conversation") + if full_conversation_payload: + try: + self._full_conversation = decode_chat_messages(full_conversation_payload) + except Exception as exc: + logger.warning("Failed to restore full conversation: %s", exc) + self._full_conversation = [] + else: + self._full_conversation = [] + thread_payload = state.get("agent_thread") if thread_payload: try: @@ -276,7 +292,7 @@ def reset(self) -> None: logger.debug("AgentExecutor %s: Resetting cache", self.id) self._cache.clear() - async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]) -> None: + async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse, AgentResponse]) -> None: """Execute the underlying agent, emit events, and enqueue response. Checks ctx.is_streaming() to determine whether to emit incremental AgentRunUpdateEvent @@ -289,6 +305,12 @@ async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse, # Non-streaming mode: use run() and emit single event response = await self._run_agent(cast(WorkflowContext, ctx)) + # Always extend full conversation with cached messages plus agent outputs + # (agent_response.messages) after each run. This is to avoid losing context + # when agent did not complete and the cache is cleared when responses come back. + # Do not mutate response.messages so AgentRunEvent remains faithful to the raw output. + self._full_conversation.extend(list(self._cache) + (list(response.messages) if response else [])) + if response is None: # Agent did not complete (e.g., waiting for user input); do not emit response logger.info("AgentExecutor %s: Agent did not complete, awaiting user input", self.id) @@ -297,23 +319,18 @@ async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse, if self._output_response: await ctx.yield_output(response) - # Always construct a full conversation snapshot from inputs (cache) - # plus agent outputs (agent_run_response.messages). Do not mutate - # response.messages so AgentRunEvent remains faithful to the raw output. - full_conversation: list[ChatMessage] = list(self._cache) + list(response.messages) - - agent_response = AgentExecutorResponse(self.id, response, full_conversation=full_conversation) + agent_response = AgentExecutorResponse(self.id, response, full_conversation=self._full_conversation) await ctx.send_message(agent_response) self._cache.clear() - async def _run_agent(self, ctx: WorkflowContext) -> AgentRunResponse | None: + async def _run_agent(self, ctx: WorkflowContext) -> AgentResponse | None: """Execute the underlying agent in non-streaming mode. Args: ctx: The workflow context for emitting events. Returns: - The complete AgentRunResponse, or None if waiting for user input. + The complete AgentResponse, or None if waiting for user input. """ run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) @@ -333,18 +350,18 @@ async def _run_agent(self, ctx: WorkflowContext) -> AgentRunResponse | None: return response - async def _run_agent_streaming(self, ctx: WorkflowContext) -> AgentRunResponse | None: + async def _run_agent_streaming(self, ctx: WorkflowContext) -> AgentResponse | None: """Execute the underlying agent in streaming mode and collect the full response. Args: ctx: The workflow context for emitting events. Returns: - The complete AgentRunResponse, or None if waiting for user input. + The complete AgentResponse, or None if waiting for user input. """ run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] user_input_requests: list[FunctionApprovalRequestContent] = [] async for update in self._agent.run_stream( self._cache, @@ -357,15 +374,15 @@ async def _run_agent_streaming(self, ctx: WorkflowContext) -> AgentRunResponse | if update.user_input_requests: user_input_requests.extend(update.user_input_requests) - # Build the final AgentRunResponse from the collected updates + # Build the final AgentResponse from the collected updates if isinstance(self._agent, ChatAgent): - response_format = self._agent.chat_options.response_format - response = AgentRunResponse.from_agent_run_response_updates( + response_format = self._agent.default_options.get("response_format") + response = AgentResponse.from_agent_run_response_updates( updates, output_format_type=response_format, ) else: - response = AgentRunResponse.from_agent_run_response_updates(updates) + response = AgentResponse.from_agent_run_response_updates(updates) # Handle any user input requests after the streaming completes if user_input_requests: diff --git a/python/packages/core/agent_framework/_workflows/_agent_utils.py b/python/packages/core/agent_framework/_workflows/_agent_utils.py new file mode 100644 index 0000000000..f296f53ab9 --- /dev/null +++ b/python/packages/core/agent_framework/_workflows/_agent_utils.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft. All rights reserved. + +from .._agents import AgentProtocol + + +def resolve_agent_id(agent: AgentProtocol) -> str: + """Resolve the unique identifier for an agent. + + Prefers the `.name` attribute if set; otherwise falls back to `.id`. + + Args: + agent: The agent whose identifier is to be resolved. + + Returns: + The resolved unique identifier for the agent. + """ + return agent.name if agent.name else agent.id diff --git a/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py b/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py index 5576246a8e..026933d777 100644 --- a/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py +++ b/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py @@ -2,16 +2,23 @@ """Base class for group chat orchestrators that manages conversation flow and participant selection.""" +import asyncio import inspect import logging import sys -from abc import ABC, abstractmethod +from abc import ABC +from collections import OrderedDict from collections.abc import Awaitable, Callable, Sequence -from typing import Any +from dataclasses import dataclass +from typing import Any, ClassVar, TypeAlias -from .._types import ChatMessage -from ._executor import Executor -from ._orchestrator_helpers import ParticipantRegistry +from typing_extensions import Never + +from .._types import ChatMessage, Role +from ._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse +from ._events import WorkflowEvent +from ._executor import Executor, handler +from ._orchestration_request_info import AgentApprovalExecutor from ._workflow_context import WorkflowContext if sys.version_info >= (3, 12): @@ -23,6 +30,129 @@ logger = logging.getLogger(__name__) +@dataclass +class GroupChatRequestMessage: + """Request envelope sent from the orchestrator to a participant.""" + + additional_instruction: str | None = None + metadata: dict[str, Any] | None = None + + +@dataclass +class GroupChatParticipantMessage: + """Message envelop containing messages generated by a participant. + + This message envelope is used to broadcast messages from one participant + to other participants in the group chat to keep them synchronized. + """ + + messages: list[ChatMessage] + + +@dataclass +class GroupChatResponseMessage: + """Response envelope emitted by participants back to the orchestrator.""" + + message: ChatMessage + + +TerminationCondition: TypeAlias = Callable[[list[ChatMessage]], bool | Awaitable[bool]] +GroupChatWorkflowContext_T_Out: TypeAlias = AgentExecutorRequest | GroupChatRequestMessage | GroupChatParticipantMessage + + +# region Group chat events +class GroupChatEvent(WorkflowEvent): + """Base class for group chat workflow events.""" + + def __init__(self, round_index: int, data: Any | None = None) -> None: + """Initialize group chat event. + + Args: + round_index: Current round index + data: Optional event-specific data + """ + super().__init__(data) + self.round_index = round_index + + +class GroupChatResponseReceivedEvent(GroupChatEvent): + """Event emitted when a participant response is received.""" + + def __init__(self, round_index: int, participant_name: str, data: Any | None = None) -> None: + """Initialize response received event. + + Args: + round_index: Current round index + participant_name: Name of the participant who sent the response + data: Optional event-specific data + """ + super().__init__(round_index, data) + self.participant_name = participant_name + + +class GroupChatRequestSentEvent(GroupChatEvent): + """Event emitted when a request is sent to a participant.""" + + def __init__(self, round_index: int, participant_name: str, data: Any | None = None) -> None: + """Initialize request sent event. + + Args: + round_index: Current round index + participant_name: Name of the participant to whom the request was sent + data: Optional event-specific data + """ + super().__init__(round_index, data) + self.participant_name = participant_name + + +# endregion + + +# region Participant registry +class ParticipantRegistry: + """Simple registry for tracking group chat participants and their types and other properties.""" + + EMPTY_DESCRIPTION_PLACEHOLDER: ClassVar[str] = ( + "" + ) + + def __init__(self, participants: Sequence[Executor]) -> None: + """Initialize the registry and validate participant IDs. + + Args: + participants: List of executors (agents or custom executors) to register + Raises: + ValueError: If there are duplicate or conflicting participant IDs + """ + self._agents: set[str] = set() + self._participants: OrderedDict[str, str] = OrderedDict() + self._resolve_participants(participants) + + def _resolve_participants(self, participants: Sequence[Executor]) -> None: + """Register participants and validate IDs.""" + for participant in participants: + if participant.id in self._participants: + raise ValueError(f"Participant ID conflict: '{participant.id}' registered as both agent and executor.") + + if isinstance(participant, AgentExecutor | AgentApprovalExecutor): + self._agents.add(participant.id) + self._participants[participant.id] = participant.description or self.EMPTY_DESCRIPTION_PLACEHOLDER + else: + self._participants[participant.id] = self.EMPTY_DESCRIPTION_PLACEHOLDER + + def is_agent(self, name: str) -> bool: + """Check if a participant is an agent (vs custom executor).""" + return name in self._agents + + @property + def participants(self) -> OrderedDict[str, str]: + """Get all registered participant names and descriptions in an ordered dictionary.""" + return self._participants + + +# endregion + + class BaseGroupChatOrchestrator(Executor, ABC): """Abstract base class for group chat orchestrators. @@ -33,36 +163,159 @@ class BaseGroupChatOrchestrator(Executor, ABC): inheriting the common participant management infrastructure. """ - def __init__(self, executor_id: str) -> None: + TERMINATION_CONDITION_MET_MESSAGE: ClassVar[str] = "The group chat has reached its termination condition." + MAX_ROUNDS_MET_MESSAGE: ClassVar[str] = "The group chat has reached the maximum number of rounds." + + def __init__( + self, + id: str, + participant_registry: ParticipantRegistry, + *, + name: str | None = None, + max_rounds: int | None = None, + termination_condition: TerminationCondition | None = None, + ) -> None: """Initialize base orchestrator. Args: - executor_id: Unique identifier for this orchestrator executor + id: Unique identifier for this orchestrator executor + participant_registry: Registry of group chat participants that tracks their types (agents + vs custom executors) + name: Optional display name for orchestrator messages + max_rounds: Optional maximum number of conversation rounds. + Must be equal to or greater than 1 if set. Number smaller than 1 will be coerced to 1. + termination_condition: Optional callable to determine conversation termination """ - super().__init__(executor_id) - self._registry = ParticipantRegistry() - # Shared conversation state management - self._conversation: list[ChatMessage] = [] + super().__init__(id) + self._name = name or id + self._max_rounds = max(1, max_rounds) if max_rounds is not None else None + self._termination_condition = termination_condition self._round_index: int = 0 - self._max_rounds: int | None = None - self._termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] | None = None + self._participant_registry = participant_registry + # Shared conversation state management + self._full_conversation: list[ChatMessage] = [] + + # region Handlers + + @handler + async def handle_str( + self, + task: str, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], + ) -> None: + """Handler for string input as workflow entry point. + + Wraps the string in a USER role ChatMessage and delegates to _handle_task_message. + + Args: + task: Plain text task description from user + ctx: Workflow context - def register_participant_entry( - self, name: str, *, entry_id: str, is_agent: bool, exit_id: str | None = None + Usage: + workflow.run("Write a blog post about AI agents") + """ + await self._handle_messages([ChatMessage(role=Role.USER, text=task)], ctx) + + @handler + async def handle_message( + self, + task: ChatMessage, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Record routing details for a participant's entry executor. + """Handler for single ChatMessage input as workflow entry point. - This method provides a unified interface for registering participants - across all orchestrator patterns, whether they are agents or custom executors. + Wraps the message in a list and delegates to _handle_task_message. Args: - name: Participant name (used for selection and tracking) - entry_id: Executor ID for this participant's entry point - is_agent: Whether this is an AgentExecutor (True) or custom Executor (False) - exit_id: Executor ID for this participant's exit point (where responses come from). - If None, defaults to entry_id. + task: ChatMessage from user + ctx: Workflow context + + Usage: + workflow.run(ChatMessage(role=Role.USER, text="Write a blog post about AI agents")) """ - self._registry.register(name, entry_id=entry_id, is_agent=is_agent, exit_id=exit_id) + await self._handle_messages([task], ctx) + + @handler + async def handle_messages( + self, + task: list[ChatMessage], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], + ) -> None: + """Handler for list of ChatMessages as workflow entry point. + + Delegates to _handle_task_message. + + Args: + task: List of ChatMessages from user + ctx: Workflow context + Usage: + workflow.run([ + ChatMessage(role=Role.USER, text="Write a blog post about AI agents"), + ChatMessage(role=Role.USER, text="Make it engaging and informative.") + ]) + """ + if not task: + raise ValueError("At least one ChatMessage is required to start the group chat workflow.") + await self._handle_messages(task, ctx) + + @handler + async def handle_participant_response( + self, + response: AgentExecutorResponse | GroupChatResponseMessage, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], + ) -> None: + """Handler for participant responses. + + This method can be overridden by subclasses if specific response handling is needed. + + Args: + response: Response from a participant + ctx: Workflow context + """ + await ctx.add_event( + GroupChatResponseReceivedEvent( + round_index=self._round_index, + participant_name=ctx.source_executor_ids[0] if ctx.source_executor_ids else "unknown", + data=response, + ) + ) + await self._handle_response(response, ctx) + + # endregion + + # region Handler methods subclasses must implement + + async def _handle_messages( + self, + messages: list[ChatMessage], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], + ) -> None: + """Handle task messages from users as workflow entry point. + + Subclasses must implement this method to define pattern-specific orchestration logic. + + Args: + messages: Task messages from user + ctx: Workflow context + """ + raise NotImplementedError("_handle_messages must be implemented by subclasses.") + + async def _handle_response( + self, + response: AgentExecutorResponse | GroupChatResponseMessage, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], + ) -> None: + """Handle a participant response. + + Subclasses must implement this method to define pattern-specific response handling logic. + + Args: + response: Response from a participant + ctx: Workflow context + """ + raise NotImplementedError("_handle_response must be implemented by subclasses.") + + # endregion # Conversation state management (shared across all patterns) @@ -72,7 +325,7 @@ def _append_messages(self, messages: Sequence[ChatMessage]) -> None: Args: messages: Messages to append """ - self._conversation.extend(messages) + self._full_conversation.extend(messages) def _get_conversation(self) -> list[ChatMessage]: """Get a copy of the current conversation. @@ -80,11 +333,27 @@ def _get_conversation(self) -> list[ChatMessage]: Returns: Cloned conversation list """ - return list(self._conversation) + return list(self._full_conversation) + + def _process_participant_response( + self, response: AgentExecutorResponse | GroupChatResponseMessage + ) -> list[ChatMessage]: + """Extract ChatMessage from participant response. + + Args: + response: Response from participant + Returns: + List of ChatMessages extracted from the response + """ + if isinstance(response, AgentExecutorResponse): + return response.agent_response.messages + if isinstance(response, GroupChatResponseMessage): + return [response.message] + raise TypeError(f"Unsupported response type: {type(response)}") def _clear_conversation(self) -> None: """Clear the conversation history.""" - self._conversation.clear() + self._full_conversation.clear() def _increment_round(self) -> None: """Increment the round counter.""" @@ -102,97 +371,121 @@ async def _check_termination(self) -> bool: return False result = self._termination_condition(self._get_conversation()) - if inspect.iscoroutine(result) or inspect.isawaitable(result): + if inspect.isawaitable(result): result = await result - return bool(result) + return result - @abstractmethod - def _get_author_name(self) -> str: - """Get the author name for orchestrator-generated messages. + async def _check_terminate_and_yield(self, ctx: WorkflowContext[Never, list[ChatMessage]]) -> bool: + """Check termination conditions and yield completion if met. - Subclasses must implement this to provide a stable author name - for completion messages and other orchestrator-generated content. + Args: + ctx: Workflow context for yielding output Returns: - Author name to use for messages generated by this orchestrator + True if termination condition met and output yielded, False otherwise """ - ... + terminate = await self._check_termination() + if terminate: + self._append_messages([self._create_completion_message(self.TERMINATION_CONDITION_MET_MESSAGE)]) + await ctx.yield_output(self._full_conversation) + return True - def _create_completion_message( - self, - text: str | None = None, - reason: str = "completed", - ) -> ChatMessage: + return False + + def _create_completion_message(self, message: str) -> ChatMessage: """Create a standardized completion message. Args: - text: Optional message text (auto-generated if None) - reason: Completion reason for default text + message: Completion text Returns: ChatMessage with completion content """ - from .._types import Role + return ChatMessage(role=Role.ASSISTANT, text=message, author_name=self._name) + + # Participant routing (shared across all patterns) + + async def _broadcast_messages_to_participants( + self, + messages: list[ChatMessage], + ctx: WorkflowContext[AgentExecutorRequest | GroupChatParticipantMessage], + participants: Sequence[str] | None = None, + ) -> None: + """Broadcast messages to participants. - message_text = text or f"Conversation {reason}." - return ChatMessage( - role=Role.ASSISTANT, - text=message_text, - author_name=self._get_author_name(), + This method sends the given messages to all registered participants + or a specified subset. This acts as a message broadcast mechanism for + participants in the group chat to stay synchronized. + + Args: + messages: Messages to send + ctx: Workflow context for message broadcasting + participants: Optional list of participant names to route to. + If None, routes to all registered participants. + """ + target_participants = ( + participants if participants is not None else list(self._participant_registry.participants) ) - # Participant routing (shared across all patterns) + async def _send_messages(target: str) -> None: + if self._participant_registry.is_agent(target): + # Send messages without requesting a response + await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=False), target_id=target) + else: + # Send messages wrapped in GroupChatParticipantMessage + await ctx.send_message(GroupChatParticipantMessage(messages=messages), target_id=target) - async def _route_to_participant( + await asyncio.gather(*[_send_messages(p) for p in target_participants]) + + async def _send_request_to_participant( self, - participant_name: str, - conversation: list[ChatMessage], - ctx: WorkflowContext[Any, Any], + target: str, + ctx: WorkflowContext[AgentExecutorRequest | GroupChatRequestMessage], *, - instruction: str | None = None, - task: ChatMessage | None = None, + additional_instruction: str | None = None, metadata: dict[str, Any] | None = None, ) -> None: - """Route a conversation to a participant. + """Send a request to a participant. This method handles the dual envelope pattern: - AgentExecutors receive AgentExecutorRequest (messages only) - Custom executors receive GroupChatRequestMessage (full context) Args: - participant_name: Name of the participant to route to - conversation: Conversation history to send + target: Name of the participant to route to ctx: Workflow context for message routing - instruction: Optional instruction from manager/orchestrator - task: Optional task context + additional_instruction: Optional additional instruction for the participant. + This can be used to provide guidance to steer the participant's response. metadata: Optional metadata dict Raises: ValueError: If participant is not registered """ - from ._agent_executor import AgentExecutorRequest - from ._orchestrator_helpers import prepare_participant_request - - entry_id = self._registry.get_entry_id(participant_name) - if entry_id is None: - raise ValueError(f"No registered entry executor for participant '{participant_name}'.") - - if self._registry.is_agent(participant_name): + if self._participant_registry.is_agent(target): # AgentExecutors receive simple message list - await ctx.send_message( - AgentExecutorRequest(messages=conversation, should_respond=True), - target_id=entry_id, + messages: list[ChatMessage] = [] + if additional_instruction: + messages.append(ChatMessage(role=Role.USER, text=additional_instruction)) + request = AgentExecutorRequest(messages=messages, should_respond=True) + await ctx.send_message(request, target_id=target) + await ctx.add_event( + GroupChatRequestSentEvent( + round_index=self._round_index, + participant_name=target, + data=request, + ) ) else: # Custom executors receive full context envelope - request = prepare_participant_request( - participant_name=participant_name, - conversation=conversation, - instruction=instruction or "", - task=task, - metadata=metadata, + request = GroupChatRequestMessage(additional_instruction=additional_instruction, metadata=metadata) # type: ignore[assignment] + await ctx.send_message(request, target_id=target) + await ctx.add_event( + GroupChatRequestSentEvent( + round_index=self._round_index, + participant_name=target, + data=request, + ) ) - await ctx.send_message(request, target_id=entry_id) # Round limit enforcement (shared across all patterns) @@ -217,6 +510,23 @@ def _check_round_limit(self) -> bool: return False + async def _check_round_limit_and_yield(self, ctx: WorkflowContext[Never, list[ChatMessage]]) -> bool: + """Check round limit and yield completion if reached. + + Args: + ctx: Workflow context for yielding output + + Returns: + True if round limit reached and output yielded, False otherwise + """ + reach_max_rounds = self._check_round_limit() + if reach_max_rounds: + self._append_messages([self._create_completion_message(self.MAX_ROUNDS_MET_MESSAGE)]) + await ctx.yield_output(self._full_conversation) + return True + + return False + # State persistence (shared across all patterns) # State persistence (shared across all patterns) @@ -234,8 +544,9 @@ async def on_checkpoint_save(self) -> dict[str, Any]: from ._orchestration_state import OrchestrationState state = OrchestrationState( - conversation=list(self._conversation), + conversation=list(self._full_conversation), round_index=self._round_index, + orchestrator_name=self._name, metadata=self._snapshot_pattern_metadata(), ) return state.to_dict() @@ -263,8 +574,9 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: from ._orchestration_state import OrchestrationState orch_state = OrchestrationState.from_dict(state) - self._conversation = list(orch_state.conversation) + self._full_conversation = list(orch_state.conversation) self._round_index = orch_state.round_index + self._name = orch_state.orchestrator_name self._restore_pattern_metadata(orch_state.metadata) def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None: diff --git a/python/packages/core/agent_framework/_workflows/_concurrent.py b/python/packages/core/agent_framework/_workflows/_concurrent.py index 2900254126..033946afff 100644 --- a/python/packages/core/agent_framework/_workflows/_concurrent.py +++ b/python/packages/core/agent_framework/_workflows/_concurrent.py @@ -10,11 +10,12 @@ from agent_framework import AgentProtocol, ChatMessage, Role -from ._agent_executor import AgentExecutorRequest, AgentExecutorResponse +from ._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse +from ._agent_utils import resolve_agent_id from ._checkpoint import CheckpointStorage from ._executor import Executor, handler from ._message_utils import normalize_messages_input -from ._orchestration_request_info import RequestInfoInterceptor +from ._orchestration_request_info import AgentApprovalExecutor from ._workflow import Workflow from ._workflow_builder import WorkflowBuilder from ._workflow_context import WorkflowContext @@ -78,7 +79,7 @@ class _AggregateAgentConversations(Executor): [ single_user_prompt?, agent1_final_assistant, agent2_final_assistant, ... ] - Extracts a single user prompt (first user message seen across results). - - For each result, selects the final assistant message (prefers agent_run_response.messages). + - For each result, selects the final assistant message (prefers agent_response.messages). - Avoids duplicating the same user message per agent. """ @@ -106,7 +107,7 @@ def _is_role(msg: Any, role: Role) -> bool: assistant_replies: list[ChatMessage] = [] for r in results: - resp_messages = list(getattr(r.agent_run_response, "messages", []) or []) + resp_messages = list(getattr(r.agent_response, "messages", []) or []) conv = r.full_conversation if r.full_conversation is not None else resp_messages logger.debug( @@ -212,7 +213,7 @@ class ConcurrentBuilder: # Custom aggregator via callback (sync or async). The callback receives # list[AgentExecutorResponse] and its return value becomes the workflow's output. def summarize(results: list[AgentExecutorResponse]) -> str: - return " | ".join(r.agent_run_response.messages[-1].text for r in results) + return " | ".join(r.agent_response.messages[-1].text for r in results) workflow = ConcurrentBuilder().participants([agent1, agent2, agent3]).with_aggregator(summarize).build() @@ -222,7 +223,7 @@ def summarize(results: list[AgentExecutorResponse]) -> str: class MyAggregator(Executor): @handler async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: - await ctx.yield_output(" | ".join(r.agent_run_response.messages[-1].text for r in results)) + await ctx.yield_output(" | ".join(r.agent_response.messages[-1].text for r in results)) workflow = ( @@ -247,6 +248,7 @@ def __init__(self) -> None: self._aggregator_factory: Callable[[], Executor] | None = None self._checkpoint_storage: CheckpointStorage | None = None self._request_info_enabled: bool = False + self._request_info_filter: set[str] | None = None def register_participants( self, @@ -414,7 +416,7 @@ def with_aggregator( class CustomAggregator(Executor): @handler async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext) -> None: - await ctx.yield_output(" | ".join(r.agent_run_response.messages[-1].text for r in results)) + await ctx.yield_output(" | ".join(r.agent_response.messages[-1].text for r in results)) wf = ConcurrentBuilder().participants([a1, a2, a3]).with_aggregator(CustomAggregator()).build() @@ -422,7 +424,7 @@ async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowCon # Callback-based aggregator (string result) async def summarize(results: list[AgentExecutorResponse]) -> str: - return " | ".join(r.agent_run_response.messages[-1].text for r in results) + return " | ".join(r.agent_response.messages[-1].text for r in results) wf = ConcurrentBuilder().participants([a1, a2, a3]).with_aggregator(summarize).build() @@ -430,7 +432,7 @@ async def summarize(results: list[AgentExecutorResponse]) -> str: # Callback-based aggregator (yield result) async def summarize(results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: - await ctx.yield_output(" | ".join(r.agent_run_response.messages[-1].text for r in results)) + await ctx.yield_output(" | ".join(r.agent_response.messages[-1].text for r in results)) wf = ConcurrentBuilder().participants([a1, a2, a3]).with_aggregator(summarize).build() @@ -461,25 +463,68 @@ def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "Concurre self._checkpoint_storage = checkpoint_storage return self - def with_request_info(self) -> "ConcurrentBuilder": - """Enable request info before aggregation in the workflow. + def with_request_info( + self, + *, + agents: Sequence[str | AgentProtocol] | None = None, + ) -> "ConcurrentBuilder": + """Enable request info after agent participant responses. + + This enables human-in-the-loop (HIL) scenarios for the sequential orchestration. + When enabled, the workflow pauses after each agent participant runs, emitting + a RequestInfoEvent that allows the caller to review the conversation and optionally + inject guidance for the agent participant to iterate. The caller provides input via + the standard response_handler/request_info pattern. + + Simulated flow with HIL: + Input -> [Agent Participant <-> Request Info] -> [Agent Participant <-> Request Info] -> ... - When enabled, the workflow pauses after all parallel agents complete, - emitting a RequestInfoEvent that allows the caller to review and optionally - modify the combined results before aggregation. The caller provides feedback - via the standard response_handler/request_info pattern. + Note: This is only available for agent participants. Executor participants can incorporate + request info handling in their own implementation if desired. - Note: - Unlike SequentialBuilder and GroupChatBuilder, ConcurrentBuilder does not - support per-agent filtering since all agents run in parallel and results - are collected together. The pause occurs once with all agent outputs received. + Args: + agents: Optional list of agents names or agent factories to enable request info for. + If None, enables HIL for all agent participants. Returns: - self: The builder instance for fluent chaining. + Self for fluent chaining """ + from ._orchestration_request_info import resolve_request_info_filter + self._request_info_enabled = True + self._request_info_filter = resolve_request_info_filter(list(agents) if agents else None) + return self + def _resolve_participants(self) -> list[Executor]: + """Resolve participant instances into Executor objects.""" + participants: list[Executor | AgentProtocol] = [] + if self._participant_factories: + # Resolve the participant factories now. This doesn't break the factory pattern + # since the Sequential builder still creates new instances per workflow build. + for factory in self._participant_factories: + p = factory() + participants.append(p) + else: + participants = self._participants + + executors: list[Executor] = [] + for p in participants: + if isinstance(p, Executor): + executors.append(p) + elif isinstance(p, AgentProtocol): + if self._request_info_enabled and ( + not self._request_info_filter or resolve_agent_id(p) in self._request_info_filter + ): + # Handle request info enabled agents + executors.append(AgentApprovalExecutor(p)) + else: + executors.append(AgentExecutor(p)) + else: + raise TypeError(f"Participants must be AgentProtocol or Executor instances. Got {type(p).__name__}.") + + return executors + def build(self) -> Workflow: r"""Build and validate the concurrent workflow. @@ -521,29 +566,15 @@ def build(self) -> Workflow: ) ) - participants: list[Executor | AgentProtocol] = [] - if self._participant_factories: - # Resolve the participant factories now. This doesn't break the factory pattern - # since the Concurrent builder still creates new instances per workflow build. - for factory in self._participant_factories: - p = factory() - participants.append(p) - else: - participants = self._participants + # Resolve participants and participant factories to executors + participants: list[Executor] = self._resolve_participants() builder = WorkflowBuilder() builder.set_start_executor(dispatcher) + # Fan-out for parallel execution builder.add_fan_out_edges(dispatcher, participants) - - if self._request_info_enabled: - # Insert interceptor between fan-in and aggregator - # participants -> fan-in -> interceptor -> aggregator - request_info_interceptor = RequestInfoInterceptor(executor_id="request_info") - builder.add_fan_in_edges(participants, request_info_interceptor) - builder.add_edge(request_info_interceptor, aggregator) - else: - # Direct fan-in to aggregator - builder.add_fan_in_edges(participants, aggregator) + # Direct fan-in to aggregator + builder.add_fan_in_edges(participants, aggregator) if self._checkpoint_storage is not None: builder = builder.with_checkpointing(self._checkpoint_storage) diff --git a/python/packages/core/agent_framework/_workflows/_edge.py b/python/packages/core/agent_framework/_workflows/_edge.py index 87a6f7af2b..02ca1722dd 100644 --- a/python/packages/core/agent_framework/_workflows/_edge.py +++ b/python/packages/core/agent_framework/_workflows/_edge.py @@ -1,10 +1,11 @@ # Copyright (c) Microsoft. All rights reserved. +import inspect import logging import uuid -from collections.abc import Callable, Sequence +from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass, field -from typing import Any, ClassVar +from typing import Any, ClassVar, TypeAlias, TypeVar from ._const import INTERNAL_SOURCE_ID from ._executor import Executor @@ -12,6 +13,10 @@ logger = logging.getLogger(__name__) +# Type alias for edge condition functions. +# Conditions receive the message data and return bool (sync or async). +EdgeCondition: TypeAlias = Callable[[Any], bool | Awaitable[bool]] + def _extract_function_name(func: Callable[..., Any]) -> str: """Map a Python callable to a concise, human-focused identifier. @@ -71,12 +76,13 @@ class Edge(DictConvertible): serialising the edge down to primitives we can reconstruct the topology of a workflow irrespective of the original Python process. + Edge conditions receive the message data and return a boolean (sync or async). + Examples: .. code-block:: python - edge = Edge(source_id="ingest", target_id="score", condition=lambda payload: payload["ready"]) - assert edge.should_route({"ready": True}) is True - assert edge.should_route({"ready": False}) is False + edge = Edge(source_id="ingest", target_id="score", condition=lambda data: data["ready"]) + assert await edge.should_route({"ready": True}) is True """ ID_SEPARATOR: ClassVar[str] = "->" @@ -84,13 +90,13 @@ class Edge(DictConvertible): source_id: str target_id: str condition_name: str | None - _condition: Callable[[Any], bool] | None = field(default=None, repr=False, compare=False) + _condition: EdgeCondition | None = field(default=None, repr=False, compare=False) def __init__( self, source_id: str, target_id: str, - condition: Callable[[Any], bool] | None = None, + condition: EdgeCondition | None = None, *, condition_name: str | None = None, ) -> None: @@ -103,9 +109,9 @@ def __init__( target_id: Canonical identifier of the downstream executor instance. condition: - Optional predicate that receives the message payload and returns - `True` when the edge should be traversed. When omitted, the edge is - considered unconditionally active. + Optional predicate that receives the message data and returns + `True` when the edge should be traversed. Can be sync or async. + When omitted, the edge is unconditionally active. condition_name: Optional override that pins a human-friendly name for the condition when the callable cannot be introspected (for example after @@ -125,7 +131,9 @@ def __init__( self.source_id = source_id self.target_id = target_id self._condition = condition - self.condition_name = _extract_function_name(condition) if condition is not None else condition_name + self.condition_name = ( + _extract_function_name(condition) if condition is not None and condition_name is None else condition_name + ) @property def id(self) -> str: @@ -144,8 +152,16 @@ def id(self) -> str: """ return f"{self.source_id}{self.ID_SEPARATOR}{self.target_id}" - def should_route(self, data: Any) -> bool: - """Evaluate the edge predicate against an incoming payload. + @property + def has_condition(self) -> bool: + """Check if this edge has a condition. + + Returns True if the edge was configured with a condition function. + """ + return self._condition is not None + + async def should_route(self, data: Any) -> bool: + """Evaluate the edge predicate against payload. When the edge was defined without an explicit predicate the method returns `True`, signalling an unconditional routing rule. Otherwise the @@ -153,16 +169,27 @@ def should_route(self, data: Any) -> bool: this edge. Any exception raised by the callable is deliberately allowed to surface to the caller to avoid masking logic bugs. + The condition receives the message data and may be sync or async. + + Args: + data: The message payload + + Returns: + True if the edge should be traversed, False otherwise. + Examples: .. code-block:: python - edge = Edge("stage1", "stage2", condition=lambda payload: payload["score"] > 0.8) - assert edge.should_route({"score": 0.9}) is True - assert edge.should_route({"score": 0.4}) is False + edge = Edge("stage1", "stage2", condition=lambda data: data["score"] > 0.8) + assert await edge.should_route({"score": 0.9}) is True + assert await edge.should_route({"score": 0.4}) is False """ if self._condition is None: return True - return self._condition(data) + result = self._condition(data) + if inspect.isawaitable(result): + return bool(await result) + return bool(result) def to_dict(self) -> dict[str, Any]: """Produce a JSON-serialisable view of the edge metadata. @@ -281,6 +308,8 @@ class EdgeGroup(DictConvertible): from builtins import type as builtin_type + _T_EdgeGroup = TypeVar("_T_EdgeGroup", bound="EdgeGroup") + _TYPE_REGISTRY: ClassVar[dict[str, builtin_type["EdgeGroup"]]] = {} def __init__( @@ -363,7 +392,7 @@ def to_dict(self) -> dict[str, Any]: } @classmethod - def register(cls, subclass: builtin_type["EdgeGroup"]) -> builtin_type["EdgeGroup"]: + def register(cls, subclass: builtin_type[_T_EdgeGroup]) -> builtin_type[_T_EdgeGroup]: """Register a subclass so deserialisation can recover the right type. Registration is typically performed via the decorator syntax applied to @@ -443,12 +472,18 @@ def __init__( self, source_id: str, target_id: str, - condition: Callable[[Any], bool] | None = None, + condition: EdgeCondition | None = None, *, id: str | None = None, ) -> None: """Create a one-to-one edge group between two executors. + Args: + source_id: The source executor ID. + target_id: The target executor ID. + condition: Optional condition function `(data) -> bool | Awaitable[bool]`. + id: Optional explicit ID for the edge group. + Examples: .. code-block:: python diff --git a/python/packages/core/agent_framework/_workflows/_edge_runner.py b/python/packages/core/agent_framework/_workflows/_edge_runner.py index 0aa4139c48..8255f8f79c 100644 --- a/python/packages/core/agent_framework/_workflows/_edge_runner.py +++ b/python/packages/core/agent_framework/_workflows/_edge_runner.py @@ -112,7 +112,9 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R return False if self._can_handle(self._edge.target_id, message): - if self._edge.should_route(message.data): + route_result = await self._edge.should_route(message.data) + + if route_result: span.set_attributes({ OtelAttr.EDGE_GROUP_DELIVERED: True, OtelAttr.EDGE_GROUP_DELIVERY_STATUS: EdgeGroupDeliveryStatus.DELIVERED.value, @@ -162,8 +164,8 @@ def __init__(self, edge_group: FanOutEdgeGroup, executors: dict[str, Executor]) async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: """Send a message through all edges in the fan-out edge group.""" - deliverable_edges = [] - single_target_edge = None + deliverable_edges: list[Edge] = [] + single_target_edge: Edge | None = None # Process routing logic within span with create_edge_group_processing_span( self._edge_group.__class__.__name__, @@ -192,7 +194,9 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R if message.target_id in selection_results: edge = self._target_map.get(message.target_id) if edge and self._can_handle(edge.target_id, message): - if edge.should_route(message.data): + route_result = await edge.should_route(message.data) + + if route_result: span.set_attributes({ OtelAttr.EDGE_GROUP_DELIVERED: True, OtelAttr.EDGE_GROUP_DELIVERY_STATUS: EdgeGroupDeliveryStatus.DELIVERED.value, @@ -223,8 +227,10 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R # If no target ID, send the message to the selected targets for target_id in selection_results: edge = self._target_map[target_id] - if self._can_handle(edge.target_id, message) and edge.should_route(message.data): - deliverable_edges.append(edge) + if self._can_handle(edge.target_id, message): + route_result = await edge.should_route(message.data) + if route_result: + deliverable_edges.append(edge) if len(deliverable_edges) > 0: span.set_attributes({ diff --git a/python/packages/core/agent_framework/_workflows/_events.py b/python/packages/core/agent_framework/_workflows/_events.py index 57c600519d..dcd6ab5866 100644 --- a/python/packages/core/agent_framework/_workflows/_events.py +++ b/python/packages/core/agent_framework/_workflows/_events.py @@ -8,7 +8,7 @@ from enum import Enum from typing import Any, TypeAlias -from agent_framework import AgentRunResponse, AgentRunResponseUpdate +from agent_framework import AgentResponse, AgentResponseUpdate from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value from ._typing_utils import deserialize_type, serialize_type @@ -278,20 +278,20 @@ class WorkflowOutputEvent(WorkflowEvent): def __init__( self, data: Any, - source_executor_id: str, + executor_id: str, ): """Initialize the workflow output event. Args: data: The output yielded by the executor. - source_executor_id: ID of the executor that yielded the output. + executor_id: ID of the executor that yielded the output. """ super().__init__(data) - self.source_executor_id = source_executor_id + self.executor_id = executor_id def __repr__(self) -> str: """Return a string representation of the workflow output event.""" - return f"{self.__class__.__name__}(data={self.data}, source_executor_id={self.source_executor_id})" + return f"{self.__class__.__name__}(data={self.data}, executor_id={self.executor_id})" class SuperStepEvent(WorkflowEvent): @@ -367,9 +367,9 @@ def __repr__(self) -> str: # pragma: no cover - representation only class AgentRunUpdateEvent(ExecutorEvent): """Event triggered when an agent is streaming messages.""" - data: AgentRunResponseUpdate | None + data: AgentResponseUpdate - def __init__(self, executor_id: str, data: AgentRunResponseUpdate | None = None): + def __init__(self, executor_id: str, data: AgentResponseUpdate): """Initialize the agent streaming event.""" super().__init__(executor_id, data) @@ -381,9 +381,9 @@ def __repr__(self) -> str: class AgentRunEvent(ExecutorEvent): """Event triggered when an agent run is completed.""" - data: AgentRunResponse | None + data: AgentResponse - def __init__(self, executor_id: str, data: AgentRunResponse | None = None): + def __init__(self, executor_id: str, data: AgentResponse): """Initialize the agent run event.""" super().__init__(executor_id, data) diff --git a/python/packages/core/agent_framework/_workflows/_exceptions.py b/python/packages/core/agent_framework/_workflows/_exceptions.py new file mode 100644 index 0000000000..2c35395da0 --- /dev/null +++ b/python/packages/core/agent_framework/_workflows/_exceptions.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft. All rights reserved. + +from ..exceptions import AgentFrameworkException + + +class WorkflowException(AgentFrameworkException): + """Base exception for workflow errors.""" + + pass + + +class WorkflowRunnerException(WorkflowException): + """Base exception for workflow runner errors.""" + + pass + + +class WorkflowConvergenceException(WorkflowRunnerException): + """Exception raised when a workflow runner fails to converge within the maximum iterations.""" + + pass + + +class WorkflowCheckpointException(WorkflowRunnerException): + """Exception raised for errors related to workflow checkpoints.""" + + pass diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index fad1e5f15e..49f3dafd06 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -250,6 +250,8 @@ async def execute( ): # Find the handler and handler spec that matches the message type. handler = self._find_handler(message) + + original_message = message if isinstance(message, Message): # Unwrap raw data for handler call message = message.data @@ -261,6 +263,9 @@ async def execute( runner_context=runner_context, trace_contexts=trace_contexts, source_span_ids=source_span_ids, + request_id=original_message.original_request_info_event.request_id + if isinstance(original_message, Message) and original_message.original_request_info_event + else None, ) # Invoke the handler with the message and context @@ -291,6 +296,7 @@ def _create_context_for_handler( runner_context: RunnerContext, trace_contexts: list[dict[str, str]] | None = None, source_span_ids: list[str] | None = None, + request_id: str | None = None, ) -> WorkflowContext[Any]: """Create the appropriate WorkflowContext based on the handler's context annotation. @@ -300,6 +306,7 @@ def _create_context_for_handler( runner_context: The runner context that provides methods to send messages and events. trace_contexts: Optional trace contexts from multiple sources for OpenTelemetry propagation. source_span_ids: Optional source span IDs from multiple sources for linking. + request_id: Optional request ID if this context is for a `handle_response` handler. Returns: WorkflowContext[Any] based on the handler's context annotation. @@ -312,6 +319,7 @@ def _create_context_for_handler( runner_context=runner_context, trace_contexts=trace_contexts, source_span_ids=source_span_ids, + request_id=request_id, ) def _discover_handlers(self) -> None: @@ -356,7 +364,17 @@ def can_handle(self, message: Message) -> bool: True if the executor can handle the message type, False otherwise. """ if message.type == MessageType.RESPONSE: - return any(is_instance_of(message.data, message_type) for message_type in self._response_handlers) + if message.original_request_info_event is None: + logger.warning( + f"Executor {self.__class__.__name__} received a response message without an original request event." + ) + return False + + return any( + is_instance_of(message.original_request_info_event.data, message_type[0]) + and is_instance_of(message.data, message_type[1]) + for message_type in self._response_handlers + ) return any(is_instance_of(message.data, message_type) for message_type in self._handlers) @@ -427,7 +445,7 @@ def workflow_output_types(self) -> list[type[Any]]: output_types: set[type[Any]] = set() # Collect workflow output types from all handlers - for handler_spec in self._handler_specs: + for handler_spec in self._handler_specs + self._response_handler_specs: handler_workflow_output_types = handler_spec.get("workflow_output_types", []) output_types.update(handler_workflow_output_types) @@ -457,11 +475,15 @@ def _find_handler(self, message: Any) -> Callable[[Any, WorkflowContext[Any, Any f"Executor {self.__class__.__name__} cannot handle message of type {type(message.data)}." ) # Response message case - find response handler based on original request and response types - handler = self._find_response_handler(message.original_request, message.data) + if message.original_request_info_event is None: + raise RuntimeError( + f"Executor {self.__class__.__name__} received a response message without an original request event." + ) + handler = self._find_response_handler(message.original_request_info_event.data, message.data) if not handler: raise RuntimeError( f"Executor {self.__class__.__name__} cannot handle request of type " - f"{type(message.original_request)} and response of type {type(message.data)}." + f"{type(message.original_request_info_event.data)} and response of type {type(message.data)}." ) return handler diff --git a/python/packages/core/agent_framework/_workflows/_function_executor.py b/python/packages/core/agent_framework/_workflows/_function_executor.py index 417a4ee51b..d7b68c10fd 100644 --- a/python/packages/core/agent_framework/_workflows/_function_executor.py +++ b/python/packages/core/agent_framework/_workflows/_function_executor.py @@ -17,13 +17,19 @@ import asyncio import inspect +import sys import typing from collections.abc import Awaitable, Callable -from typing import Any, overload +from typing import Any from ._executor import Executor from ._workflow_context import WorkflowContext, validate_workflow_context_annotation +if sys.version_info >= (3, 11): + from typing import overload # pragma: no cover +else: + from typing_extensions import overload # pragma: no cover + class FunctionExecutor(Executor): """Executor that wraps a user-defined function. diff --git a/python/packages/core/agent_framework/_workflows/_group_chat.py b/python/packages/core/agent_framework/_workflows/_group_chat.py index 725a5c829c..d75b805514 100644 --- a/python/packages/core/agent_framework/_workflows/_group_chat.py +++ b/python/packages/core/agent_framework/_workflows/_group_chat.py @@ -2,15 +2,15 @@ """Group chat orchestration primitives. -This module introduces a reusable orchestration surface for manager-directed +This module introduces a reusable orchestration surface for orchestrator-directed multi-agent conversations. The key components are: - GroupChatRequestMessage / GroupChatResponseMessage: canonical envelopes used between the orchestrator and participants. -- Group chat managers: minimal asynchronous callables for pluggable coordination logic. -- GroupChatOrchestratorExecutor: runtime state machine that delegates to a - manager to select the next participant or complete the task. -- GroupChatBuilder: high-level builder that wires managers and participants +- GroupChatSelectionFunction: asynchronous callable for pluggable speaker selection logic. +- GroupChatOrchestrator: runtime state machine that delegates to a + selection function to select the next participant or complete the task. +- GroupChatBuilder: high-level builder that wires orchestrators and participants into a workflow graph. It mirrors the ergonomics of SequentialBuilder and ConcurrentBuilder while allowing Magentic to reuse the same infrastructure. @@ -19,1722 +19,723 @@ """ import inspect -import itertools import logging -from collections.abc import Awaitable, Callable, Mapping, Sequence -from dataclasses import dataclass, field -from types import MappingProxyType -from typing import Any, TypeAlias, cast -from uuid import uuid4 +import sys +from collections import OrderedDict +from collections.abc import Awaitable, Callable, Sequence +from dataclasses import dataclass +from typing import Any, ClassVar, cast from pydantic import BaseModel, Field +from typing_extensions import Never from .._agents import AgentProtocol, ChatAgent +from .._threads import AgentThread from .._types import ChatMessage, Role -from ._agent_executor import AgentExecutorRequest, AgentExecutorResponse -from ._base_group_chat_orchestrator import BaseGroupChatOrchestrator +from ._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse +from ._agent_utils import resolve_agent_id +from ._base_group_chat_orchestrator import ( + BaseGroupChatOrchestrator, + GroupChatParticipantMessage, + GroupChatRequestMessage, + GroupChatResponseMessage, + GroupChatWorkflowContext_T_Out, + ParticipantRegistry, + TerminationCondition, +) from ._checkpoint import CheckpointStorage -from ._conversation_history import ensure_author, latest_user_message -from ._executor import Executor, handler -from ._orchestration_request_info import RequestInfoInterceptor -from ._participant_utils import GroupChatParticipantSpec, prepare_participant_metadata, wrap_participant +from ._conversation_state import decode_chat_messages, encode_chat_messages +from ._executor import Executor +from ._orchestration_request_info import AgentApprovalExecutor from ._workflow import Workflow from ._workflow_builder import WorkflowBuilder from ._workflow_context import WorkflowContext -logger = logging.getLogger(__name__) - - -# region Message primitives - - -@dataclass -class _GroupChatRequestMessage: - """Internal: Request envelope sent from the orchestrator to a participant.""" - - agent_name: str - conversation: list[ChatMessage] = field(default_factory=list) # type: ignore - instruction: str = "" - task: ChatMessage | None = None - metadata: dict[str, Any] | None = None - - -@dataclass -class _GroupChatResponseMessage: - """Internal: Response envelope emitted by participants back to the orchestrator.""" - - agent_name: str - message: ChatMessage - - -@dataclass -class _GroupChatTurn: - """Internal: Represents a single turn in the manager-participant conversation.""" - - speaker: str - role: str - message: ChatMessage - - -@dataclass -class GroupChatDirective: - """Instruction emitted by a group chat manager implementation.""" - - agent_name: str | None = None - instruction: str | None = None - metadata: dict[str, Any] | None = None - finish: bool = False - final_message: ChatMessage | None = None +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override - -@dataclass -class ManagerSelectionRequest: - """Request sent to manager agent for next speaker selection. - - This dataclass packages the full conversation state and task context - for the manager agent to analyze and make a speaker selection decision. - - Attributes: - task: Original user task message - participants: Mapping of participant names to their descriptions - conversation: Full conversation history including all messages - round_index: Number of manager selection rounds completed so far - metadata: Optional metadata for extensibility - """ - - task: ChatMessage - participants: dict[str, str] # type: ignore - conversation: list[ChatMessage] # type: ignore - round_index: int - metadata: dict[str, Any] | None = None - - def to_dict(self) -> dict[str, Any]: - """Convert to dictionary for serialization.""" - return { - "task": self.task.to_dict(), - "participants": dict(self.participants), - "conversation": [msg.to_dict() for msg in self.conversation], - "round_index": self.round_index, - "metadata": self.metadata, - } - - -class ManagerSelectionResponse(BaseModel): - """Response from manager agent with speaker selection decision. - - The manager agent must produce this structure (or compatible dict/JSON) - to communicate its decision back to the orchestrator. - - Attributes: - selected_participant: Name of participant to speak next (None = finish conversation) - instruction: Optional instruction to provide to the selected participant - finish: Whether the conversation should be completed - final_message: Optional final message string when finishing conversation (will be converted to ChatMessage) - """ - - model_config = { - "extra": "forbid", - # OpenAI strict mode requires all properties to be in required array - "json_schema_extra": {"required": ["selected_participant", "instruction", "finish", "final_message"]}, - } - - selected_participant: str | None = None - instruction: str | None = None - finish: bool = False - final_message: str | None = Field(default=None, description="Optional text content for final message") - - @staticmethod - def from_dict(data: dict[str, Any]) -> "ManagerSelectionResponse": - """Create from dictionary representation.""" - return ManagerSelectionResponse( - selected_participant=data.get("selected_participant"), - instruction=data.get("instruction"), - finish=data.get("finish", False), - final_message=data.get("final_message"), - ) - - def get_final_message_as_chat_message(self) -> ChatMessage | None: - """Convert final_message string to ChatMessage if present.""" - if self.final_message: - return ChatMessage(role=Role.ASSISTANT, text=self.final_message) - return None - - -# endregion - - -# region Manager callable - - -GroupChatStateSnapshot = Mapping[str, Any] -_GroupChatManagerFn = Callable[[GroupChatStateSnapshot], Awaitable[GroupChatDirective]] - - -async def _maybe_await(value: Any) -> Any: - """Await value if it is awaitable; otherwise return as-is.""" - if inspect.isawaitable(value): - return await value - return value - - -_GroupChatParticipantPipeline: TypeAlias = Sequence[Executor] +logger = logging.getLogger(__name__) -@dataclass -class _GroupChatConfig: - """Internal: Configuration passed to factories during workflow assembly. +@dataclass(frozen=True) +class GroupChatState: + """Immutable state of the group chat for the selection function to determine the next speaker. Attributes: - manager: Manager callable for orchestration decisions (used by set_select_speakers_func) - manager_participant: Manager agent/executor instance (used by set_manager) - manager_name: Display name for the manager in conversation history - participants: Mapping of participant names to their specifications - max_rounds: Optional limit on manager selection rounds to prevent infinite loops - termination_condition: Optional callable that halts the conversation when it returns True - orchestrator: Orchestrator executor instance (populated during build) - participant_aliases: Mapping of aliases to executor IDs - participant_executors: Mapping of participant names to their executor instances + current_round: The current round index of the group chat, starting from 0. + participants: A mapping of participant names to their descriptions in the group chat. + conversation: The full conversation history up to this point as a list of ChatMessage. """ - manager: _GroupChatManagerFn | None - manager_participant: AgentProtocol | Executor | None - manager_name: str - participants: Mapping[str, GroupChatParticipantSpec] - max_rounds: int | None = None - termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] | None = None - orchestrator: Executor | None = None - participant_aliases: dict[str, str] = field(default_factory=dict) # type: ignore[type-arg] - participant_executors: dict[str, Executor] = field(default_factory=dict) # type: ignore[type-arg] - - -# endregion - - -# region Default participant factory - -_GroupChatOrchestratorFactory: TypeAlias = Callable[[_GroupChatConfig], Executor] -_InterceptorSpec: TypeAlias = tuple[Callable[[_GroupChatConfig], Executor], Callable[[Any], bool]] - - -def _default_participant_factory( - spec: GroupChatParticipantSpec, - wiring: _GroupChatConfig, -) -> _GroupChatParticipantPipeline: - """Default factory for constructing participant pipeline nodes in the workflow graph. - - Creates a single AgentExecutor node for AgentProtocol participants or a passthrough executor - for custom participants. Translation between group-chat envelopes and the agent runtime is now - handled inside the orchestrator, removing the need for dedicated ingress/egress adapters. - - Args: - spec: Participant specification containing name, instance, and description - wiring: GroupChatWiring configuration for accessing cached executors - - Returns: - Sequence of executors representing the participant pipeline in execution order + # Round index, starting from 0 + current_round: int + # participant name to description mapping as a ordered dict + participants: OrderedDict[str, str] + # Full conversation history up to this point + conversation: list[ChatMessage] - Behavior: - - AgentProtocol participants are wrapped in AgentExecutor with deterministic IDs - - Executor participants are wired directly without additional adapters - """ - participant = spec.participant - if isinstance(participant, Executor): - return (participant,) - cached = wiring.participant_executors.get(spec.name) - if cached is not None: - return (cached,) +# region Default orchestrator - agent_executor = wrap_participant(participant, executor_id=f"groupchat_agent:{spec.name}") - return (agent_executor,) +# Type alias for the selection function used by the orchestrator to choose the next speaker. +GroupChatSelectionFunction = Callable[[GroupChatState], Awaitable[str] | str] -# endregion +class GroupChatOrchestrator(BaseGroupChatOrchestrator): + """Orchestrator that manages a group chat between multiple participants. -# region Default orchestrator + This group chat orchestrator operates under the direction of a selection function + provided at initialization. The selection function receives the current state of + the group chat and returns the name of the next participant to speak. + This orchestrator drives the conversation loop as follows: + 1. Receives initial messages, saves to history, and broadcasts to all participants + 2. Invokes the selection function to determine the next speaker based on the most recent state + 3. Sends a request to the selected participant to generate a response + 4. Receives the participant's response, saves to history, and broadcasts to all participants + except the one that just spoke + 5. Repeats steps 2-4 until the termination conditions are met -class GroupChatOrchestratorExecutor(BaseGroupChatOrchestrator): - """Executor that orchestrates a group chat between multiple participants using a manager. - - This is the central runtime state machine that drives multi-agent conversations. It - maintains conversation state, delegates speaker selection to a manager, routes messages - to participants, and collects responses in a loop until the manager signals completion. - - Core responsibilities: - - Accept initial input as str, ChatMessage, or list[ChatMessage] - - Maintain conversation history and turn tracking - - Query manager for next action (select participant or finish) - - Route requests to selected participants using AgentExecutorRequest or GroupChatRequestMessage - - Collect participant responses and append to conversation - - Enforce optional round limits to prevent infinite loops - - Yield final completion message and transition to idle state - - State management: - - _conversation: Growing list of all messages (user, manager, agents) - - _history: Turn-by-turn record with speaker attribution and roles - - _task_message: Original user task extracted from input - - _pending_agent: Name of agent currently processing a request - - _round_index: Count of manager selection rounds for limit enforcement - - Manager interaction: - The orchestrator builds immutable state snapshots and passes them to the manager - callable. The manager returns a GroupChatDirective indicating either: - - Next participant to speak (with optional instruction) - - Finish signal (with optional final message) - - Message flow topology: - User input -> orchestrator -> manager -> orchestrator -> participant -> orchestrator - (loops until manager returns finish directive) - - Why this design: - - Separates orchestration logic (this class) from selection logic (manager) - - Manager is stateless and testable in isolation - - Orchestrator handles all state mutations and message routing - - Broadcast routing to participants keeps graph structure simple - - Args: - manager: Callable that selects the next participant or finishes based on state snapshot - participants: Mapping of participant names to descriptions (for manager context) - manager_name: Display name for manager in conversation history - max_rounds: Optional limit on manager selection rounds (None = unlimited) - termination_condition: Optional callable that halts the conversation when it returns True - executor_id: Optional custom ID for observability (auto-generated if not provided) + This is the most basic orchestrator, great for getting started with multi-agent + conversations. More advanced orchestrators can be built by extending BaseGroupChatOrchestrator + and implementing custom logic in the message and response handlers. """ def __init__( self, - manager: _GroupChatManagerFn, + id: str, + participant_registry: ParticipantRegistry, + selection_func: GroupChatSelectionFunction, *, - participants: Mapping[str, str], - manager_name: str, + name: str | None = None, max_rounds: int | None = None, - termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] | None = None, - executor_id: str | None = None, + termination_condition: TerminationCondition | None = None, ) -> None: - super().__init__(executor_id or f"groupchat_orchestrator_{uuid4().hex[:8]}") - self._manager = manager - self._participants = dict(participants) - self._manager_name = manager_name - self._max_rounds = max_rounds - self._termination_condition = termination_condition - self._history: list[_GroupChatTurn] = [] - self._task_message: ChatMessage | None = None - self._pending_agent: str | None = None - self._pending_finalization: bool = False - # Stashes the initial conversation list until _handle_task_message normalizes it into _conversation. - self._pending_initial_conversation: list[ChatMessage] | None = None - - def _get_author_name(self) -> str: - """Get the manager name for orchestrator-generated messages.""" - return self._manager_name - - def _build_state(self) -> GroupChatStateSnapshot: - """Build a snapshot of current orchestration state for the manager. - - Packages conversation history, participant metadata, and round tracking into - an immutable mapping that the manager uses to make speaker selection decisions. - - Returns: - Mapping containing all context needed for manager decision-making - - Raises: - RuntimeError: If called before task message initialization (defensive check) - - When this is called: - - After initial input is processed (first manager query) - - After each participant response (subsequent manager queries) - """ - if self._task_message is None: - raise RuntimeError("GroupChatOrchestratorExecutor state not initialized with task message.") - snapshot: dict[str, Any] = { - "task": self._task_message, - "participants": dict(self._participants), - "conversation": tuple(self._conversation), - "history": tuple(self._history), - "pending_agent": self._pending_agent, - "round_index": self._round_index, - } - return MappingProxyType(snapshot) - - def _snapshot_pattern_metadata(self) -> dict[str, Any]: - """Serialize GroupChat-specific state for checkpointing. - - Returns: - Dict with participants, manager name, history, and pending agent - """ - return { - "participants": dict(self._participants), - "manager_name": self._manager_name, - "pending_agent": self._pending_agent, - "task_message": self._task_message.to_dict() if self._task_message else None, - "history": [ - {"speaker": turn.speaker, "role": turn.role, "message": turn.message.to_dict()} - for turn in self._history - ], - } - - def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None: - """Restore GroupChat-specific state from checkpoint. + """Initialize the GroupChatOrchestrator. Args: - metadata: Pattern-specific state dict - """ - if "participants" in metadata: - self._participants = dict(metadata["participants"]) - if "manager_name" in metadata: - self._manager_name = metadata["manager_name"] - if "pending_agent" in metadata: - self._pending_agent = metadata["pending_agent"] - task_msg = metadata.get("task_message") - if task_msg: - self._task_message = ChatMessage.from_dict(task_msg) - if "history" in metadata: - self._history = [ - _GroupChatTurn( - speaker=turn["speaker"], - role=turn["role"], - message=ChatMessage.from_dict(turn["message"]), - ) - for turn in metadata["history"] - ] - - async def _complete_on_termination( - self, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], - ) -> bool: - """Finish the conversation early when the termination condition is met.""" - if not await self._check_termination(): - return False - - if self._is_manager_agent(): - if self._pending_finalization: - return True - - self._pending_finalization = True - termination_prompt = ChatMessage( - role=Role.SYSTEM, - text="Termination condition met. Provide a final manager summary and finish the conversation.", - ) - manager_conversation = [ - self._build_manager_context_message(), - termination_prompt, - *list(self._conversation), - ] - self._pending_agent = self._manager_name - await self._route_to_participant( - participant_name=self._manager_name, - conversation=manager_conversation, - ctx=ctx, - instruction="", - task=self._task_message, - metadata={"termination_condition": True}, - ) - return True - - final_message: ChatMessage | None = None - if self._manager is not None and not self._is_manager_agent(): - try: - directive = await self._manager(self._build_state()) - except Exception: - logger.warning("Manager finalization failed during termination; using default termination message.") - else: - if directive.final_message is not None: - final_message = ensure_author(directive.final_message, self._manager_name) - elif directive.finish: - final_message = ensure_author( - self._create_completion_message( - text="Conversation completed.", - reason="termination_condition_manager_finish", - ), - self._manager_name, - ) + id: Unique executor ID for the orchestrator. The ID must be unique within the workflow. + participant_registry: Registry of participants in the group chat that track executor types + (agents vs. executors) and provide resolution utilities. + selection_func: Function to select the next speaker based on conversation state + name: Optional display name for the orchestrator in the messages, defaults to executor ID. + A more descriptive name that is not an ID could help models better understand the role + of the orchestrator in multi-agent conversations. If the ID is not human-friendly, + providing a name can improve context for the agents. + max_rounds: Optional limit on selection rounds to prevent infinite loops. + termination_condition: Optional callable that halts the conversation when it returns True + + Note: If neither `max_rounds` nor `termination_condition` is provided, the conversation + will continue indefinitely. It is recommended to always set one of these to ensure proper termination. - if final_message is None: - final_message = ensure_author( - self._create_completion_message( - text="Conversation halted after termination condition was met.", - reason="termination_condition", - ), - self._manager_name, - ) - self._conversation.append(final_message) - self._history.append(_GroupChatTurn(self._manager_name, "manager", final_message)) - self._pending_agent = None - await ctx.yield_output(list(self._conversation)) - return True + Example: + .. code-block:: python - async def _apply_directive( - self, - directive: GroupChatDirective, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], - ) -> None: - """Execute a manager directive by either finishing the workflow or routing to a participant. + from agent_framework import GroupChatOrchestrator - This is the core routing logic that interprets manager decisions. It handles two cases: - 1. Finish directive: append final message, update state, yield output, become idle - 2. Agent selection: build request envelope, route to participant, increment round counter - Args: - directive: Manager's decision (finish or select next participant) - ctx: Workflow context for sending messages and yielding output - - Behavior for finish directive: - - Uses provided final_message or creates default completion message - - Ensures author_name is set to manager for attribution - - Appends to conversation and history for complete record - - Yields message as workflow output - - Orchestrator becomes idle (no further processing) - - Behavior for agent selection: - - Validates agent_name exists in participants - - Optionally appends manager instruction as USER message - - Prepares full conversation context for the participant - - Routes request directly to the participant entry executor - - Increments round counter and enforces max_rounds if configured - - Round limit enforcement: - If max_rounds is reached, recursively calls _apply_directive with a finish - directive to gracefully terminate the conversation. + async def round_robin_selector(state: GroupChatState) -> str: + # Simple round-robin selection among participants + return state.participants[state.current_round % len(state.participants)] - Raises: - ValueError: If directive lacks agent_name when finish=False, or if - agent_name doesn't match any participant - """ - if directive.finish: - final_message = directive.final_message - if final_message is None: - final_message = self._create_completion_message( - text="Completed without final summary.", - reason="no summary provided", - ) - final_message = ensure_author(final_message, self._manager_name) - - self._conversation.extend((final_message,)) - self._history.append(_GroupChatTurn(self._manager_name, "manager", final_message)) - self._pending_agent = None - await ctx.yield_output(list(self._conversation)) - return - agent_name = directive.agent_name - if not agent_name: - raise ValueError("Directive must include agent_name when finish is False.") - if agent_name not in self._participants: - raise ValueError(f"Manager selected unknown participant '{agent_name}'.") - - instruction = directive.instruction or "" - conversation = list(self._conversation) - if instruction: - manager_message = ensure_author( - self._create_completion_message(text=instruction, reason="instruction"), - self._manager_name, + orchestrator = GroupChatOrchestrator( + id="group_chat_orchestrator_1", + selection_func=round_robin_selector, + participants=["researcher", "writer"], + name="Coordinator", + max_rounds=10, ) - conversation.extend((manager_message,)) - self._conversation.extend((manager_message,)) - self._history.append(_GroupChatTurn(self._manager_name, "manager", manager_message)) - - if await self._complete_on_termination(ctx): - return - - self._pending_agent = agent_name - self._increment_round() - - # Use inherited routing method from BaseGroupChatOrchestrator - await self._route_to_participant( - participant_name=agent_name, - conversation=conversation, - ctx=ctx, - instruction=instruction, - task=self._task_message, - metadata=directive.metadata, + """ + super().__init__( + id, + participant_registry, + name=name, + max_rounds=max_rounds, + termination_condition=termination_condition, ) + self._selection_func = selection_func - if self._check_round_limit(): - await self._apply_directive( - GroupChatDirective( - finish=True, - final_message=self._create_completion_message( - text="Conversation halted after reaching manager round limit.", - reason="max_rounds reached", - ), - ), - ctx, - ) - - async def _ingest_participant_message( + @override + async def _handle_messages( self, - participant_name: str, - message: ChatMessage, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], - trailing_messages: list[ChatMessage] | None = None, + messages: list[ChatMessage], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Common response ingestion logic shared by agent and custom participants. - - Args: - participant_name: Name of the participant who sent the message - message: The participant's response message - ctx: Workflow context for routing and output - trailing_messages: Optional list of messages to inject after the participant's - message (e.g., additional input from the RequestInfoInterceptor) - """ - if participant_name not in self._participants: - raise ValueError(f"Received response from unknown participant '{participant_name}'.") - - message = ensure_author(message, participant_name) - self._conversation.extend((message,)) - self._history.append(_GroupChatTurn(participant_name, "agent", message)) - - # Inject any trailing messages (e.g., human input) into the conversation - if trailing_messages: - for trailing_msg in trailing_messages: - self._conversation.extend((trailing_msg,)) - # Record as user input in history - author = trailing_msg.author_name or "human" - self._history.append(_GroupChatTurn(author, "user", trailing_msg)) - logger.debug( - f"Injected human input into group chat conversation: " - f"{trailing_msg.text[:50] if trailing_msg.text else '(empty)'}..." - ) - - self._pending_agent = None - - if await self._complete_on_termination(ctx): - return - - if self._check_round_limit(): - final_message = self._create_completion_message( - text="Conversation halted after reaching manager round limit.", - reason="max_rounds reached after response", - ) - self._conversation.extend((final_message,)) - self._history.append(_GroupChatTurn(self._manager_name, "manager", final_message)) - await ctx.yield_output(list(self._conversation)) + """Initialize orchestrator state and start the conversation loop.""" + self._append_messages(messages) + # Termination condition will also be applied to the input messages + if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[ChatMessage]], ctx)): return - # Query manager for next speaker selection - if self._is_manager_agent(): - # Agent-based manager: route request through workflow graph - # Prepend system message with participant context - manager_conversation = [self._build_manager_context_message(), *list(self._conversation)] - await self._route_to_participant( - participant_name=self._manager_name, - conversation=manager_conversation, - ctx=ctx, - instruction="", - task=self._task_message, - metadata=None, - ) - else: - # Callable manager: invoke directly - directive = await self._manager(self._build_state()) - await self._apply_directive(directive, ctx) - - def _is_manager_agent(self) -> bool: - """Check if orchestrator is using an agent-based manager (vs callable manager).""" - return self._registry.is_participant_registered(self._manager_name) - - def _build_manager_context_message(self) -> ChatMessage: - """Build system message with participant context for manager agent. - - This message is prepended to the conversation when querying the manager - to provide up-to-date participant information for selection decisions. + next_speaker = await self._get_next_speaker() - Returns: - System message with participant names and descriptions - """ - participant_list = "\n".join(f"- {name}: {desc}" for name, desc in self._participants.items()) - context_text = ( - "Available participants:\n" - f"{participant_list}\n\n" - "IMPORTANT: Choose only from these exact participant names (case-sensitive)." + # Broadcast messages to all participants for context + await self._broadcast_messages_to_participants( + messages, + cast(WorkflowContext[AgentExecutorRequest | GroupChatParticipantMessage], ctx), ) - return ChatMessage(role=Role.SYSTEM, text=context_text) - - def _parse_manager_selection(self, response: AgentExecutorResponse) -> ManagerSelectionResponse: - """Extract manager selection decision from agent response. - - Attempts to parse structured output from the manager agent using multiple strategies: - 1. response.value (structured output from response_format) - 2. JSON parsing from message text - 3. Fallback error handling - - Args: - response: AgentExecutor response from manager agent - - Returns: - Parsed ManagerSelectionResponse with speaker selection - - Raises: - RuntimeError: If manager response cannot be parsed into valid selection - """ - import json - - # Strategy 1: agent_run_response.value (structured output) - agent_value = response.agent_run_response.value - if agent_value is not None: - if isinstance(agent_value, ManagerSelectionResponse): - return agent_value - if isinstance(agent_value, dict): - return ManagerSelectionResponse.from_dict(cast(dict[str, Any], agent_value)) - if isinstance(agent_value, str): - try: - data = json.loads(agent_value) - return ManagerSelectionResponse.from_dict(data) - except (json.JSONDecodeError, TypeError, KeyError) as e: - raise RuntimeError(f"Manager response.value contains invalid JSON: {e}") from e - - # Strategy 2: Parse from message text - messages = response.agent_run_response.messages or [] - if messages: - last_msg = messages[-1] - text = last_msg.text or "" - try: - return ManagerSelectionResponse.model_validate_json(text) - except (json.JSONDecodeError, TypeError, KeyError): - pass - - # Fallback: Cannot parse manager decision - raise RuntimeError( - "Manager response did not contain valid selection data. " - "Ensure manager agent uses response_format=ManagerSelectionResponse " - "or returns compatible JSON structure." + # Send request to selected participant + await self._send_request_to_participant( + next_speaker, + cast(WorkflowContext[AgentExecutorRequest | GroupChatRequestMessage], ctx), ) + self._increment_round() - async def _handle_manager_response( + @override + async def _handle_response( self, - response: AgentExecutorResponse, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], + response: AgentExecutorResponse | GroupChatResponseMessage, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Process manager agent's speaker selection decision. - - Parses the manager's response and either finishes the conversation or routes - to the selected participant. This method implements the core orchestration - logic for agent-based managers. - - Also handles any human input that was injected into the response's full_conversation - by the human input hook interceptor. - - Args: - response: AgentExecutor response from manager agent - ctx: Workflow context for routing and output + """Handle a participant response.""" + messages = self._process_participant_response(response) + self._append_messages(messages) - Behavior: - - Extracts any human input from the response - - Parses manager selection from response - - If finish=True: yields final message and completes workflow - - If participant selected: routes request to that participant with human input included - - Validates selected participant exists - - Enforces round limits if configured - - Raises: - ValueError: If manager selects invalid/unknown participant - RuntimeError: If manager response cannot be parsed - """ - # Extract any human input that was injected by the human input hook - trailing_user_messages = self._extract_trailing_user_messages(response) - - selection = self._parse_manager_selection(response) - - if self._pending_finalization: - self._pending_finalization = False - final_message_obj = selection.get_final_message_as_chat_message() - if final_message_obj is None: - final_message_obj = self._create_completion_message( - text="Conversation halted after termination condition was met.", - reason="termination_condition_manager", - ) - final_message_obj = ensure_author(final_message_obj, self._manager_name) - - self._conversation.append(final_message_obj) - self._history.append(_GroupChatTurn(self._manager_name, "manager", final_message_obj)) - self._pending_agent = None - await ctx.yield_output(list(self._conversation)) + if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[ChatMessage]], ctx)): return - - if selection.finish: - # Manager decided to complete conversation - final_message_obj = selection.get_final_message_as_chat_message() - if final_message_obj is None: - final_message_obj = self._create_completion_message( - text="Conversation completed.", - reason="manager_finish", - ) - final_message_obj = ensure_author(final_message_obj, self._manager_name) - - self._conversation.append(final_message_obj) - self._history.append(_GroupChatTurn(self._manager_name, "manager", final_message_obj)) - self._pending_agent = None - await ctx.yield_output(list(self._conversation)) + if await self._check_round_limit_and_yield(cast(WorkflowContext[Never, list[ChatMessage]], ctx)): return - # Manager selected next participant - selected = selection.selected_participant - if not selected: - raise ValueError("Manager selection missing selected_participant when finish=False.") - if selected not in self._participants: - raise ValueError(f"Manager selected unknown participant: '{selected}'") - - # Route to selected participant - instruction = selection.instruction or "" - conversation = list(self._conversation) - if instruction: - manager_message = ensure_author( - self._create_completion_message(text=instruction, reason="manager_instruction"), - self._manager_name, - ) - conversation.append(manager_message) - self._conversation.append(manager_message) - self._history.append(_GroupChatTurn(self._manager_name, "manager", manager_message)) - - # Inject any human input that was attached to the manager's response - # This ensures the next participant sees the human's guidance - if trailing_user_messages: - for human_msg in trailing_user_messages: - conversation.append(human_msg) - self._conversation.append(human_msg) - author = human_msg.author_name or "human" - self._history.append(_GroupChatTurn(author, "user", human_msg)) - logger.debug( - f"Injected human input after manager selection: " - f"{human_msg.text[:50] if human_msg.text else '(empty)'}..." - ) - - if await self._complete_on_termination(ctx): - return + next_speaker = await self._get_next_speaker() - self._pending_agent = selected + # Broadcast participant messages to all participants for context, except + # the participant that just responded + participant = ctx.get_source_executor_id() + await self._broadcast_messages_to_participants( + messages, + cast(WorkflowContext[AgentExecutorRequest | GroupChatParticipantMessage], ctx), + participants=[p for p in self._participant_registry.participants if p != participant], + ) + # Send request to selected participant + await self._send_request_to_participant( + next_speaker, + cast(WorkflowContext[AgentExecutorRequest | GroupChatRequestMessage], ctx), + ) self._increment_round() - await self._route_to_participant( - participant_name=selected, - conversation=conversation, - ctx=ctx, - instruction=instruction, - task=self._task_message, - metadata=None, + async def _get_next_speaker(self) -> str: + """Determine the next speaker using the selection function.""" + group_chat_state = GroupChatState( + current_round=self._round_index, + participants=self._participant_registry.participants, + conversation=self._get_conversation(), ) - if self._check_round_limit(): - await self._apply_directive( - GroupChatDirective( - finish=True, - final_message=self._create_completion_message( - text="Conversation halted after reaching manager round limit.", - reason="max_rounds reached after manager selection", - ), - ), - ctx, - ) + next_speaker = self._selection_func(group_chat_state) + if inspect.isawaitable(next_speaker): + next_speaker = await next_speaker - @staticmethod - def _extract_agent_message(response: AgentExecutorResponse, participant_name: str) -> ChatMessage: - """Select the final assistant message from an AgentExecutor response.""" - from ._orchestrator_helpers import create_completion_message + if next_speaker not in self._participant_registry.participants: + raise RuntimeError(f"Selection function returned unknown participant '{next_speaker}'.") - final_message: ChatMessage | None = None - candidate_sequences: tuple[Sequence[ChatMessage] | None, ...] = ( - response.agent_run_response.messages, - response.full_conversation, - ) - for sequence in candidate_sequences: - if not sequence: - continue - for candidate in reversed(sequence): - if candidate.role == Role.ASSISTANT: - final_message = candidate - break - if final_message is not None: - break - - if final_message is None: - final_message = create_completion_message( - text="", - author_name=participant_name, - reason="empty response", - ) - return ensure_author(final_message, participant_name) + return next_speaker - @staticmethod - def _extract_trailing_user_messages(response: AgentExecutorResponse) -> list[ChatMessage]: - """Extract any user messages that appear after the last assistant message. - This is used to capture human input that was injected by the human input hook - interceptor. The hook adds user messages to full_conversation after the agent's - response, so they appear at the end of the sequence. - - Args: - response: AgentExecutor response that may contain trailing user messages - - Returns: - List of user messages that appear after the last assistant message, - or empty list if none found - """ - if not response.full_conversation: - return [] - - # Find index of last assistant message - last_assistant_idx = -1 - for i, msg in enumerate(response.full_conversation): - if msg.role == Role.ASSISTANT: - last_assistant_idx = i +# endregion - if last_assistant_idx < 0: - return [] +# region Agent-based orchestrator - # Collect any user messages after the last assistant message - trailing_user: list[ChatMessage] = [] - for msg in response.full_conversation[last_assistant_idx + 1 :]: - if msg.role == Role.USER: - trailing_user.append(msg) - return trailing_user +class AgentOrchestrationOutput(BaseModel): + """Structured output type for the agent in AgentBasedGroupChatOrchestrator.""" - async def _handle_task_message( - self, - task_message: ChatMessage, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], - ) -> None: - """Initialize orchestrator state and start the manager-directed conversation loop. - - This internal method is called by all public handlers (str, ChatMessage, list[ChatMessage]) - after normalizing their input. It initializes conversation state, queries the manager - for the first action, and applies the resulting directive. + model_config = { + "extra": "forbid", + # OpenAI strict mode requires all properties to be in required array + "json_schema_extra": {"required": ["terminate", "reason", "next_speaker", "final_message"]}, + } - Args: - task_message: The primary user task message (extracted or provided directly) - ctx: Workflow context for sending messages and yielding output - - Behavior: - - Sets task_message for manager context - - Initializes conversation from pending_initial_conversation if present - - Otherwise starts fresh with just the task message - - Builds turn history with speaker attribution - - Resets pending_agent and round_index - - Queries manager for first action - - Applies directive to start the conversation loop - - State initialization: - - _conversation: Full message list for context - - _history: Turn-by-turn record with speaker names and roles - - _pending_agent: None (no active request) - - _round_index: 0 (first manager query) - - Why pending_initial_conversation exists: - The handle_conversation handler supplies an explicit task (the first message in - the list) but still forwards the entire conversation for context. The full list is - stashed in _pending_initial_conversation to preserve all context when initializing state. - """ - self._task_message = task_message - if self._pending_initial_conversation: - initial_conversation = list(self._pending_initial_conversation) - self._pending_initial_conversation = None - self._conversation = initial_conversation - self._history = [ - _GroupChatTurn( - msg.author_name or msg.role.value, - msg.role.value, - msg, - ) - for msg in initial_conversation - ] - else: - self._conversation = [task_message] - self._history = [_GroupChatTurn("user", "user", task_message)] - self._pending_agent = None - self._round_index = 0 - - if await self._complete_on_termination(ctx): - return + # Whether to terminate the conversation + terminate: bool + # An explanation for the decision made + reason: str + # Next speaker to select if not terminating + next_speaker: str | None = Field( + default=None, + description="Name of the next participant to speak (if not terminating)", + ) + # Optional final message to send if terminating + final_message: str | None = Field(default=None, description="Optional final message if terminating") - # Query manager for first speaker selection - if self._is_manager_agent(): - # Agent-based manager: route request through workflow graph - # Prepend system message with participant context - manager_conversation = [self._build_manager_context_message(), *list(self._conversation)] - await self._route_to_participant( - participant_name=self._manager_name, - conversation=manager_conversation, - ctx=ctx, - instruction="", - task=self._task_message, - metadata=None, - ) - else: - # Callable manager: invoke directly - directive = await self._manager(self._build_state()) - await self._apply_directive(directive, ctx) - @handler - async def handle_str( - self, - task: str, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], - ) -> None: - """Handler for string input as workflow entry point. +class AgentBasedGroupChatOrchestrator(BaseGroupChatOrchestrator): + """Orchestrator that manages a group chat between multiple participants. - Wraps the string in a USER role ChatMessage and delegates to _handle_task_message. + This group chat orchestrator is driven by an agent that can select the next speaker + intelligently based on the conversation context. - Args: - task: Plain text task description from user - ctx: Workflow context + This orchestrator drives the conversation loop as follows: + 1. Receives initial messages, saves to history, and broadcasts to all participants + 2. Invokes the agent to determine the next speaker based on the most recent state + 3. Sends a request to the selected participant to generate a response + 4. Receives the participant's response, saves to history, and broadcasts to all participants + except the one that just spoke + 5. Repeats steps 2-4 until the termination conditions are met - Usage: - workflow.run("Write a blog post about AI agents") - """ - await self._handle_task_message(ChatMessage(role=Role.USER, text=task), ctx) + Note: The agent will be asked to generate a structured output of type `AgentOrchestrationOutput`, + thus it must be capable of structured output. + """ - @handler - async def handle_chat_message( + def __init__( self, - task_message: ChatMessage, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], + agent: ChatAgent, + participant_registry: ParticipantRegistry, + *, + max_rounds: int | None = None, + termination_condition: TerminationCondition | None = None, + retry_attempts: int | None = None, + thread: AgentThread | None = None, ) -> None: - """Handler for ChatMessage input as workflow entry point. - - Directly delegates to _handle_task_message for state initialization. + """Initialize the GroupChatOrchestrator. Args: - task_message: Structured chat message from user (may include metadata, role, etc.) - ctx: Workflow context - - Usage: - workflow.run(ChatMessage(role=Role.USER, text="Analyze this data")) + agent: Agent that selects the next speaker based on conversation state + participant_registry: Registry of participants in the group chat that track executor types + (agents vs. executors) and provide resolution utilities. + max_rounds: Optional limit on selection rounds to prevent infinite loops. + termination_condition: Optional callable that halts the conversation when it returns True + retry_attempts: Optional number of retry attempts for the agent in case of failure. + thread: Optional agent thread to use for the orchestrator agent. """ - await self._handle_task_message(task_message, ctx) - - @handler - async def handle_conversation( + super().__init__( + resolve_agent_id(agent), + participant_registry, + name=agent.name, + max_rounds=max_rounds, + termination_condition=termination_condition, + ) + self._agent = agent + self._retry_attempts = retry_attempts + self._thread = thread or agent.get_new_thread() + # Cache for messages since last agent invocation + # This is different from the full conversation history maintained by the base orchestrator + self._cache: list[ChatMessage] = [] + + @override + def _append_messages(self, messages: Sequence[ChatMessage]) -> None: + self._cache.extend(messages) + return super()._append_messages(messages) + + @override + async def _handle_messages( self, - conversation: list[ChatMessage], - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], + messages: list[ChatMessage], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Handler for conversation history as workflow entry point. - - Accepts a pre-existing conversation and uses the first message in the list as the task. - Preserves the full conversation for state initialization. + """Initialize orchestrator state and start the conversation loop.""" + self._append_messages(messages) + # Termination condition will also be applied to the input messages + if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[ChatMessage]], ctx)): + return - Args: - conversation: List of chat messages (system, user, assistant) - ctx: Workflow context + agent_orchestration_output = await self._invoke_agent() + if await self._check_agent_terminate_and_yield( + agent_orchestration_output, + cast(WorkflowContext[Never, list[ChatMessage]], ctx), + ): + return - Raises: - ValueError: If conversation list is empty - - Behavior: - - Validates conversation is non-empty - - Clones conversation to avoid mutation - - Extracts task message (most recent USER message) - - Stashes full conversation in _pending_initial_conversation - - Delegates to _handle_task_message for initialization - - Usage: - existing_messages = [ - ChatMessage(role=Role.SYSTEM, text="You are an expert"), - ChatMessage(role=Role.USER, text="Help me with this task") - ] - workflow.run(existing_messages) - """ - if not conversation: - raise ValueError("GroupChat workflow requires at least one chat message.") - self._pending_initial_conversation = list(conversation) - task_message = latest_user_message(conversation) - await self._handle_task_message(task_message, ctx) - - @handler - async def handle_agent_response( - self, - response: _GroupChatResponseMessage, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], - ) -> None: - """Handle responses from custom participant executors.""" - await self._ingest_participant_message(response.agent_name, response.message, ctx) + # Broadcast messages to all participants for context + await self._broadcast_messages_to_participants( + messages, + cast(WorkflowContext[AgentExecutorRequest | GroupChatParticipantMessage], ctx), + ) + # Send request to selected participant + await self._send_request_to_participant( + # If not terminating, next_speaker must be provided thus will not be None + agent_orchestration_output.next_speaker, # type: ignore[arg-type] + cast(WorkflowContext[AgentExecutorRequest | GroupChatRequestMessage], ctx), + ) + self._increment_round() - @handler - async def handle_agent_executor_response( + @override + async def _handle_response( self, - response: AgentExecutorResponse, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], + response: AgentExecutorResponse | GroupChatResponseMessage, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Handle responses from both manager agent and regular participants. - - Routes responses based on whether they come from the manager or a participant: - - Manager responses: parsed for speaker selection decisions - - Participant responses: ingested as conversation messages - - Also handles any human input that was injected into the response's full_conversation - by the human input hook interceptor. - """ - participant_name = self._registry.get_participant_name(response.executor_id) - if participant_name is None: - logger.debug( - "Ignoring response from unregistered agent executor '%s'.", - response.executor_id, - ) + """Handle a participant response.""" + messages = self._process_participant_response(response) + self._append_messages(messages) + if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[ChatMessage]], ctx)): + return + if await self._check_round_limit_and_yield(cast(WorkflowContext[Never, list[ChatMessage]], ctx)): return - # Check if response is from manager agent - if participant_name == self._manager_name and self._is_manager_agent(): - await self._handle_manager_response(response, ctx) - else: - # Regular participant response - message = self._extract_agent_message(response, participant_name) - - # Check for human input injected by human input hook - # Human input appears as user messages at the end of full_conversation - # after the agent's assistant message - trailing_user_messages = self._extract_trailing_user_messages(response) - - await self._ingest_participant_message(participant_name, message, ctx, trailing_user_messages) - - -def _default_orchestrator_factory(wiring: _GroupChatConfig) -> Executor: - """Default factory for creating the GroupChatOrchestratorExecutor instance. - - This is the internal implementation used by GroupChatBuilder to instantiate the - orchestrator. It extracts participant descriptions from the wiring configuration - and passes them to the orchestrator for manager context. - - Args: - wiring: Complete workflow configuration assembled by the builder - - Returns: - Initialized GroupChatOrchestratorExecutor ready to coordinate the conversation - - Behavior: - - Extracts participant names and descriptions for manager context - - Forwards manager instance, manager name, max_rounds, and termination_condition settings - - Allows orchestrator to auto-generate its executor ID - - Supports both callable managers (set_select_speakers_func) and agent-based managers (set_manager) - - Why descriptions are extracted: - The manager needs participant descriptions (not full specs) to make informed - selection decisions. The orchestrator doesn't need participant instances directly - since routing is handled by the workflow graph. + agent_orchestration_output = await self._invoke_agent() + if await self._check_agent_terminate_and_yield( + agent_orchestration_output, + cast(WorkflowContext[Never, list[ChatMessage]], ctx), + ): + return - Raises: - RuntimeError: If neither manager nor manager_participant is configured - """ - if wiring.manager is None and wiring.manager_participant is None: - raise RuntimeError( - "Default orchestrator factory requires a manager to be configured. " - "Call set_manager(...) or set_select_speakers_func(...) before build()." + # Broadcast participant messages to all participants for context, except + # the participant that just responded + participant = ctx.get_source_executor_id() + await self._broadcast_messages_to_participants( + messages, + cast(WorkflowContext[AgentExecutorRequest | GroupChatParticipantMessage], ctx), + participants=[p for p in self._participant_registry.participants if p != participant], ) + # Send request to selected participant + await self._send_request_to_participant( + # If not terminating, next_speaker must be provided thus will not be None + agent_orchestration_output.next_speaker, # type: ignore[arg-type] + cast(WorkflowContext[AgentExecutorRequest | GroupChatRequestMessage], ctx), + ) + self._increment_round() - manager_callable = wiring.manager - if manager_callable is None: - # Keep orchestrator signature satisfied; agent managers are routed via the workflow graph - async def _agent_manager_placeholder(_: GroupChatStateSnapshot) -> GroupChatDirective: # noqa: RUF029 - raise RuntimeError( - "Manager callable invoked unexpectedly. Agent-based managers should route through the workflow graph." - ) - - manager_callable = _agent_manager_placeholder - - return GroupChatOrchestratorExecutor( - manager=manager_callable, - participants={name: spec.description for name, spec in wiring.participants.items()}, - manager_name=wiring.manager_name, - max_rounds=wiring.max_rounds, - termination_condition=wiring.termination_condition, - ) + async def _invoke_agent(self) -> AgentOrchestrationOutput: + """Invoke the orchestrator agent to determine the next speaker and termination.""" + + async def _invoke_agent_helper(conversation: list[ChatMessage]) -> AgentOrchestrationOutput: + # Run the agent in non-streaming mode for simplicity + agent_response = await self._agent.run( + messages=conversation, + thread=self._thread, + options={"response_format": AgentOrchestrationOutput}, + ) + # Parse and validate the structured output + agent_orchestration_output = AgentOrchestrationOutput.model_validate_json(agent_response.text) + + if not agent_orchestration_output.terminate and not agent_orchestration_output.next_speaker: + raise ValueError("next_speaker must be provided if not terminating the conversation.") + + return agent_orchestration_output + + # We only need the last message for context since history is maintained in the thread + current_conversation = self._cache.copy() + self._cache.clear() + instruction = ( + "Decide what to do next. Respond with a JSON object of the following format:\n" + "{\n" + ' "terminate": ,\n' + ' "reason": "",\n' + ' "next_speaker": "",\n' + ' "final_message": ""\n' + "}\n" + "If not terminating, here are the valid participant names (case-sensitive) and their descriptions:\n" + + "\n".join([ + f"{name}: {description}" for name, description in self._participant_registry.participants.items() + ]) + ) + # Prepend instruction as system message + current_conversation.append(ChatMessage(role=Role.USER, text=instruction)) + retry_attempts = self._retry_attempts + while True: + try: + return await _invoke_agent_helper(current_conversation) + except Exception as ex: + logger.error(f"Agent orchestration invocation failed: {ex}") + if retry_attempts is None or retry_attempts <= 0: + raise + retry_attempts -= 1 + logger.debug(f"Retrying agent orchestration invocation, attempts left: {retry_attempts}") + # We don't need the full conversation since the thread should maintain history + current_conversation = [ + ChatMessage( + role=Role.USER, + text=f"Your input could not be parsed due to an error: {ex}. Please try again.", + ) + ] -def group_chat_orchestrator(factory: _GroupChatOrchestratorFactory | None = None) -> _GroupChatOrchestratorFactory: - """Return a callable orchestrator factory, defaulting to the built-in implementation.""" - return factory or _default_orchestrator_factory - - -def assemble_group_chat_workflow( - *, - wiring: _GroupChatConfig, - participant_factory: Callable[[GroupChatParticipantSpec, _GroupChatConfig], _GroupChatParticipantPipeline], - orchestrator_factory: _GroupChatOrchestratorFactory = _default_orchestrator_factory, - interceptors: Sequence[_InterceptorSpec] | None = None, - checkpoint_storage: CheckpointStorage | None = None, - builder: WorkflowBuilder | None = None, - return_builder: bool = False, -) -> Workflow | tuple[WorkflowBuilder, Executor]: - """Build the workflow graph shared by group-chat style orchestrators.""" - interceptor_specs = interceptors or () - - orchestrator = wiring.orchestrator or orchestrator_factory(wiring) - wiring.orchestrator = orchestrator - - workflow_builder = builder or WorkflowBuilder() - start_executor = getattr(workflow_builder, "_start_executor", None) - if start_executor is None: - workflow_builder = workflow_builder.set_start_executor(orchestrator) - - # Wire manager as participant if agent-based manager is configured - if wiring.manager_participant is not None: - manager_spec = GroupChatParticipantSpec( - name=wiring.manager_name, - participant=wiring.manager_participant, - description="Coordination manager", - ) - manager_pipeline = list(participant_factory(manager_spec, wiring)) - if not manager_pipeline: - raise ValueError("Participant factory returned empty pipeline for manager.") - - manager_entry = manager_pipeline[0] - manager_exit = manager_pipeline[-1] - - # Register manager with orchestrator (with entry and exit IDs for pipeline routing) - register_entry = getattr(orchestrator, "register_participant_entry", None) - if callable(register_entry): - register_entry( - wiring.manager_name, - entry_id=manager_entry.id, - is_agent=not isinstance(wiring.manager_participant, Executor), - exit_id=manager_exit.id if manager_exit is not manager_entry else None, - ) + async def _check_agent_terminate_and_yield( + self, + agent_orchestration_output: AgentOrchestrationOutput, + ctx: WorkflowContext[Never, list[ChatMessage]], + ) -> bool: + """Check if the agent requested termination and yield completion if so. - # Wire manager edges: Orchestrator ↔ Manager - workflow_builder = workflow_builder.add_edge(orchestrator, manager_entry) - for upstream, downstream in itertools.pairwise(manager_pipeline): - workflow_builder = workflow_builder.add_edge(upstream, downstream) - if manager_exit is not orchestrator: - workflow_builder = workflow_builder.add_edge(manager_exit, orchestrator) - - # Wire regular participants - for name, spec in wiring.participants.items(): - pipeline = list(participant_factory(spec, wiring)) - if not pipeline: - raise ValueError( - f"Participant factory returned an empty pipeline for '{name}'. " - "Provide at least one executor per participant." - ) - entry_executor = pipeline[0] - exit_executor = pipeline[-1] - - register_entry = getattr(orchestrator, "register_participant_entry", None) - if callable(register_entry): - # Register both entry and exit IDs so responses can be routed correctly - # when interceptors are prepended to the pipeline - register_entry( - name, - entry_id=entry_executor.id, - is_agent=not isinstance(spec.participant, Executor), - exit_id=exit_executor.id if exit_executor is not entry_executor else None, + Args: + agent_orchestration_output: Output from the orchestrator agent + ctx: Workflow context for yielding output + Returns: + True if termination was requested and output was yielded, False otherwise + """ + if agent_orchestration_output.terminate: + final_message = ( + agent_orchestration_output.final_message or "The conversation has been terminated by the agent." ) + self._append_messages([self._create_completion_message(final_message)]) + await ctx.yield_output(self._full_conversation) + return True - workflow_builder = workflow_builder.add_edge(orchestrator, entry_executor) - for upstream, downstream in itertools.pairwise(pipeline): - workflow_builder = workflow_builder.add_edge(upstream, downstream) - if exit_executor is not orchestrator: - workflow_builder = workflow_builder.add_edge(exit_executor, orchestrator) + return False - for factory, condition in interceptor_specs: - interceptor_executor = factory(wiring) - workflow_builder = workflow_builder.add_edge(orchestrator, interceptor_executor, condition=condition) - workflow_builder = workflow_builder.add_edge(interceptor_executor, orchestrator) + @override + async def on_checkpoint_save(self) -> dict[str, Any]: + """Capture current orchestrator state for checkpointing.""" + state = await super().on_checkpoint_save() + state["cache"] = encode_chat_messages(self._cache) + serialized_thread = await self._thread.serialize() + state["thread"] = serialized_thread - if checkpoint_storage is not None: - workflow_builder = workflow_builder.with_checkpointing(checkpoint_storage) + return state - if return_builder: - return workflow_builder, orchestrator - return workflow_builder.build() + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + """Restore executor state from checkpoint.""" + await super().on_checkpoint_restore(state) + self._cache = decode_chat_messages(state.get("cache", [])) + serialized_thread = state.get("thread") + if serialized_thread: + self._thread = await self._agent.deserialize_thread(serialized_thread) # endregion - # region Builder class GroupChatBuilder: - r"""High-level builder for manager-directed group chat workflows with dynamic orchestration. - - GroupChat coordinates multi-agent conversations using a manager that selects which participant - speaks next. The manager can be a simple Python function (:py:meth:`GroupChatBuilder.set_select_speakers_func`) - or an agent-based selector via :py:meth:`GroupChatBuilder.set_manager`. These two approaches are - mutually exclusive. - - **Core Workflow:** - 1. Define participants: list of agents (uses their .name) or dict mapping names to agents - 2. Configure speaker selection: :py:meth:`GroupChatBuilder.set_select_speakers_func` OR - :py:meth:`GroupChatBuilder.set_manager` (not both) - 3. Optional: set round limits, checkpointing, termination conditions - 4. Build and run the workflow - - **Speaker Selection Patterns:** + r"""High-level builder for group chat workflows. - *Pattern 1: Simple function-based selection (recommended)* + GroupChat coordinates multi-agent conversations using an orchestrator that can dynamically + select participants to speak at each turn based on the conversation state. - .. code-block:: python + Routing Pattern: + Agents respond in turns as directed by the orchestrator until termination conditions are met. + This provides a centralized approach to multi-agent collaboration, similar to a star topology. - from agent_framework import GroupChatBuilder, GroupChatStateSnapshot + Participants can be a combination of agents and executors. If they are executors, they + must implement the expected handlers for receiving GroupChat messages and returning responses + (Read our official documentation for details on implementing custom participant executors). + The orchestrator can be provided directly, or a simple selection function can be defined + to choose the next speaker based on the current state. The builder wires everything together + into a complete workflow graph that can be executed. - def select_next_speaker(state: GroupChatStateSnapshot) -> str | None: - # state contains: task, participants, conversation, history, round_index - if state["round_index"] >= 5: - return None # Finish - last_speaker = state["history"][-1].speaker if state["history"] else None - if last_speaker == "researcher": - return "writer" - return "researcher" - - - workflow = ( - GroupChatBuilder() - .set_select_speakers_func(select_next_speaker) - .participants([researcher_agent, writer_agent]) # Uses agent.name - .build() - ) - - *Pattern 2: LLM-based selection* - - .. code-block:: python - - from agent_framework import ChatAgent - from agent_framework.azure import AzureOpenAIChatClient + Outputs: + The final conversation history as a list of ChatMessage once the group chat completes. + """ - manager_agent = AzureOpenAIChatClient().create_agent( - instructions="Coordinate the conversation and pick the next speaker.", - name="Coordinator", - temperature=0.3, - seed=42, - max_tokens=500, - ) + DEFAULT_ORCHESTRATOR_ID: ClassVar[str] = "group_chat_orchestrator" - workflow = ( - GroupChatBuilder() - .set_manager(manager_agent, display_name="Coordinator") - .participants([researcher, writer]) # Or use dict: researcher=r, writer=w - .with_max_rounds(10) - .build() - ) + def __init__(self) -> None: + """Initialize the GroupChatBuilder.""" + self._participants: dict[str, AgentProtocol | Executor] = {} - *Pattern 3: Request info for mid-conversation feedback* + # Orchestrator related members + self._orchestrator: BaseGroupChatOrchestrator | None = None + self._selection_func: GroupChatSelectionFunction | None = None + self._agent_orchestrator: ChatAgent | None = None + self._termination_condition: TerminationCondition | None = None + self._max_rounds: int | None = None + self._orchestrator_name: str | None = None - .. code-block:: python + # Checkpoint related members + self._checkpoint_storage: CheckpointStorage | None = None - from agent_framework import GroupChatBuilder + # Request info related members + self._request_info_enabled: bool = False + self._request_info_filter: set[str] = set() - # Pause before all participants - workflow = ( - GroupChatBuilder() - .set_select_speakers_func(select_next_speaker) - .participants([researcher, writer]) - .with_request_info() - .build() - ) + def with_orchestrator(self, orchestrator: BaseGroupChatOrchestrator) -> "GroupChatBuilder": + """Set the orchestrator for this group chat workflow. - # Pause only before specific participants - workflow = ( - GroupChatBuilder() - .set_select_speakers_func(select_next_speaker) - .participants([researcher, writer, editor]) - .with_request_info(agents=[editor]) # Only pause before editor responds - .build() - ) + An group chat orchestrator is responsible for managing the flow of conversation, making + sure all participants are synced and picking the next speaker according to the defined logic + until the termination conditions are met. - **Participant Specification:** + Args: + orchestrator: An instance of BaseGroupChatOrchestrator to manage the group chat. - Two ways to specify participants: - - List form: `[agent1, agent2]` - uses `agent.name` attribute for participant names - - Dict form: `{name1: agent1, name2: agent2}` - explicit name control - - Keyword form: `participants(name1=agent1, name2=agent2)` - explicit name control + Returns: + Self for fluent chaining. - **State Snapshot Structure:** + Raises: + ValueError: If an orchestrator has already been set - The GroupChatStateSnapshot passed to set_select_speakers_func contains: - - `task`: ChatMessage - Original user task - - `participants`: dict[str, str] - Mapping of participant names to descriptions - - `conversation`: tuple[ChatMessage, ...] - Full conversation history - - `history`: tuple[GroupChatTurn, ...] - Turn-by-turn record with speaker attribution - - `round_index`: int - Number of manager selection rounds so far - - `pending_agent`: str | None - Name of agent currently processing (if any) + Example: + .. code-block:: python - **Important Constraints:** - - Cannot combine :py:meth:`GroupChatBuilder.set_select_speakers_func` and :py:meth:`GroupChatBuilder.set_manager` - - Participant names must be unique - - When using list form, agents must have a non-empty `name` attribute - """ + from agent_framework import GroupChatBuilder - def __init__( - self, - *, - _orchestrator_factory: _GroupChatOrchestratorFactory | None = None, - _participant_factory: Callable[[GroupChatParticipantSpec, _GroupChatConfig], _GroupChatParticipantPipeline] - | None = None, - ) -> None: - """Initialize the GroupChatBuilder. - Args: - _orchestrator_factory: Internal extension point for custom orchestrator implementations. - Used by Magentic. Not part of public API - subject to change. - _participant_factory: Internal extension point for custom participant pipelines. - Used by Magentic. Not part of public API - subject to change. + orchestrator = CustomGroupChatOrchestrator(...) + workflow = GroupChatBuilder().with_orchestrator(orchestrator).participants([agent1, agent2]).build() """ - self._participants: dict[str, AgentProtocol | Executor] = {} - self._participant_metadata: dict[str, Any] | None = None - self._manager: _GroupChatManagerFn | None = None - self._manager_participant: AgentProtocol | Executor | None = None - self._manager_name: str = "manager" - self._checkpoint_storage: CheckpointStorage | None = None - self._max_rounds: int | None = None - self._termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] | None = None - self._interceptors: list[_InterceptorSpec] = [] - self._orchestrator_factory = group_chat_orchestrator(_orchestrator_factory) - self._participant_factory = _participant_factory or _default_participant_factory - self._request_info_enabled: bool = False - self._request_info_filter: set[str] | None = None - - def _set_manager_function( - self, - manager: _GroupChatManagerFn, - display_name: str | None, - ) -> "GroupChatBuilder": - if self._manager is not None or self._manager_participant is not None: + if self._orchestrator is not None: + raise ValueError("An orchestrator has already been configured. Call with_orchestrator(...) at most once.") + if self._agent_orchestrator is not None: raise ValueError( - "GroupChatBuilder already has a manager configured. " - "Call set_select_speakers_func(...) or set_manager(...) at most once." + "An agent orchestrator has already been configured. " + "Call only one of with_orchestrator(...) or with_agent_orchestrator(...)." + ) + if self._selection_func is not None: + raise ValueError( + "A selection function has already been configured. " + "Call only one of with_orchestrator(...) or with_select_speaker_func(...)." ) - resolved_name = display_name or getattr(manager, "name", None) or "manager" - self._manager = manager - self._manager_name = resolved_name - return self - def set_manager( - self, - manager: AgentProtocol | Executor, - *, - display_name: str | None = None, - ) -> "GroupChatBuilder": - """Configure the manager/coordinator agent for group chat orchestration. + self._orchestrator = orchestrator + return self - The manager coordinates participants by selecting who speaks next based on - conversation state and task requirements. The manager is a full workflow - participant with access to all agent infrastructure (tools, context, observability). + def with_agent_orchestrator(self, agent: ChatAgent) -> "GroupChatBuilder": + """Set an agent-based orchestrator for this group chat workflow. - The manager agent must produce structured output compatible with ManagerSelectionResponse - to communicate its speaker selection decisions. Use response_format for reliable parsing. - GroupChatBuilder enforces this when the manager is a ChatAgent and rejects incompatible - response formats. + An agent-based group chat orchestrator uses a ChatAgent to select the next speaker + intelligently based on the conversation context. Args: - manager: Agent or executor responsible for speaker selection and coordination. - Must return ManagerSelectionResponse or compatible dict/JSON structure. - display_name: Optional name for manager messages in conversation history. - If not provided, uses manager.name for AgentProtocol or manager.id for Executor. + agent: An instance of ChatAgent to manage the group chat. Returns: Self for fluent chaining. Raises: - ValueError: If manager is already configured via :py:meth:`GroupChatBuilder.set_select_speakers_func` - TypeError: If manager is not AgentProtocol or Executor instance - - Example: - - .. code-block:: python - - from agent_framework import GroupChatBuilder, ChatAgent - from agent_framework.openai import OpenAIChatClient - - # Coordinator agent - response_format is enforced to ManagerSelectionResponse - coordinator = ChatAgent( - name="Coordinator", - description="Coordinates multi-agent collaboration", - instructions=''' - You coordinate a team conversation. Review the conversation history - and select the next participant to speak. - - When ready to finish, set finish=True and provide a summary in final_message. - ''', - chat_client=OpenAIChatClient(), + ValueError: If an orchestrator has already been set + """ + if self._agent_orchestrator is not None: + raise ValueError( + "Agent orchestrator has already been configured. Call with_agent_orchestrator(...) at most once." ) - - workflow = ( - GroupChatBuilder() - .set_manager(coordinator, display_name="Orchestrator") - .participants([researcher, writer]) - .build() + if self._orchestrator is not None: + raise ValueError( + "An orchestrator has already been configured. " + "Call only one of with_agent_orchestrator(...) or with_orchestrator(...)." ) - - Note: - The manager agent's response_format must be ManagerSelectionResponse for structured output. - Custom response formats raise ValueError instead of being overridden. - - The manager can be included in :py:meth:`with_request_info` to pause before the manager - runs, allowing human steering of orchestration decisions. If no filter is specified, - the manager is included automatically. To filter explicitly:: - - .with_request_info(agents=[manager, writer]) # Pause before manager and writer - """ - if self._manager is not None or self._manager_participant is not None: + if self._selection_func is not None: raise ValueError( - "GroupChatBuilder already has a manager configured. " - "Call set_select_speakers_func(...) or set_manager(...) at most once." + "A selection function has already been configured. " + "Call only one of with_agent_orchestrator(...) or with_select_speaker_func(...)." ) - if not isinstance(manager, (AgentProtocol, Executor)): - raise TypeError(f"Manager must be AgentProtocol or Executor instance. Got {type(manager).__name__}.") - - # Infer display name from manager if not provided - if display_name is None: - display_name = manager.id if isinstance(manager, Executor) else manager.name or "manager" - - # Enforce ManagerSelectionResponse for ChatAgent managers - if isinstance(manager, ChatAgent): - configured_format = manager.chat_options.response_format - if configured_format is None: - manager.chat_options.response_format = ManagerSelectionResponse - elif configured_format is not ManagerSelectionResponse: - configured_format_name = getattr(configured_format, "__name__", str(configured_format)) - raise ValueError( - "Manager ChatAgent response_format must be ManagerSelectionResponse. " - f"Received '{configured_format_name}' for manager '{display_name}'." - ) - - self._manager_participant = manager - self._manager_name = display_name + self._agent_orchestrator = agent return self - def set_select_speakers_func( + def with_select_speaker_func( self, - selector: ( - Callable[[GroupChatStateSnapshot], Awaitable[str | None]] | Callable[[GroupChatStateSnapshot], str | None] - ), + selection_func: GroupChatSelectionFunction, *, - display_name: str | None = None, - final_message: ChatMessage | str | Callable[[GroupChatStateSnapshot], Any] | None = None, + orchestrator_name: str | None = None, ) -> "GroupChatBuilder": - """Configure speaker selection using a pure function that examines group chat state. - - This is the primary way to control orchestration flow in a GroupChat. Your selector - function receives an immutable snapshot of the current conversation state and returns - the name of the next participant to speak, or None to finish the conversation. + """Define a custom function to select the next speaker in the group chat. - The selector function can implement any logic including: - - Simple round-robin or rule-based selection - - LLM-based decision making with custom prompts - - Conversation summarization before routing to the next agent - - Custom metadata or context passing - - For advanced scenarios, return a GroupChatDirective instead of a string to include - custom instructions or metadata for the next participant. - - The selector function signature: - def select_next_speaker(state: GroupChatStateSnapshot) -> str | None: - # state contains: task, participants, conversation, history, round_index - # Return participant name to continue, or None to finish - ... + This is a quick way to implement simple orchestration logic without needing a full + GroupChatOrchestrator. The provided function receives the current state of + the group chat and returns the name of the next participant to speak. Args: - selector: Function that takes GroupChatStateSnapshot and returns the next speaker's - name (str) to continue the conversation, or None to finish. May be sync or async. - Can also return GroupChatDirective for advanced control (instruction, metadata). - display_name: Optional name shown in conversation history for orchestrator messages - (defaults to "manager"). - final_message: Optional final message (or factory) emitted when selector returns None - (defaults to "Conversation completed." authored by the manager). + selection_func: Callable that receives the current GroupChatState and returns + the name of the next participant to speak, or None to finish. + orchestrator_name: Optional display name for the orchestrator in the workflow. + If not provided, defaults to `GroupChatBuilder.DEFAULT_ORCHESTRATOR_ID`. Returns: - Self for fluent chaining. + Self for fluent chaining - Example (simple): + Raises: + ValueError: If an orchestrator has already been set + Example: .. code-block:: python - def select_next_speaker(state: GroupChatStateSnapshot) -> str | None: - if state["round_index"] >= 3: - return None # Finish after 3 rounds - last_speaker = state["history"][-1].speaker if state["history"] else None - if last_speaker == "researcher": - return "writer" - return "researcher" + from agent_framework import GroupChatBuilder, GroupChatState + + + async def round_robin_selector(state: GroupChatState) -> str: + # Simple round-robin selection among participants + return state.participants[state.current_round % len(state.participants)] workflow = ( GroupChatBuilder() - .set_select_speakers_func(select_next_speaker) - .participants(researcher=researcher_agent, writer=writer_agent) + .with_select_speaker_func(round_robin_selector, orchestrator_name="Coordinator") + .participants([agent1, agent2]) .build() ) - - Example (with LLM and custom instructions): - - .. code-block:: python - - from agent_framework import GroupChatDirective - - - async def llm_based_selector(state: GroupChatStateSnapshot) -> GroupChatDirective | None: - if state["round_index"] >= 5: - return GroupChatDirective(finish=True) - - # Use LLM to decide next speaker and summarize conversation - conversation_summary = await summarize_with_llm(state["conversation"]) - next_agent = await pick_agent_with_llm(state["participants"], state["task"]) - - # Pass custom instruction to the selected agent - return GroupChatDirective( - agent_name=next_agent, - instruction=f"Context summary: {conversation_summary}", - ) - - - workflow = GroupChatBuilder().set_select_speakers_func(llm_based_selector).participants(...).build() - - Note: - Cannot be combined with :py:meth:`GroupChatBuilder.set_manager`. Choose one orchestration strategy. """ - manager_name = display_name or "manager" - adapter = _SpeakerSelectorAdapter( - selector, - manager_name=manager_name, - final_message=final_message, - ) - return self._set_manager_function(adapter, display_name) + if self._selection_func is not None: + raise ValueError( + "select_speakers_func has already been configured. Call with_select_speakers_func(...) at most once." + ) + if self._orchestrator is not None: + raise ValueError( + "An orchestrator has already been configured. " + "Call only one of with_select_speaker_func(...) or with_orchestrator(...)." + ) + if self._agent_orchestrator is not None: + raise ValueError( + "An agent orchestrator has already been configured. " + "Call only one of with_select_speaker_func(...) or with_agent_orchestrator(...)." + ) - def participants( - self, - participants: Mapping[str, AgentProtocol | Executor] | Sequence[AgentProtocol | Executor] | None = None, - /, - **named_participants: AgentProtocol | Executor, - ) -> "GroupChatBuilder": + self._selection_func = selection_func + self._orchestrator_name = orchestrator_name + return self + + def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "GroupChatBuilder": """Define participants for this group chat workflow. Accepts AgentProtocol instances (auto-wrapped as AgentExecutor) or Executor instances. - Provide a mapping of name → participant for explicit control, or pass a sequence and - names will be inferred from the agent's name attribute (or executor id). Args: - participants: Optional mapping or sequence of participant definitions - **named_participants: Keyword arguments mapping names to agent/executor instances + participants: Sequence of participant definitions Returns: Self for fluent chaining Raises: - ValueError: If participants are empty, names are duplicated, or names are empty strings + ValueError: If participants are empty, names are duplicated, or already set + TypeError: If any participant is not AgentProtocol or Executor instance - Usage: + Example: .. code-block:: python from agent_framework import GroupChatBuilder workflow = ( - GroupChatBuilder().set_manager(manager_agent).participants([writer_agent, reviewer_agent]).build() + GroupChatBuilder() + .with_select_speaker_func(my_selection_function) + .participants([agent1, agent2, custom_executor]) + .build() ) """ - combined: dict[str, AgentProtocol | Executor] = {} - - def _add(name: str, participant: AgentProtocol | Executor) -> None: - if not name: - raise ValueError("participant names must be non-empty strings") - if name in combined or name in self._participants: - raise ValueError(f"Duplicate participant name '{name}' supplied.") - if name == self._manager_name: - raise ValueError( - f"Participant name '{name}' conflicts with manager name. " - "Manager is automatically registered as a participant." - ) - combined[name] = participant - - if participants: - if isinstance(participants, Mapping): - for name, participant in participants.items(): - _add(name, participant) + if self._participants: + raise ValueError("participants have already been set. Call participants(...) at most once.") + + if not participants: + raise ValueError("participants cannot be empty.") + + # Name of the executor mapped to participant instance + named: dict[str, AgentProtocol | Executor] = {} + for participant in participants: + if isinstance(participant, Executor): + identifier = participant.id + elif isinstance(participant, AgentProtocol): + if not participant.name: + raise ValueError("AgentProtocol participants must have a non-empty name.") + identifier = participant.name else: - for participant in participants: - inferred_name: str - if isinstance(participant, Executor): - inferred_name = participant.id - else: - name_attr = getattr(participant, "name", None) - if not name_attr: - raise ValueError( - "Agent participants supplied via sequence must define a non-empty 'name' attribute." - ) - inferred_name = str(name_attr) - _add(inferred_name, participant) - - for name, participant in named_participants.items(): - _add(name, participant) - - if not combined: - raise ValueError("participants cannot be empty") - - for name, participant in combined.items(): - self._participants[name] = participant - self._participant_metadata = None - return self - - def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "GroupChatBuilder": - """Enable checkpointing for the built workflow using the provided storage. - - Checkpointing allows the workflow to persist state and resume from interruption - points, enabling long-running conversations and failure recovery. - - Args: - checkpoint_storage: Storage implementation for persisting workflow state - - Returns: - Self for fluent chaining + raise TypeError( + f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}." + ) - Usage: + if identifier in named: + raise ValueError(f"Duplicate participant name '{identifier}' detected") - .. code-block:: python + named[identifier] = participant - from agent_framework import GroupChatBuilder, MemoryCheckpointStorage + self._participants = named - storage = MemoryCheckpointStorage() - workflow = ( - GroupChatBuilder() - .set_manager(manager_agent) - .participants(agent1=agent1, agent2=agent2) - .with_checkpointing(storage) - .build() - ) - """ - self._checkpoint_storage = checkpoint_storage return self - def with_request_handler( - self, - handler: Callable[[_GroupChatConfig], Executor] | Executor, - *, - condition: Callable[[Any], bool], - ) -> "GroupChatBuilder": - """Register an interceptor factory that creates executors for special requests. + def with_termination_condition(self, termination_condition: TerminationCondition) -> "GroupChatBuilder": + """Set a custom termination condition for the group chat workflow. Args: - handler: Callable that receives the wiring and returns an executor, or a pre-built executor - condition: Filter determining which orchestrator messages the interceptor should process + termination_condition: Callable that receives the conversation history and returns + True to terminate the conversation, False to continue. Returns: Self for fluent chaining - """ - factory: Callable[[_GroupChatConfig], Executor] - if isinstance(handler, Executor): - executor = handler - - def _factory(_: _GroupChatConfig) -> Executor: - return executor - - factory = _factory - else: - factory = handler - - self._interceptors.append((factory, condition)) - return self - - def with_termination_condition( - self, - condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]], - ) -> "GroupChatBuilder": - """Define a custom termination condition for the group chat workflow. - - The condition receives the full conversation (including manager and agent messages) and may be async. - When it returns True, the orchestrator halts the conversation and emits a completion message authored - by the manager. Example: @@ -1751,324 +752,179 @@ def stop_after_two_calls(conversation: list[ChatMessage]) -> bool: specialist_agent = ... workflow = ( GroupChatBuilder() - .set_select_speakers_func(lambda _: "specialist") - .participants(specialist=specialist_agent) + .with_select_speaker_func(my_selection_function) + .participants([agent1, specialist_agent]) .with_termination_condition(stop_after_two_calls) .build() ) """ - self._termination_condition = condition + if self._orchestrator is not None: + logger.warning( + "Orchestrator has already been configured; setting termination condition on builder has no effect." + ) + + self._termination_condition = termination_condition return self def with_max_rounds(self, max_rounds: int | None) -> "GroupChatBuilder": - """Set a maximum number of manager rounds to prevent infinite conversations. + """Set a maximum number of orchestrator rounds to prevent infinite conversations. When the round limit is reached, the workflow automatically completes with a default completion message. Setting to None allows unlimited rounds. Args: - max_rounds: Maximum number of manager selection rounds, or None for unlimited + max_rounds: Maximum number of orchestrator selection rounds, or None for unlimited Returns: Self for fluent chaining - - Usage: - - .. code-block:: python - - from agent_framework import GroupChatBuilder - - # Limit to 15 rounds - workflow = ( - GroupChatBuilder() - .set_manager(manager_agent) - .participants(agent1=agent1, agent2=agent2) - .with_max_rounds(15) - .build() - ) - - # Unlimited rounds - workflow = ( - GroupChatBuilder().set_manager(manager_agent).participants(agent1=agent1).with_max_rounds(None).build() - ) """ self._max_rounds = max_rounds return self - def with_request_info( - self, - *, - agents: Sequence[str | AgentProtocol | Executor] | None = None, - ) -> "GroupChatBuilder": - """Enable request info before participants run in the workflow. + def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "GroupChatBuilder": + """Enable checkpointing for the built workflow using the provided storage. - When enabled, the workflow pauses before each participant runs, emitting - a RequestInfoEvent that allows the caller to review the conversation and - optionally inject guidance before the participant responds. The caller provides - input via the standard response_handler/request_info pattern. + Checkpointing allows the workflow to persist state and resume from interruption + points, enabling long-running conversations and failure recovery. Args: - agents: Optional filter - only pause before these specific agents/executors. - Accepts agent names (str), agent instances, or executor instances. - If None (default), pauses before every participant. + checkpoint_storage: Storage implementation for persisting workflow state Returns: - self: The builder instance for fluent chaining. + Self for fluent chaining Example: .. code-block:: python - # Pause before all participants - workflow = ( - GroupChatBuilder() - .set_manager(manager) - .participants([optimist, pragmatist, creative]) - .with_request_info() - .build() - ) + from agent_framework import GroupChatBuilder, MemoryCheckpointStorage - # Pause only before specific participants + storage = MemoryCheckpointStorage() workflow = ( GroupChatBuilder() - .set_manager(manager) - .participants([optimist, pragmatist, creative]) - .with_request_info(agents=[pragmatist]) # Only pause before pragmatist + .with_select_speaker_func(my_selection_function) + .participants([agent1, agent2]) + .with_checkpointing(storage) .build() ) """ - from ._orchestration_request_info import resolve_request_info_filter - - self._request_info_enabled = True - self._request_info_filter = resolve_request_info_filter(list(agents) if agents else None) + self._checkpoint_storage = checkpoint_storage return self - def _get_participant_metadata(self) -> dict[str, Any]: - if self._participant_metadata is None: - self._participant_metadata = prepare_participant_metadata( - self._participants, - executor_id_factory=lambda name, participant: ( - participant.id if isinstance(participant, Executor) else f"groupchat_agent:{name}" - ), - description_factory=lambda name, participant: ( - participant.id if isinstance(participant, Executor) else participant.__class__.__name__ - ), - ) - return self._participant_metadata - - def _build_participant_specs(self) -> dict[str, GroupChatParticipantSpec]: - metadata = self._get_participant_metadata() - descriptions: Mapping[str, str] = metadata["descriptions"] - specs: dict[str, GroupChatParticipantSpec] = {} - for name, participant in self._participants.items(): - specs[name] = GroupChatParticipantSpec( - name=name, - participant=participant, - description=descriptions[name], - ) - return specs + def with_request_info(self, *, agents: Sequence[str | AgentProtocol] | None = None) -> "GroupChatBuilder": + """Enable request info after agent participant responses. - def build(self) -> Workflow: - """Build and validate the group chat workflow. + This enables human-in-the-loop (HIL) scenarios for the group chat orchestration. + When enabled, the workflow pauses after each agent participant runs, emitting + a RequestInfoEvent that allows the caller to review the conversation and optionally + inject guidance for the agent participant to iterate. The caller provides input via + the standard response_handler/request_info pattern. - Assembles the orchestrator, participants, and their interconnections into - a complete workflow graph. The orchestrator delegates speaker selection to - the manager, routes requests to the appropriate participants, and collects - their responses to continue or complete the conversation. + Simulated flow with HIL: + Input -> Orchestrator -> [Participant <-> Request Info] -> Orchestrator -> [Participant <-> Request Info] -> ... - Returns: - Validated Workflow instance ready for execution + Note: This is only available for agent participants. Executor participants can incorporate + request info handling in their own implementation if desired. - Raises: - ValueError: If manager or participants are not configured (when using default factory) + Args: + agents: Optional list of agents names to enable request info for. + If None, enables HIL for all agent participants. - Wiring pattern: - - Orchestrator receives initial input (str, ChatMessage, or list[ChatMessage]) - - Orchestrator queries manager for next action (participant selection or finish) - - If participant selected: request routed directly to participant entry node - - Participant pipeline: AgentExecutor for agents or custom executor chains - - Participant response flows back to orchestrator - - Orchestrator updates state and queries manager again - - When manager returns finish directive: orchestrator yields final message and becomes idle + Returns: + Self for fluent chaining + """ + from ._orchestration_request_info import resolve_request_info_filter - Usage: + self._request_info_enabled = True + self._request_info_filter = resolve_request_info_filter(list(agents) if agents else None) - .. code-block:: python + return self - from agent_framework import GroupChatBuilder + def _resolve_orchestrator(self, participants: Sequence[Executor]) -> Executor: + """Determine the orchestrator to use for the workflow. - # Execute the workflow - workflow = GroupChatBuilder().set_manager(manager_agent).participants(agent1=agent1, agent2=agent2).build() - async for message in workflow.run("Solve this problem collaboratively"): - print(message.text) + Args: + participants: List of resolved participant executors """ - # Manager is only required when using the default orchestrator factory - # Custom factories (e.g., MagenticBuilder) provide their own orchestrator with embedded manager - if ( - self._manager is None - and self._manager_participant is None - and self._orchestrator_factory == _default_orchestrator_factory - ): + if self._orchestrator is not None: + return self._orchestrator + + if self._agent_orchestrator is not None and self._selection_func is not None: raise ValueError( - "manager must be configured before build() when using default orchestrator. " - "Call set_manager(...) or set_select_speakers_func(...) before build()." + "Both agent-based orchestrator and selection function are configured; only one can be used at a time." ) - if not self._participants: - raise ValueError("participants must be configured before build()") - - metadata = self._get_participant_metadata() - participant_specs = self._build_participant_specs() - wiring = _GroupChatConfig( - manager=self._manager, - manager_participant=self._manager_participant, - manager_name=self._manager_name, - participants=participant_specs, - max_rounds=self._max_rounds, - termination_condition=self._termination_condition, - participant_aliases=metadata["aliases"], - participant_executors=metadata["executors"], - ) - # Determine participant factory - wrap if request info is enabled - participant_factory = self._participant_factory - if self._request_info_enabled: - # Create a wrapper factory that adds request info interceptor before each participant - base_factory = participant_factory - agent_filter = self._request_info_filter - - def _factory_with_request_info( - spec: GroupChatParticipantSpec, - config: _GroupChatConfig, - ) -> _GroupChatParticipantPipeline: - pipeline = list(base_factory(spec, config)) - if pipeline: - # Add interceptor executor BEFORE the participant (prepend) - interceptor = RequestInfoInterceptor( - executor_id=f"request_info:{spec.name}", - agent_filter=agent_filter, - ) - pipeline.insert(0, interceptor) - return tuple(pipeline) + if self._selection_func is not None: + return GroupChatOrchestrator( + id=self.DEFAULT_ORCHESTRATOR_ID, + participant_registry=ParticipantRegistry(participants), + selection_func=self._selection_func, + name=self._orchestrator_name, + max_rounds=self._max_rounds, + termination_condition=self._termination_condition, + ) - participant_factory = _factory_with_request_info + if self._agent_orchestrator is not None: + return AgentBasedGroupChatOrchestrator( + agent=self._agent_orchestrator, + participant_registry=ParticipantRegistry(participants), + max_rounds=self._max_rounds, + termination_condition=self._termination_condition, + ) - result = assemble_group_chat_workflow( - wiring=wiring, - participant_factory=participant_factory, - orchestrator_factory=self._orchestrator_factory, - interceptors=self._interceptors, - checkpoint_storage=self._checkpoint_storage, + raise RuntimeError( + "Orchestrator could not be resolved. Please provide one via with_orchestrator(), " + "with_agent_orchestrator(), or with_select_speaker_func()." ) - if not isinstance(result, Workflow): - raise TypeError("Expected Workflow from assemble_group_chat_workflow") - return result - - -# endregion + def _resolve_participants(self) -> list[Executor]: + """Resolve participant instances into Executor objects.""" + executors: list[Executor] = [] + for participant in self._participants.values(): + if isinstance(participant, Executor): + executors.append(participant) + elif isinstance(participant, AgentProtocol): + if self._request_info_enabled and ( + not self._request_info_filter or resolve_agent_id(participant) in self._request_info_filter + ): + # Handle request info enabled agents + executors.append(AgentApprovalExecutor(participant)) + else: + executors.append(AgentExecutor(participant)) + else: + raise TypeError( + f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}." + ) -# region Default manager implementation - - -DEFAULT_MANAGER_INSTRUCTIONS = """You are coordinating a team conversation to solve the user's task. -Your role is to orchestrate collaboration between multiple participants by selecting who speaks next. -Leverage each participant's unique expertise as described in their descriptions. -Have participants build on each other's contributions - earlier participants gather information, -later ones refine and synthesize. -Only finish the task after multiple relevant participants have contributed their expertise.""" - -DEFAULT_MANAGER_STRUCTURED_OUTPUT_PROMPT = """Return your decision using the following structure: -- next_agent: name of the participant who should act next (use null when finish is true) -- message: instruction for that participant (empty string if not needed) -- finish: boolean indicating if the task is complete -- final_response: when finish is true, provide the final answer to the user""" + return executors + def build(self) -> Workflow: + """Build and validate the group chat workflow. -class ManagerDirectiveModel(BaseModel): - """Pydantic model for structured manager directive output.""" - - next_agent: str | None = Field( - default=None, - description="Name of the participant who should act next (null when finish is true)", - ) - message: str = Field( - default="", - description="Instruction for the selected participant", - ) - finish: bool = Field( - default=False, - description="Whether the task is complete", - ) - final_response: str | None = Field( - default=None, - description="Final answer to the user when finish is true", - ) + Assembles the orchestrator and participants into a complete workflow graph. + The workflow graph consists of bi-directional edges between the orchestrator and each participant, + allowing for message exchanges in both directions. + Returns: + Validated Workflow instance ready for execution + """ + if not self._participants: + raise ValueError("participants must be configured before build()") -class _SpeakerSelectorAdapter: - """Adapter that turns a simple speaker selector into a full manager directive.""" + # Resolve orchestrator and participants to executors + participants: list[Executor] = self._resolve_participants() + orchestrator: Executor = self._resolve_orchestrator(participants) - def __init__( - self, - selector: Callable[[GroupChatStateSnapshot], Awaitable[Any]] | Callable[[GroupChatStateSnapshot], Any], - *, - manager_name: str, - final_message: ChatMessage | str | Callable[[GroupChatStateSnapshot], Any] | None = None, - ) -> None: - self._selector = selector - self._manager_name = manager_name - self._final_message = final_message - self.name = manager_name - - async def __call__(self, state: GroupChatStateSnapshot) -> GroupChatDirective: - result = await _maybe_await(self._selector(state)) - if result is None: - message = await self._resolve_final_message(state) - return GroupChatDirective(finish=True, final_message=message) - - if isinstance(result, Sequence) and not isinstance(result, (str, bytes, bytearray)): - if not result: - message = await self._resolve_final_message(state) - return GroupChatDirective(finish=True, final_message=message) - if len(result) != 1: # type: ignore[arg-type] - raise ValueError("Speaker selector must return a single participant name or None.") - first_item = result[0] # type: ignore[index] - if not isinstance(first_item, str): - raise TypeError("Speaker selector must return a participant name (str) or None.") - result = first_item - - if not isinstance(result, str): - raise TypeError("Speaker selector must return a participant name (str) or None.") - - return GroupChatDirective(agent_name=result) - - async def _resolve_final_message(self, state: GroupChatStateSnapshot) -> ChatMessage: - final_message = self._final_message - if callable(final_message): - value = await _maybe_await(final_message(state)) - else: - value = final_message - - if value is None: - message = ChatMessage( - role=Role.ASSISTANT, - text="Conversation completed.", - author_name=self._manager_name, - ) - elif isinstance(value, ChatMessage): - message = value - else: - message = ChatMessage( - role=Role.ASSISTANT, - text=str(value), - author_name=self._manager_name, - ) + # Build workflow graph + workflow_builder = WorkflowBuilder().set_start_executor(orchestrator) + for participant in participants: + # Orchestrator and participant bi-directional edges + workflow_builder = workflow_builder.add_edge(orchestrator, participant) + workflow_builder = workflow_builder.add_edge(participant, orchestrator) + if self._checkpoint_storage is not None: + workflow_builder = workflow_builder.with_checkpointing(self._checkpoint_storage) - if not message.author_name: - patch = message.to_dict() - patch["author_name"] = self._manager_name - message = ChatMessage.from_dict(patch) - return message + return workflow_builder.build() # endregion diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index 33c533c5e5..79e97dfca8 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -2,58 +2,53 @@ """High-level builder for conversational handoff workflows. -The handoff pattern models a coordinator agent that optionally routes -control to specialist agents before handing the conversation back to the user. -The flow is intentionally cyclical by default: +The handoff pattern models a group of agents that can intelligently route +control to other agents based on the conversation context. - user input -> coordinator -> optional specialist -> request user input -> ... +The flow is typically: -An autonomous interaction mode can bypass the user input request and continue routing -responses back to agents until a handoff occurs or termination criteria are met. + user input -> Agent A -> Agent B -> Agent C -> Agent A -> ... -> output + +Depending of wether request info is enabled, the flow may include user input (except when an agent hands off): + + user input -> [Agent A -> Request info] -> [Agent B -> Request info] -> [Agent C -> ... -> output + +The difference between a group chat workflow and a handoff workflow is that in group chat there is +always a orchestrator that decides who to speak next, while in handoff the agents themselves decide +who to handoff to next by invoking a tool call that names the target agent. + +Group Chat: centralized orchestration of multiple agents +Handoff: decentralized routing by agents themselves Key properties: - The entire conversation is maintained and reused on every hop -- The coordinator signals a handoff by invoking a tool call that names the specialist +- Agents signal handoffs by invoking a tool call that names the other agents - In human_in_loop mode (default), the workflow requests user input after each agent response that doesn't trigger a handoff - In autonomous mode, agents continue responding until they invoke a handoff tool or reach a termination condition or turn limit """ +import inspect import logging -import re import sys from collections.abc import Awaitable, Callable, Mapping, Sequence -from dataclasses import dataclass, field -from typing import Any, Literal - -from agent_framework import ( - AgentProtocol, - AgentRunResponse, - AIFunction, - ChatMessage, - FunctionApprovalRequestContent, - FunctionCallContent, - FunctionResultContent, - Role, - ai_function, -) - -from .._agents import ChatAgent +from dataclasses import dataclass +from typing import Any, cast + +from typing_extensions import Never + +from .._agents import AgentProtocol, ChatAgent from .._middleware import FunctionInvocationContext, FunctionMiddleware +from .._threads import AgentThread +from .._tools import AIFunction, ai_function +from .._types import AgentResponse, ChatMessage, Role from ._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse -from ._base_group_chat_orchestrator import BaseGroupChatOrchestrator +from ._agent_utils import resolve_agent_id +from ._base_group_chat_orchestrator import TerminationCondition from ._checkpoint import CheckpointStorage -from ._executor import Executor, handler -from ._group_chat import ( - _default_participant_factory, # type: ignore[reportPrivateUsage] - _GroupChatConfig, # type: ignore[reportPrivateUsage] - _GroupChatParticipantPipeline, # type: ignore[reportPrivateUsage] - assemble_group_chat_workflow, -) -from ._orchestration_request_info import RequestInfoInterceptor +from ._events import WorkflowEvent from ._orchestrator_helpers import clean_conversation_for_handoff -from ._participant_utils import GroupChatParticipantSpec, prepare_participant_metadata, sanitize_identifier from ._request_info_mixin import response_handler from ._workflow import Workflow from ._workflow_builder import WorkflowBuilder @@ -67,141 +62,75 @@ logger = logging.getLogger(__name__) -_HANDOFF_TOOL_PATTERN = re.compile(r"(?:handoff|transfer)[_\s-]*to[_\s-]*(?P[\w-]+)", re.IGNORECASE) -_DEFAULT_AUTONOMOUS_TURN_LIMIT = 50 +# region Handoff events +class HandoffSentEvent(WorkflowEvent): + """Base class for handoff workflow events.""" -def _create_handoff_tool(alias: str, description: str | None = None) -> AIFunction[Any, Any]: - """Construct the synthetic handoff tool that signals routing to `alias`.""" - sanitized = sanitize_identifier(alias) - tool_name = f"handoff_to_{sanitized}" - doc = description or f"Handoff to the {alias} agent." - - # Note: approval_mode is intentionally NOT set for handoff tools. - # Handoff tools are framework-internal signals that trigger routing logic, - # not actual function executions. They are automatically intercepted by - # _AutoHandoffMiddleware which short-circuits execution and provides synthetic - # results, so the function body never actually runs in practice. - @ai_function(name=tool_name, description=doc) - def _handoff_tool(context: str | None = None) -> str: - """Return a deterministic acknowledgement that encodes the target alias.""" - return f"Handoff to {alias}" - - return _handoff_tool - - -def _clone_chat_agent(agent: ChatAgent) -> ChatAgent: - """Produce a deep copy of the ChatAgent while preserving runtime configuration.""" - options = agent.chat_options - middleware = list(agent.middleware or []) - - # Reconstruct the original tools list by combining regular tools with MCP tools. - # ChatAgent.__init__ separates MCP tools into _local_mcp_tools during initialization, - # so we need to recombine them here to pass the complete tools list to the constructor. - # This makes sure MCP tools are preserved when cloning agents for handoff workflows. - all_tools = list(options.tools) if options.tools else [] - if agent._local_mcp_tools: # type: ignore - all_tools.extend(agent._local_mcp_tools) # type: ignore - - return ChatAgent( - chat_client=agent.chat_client, - instructions=options.instructions, - id=agent.id, - name=agent.name, - description=agent.description, - chat_message_store_factory=agent.chat_message_store_factory, - context_providers=agent.context_provider, - middleware=middleware, - # Disable parallel tool calls to prevent the agent from invoking multiple handoff tools at once. - allow_multiple_tool_calls=False, - frequency_penalty=options.frequency_penalty, - logit_bias=dict(options.logit_bias) if options.logit_bias else None, - max_tokens=options.max_tokens, - metadata=dict(options.metadata) if options.metadata else None, - model_id=options.model_id, - presence_penalty=options.presence_penalty, - response_format=options.response_format, - seed=options.seed, - stop=options.stop, - store=options.store, - temperature=options.temperature, - tool_choice=options.tool_choice, # type: ignore[arg-type] - tools=all_tools if all_tools else None, - top_p=options.top_p, - user=options.user, - additional_chat_options=dict(options.additional_properties), - ) + def __init__(self, source: str, target: str, data: Any | None = None) -> None: + """Initialize handoff sent event. + Args: + source: Identifier of the source agent initiating the handoff + target: Identifier of the target agent receiving the handoff + data: Optional event-specific data + """ + super().__init__(data) + self.source = source + self.target = target -@dataclass -class HandoffUserInputRequest: - """Request message emitted when the workflow needs fresh user input. - Note: The conversation field is intentionally excluded from checkpoint serialization - to prevent duplication. The conversation is preserved in the coordinator's state - and will be reconstructed on restore. See issue #2667. - """ +# endregion - conversation: list[ChatMessage] - awaiting_agent_id: str - prompt: str - source_executor_id: str - def to_dict(self) -> dict[str, Any]: - """Serialize to dict, excluding conversation to prevent checkpoint duplication. +@dataclass +class HandoffConfiguration: + """Configuration for handoff routing between agents. - The conversation is already preserved in the workflow coordinator's state. - Including it here would cause duplicate messages when restoring from checkpoint. - """ - return { - "awaiting_agent_id": self.awaiting_agent_id, - "prompt": self.prompt, - "source_executor_id": self.source_executor_id, - } + Attributes: + target_id: Identifier of the target agent to hand off to + description: Optional human-readable description of the handoff + """ + + target_id: str + description: str | None = None - @classmethod - def from_dict(cls, data: dict[str, Any]) -> "HandoffUserInputRequest": - """Deserialize from dict, initializing conversation as empty. + def __init__(self, *, target: str | AgentProtocol, description: str | None = None) -> None: + """Initialize HandoffConfiguration. - The conversation will be reconstructed from the coordinator's state on restore. + Args: + target: Target agent identifier or AgentProtocol instance + description: Optional human-readable description of the handoff """ - return cls( - conversation=[], - awaiting_agent_id=data["awaiting_agent_id"], - prompt=data["prompt"], - source_executor_id=data["source_executor_id"], - ) + self.target_id = resolve_agent_id(target) if isinstance(target, AgentProtocol) else target + self.description = description + def __eq__(self, other: Any) -> bool: + """Determine equality based on source_id and target_id.""" + if not isinstance(other, HandoffConfiguration): + return False -@dataclass -class _ConversationWithUserInput: - """Internal message carrying full conversation + new user messages from gateway to coordinator. + return self.target_id == other.target_id - Attributes: - full_conversation: The conversation messages to process. - is_post_restore: If True, indicates this message was created after a checkpoint restore. - The coordinator should append these messages to its existing conversation rather - than replacing it. This prevents duplicate messages (see issue #2667). - """ + def __hash__(self) -> int: + """Compute hash based on source_id and target_id.""" + return hash(self.target_id) - full_conversation: list[ChatMessage] = field(default_factory=lambda: []) # type: ignore[misc] - is_post_restore: bool = False +def get_handoff_tool_name(target_id: str) -> str: + """Get the standardized handoff tool name for a given target agent ID.""" + return f"handoff_to_{target_id}" -@dataclass -class _ConversationForUserInput: - """Internal message from coordinator to gateway specifying which agent will receive the response.""" - conversation: list[ChatMessage] - next_agent_id: str +HANDOFF_FUNCTION_RESULT_KEY = "handoff_to" class _AutoHandoffMiddleware(FunctionMiddleware): """Intercept handoff tool invocations and short-circuit execution with synthetic results.""" - def __init__(self, handoff_targets: Mapping[str, str]) -> None: + def __init__(self, handoffs: Sequence[HandoffConfiguration]) -> None: """Initialise middleware with the mapping from tool name to specialist id.""" - self._targets = {name.lower(): target for name, target in handoff_targets.items()} + self._handoff_functions = {get_handoff_tool_name(handoff.target_id): handoff.target_id for handoff in handoffs} async def process( self, @@ -209,782 +138,504 @@ async def process( next: Callable[[FunctionInvocationContext], Awaitable[None]], ) -> None: """Intercept matching handoff tool calls and inject synthetic results.""" - name = getattr(context.function, "name", "") - normalized = name.lower() if name else "" - target = self._targets.get(normalized) - if target is None: + if context.function.name not in self._handoff_functions: await next(context) return # Short-circuit execution and provide deterministic response payload for the tool call. - context.result = {"handoff_to": target} + context.result = {HANDOFF_FUNCTION_RESULT_KEY: self._handoff_functions[context.function.name]} context.terminate = True -class _InputToConversation(Executor): - """Normalizes initial workflow input into a list[ChatMessage].""" +@dataclass +class HandoffAgentUserRequest: + """Request issued to the user after an agent run in a handoff workflow. + + Attributes: + agent_response: The response generated by the agent at the most recent turn + """ - @handler - async def from_str(self, prompt: str, ctx: WorkflowContext[list[ChatMessage]]) -> None: - """Convert a raw user prompt into a conversation containing a single user message.""" - await ctx.send_message([ChatMessage(Role.USER, text=prompt)]) + agent_response: AgentResponse + + @staticmethod + def create_response(response: str | list[str] | ChatMessage | list[ChatMessage]) -> list[ChatMessage]: + """Create a HandoffAgentUserRequest from a simple text response.""" + messages: list[ChatMessage] = [] + if isinstance(response, str): + messages.append(ChatMessage(role=Role.USER, text=response)) + elif isinstance(response, ChatMessage): + messages.append(response) + elif isinstance(response, list): + for item in response: + if isinstance(item, ChatMessage): + messages.append(item) + elif isinstance(item, str): + messages.append(ChatMessage(role=Role.USER, text=item)) + else: + raise TypeError("List items must be either str or ChatMessage instances") + else: + raise TypeError("Response must be str, list of str, ChatMessage, or list of ChatMessage") - @handler - async def from_message(self, message: ChatMessage, ctx: WorkflowContext[list[ChatMessage]]) -> None: - """Pass through an existing chat message as the initial conversation.""" - await ctx.send_message([message]) + return messages - @handler - async def from_messages(self, messages: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: - """Forward a list of chat messages as the starting conversation history.""" - await ctx.send_message(list(messages)) + @staticmethod + def terminate() -> list[ChatMessage]: + """Create a termination response for the handoff workflow.""" + return [] -@dataclass -class _HandoffResolution: - """Result of handoff detection containing the target alias and originating call.""" - - target: str - function_call: FunctionCallContent | None = None - - -def _resolve_handoff_target(agent_response: AgentRunResponse) -> _HandoffResolution | None: - """Detect handoff intent from tool call metadata.""" - for message in agent_response.messages: - resolution = _resolution_from_message(message) - if resolution: - return resolution - - for request in agent_response.user_input_requests: - if isinstance(request, FunctionApprovalRequestContent): - resolution = _resolution_from_function_call(request.function_call) - if resolution: - return resolution - - return None - - -def _resolution_from_message(message: ChatMessage) -> _HandoffResolution | None: - """Inspect an assistant message for embedded handoff tool metadata.""" - for content in getattr(message, "contents", ()): - if isinstance(content, FunctionApprovalRequestContent): - resolution = _resolution_from_function_call(content.function_call) - if resolution: - return resolution - elif isinstance(content, FunctionCallContent): - resolution = _resolution_from_function_call(content) - if resolution: - return resolution - return None - - -def _resolution_from_function_call(function_call: FunctionCallContent | None) -> _HandoffResolution | None: - """Wrap the target resolved from a function call in a `_HandoffResolution`.""" - if function_call is None: - return None - target = _target_from_function_call(function_call) - if not target: - return None - return _HandoffResolution(target=target, function_call=function_call) - - -def _target_from_function_call(function_call: FunctionCallContent) -> str | None: - """Extract the handoff target from the tool name or structured arguments.""" - name_candidate = _target_from_tool_name(function_call.name) - if name_candidate: - return name_candidate - - arguments = function_call.parse_arguments() - if isinstance(arguments, Mapping): - value = arguments.get("handoff_to") - if isinstance(value, str) and value.strip(): - return value.strip() - elif isinstance(arguments, str): - stripped = arguments.strip() - if stripped: - name_candidate = _target_from_tool_name(stripped) - if name_candidate: - return name_candidate - return stripped - - return None - - -def _target_from_tool_name(name: str | None) -> str | None: - """Parse the specialist alias encoded in a handoff tool's name.""" - if not name: - return None - match = _HANDOFF_TOOL_PATTERN.search(name) - if match: - parsed = match.group("target").strip() - if parsed: - return parsed - return None +# In autonomous mode, the agent continues responding until it requests a handoff +# or reaches a turn limit, after which it requests user input to continue. +_AUTONOMOUS_MODE_DEFAULT_PROMPT = "User did not respond. Continue assisting autonomously." +_DEFAULT_AUTONOMOUS_TURN_LIMIT = 50 +# region Handoff Agent Executor -class _HandoffCoordinator(BaseGroupChatOrchestrator): - """Coordinates agent-to-agent transfers and user turn requests.""" + +class HandoffAgentExecutor(AgentExecutor): + """Specialized AgentExecutor that supports handoff tool interception.""" def __init__( self, + agent: AgentProtocol, + handoffs: Sequence[HandoffConfiguration], *, - starting_agent_id: str, - specialist_ids: Mapping[str, str], - input_gateway_id: str, - termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]], - id: str, - handoff_tool_targets: Mapping[str, str] | None = None, - return_to_previous: bool = False, - interaction_mode: Literal["human_in_loop", "autonomous"] = "human_in_loop", - autonomous_turn_limit: int | None = None, + agent_thread: AgentThread | None = None, + is_start_agent: bool = False, + termination_condition: TerminationCondition | None = None, + autonomous_mode: bool = False, + autonomous_mode_prompt: str | None = None, + autonomous_mode_turn_limit: int | None = None, ) -> None: - """Create a coordinator that manages routing between specialists and the user.""" - super().__init__(id) - self._starting_agent_id = starting_agent_id - self._specialist_by_alias = dict(specialist_ids) - self._specialist_ids = set(specialist_ids.values()) - self._input_gateway_id = input_gateway_id - self._termination_condition = termination_condition - self._handoff_tool_targets = {k.lower(): v for k, v in (handoff_tool_targets or {}).items()} - self._return_to_previous = return_to_previous - self._current_agent_id: str | None = None # Track the current agent handling conversation - self._interaction_mode = interaction_mode - self._autonomous_turn_limit = autonomous_turn_limit - self._autonomous_turns = 0 + """Initialize the HandoffAgentExecutor. - def _get_author_name(self) -> str: - """Get the coordinator name for orchestrator-generated messages.""" - return "handoff_coordinator" + Args: + agent: The agent to execute + handoffs: Sequence of handoff configurations defining target agents + agent_thread: Optional AgentThread that manages the agent's execution context + is_start_agent: Whether this agent is the starting agent in the handoff workflow. + There can only be one starting agent in a handoff workflow. + termination_condition: Optional callable that determines when to terminate the workflow + autonomous_mode: Whether the agent should operate involve external systems after + a response that does not trigger a handoff or before the turn + limit is reached. This allows the agent to perform long-running + tasks (e.g., research, coding, analysis) without prematurely returning + control to the coordinator or user. + autonomous_mode_prompt: Prompt to provide to the agent when continuing in autonomous mode. + This will guide the agent in the absence of user input. + autonomous_mode_turn_limit: Maximum number of autonomous turns before requesting user input. + """ + cloned_agent = self._prepare_agent_with_handoffs(agent, handoffs) + super().__init__(cloned_agent, agent_thread=agent_thread) - def _extract_agent_id_from_source(self, source: str | None) -> str | None: - """Extract the original agent ID from the source executor ID. + self._handoff_targets = {handoff.target_id for handoff in handoffs} + self._termination_condition = termination_condition + self._is_start_agent = is_start_agent - When a request info interceptor is in the pipeline, the source will be - like 'request_info:agent_name'. This method extracts the - actual agent ID. + # Autonomous mode members + self._autonomous_mode = autonomous_mode + self._autonomous_mode_prompt = autonomous_mode_prompt or _AUTONOMOUS_MODE_DEFAULT_PROMPT + self._autonomous_mode_turn_limit = autonomous_mode_turn_limit or _DEFAULT_AUTONOMOUS_TURN_LIMIT + self._autonomous_mode_turns = 0 + + def _prepare_agent_with_handoffs( + self, + agent: AgentProtocol, + handoffs: Sequence[HandoffConfiguration], + ) -> AgentProtocol: + """Prepare an agent by adding handoff tools for the specified target agents. Args: - source: The source executor ID from the workflow context + agent: The agent to prepare + handoffs: Sequence of handoff configurations defining target agents Returns: - The actual agent ID, or the original source if not an interceptor + A new AgentExecutor instance with handoff tools added """ - if source is None: - return None - if source.startswith("request_info:"): - return source[len("request_info:") :] - # TODO(@moonbox3): Remove legacy prefix support in a separate PR (GA cleanup) - if source.startswith("human_review:"): - return source[len("human_review:") :] - if source.startswith("human_input_interceptor:"): - return source[len("human_input_interceptor:") :] - return source - - @handler - async def handle_agent_response( - self, - response: AgentExecutorResponse, - ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage] | _ConversationForUserInput], - ) -> None: - """Process an agent's response and determine whether to route, request input, or terminate.""" - raw_source = ctx.get_source_executor_id() - source = self._extract_agent_id_from_source(raw_source) - is_starting_agent = source == self._starting_agent_id - - # On first turn of a run, conversation is empty - # Track new messages only, build authoritative history incrementally - conversation_msgs = self._get_conversation() - if not conversation_msgs: - # First response from starting agent - initialize with authoritative conversation snapshot - # Keep the FULL conversation including tool calls (OpenAI SDK default behavior) - full_conv = self._conversation_from_response(response) - self._conversation = list(full_conv) - else: - # Subsequent responses - append only new messages from this agent - # Keep ALL messages including tool calls to maintain complete history. - # This includes assistant messages with function calls and tool role messages with results. - new_messages = response.agent_run_response.messages or [] - self._conversation.extend(new_messages) - - self._apply_response_metadata(self._conversation, response.agent_run_response) - - conversation = list(self._conversation) - - # Check for handoff from ANY agent (starting agent or specialist) - target = self._resolve_specialist(response.agent_run_response, conversation) - if target is not None: - # Update current agent when handoff occurs - self._current_agent_id = target - self._autonomous_turns = 0 - logger.info(f"Handoff detected: {source} -> {target}. Routing control to specialist '{target}'.") - - # Clean tool-related content before sending to next agent - cleaned = clean_conversation_for_handoff(conversation) - request = AgentExecutorRequest(messages=cleaned, should_respond=True) - await ctx.send_message(request, target_id=target) - return - - # No handoff detected - response must come from starting agent or known specialist - if not is_starting_agent and source not in self._specialist_ids: - raise RuntimeError(f"HandoffCoordinator received response from unknown executor '{source}'.") - - # Update current agent when they respond without handoff - self._current_agent_id = source - if await self._check_termination(): - # Clean the output conversation for display - cleaned_output = clean_conversation_for_handoff(conversation) - await ctx.yield_output(cleaned_output) - return - - if self._interaction_mode == "autonomous": - self._autonomous_turns += 1 - if self._autonomous_turn_limit is not None and self._autonomous_turns >= self._autonomous_turn_limit: - logger.info( - f"Autonomous turn limit reached ({self._autonomous_turn_limit}). " - "Yielding conversation and stopping." - ) - cleaned_output = clean_conversation_for_handoff(conversation) - await ctx.yield_output(cleaned_output) - return - - # In autonomous mode, agents continue iterating until they invoke a handoff tool - logger.info( - f"Agent '{source}' responded without handoff (turn {self._autonomous_turns}). " - "Continuing autonomous execution." + if not isinstance(agent, ChatAgent): + raise TypeError( + "Handoff can only be applied to ChatAgent. Please ensure the agent is a ChatAgent instance." ) - cleaned = clean_conversation_for_handoff(conversation) - request = AgentExecutorRequest(messages=cleaned, should_respond=True) - await ctx.send_message(request, target_id=source) - return - logger.info( - f"Agent '{source}' responded without handoff. " - f"Requesting user input. Return-to-previous: {self._return_to_previous}" - ) + # Clone the agent to avoid mutating the original + cloned_agent = self._clone_chat_agent(agent) # type: ignore + # Add handoff tools to the cloned agent + self._apply_auto_tools(cloned_agent, handoffs) + # Add middleware to handle handoff tool invocations + middleware = _AutoHandoffMiddleware(handoffs) + existing_middleware = list(cloned_agent.middleware or []) + existing_middleware.append(middleware) + cloned_agent.middleware = existing_middleware + + return cloned_agent + + def _clone_chat_agent(self, agent: ChatAgent) -> ChatAgent: + """Produce a deep copy of the ChatAgent while preserving runtime configuration.""" + options = agent.default_options + middleware = list(agent.middleware or []) + + # Reconstruct the original tools list by combining regular tools with MCP tools. + # ChatAgent.__init__ separates MCP tools during initialization, + # so we need to recombine them here to pass the complete tools list to the constructor. + # This makes sure MCP tools are preserved when cloning agents for handoff workflows. + tools_from_options = options.get("tools") + all_tools = list(tools_from_options) if tools_from_options else [] + if agent.mcp_tools: + all_tools.extend(agent.mcp_tools) + + logit_bias = options.get("logit_bias") + metadata = options.get("metadata") - # Clean conversation before sending to gateway for user input request - # This removes tool messages that shouldn't be shown to users - cleaned_for_display = clean_conversation_for_handoff(conversation) - - # The awaiting_agent_id is the agent that just responded and is awaiting user input - # This is the source of the current response (fallback to starting agent if source is unknown) - next_agent_id = source or self._starting_agent_id + # Disable parallel tool calls to prevent the agent from invoking multiple handoff tools at once. + cloned_options: dict[str, Any] = { + "allow_multiple_tool_calls": False, + "frequency_penalty": options.get("frequency_penalty"), + "instructions": options.get("instructions"), + "logit_bias": dict(logit_bias) if logit_bias else None, + "max_tokens": options.get("max_tokens"), + "metadata": dict(metadata) if metadata else None, + "model_id": options.get("model_id"), + "presence_penalty": options.get("presence_penalty"), + "response_format": options.get("response_format"), + "seed": options.get("seed"), + "stop": options.get("stop"), + "store": options.get("store"), + "temperature": options.get("temperature"), + "tool_choice": options.get("tool_choice"), + "tools": all_tools if all_tools else None, + "top_p": options.get("top_p"), + "user": options.get("user"), + } - message_to_gateway = _ConversationForUserInput(conversation=cleaned_for_display, next_agent_id=next_agent_id) - await ctx.send_message(message_to_gateway, target_id=self._input_gateway_id) # type: ignore[arg-type] + return ChatAgent( + chat_client=agent.chat_client, + id=agent.id, + name=agent.name, + description=agent.description, + chat_message_store_factory=agent.chat_message_store_factory, + context_providers=agent.context_provider, + middleware=middleware, + default_options=cloned_options, # type: ignore[arg-type] + ) - @handler - async def handle_user_input( - self, - message: _ConversationWithUserInput, - ctx: WorkflowContext[AgentExecutorRequest, list[ChatMessage]], - ) -> None: - """Receive user input from gateway, update history, and route to agent. + def _apply_auto_tools(self, agent: ChatAgent, targets: Sequence[HandoffConfiguration]) -> None: + """Attach synthetic handoff tools to a chat agent and return the target lookup table. - The message.full_conversation may contain: - - Full conversation history + new user messages (normal flow) - - Only new user messages (post-checkpoint-restore flow, see issue #2667) + Creates handoff tools for each specialist agent that this agent can route to. - The gateway sets message.is_post_restore=True when resuming after a checkpoint - restore. In that case, we append the new messages to the existing conversation - rather than replacing it. + Args: + agent: The ChatAgent to add handoff tools to + targets: Sequence of handoff configurations defining target agents """ - incoming = message.full_conversation - - if message.is_post_restore and self._conversation: - # Post-restore: append new user messages to existing conversation - # The coordinator already has its conversation restored from checkpoint - self._conversation.extend(incoming) - else: - # Normal flow: replace with full conversation - self._conversation = list(incoming) if incoming else self._conversation - - # Reset autonomous turn counter on new user input - self._autonomous_turns = 0 - - # Check termination before sending to agent - if await self._check_termination(): - await ctx.yield_output(list(self._conversation)) - return + default_options = agent.default_options + existing_tools = list(default_options.get("tools") or []) + existing_names = {getattr(tool, "name", "") for tool in existing_tools if hasattr(tool, "name")} - # Determine routing target based on return-to-previous setting - target_agent_id = self._starting_agent_id - if self._return_to_previous and self._current_agent_id: - # Route back to the current agent that's handling the conversation - target_agent_id = self._current_agent_id - logger.info( - f"Return-to-previous enabled: routing user input to current agent '{target_agent_id}' " - f"(bypassing coordinator '{self._starting_agent_id}')" - ) - else: - logger.info(f"Routing user input to coordinator '{target_agent_id}'") - - # Clean conversation before sending to target agent - # Removes tool-related messages that shouldn't be resent on every turn - cleaned = clean_conversation_for_handoff(self._conversation) - request = AgentExecutorRequest(messages=cleaned, should_respond=True) - await ctx.send_message(request, target_id=target_agent_id) - - def _resolve_specialist(self, agent_response: AgentRunResponse, conversation: list[ChatMessage]) -> str | None: - """Resolve the specialist executor id requested by the agent response, if any.""" - resolution = _resolve_handoff_target(agent_response) - if not resolution: - return None + new_tools: list[AIFunction[Any, Any]] = [] + for target in targets: + tool = self._create_handoff_tool(target.target_id, target.description) + if tool.name in existing_names: + raise ValueError( + f"Agent '{resolve_agent_id(agent)}' already has a tool named '{tool.name}'. " + f"Handoff tool name '{tool.name}' conflicts with existing tool." + "Please rename the existing tool or modify the target agent ID to avoid conflicts." + ) + new_tools.append(tool) - candidate = resolution.target - normalized = candidate.lower() - resolved_id: str | None - if normalized in self._handoff_tool_targets: - resolved_id = self._handoff_tool_targets[normalized] + if new_tools: + default_options["tools"] = existing_tools + new_tools # type: ignore[operator] else: - resolved_id = self._specialist_by_alias.get(candidate) + default_options["tools"] = existing_tools - if resolved_id: - if resolution.function_call: - self._append_tool_acknowledgement(conversation, resolution.function_call, resolved_id) - return resolved_id + def _create_handoff_tool(self, target_id: str, description: str | None = None) -> AIFunction[Any, Any]: + """Construct the synthetic handoff tool that signals routing to `target_id`.""" + tool_name = get_handoff_tool_name(target_id) + doc = description or f"Handoff to the {target_id} agent." + # Note: approval_mode is intentionally NOT set for handoff tools. + # Handoff tools are framework-internal signals that trigger routing logic, + # not actual function executions. They are automatically intercepted by + # _AutoHandoffMiddleware which short-circuits execution and provides synthetic + # results, so the function body never actually runs in practice. - lowered = candidate.lower() - for alias, exec_id in self._specialist_by_alias.items(): - if alias.lower() == lowered: - if resolution.function_call: - self._append_tool_acknowledgement(conversation, resolution.function_call, exec_id) - return exec_id + @ai_function(name=tool_name, description=doc) + def _handoff_tool(context: str | None = None) -> str: + """Return a deterministic acknowledgement that encodes the target alias.""" + return f"Handoff to {target_id}" - logger.warning("Handoff requested unknown specialist '%s'.", candidate) - return None - - def _append_tool_acknowledgement( - self, - conversation: list[ChatMessage], - function_call: FunctionCallContent, - resolved_id: str, - ) -> None: - """Append a synthetic tool result acknowledging the resolved specialist id.""" - call_id = getattr(function_call, "call_id", None) - if not call_id: - return - - result_payload: Any = {"handoff_to": resolved_id} - result_content = FunctionResultContent(call_id=call_id, result=result_payload) - tool_message = ChatMessage( - role=Role.TOOL, - contents=[result_content], - author_name=function_call.name, - ) - # Add tool acknowledgement to both the conversation being sent and the full history - conversation.extend((tool_message,)) - self._append_messages((tool_message,)) - - def _conversation_from_response(self, response: AgentExecutorResponse) -> list[ChatMessage]: - """Return the authoritative conversation snapshot from an executor response.""" - conversation = response.full_conversation - if conversation is None: - raise RuntimeError( - "AgentExecutorResponse.full_conversation missing; AgentExecutor must populate it in handoff workflows." - ) - return list(conversation) + return _handoff_tool @override - def _snapshot_pattern_metadata(self) -> dict[str, Any]: - """Serialize pattern-specific state. + async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse, AgentResponse]) -> None: + """Override to support handoff.""" + # When the full conversation is empty, it means this is the first run. + # Broadcast the initial cache to all other agents. Subsequent runs won't + # need this since responses are broadcast after each agent run and user input. + if self._is_start_agent and not self._full_conversation: + await self._broadcast_messages(self._cache.copy(), cast(WorkflowContext[AgentExecutorRequest], ctx)) + + # Append the cache to the full conversation history + self._full_conversation.extend(self._cache) + + # Check termination condition before running the agent + if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[ChatMessage]], ctx)): + return - Includes the current agent for return-to-previous routing. + # Run the agent + if ctx.is_streaming(): + # Streaming mode: emit incremental updates + response = await self._run_agent_streaming(cast(WorkflowContext, ctx)) + else: + # Non-streaming mode: use run() and emit single event + response = await self._run_agent(cast(WorkflowContext, ctx)) - Returns: - Dict containing current agent if return-to-previous is enabled - """ - metadata: dict[str, Any] = {} - if self._return_to_previous: - metadata["current_agent_id"] = self._current_agent_id - if self._interaction_mode == "autonomous": - metadata["autonomous_turns"] = self._autonomous_turns - return metadata + # Clear the cache after running the agent + self._cache.clear() - @override - def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None: - """Restore pattern-specific state. + # A function approval request is issued by the base AgentExecutor + if response is None: + # Agent did not complete (e.g., waiting for user input); do not emit response + logger.debug("AgentExecutor %s: Agent did not complete, awaiting user input", self.id) + return - Restores the current agent for return-to-previous routing. + # Remove function call related content from the agent response for full conversation history + cleaned_response = clean_conversation_for_handoff(response.messages) + # Append the agent response to the full conversation history. This list removes + # function call related content such that the result stays consistent regardless + # of which agent yields the final output. + self._full_conversation.extend(cleaned_response) + # Broadcast the cleaned response to all other agents + await self._broadcast_messages(cleaned_response, cast(WorkflowContext[AgentExecutorRequest], ctx)) + + # Check if a handoff was requested + if handoff_target := self._is_handoff_requested(response): + if handoff_target not in self._handoff_targets: + raise ValueError( + f"Agent '{resolve_agent_id(self._agent)}' attempted to handoff to unknown " + f"target '{handoff_target}'. Valid targets are: {', '.join(self._handoff_targets)}" + ) - Args: - metadata: Pattern-specific state dict - """ - if self._return_to_previous and "current_agent_id" in metadata: - self._current_agent_id = metadata["current_agent_id"] - if self._interaction_mode == "autonomous" and "autonomous_turns" in metadata: - self._autonomous_turns = metadata["autonomous_turns"] - - def _apply_response_metadata(self, conversation: list[ChatMessage], agent_response: AgentRunResponse) -> None: - """Merge top-level response metadata into the latest assistant message.""" - if not agent_response.additional_properties: + await cast(WorkflowContext[AgentExecutorRequest], ctx).send_message( + AgentExecutorRequest(messages=[], should_respond=True), target_id=handoff_target + ) + await ctx.add_event(HandoffSentEvent(source=self.id, target=handoff_target)) + self._autonomous_mode_turns = 0 # Reset autonomous mode turn counter on handoff return - # Find the most recent assistant message contributed by this response - for message in reversed(conversation): - if message.role == Role.ASSISTANT: - metadata = agent_response.additional_properties or {} - if not metadata: - return - # Merge metadata without mutating shared dict from agent response - merged = dict(message.additional_properties or {}) - for key, value in metadata.items(): - merged.setdefault(key, value) - message.additional_properties = merged - break - - -class _UserInputGateway(Executor): - """Bridges conversation context with the request & response cycle and re-enters the loop.""" - - def __init__(self, *, starting_agent_id: str, prompt: str | None, id: str) -> None: - """Initialise the gateway that requests user input and forwards responses.""" - super().__init__(id) - self._starting_agent_id = starting_agent_id - self._prompt = prompt or "Provide your next input for the conversation." - - @handler - async def request_input(self, message: _ConversationForUserInput, ctx: WorkflowContext) -> None: - """Emit a `HandoffUserInputRequest` capturing the conversation snapshot.""" - if not message.conversation: - raise ValueError("Handoff workflow requires non-empty conversation before requesting user input.") - request = HandoffUserInputRequest( - conversation=list(message.conversation), - awaiting_agent_id=message.next_agent_id, - prompt=self._prompt, - source_executor_id=self.id, - ) - await ctx.request_info(request, object) - - @handler - async def request_input_legacy(self, conversation: list[ChatMessage], ctx: WorkflowContext) -> None: - """Legacy handler for backward compatibility - emit user input request with starting agent.""" - if not conversation: - raise ValueError("Handoff workflow requires non-empty conversation before requesting user input.") - request = HandoffUserInputRequest( - conversation=list(conversation), - awaiting_agent_id=self._starting_agent_id, - prompt=self._prompt, - source_executor_id=self.id, - ) - await ctx.request_info(request, object) + # Handle case where no handoff was requested + if self._autonomous_mode and self._autonomous_mode_turns < self._autonomous_mode_turn_limit: + # In autonomous mode, continue running the agent until a handoff is requested + # or a termination condition is met. + # This allows the agent to perform long-running tasks without returning control + # to the coordinator or user prematurely. + self._cache.extend([ChatMessage(role=Role.USER, text=self._autonomous_mode_prompt)]) + self._autonomous_mode_turns += 1 + await self._run_agent_and_emit(ctx) + else: + # The response is handled via `handle_response` + self._autonomous_mode_turns = 0 # Reset autonomous mode turn counter on handoff + await ctx.request_info(HandoffAgentUserRequest(response), list[ChatMessage]) @response_handler - async def resume_from_user( + async def handle_response( self, - original_request: HandoffUserInputRequest, - response: object, - ctx: WorkflowContext[_ConversationWithUserInput], + original_request: HandoffAgentUserRequest, + response: list[ChatMessage], + ctx: WorkflowContext[AgentExecutorResponse, AgentResponse], ) -> None: - """Convert user input responses back into chat messages and resume the workflow. + """Handle user response for a request that is issued after agent runs. - After checkpoint restore, original_request.conversation will be empty (not serialized - to prevent duplication - see issue #2667). In this case, we send only the new user - messages and let the coordinator append them to its already-restored conversation. - """ - user_messages = _as_user_messages(response) + The request only occurs when the agent did not request a handoff and + autonomous mode is disabled. - if original_request.conversation: - # Normal flow: have conversation history from the original request - conversation = list(original_request.conversation) - conversation.extend(user_messages) - message = _ConversationWithUserInput(full_conversation=conversation, is_post_restore=False) - else: - # Post-restore flow: conversation was not serialized, send only new user messages - # The coordinator will append these to its already-restored conversation - message = _ConversationWithUserInput(full_conversation=user_messages, is_post_restore=True) + Note that this is different that the `handle_user_input_response` method + in the base AgentExecutor, which handles function approval responses. - await ctx.send_message(message, target_id="handoff-coordinator") + Args: + original_request: The original HandoffAgentUserRequest issued to the user + response: The user's response messages + ctx: The workflow context + If the response is empty, it indicates termination of the handoff workflow. + """ + if not response: + await cast(WorkflowContext[Never, list[ChatMessage]], ctx).yield_output(self._full_conversation) + return -def _as_user_messages(payload: Any) -> list[ChatMessage]: - """Normalize arbitrary payloads into user-authored chat messages. + # Broadcast the user response to all other agents + await self._broadcast_messages(response, cast(WorkflowContext[AgentExecutorRequest], ctx)) - Handles various input formats: - - ChatMessage instances (converted to USER role if needed) - - List of ChatMessage instances - - Mapping with 'text' or 'content' key - - Any other type (converted to string) + # Append the user response messages to the cache + self._cache.extend(response) + await self._run_agent_and_emit(ctx) - Returns: - List of ChatMessage instances with USER role. - """ - if isinstance(payload, ChatMessage): - if payload.role == Role.USER: - return [payload] - return [ChatMessage(Role.USER, text=payload.text)] - if isinstance(payload, list): - # Check if all items are ChatMessage instances - all_chat_messages = all(isinstance(msg, ChatMessage) for msg in payload) # type: ignore[arg-type] - if all_chat_messages: - messages: list[ChatMessage] = payload # type: ignore[assignment] - return [msg if msg.role == Role.USER else ChatMessage(Role.USER, text=msg.text) for msg in messages] - if isinstance(payload, Mapping): # User supplied structured data - text = payload.get("text") or payload.get("content") # type: ignore[union-attr] - if isinstance(text, str) and text.strip(): - return [ChatMessage(Role.USER, text=text.strip())] - return [ChatMessage(Role.USER, text=str(payload))] # type: ignore[arg-type] - - -def _default_termination_condition(conversation: list[ChatMessage]) -> bool: - """Default termination: stop after 10 user messages.""" - user_message_count = sum(1 for msg in conversation if msg.role == Role.USER) - return user_message_count >= 10 + async def _broadcast_messages( + self, + messages: list[ChatMessage], + ctx: WorkflowContext[AgentExecutorRequest], + ) -> None: + """Broadcast the workflow cache to the agent before running.""" + agent_executor_request = AgentExecutorRequest( + messages=messages, + should_respond=False, # Other agents do not need to respond yet + ) + # Since all agents are connected via fan-out, we can directly send the message + await ctx.send_message(agent_executor_request) + def _is_handoff_requested(self, response: AgentResponse) -> str | None: + """Determine if the agent response includes a handoff request. -class HandoffBuilder: - r"""Fluent builder for conversational handoff workflows with coordinator and specialist agents. - - The handoff pattern enables a coordinator agent to route requests to specialist agents. - Interaction mode controls whether the workflow requests user input after each agent response or - completes autonomously once agents finish responding. A termination condition determines when - the workflow should stop requesting input and complete. - - Routing Patterns: - - **Single-Tier (Default):** Only the coordinator can hand off to specialists. By default, after any specialist - responds, control returns to the user for more input. This creates a cyclical flow: - user -> coordinator -> [optional specialist] -> user -> coordinator -> ... - Use `with_interaction_mode("autonomous")` to skip requesting additional user input and yield the - final conversation when an agent responds without delegating. - - **Multi-Tier (Advanced):** Specialists can hand off to other specialists using `.add_handoff()`. - This provides more flexibility for complex workflows but is less controllable than the single-tier - pattern. Users lose real-time visibility into intermediate steps during specialist-to-specialist - handoffs (though the full conversation history including all handoffs is preserved and can be - inspected afterward). - - - Key Features: - - **Automatic handoff detection**: The coordinator invokes a handoff tool whose - arguments (for example `{"handoff_to": "shipping_agent"}`) identify the specialist to receive control. - - **Auto-generated tools**: By default the builder synthesizes `handoff_to_` tools for the coordinator, - so you don't manually define placeholder functions. - - **Full conversation history**: The entire conversation (including any - `ChatMessage.additional_properties`) is preserved and passed to each agent. - - **Termination control**: By default, terminates after 10 user messages. Override with - `.with_termination_condition(lambda conv: ...)` for custom logic (e.g., detect "goodbye"). - - **Interaction modes**: Choose `human_in_loop` (default) to prompt users between agent turns, - or `autonomous` to continue routing back to agents without prompting for user input until a - handoff occurs or a termination/turn limit is reached (default autonomous turn limit: 50). - - **Checkpointing**: Optional persistence for resumable workflows. - - Usage (Single-Tier): - - .. code-block:: python - - from agent_framework import HandoffBuilder - from agent_framework.openai import OpenAIChatClient - - chat_client = OpenAIChatClient() - - # Create coordinator and specialist agents - coordinator = chat_client.create_agent( - instructions=( - "You are a frontline support agent. Assess the user's issue and decide " - "whether to hand off to 'refund_agent' or 'shipping_agent'. When delegation is " - "required, call the matching handoff tool (for example `handoff_to_refund_agent`)." - ), - name="coordinator_agent", - ) + If a handoff tool is invoked, the middleware will short-circuit execution + and provide a synthetic result that includes the target agent ID. The message + that contains the function result will be the last message in the response. + """ + if not response.messages: + return None - refund = chat_client.create_agent( - instructions="You handle refund requests. Ask for order details and process refunds.", - name="refund_agent", - ) + last_message = response.messages[-1] + for content in last_message.contents: + if content.type == "function_result": + # Use string comparison instead of isinstance to improve performance + if content.result and isinstance(content.result, dict): + handoff_target = content.result.get(HANDOFF_FUNCTION_RESULT_KEY) # type: ignore + if isinstance(handoff_target, str): + return handoff_target + else: + continue - shipping = chat_client.create_agent( - instructions="You resolve shipping issues. Track packages and update delivery status.", - name="shipping_agent", - ) + return None - # Build the handoff workflow - default single-tier routing - workflow = ( - HandoffBuilder( - name="customer_support", - participants=[coordinator, refund, shipping], - ) - .set_coordinator(coordinator) - .build() - ) + async def _check_terminate_and_yield(self, ctx: WorkflowContext[Never, list[ChatMessage]]) -> bool: + """Check termination conditions and yield completion if met. - # Run the workflow - events = await workflow.run_stream("My package hasn't arrived yet") - async for event in events: - if isinstance(event, RequestInfoEvent): - # Request user input - user_response = input("You: ") - await workflow.send_response(event.data.request_id, user_response) - - **Multi-Tier Routing with .add_handoff():** - - .. code-block:: python - - # Enable specialist-to-specialist handoffs with fluent API - workflow = ( - HandoffBuilder(participants=[coordinator, replacement, delivery, billing]) - .set_coordinator(coordinator) - .add_handoff(coordinator, [replacement, delivery, billing]) # Coordinator routes to all - .add_handoff(replacement, [delivery, billing]) # Replacement delegates to delivery/billing - .add_handoff(delivery, billing) # Delivery escalates to billing - .build() - ) + Args: + ctx: Workflow context for yielding output - # Flow: User → Coordinator → Replacement → Delivery → Back to User - # (Replacement hands off to Delivery without returning to user) + Returns: + True if termination condition met and output yielded, False otherwise + """ + if self._termination_condition is None: + return False - **Use Participant Factories for State Isolation:** + terminated = self._termination_condition(self._full_conversation) + if inspect.isawaitable(terminated): + terminated = await terminated - .. code-block:: python - # Define factories that produce fresh agent instances per workflow run - def create_coordinator() -> AgentProtocol: - return chat_client.create_agent( - instructions="You are the coordinator agent...", - name="coordinator_agent", - ) + if terminated: + await ctx.yield_output(self._full_conversation) + return True + return False - def create_specialist() -> AgentProtocol: - return chat_client.create_agent( - instructions="You are the specialist agent...", - name="specialist_agent", - ) + @override + async def on_checkpoint_save(self) -> dict[str, Any]: + """Serialize the executor state for checkpointing.""" + state = await super().on_checkpoint_save() + state["_autonomous_mode_turns"] = self._autonomous_mode_turns + return state + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + """Restore the executor state from a checkpoint.""" + await super().on_checkpoint_restore(state) + if "_autonomous_mode_turns" in state: + self._autonomous_mode_turns = state["_autonomous_mode_turns"] - workflow = ( - HandoffBuilder( - participant_factories={ - "coordinator": create_coordinator, - "specialist": create_specialist, - } - ) - .set_coordinator("coordinator") - .build() - ) - **Custom Termination Condition:** +# endregion Handoff Agent Executor - .. code-block:: python +# region Handoff workflow builder - # Terminate when user says goodbye or after 5 exchanges - workflow = ( - HandoffBuilder(participants=[coordinator, refund, shipping]) - .set_coordinator(coordinator) - .with_termination_condition( - lambda conv: ( - sum(1 for msg in conv if msg.role.value == "user") >= 5 - or any("goodbye" in msg.text.lower() for msg in conv[-2:]) - ) - ) - .build() - ) - **Checkpointing:** +class HandoffBuilder: + r"""Fluent builder for conversational handoff workflows with multiple agents. - .. code-block:: python + The handoff pattern enables a group of agents to route control among themselves. - from agent_framework import InMemoryCheckpointStorage + Routing Pattern: + Agents can hand off to other agents using `.add_handoff()`. This provides a decentralized + approach to multi-agent collaboration. Handoffs can be configured using `.add_handoff`. If + none are specified, all agents can hand off to all others by default (making a mesh topology). - storage = InMemoryCheckpointStorage() - workflow = ( - HandoffBuilder(participants=[coordinator, refund, shipping]) - .set_coordinator(coordinator) - .with_checkpointing(storage) - .build() - ) + Participants must be agents. Support for custom executors is not available in handoff workflows. - Args: - name: Optional workflow name for identification and logging. - participants: List of agents (AgentProtocol) or executors to participate in the handoff. - The first agent you specify as coordinator becomes the orchestrating agent. - participant_factories: Mapping of factory names to callables that produce agents or - executors when invoked. This allows for lazy instantiation - and state isolation per workflow instance created by this builder. - description: Optional human-readable description of the workflow. - - Raises: - ValueError: If participants list is empty, contains duplicates, or coordinator not specified. - TypeError: If participants are not AgentProtocol or Executor instances. + Outputs: + The final conversation history as a list of ChatMessage once the group chat completes. + + Note: + Agents in handoff workflows must be ChatAgent instances and support local tool calls. """ def __init__( self, *, name: str | None = None, - participants: Sequence[AgentProtocol | Executor] | None = None, - participant_factories: Mapping[str, Callable[[], AgentProtocol | Executor]] | None = None, + participants: Sequence[AgentProtocol] | None = None, + participant_factories: Mapping[str, Callable[[], AgentProtocol]] | None = None, description: str | None = None, ) -> None: r"""Initialize a HandoffBuilder for creating conversational handoff workflows. The builder starts in an unconfigured state and requires you to call: 1. `.participants([...])` - Register agents - 2. or `.participant_factories({...})` - Register agent/executor factories - 3. `.set_coordinator(...)` - Designate which agent receives initial user input - 4. `.build()` - Construct the final Workflow + 2. or `.participant_factories({...})` - Register agent factories + 3. `.build()` - Construct the final Workflow Optional configuration methods allow you to customize context management, termination logic, and persistence. Args: name: Optional workflow identifier used in logging and debugging. - If not provided, a default name will be generated. - participants: Optional list of agents (AgentProtocol) or executors that will - participate in the handoff workflow. You can also call - `.participants([...])` later. Each participant must have a - unique identifier (name for agents, id for executors). - participant_factories: Optional mapping of factory names to callables that produce agents or - executors when invoked. This allows for lazy instantiation - and state isolation per workflow instance created by this builder. + If not provided, a default name will be generated. + participants: Optional list of agents that will participate in the handoff workflow. + You can also call `.participants([...])` later. Each participant must have a + unique identifier (`.name` is preferred if set, otherwise `.id` is used). + participant_factories: Optional mapping of factory names to callables that produce agents when invoked. + This allows for lazy instantiation and state isolation per workflow instance + created by this builder. description: Optional human-readable description explaining the workflow's - purpose. Useful for documentation and observability. - - Note: - Participants must have stable names/ids because the workflow maps the - handoff tool arguments to these identifiers. Agent names should match - the strings emitted by the coordinator's handoff tool (e.g., a tool that - outputs `{\"handoff_to\": \"billing\"}` requires an agent named `billing`). + purpose. Useful for documentation and observability. """ self._name = name self._description = description - self._executors: dict[str, Executor] = {} - self._aliases: dict[str, str] = {} - self._starting_agent_id: str | None = None - self._checkpoint_storage: CheckpointStorage | None = None - self._request_prompt: str | None = None - # Termination condition - self._termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] = ( - _default_termination_condition - ) - self._handoff_config: dict[str, list[str]] = {} # Maps agent_id -> [target_agent_ids] - self._return_to_previous: bool = False - self._interaction_mode: Literal["human_in_loop", "autonomous"] = "human_in_loop" - self._autonomous_turn_limit: int | None = _DEFAULT_AUTONOMOUS_TURN_LIMIT - self._request_info_enabled: bool = False - self._request_info_filter: set[str] | None = None - - self._participant_factories: dict[str, Callable[[], AgentProtocol | Executor]] = {} + + # Participant related members + self._participants: dict[str, AgentProtocol] = {} + self._participant_factories: dict[str, Callable[[], AgentProtocol]] = {} + self._start_id: str | None = None if participant_factories: self.participant_factories(participant_factories) if participants: self.participants(participants) - # region Fluent Configuration Methods + # Handoff related members + self._handoff_config: dict[str, set[HandoffConfiguration]] = {} + + # Checkpoint related members + self._checkpoint_storage: CheckpointStorage | None = None + + # Autonomous mode related + self._autonomous_mode: bool = False + self._autonomous_mode_prompts: dict[str, str] = {} + self._autonomous_mode_turn_limits: dict[str, int] = {} + self._autonomous_mode_enabled_agents: list[str] = [] + + # Termination related members + self._termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] | None = None def participant_factories( - self, participant_factories: Mapping[str, Callable[[], AgentProtocol | Executor]] + self, participant_factories: Mapping[str, Callable[[], AgentProtocol]] ) -> "HandoffBuilder": - """Register factories that produce agents or executors for the handoff workflow. + """Register factories that produce agents for the handoff workflow. - Each factory is a callable that returns an AgentProtocol or Executor instance. + Each factory is a callable that returns an AgentProtocol instance. Factories are invoked when building the workflow, allowing for lazy instantiation and state isolation per workflow instance. Args: - participant_factories: Mapping of factory names to callables that return AgentProtocol or Executor - instances. Each produced participant must have a unique identifier (name for - agents, id for executors). + participant_factories: Mapping of factory names to callables that return AgentProtocol + instances. Each produced participant must have a unique identifier + (`.name` is preferred if set, otherwise `.id` is used). Returns: Self for method chaining. @@ -999,7 +650,7 @@ def participant_factories( from agent_framework import ChatAgent, HandoffBuilder - def create_coordinator() -> ChatAgent: + def create_triage() -> ChatAgent: return ... @@ -1012,17 +663,17 @@ def create_billing_agent() -> ChatAgent: factories = { - "coordinator": create_coordinator, + "triage": create_triage, "refund": create_refund_agent, "billing": create_billing_agent, } + # Handoff will be created automatically unless specified otherwise + # The default creates a mesh topology where all agents can handoff to all others builder = HandoffBuilder().participant_factories(factories) - # Use the factory IDs to create handoffs and set the coordinator - builder.add_handoff("coordinator", ["refund", "billing"]) - builder.set_coordinator("coordinator") + builder.with_start_agent("triage") """ - if self._executors: + if self._participants: raise ValueError( "Cannot mix .participants([...]) and .participant_factories() in the same builder instance." ) @@ -1036,17 +687,12 @@ def create_billing_agent() -> ChatAgent: self._participant_factories = dict(participant_factories) return self - def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "HandoffBuilder": - """Register the agents or executors that will participate in the handoff workflow. - - Each participant must have a unique identifier (name for agents, id for executors). - The workflow will automatically create an alias map so agents can be referenced by - their name, display_name, or executor id when routing. + def participants(self, participants: Sequence[AgentProtocol]) -> "HandoffBuilder": + """Register the agents that will participate in the handoff workflow. Args: - participants: Sequence of AgentProtocol or Executor instances. Each must have - a unique identifier. For agents, the name attribute is used as the - primary identifier and must match handoff target strings. + participants: Sequence of AgentProtocol instances. Each must have a unique identifier. + (`.name` is preferred if set, otherwise `.id` is used). Returns: Self for method chaining. @@ -1054,7 +700,7 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Han Raises: ValueError: If participants is empty, contains duplicates, or `.participants(...)` or `.participant_factories(...)` has already been called. - TypeError: If participants are not AgentProtocol or Executor instances. + TypeError: If participants are not AgentProtocol instances. Example: @@ -1064,175 +710,81 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Han from agent_framework.openai import OpenAIChatClient client = OpenAIChatClient() - coordinator = client.create_agent(instructions="...", name="coordinator") + triage = client.create_agent(instructions="...", name="triage_agent") refund = client.create_agent(instructions="...", name="refund_agent") billing = client.create_agent(instructions="...", name="billing_agent") - builder = HandoffBuilder().participants([coordinator, refund, billing]) - # Now you can call .set_coordinator() to designate the entry point - - Note: - This method resets any previously configured coordinator, so you must call - `.set_coordinator(...)` again after changing participants. + builder = HandoffBuilder().participants([triage, refund, billing]) + builder.with_start_agent(triage) """ if self._participant_factories: raise ValueError( "Cannot mix .participants([...]) and .participant_factories() in the same builder instance." ) - if self._executors: + if self._participants: raise ValueError("participants have already been assigned") if not participants: raise ValueError("participants cannot be empty") - named: dict[str, AgentProtocol | Executor] = {} + named: dict[str, AgentProtocol] = {} for participant in participants: - if isinstance(participant, Executor): - identifier = participant.id - elif isinstance(participant, AgentProtocol): - identifier = participant.display_name + if isinstance(participant, AgentProtocol): + resolved_id = self._resolve_to_id(participant) else: raise TypeError( f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}." ) - if identifier in named: - raise ValueError(f"Duplicate participant name '{identifier}' detected") - named[identifier] = participant + if resolved_id in named: + raise ValueError(f"Duplicate participant name '{resolved_id}' detected") + named[resolved_id] = participant - metadata = prepare_participant_metadata( - named, - description_factory=lambda name, participant: getattr(participant, "description", None) or name, - ) - - wrapped = metadata["executors"] - self._executors = {executor.id: executor for executor in wrapped.values()} - self._aliases = metadata["aliases"] - self._starting_agent_id = None - - return self - - def set_coordinator(self, agent: str | AgentProtocol | Executor) -> "HandoffBuilder": - r"""Designate which agent receives initial user input and orchestrates specialist routing. - - The coordinator agent is responsible for analyzing user requests and deciding whether to: - 1. Handle the request directly and respond to the user, OR - 2. Hand off to a specialist agent by including handoff metadata in the response - - After a specialist responds, the workflow automatically returns control to the user - (default) creating a cyclical flow: user -> coordinator -> [specialist] -> user -> ... - Configure `with_interaction_mode("autonomous")` to continue with the responding agent - without requesting another user turn until a handoff occurs or a termination/turn limit is met. - - Args: - agent: The agent to use as the coordinator. Can be: - - Factory name (str): If using participant factories - - AgentProtocol instance: The actual agent object - - Executor instance: A custom executor wrapping an agent - - Returns: - Self for method chaining. - - Raises: - ValueError: 1) If `agent` is an AgentProtocol or Executor instance but `.participants(...)` hasn't - been called yet, or if it is not in the participants list. - 2) If `agent` is a factory name (str) but `.participant_factories(...)` hasn't been - called yet, or if it is not in the participant_factories list. - TypeError: If `agent` is not a str, AgentProtocol, or Executor instance. - - Example: - - .. code-block:: python - - # Use factory name with `.participant_factories()` - builder = ( - HandoffBuilder() - .participant_factories({ - "coordinator": create_coordinator, - "refund": create_refund_agent, - "billing": create_billing_agent, - }) - .set_coordinator("coordinator") - ) - - # Or pass the agent object directly - builder = HandoffBuilder().participants([coordinator, refund, billing]).set_coordinator(coordinator) - - Note: - The coordinator determines routing by invoking a handoff tool call whose - arguments identify the target specialist (for example `{\"handoff_to\": \"billing\"}`). - Decorate the tool with `approval_mode="always_require"` to ensure the workflow - intercepts the call before execution and can make the transition. - """ - if isinstance(agent, (AgentProtocol, Executor)): - if not self._executors: - raise ValueError( - "Call participants(...) before coordinator(...). If using participant_factories, " - "pass the factory name (str) instead of the agent instance." - ) - resolved = self._resolve_to_id(agent) - if resolved not in self._executors: - raise ValueError(f"coordinator '{resolved}' is not part of the participants list") - self._starting_agent_id = resolved - elif isinstance(agent, str): - if agent not in self._participant_factories: - raise ValueError( - f"coordinator factory name '{agent}' is not part of the participant_factories list. If " - "you are using participant instances, call .participants(...) and pass the agent instance instead." - ) - self._starting_agent_id = agent - else: - raise TypeError( - "coordinator must be a factory name (str), AgentProtocol, or Executor instance. " - f"Got {type(agent).__name__}." - ) + self._participants = named return self def add_handoff( self, - source: str | AgentProtocol | Executor, - targets: str | AgentProtocol | Executor | Sequence[str | AgentProtocol | Executor], + source: str | AgentProtocol, + targets: Sequence[str] | Sequence[AgentProtocol], *, - tool_name: str | None = None, - tool_description: str | None = None, + description: str | None = None, ) -> "HandoffBuilder": """Add handoff routing from a source agent to one or more target agents. - This method enables specialist-to-specialist handoffs by configuring which agents - can hand off to which others. Call this method multiple times to build a complete - routing graph. By default, only the starting agent can hand off to all other participants; - use this method to enable additional routing paths. + This method enables agent-to-agent handoffs by configuring which agents + can hand off to which others. Call this method multiple times to build a + complete routing graph. If no handoffs are specified, all agents can hand off + to all others by default (mesh topology). Args: source: The agent that can initiate the handoff. Can be: - Factory name (str): If using participant factories - AgentProtocol instance: The actual agent object - - Executor instance: A custom executor wrapping an agent - Cannot mix factory names and instances across source and targets targets: One or more target agents that the source can hand off to. Can be: - Factory name (str): If using participant factories - AgentProtocol instance: The actual agent object - - Executor instance: A custom executor wrapping an agent - - Single target: "billing_agent" or agent_instance + - Single target: ["billing_agent"] or [agent_instance] - Multiple targets: ["billing_agent", "support_agent"] or [agent1, agent2] - Cannot mix factory names and instances across source and targets - tool_name: Optional custom name for the handoff tool. Currently not used in the - implementation - tools are always auto-generated as "handoff_to_". - Reserved for future enhancement. - tool_description: Optional custom description for the handoff tool. Currently not used - in the implementation - descriptions are always auto-generated as - "Handoff to the agent.". Reserved for future enhancement. + description: Optional custom description for the handoff. If not provided, the description + of the target agent(s) will be used. If the target agent has no description, + no description will be set for the handoff tool, which is not recommended. + If multiple targets are provided, description will be shared among all handoff + tools. To configure distinct descriptions for multiple targets, call add_handoff() + separately for each target. Returns: Self for method chaining. Raises: ValueError: 1) If source or targets are not in the participants list, or if - participants(...) hasn't been called yet. + participants(...) hasn't been called yet. 2) If source or targets are factory names (str) but participant_factories(...) - hasn't been called yet, or if they are not in the participant_factories list. + hasn't been called yet, or if they are not in the participant_factories list. TypeError: If mixing factory names (str) and AgentProtocol/Executor instances Examples: @@ -1260,10 +812,9 @@ def add_handoff( workflow = ( HandoffBuilder(participants=[triage, replacement, delivery, billing]) - .set_coordinator(triage) .add_handoff(triage, [replacement, delivery, billing]) .add_handoff(replacement, [delivery, billing]) - .add_handoff(delivery, billing) + .add_handoff(delivery, [billing]) .build() ) @@ -1271,9 +822,7 @@ def add_handoff( - Handoff tools are automatically registered for each source agent - If a source agent is configured multiple times via add_handoff, targets are merged """ - if isinstance(source, str) and ( - isinstance(targets, str) or (isinstance(targets, Sequence) and all(isinstance(t, str) for t in targets)) - ): + if isinstance(source, str) and all(isinstance(t, str) for t in targets): # Both source and targets are factory names if not self._participant_factories: raise ValueError("Call participant_factories(...) before add_handoff(...)") @@ -1281,90 +830,120 @@ def add_handoff( if source not in self._participant_factories: raise ValueError(f"Source factory name '{source}' is not in the participant_factories list") - target_list: list[str] = [targets] if isinstance(targets, str) else list(targets) # type: ignore - for target in target_list: + for target in targets: if target not in self._participant_factories: raise ValueError(f"Target factory name '{target}' is not in the participant_factories list") - self._handoff_config[source] = target_list # type: ignore + # Merge with existing handoff configuration for this source + if source in self._handoff_config: + # Add new targets to existing list, avoiding duplicates + for t in targets: + if t in self._handoff_config[source]: + logger.warning(f"Handoff from '{source}' to '{t}' is already configured; overwriting.") + self._handoff_config[source].add(HandoffConfiguration(target=t, description=description)) + else: + self._handoff_config[source] = set() + for t in targets: + self._handoff_config[source].add(HandoffConfiguration(target=t, description=description)) return self - if isinstance(source, (AgentProtocol, Executor)) and ( - isinstance(targets, (AgentProtocol, Executor)) - or all(isinstance(t, (AgentProtocol, Executor)) for t in targets) - ): + if isinstance(source, (AgentProtocol)) and all(isinstance(t, AgentProtocol) for t in targets): # Both source and targets are instances - if not self._executors: + if not self._participants: raise ValueError("Call participants(...) before add_handoff(...)") # Resolve source agent ID source_id = self._resolve_to_id(source) - if source_id not in self._executors: + if source_id not in self._participants: raise ValueError(f"Source agent '{source}' is not in the participants list") - # Normalize targets to list - target_list: list[AgentProtocol | Executor] = ( # type: ignore[no-redef] - [targets] if isinstance(targets, (AgentProtocol, Executor)) else list(targets) - ) # type: ignore - # Resolve all target IDs target_ids: list[str] = [] - for target in target_list: + for target in targets: target_id = self._resolve_to_id(target) - if target_id not in self._executors: + if target_id not in self._participants: raise ValueError(f"Target agent '{target}' is not in the participants list") target_ids.append(target_id) # Merge with existing handoff configuration for this source if source_id in self._handoff_config: # Add new targets to existing list, avoiding duplicates - existing = self._handoff_config[source_id] - for target_id in target_ids: - if target_id not in existing: - existing.append(target_id) + for t in target_ids: + if t in self._handoff_config[source_id]: + logger.warning(f"Handoff from '{source_id}' to '{t}' is already configured; overwriting.") + self._handoff_config[source_id].add(HandoffConfiguration(target=t, description=description)) else: - self._handoff_config[source_id] = target_ids + self._handoff_config[source_id] = set() + for t in target_ids: + self._handoff_config[source_id].add(HandoffConfiguration(target=t, description=description)) return self raise TypeError( - "Cannot mix factory names (str) and AgentProtocol/Executor instances " - "across source and targets in add_handoff()" + "Cannot mix factory names (str) and AgentProtocol instances across source and targets in add_handoff()" ) - def request_prompt(self, prompt: str | None) -> "HandoffBuilder": - """Set a custom prompt message displayed when requesting user input. + def with_start_agent(self, agent: str | AgentProtocol) -> "HandoffBuilder": + """Set the agent that will initiate the handoff workflow. - By default, the workflow uses a generic prompt: "Provide your next input for the - conversation." Use this method to customize the message shown to users when the - workflow needs their response. + If not specified, the first registered participant will be used as the starting agent. Args: - prompt: Custom prompt text to display, or None to use the default prompt. - + agent: The agent that will start the workflow. Can be: + - Factory name (str): If using participant factories + - AgentProtocol instance: The actual agent object Returns: Self for method chaining. + """ + if isinstance(agent, str): + if self._participant_factories: + if agent not in self._participant_factories: + raise ValueError(f"Start agent factory name '{agent}' is not in the participant_factories list") + else: + raise ValueError("Call participant_factories(...) before with_start_agent(...)") + self._start_id = agent + elif isinstance(agent, AgentProtocol): + resolved_id = self._resolve_to_id(agent) + if self._participants: + if resolved_id not in self._participants: + raise ValueError(f"Start agent '{resolved_id}' is not in the participants list") + else: + raise ValueError("Call participants(...) before with_start_agent(...)") + self._start_id = resolved_id + else: + raise TypeError("Start agent must be a factory name (str) or an AgentProtocol instance") - Example: - - .. code-block:: python + return self - workflow = ( - HandoffBuilder(participants=[triage, refund, billing]) - .set_coordinator("triage") - .request_prompt("How can we help you today?") - .build() - ) + def with_autonomous_mode( + self, + *, + agents: Sequence[AgentProtocol] | Sequence[str] | None = None, + prompts: dict[str, str] | None = None, + turn_limits: dict[str, int] | None = None, + ) -> "HandoffBuilder": + """Enable autonomous mode for the handoff workflow. - # For more context-aware prompts, you can access the prompt via - # RequestInfoEvent.data.prompt in your event handling loop + Autonomous mode allows agents to continue responding without user input. + The default behavior when autonomous mode is disabled is to return control to the user + after each agent response that does not trigger a handoff. With autonomous mode enabled, + agents can continue the conversation until they request a handoff or the turn limit is reached. - Note: - The prompt is static and set once during workflow construction. If you need - dynamic prompts based on conversation state, you'll need to handle that in - your application's event processing logic. + Args: + agents: Optional list of agents to enable autonomous mode for. Can be: + - Factory names (str): If using participant factories + - AgentProtocol instances: The actual agent objects + - If not provided, all agents will operate in autonomous mode. + prompts: Optional mapping of agent identifiers/factory names to custom prompts to use when continuing + in autonomous mode. If not provided, a default prompt will be used. + turn_limits: Optional mapping of agent identifiers/factory names to maximum number of autonomous turns + before returning control to the user. If not provided, a default turn limit will be used. """ - self._request_prompt = prompt + self._autonomous_mode = True + self._autonomous_mode_prompts = prompts or {} + self._autonomous_mode_turn_limits = turn_limits or {} + self._autonomous_mode_enabled_agents = [self._resolve_to_id(agent) for agent in agents] if agents else [] + return self def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "HandoffBuilder": @@ -1391,12 +970,7 @@ def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "HandoffB from agent_framework import InMemoryCheckpointStorage storage = InMemoryCheckpointStorage() - workflow = ( - HandoffBuilder(participants=[triage, refund, billing]) - .set_coordinator("triage") - .with_checkpointing(storage) - .build() - ) + workflow = HandoffBuilder(participants=[triage, refund, billing]).with_checkpointing(storage).build() # Run workflow with a session ID for resumption async for event in workflow.run_stream("Help me", session_id="user_123"): @@ -1421,16 +995,14 @@ def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "HandoffB self._checkpoint_storage = checkpoint_storage return self - def with_termination_condition( - self, condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] - ) -> "HandoffBuilder": + def with_termination_condition(self, termination_condition: TerminationCondition) -> "HandoffBuilder": """Set a custom termination condition for the handoff workflow. The condition can be either synchronous or asynchronous. Args: - condition: Function that receives the full conversation and returns True - (or awaitable True) if the workflow should terminate (not request further user input). + termination_condition: Function that receives the full conversation and returns True + (or awaitable True) if the workflow should terminate. Returns: Self for chaining. @@ -1453,199 +1025,15 @@ async def check_termination(conv: list[ChatMessage]) -> bool: builder.with_termination_condition(check_termination) """ - self._termination_condition = condition - return self - - def with_interaction_mode( - self, - interaction_mode: Literal["human_in_loop", "autonomous"] = "human_in_loop", - *, - autonomous_turn_limit: int | None = None, - ) -> "HandoffBuilder": - """Choose whether the workflow requests user input or runs autonomously after agent replies. - - In autonomous mode, agents (including specialists) continue iterating on their task - until they explicitly invoke a handoff tool or the turn limit is reached. This allows - specialists to perform long-running autonomous tasks (e.g., research, coding, analysis) - without prematurely returning control to the coordinator or user. - - Args: - interaction_mode: `"human_in_loop"` (default) requests user input after each agent response - that does not trigger a handoff. `"autonomous"` lets agents continue - working until they invoke a handoff tool or the turn limit is reached. - - Keyword Args: - autonomous_turn_limit: Maximum number of agent responses before the workflow yields - when in autonomous mode. Only applicable when interaction_mode - is `"autonomous"`. Default is 50. Set to `None` to disable - the limit (use with caution). Ignored with a warning if provided - when interaction_mode is `"human_in_loop"`. - - Returns: - Self for chaining. - - Example: - - .. code-block:: python - - workflow = ( - HandoffBuilder(participants=[coordinator, research_agent]) - .set_coordinator(coordinator) - .add_handoff(coordinator, research_agent) - .add_handoff(research_agent, coordinator) - .with_interaction_mode("autonomous", autonomous_turn_limit=20) - .build() - ) - - # Flow: User asks a question - # -> Coordinator routes to Research Agent - # -> Research Agent iterates (researches, analyzes, refines) - # -> Research Agent calls handoff_to_coordinator when done - # -> Coordinator provides final response - """ - if interaction_mode not in ("human_in_loop", "autonomous"): - raise ValueError("interaction_mode must be either 'human_in_loop' or 'autonomous'") - self._interaction_mode = interaction_mode - - if autonomous_turn_limit is not None: - if interaction_mode != "autonomous": - logger.warning( - f"autonomous_turn_limit={autonomous_turn_limit} was provided but interaction_mode is " - f"'{interaction_mode}'; ignoring." - ) - elif autonomous_turn_limit <= 0: - raise ValueError("autonomous_turn_limit must be positive when provided") - else: - self._autonomous_turn_limit = autonomous_turn_limit - - return self - - def enable_return_to_previous(self, enabled: bool = True) -> "HandoffBuilder": - """Enable direct return to the current agent after user input, bypassing the coordinator. - - When enabled, after a specialist responds without requesting another handoff, user input - routes directly back to that same specialist instead of always routing back to the - coordinator agent for re-evaluation. - - This is useful when a specialist needs multiple turns with the user to gather information - or resolve an issue, avoiding unnecessary coordinator involvement while maintaining context. - - Flow Comparison: - - **Default (disabled):** - User -> Coordinator -> Specialist -> User -> Coordinator -> Specialist -> ... - - **With return_to_previous (enabled):** - User -> Coordinator -> Specialist -> User -> Specialist -> ... - - Args: - enabled: Whether to enable return-to-previous routing. Default is True. - - Returns: - Self for method chaining. - - Example: - - .. code-block:: python - - workflow = ( - HandoffBuilder(participants=[triage, technical_support, billing]) - .set_coordinator("triage") - .add_handoff(triage, [technical_support, billing]) - .enable_return_to_previous() # Enable direct return routing - .build() - ) - - # Flow: User asks question - # -> Triage routes to Technical Support - # -> Technical Support asks clarifying question - # -> User provides more info - # -> Routes back to Technical Support (not Triage) - # -> Technical Support continues helping - - Multi-tier handoff example: - - .. code-block:: python - - workflow = ( - HandoffBuilder(participants=[triage, specialist_a, specialist_b]) - .set_coordinator("triage") - .add_handoff(triage, [specialist_a, specialist_b]) - .add_handoff(specialist_a, specialist_b) - .enable_return_to_previous() - .build() - ) - - # Flow: User asks question - # -> Triage routes to Specialist A - # -> Specialist A hands off to Specialist B - # -> Specialist B asks clarifying question - # -> User provides more info - # -> Routes back to Specialist B (who is currently handling the conversation) - - Note: - This feature routes to whichever agent most recently responded, whether that's - the coordinator or a specialist. The conversation continues with that agent until - they either hand off to another agent or the termination condition is met. - """ - self._return_to_previous = enabled - return self - - def with_request_info( - self, - *, - agents: Sequence[str | AgentProtocol | Executor] | None = None, - ) -> "HandoffBuilder": - """Enable request info before participants run in the workflow. - - When enabled, the workflow pauses before each participant runs, emitting - a RequestInfoEvent that allows the caller to review the conversation and - optionally inject guidance before the participant responds. The caller provides - input via the standard response_handler/request_info pattern. - - Args: - agents: Optional filter - only pause before these specific agents/executors. - Accepts agent names (str), agent instances, or executor instances. - If None (default), pauses before every participant. - - Returns: - self: The builder instance for fluent chaining. - - Example: - - .. code-block:: python - - # Pause before all participants - workflow = ( - HandoffBuilder(participants=[coordinator, refund, shipping]) - .set_coordinator("coordinator_agent") - .with_request_info() - .build() - ) - - # Pause only before specialist agents (not coordinator) - workflow = ( - HandoffBuilder(participants=[coordinator, refund, shipping]) - .set_coordinator("coordinator_agent") - .with_request_info(agents=[refund, shipping]) - .build() - ) - """ - from ._orchestration_request_info import resolve_request_info_filter - - self._request_info_enabled = True - self._request_info_filter = resolve_request_info_filter(list(agents) if agents else None) + self._termination_condition = termination_condition return self def build(self) -> Workflow: """Construct the final Workflow instance from the configured builder. This method validates the configuration and assembles all internal components: - - Input normalization executor - Starting agent executor - - Handoff coordinator - Specialist agent executors - - User input gateway - Request/response handling Returns: @@ -1654,388 +1042,192 @@ def build(self) -> Workflow: Raises: ValueError: If participants or coordinator were not configured, or if required configuration is invalid. - - Example (Minimal): - - .. code-block:: python - - workflow = ( - HandoffBuilder(participants=[coordinator, refund, billing]).set_coordinator("coordinator").build() - ) - - # Run the workflow - async for event in workflow.run_stream("I need help"): - # Handle events... - pass - - Example (Full Configuration): - - .. code-block:: python - - from agent_framework import InMemoryCheckpointStorage - - storage = InMemoryCheckpointStorage() - workflow = ( - HandoffBuilder( - name="support_workflow", - participants=[coordinator, refund, billing], - description="Customer support with specialist routing", - ) - .set_coordinator("coordinator") - .with_termination_condition(lambda conv: len(conv) > 20) - .request_prompt("How can we help?") - .with_checkpointing(storage) - .build() - ) - - Note: - After calling build(), the builder instance should not be reused. Create a - new builder if you need to construct another workflow with different configuration. """ - if not self._executors and not self._participant_factories: + if not self._participants and not self._participant_factories: raise ValueError( "No participants or participant_factories have been configured. " "Call participants(...) or participant_factories(...) first." ) - if self._starting_agent_id is None: - raise ValueError("Must call set_coordinator(...) before building the workflow.") - - # Resolve executors, aliases, and handoff tool targets - # This will instantiate participants if using factories, and validate handoff config - start_executor_id, executors, aliases, handoff_tool_targets = self._resolve_executors_and_handoffs() - - specialists = {exec_id: executor for exec_id, executor in executors.items() if exec_id != start_executor_id} - if not specialists: - logger.warning("Handoff workflow has no specialist agents; the coordinator will loop with the user.") - - descriptions = { - exec_id: getattr(executor, "description", None) or exec_id for exec_id, executor in executors.items() - } - participant_specs = { - exec_id: GroupChatParticipantSpec(name=exec_id, participant=executor, description=descriptions[exec_id]) - for exec_id, executor in executors.items() - } - - input_node = _InputToConversation(id="input-conversation") - user_gateway = _UserInputGateway( - starting_agent_id=start_executor_id, - prompt=self._request_prompt, - id="handoff-user-input", - ) - builder = WorkflowBuilder(name=self._name, description=self._description).set_start_executor(input_node) - - specialist_aliases = { - alias: specialists[exec_id].id for alias, exec_id in aliases.items() if exec_id in specialists - } - - def _handoff_orchestrator_factory(_: _GroupChatConfig) -> Executor: - return _HandoffCoordinator( - starting_agent_id=start_executor_id, - specialist_ids=specialist_aliases, - input_gateway_id=user_gateway.id, - termination_condition=self._termination_condition, - id="handoff-coordinator", - handoff_tool_targets=handoff_tool_targets, - return_to_previous=self._return_to_previous, - interaction_mode=self._interaction_mode, - autonomous_turn_limit=self._autonomous_turn_limit, - ) - - wiring = _GroupChatConfig( - manager=None, - manager_participant=None, - manager_name=self._starting_agent_id, - participants=participant_specs, - max_rounds=None, - participant_aliases=aliases, - participant_executors=executors, - ) - - # Determine participant factory - wrap with request info interceptor if enabled - participant_factory: Callable[[GroupChatParticipantSpec, _GroupChatConfig], _GroupChatParticipantPipeline] = ( - _default_participant_factory - ) - if self._request_info_enabled: - base_factory = _default_participant_factory - agent_filter = self._request_info_filter - - def _factory_with_request_info( - spec: GroupChatParticipantSpec, - config: _GroupChatConfig, - ) -> _GroupChatParticipantPipeline: - pipeline = list(base_factory(spec, config)) - if pipeline: - # Add interceptor executor BEFORE the participant (prepend) - interceptor = RequestInfoInterceptor( - executor_id=f"request_info:{spec.name}", - agent_filter=agent_filter, - ) - pipeline.insert(0, interceptor) - return tuple(pipeline) - - participant_factory = _factory_with_request_info - - result = assemble_group_chat_workflow( - wiring=wiring, - participant_factory=participant_factory, - orchestrator_factory=_handoff_orchestrator_factory, - interceptors=(), - checkpoint_storage=self._checkpoint_storage, - builder=builder, - return_builder=True, - ) - if not isinstance(result, tuple): - raise TypeError("Expected tuple from assemble_group_chat_workflow with return_builder=True") - builder, coordinator = result - - # When request_info is enabled, the input should go through the interceptor first - if self._request_info_enabled: - # Get the entry executor from the builder's registered executors - starting_entry_id = f"request_info:{self._starting_agent_id}" - starting_entry_executor = builder._executors.get(starting_entry_id) # type: ignore - if starting_entry_executor: - builder = builder.add_edge(input_node, starting_entry_executor) - else: - # Fallback to direct connection if interceptor not found - builder = builder.add_edge(input_node, executors[start_executor_id]) - else: - builder = builder.add_edge(input_node, executors[start_executor_id]) - builder = builder.add_edge(coordinator, user_gateway) - builder = builder.add_edge(user_gateway, coordinator) + if self._start_id is None: + raise ValueError("Must call with_start_agent(...) before building the workflow.") + + # Resolve agents (either from instances or factories) + # The returned map keys are either executor IDs or factory names, which is need to resolve handoff configs + resolved_agents = self._resolve_agents() + # Resolve handoff configurations to use agent display names + # The returned map keys are executor IDs + resolved_handoffs = self._resolve_handoffs(resolved_agents) + # Resolve agents into executors + executors = self._resolve_executors(resolved_agents, resolved_handoffs) + + # Build the workflow graph + start_executor = executors[self._resolve_to_id(resolved_agents[self._start_id])] + builder = WorkflowBuilder( + name=self._name, + description=self._description, + ).set_start_executor(start_executor) + + # Add the appropriate edges + # In handoff workflows, all executors are connected, making a fully connected graph. + # This is because for all agents to stay synchronized, the active agent must be able to + # broadcast updates to all others via edges. Handoffs are controlled internally by the + # `HandoffAgentExecutor` instances using handoff tools and middleware. + for executor in executors.values(): + targets = [e for e in executors.values() if e.id != executor.id] + # Fan-out requires at least 2 targets. Just in case there are only 2 agents total, + # we add a direct edge if there's only 1 target. + if len(targets) > 1: + builder = builder.add_fan_out_edges(executor, targets) + elif len(targets) == 1: + builder = builder.add_edge(executor, targets[0]) + + # Configure checkpointing if enabled + if self._checkpoint_storage: + builder.with_checkpointing(self._checkpoint_storage) return builder.build() - # endregion Fluent Configuration Methods - # region Internal Helper Methods - def _resolve_executors(self) -> tuple[dict[str, Executor], dict[str, str]]: - """Resolve participant factories into executor instances. + def _resolve_agents(self) -> dict[str, AgentProtocol]: + """Resolve participant factories into agent instances. - If executors were provided directly via participants(...), those are returned as-is. - If participant factories were provided via participant_factories(...), those - are invoked to create executor instances and aliases. + If agent instances were provided directly via participants(...), those are + returned as-is. If participant factories were provided via participant_factories(...), + those are invoked to create the agent instances. Returns: - Tuple of (executors map, aliases map) + Map of executor IDs or factory names to `AgentProtocol` instances """ - if self._executors and self._participant_factories: + if self._participants and self._participant_factories: raise ValueError("Cannot have both executors and participant_factories configured") - if self._executors: - if self._aliases: - # Return existing executors and aliases - return self._executors, self._aliases - raise ValueError("Aliases is empty despite executors being provided") + if self._participants: + return self._participants if self._participant_factories: # Invoke each factory to create participant instances - executor_ids_to_executors: dict[str, AgentProtocol | Executor] = {} - factory_names_to_ids: dict[str, str] = {} + factory_names_to_agents: dict[str, AgentProtocol] = {} for factory_name, factory in self._participant_factories.items(): - instance: Executor | AgentProtocol = factory() - if isinstance(instance, Executor): - identifier = instance.id - elif isinstance(instance, AgentProtocol): - identifier = instance.display_name + instance = factory() + if isinstance(instance, AgentProtocol): + resolved_id = self._resolve_to_id(instance) else: - raise TypeError( - f"Participants must be AgentProtocol or Executor instances. Got {type(instance).__name__}." - ) + raise TypeError(f"Participants must be AgentProtocol instances. Got {type(instance).__name__}.") - if identifier in executor_ids_to_executors: - raise ValueError(f"Duplicate participant name '{identifier}' detected") - executor_ids_to_executors[identifier] = instance - factory_names_to_ids[factory_name] = identifier + if resolved_id in factory_names_to_agents: + raise ValueError(f"Duplicate participant name '{resolved_id}' detected") - # Prepare metadata and wrap instances as needed - metadata = prepare_participant_metadata( - executor_ids_to_executors, - description_factory=lambda name, participant: getattr(participant, "description", None) or name, - ) + # Map executors by factory name (not executor.id) because handoff configs reference factory names + # This allows users to configure handoffs using the factory names they provided + factory_names_to_agents[factory_name] = instance - wrapped = metadata["executors"] - # Map executors by factory name (not executor.id) because handoff configs reference factory names - # This allows users to configure handoffs using the factory names they provided - executors = { - factory_name: wrapped[executor_id] for factory_name, executor_id in factory_names_to_ids.items() - } - aliases = metadata["aliases"] - - return executors, aliases + return factory_names_to_agents raise ValueError("No executors or participant_factories have been configured") - def _resolve_handoffs(self, executors: Mapping[str, Executor]) -> tuple[dict[str, Executor], dict[str, str]]: + def _resolve_handoffs(self, agents: Mapping[str, AgentProtocol]) -> dict[str, list[HandoffConfiguration]]: """Handoffs may be specified using factory names or instances; resolve to executor IDs. Args: - executors: Map of executor IDs or factory names to Executor instances + agents: Map of agent IDs or factory names to `AgentProtocol` instances Returns: - Tuple of (updated executors map, handoff configuration map) - The updated executors map may have modified agents with handoff tools added - and maps executor IDs to Executor instances. - The handoff configuration map maps executor IDs to lists of target executor IDs. + Map of executor IDs to list of HandoffConfiguration instances """ - handoff_tool_targets: dict[str, str] = {} - updated_executors = {executor.id: executor for executor in executors.values()} - # Determine which agents should have handoff tools + # Updated map that used agent resolved IDs as keys + updated_handoff_configurations: dict[str, list[HandoffConfiguration]] = {} if self._handoff_config: # Use explicit handoff configuration from add_handoff() calls - for source_id, target_ids in self._handoff_config.items(): - executor = executors.get(source_id) - if not executor: + for source_id, handoff_configurations in self._handoff_config.items(): + source_agent = agents.get(source_id) + if not source_agent: raise ValueError( f"Handoff source agent '{source_id}' not found. " "Please make sure source has been added as either a participant or participant_factory." ) - - if isinstance(executor, AgentExecutor): - # Build targets map for this source agent - targets_map: dict[str, Executor] = {} - for target_id in target_ids: - target_executor = executors.get(target_id) - if not target_executor: - raise ValueError( - f"Handoff target agent '{target_id}' not found. " - "Please make sure target has been added as either a participant or participant_factory." - ) - targets_map[target_executor.id] = target_executor - - # Register handoff tools for this agent - updated_executor, tool_targets = self._prepare_agent_with_handoffs(executor, targets_map) - updated_executors[updated_executor.id] = updated_executor - handoff_tool_targets.update(tool_targets) + for handoff_config in handoff_configurations: + target_agent = agents.get(handoff_config.target_id) + if not target_agent: + raise ValueError( + f"Handoff target agent '{handoff_config.target_id}' not found for source '{source_id}'. " + "Please make sure target has been added as either a participant or participant_factory." + ) + + updated_handoff_configurations.setdefault(self._resolve_to_id(source_agent), []).append( + HandoffConfiguration( + target=self._resolve_to_id(target_agent), + description=handoff_config.description or target_agent.description, + ) + ) else: - if self._starting_agent_id is None or self._starting_agent_id not in executors: - raise RuntimeError("Failed to resolve default handoff configuration due to missing starting agent.") - - # Default behavior: only coordinator gets handoff tools to all specialists - starting_executor = executors[self._starting_agent_id] - specialists = { - executor.id: executor for executor in executors.values() if executor.id != starting_executor.id - } - - if isinstance(starting_executor, AgentExecutor) and specialists: - starting_executor, tool_targets = self._prepare_agent_with_handoffs(starting_executor, specialists) - updated_executors[starting_executor.id] = starting_executor - handoff_tool_targets.update(tool_targets) # Update references after potential agent modifications + # Use default handoff configuration: all agents can hand off to all others (mesh topology) + for source_id, source_agent in agents.items(): + for target_id, target_agent in agents.items(): + if source_id == target_id: + continue # Skip self-handoff + updated_handoff_configurations.setdefault(self._resolve_to_id(source_agent), []).append( + HandoffConfiguration( + target=self._resolve_to_id(target_agent), + description=target_agent.description, + ) + ) - return updated_executors, handoff_tool_targets + return updated_handoff_configurations - def _resolve_executors_and_handoffs(self) -> tuple[str, dict[str, Executor], dict[str, str], dict[str, str]]: - """Resolve participant factories into executor instances and handoff configurations. + def _resolve_executors( + self, + agents: dict[str, AgentProtocol], + handoffs: dict[str, list[HandoffConfiguration]], + ) -> dict[str, HandoffAgentExecutor]: + """Resolve agents into HandoffAgentExecutors. - If executors were provided directly via participants(...), those are returned as-is. - If participant factories were provided via participant_factories(...), those - are invoked to create executor instances and aliases. + Args: + agents: Map of agent IDs or factory names to `AgentProtocol` instances + handoffs: Map of executor IDs to list of HandoffConfiguration instances Returns: - Tuple of (executors map, aliases map, handoff configuration map) + Tuple of (starting executor ID, list of HandoffAgentExecutor instances) """ - # Resolve the participant factories now. This doesn't break the factory pattern - # since the Handoff builder still creates new instances per workflow build. - executors, aliases = self._resolve_executors() - # `self._starting_agent_id` is either a factory name or executor ID at this point, - # resolve to executor ID - if self._starting_agent_id in executors: - start_executor_id = executors[self._starting_agent_id].id - else: - raise RuntimeError("Failed to resolve starting agent ID during build.") + executors: dict[str, HandoffAgentExecutor] = {} - # Resolve handoffs - # This will update the `executors` dict to a map of executor IDs to executors - updated_executors, handoff_tool_targets = self._resolve_handoffs(executors) + for id, agent in agents.items(): + # Note that here `id` may be either factory name or agent resolved ID + resolved_id = self._resolve_to_id(agent) + if resolved_id not in handoffs or not handoffs.get(resolved_id): + logger.warning( + f"No handoff configuration found for agent '{resolved_id}'. " + "This agent will not be able to hand off to any other agents and your workflow may get stuck." + ) - return start_executor_id, updated_executors, aliases, handoff_tool_targets + # Autonomous mode is enabled only for specified agents (or all if none specified) + autonomous_mode = self._autonomous_mode and ( + not self._autonomous_mode_enabled_agents or id in self._autonomous_mode_enabled_agents + ) + + executors[resolved_id] = HandoffAgentExecutor( + agent=agent, + handoffs=handoffs.get(resolved_id, []), + is_start_agent=(id == self._start_id), + termination_condition=self._termination_condition, + autonomous_mode=autonomous_mode, + autonomous_mode_prompt=self._autonomous_mode_prompts.get(id, None), + autonomous_mode_turn_limit=self._autonomous_mode_turn_limits.get(id, None), + ) + + return executors - def _resolve_to_id(self, candidate: str | AgentProtocol | Executor) -> str: + def _resolve_to_id(self, candidate: str | AgentProtocol) -> str: """Resolve a participant reference into a concrete executor identifier.""" - if isinstance(candidate, Executor): - return candidate.id if isinstance(candidate, AgentProtocol): - name: str | None = getattr(candidate, "name", None) - if not name: - raise ValueError("AgentProtocol without a name cannot be resolved to an executor id.") - return self._aliases.get(name, name) + return resolve_agent_id(candidate) if isinstance(candidate, str): - if candidate in self._aliases: - return self._aliases[candidate] return candidate - raise TypeError(f"Invalid starting agent reference: {type(candidate).__name__}") - def _apply_auto_tools(self, agent: ChatAgent, specialists: Mapping[str, Executor]) -> dict[str, str]: - """Attach synthetic handoff tools to a chat agent and return the target lookup table. - - Creates handoff tools for each specialist agent that this agent can route to. - The tool_targets dict maps various name formats (tool name, sanitized name, alias) - to executor IDs to enable flexible handoff target resolution. - - Args: - agent: The ChatAgent to add handoff tools to - specialists: Map of executor IDs or factory names to specialist executors this agent can hand off to - - Returns: - Dict mapping tool names (in various formats) to executor IDs for handoff resolution - """ - chat_options = agent.chat_options - existing_tools = list(chat_options.tools or []) - existing_names = {getattr(tool, "name", "") for tool in existing_tools if hasattr(tool, "name")} - - tool_targets: dict[str, str] = {} - new_tools: list[Any] = [] - for executor in specialists.values(): - alias = executor.id - sanitized = sanitize_identifier(alias) - tool = _create_handoff_tool(alias, executor.description if isinstance(executor, AgentExecutor) else None) - if tool.name not in existing_names: - new_tools.append(tool) - # Map multiple name variations to the same executor ID for robust resolution - tool_targets[tool.name.lower()] = executor.id - tool_targets[sanitized] = executor.id - tool_targets[alias.lower()] = executor.id - - if new_tools: - chat_options.tools = existing_tools + new_tools - else: - chat_options.tools = existing_tools - - return tool_targets - - def _prepare_agent_with_handoffs( - self, - executor: AgentExecutor, - target_agents: Mapping[str, Executor], - ) -> tuple[AgentExecutor, dict[str, str]]: - """Prepare an agent by adding handoff tools for the specified target agents. + raise TypeError(f"Invalid starting agent reference: {type(candidate).__name__}") - Args: - executor: The agent executor to prepare - target_agents: Map of executor IDs to target executors this agent can hand off to + # endregion Internal Helper Methods - Returns: - Tuple of (updated executor, tool_targets map) - """ - agent = getattr(executor, "_agent", None) - if not isinstance(agent, ChatAgent): - return executor, {} - - cloned_agent = _clone_chat_agent(agent) - tool_targets = self._apply_auto_tools(cloned_agent, target_agents) - if tool_targets: - middleware = _AutoHandoffMiddleware(tool_targets) - existing_middleware = list(cloned_agent.middleware or []) - existing_middleware.append(middleware) - cloned_agent.middleware = existing_middleware - - new_executor = AgentExecutor( - cloned_agent, - agent_thread=getattr(executor, "_agent_thread", None), - output_response=getattr(executor, "_output_response", False), - id=executor.id, - ) - return new_executor, tool_targets - # endregion Internal Helper Methods +# endregion Handoff workflow builder diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index cdbc79e0c0..052a59766f 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -7,40 +7,36 @@ import re import sys from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, Sequence +from collections.abc import Sequence from dataclasses import dataclass, field from enum import Enum -from typing import Any, TypeVar, cast -from uuid import uuid4 +from typing import Any, ClassVar, TypeVar, cast + +from typing_extensions import Never from agent_framework import ( AgentProtocol, - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, ChatMessage, - FunctionApprovalRequestContent, - FunctionResultContent, Role, ) -from ._base_group_chat_orchestrator import BaseGroupChatOrchestrator -from ._checkpoint import CheckpointStorage, WorkflowCheckpoint -from ._const import EXECUTOR_STATE_KEY, WORKFLOW_RUN_KWARGS_KEY -from ._events import AgentRunUpdateEvent, WorkflowEvent -from ._executor import Executor, handler -from ._group_chat import ( - GroupChatBuilder, - _GroupChatConfig, # type: ignore[reportPrivateUsage] - _GroupChatParticipantPipeline, # type: ignore[reportPrivateUsage] - _GroupChatRequestMessage, # type: ignore[reportPrivateUsage] - _GroupChatResponseMessage, # type: ignore[reportPrivateUsage] - group_chat_orchestrator, +from ._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse +from ._base_group_chat_orchestrator import ( + BaseGroupChatOrchestrator, + GroupChatParticipantMessage, + GroupChatRequestMessage, + GroupChatResponseMessage, + GroupChatWorkflowContext_T_Out, + ParticipantRegistry, ) -from ._message_utils import normalize_messages_input +from ._checkpoint import CheckpointStorage +from ._events import ExecutorEvent +from ._executor import Executor, handler from ._model_utils import DictConvertible, encode_value -from ._participant_utils import GroupChatParticipantSpec, participant_description from ._request_info_mixin import response_handler -from ._workflow import Workflow, WorkflowRunResult +from ._workflow import Workflow +from ._workflow_builder import WorkflowBuilder from ._workflow_context import WorkflowContext if sys.version_info >= (3, 11): @@ -103,14 +99,6 @@ def _message_from_payload(payload: Any) -> ChatMessage: raise TypeError("Unable to reconstruct ChatMessage from payload") -# region Magentic event metadata constants - -# Event type identifiers for magentic_event_type in additional_properties -MAGENTIC_EVENT_TYPE_ORCHESTRATOR = "orchestrator_message" -MAGENTIC_EVENT_TYPE_AGENT_DELTA = "agent_delta" - -# endregion Magentic event metadata constants - # region Magentic One Prompts ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT = """Below I will present you a request. @@ -276,204 +264,6 @@ def _new_participant_descriptions() -> dict[str, str]: return {} -def _new_chat_message_list() -> list[ChatMessage]: - """Typed default factory for ChatMessage list to satisfy type checkers.""" - return [] - - -@dataclass -class _MagenticStartMessage(DictConvertible): - """Internal: A message to start a magentic workflow.""" - - messages: list[ChatMessage] = field(default_factory=_new_chat_message_list) - run_kwargs: dict[str, Any] = field(default_factory=dict) - - def __init__( - self, - messages: str | ChatMessage | Sequence[str] | Sequence[ChatMessage] | None = None, - *, - task: ChatMessage | None = None, - run_kwargs: dict[str, Any] | None = None, - ) -> None: - normalized = normalize_messages_input(messages) - if task is not None: - normalized += normalize_messages_input(task) - if not normalized: - raise ValueError("MagenticStartMessage requires at least one message input.") - self.messages: list[ChatMessage] = normalized - self.run_kwargs: dict[str, Any] = run_kwargs or {} - - @property - def task(self) -> ChatMessage: - """Final user message for the task.""" - return self.messages[-1] - - @classmethod - def from_string(cls, task_text: str) -> "_MagenticStartMessage": - """Create a MagenticStartMessage from a simple string.""" - return cls(task_text) - - def to_dict(self) -> dict[str, Any]: - """Create a dict representation of the message.""" - return { - "messages": [message.to_dict() for message in self.messages], - "task": self.task.to_dict(), - } - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> "_MagenticStartMessage": - """Create from a dict.""" - if "messages" in data: - raw_messages = data["messages"] - if not isinstance(raw_messages, Sequence) or isinstance(raw_messages, (str, bytes)): - raise TypeError("MagenticStartMessage 'messages' must be a sequence.") - messages: list[ChatMessage] = [ChatMessage.from_dict(raw) for raw in raw_messages] # type: ignore[arg-type] - return cls(messages) - if "task" in data: - task = ChatMessage.from_dict(data["task"]) - return cls(task) - raise KeyError("Expected 'messages' or 'task' in MagenticStartMessage payload.") - - -@dataclass -class _MagenticRequestMessage(_GroupChatRequestMessage): - """Internal: A request message type for agents in a magentic workflow.""" - - task_context: str = "" - - -class _MagenticResponseMessage(_GroupChatResponseMessage): - """Internal: A response message type. - - When emitted by the orchestrator you can mark it as a broadcast to all agents, - or target a specific agent by name. - """ - - def __init__( - self, - body: ChatMessage, - target_agent: str | None = None, # deliver only to this agent if set - broadcast: bool = False, # deliver to all agents if True - ) -> None: - agent_name = body.author_name or "" - super().__init__( - agent_name=agent_name, - message=body, - ) - self.body = body - self.target_agent = target_agent - self.broadcast = broadcast - - def to_dict(self) -> dict[str, Any]: - """Create a dict representation of the message.""" - return {"body": self.body.to_dict(), "target_agent": self.target_agent, "broadcast": self.broadcast} - - @classmethod - def from_dict(cls, value: dict[str, Any]) -> "_MagenticResponseMessage": - """Create from a dict.""" - body = ChatMessage.from_dict(value["body"]) - target_agent = value.get("target_agent") - broadcast = value.get("broadcast", False) - return cls(body=body, target_agent=target_agent, broadcast=broadcast) - - -# region Human Intervention Types - - -class MagenticHumanInterventionKind(str, Enum): - """The kind of human intervention being requested.""" - - PLAN_REVIEW = "plan_review" # Review and approve/revise the initial plan - TOOL_APPROVAL = "tool_approval" # Approve a tool/function call - STALL = "stall" # Workflow has stalled and needs guidance - - -class MagenticHumanInterventionDecision(str, Enum): - """Decision options for human intervention responses.""" - - APPROVE = "approve" # Approve (plan review, tool approval) - REVISE = "revise" # Request revision with feedback (plan review) - REJECT = "reject" # Reject/deny (tool approval) - CONTINUE = "continue" # Continue with current state (stall) - REPLAN = "replan" # Trigger replanning (stall) - GUIDANCE = "guidance" # Provide guidance text (stall, tool approval) - - -@dataclass -class _MagenticHumanInterventionRequest: - """Unified request for human intervention in a Magentic workflow. - - This request is emitted when the workflow needs human input. The `kind` field - indicates what type of intervention is needed, and the relevant fields are - populated based on the kind. - - Attributes: - request_id: Unique identifier for correlating responses - kind: The type of intervention needed (plan_review, tool_approval, stall) - - # Plan review fields - task_text: The task description (plan_review) - facts_text: Extracted facts from the task (plan_review) - plan_text: The proposed or current plan (plan_review, stall) - round_index: Number of review rounds so far (plan_review) - - # Tool approval fields - agent_id: The agent requesting input (tool_approval) - prompt: Description of what input is needed (tool_approval) - context: Additional context (tool_approval) - conversation_snapshot: Recent conversation history (tool_approval) - - # Stall intervention fields - stall_count: Number of consecutive stall rounds (stall) - max_stall_count: Threshold that triggered intervention (stall) - stall_reason: Description of why progress stalled (stall) - last_agent: Last active agent (stall) - """ - - request_id: str = field(default_factory=lambda: str(uuid4())) - kind: MagenticHumanInterventionKind = MagenticHumanInterventionKind.PLAN_REVIEW - - # Plan review fields - task_text: str = "" - facts_text: str = "" - plan_text: str = "" - round_index: int = 0 - - # Tool approval fields - agent_id: str = "" - prompt: str = "" - context: str | None = None - conversation_snapshot: list[ChatMessage] = field(default_factory=list) # type: ignore - - # Stall intervention fields - stall_count: int = 0 - max_stall_count: int = 3 - stall_reason: str = "" - last_agent: str = "" - - -@dataclass -class _MagenticHumanInterventionReply: - """Unified reply to a human intervention request. - - The relevant fields depend on the original request kind and the decision made. - - Attributes: - decision: The human's decision (approve, revise, continue, replan, guidance) - edited_plan_text: New plan text if directly editing (plan_review with approve/revise) - comments: Feedback for revision or guidance text (plan_review, stall with guidance) - response_text: Free-form response text (tool_approval) - """ - - decision: MagenticHumanInterventionDecision - edited_plan_text: str | None = None - comments: str | None = None - response_text: str | None = None - - -# endregion Human Intervention Types - - @dataclass class _MagenticTaskLedger(DictConvertible): """Internal: Task ledger for the Standard Magentic manager.""" @@ -493,7 +283,7 @@ def from_dict(cls, data: dict[str, Any]) -> "_MagenticTaskLedger": @dataclass -class _MagenticProgressLedgerItem(DictConvertible): +class MagenticProgressLedgerItem(DictConvertible): """Internal: A progress ledger item.""" reason: str @@ -503,7 +293,7 @@ def to_dict(self) -> dict[str, Any]: return {"reason": self.reason, "answer": self.answer} @classmethod - def from_dict(cls, data: dict[str, Any]) -> "_MagenticProgressLedgerItem": + def from_dict(cls, data: dict[str, Any]) -> "MagenticProgressLedgerItem": answer_value = data.get("answer") if not isinstance(answer_value, (str, bool)): answer_value = "" # Default to empty string if not str or bool @@ -511,14 +301,14 @@ def from_dict(cls, data: dict[str, Any]) -> "_MagenticProgressLedgerItem": @dataclass -class _MagenticProgressLedger(DictConvertible): +class MagenticProgressLedger(DictConvertible): """Internal: A progress ledger for tracking workflow progress.""" - is_request_satisfied: _MagenticProgressLedgerItem - is_in_loop: _MagenticProgressLedgerItem - is_progress_being_made: _MagenticProgressLedgerItem - next_speaker: _MagenticProgressLedgerItem - instruction_or_question: _MagenticProgressLedgerItem + is_request_satisfied: MagenticProgressLedgerItem + is_in_loop: MagenticProgressLedgerItem + is_progress_being_made: MagenticProgressLedgerItem + next_speaker: MagenticProgressLedgerItem + instruction_or_question: MagenticProgressLedgerItem def to_dict(self) -> dict[str, Any]: return { @@ -530,13 +320,13 @@ def to_dict(self) -> dict[str, Any]: } @classmethod - def from_dict(cls, data: dict[str, Any]) -> "_MagenticProgressLedger": + def from_dict(cls, data: dict[str, Any]) -> "MagenticProgressLedger": return cls( - is_request_satisfied=_MagenticProgressLedgerItem.from_dict(data.get("is_request_satisfied", {})), - is_in_loop=_MagenticProgressLedgerItem.from_dict(data.get("is_in_loop", {})), - is_progress_being_made=_MagenticProgressLedgerItem.from_dict(data.get("is_progress_being_made", {})), - next_speaker=_MagenticProgressLedgerItem.from_dict(data.get("next_speaker", {})), - instruction_or_question=_MagenticProgressLedgerItem.from_dict(data.get("instruction_or_question", {})), + is_request_satisfied=MagenticProgressLedgerItem.from_dict(data.get("is_request_satisfied", {})), + is_in_loop=MagenticProgressLedgerItem.from_dict(data.get("is_in_loop", {})), + is_progress_being_made=MagenticProgressLedgerItem.from_dict(data.get("is_progress_being_made", {})), + next_speaker=MagenticProgressLedgerItem.from_dict(data.get("next_speaker", {})), + instruction_or_question=MagenticProgressLedgerItem.from_dict(data.get("instruction_or_question", {})), ) @@ -544,7 +334,7 @@ def from_dict(cls, data: dict[str, Any]) -> "_MagenticProgressLedger": class MagenticContext(DictConvertible): """Context for the Magentic manager.""" - task: ChatMessage + task: str chat_history: list[ChatMessage] = field(default_factory=_new_chat_history) participant_descriptions: dict[str, str] = field(default_factory=_new_participant_descriptions) round_count: int = 0 @@ -553,7 +343,7 @@ class MagenticContext(DictConvertible): def to_dict(self) -> dict[str, Any]: return { - "task": _message_to_payload(self.task), + "task": self.task, "chat_history": [_message_to_payload(msg) for msg in self.chat_history], "participant_descriptions": dict(self.participant_descriptions), "round_count": self.round_count, @@ -563,14 +353,27 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict[str, Any]) -> "MagenticContext": + # Validate required fields + # `task` is required + task = data.get("task") + if task is None or not isinstance(task, str): + raise ValueError("MagenticContext requires a 'task' string field.") + # `chat_history` is required chat_history_payload = data.get("chat_history", []) history: list[ChatMessage] = [] for item in chat_history_payload: history.append(_message_from_payload(item)) + # `participant_descriptions` is required + participant_descriptions = data.get("participant_descriptions") + if not isinstance(participant_descriptions, dict) or not participant_descriptions: + raise ValueError("MagenticContext requires a 'participant_descriptions' dictionary field.") + if not all(isinstance(k, str) and isinstance(v, str) for k, v in participant_descriptions.items()): # type: ignore + raise ValueError("MagenticContext 'participant_descriptions' must be a dict of str to str.") + return cls( - task=_message_from_payload(data.get("task")), + task=task, chat_history=history, - participant_descriptions=dict(data.get("participant_descriptions", {})), + participant_descriptions=participant_descriptions, # type: ignore round_count=data.get("round_count", 0), stall_count=data.get("stall_count", 0), reset_count=data.get("reset_count", 0), @@ -597,13 +400,6 @@ def _team_block(participants: dict[str, str]) -> str: return "\n".join(f"- {name}: {desc}" for name, desc in participants.items()) -def _first_assistant(messages: list[ChatMessage]) -> ChatMessage | None: - for msg in reversed(messages): - if msg.role == Role.ASSISTANT: - return msg - return None - - def _extract_json(text: str) -> dict[str, Any]: """Potentially temp helper method. @@ -693,7 +489,7 @@ async def replan(self, magentic_context: MagenticContext) -> ChatMessage: ... @abstractmethod - async def create_progress_ledger(self, magentic_context: MagenticContext) -> _MagenticProgressLedger: + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: """Create a progress ledger.""" ... @@ -724,6 +520,8 @@ class StandardMagenticManager(MagenticManagerBase): task_ledger: _MagenticTaskLedger | None + MANAGER_NAME: ClassVar[str] = "StandardMagenticManager" + def __init__( self, agent: AgentProtocol, @@ -796,26 +594,22 @@ async def _complete( The agent's run method is called which applies the agent's configured options (temperature, seed, instructions, etc.). """ - response: AgentRunResponse = await self._agent.run(messages) - out_messages = response.messages if response else None - if out_messages: - last = out_messages[-1] - return ChatMessage( - role=last.role, - text=last.text, - author_name=last.author_name or MAGENTIC_MANAGER_NAME, - ) - return ChatMessage(role=Role.ASSISTANT, text="No output produced.", author_name=MAGENTIC_MANAGER_NAME) + response: AgentResponse = await self._agent.run(messages) + if not response.messages: + raise RuntimeError("Agent returned no messages in response.") + if len(response.messages) > 1: + logger.warning("Agent returned multiple messages; using the last one.") + + return response.messages[-1] async def plan(self, magentic_context: MagenticContext) -> ChatMessage: """Create facts and plan using the model, then render a combined task ledger as a single assistant message.""" - task_text = magentic_context.task.text team_text = _team_block(magentic_context.participant_descriptions) # Gather facts facts_user = ChatMessage( role=Role.USER, - text=self.task_ledger_facts_prompt.format(task=task_text), + text=self.task_ledger_facts_prompt.format(task=magentic_context.task), ) facts_msg = await self._complete([*magentic_context.chat_history, facts_user]) @@ -834,7 +628,7 @@ async def plan(self, magentic_context: MagenticContext) -> ChatMessage: magentic_context.chat_history.extend([facts_user, facts_msg, plan_user, plan_msg]) combined = self.task_ledger_full_prompt.format( - task=task_text, + task=magentic_context.task, team=team_text, facts=facts_msg.text, plan=plan_msg.text, @@ -846,13 +640,14 @@ async def replan(self, magentic_context: MagenticContext) -> ChatMessage: if self.task_ledger is None: raise RuntimeError("replan() called before plan(); call plan() once before requesting a replan.") - task_text = magentic_context.task.text team_text = _team_block(magentic_context.participant_descriptions) # Update facts facts_update_user = ChatMessage( role=Role.USER, - text=self.task_ledger_facts_update_prompt.format(task=task_text, old_facts=self.task_ledger.facts.text), + text=self.task_ledger_facts_update_prompt.format( + task=magentic_context.task, old_facts=self.task_ledger.facts.text + ), ) updated_facts = await self._complete([*magentic_context.chat_history, facts_update_user]) @@ -876,14 +671,14 @@ async def replan(self, magentic_context: MagenticContext) -> ChatMessage: magentic_context.chat_history.extend([facts_update_user, updated_facts, plan_update_user, updated_plan]) combined = self.task_ledger_full_prompt.format( - task=task_text, + task=magentic_context.task, team=team_text, facts=updated_facts.text, plan=updated_plan.text, ) return ChatMessage(role=Role.ASSISTANT, text=combined, author_name=MAGENTIC_MANAGER_NAME) - async def create_progress_ledger(self, magentic_context: MagenticContext) -> _MagenticProgressLedger: + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: """Use the model to produce a JSON progress ledger based on the conversation so far. Adds lightweight retries with backoff for transient parse issues and avoids selecting a @@ -897,7 +692,7 @@ async def create_progress_ledger(self, magentic_context: MagenticContext) -> _Ma team_text = _team_block(magentic_context.participant_descriptions) prompt = self.progress_ledger_prompt.format( - task=magentic_context.task.text, + task=magentic_context.task, team=team_text, names=names_csv, ) @@ -910,7 +705,7 @@ async def create_progress_ledger(self, magentic_context: MagenticContext) -> _Ma raw = await self._complete([*magentic_context.chat_history, user_message]) try: ledger_dict = _extract_json(raw.text) - return _coerce_model(_MagenticProgressLedger, ledger_dict) + return _coerce_model(MagenticProgressLedger, ledger_dict) except Exception as ex: last_error = ex attempts += 1 @@ -927,7 +722,7 @@ async def create_progress_ledger(self, magentic_context: MagenticContext) -> _Ma async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessage: """Ask the model to produce the final answer addressed to the user.""" - prompt = self.final_answer_prompt.format(task=magentic_context.task.text) + prompt = self.final_answer_prompt.format(task=magentic_context.task) user_message = ChatMessage(role=Role.USER, text=prompt) response = await self._complete([*magentic_context.chat_history, user_message]) # Ensure role is assistant @@ -956,627 +751,370 @@ def on_checkpoint_restore(self, state: dict[str, Any]) -> None: # endregion Magentic Manager -# region Magentic Executors +# region Magentic Orchestrator -class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator): - """Magentic orchestrator executor that handles all orchestration logic. +class MagenticResetSignal: + """Signal to indicate that the Magentic workflow should reset. - This executor manages the entire Magentic One workflow including: - - Initial planning and task ledger creation - - Progress tracking and completion detection - - Agent coordination and message routing - - Reset and replanning logic + This signal can be raised within the orchestrator's inner loop to trigger + a reset of the Magentic context, clearing chat history and resetting + stall counts. """ - # Typed attributes (initialized in __init__) - _agent_executors: dict[str, "MagenticAgentExecutor"] - _context: "MagenticContext | None" - _task_ledger: "ChatMessage | None" - _inner_loop_lock: asyncio.Lock - _require_plan_signoff: bool - _plan_review_round: int - _max_plan_review_rounds: int - _terminated: bool - _enable_stall_intervention: bool - - def __init__( - self, - manager: MagenticManagerBase, - participants: dict[str, str], - *, - require_plan_signoff: bool = False, - max_plan_review_rounds: int = 10, - enable_stall_intervention: bool = False, - executor_id: str | None = None, - ) -> None: - """Initializes a new instance of the MagenticOrchestratorExecutor. - - Args: - manager: The Magentic manager instance. - participants: A dictionary of participant IDs to their names. - require_plan_signoff: Whether to require plan sign-off from a human. - max_plan_review_rounds: The maximum number of plan review rounds. - enable_stall_intervention: Whether to request human input on stalls instead of auto-replan. - executor_id: An optional executor ID. - """ - super().__init__(executor_id or f"magentic_orchestrator_{uuid4().hex[:8]}") - self._manager = manager - self._participants = participants - self._context = None - self._task_ledger = None - self._require_plan_signoff = require_plan_signoff - self._plan_review_round = 0 - self._max_plan_review_rounds = max_plan_review_rounds - self._enable_stall_intervention = enable_stall_intervention - # Registry of agent executors for internal coordination (e.g., resets) - self._agent_executors = {} - # Terminal state marker to stop further processing after completion/limits - self._terminated = False - # Tracks whether checkpoint state has been applied for this run - - def _get_author_name(self) -> str: - """Get the magentic manager name for orchestrator-generated messages.""" - return MAGENTIC_MANAGER_NAME - - def register_agent_executor(self, name: str, executor: "MagenticAgentExecutor") -> None: - """Register an agent executor for internal control (no messages).""" - self._agent_executors[name] = executor - - async def _emit_orchestrator_message( - self, - ctx: WorkflowContext[Any, list[ChatMessage]], - message: ChatMessage, - kind: str, - ) -> None: - """Emit orchestrator message to the workflow event stream. - - Emits an AgentRunUpdateEvent (for agent wrapper consumers) with metadata indicating - the orchestrator event type. - - Args: - ctx: Workflow context for adding events to the stream - message: Orchestrator message to emit (task, plan, instruction, notice) - kind: Message classification (user_task, task_ledger, instruction, notice) - - Example: - async for event in workflow.run_stream("task"): - if isinstance(event, AgentRunUpdateEvent): - props = event.data.additional_properties if event.data else None - if props and props.get("magentic_event_type") == "orchestrator_message": - kind = props.get("orchestrator_message_kind", "") - print(f"Orchestrator {kind}: {event.data.text}") - """ - # Emit AgentRunUpdateEvent with metadata - update = AgentRunResponseUpdate( - text=message.text, - role=message.role, - author_name=self._get_author_name(), - additional_properties={ - "magentic_event_type": MAGENTIC_EVENT_TYPE_ORCHESTRATOR, - "orchestrator_message_kind": kind, - "orchestrator_id": self.id, - }, - ) - await ctx.add_event(AgentRunUpdateEvent(executor_id=self.id, data=update)) - - @override - async def on_checkpoint_save(self) -> dict[str, Any]: - """Capture current orchestrator state for checkpointing. - - Uses OrchestrationState for structure but maintains Magentic's complex metadata - at the top level for backward compatibility with existing checkpoints. - - Returns: - Dict ready for checkpoint persistence - """ - state: dict[str, Any] = { - "plan_review_round": self._plan_review_round, - "max_plan_review_rounds": self._max_plan_review_rounds, - "require_plan_signoff": self._require_plan_signoff, - "terminated": self._terminated, - } - if self._context is not None: - state["magentic_context"] = self._context.to_dict() - if self._task_ledger is not None: - state["task_ledger"] = _message_to_payload(self._task_ledger) + pass - try: - state["manager_state"] = self._manager.on_checkpoint_save() - except Exception as exc: - logger.warning(f"Failed to save manager state for checkpoint: {exc}\nSkipping...") - return state +class MagenticOrchestratorEventType(str, Enum): + """Types of Magentic orchestrator events.""" - @override - async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: - """Restore orchestrator state from checkpoint. + PLAN_CREATED = "plan_created" + REPLANNED = "replanned" + PROGRESS_LEDGER_UPDATED = "progress_ledger_updated" - Maintains backward compatibility with existing Magentic checkpoints - while supporting OrchestrationState structure. - Args: - state: Checkpoint data dict - """ - # Support both old format (direct keys) and new format (wrapped in OrchestrationState) - if "metadata" in state and isinstance(state.get("metadata"), dict): - # New OrchestrationState format - extract metadata - from ._orchestration_state import OrchestrationState +@dataclass +class MagenticOrchestratorEvent(ExecutorEvent): + """Base class for Magentic orchestrator events.""" - orch_state = OrchestrationState.from_dict(state) - state = orch_state.metadata + def __init__( + self, + executor_id: str, + event_type: MagenticOrchestratorEventType, + data: ChatMessage | MagenticProgressLedger, + ) -> None: + super().__init__(executor_id, data) + self.event_type = event_type - ctx_payload = state.get("magentic_context") - if ctx_payload is not None: - try: - if isinstance(ctx_payload, dict): - self._context = MagenticContext.from_dict(ctx_payload) # type: ignore[arg-type] - else: - self._context = None - except Exception as exc: # pragma: no cover - defensive - logger.warning(f"Failed to restore magentic context: {exc}") - self._context = None - ledger_payload = state.get("task_ledger") - if ledger_payload is not None: - try: - self._task_ledger = _message_from_payload(ledger_payload) - except Exception as exc: # pragma: no cover - logger.warning(f"Failed to restore task ledger message: {exc}") - self._task_ledger = None + def __repr__(self) -> str: + return f"{self.__class__.__name__}(executor_id={self.executor_id}, event_type={self.event_type})" - if "plan_review_round" in state: - try: - self._plan_review_round = int(state["plan_review_round"]) - except Exception: # pragma: no cover - logger.debug("Ignoring invalid plan_review_round in checkpoint state") - if "max_plan_review_rounds" in state: - self._max_plan_review_rounds = state.get("max_plan_review_rounds") # type: ignore[assignment] - if "require_plan_signoff" in state: - self._require_plan_signoff = bool(state.get("require_plan_signoff")) - if "terminated" in state: - self._terminated = bool(state.get("terminated")) - manager_state = state.get("manager_state") - if manager_state is not None: - try: - self._manager.on_checkpoint_restore(manager_state) - except Exception as exc: # pragma: no cover - logger.warning(f"Failed to restore manager state: {exc}") +# region Request info related types - self._reconcile_restored_participants() - def _reconcile_restored_participants(self) -> None: - """Ensure restored participant roster matches the current workflow graph.""" - if self._context is None: - return +@dataclass +class MagenticPlanReviewResponse: + """Response to a human plan review request. - restored = self._context.participant_descriptions or {} - expected = self._participants + Attributes: + review: List of messages containing feedback and suggested revisions. If empty, + the plan is considered approved. + """ - restored_names = set(restored.keys()) - expected_names = set(expected.keys()) + review: list[ChatMessage] - if restored_names != expected_names: - missing = ", ".join(sorted(expected_names - restored_names)) or "none" - unexpected = ", ".join(sorted(restored_names - expected_names)) or "none" - raise RuntimeError( - "Magentic checkpoint restore failed: participant names do not match the checkpoint. " - "Ensure MagenticBuilder.participants keys remain stable across runs. " - f"Missing names: {missing}; unexpected names: {unexpected}." - ) + @staticmethod + def approve() -> "MagenticPlanReviewResponse": + """Create an approval response.""" + return MagenticPlanReviewResponse(review=[]) - # Refresh descriptions so prompt surfaces always reflect the rebuilt workflow inputs. - for name, description in expected.items(): - restored[name] = description + @staticmethod + def revise(feedback: str | list[str] | ChatMessage | list[ChatMessage]) -> "MagenticPlanReviewResponse": + """Create a revision response with feedback.""" + if isinstance(feedback, str): + feedback = [ChatMessage(role=Role.USER, text=feedback)] + elif isinstance(feedback, ChatMessage): + feedback = [feedback] + elif isinstance(feedback, list): + feedback = [ChatMessage(role=Role.USER, text=item) if isinstance(item, str) else item for item in feedback] - @handler - async def handle_start_message( - self, - message: _MagenticStartMessage, - context: WorkflowContext[ - _MagenticResponseMessage | _MagenticRequestMessage | _MagenticHumanInterventionRequest, list[ChatMessage] - ], - ) -> None: - """Handle the initial start message to begin orchestration.""" - if getattr(self, "_terminated", False): - return - logger.info("Magentic Orchestrator: Received start message") + return MagenticPlanReviewResponse(review=feedback) - # Store run_kwargs in SharedState so agent executors can access them - # Always store (even empty dict) so retrieval is deterministic - await context.set_shared_state(WORKFLOW_RUN_KWARGS_KEY, message.run_kwargs or {}) - self._context = MagenticContext( - task=message.task, - participant_descriptions=self._participants, - ) - if message.messages: - self._context.chat_history.extend(message.messages) +@dataclass +class MagenticPlanReviewRequest: + """Request for human review of a proposed plan. - # Non-streaming callback for the orchestrator receipt of the task - await self._emit_orchestrator_message(context, message.task, ORCH_MSG_KIND_USER_TASK) + Attributes: + plan: The proposed plan message. + current_progress: The current progress ledger, if available. + During the initial plan review, this will be None. In subsequent + reviews after replanning (due to stalls), this will contain the + latest progress ledger that determined no progress had been made + or the workflow was in a loop. + is_stalled: Whether the workflow is currently stalled. + """ - # Initial planning using the manager with real model calls - self._task_ledger = await self._manager.plan(self._context.clone(deep=True)) + plan: ChatMessage + current_progress: MagenticProgressLedger | None + is_stalled: bool - # If a human must sign off, ask now and return. The response handler will resume. - if self._require_plan_signoff: - await self._send_plan_review_request(cast(WorkflowContext, context)) - return + def approve(self) -> MagenticPlanReviewResponse: + """Create an approval response.""" + return MagenticPlanReviewResponse.approve() - # Add task ledger to conversation history - self._context.chat_history.append(self._task_ledger) + def revise(self, feedback: str | list[str] | ChatMessage | list[ChatMessage]) -> MagenticPlanReviewResponse: + """Create a revision response with feedback.""" + return MagenticPlanReviewResponse.revise(feedback) - logger.debug("Task ledger created.") - await self._emit_orchestrator_message(context, self._task_ledger, ORCH_MSG_KIND_TASK_LEDGER) +# endregion Human Intervention Types - # Start the inner loop - ctx2 = cast( - WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], - context, - ) - await self._run_inner_loop(ctx2) - @handler - async def handle_task_text( - self, - task_text: str, - context: WorkflowContext[ - _MagenticResponseMessage | _MagenticRequestMessage | _MagenticHumanInterventionRequest, list[ChatMessage] - ], - ) -> None: - await self.handle_start_message(_MagenticStartMessage.from_string(task_text), context) +class MagenticOrchestrator(BaseGroupChatOrchestrator): + """Magentic orchestrator that defines the workflow structure. - @handler - async def handle_task_message( - self, - task_message: ChatMessage, - context: WorkflowContext[ - _MagenticResponseMessage | _MagenticRequestMessage | _MagenticHumanInterventionRequest, list[ChatMessage] - ], - ) -> None: - await self.handle_start_message(_MagenticStartMessage(task_message), context) + This orchestrator manages the overall Magentic workflow in the following structure: - @handler - async def handle_task_messages( - self, - conversation: list[ChatMessage], - context: WorkflowContext[ - _MagenticResponseMessage | _MagenticRequestMessage | _MagenticHumanInterventionRequest, list[ChatMessage] - ], - ) -> None: - await self.handle_start_message(_MagenticStartMessage(conversation), context) + 1. Upon receiving the task (a list of messages), it creates the plan using the manager + then runs the inner loop. + 2. The inner loop is distributed and implementation is decentralized. In the orchestrator, + it is responsible for: + - Creating the progress ledger using the manager. + - Checking for task completion. + - Detecting stalling or looping and triggering replanning if needed. + - Sending requests to participants based on the progress ledger's next speaker. + - Issue requests for human intervention if enabled and needed. + 3. The inner loop waits for responses from the selected participant, then continues the loop. + 4. The orchestrator breaks out of the inner loop when the replanning or final answer conditions are met. + 5. The outer loop handles replanning and reenters the inner loop. + """ - @handler - async def handle_response_message( + def __init__( self, - message: _MagenticResponseMessage, - context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], + manager: MagenticManagerBase, + participant_registry: ParticipantRegistry, + *, + require_plan_signoff: bool = False, ) -> None: - """Handle responses from agents.""" - if getattr(self, "_terminated", False): - return - - if self._context is None: - raise RuntimeError("Magentic Orchestrator: Received response but not initialized") + """Initialize the Magentic orchestrator. - logger.debug("Magentic Orchestrator: Received response from agent") + Args: + manager: The Magentic manager instance to use for planning and progress tracking. + participant_registry: Registry of participants involved in the workflow. - # Add transfer message if needed - if message.body.role != Role.USER: - transfer_msg = ChatMessage( - role=Role.USER, - text=f"Transferred to {getattr(message.body, 'author_name', 'agent')}", - ) - self._context.chat_history.append(transfer_msg) + Keyword Args: + require_plan_signoff: If True, requires human approval of the initial plan before proceeding. + """ + super().__init__("magentic_orchestrator", participant_registry) + self._manager = manager + self._require_plan_signoff = require_plan_signoff - # Add agent response to context - self._context.chat_history.append(message.body) + # Task related state + self._magentic_context: MagenticContext | None = None + self._task_ledger: ChatMessage | None = None + self._progress_ledger: MagenticProgressLedger | None = None - # Continue with inner loop - await self._run_inner_loop(context) + # Termination related state + self._terminated: bool = False + self._max_rounds = manager.max_round_count - @response_handler - async def handle_human_intervention_response( + @override + async def _handle_messages( self, - original_request: _MagenticHumanInterventionRequest, - response: _MagenticHumanInterventionReply, - context: WorkflowContext[ - _MagenticResponseMessage | _MagenticRequestMessage | _MagenticHumanInterventionRequest, list[ChatMessage] - ], + messages: list[ChatMessage], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Handle unified human intervention responses. - - Routes the response to the appropriate handler based on the original request kind. - """ - if getattr(self, "_terminated", False): - return - - if self._context is None: - return + """Handle the initial task messages to start the workflow.""" + if self._terminated: + raise RuntimeError( + "This Magentic workflow has already been completed. No further messages can be processed. " + "Use the builder to create a new workflow instance to handle additional tasks." + ) - if original_request.kind == MagenticHumanInterventionKind.PLAN_REVIEW: - await self._handle_plan_review_response(original_request, response, context) - elif original_request.kind == MagenticHumanInterventionKind.STALL: - await self._handle_stall_intervention_response(original_request, response, context) - # TOOL_APPROVAL is handled by MagenticAgentExecutor, not the orchestrator + if not messages: + raise ValueError("Magentic orchestrator requires at least one message to start the workflow.") - async def _handle_plan_review_response( - self, - original_request: _MagenticHumanInterventionRequest, - response: _MagenticHumanInterventionReply, - context: WorkflowContext[ - _MagenticResponseMessage | _MagenticRequestMessage | _MagenticHumanInterventionRequest, list[ChatMessage] - ], - ) -> None: - """Handle plan review response.""" - if self._context is None: - return + if len(messages) > 1: + raise ValueError("Magentic only support a single task message to start the workflow.") - is_approve = response.decision == MagenticHumanInterventionDecision.APPROVE - - if is_approve: - # Close the review loop on approval (no further plan review requests this run) - self._require_plan_signoff = False - # If the user supplied an edited plan, adopt it - if response.edited_plan_text: - # Update the manager's internal ledger and rebuild the combined message - mgr_ledger = getattr(self._manager, "task_ledger", None) - if mgr_ledger is not None: - mgr_ledger.plan.text = response.edited_plan_text - team_text = _team_block(self._participants) - combined = self._manager.task_ledger_full_prompt.format( - task=self._context.task.text, - team=team_text, - facts=(mgr_ledger.facts.text if mgr_ledger else ""), - plan=response.edited_plan_text, - ) - self._task_ledger = ChatMessage( - role=Role.ASSISTANT, - text=combined, - author_name=MAGENTIC_MANAGER_NAME, - ) - # If approved with comments but no edited text, apply comments via replan and proceed - elif response.comments: - self._context.chat_history.append( - ChatMessage(role=Role.USER, text=f"Human plan feedback: {response.comments}") - ) - self._task_ledger = await self._manager.replan(self._context.clone(deep=True)) + if messages[0].text.strip() == "": + raise ValueError("Magentic task message must contain non-empty text.") - # Record the signed-off plan (no broadcast) - if self._task_ledger: - self._context.chat_history.append(self._task_ledger) - await self._emit_orchestrator_message(context, self._task_ledger, ORCH_MSG_KIND_TASK_LEDGER) + self._magentic_context = MagenticContext( + task=messages[0].text, + participant_descriptions=self._participant_registry.participants, + chat_history=list(messages), + ) - # Enter the normal coordination loop - ctx2 = cast( - WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], - context, + # Initial planning using the manager with real model calls + self._task_ledger = await self._manager.plan(self._magentic_context.clone(deep=True)) + await ctx.add_event( + MagenticOrchestratorEvent( + executor_id=self.id, + event_type=MagenticOrchestratorEventType.PLAN_CREATED, + data=self._task_ledger, ) - await self._run_inner_loop(ctx2) - return + ) - # Otherwise, REVISION round - self._plan_review_round += 1 - if self._plan_review_round > self._max_plan_review_rounds: - logger.warning("Magentic Orchestrator: Max plan review rounds reached. Proceeding with current plan.") - self._require_plan_signoff = False - notice = ChatMessage( - role=Role.ASSISTANT, - text=( - "Plan review closed after max rounds. Proceeding with the current plan and will no longer " - "prompt for plan approval." - ), - author_name=MAGENTIC_MANAGER_NAME, - ) - self._context.chat_history.append(notice) - await self._emit_orchestrator_message(context, notice, ORCH_MSG_KIND_NOTICE) - if self._task_ledger: - self._context.chat_history.append(self._task_ledger) - ctx2 = cast( - WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], - context, - ) - await self._run_inner_loop(ctx2) + # If a human must sign off, ask now and return. The response handler will resume. + if self._require_plan_signoff: + await self._send_plan_review_request(cast(WorkflowContext, ctx)) return - # If the user provided an edited plan, adopt it and ask for confirmation - if response.edited_plan_text: - mgr_ledger2 = getattr(self._manager, "task_ledger", None) - if mgr_ledger2 is not None: - mgr_ledger2.plan.text = response.edited_plan_text - team_text = _team_block(self._participants) - combined = self._manager.task_ledger_full_prompt.format( - task=self._context.task.text, - team=team_text, - facts=(mgr_ledger2.facts.text if mgr_ledger2 else ""), - plan=response.edited_plan_text, - ) - self._task_ledger = ChatMessage(role=Role.ASSISTANT, text=combined, author_name=MAGENTIC_MANAGER_NAME) - await self._send_plan_review_request(cast(WorkflowContext, context)) - return + # Add task ledger to conversation history + self._magentic_context.chat_history.append(self._task_ledger) - # Else pass comments into the chat history and replan - if response.comments: - self._context.chat_history.append( - ChatMessage(role=Role.USER, text=f"Human plan feedback: {response.comments}") - ) + logger.debug("Task ledger created.") - self._task_ledger = await self._manager.replan(self._context.clone(deep=True)) - await self._send_plan_review_request(cast(WorkflowContext, context)) + # Start the inner loop + await self._run_inner_loop(ctx) - async def _handle_stall_intervention_response( + @override + async def _handle_response( self, - original_request: _MagenticHumanInterventionRequest, - response: _MagenticHumanInterventionReply, - context: WorkflowContext[ - _MagenticResponseMessage | _MagenticRequestMessage | _MagenticHumanInterventionRequest, list[ChatMessage] - ], + response: AgentExecutorResponse | GroupChatResponseMessage, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Handle stall intervention response.""" - if self._context is None: - return + """Handle a response message from a participant.""" + if self._magentic_context is None or self._task_ledger is None: + raise RuntimeError("Context or task ledger not initialized") - ctx = self._context - logger.info( - f"Stall intervention response: decision={response.decision.value}, " - f"stall_count was {original_request.stall_count}" - ) + messages = self._process_participant_response(response) - if response.decision == MagenticHumanInterventionDecision.CONTINUE: - ctx.stall_count = 0 - ctx2 = cast( - WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], - context, - ) - await self._run_inner_loop(ctx2) - return + self._magentic_context.chat_history.extend(messages) - if response.decision == MagenticHumanInterventionDecision.REPLAN: - ctx2 = cast( - WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], - context, - ) - await self._reset_and_replan(ctx2) - return + # Broadcast participant messages to all participants for context, except + # the participant that just responded + participant = ctx.get_source_executor_id() + await self._broadcast_messages_to_participants( + messages, + cast(WorkflowContext[AgentExecutorRequest | GroupChatParticipantMessage], ctx), + participants=[p for p in self._participant_registry.participants if p != participant], + ) - if response.decision == MagenticHumanInterventionDecision.GUIDANCE: - ctx.stall_count = 0 - guidance = response.comments or response.response_text - if guidance: - guidance_msg = ChatMessage( - role=Role.USER, - text=f"Human guidance to help with stall: {guidance}", - ) - ctx.chat_history.append(guidance_msg) - await self._emit_orchestrator_message(context, guidance_msg, ORCH_MSG_KIND_NOTICE) - ctx2 = cast( - WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], - context, - ) - await self._run_inner_loop(ctx2) - return + await self._run_inner_loop(ctx) - async def _run_outer_loop( + @response_handler + async def handle_plan_review_response( self, - context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], + original_request: MagenticPlanReviewRequest, + response: MagenticPlanReviewResponse, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Run the outer orchestration loop - planning phase.""" - if self._context is None: - raise RuntimeError("Context not initialized") - - logger.info("Magentic Orchestrator: Outer loop - entering inner loop") + """Handle the human response to the plan review request. + + Logic: + There are code paths which will trigger a plan review request to the human: + - Initial plan creation if `require_plan_signoff` is True. + - Potentially during the inner loop if stalling is detected (resetting and replanning). + + The human can either approve the plan or request revisions with comments. + - If approved, proceed to run the outer loop, which simply adds the task ledger + to the conversation and enters the inner loop. + - If revision requested, append the review comments to the chat history, + trigger replanning via the manager, emit a REPLANNED event, then run the outer loop. + """ + if self._magentic_context is None or self._task_ledger is None: + raise RuntimeError("Context or task ledger not initialized") - # Add task ledger to history if not already there - if self._task_ledger and ( - not self._context.chat_history or self._context.chat_history[-1] != self._task_ledger - ): - self._context.chat_history.append(self._task_ledger) + # Case 1: Approved + if len(response.review) == 0: + logger.debug("Magentic Orchestrator: Plan review approved by human.") + await self._run_outer_loop(ctx) + return + # Case 2: Revision requested + logger.debug("Magentic Orchestrator: Plan review revision requested by human.") + self._magentic_context.chat_history.extend(response.review) + self._task_ledger = await self._manager.replan(self._magentic_context.clone(deep=True)) + await ctx.add_event( + MagenticOrchestratorEvent( + executor_id=self.id, + event_type=MagenticOrchestratorEventType.REPLANNED, + data=self._task_ledger, + ) + ) + # Continue the review process by sending the new plan for review again until approved + # We don't need to check if `_require_plan_signoff` is True here, since we are already + # in the review process. + await self._send_plan_review_request(cast(WorkflowContext, ctx), is_stalled=original_request.is_stalled) - # Optionally surface the updated task ledger via message callback (no broadcast) - if self._task_ledger is not None: - await self._emit_orchestrator_message(context, self._task_ledger, ORCH_MSG_KIND_TASK_LEDGER) + async def _send_plan_review_request(self, ctx: WorkflowContext, is_stalled: bool = False) -> None: + """Send a human intervention request for plan review. - # Start inner loop - await self._run_inner_loop(context) + The response will be handled in the response handler `handle_plan_review_response`. + """ + if self._task_ledger is None: + raise RuntimeError("No task ledger available for plan review request.") + + await ctx.request_info( + MagenticPlanReviewRequest( + plan=self._task_ledger, + current_progress=self._progress_ledger, + is_stalled=is_stalled, + ), + MagenticPlanReviewResponse, + ) async def _run_inner_loop( self, - context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: """Run the inner orchestration loop. Coordination phase. Serialized with a lock.""" - if self._context is None or self._task_ledger is None: + if self._magentic_context is None or self._task_ledger is None: raise RuntimeError("Context or task ledger not initialized") - await self._run_inner_loop_helper(context) + await self._run_inner_loop_helper(ctx) async def _run_inner_loop_helper( self, - context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: """Run inner loop with exclusive access.""" # Narrow optional context for the remainder of this method - ctx = self._context - if ctx is None: + if self._magentic_context is None: raise RuntimeError("Context not initialized") # Check limits first - within_limits = await self._check_within_limits_or_complete(context) + within_limits = await self._check_within_limits_or_complete( + cast(WorkflowContext[Never, list[ChatMessage]], ctx) + ) if not within_limits: return - ctx.round_count += 1 - logger.info(f"Magentic Orchestrator: Inner loop - round {ctx.round_count}") + self._magentic_context.round_count += 1 + self._increment_round() + logger.debug(f"Magentic Orchestrator: Inner loop - round {self._round_index}") # Create progress ledger using the manager try: - current_progress_ledger = await self._manager.create_progress_ledger(ctx.clone(deep=True)) + self._progress_ledger = await self._manager.create_progress_ledger(self._magentic_context.clone(deep=True)) except Exception as ex: logger.warning(f"Magentic Orchestrator: Progress ledger creation failed, triggering reset: {ex}") - await self._reset_and_replan(context) + await self._reset_and_replan(ctx) return + await ctx.add_event( + MagenticOrchestratorEvent( + executor_id=self.id, + event_type=MagenticOrchestratorEventType.PROGRESS_LEDGER_UPDATED, + data=self._progress_ledger, + ) + ) + logger.debug( - f"Progress evaluation: satisfied={current_progress_ledger.is_request_satisfied.answer}, " - f"next={current_progress_ledger.next_speaker.answer}" + f"Progress evaluation: satisfied={self._progress_ledger.is_request_satisfied.answer}, " + f"next={self._progress_ledger.next_speaker.answer}" ) # Check for task completion - if current_progress_ledger.is_request_satisfied.answer: + if self._progress_ledger.is_request_satisfied.answer: logger.info("Magentic Orchestrator: Task completed") - await self._prepare_final_answer(context) + await self._prepare_final_answer(cast(WorkflowContext[Never, list[ChatMessage]], ctx)) return # Check for stalling or looping - if not current_progress_ledger.is_progress_being_made.answer or current_progress_ledger.is_in_loop.answer: - ctx.stall_count += 1 + if not self._progress_ledger.is_progress_being_made.answer or self._progress_ledger.is_in_loop.answer: + self._magentic_context.stall_count += 1 else: - ctx.stall_count = max(0, ctx.stall_count - 1) - - if ctx.stall_count > self._manager.max_stall_count: - logger.info(f"Magentic Orchestrator: Stalling detected after {ctx.stall_count} rounds") - if self._enable_stall_intervention: - # Request human intervention instead of auto-replan - is_progress = current_progress_ledger.is_progress_being_made.answer - is_loop = current_progress_ledger.is_in_loop.answer - stall_reason = "No progress being made" if not is_progress else "" - if is_loop: - loop_msg = "Agents appear to be in a loop" - stall_reason = f"{stall_reason}; {loop_msg}" if stall_reason else loop_msg - next_speaker_val = current_progress_ledger.next_speaker.answer - last_agent = next_speaker_val if isinstance(next_speaker_val, str) else "" - # Get facts and plan from manager's task ledger - mgr_ledger = getattr(self._manager, "task_ledger", None) - facts_text = mgr_ledger.facts.text if mgr_ledger else "" - plan_text = mgr_ledger.plan.text if mgr_ledger else "" - request = _MagenticHumanInterventionRequest( - kind=MagenticHumanInterventionKind.STALL, - stall_count=ctx.stall_count, - max_stall_count=self._manager.max_stall_count, - task_text=ctx.task.text if ctx.task else "", - facts_text=facts_text, - plan_text=plan_text, - last_agent=last_agent, - stall_reason=stall_reason, - ) - await context.request_info(request, _MagenticHumanInterventionReply) - return - # Default behavior: auto-replan - await self._reset_and_replan(context) + self._magentic_context.stall_count = max(0, self._magentic_context.stall_count - 1) + + if self._magentic_context.stall_count > self._manager.max_stall_count: + logger.debug(f"Magentic Orchestrator: Stalling detected after {self._magentic_context.stall_count} rounds") + await self._reset_and_replan(ctx) return # Determine the next speaker and instruction - answer_val = current_progress_ledger.next_speaker.answer - if not isinstance(answer_val, str): + next_speaker = self._progress_ledger.next_speaker.answer + if not isinstance(next_speaker, str): # Fallback to first participant if ledger returns non-string logger.warning("Next speaker answer was not a string; selecting first participant as fallback") - answer_val = next(iter(self._participants.keys())) - next_speaker_value: str = answer_val - instruction = current_progress_ledger.instruction_or_question.answer + next_speaker = next(iter(self._participant_registry.participants.keys())) + instruction = self._progress_ledger.instruction_or_question.answer - if next_speaker_value not in self._participants: - logger.warning(f"Invalid next speaker: {next_speaker_value}") - await self._prepare_final_answer(context) + if next_speaker not in self._participant_registry.participants: + logger.warning(f"Invalid next speaker: {next_speaker}") + await self._prepare_final_answer(cast(WorkflowContext[Never, list[ChatMessage]], ctx)) return # Add instruction to conversation (assistant guidance) @@ -1585,505 +1123,232 @@ async def _run_inner_loop_helper( text=str(instruction), author_name=MAGENTIC_MANAGER_NAME, ) - ctx.chat_history.append(instruction_msg) - await self._emit_orchestrator_message(context, instruction_msg, ORCH_MSG_KIND_INSTRUCTION) - - # Determine the selected agent's executor id - target_executor_id = f"agent_{next_speaker_value}" + self._magentic_context.chat_history.append(instruction_msg) # Request specific agent to respond - logger.debug(f"Magentic Orchestrator: Requesting {next_speaker_value} to respond") - await context.send_message( - _MagenticRequestMessage( - agent_name=next_speaker_value, - instruction=str(instruction), - task_context=ctx.task.text, - ), - target_id=target_executor_id, + logger.debug(f"Magentic Orchestrator: Requesting {next_speaker} to respond") + await self._send_request_to_participant( + next_speaker, + cast(WorkflowContext[AgentExecutorRequest | GroupChatRequestMessage], ctx), + additional_instruction=str(instruction), ) async def _reset_and_replan( self, - context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: """Reset context and replan.""" - if self._context is None: - return + if self._magentic_context is None: + raise RuntimeError("Context not initialized") - logger.info("Magentic Orchestrator: Resetting and replanning") + logger.debug("Magentic Orchestrator: Resetting and replanning") # Reset context - self._context.reset() + self._magentic_context.reset() - # Replan - self._task_ledger = await self._manager.replan(self._context.clone(deep=True)) - self._context.chat_history.append(self._task_ledger) - await self._emit_orchestrator_message(context, self._task_ledger, ORCH_MSG_KIND_TASK_LEDGER) + # Reset all participant states + await self._reset_participants(cast(WorkflowContext[MagenticResetSignal], ctx)) - # Internally reset all registered agent executors (no handler/messages involved) - for agent in self._agent_executors.values(): - with contextlib.suppress(Exception): - agent.reset() + # Replan + self._task_ledger = await self._manager.replan(self._magentic_context.clone(deep=True)) + await ctx.add_event( + MagenticOrchestratorEvent( + executor_id=self.id, + event_type=MagenticOrchestratorEventType.REPLANNED, + data=self._task_ledger, + ) + ) + # If a human must sign off, ask now and return. The response handler will resume. + if self._require_plan_signoff: + await self._send_plan_review_request(cast(WorkflowContext, ctx), is_stalled=True) + return + + self._magentic_context.chat_history.append(self._task_ledger) # Restart outer loop - await self._run_outer_loop(context) + await self._run_outer_loop(ctx) - async def _prepare_final_answer( + async def _run_outer_loop( self, - context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: + """Run the outer orchestration loop - planning phase.""" + if self._magentic_context is None: + raise RuntimeError("Context not initialized") + + logger.debug("Magentic Orchestrator: Outer loop - entering inner loop") + + # Add task ledger to history if not already there + if self._task_ledger and ( + not self._magentic_context.chat_history or self._magentic_context.chat_history[-1] != self._task_ledger + ): + self._magentic_context.chat_history.append(self._task_ledger) + + # Start inner loop + await self._run_inner_loop(ctx) + + async def _prepare_final_answer(self, ctx: WorkflowContext[Never, list[ChatMessage]]) -> None: """Prepare the final answer using the manager.""" - if self._context is None: - return + if self._magentic_context is None: + raise RuntimeError("Context not initialized") logger.info("Magentic Orchestrator: Preparing final answer") - final_answer = await self._manager.prepare_final_answer(self._context.clone(deep=True)) + final_answer = await self._manager.prepare_final_answer(self._magentic_context.clone(deep=True)) # Emit a completed event for the workflow - await context.yield_output([final_answer]) + await ctx.yield_output([final_answer]) - async def _check_within_limits_or_complete( - self, - context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], - ) -> bool: - """Check if orchestrator is within operational limits.""" - if self._context is None: - return False - ctx = self._context + self._terminated = True - hit_round_limit = self._manager.max_round_count is not None and ctx.round_count >= self._manager.max_round_count - hit_reset_limit = self._manager.max_reset_count is not None and ctx.reset_count >= self._manager.max_reset_count + async def _check_within_limits_or_complete(self, ctx: WorkflowContext[Never, list[ChatMessage]]) -> bool: + """Check if orchestrator is within operational limits. - if hit_round_limit or hit_reset_limit: - limit_type = "round" if hit_round_limit else "reset" - logger.error(f"Magentic Orchestrator: Max {limit_type} count reached") + If limits are exceeded, yield a termination message and mark the workflow as terminated. - # Only emit completion once and then mark terminated - if not self._terminated: - self._terminated = True - # Get partial result - partial_result = _first_assistant(ctx.chat_history) - if partial_result is None: - partial_result = ChatMessage( - role=Role.ASSISTANT, - text=f"Stopped due to {limit_type} limit. No partial result available.", - author_name=MAGENTIC_MANAGER_NAME, - ) - - # Yield the partial result and signal completion - await context.yield_output([partial_result]) - return False + Args: + ctx: The workflow context. - return True + Returns: + True if within limits, False if limits exceeded and workflow is terminated. + """ + if self._magentic_context is None: + raise RuntimeError("Context not initialized") - async def _send_plan_review_request(self, context: WorkflowContext) -> None: - """Send a human intervention request for plan review.""" - # If plan sign-off is disabled (e.g., ran out of review rounds), do nothing - if not self._require_plan_signoff: - return - ledger = getattr(self._manager, "task_ledger", None) - facts_text = ledger.facts.text if ledger else "" - plan_text = ledger.plan.text if ledger else "" - task_text = self._context.task.text if self._context else "" - - req = _MagenticHumanInterventionRequest( - kind=MagenticHumanInterventionKind.PLAN_REVIEW, - task_text=task_text, - facts_text=facts_text, - plan_text=plan_text, - round_index=self._plan_review_round, + hit_round_limit = self._max_rounds is not None and self._round_index >= self._max_rounds + hit_reset_limit = ( + self._manager.max_reset_count is not None + and self._magentic_context.reset_count >= self._manager.max_reset_count ) - await context.request_info(req, _MagenticHumanInterventionReply) - -# region Magentic Executors + if hit_round_limit or hit_reset_limit: + limit_type = "round" if hit_round_limit else "reset" + logger.error(f"Magentic Orchestrator: Max {limit_type} count reached") + # Yield the full conversation with an indication of termination due to limits + await ctx.yield_output([ + *self._magentic_context.chat_history, + ChatMessage( + role=Role.ASSISTANT, + text=f"Workflow terminated due to reaching maximum {limit_type} count.", + author_name=MAGENTIC_MANAGER_NAME, + ), + ]) + self._terminated = True -class MagenticAgentExecutor(Executor): - """Magentic agent executor that wraps an agent for participation in workflows. + return False - Leverages enhanced AgentExecutor with conversation injection hooks for: - - Receiving task ledger broadcasts - - Responding to specific agent requests - - Resetting agent state when needed - - Surfacing tool approval requests (user_input_requests) as HITL events - """ + return True - def __init__( - self, - agent: AgentProtocol | Executor, - agent_id: str, - ) -> None: - super().__init__(f"agent_{agent_id}") - self._agent = agent - self._agent_id = agent_id - self._chat_history: list[ChatMessage] = [] - self._pending_human_input_request: _MagenticHumanInterventionRequest | None = None - self._pending_tool_request: FunctionApprovalRequestContent | None = None - self._current_request_message: _MagenticRequestMessage | None = None + async def _reset_participants(self, ctx: WorkflowContext[MagenticResetSignal]) -> None: + """Reset all participant executors.""" + # Orchestrator is connected to all participants. Sending the message without specifying + # a target will broadcast to all. + await ctx.send_message(MagenticResetSignal()) @override async def on_checkpoint_save(self) -> dict[str, Any]: - """Capture current executor state for checkpointing. + """Capture current orchestrator state for checkpointing.""" + state = await super().on_checkpoint_save() + state["terminated"] = self._terminated - Returns: - Dict containing serialized chat history - """ - from ._conversation_state import encode_chat_messages + if self._magentic_context is not None: + state["magentic_context"] = self._magentic_context.to_dict() + if self._task_ledger is not None: + state["task_ledger"] = _message_to_payload(self._task_ledger) + if self._progress_ledger is not None: + state["progress_ledger"] = self._progress_ledger.to_dict() - return { - "chat_history": encode_chat_messages(self._chat_history), - } + try: + state["manager_state"] = self._manager.on_checkpoint_save() + except Exception as exc: + logger.warning(f"Failed to save manager state for checkpoint: {exc}\nSkipping...") + + return state @override async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: - """Restore executor state from checkpoint. + """Restore executor state from checkpoint.""" + await super().on_checkpoint_restore(state) + self._terminated = state.get("terminated", False) - Args: - state: Checkpoint data dict - """ - from ._conversation_state import decode_chat_messages - - history_payload = state.get("chat_history") - if history_payload: + magentic_context_data = state.get("magentic_context") + if magentic_context_data is not None: try: - self._chat_history = decode_chat_messages(history_payload) - except Exception as exc: # pragma: no cover - logger.warning(f"Agent {self._agent_id}: Failed to restore chat history: {exc}") - self._chat_history = [] - else: - self._chat_history = [] - - @handler - async def handle_response_message( - self, message: _MagenticResponseMessage, context: WorkflowContext[_MagenticResponseMessage] - ) -> None: - """Handle response message (task ledger broadcast).""" - logger.debug(f"Agent {self._agent_id}: Received response message") - - # Check if this message is intended for this agent - if message.target_agent is not None and message.target_agent != self._agent_id and not message.broadcast: - # Message is targeted to a different agent, ignore it - logger.debug(f"Agent {self._agent_id}: Ignoring message targeted to {message.target_agent}") - return - - # Add transfer message if needed - if message.body.role != Role.USER: - transfer_msg = ChatMessage( - role=Role.USER, - text=f"Transferred to {getattr(message.body, 'author_name', 'agent')}", - ) - self._chat_history.append(transfer_msg) - - # Add message to agent's history - self._chat_history.append(message.body) - - def _get_persona_adoption_role(self) -> Role: - """Determine the best role for persona adoption messages. - - Uses SYSTEM role if the agent supports it, otherwise falls back to USER. - """ - # Only BaseAgent-derived agents are assumed to support SYSTEM messages reliably. - from agent_framework import BaseAgent as _AF_AgentBase # local import to avoid cycles - - if isinstance(self._agent, _AF_AgentBase) and hasattr(self._agent, "chat_client"): - return Role.SYSTEM - # For other agent types or when we can't determine support, use USER - return Role.USER + self._magentic_context = MagenticContext.from_dict(magentic_context_data) + except Exception: # pragma: no cover - defensive + logger.warning("Failed to restore Magentic context from checkpoint data") + self._magentic_context = None - @handler - async def handle_request_message( - self, message: _MagenticRequestMessage, context: WorkflowContext[_MagenticResponseMessage, AgentRunResponse] - ) -> None: - """Handle request to respond.""" - if message.agent_name != self._agent_id: - return + task_ledger_data = state.get("task_ledger") + if task_ledger_data is not None: + try: + self._task_ledger = _message_from_payload(task_ledger_data) + except Exception: # pragma: no cover - defensive + logger.warning("Failed to restore task ledger from checkpoint data") + self._task_ledger = None - logger.info(f"Agent {self._agent_id}: Received request to respond") + progress_ledger_data = state.get("progress_ledger") + if progress_ledger_data is not None: + try: + self._progress_ledger = MagenticProgressLedger.from_dict(progress_ledger_data) + except Exception: # pragma: no cover - defensive + logger.warning("Failed to restore progress ledger from checkpoint data") + self._progress_ledger = None - # Store the original request message for potential continuation after human input - self._current_request_message = message + manager_state = state.get("manager_state") + if manager_state is not None: + try: + self._manager.on_checkpoint_restore(manager_state) + except Exception as exc: + logger.warning(f"Failed to restore manager state from checkpoint: {exc}\nSkipping...") - # Add persona adoption message with appropriate role - persona_role = self._get_persona_adoption_role() - persona_msg = ChatMessage( - role=persona_role, - text=f"Transferred to {self._agent_id}, adopt the persona immediately.", - ) - self._chat_history.append(persona_msg) - # Add the orchestrator's instruction as a USER message so the agent treats it as the prompt - if message.instruction: - self._chat_history.append(ChatMessage(role=Role.USER, text=message.instruction)) - try: - # If the participant is not an invokable BaseAgent, return a no-op response. - from agent_framework import BaseAgent as _AF_AgentBase # local import to avoid cycles +# endregion Magentic Orchestrator - if not isinstance(self._agent, _AF_AgentBase): - response: ChatMessage = ChatMessage( - role=Role.ASSISTANT, - text=f"{self._agent_id} is a workflow executor and cannot be invoked directly.", - author_name=self._agent_id, - ) - self._chat_history.append(response) - await self._emit_agent_message_event(context, response) - await context.send_message(_MagenticResponseMessage(body=response)) - else: - # Invoke the agent - agent_response = await self._invoke_agent(context) - if agent_response is None: - # Agent is waiting for human input - don't send response yet - return - self._chat_history.append(agent_response) - # Send response back to orchestrator - await context.send_message(_MagenticResponseMessage(body=agent_response)) - - except Exception as e: - logger.warning(f"Agent {self._agent_id} invoke failed: {e}") - # Fallback response - response = ChatMessage( - role=Role.ASSISTANT, - text=f"Agent {self._agent_id}: Error processing request - {str(e)[:100]}", - ) - self._chat_history.append(response) - await self._emit_agent_message_event(context, response) - await context.send_message(_MagenticResponseMessage(body=response)) +# region Magentic Agent Executor - def reset(self) -> None: - """Reset the internal chat history of the agent (internal operation).""" - logger.debug(f"Agent {self._agent_id}: Resetting chat history") - self._chat_history.clear() - self._pending_human_input_request = None - self._pending_tool_request = None - self._current_request_message = None - @response_handler - async def handle_tool_approval_response( - self, - original_request: _MagenticHumanInterventionRequest, - response: _MagenticHumanInterventionReply, - context: WorkflowContext[_MagenticResponseMessage, AgentRunResponse], - ) -> None: - """Handle human response for tool approval and continue agent execution. +class MagenticAgentExecutor(AgentExecutor): + """Specialized AgentExecutor for Magentic agent participants.""" - When a human provides input in response to a tool approval request, - this handler processes the response based on the decision type: + def __init__(self, agent: AgentProtocol) -> None: + """Initialize a Magentic Agent Executor. - - APPROVE: Execute the tool call with the provided response text - - REJECT: Do not execute the tool, inform the agent of rejection - - GUIDANCE: Execute the tool call with the guidance text as input + This executor wraps an AgentProtocol instance to be used as a participant + in a Magentic One workflow. Args: - original_request: The original human intervention request - response: The human's response containing the decision and any text - context: The workflow context - """ - response_text = response.response_text or response.comments or "" - decision = response.decision - logger.info( - f"Agent {original_request.agent_id}: Received tool approval response " - f"(decision={decision.value}): {response_text[:50] if response_text else ''}" - ) - - # Get the pending tool request to extract call_id - pending_tool_request = self._pending_tool_request - self._pending_human_input_request = None - self._pending_tool_request = None - - # Handle REJECT decision - do not execute the tool call - if decision == MagenticHumanInterventionDecision.REJECT: - rejection_reason = response_text or "Tool call rejected by human" - logger.info(f"Agent {self._agent_id}: Tool call rejected: {rejection_reason}") - - if pending_tool_request is not None: - # Create a FunctionResultContent indicating rejection - function_result = FunctionResultContent( - call_id=pending_tool_request.function_call.call_id, - result=f"Tool call was rejected by human reviewer. Reason: {rejection_reason}", - ) - result_msg = ChatMessage( - role=Role.USER, - contents=[function_result], - ) - self._chat_history.append(result_msg) - else: - # Fallback without pending tool request - rejection_msg = ChatMessage( - role=Role.USER, - text=f"Tool call '{original_request.prompt}' was rejected: {rejection_reason}", - author_name="human", - ) - self._chat_history.append(rejection_msg) - - # Re-invoke the agent so it can adapt to the rejection - agent_response = await self._invoke_agent(context) - if agent_response is None: - return - self._chat_history.append(agent_response) - await context.send_message(_MagenticResponseMessage(body=agent_response)) - return + agent: The agent instance to wrap. - # Handle APPROVE and GUIDANCE decisions - execute the tool call - if pending_tool_request is not None: - # Create a FunctionResultContent with the human's response - function_result = FunctionResultContent( - call_id=pending_tool_request.function_call.call_id, - result=response_text, - ) - # Add the function result as a message to continue the conversation - result_msg = ChatMessage( - role=Role.USER, - contents=[function_result], - ) - self._chat_history.append(result_msg) - - # Re-invoke the agent to continue execution - agent_response = await self._invoke_agent(context) - if agent_response is None: - # Agent is waiting for more human input - return - self._chat_history.append(agent_response) - await context.send_message(_MagenticResponseMessage(body=agent_response)) - else: - # Fallback: no pending tool request, just add as text message - logger.warning( - f"Agent {original_request.agent_id}: No pending tool request found for response, " - "using fallback text handling", - ) - human_response_msg = ChatMessage( - role=Role.USER, - text=f"Human response to '{original_request.prompt}': {response_text}", - author_name="human", - ) - self._chat_history.append(human_response_msg) - - # Create a response message indicating human input was received - agent_response_msg = ChatMessage( - role=Role.ASSISTANT, - text=f"Received human input for: {original_request.prompt}. Continuing with the task.", - author_name=original_request.agent_id, - ) - self._chat_history.append(agent_response_msg) - await context.send_message(_MagenticResponseMessage(body=agent_response_msg)) - - async def _emit_agent_delta_event( - self, - ctx: WorkflowContext[Any, Any], - update: AgentRunResponseUpdate, - ) -> None: - # Add metadata to identify this as an agent streaming update - props = update.additional_properties - if props is None: - props = {} - update.additional_properties = props - props["magentic_event_type"] = MAGENTIC_EVENT_TYPE_AGENT_DELTA - props["agent_id"] = self._agent_id - - # Emit AgentRunUpdateEvent with the agent response update - await ctx.add_event(AgentRunUpdateEvent(executor_id=self._agent_id, data=update)) - - async def _emit_agent_message_event( - self, - ctx: WorkflowContext[Any, Any], - message: ChatMessage, - ) -> None: - # Agent message completion is already communicated via streaming updates - # No additional event needed - pass - - async def _invoke_agent( - self, - ctx: WorkflowContext[_MagenticResponseMessage, AgentRunResponse], - ) -> ChatMessage | None: - """Invoke the wrapped agent and return a response. - - This method streams the agent's response updates, collects them into an - AgentRunResponse, and handles any human input requests (tool approvals). - - Note: - If multiple user input requests are present in the agent's response, - only the first one is processed. A warning is logged and subsequent - requests are ignored. This is a current limitation of the single-request - pending state architecture. - - Returns: - ChatMessage with the agent's response, or None if waiting for human input. + Notes: Magentic pattern requires a reset operation upon replanning. This executor + extends the base AgentExecutor to handle resets appropriately. In order to handle + resets, the agent threads and other states are reset when requested by the orchestrator. + And because of this, MagenticAgentExecutor does not support custom threads. """ - logger.debug(f"Agent {self._agent_id}: Running with {len(self._chat_history)} messages") - - run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) - - updates: list[AgentRunResponseUpdate] = [] - # The wrapped participant is guaranteed to be an BaseAgent when this is called. - agent = cast("AgentProtocol", self._agent) - async for update in agent.run_stream(messages=self._chat_history, **run_kwargs): # type: ignore[attr-defined] - updates.append(update) - await self._emit_agent_delta_event(ctx, update) - - run_result: AgentRunResponse = AgentRunResponse.from_agent_run_response_updates(updates) - - # Handle human input requests (tool approval) - process one at a time - if run_result.user_input_requests: - if len(run_result.user_input_requests) > 1: - logger.warning( - f"Agent {self._agent_id}: Multiple user input requests received " - f"({len(run_result.user_input_requests)}), processing only the first one" - ) - - user_input_request = run_result.user_input_requests[0] + super().__init__(agent) - # Build a prompt from the request based on its type - prompt: str - context_text: str | None = None + @handler + async def handle_magentic_reset(self, signal: MagenticResetSignal, ctx: WorkflowContext) -> None: + """Handle reset signal from the Magentic orchestrator. - if isinstance(user_input_request, FunctionApprovalRequestContent): - fn_call = user_input_request.function_call - prompt = f"Approve function call: {fn_call.name}" - if fn_call.arguments: - context_text = f"Arguments: {fn_call.arguments}" - else: - # Fallback for unknown request types - request_type = type(user_input_request).__name__ - prompt = f"Agent {self._agent_id} requires human input ({request_type})" - logger.warning(f"Agent {self._agent_id}: Unrecognized user input request type: {request_type}") - - # Store the original FunctionApprovalRequestContent for later use - self._pending_tool_request = user_input_request - - # Create and send the human intervention request for tool approval - request = _MagenticHumanInterventionRequest( - kind=MagenticHumanInterventionKind.TOOL_APPROVAL, - agent_id=self._agent_id, - prompt=prompt, - context=context_text, - conversation_snapshot=list(self._chat_history[-5:]), - ) - self._pending_human_input_request = request - await ctx.request_info(request, _MagenticHumanInterventionReply) - return None # Signal that we're waiting for human input + This method resets the internal state of the agent executor, including + any threads or caches, to prepare for a fresh start after replanning. - messages: list[ChatMessage] | None = None - with contextlib.suppress(Exception): - messages = list(run_result.messages) # type: ignore[assignment] - if messages and len(messages) > 0: - last: ChatMessage = messages[-1] - author = last.author_name or self._agent_id - role: Role = last.role if last.role else Role.ASSISTANT - text = last.text or "" - msg = ChatMessage(role=role, text=text, author_name=author) - await self._emit_agent_message_event(ctx, msg) - return msg - - msg = ChatMessage( - role=Role.ASSISTANT, - text=f"Agent {self._agent_id}: No output produced", - author_name=self._agent_id, - ) - await self._emit_agent_message_event(ctx, msg) - return msg + Args: + signal: The MagenticResetSignal instance. + ctx: The workflow context. + """ + # Message related + self._cache.clear() + self._full_conversation.clear() + # Request into related + self._pending_agent_requests.clear() + self._pending_responses_to_agent.clear() + # Reset threads + self._agent_thread = self._agent.get_new_thread() -# endregion Magentic Executors +# endregion Magentic Agent Executor # region Magentic Workflow Builder @@ -2108,51 +1373,6 @@ class MagenticBuilder: These emit `MagenticHumanInterventionRequest` events that provide structured decision options (APPROVE, REVISE, CONTINUE, REPLAN, GUIDANCE) appropriate for Magentic's planning-based orchestration. - - Usage: - - .. code-block:: python - - from agent_framework import MagenticBuilder, StandardMagenticManager - from azure.ai.projects.aio import AIProjectClient - - # Create manager with LLM client - project_client = AIProjectClient.from_connection_string(...) - chat_client = project_client.inference.get_chat_completions_client() - - # Build Magentic workflow with agents - workflow = ( - MagenticBuilder() - .participants(researcher=research_agent, writer=writing_agent, coder=coding_agent) - .with_standard_manager(chat_client=chat_client, max_round_count=20, max_stall_count=3) - .with_plan_review(enable=True) - .with_checkpointing(checkpoint_storage) - .build() - ) - - # Execute workflow - async for message in workflow.run("Research and write article about AI agents"): - print(message.text) - - With custom manager: - - .. code-block:: python - - # Create custom manager subclass - class MyCustomManager(MagenticManagerBase): - async def plan(self, context: MagenticContext) -> ChatMessage: - # Custom planning logic - ... - - - manager = MyCustomManager() - workflow = MagenticBuilder().participants(agent1=agent1, agent2=agent2).with_standard_manager(manager).build() - - See Also: - - :class:`MagenticManagerBase`: Base class for custom managers - - :class:`StandardMagenticManager`: Default LLM-powered manager - - :class:`MagenticContext`: Context object passed to manager methods - - :class:`MagenticEvent`: Base class for workflow events """ def __init__(self) -> None: @@ -2160,33 +1380,29 @@ def __init__(self) -> None: self._manager: MagenticManagerBase | None = None self._enable_plan_review: bool = False self._checkpoint_storage: CheckpointStorage | None = None - self._enable_stall_intervention: bool = False - def participants(self, **participants: AgentProtocol | Executor) -> Self: - """Add participant agents or executors to the Magentic workflow. + def participants(self, participants: Sequence[AgentProtocol | Executor]) -> Self: + """Define participants for this Magentic workflow. - Participants are the agents that will execute tasks under the manager's direction. - Each participant should have distinct capabilities that complement the team. The - manager will select which participant to invoke based on the current plan and - progress state. + Accepts AgentProtocol instances (auto-wrapped as AgentExecutor) or Executor instances. Args: - **participants: Named agents or executors to add to the workflow. Names should - be descriptive of the agent's role (e.g., researcher=research_agent). - Accepts BaseAgent instances or custom Executor implementations. + participants: Sequence of participant definitions Returns: Self for method chaining - Usage: + Raises: + ValueError: If participants are empty, names are duplicated, or already set + TypeError: If any participant is not AgentProtocol or Executor instance + + Example: .. code-block:: python workflow = ( MagenticBuilder() - .participants( - researcher=research_agent, writer=writing_agent, coder=coding_agent, reviewer=review_agent - ) + .participants([research_agent, writing_agent, coding_agent, review_agent]) .with_standard_manager(agent=manager_agent) .build() ) @@ -2196,7 +1412,33 @@ def participants(self, **participants: AgentProtocol | Executor) -> Self: - Agent descriptions (if available) are extracted and provided to the manager - Can be called multiple times to add participants incrementally """ - self._participants.update(participants) + if self._participants: + raise ValueError("participants have already been set. Call participants(...) at most once.") + + if not participants: + raise ValueError("participants cannot be empty.") + + # Name of the executor mapped to participant instance + named: dict[str, AgentProtocol | Executor] = {} + for participant in participants: + if isinstance(participant, Executor): + identifier = participant.id + elif isinstance(participant, AgentProtocol): + if not participant.name: + raise ValueError("AgentProtocol participants must have a non-empty name.") + identifier = participant.name + else: + raise TypeError( + f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}." + ) + + if identifier in named: + raise ValueError(f"Duplicate participant name '{identifier}' detected") + + named[identifier] = participant + + self._participants = named + return self def with_plan_review(self, enable: bool = True) -> "MagenticBuilder": @@ -2249,67 +1491,6 @@ def with_plan_review(self, enable: bool = True) -> "MagenticBuilder": self._enable_plan_review = enable return self - def with_human_input_on_stall(self, enable: bool = True) -> "MagenticBuilder": - """Enable human intervention when the workflow detects a stall. - - When enabled, instead of automatically replanning when the workflow detects - that agents are not making progress or are stuck in a loop, the workflow will - pause and emit a MagenticStallInterventionRequest event. A human can then - decide to continue, trigger replanning, or provide guidance. - - This is useful for: - - Workflows where automatic replanning may not resolve the issue - - Scenarios requiring human judgment about workflow direction - - Debugging stuck workflows with human insight - - Complex tasks where human guidance can help agents get back on track - - When stall detection triggers (based on max_stall_count), instead of calling - _reset_and_replan automatically, the workflow will: - 1. Emit a MagenticHumanInterventionRequest with kind=STALL - 2. Wait for human response via send_responses_streaming - 3. Take action based on the human's decision (continue, replan, or guidance) - - Args: - enable: Whether to enable stall intervention (default True) - - Returns: - Self for method chaining - - Usage: - - .. code-block:: python - - workflow = ( - MagenticBuilder() - .participants(agent1=agent1) - .with_standard_manager(agent=manager_agent, max_stall_count=3) - .with_human_input_on_stall(enable=True) - .build() - ) - - # During execution, handle human intervention requests - async for event in workflow.run_stream("task"): - if isinstance(event, RequestInfoEvent): - if event.request_type is MagenticHumanInterventionRequest: - request = event.data - if request.kind == MagenticHumanInterventionKind.STALL: - print(f"Workflow stalled: {request.stall_reason}") - reply = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.GUIDANCE, - comments="Focus on completing the current step first", - ) - responses = {event.request_id: reply} - async for ev in workflow.send_responses_streaming(responses): - ... - - See Also: - - :class:`MagenticHumanInterventionRequest`: Unified request type - - :class:`MagenticHumanInterventionDecision`: Decision options - - :meth:`with_standard_manager`: Configure max_stall_count for stall detection - """ - self._enable_stall_intervention = enable - return self - def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "MagenticBuilder": """Enable workflow state persistence using the provided checkpoint storage. @@ -2333,7 +1514,7 @@ def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "Magentic storage = InMemoryCheckpointStorage() workflow = ( MagenticBuilder() - .participants(agent1=agent1) + .participants([agent1]) .with_standard_manager(agent=manager_agent) .with_checkpointing(storage) .build() @@ -2432,7 +1613,7 @@ def with_standard_manager( manager_agent = ChatAgent( name="Coordinator", chat_client=OpenAIChatClient(model_id="gpt-4o"), - chat_options=ChatOptions(temperature=0.3, seed=42), + options=ChatOptions(temperature=0.3, seed=42), instructions="Be concise and focus on accuracy", ) @@ -2506,6 +1687,36 @@ async def plan(self, context: MagenticContext) -> ChatMessage: ) return self + def _resolve_orchestrator(self, participants: Sequence[Executor]) -> Executor: + """Determine the orchestrator to use for the workflow. + + Args: + participants: List of resolved participant executors + """ + if self._manager is None: + raise ValueError("No manager configured. Call with_standard_manager(...) before building the orchestrator.") + + return MagenticOrchestrator( + manager=self._manager, + participant_registry=ParticipantRegistry(participants), + require_plan_signoff=self._enable_plan_review, + ) + + def _resolve_participants(self) -> list[Executor]: + """Resolve participant instances into Executor objects.""" + executors: list[Executor] = [] + for participant in self._participants.values(): + if isinstance(participant, Executor): + executors.append(participant) + elif isinstance(participant, AgentProtocol): + executors.append(MagenticAgentExecutor(participant)) + else: + raise TypeError( + f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}." + ) + + return executors + def build(self) -> Workflow: """Build a Magentic workflow with the orchestrator and all agent executors.""" if not self._participants: @@ -2516,305 +1727,19 @@ def build(self) -> Workflow: logger.info(f"Building Magentic workflow with {len(self._participants)} participants") - # Create participant descriptions - participant_descriptions: dict[str, str] = {} - for name, participant in self._participants.items(): - fallback = f"Executor {name}" if isinstance(participant, Executor) else f"Agent {name}" - participant_descriptions[name] = participant_description(participant, fallback) - - # Type narrowing: we already checked self._manager is not None above - manager: MagenticManagerBase = self._manager # type: ignore[assignment] - enable_stall_intervention = self._enable_stall_intervention - - def _orchestrator_factory(wiring: _GroupChatConfig) -> Executor: - return MagenticOrchestratorExecutor( - manager=manager, - participants=participant_descriptions, - require_plan_signoff=self._enable_plan_review, - enable_stall_intervention=enable_stall_intervention, - executor_id="magentic_orchestrator", - ) - - def _participant_factory( - spec: GroupChatParticipantSpec, - wiring: _GroupChatConfig, - ) -> _GroupChatParticipantPipeline: - agent_executor = MagenticAgentExecutor( - spec.participant, - spec.name, - ) - orchestrator = wiring.orchestrator - if isinstance(orchestrator, MagenticOrchestratorExecutor): - orchestrator.register_agent_executor(spec.name, agent_executor) - return (agent_executor,) - - # Magentic provides its own orchestrator via custom factory, so no manager is needed - group_builder = GroupChatBuilder( - _orchestrator_factory=group_chat_orchestrator(_orchestrator_factory), - _participant_factory=_participant_factory, - ).participants(self._participants) + participants: list[Executor] = self._resolve_participants() + orchestrator: Executor = self._resolve_orchestrator(participants) + # Build workflow graph + workflow_builder = WorkflowBuilder().set_start_executor(orchestrator) + for participant in participants: + # Orchestrator and participant bi-directional edges + workflow_builder = workflow_builder.add_edge(orchestrator, participant) + workflow_builder = workflow_builder.add_edge(participant, orchestrator) if self._checkpoint_storage is not None: - group_builder = group_builder.with_checkpointing(self._checkpoint_storage) - - return group_builder.build() - - def start_with_string(self, task: str) -> "MagenticWorkflow": - """Build a Magentic workflow and return a wrapper with convenience methods for string tasks. - - Args: - task: The task description as a string. - - Returns: - A MagenticWorkflow wrapper that provides convenience methods for starting with strings. - """ - return MagenticWorkflow(self.build(), task) - - def start_with_message(self, task: ChatMessage) -> "MagenticWorkflow": - """Build a Magentic workflow and return a wrapper with convenience methods for ChatMessage tasks. - - Args: - task: The task as a ChatMessage. - - Returns: - A MagenticWorkflow wrapper that provides convenience methods. - """ - return MagenticWorkflow(self.build(), task.text) - - def start_with(self, task: str | ChatMessage) -> "MagenticWorkflow": - """Build a Magentic workflow and return a wrapper with convenience methods. - - Args: - task: The task description as a string or ChatMessage. + workflow_builder = workflow_builder.with_checkpointing(self._checkpoint_storage) - Returns: - A MagenticWorkflow wrapper that provides convenience methods. - """ - if isinstance(task, str): - return self.start_with_string(task) - return self.start_with_message(task) + return workflow_builder.build() # endregion Magentic Workflow Builder - - -# region Magentic Workflow - - -class MagenticWorkflow: - """A wrapper around the base Workflow that provides convenience methods for Magentic workflows.""" - - def __init__(self, workflow: Workflow, task_text: str | None = None): - self._workflow = workflow - self._task_text = task_text - - @property - def workflow(self) -> Workflow: - """Access the underlying workflow.""" - return self._workflow - - async def run_streaming_with_string(self, task_text: str, **kwargs: Any) -> AsyncIterable[WorkflowEvent]: - """Run the workflow with a task string. - - Args: - task_text: The task description as a string. - **kwargs: Additional keyword arguments to pass through to agent invocations. - These kwargs will be available in @ai_function tools via **kwargs. - - Yields: - WorkflowEvent: The events generated during the workflow execution. - """ - start_message = _MagenticStartMessage.from_string(task_text) - start_message.run_kwargs = kwargs - async for event in self._workflow.run_stream(start_message): - yield event - - async def run_streaming_with_message( - self, task_message: ChatMessage, **kwargs: Any - ) -> AsyncIterable[WorkflowEvent]: - """Run the workflow with a ChatMessage. - - Args: - task_message: The task as a ChatMessage. - **kwargs: Additional keyword arguments to pass through to agent invocations. - These kwargs will be available in @ai_function tools via **kwargs. - - Yields: - WorkflowEvent: The events generated during the workflow execution. - """ - start_message = _MagenticStartMessage(task_message, run_kwargs=kwargs) - async for event in self._workflow.run_stream(start_message): - yield event - - async def run_stream(self, message: Any | None = None, **kwargs: Any) -> AsyncIterable[WorkflowEvent]: - """Run the workflow with either a message object or the preset task string. - - Args: - message: The message to send. If None and task_text was provided during construction, - uses the preset task string. - **kwargs: Additional keyword arguments to pass through to agent invocations. - These kwargs will be available in @ai_function tools via **kwargs. - Example: workflow.run_stream("task", user_id="123", custom_data={...}) - - Yields: - WorkflowEvent: The events generated during the workflow execution. - """ - if message is None: - if self._task_text is None: - raise ValueError("No message provided and no preset task text available") - start_message = _MagenticStartMessage.from_string(self._task_text) - elif isinstance(message, str): - start_message = _MagenticStartMessage.from_string(message) - elif isinstance(message, (ChatMessage, list)): - start_message = _MagenticStartMessage(message) # type: ignore[arg-type] - else: - start_message = message - - # Attach kwargs to the start message - if isinstance(start_message, _MagenticStartMessage): - start_message.run_kwargs = kwargs - - async for event in self._workflow.run_stream(start_message): - yield event - - async def _validate_checkpoint_participants( - self, - checkpoint_id: str, - checkpoint_storage: CheckpointStorage | None = None, - ) -> None: - """Ensure participant roster matches the checkpoint before attempting restoration.""" - orchestrator = next( - ( - executor - for executor in self._workflow.executors.values() - if isinstance(executor, MagenticOrchestratorExecutor) - ), - None, - ) - if orchestrator is None: - return - - expected = getattr(orchestrator, "_participants", None) - if not expected: - return - - checkpoint: WorkflowCheckpoint | None = None - if checkpoint_storage is not None: - try: - checkpoint = await checkpoint_storage.load_checkpoint(checkpoint_id) - except Exception: # pragma: no cover - best effort - checkpoint = None - - if checkpoint is None: - runner_context = getattr(self._workflow, "_runner_context", None) - has_checkpointing = getattr(runner_context, "has_checkpointing", None) - load_checkpoint = getattr(runner_context, "load_checkpoint", None) - try: - if callable(has_checkpointing) and has_checkpointing() and callable(load_checkpoint): - loaded_checkpoint = await load_checkpoint(checkpoint_id) # type: ignore[misc] - if loaded_checkpoint is not None: - checkpoint = cast(WorkflowCheckpoint, loaded_checkpoint) - except Exception: # pragma: no cover - best effort - checkpoint = None - - if checkpoint is None: - return - - # At this point, checkpoint is guaranteed to be WorkflowCheckpoint - executor_states = cast(dict[str, Any], checkpoint.shared_state.get(EXECUTOR_STATE_KEY, {})) - orchestrator_id = getattr(orchestrator, "id", "") - orchestrator_state = cast(Any, executor_states.get(orchestrator_id)) - if orchestrator_state is None: - orchestrator_state = cast(Any, executor_states.get("magentic_orchestrator")) - - if not isinstance(orchestrator_state, dict): - return - - orchestrator_state_dict = cast(dict[str, Any], orchestrator_state) - context_payload = cast(Any, orchestrator_state_dict.get("magentic_context")) - if not isinstance(context_payload, dict): - return - - context_dict = cast(dict[str, Any], context_payload) - restored_participants = cast(Any, context_dict.get("participant_descriptions")) - if not isinstance(restored_participants, dict): - return - - participants_dict = cast(dict[str, str], restored_participants) - restored_names: set[str] = set(participants_dict.keys()) - expected_names = set(expected.keys()) - - if restored_names == expected_names: - return - - missing = ", ".join(sorted(expected_names - restored_names)) or "none" - unexpected = ", ".join(sorted(restored_names - expected_names)) or "none" - raise RuntimeError( - "Magentic checkpoint restore failed: participant names do not match the checkpoint. " - "Ensure MagenticBuilder.participants keys remain stable across runs. " - f"Missing names: {missing}; unexpected names: {unexpected}." - ) - - async def run_with_string(self, task_text: str, **kwargs: Any) -> WorkflowRunResult: - """Run the workflow with a task string and return all events. - - Args: - task_text: The task description as a string. - **kwargs: Additional keyword arguments to pass through to agent invocations. - - Returns: - WorkflowRunResult: All events generated during the workflow execution. - """ - events: list[WorkflowEvent] = [] - async for event in self.run_streaming_with_string(task_text, **kwargs): - events.append(event) - return WorkflowRunResult(events) - - async def run_with_message(self, task_message: ChatMessage, **kwargs: Any) -> WorkflowRunResult: - """Run the workflow with a ChatMessage and return all events. - - Args: - task_message: The task as a ChatMessage. - **kwargs: Additional keyword arguments to pass through to agent invocations. - - Returns: - WorkflowRunResult: All events generated during the workflow execution. - """ - events: list[WorkflowEvent] = [] - async for event in self.run_streaming_with_message(task_message, **kwargs): - events.append(event) - return WorkflowRunResult(events) - - async def run(self, message: Any | None = None, **kwargs: Any) -> WorkflowRunResult: - """Run the workflow and return all events. - - Args: - message: The message to send. If None and task_text was provided during construction, - uses the preset task string. - **kwargs: Additional keyword arguments to pass through to agent invocations. - - Returns: - WorkflowRunResult: All events generated during the workflow execution. - """ - events: list[WorkflowEvent] = [] - async for event in self.run_stream(message, **kwargs): - events.append(event) - return WorkflowRunResult(events) - - def __getattr__(self, name: str) -> Any: - """Delegate unknown attributes to the underlying workflow.""" - return getattr(self._workflow, name) - - -# endregion Magentic Workflow - -# Public aliases for unified human intervention types -MagenticHumanInterventionRequest = _MagenticHumanInterventionRequest -MagenticHumanInterventionReply = _MagenticHumanInterventionReply - -# Backward compatibility aliases (deprecated) -# Old aliases - point to unified types for compatibility -MagenticHumanInputRequest = _MagenticHumanInterventionRequest # type: ignore -MagenticStallInterventionRequest = _MagenticHumanInterventionRequest # type: ignore -MagenticStallInterventionReply = _MagenticHumanInterventionReply # type: ignore -MagenticStallInterventionDecision = MagenticHumanInterventionDecision # type: ignore diff --git a/python/packages/core/agent_framework/_workflows/_orchestration_request_info.py b/python/packages/core/agent_framework/_workflows/_orchestration_request_info.py index 91c9ec799a..dc1e282a12 100644 --- a/python/packages/core/agent_framework/_workflows/_orchestration_request_info.py +++ b/python/packages/core/agent_framework/_workflows/_orchestration_request_info.py @@ -1,37 +1,20 @@ # Copyright (c) Microsoft. All rights reserved. -"""Request info support for high-level builder APIs. - -This module provides a mechanism for pausing workflows to request external input -before agent turns in `SequentialBuilder`, `ConcurrentBuilder`, `GroupChatBuilder`, -and `HandoffBuilder`. - -The design follows the standard `request_info` pattern used throughout the -workflow system, keeping the API consistent and predictable. - -Key components: -- AgentInputRequest: Request type emitted via RequestInfoEvent for pre-agent steering -- RequestInfoInterceptor: Internal executor that pauses workflow before agent runs -""" - -import logging -import uuid -from dataclasses import dataclass, field -from typing import Any +from dataclasses import dataclass from .._agents import AgentProtocol from .._types import ChatMessage, Role -from ._agent_executor import AgentExecutorRequest +from ._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse +from ._agent_utils import resolve_agent_id from ._executor import Executor, handler from ._request_info_mixin import response_handler +from ._workflow import Workflow +from ._workflow_builder import WorkflowBuilder from ._workflow_context import WorkflowContext - -logger = logging.getLogger(__name__) +from ._workflow_executor import WorkflowExecutor -def resolve_request_info_filter( - agents: list[str | AgentProtocol | Executor] | None, -) -> set[str] | None: +def resolve_request_info_filter(agents: list[str | AgentProtocol] | None) -> set[str]: """Resolve a list of agent/executor references to a set of IDs for filtering. Args: @@ -42,288 +25,122 @@ def resolve_request_info_filter( Set of executor/agent IDs to filter on, or None if no filtering. """ if agents is None: - return None + return set() result: set[str] = set() for agent in agents: if isinstance(agent, str): result.add(agent) - elif isinstance(agent, Executor): - result.add(agent.id) elif isinstance(agent, AgentProtocol): - name = getattr(agent, "name", None) - if name: - result.add(name) - else: - logger.warning("AgentProtocol without name cannot be used for request_info filtering") + result.add(resolve_agent_id(agent)) else: - logger.warning(f"Unsupported type for request_info filter: {type(agent).__name__}") + raise TypeError(f"Unsupported type for request_info filter: {type(agent).__name__}") - return result if result else None + return result @dataclass -class AgentInputRequest: - """Request for human input before an agent runs in high-level builder workflows. - - Emitted via RequestInfoEvent when a workflow pauses before an agent executes. - The response is injected into the conversation as a user message to steer - the agent's behavior. - - This is the standard request type used by `.with_request_info()` on - SequentialBuilder, ConcurrentBuilder, GroupChatBuilder, and HandoffBuilder. +class AgentRequestInfoResponse: + """Response containing additional information requested from users for agents. Attributes: - target_agent_id: ID of the agent that is about to run - conversation: Current conversation history the agent will receive - instruction: Optional instruction from the orchestrator (e.g., manager in GroupChat) - metadata: Builder-specific context (stores internal state for resume) + messages: list[ChatMessage]: Additional messages provided by users. If empty, + the agent response is approved as-is. """ - target_agent_id: str | None - conversation: list[ChatMessage] = field(default_factory=lambda: []) - instruction: str | None = None - metadata: dict[str, Any] = field(default_factory=lambda: {}) - - -# Keep legacy name as alias for backward compatibility -AgentResponseReviewRequest = AgentInputRequest + messages: list[ChatMessage] + @staticmethod + def from_messages(messages: list[ChatMessage]) -> "AgentRequestInfoResponse": + """Create an AgentRequestInfoResponse from a list of ChatMessages. -DEFAULT_REQUEST_INFO_ID = "request_info_interceptor" - + Args: + messages: List of ChatMessage instances provided by users. -class RequestInfoInterceptor(Executor): - """Internal executor that pauses workflow for human input before agent runs. + Returns: + AgentRequestInfoResponse instance. + """ + return AgentRequestInfoResponse(messages=messages) - This executor is inserted into the workflow graph by builders when - `.with_request_info()` is called. It intercepts AgentExecutorRequest messages - BEFORE the agent runs and pauses the workflow via `ctx.request_info()` with - an AgentInputRequest. + @staticmethod + def from_strings(texts: list[str]) -> "AgentRequestInfoResponse": + """Create an AgentRequestInfoResponse from a list of string messages. - When a response is received, the response handler injects the input - as a user message into the conversation and forwards the request to the agent. + Args: + texts: List of text messages provided by users. - The optional `agent_filter` parameter allows limiting which agents trigger the pause. - If the target agent's ID is not in the filter set, the request is forwarded - without pausing. - """ + Returns: + AgentRequestInfoResponse instance. + """ + return AgentRequestInfoResponse(messages=[ChatMessage(role=Role.USER, text=text) for text in texts]) - def __init__( - self, - executor_id: str | None = None, - agent_filter: set[str] | None = None, - ) -> None: - """Initialize the request info interceptor executor. + @staticmethod + def approve() -> "AgentRequestInfoResponse": + """Create an AgentRequestInfoResponse that approves the original agent response. - Args: - executor_id: ID for this executor. If None, generates a unique ID - using the format "request_info_interceptor-". - agent_filter: Optional set of agent/executor IDs to filter on. - If provided, only requests to these agents trigger a pause. - If None (default), all requests trigger a pause. - """ - if executor_id is None: - executor_id = f"{DEFAULT_REQUEST_INFO_ID}-{uuid.uuid4().hex[:8]}" - super().__init__(executor_id) - self._agent_filter = agent_filter - - def _should_pause_for_agent(self, agent_id: str | None) -> bool: - """Check if we should pause for the given agent ID.""" - if self._agent_filter is None: - return True - if agent_id is None: - return False - # Check both the full ID and any name portion after a prefix - # e.g., "groupchat_agent:writer" should match filter "writer" - if agent_id in self._agent_filter: - return True - # Extract name from prefixed IDs like "groupchat_agent:writer" or "request_info:writer" - if ":" in agent_id: - name_part = agent_id.split(":", 1)[1] - if name_part in self._agent_filter: - return True - return False - - def _extract_agent_name_from_executor_id(self) -> str | None: - """Extract the agent name from this interceptor's executor ID. - - The interceptor ID is typically "request_info:", so we - extract the agent name to determine which agent we're intercepting for. + Returns: + AgentRequestInfoResponse instance with no additional messages. """ - if ":" in self.id: - return self.id.split(":", 1)[1] - return None - - @handler - async def intercept_agent_request( - self, - request: AgentExecutorRequest, - ctx: WorkflowContext[AgentExecutorRequest, Any], - ) -> None: - """Intercept request before agent runs and pause for human input. + return AgentRequestInfoResponse(messages=[]) - Pauses the workflow and emits a RequestInfoEvent with the current - conversation for steering. If an agent filter is configured and this - agent is not in the filter, the request is forwarded without pausing. - Args: - request: The request about to be sent to the agent - ctx: Workflow context for requesting info - """ - # Determine the target agent from our executor ID - target_agent = self._extract_agent_name_from_executor_id() - - # Check if we should pause for this agent - if not self._should_pause_for_agent(target_agent): - logger.debug(f"Skipping request_info pause for agent {target_agent} (not in filter)") - await ctx.send_message(request) - return - - conversation = list(request.messages or []) - - input_request = AgentInputRequest( - target_agent_id=target_agent, - conversation=conversation, - instruction=None, # Could be extended to include manager instruction - metadata={"_original_request": request, "_input_type": "AgentExecutorRequest"}, - ) - await ctx.request_info(input_request, str) +class AgentRequestInfoExecutor(Executor): + """Executor for gathering request info from users to assist agents.""" @handler - async def intercept_conversation( + async def request_info(self, agent_response: AgentExecutorResponse, ctx: WorkflowContext) -> None: + """Handle the agent's response and gather additional info from users.""" + await ctx.request_info(agent_response, AgentRequestInfoResponse) + + @response_handler + async def handle_request_info_response( self, - messages: list[ChatMessage], - ctx: WorkflowContext[list[ChatMessage], Any], + original_request: AgentExecutorResponse, + response: AgentRequestInfoResponse, + ctx: WorkflowContext[AgentExecutorRequest, AgentExecutorResponse], ) -> None: - """Intercept conversation before agent runs (used by SequentialBuilder). + """Process the additional info provided by users.""" + if response.messages: + # User provided additional messages, further iterate on agent response + await ctx.send_message(AgentExecutorRequest(messages=response.messages, should_respond=True)) + else: + # No additional info, approve original agent response + await ctx.yield_output(original_request) - SequentialBuilder passes list[ChatMessage] directly to agents. This handler - intercepts that flow and pauses for human input. - Args: - messages: The conversation about to be sent to the agent - ctx: Workflow context for requesting info - """ - # Determine the target agent from our executor ID - target_agent = self._extract_agent_name_from_executor_id() - - # Check if we should pause for this agent - if not self._should_pause_for_agent(target_agent): - logger.debug(f"Skipping request_info pause for agent {target_agent} (not in filter)") - await ctx.send_message(messages) - return - - input_request = AgentInputRequest( - target_agent_id=target_agent, - conversation=list(messages), - instruction=None, - metadata={"_original_messages": messages, "_input_type": "list[ChatMessage]"}, - ) - await ctx.request_info(input_request, str) +class AgentApprovalExecutor(WorkflowExecutor): + """Executor for enabling scenarios requiring agent approval in an orchestration. - @handler - async def intercept_concurrent_requests( - self, - requests: list[AgentExecutorRequest], - ctx: WorkflowContext[list[AgentExecutorRequest], Any], - ) -> None: - """Intercept requests before concurrent agents run. + This executor wraps a sub workflow that contains two executors: an agent executor + and an request info executor. The agent executor provides intelligence generation, + while the request info executor gathers input from users to further iterate on the + agent's output or send the final response to down stream executors in the orchestration. + """ - This handler is used by ConcurrentBuilder to get human input before - all parallel agents execute. + def __init__(self, agent: AgentProtocol) -> None: + """Initialize the AgentApprovalExecutor. Args: - requests: List of requests for all concurrent agents - ctx: Workflow context for requesting info + agent: The agent protocol to use for generating responses. """ - # Combine conversations for display - combined_conversation: list[ChatMessage] = [] - if requests: - combined_conversation = list(requests[0].messages or []) - - input_request = AgentInputRequest( - target_agent_id=None, # Multiple agents - conversation=combined_conversation, - instruction=None, - metadata={"_original_requests": requests}, + super().__init__(workflow=self._build_workflow(agent), id=resolve_agent_id(agent), propagate_request=True) + self._description = agent.description + + def _build_workflow(self, agent: AgentProtocol) -> Workflow: + """Build the internal workflow for the AgentApprovalExecutor.""" + agent_executor = AgentExecutor(agent) + request_info_executor = AgentRequestInfoExecutor(id="agent_request_info_executor") + + return ( + WorkflowBuilder() + # Create a loop between agent executor and request info executor + .add_edge(agent_executor, request_info_executor) + .add_edge(request_info_executor, agent_executor) + .set_start_executor(agent_executor) + .build() ) - await ctx.request_info(input_request, str) - @response_handler - async def handle_input_response( - self, - original_request: AgentInputRequest, - # TODO(@moonbox3): Extend to support other content types - response: str, - ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], Any], - ) -> None: - """Handle the human input and forward the modified request to the agent. - - Injects the response as a user message into the conversation - and forwards the modified request to the agent. - - Args: - original_request: The AgentInputRequest that triggered the pause - response: The human input text - ctx: Workflow context for continuing the workflow - - TODO: Consider having each orchestration implement its own response handler - for more specialized behavior. - """ - human_message = ChatMessage(role=Role.USER, text=response) - - # Handle concurrent case (list of AgentExecutorRequest) - original_requests: list[AgentExecutorRequest] | None = original_request.metadata.get("_original_requests") - if original_requests is not None: - updated_requests: list[AgentExecutorRequest] = [] - for orig_req in original_requests: - messages = list(orig_req.messages or []) - messages.append(human_message) - updated_requests.append( - AgentExecutorRequest( - messages=messages, - should_respond=orig_req.should_respond, - ) - ) - - logger.debug( - f"Human input received for concurrent workflow, " - f"continuing with {len(updated_requests)} updated requests" - ) - await ctx.send_message(updated_requests) # type: ignore[arg-type] - return - - # Handle list[ChatMessage] case (SequentialBuilder) - original_messages: list[ChatMessage] | None = original_request.metadata.get("_original_messages") - if original_messages is not None: - messages = list(original_messages) - messages.append(human_message) - - logger.debug( - f"Human input received for agent {original_request.target_agent_id}, " - f"forwarding conversation with steering context" - ) - await ctx.send_message(messages) - return - - # Handle AgentExecutorRequest case (GroupChatBuilder) - orig_request: AgentExecutorRequest | None = original_request.metadata.get("_original_request") - if orig_request is not None: - messages = list(orig_request.messages or []) - messages.append(human_message) - - updated_request = AgentExecutorRequest( - messages=messages, - should_respond=orig_request.should_respond, - ) - - logger.debug( - f"Human input received for agent {original_request.target_agent_id}, " - f"forwarding request with steering context" - ) - await ctx.send_message(updated_request) - return - - logger.error("Input response handler missing original request/messages in metadata") - raise RuntimeError("Missing original request or messages in AgentInputRequest metadata") + @property + def description(self) -> str | None: + """Get a description of the underlying agent.""" + return self._description diff --git a/python/packages/core/agent_framework/_workflows/_orchestration_state.py b/python/packages/core/agent_framework/_workflows/_orchestration_state.py index 26c0068e7a..8210d7d4bb 100644 --- a/python/packages/core/agent_framework/_workflows/_orchestration_state.py +++ b/python/packages/core/agent_framework/_workflows/_orchestration_state.py @@ -47,6 +47,7 @@ class OrchestrationState: conversation: list[ChatMessage] = field(default_factory=_new_chat_message_list) round_index: int = 0 + orchestrator_name: str = "" metadata: dict[str, Any] = field(default_factory=_new_metadata_dict) task: ChatMessage | None = None diff --git a/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py b/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py index 14fd68fa46..edcffaa530 100644 --- a/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py +++ b/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py @@ -7,13 +7,9 @@ """ import logging -from typing import TYPE_CHECKING, Any from .._types import ChatMessage, Role -if TYPE_CHECKING: - from ._group_chat import _GroupChatRequestMessage # type: ignore[reportPrivateUsage] - logger = logging.getLogger(__name__) @@ -99,107 +95,3 @@ def create_completion_message( text=message_text, author_name=author_name, ) - - -def prepare_participant_request( - *, - participant_name: str, - conversation: list[ChatMessage], - instruction: str | None = None, - task: ChatMessage | None = None, - metadata: dict[str, Any] | None = None, -) -> "_GroupChatRequestMessage": - """Create a standardized participant request message. - - Simple helper to avoid duplicating request construction. - - Args: - participant_name: Name of the target participant - conversation: Conversation history to send - instruction: Optional instruction from manager/orchestrator - task: Optional task context - metadata: Optional metadata dict - - Returns: - GroupChatRequestMessage ready to send - """ - # Import here to avoid circular dependency - from ._group_chat import _GroupChatRequestMessage # type: ignore[reportPrivateUsage] - - return _GroupChatRequestMessage( - agent_name=participant_name, - conversation=list(conversation), - instruction=instruction or "", - task=task, - metadata=metadata, - ) - - -class ParticipantRegistry: - """Simple registry for tracking participant executor IDs and routing info. - - Provides a clean interface for the common pattern of mapping participant names - to executor IDs and tracking which are agents vs custom executors. - - Tracks both entry IDs (where to send requests) and exit IDs (where responses - come from) to support pipeline configurations where these differ. - """ - - def __init__(self) -> None: - self._participant_entry_ids: dict[str, str] = {} - self._agent_executor_ids: dict[str, str] = {} - self._executor_id_to_participant: dict[str, str] = {} - self._non_agent_participants: set[str] = set() - - def register( - self, - name: str, - *, - entry_id: str, - is_agent: bool, - exit_id: str | None = None, - ) -> None: - """Register a participant's routing information. - - Args: - name: Participant name - entry_id: Executor ID for this participant's entry point (where to send) - is_agent: Whether this is an AgentExecutor (True) or custom Executor (False) - exit_id: Executor ID for this participant's exit point (where responses come from). - If None, defaults to entry_id (single-executor pipeline). - """ - self._participant_entry_ids[name] = entry_id - actual_exit_id = exit_id if exit_id is not None else entry_id - - if is_agent: - self._agent_executor_ids[name] = entry_id - # Map both entry and exit IDs to participant name for response routing - self._executor_id_to_participant[entry_id] = name - if actual_exit_id != entry_id: - self._executor_id_to_participant[actual_exit_id] = name - else: - self._non_agent_participants.add(name) - - def get_entry_id(self, name: str) -> str | None: - """Get the entry executor ID for a participant name.""" - return self._participant_entry_ids.get(name) - - def get_participant_name(self, executor_id: str) -> str | None: - """Get the participant name for an executor ID (agents only).""" - return self._executor_id_to_participant.get(executor_id) - - def is_agent(self, name: str) -> bool: - """Check if a participant is an agent (vs custom executor).""" - return name in self._agent_executor_ids - - def is_registered(self, name: str) -> bool: - """Check if a participant is registered.""" - return name in self._participant_entry_ids - - def is_participant_registered(self, name: str) -> bool: - """Check if a participant is registered (alias for is_registered for compatibility).""" - return self.is_registered(name) - - def all_participants(self) -> set[str]: - """Get all registered participant names.""" - return set(self._participant_entry_ids.keys()) diff --git a/python/packages/core/agent_framework/_workflows/_participant_utils.py b/python/packages/core/agent_framework/_workflows/_participant_utils.py deleted file mode 100644 index a6f1cf2a84..0000000000 --- a/python/packages/core/agent_framework/_workflows/_participant_utils.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Shared participant helpers for orchestration builders.""" - -import re -from collections.abc import Callable, Iterable, Mapping -from dataclasses import dataclass -from typing import Any - -from .._agents import AgentProtocol -from ._agent_executor import AgentExecutor -from ._executor import Executor - - -@dataclass -class GroupChatParticipantSpec: - """Metadata describing a single participant in group chat orchestrations. - - Used by multiple orchestration patterns (GroupChat, Handoff, Magentic) to describe - participants with consistent structure across different workflow types. - - Attributes: - name: Unique identifier for the participant used by managers for selection - participant: AgentProtocol or Executor instance representing the participant - description: Human-readable description provided to managers for selection context - """ - - name: str - participant: AgentProtocol | Executor - description: str - - -_SANITIZE_PATTERN = re.compile(r"[^0-9a-zA-Z]+") - - -def sanitize_identifier(value: str, *, default: str = "agent") -> str: - """Return a deterministic, lowercase identifier derived from `value`.""" - cleaned = _SANITIZE_PATTERN.sub("_", value).strip("_") - if not cleaned: - cleaned = default - if cleaned[0].isdigit(): - cleaned = f"{default}_{cleaned}" - return cleaned.lower() - - -def wrap_participant(participant: AgentProtocol | Executor, *, executor_id: str | None = None) -> Executor: - """Represent `participant` as an `Executor`.""" - if isinstance(participant, Executor): - return participant - - if not isinstance(participant, AgentProtocol): - raise TypeError( - f"Participants must implement AgentProtocol or be Executor instances. Got {type(participant).__name__}." - ) - - executor_id = executor_id or participant.display_name - return AgentExecutor(participant, id=executor_id) - - -def participant_description(participant: AgentProtocol | Executor, fallback: str) -> str: - """Produce a human-readable description for manager context.""" - if isinstance(participant, Executor): - description = getattr(participant, "description", None) - if isinstance(description, str) and description.strip(): - return description.strip() - return fallback - description = getattr(participant, "description", None) - if isinstance(description, str) and description.strip(): - return description.strip() - return fallback - - -def build_alias_map(participant: AgentProtocol | Executor, executor: Executor) -> dict[str, str]: - """Collect canonical and sanitised aliases that should resolve to `executor`.""" - aliases: dict[str, str] = {} - - def _register(values: Iterable[str | None]) -> None: - for value in values: - if not value: - continue - key = str(value) - if key not in aliases: - aliases[key] = executor.id - sanitized = sanitize_identifier(key) - if sanitized not in aliases: - aliases[sanitized] = executor.id - - _register([executor.id]) - - if isinstance(participant, AgentProtocol): - name = getattr(participant, "name", None) - display = getattr(participant, "display_name", None) - _register([name, display]) - else: - display = getattr(participant, "display_name", None) - _register([display]) - - return aliases - - -def merge_alias_maps(maps: Iterable[Mapping[str, str]]) -> dict[str, str]: - """Merge alias mappings, preserving the first occurrence of each alias.""" - merged: dict[str, str] = {} - for mapping in maps: - for key, value in mapping.items(): - merged.setdefault(key, value) - return merged - - -def prepare_participant_metadata( - participants: Mapping[str, AgentProtocol | Executor], - *, - executor_id_factory: Callable[[str, AgentProtocol | Executor], str | None] | None = None, - description_factory: Callable[[str, AgentProtocol | Executor], str] | None = None, -) -> dict[str, dict[str, Any]]: - """Return metadata dicts for participants keyed by participant name.""" - executors: dict[str, Executor] = {} - descriptions: dict[str, str] = {} - alias_maps: list[Mapping[str, str]] = [] - - for name, participant in participants.items(): - desired_id = executor_id_factory(name, participant) if executor_id_factory else None - executor = wrap_participant(participant, executor_id=desired_id) - fallback_description = description_factory(name, participant) if description_factory else executor.id - descriptions[name] = participant_description(participant, fallback_description) - executors[name] = executor - alias_maps.append(build_alias_map(participant, executor)) - - aliases = merge_alias_maps(alias_maps) - return { - "executors": executors, - "descriptions": descriptions, - "aliases": aliases, - } diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 8cc01c23cf..227f0f7fe7 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -7,11 +7,20 @@ from typing import Any from ._checkpoint import CheckpointStorage, WorkflowCheckpoint -from ._checkpoint_encoding import DATACLASS_MARKER, MODEL_MARKER, decode_checkpoint_value +from ._checkpoint_encoding import ( + DATACLASS_MARKER, + MODEL_MARKER, + decode_checkpoint_value, +) from ._const import EXECUTOR_STATE_KEY from ._edge import EdgeGroup from ._edge_runner import EdgeRunner, create_edge_runner from ._events import SuperStepCompletedEvent, SuperStepStartedEvent, WorkflowEvent +from ._exceptions import ( + WorkflowCheckpointException, + WorkflowConvergenceException, + WorkflowRunnerException, +) from ._executor import Executor from ._runner_context import ( Message, @@ -72,7 +81,7 @@ def reset_iteration_count(self) -> None: async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: """Run the workflow until no more messages are sent.""" if self._running: - raise RuntimeError("Runner is already running.") + raise WorkflowRunnerException("Runner is already running.") self._running = True try: @@ -134,12 +143,9 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: break if self._iteration >= self._max_iterations and await self._ctx.has_messages(): - raise RuntimeError(f"Runner did not converge after {self._max_iterations} iterations.") + raise WorkflowConvergenceException(f"Runner did not converge after {self._max_iterations} iterations.") logger.info(f"Workflow completed after {self._iteration} supersteps") - # TODO(@taochen): iteration is reset to zero, even in the event of a request info event. - # Should iteration be preserved in the event of a request info event? - self._iteration = 0 self._resumed_from_checkpoint = False # Reset resume flag for next run finally: self._running = False @@ -168,7 +174,8 @@ def _normalize_message_payload(message: Message) -> None: # Route all messages through normal workflow edges associated_edge_runners = self._edge_runner_map.get(source_executor_id, []) if not associated_edge_runners: - logger.warning(f"No outgoing edges found for executor {source_executor_id}; dropping messages.") + # This is expected for terminal nodes (e.g., EndWorkflow, last action in workflow) + logger.debug(f"No outgoing edges found for executor {source_executor_id}; dropping messages.") return for message in messages: @@ -211,7 +218,7 @@ async def restore_from_checkpoint( self, checkpoint_id: str, checkpoint_storage: CheckpointStorage | None = None, - ) -> bool: + ) -> None: """Restore workflow state from a checkpoint. Args: @@ -220,7 +227,10 @@ async def restore_from_checkpoint( runner context itself is not configured with checkpointing. Returns: - True if restoration was successful, False otherwise + None on success. + + Raises: + WorkflowCheckpointException on failure. """ try: # Load the checkpoint @@ -230,18 +240,19 @@ async def restore_from_checkpoint( elif checkpoint_storage is not None: checkpoint = await checkpoint_storage.load_checkpoint(checkpoint_id) else: - logger.warning("Context does not support checkpointing and no external storage was provided") - return False + raise WorkflowCheckpointException( + "Cannot load checkpoint: no checkpointing configured in context or external storage provided." + ) if not checkpoint: logger.error(f"Checkpoint {checkpoint_id} not found") - return False + raise WorkflowCheckpointException(f"Checkpoint {checkpoint_id} not found") # Validate the loaded checkpoint against the workflow graph_hash = getattr(self, "graph_signature_hash", None) checkpoint_hash = (checkpoint.metadata or {}).get("graph_signature") if graph_hash and checkpoint_hash and graph_hash != checkpoint_hash: - raise ValueError( + raise WorkflowCheckpointException( "Workflow graph has changed since the checkpoint was created. " "Please rebuild the original workflow before resuming." ) @@ -262,12 +273,11 @@ async def restore_from_checkpoint( self._mark_resumed(checkpoint.iteration_count) logger.info(f"Successfully restored workflow from checkpoint: {checkpoint_id}") - return True - except ValueError: + except WorkflowCheckpointException: raise except Exception as e: logger.error(f"Failed to restore from checkpoint {checkpoint_id}: {e}") - return False + raise WorkflowCheckpointException(f"Failed to restore from checkpoint {checkpoint_id}") from e async def _save_executor_states(self) -> None: """Populate executor state by calling checkpoint hooks on executors. @@ -308,7 +318,7 @@ async def _save_executor_states(self) -> None: try: state_dict = await executor.on_checkpoint_save() except Exception as ex: # pragma: no cover - raise ValueError(f"Executor {exec_id} on_checkpoint_save failed: {ex}") from ex + raise WorkflowCheckpointException(f"Executor {exec_id} on_checkpoint_save failed") from ex try: await self._set_executor_state(exec_id, state_dict) @@ -334,17 +344,19 @@ async def _restore_executor_states(self) -> None: executor_states = await self._shared_state.get(EXECUTOR_STATE_KEY) if not isinstance(executor_states, dict): - raise ValueError("Executor states in shared state is not a dictionary. Unable to restore.") + raise WorkflowCheckpointException("Executor states in shared state is not a dictionary. Unable to restore.") for executor_id, state in executor_states.items(): # pyright: ignore[reportUnknownVariableType] if not isinstance(executor_id, str): - raise ValueError("Executor ID in executor states is not a string. Unable to restore.") + raise WorkflowCheckpointException("Executor ID in executor states is not a string. Unable to restore.") if not isinstance(state, dict) or not all(isinstance(k, str) for k in state): # pyright: ignore[reportUnknownVariableType] - raise ValueError(f"Executor state for {executor_id} is not a dict[str, Any]. Unable to restore.") + raise WorkflowCheckpointException( + f"Executor state for {executor_id} is not a dict[str, Any]. Unable to restore." + ) executor = self._executors.get(executor_id) if not executor: - raise ValueError(f"Executor {executor_id} not found during state restoration.") + raise WorkflowCheckpointException(f"Executor {executor_id} not found during state restoration.") # Try backward compatibility behavior first # TODO(@taochen): Remove backward compatibility @@ -357,7 +369,7 @@ async def _restore_executor_states(self) -> None: await maybe # type: ignore[arg-type] restored = True except Exception as ex: # pragma: no cover - defensive - raise ValueError(f"Executor {executor_id} restore_state failed: {ex}") from ex + raise WorkflowCheckpointException(f"Executor {executor_id} restore_state failed") from ex if not restored: # Try the updated behavior only if backward compatibility did not restore @@ -365,7 +377,7 @@ async def _restore_executor_states(self) -> None: await executor.on_checkpoint_restore(state) # pyright: ignore[reportUnknownArgumentType] restored = True except Exception as ex: # pragma: no cover - defensive - raise ValueError(f"Executor {executor_id} on_checkpoint_restore failed: {ex}") from ex + raise WorkflowCheckpointException(f"Executor {executor_id} on_checkpoint_restore failed") from ex if not restored: logger.debug(f"Executor {executor_id} does not support state restoration; skipping.") @@ -408,7 +420,7 @@ async def _set_executor_state(self, executor_id: str, state: dict[str, Any]) -> existing_states = {} if not isinstance(existing_states, dict): - raise ValueError("Existing executor states in shared state is not a dictionary.") + raise WorkflowCheckpointException("Existing executor states in shared state is not a dictionary.") existing_states[executor_id] = state await self._shared_state.set(EXECUTOR_STATE_KEY, existing_states) diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index 00318d7021..62f3836617 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -13,6 +13,7 @@ from ._const import INTERNAL_SOURCE_ID from ._events import RequestInfoEvent, WorkflowEvent from ._shared_state import SharedState +from ._typing_utils import is_instance_of logger = logging.getLogger(__name__) @@ -44,7 +45,7 @@ class Message: source_span_ids: list[str] | None = None # Publishing span IDs for linking from multiple sources # For response messages, the original request data - original_request: Any = None + original_request_info_event: RequestInfoEvent | None = None # Backward compatibility properties @property @@ -66,7 +67,7 @@ def to_dict(self) -> dict[str, Any]: "type": self.type.value, "trace_contexts": self.trace_contexts, "source_span_ids": self.source_span_ids, - "original_request": self.original_request, + "original_request_info_event": encode_checkpoint_value(self.original_request_info_event), } @staticmethod @@ -86,7 +87,7 @@ def from_dict(data: dict[str, Any]) -> "Message": type=MessageType(data.get("type", "standard")), trace_contexts=data.get("trace_contexts"), source_span_ids=data.get("source_span_ids"), - original_request=data.get("original_request"), + original_request_info_event=decode_checkpoint_value(data.get("original_request_info_event")), ) @@ -413,16 +414,7 @@ async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None: self._messages.clear() messages_data = checkpoint.messages for source_id, message_list in messages_data.items(): - self._messages[source_id] = [ - Message( - data=decode_checkpoint_value(msg.get("data")), - source_id=msg.get("source_id", ""), - target_id=msg.get("target_id"), - trace_contexts=msg.get("trace_contexts"), - source_span_ids=msg.get("source_span_ids"), - ) - for msg in message_list - ] + self._messages[source_id] = [Message.from_dict(msg) for msg in message_list] # Restore pending request info events self._pending_request_info_events.clear() @@ -493,7 +485,7 @@ async def send_request_info_response(self, request_id: str, response: Any) -> No raise ValueError(f"No pending request found for request_id: {request_id}") # Validate response type if specified - if event.response_type and not isinstance(response, event.response_type): + if event.response_type and not is_instance_of(response, event.response_type): raise TypeError( f"Response type mismatch for request_id {request_id}: " f"expected {event.response_type.__name__}, got {type(response).__name__}" @@ -505,7 +497,7 @@ async def send_request_info_response(self, request_id: str, response: Any) -> No source_id=INTERNAL_SOURCE_ID(event.source_executor_id), target_id=event.source_executor_id, type=MessageType.RESPONSE, - original_request=event.data, + original_request_info_event=event, ) await self.send_message(response_msg) diff --git a/python/packages/core/agent_framework/_workflows/_sequential.py b/python/packages/core/agent_framework/_workflows/_sequential.py index 24ae4cda29..11c123d153 100644 --- a/python/packages/core/agent_framework/_workflows/_sequential.py +++ b/python/packages/core/agent_framework/_workflows/_sequential.py @@ -47,13 +47,14 @@ AgentExecutor, AgentExecutorResponse, ) +from ._agent_utils import resolve_agent_id from ._checkpoint import CheckpointStorage from ._executor import ( Executor, handler, ) from ._message_utils import normalize_messages_input -from ._orchestration_request_info import RequestInfoInterceptor +from ._orchestration_request_info import AgentApprovalExecutor from ._workflow import Workflow from ._workflow_builder import WorkflowBuilder from ._workflow_context import WorkflowContext @@ -77,24 +78,33 @@ async def from_messages(self, messages: list[str | ChatMessage], ctx: WorkflowCo await ctx.send_message(normalize_messages_input(messages)) -class _ResponseToConversation(Executor): - """Converts AgentExecutorResponse to list[ChatMessage] conversation for chaining.""" - - @handler - async def convert(self, response: AgentExecutorResponse, ctx: WorkflowContext[list[ChatMessage]]) -> None: - # Always use full_conversation; AgentExecutor guarantees it is populated. - if response.full_conversation is None: # Defensive: indicates a contract violation - raise RuntimeError("AgentExecutorResponse.full_conversation missing. AgentExecutor must populate it.") - await ctx.send_message(list(response.full_conversation)) - - class _EndWithConversation(Executor): """Terminates the workflow by emitting the final conversation context.""" @handler - async def end(self, conversation: list[ChatMessage], ctx: WorkflowContext[Any, list[ChatMessage]]) -> None: + async def end_with_messages( + self, + conversation: list[ChatMessage], + ctx: WorkflowContext[Any, list[ChatMessage]], + ) -> None: + """Handler for ending with a list of ChatMessage. + + This is used when the last participant is a custom executor. + """ await ctx.yield_output(list(conversation)) + @handler + async def end_with_agent_executor_response( + self, + response: AgentExecutorResponse, + ctx: WorkflowContext[Any, list[ChatMessage] | None], + ) -> None: + """Handle case where last participant is an agent. + + The agent is wrapped by AgentExecutor and emits AgentExecutorResponse. + """ + await ctx.yield_output(response.full_conversation) + class SequentialBuilder: r"""High-level builder for sequential agent/executor workflows with shared context. @@ -206,44 +216,65 @@ def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "Sequenti def with_request_info( self, *, - agents: Sequence[str | AgentProtocol | Executor] | None = None, + agents: Sequence[str | AgentProtocol] | None = None, ) -> "SequentialBuilder": - """Enable request info before agents run in the workflow. - - When enabled, the workflow pauses before each agent runs, emitting - a RequestInfoEvent that allows the caller to review the conversation and - optionally inject guidance before the agent responds. The caller provides - input via the standard response_handler/request_info pattern. + """Enable request info after agent participant responses. - Args: - agents: Optional filter - only pause before these specific agents/executors. - Accepts agent names (str), agent instances, or executor instances. - If None (default), pauses before every agent. - - Returns: - self: The builder instance for fluent chaining. + This enables human-in-the-loop (HIL) scenarios for the sequential orchestration. + When enabled, the workflow pauses after each agent participant runs, emitting + a RequestInfoEvent that allows the caller to review the conversation and optionally + inject guidance for the agent participant to iterate. The caller provides input via + the standard response_handler/request_info pattern. - Example: + Simulated flow with HIL: + Input -> [Agent Participant <-> Request Info] -> [Agent Participant <-> Request Info] -> ... - .. code-block:: python + Note: This is only available for agent participants. Executor participants can incorporate + request info handling in their own implementation if desired. - # Pause before all agents - workflow = SequentialBuilder().participants([a1, a2]).with_request_info().build() + Args: + agents: Optional list of agents names or agent factories to enable request info for. + If None, enables HIL for all agent participants. - # Pause only before specific agents - workflow = ( - SequentialBuilder() - .participants([drafter, reviewer, finalizer]) - .with_request_info(agents=[reviewer]) # Only pause before reviewer - .build() - ) + Returns: + Self for fluent chaining """ from ._orchestration_request_info import resolve_request_info_filter self._request_info_enabled = True self._request_info_filter = resolve_request_info_filter(list(agents) if agents else None) + return self + def _resolve_participants(self) -> list[Executor]: + """Resolve participant instances into Executor objects.""" + participants: list[Executor | AgentProtocol] = [] + if self._participant_factories: + # Resolve the participant factories now. This doesn't break the factory pattern + # since the Sequential builder still creates new instances per workflow build. + for factory in self._participant_factories: + p = factory() + participants.append(p) + else: + participants = self._participants + + executors: list[Executor] = [] + for p in participants: + if isinstance(p, Executor): + executors.append(p) + elif isinstance(p, AgentProtocol): + if self._request_info_enabled and ( + not self._request_info_filter or resolve_agent_id(p) in self._request_info_filter + ): + # Handle request info enabled agents + executors.append(AgentApprovalExecutor(p)) + else: + executors.append(AgentExecutor(p)) + else: + raise TypeError(f"Participants must be AgentProtocol or Executor instances. Got {type(p).__name__}.") + + return executors + def build(self) -> Workflow: """Build and validate the sequential workflow. @@ -272,48 +303,17 @@ def build(self) -> Workflow: input_conv = _InputToConversation(id="input-conversation") end = _EndWithConversation(id="end") + # Resolve participants and participant factories to executors + participants: list[Executor] = self._resolve_participants() + builder = WorkflowBuilder() builder.set_start_executor(input_conv) # Start of the chain is the input normalizer prior: Executor | AgentProtocol = input_conv - - participants: list[Executor | AgentProtocol] = [] - if self._participant_factories: - # Resolve the participant factories now. This doesn't break the factory pattern - # since the Sequential builder still creates new instances per workflow build. - for factory in self._participant_factories: - p = factory() - participants.append(p) - else: - participants = self._participants - for p in participants: - if isinstance(p, (AgentProtocol, AgentExecutor)): - label = p.id if isinstance(p, AgentExecutor) else p.display_name - - if self._request_info_enabled: - # Insert request info interceptor BEFORE the agent - interceptor = RequestInfoInterceptor( - executor_id=f"request_info:{label}", - agent_filter=self._request_info_filter, - ) - builder.add_edge(prior, interceptor) - builder.add_edge(interceptor, p) - else: - builder.add_edge(prior, p) - - resp_to_conv = _ResponseToConversation(id=f"to-conversation:{label}") - builder.add_edge(p, resp_to_conv) - prior = resp_to_conv - elif isinstance(p, Executor): - # Custom executor operates on list[ChatMessage] - # If the executor doesn't handle list[ChatMessage] correctly, validation will fail - builder.add_edge(prior, p) - prior = p - else: - raise TypeError(f"Unsupported participant type: {type(p).__name__}") - + builder.add_edge(prior, p) + prior = p # Terminate with the final conversation builder.add_edge(prior, end) diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index 7b446926fc..d6c612bff6 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -5,7 +5,6 @@ import hashlib import json import logging -import sys import uuid from collections.abc import AsyncIterable, Awaitable, Callable from typing import Any @@ -34,12 +33,7 @@ from ._runner import Runner from ._runner_context import RunnerContext from ._shared_state import SharedState - -if sys.version_info >= (3, 11): - pass # pragma: no cover -else: - pass # pragma: no cover - +from ._typing_utils import is_instance_of logger = logging.getLogger(__name__) @@ -425,10 +419,7 @@ async def _execute_with_message_or_checkpoint( "or build workflow with WorkflowBuilder.with_checkpointing(checkpoint_storage)." ) - restored = await self._runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage) - - if not restored: - raise RuntimeError(f"Failed to restore from checkpoint: {checkpoint_id}") + await self._runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage) # Handle initial message elif message is not None: @@ -734,7 +725,7 @@ async def _send_responses_internal(self, responses: dict[str, Any]) -> None: if request_id not in pending_requests: raise ValueError(f"Response provided for unknown request ID: {request_id}") pending_request = pending_requests[request_id] - if not isinstance(response, pending_request.response_type): + if not is_instance_of(response, pending_request.response_type): raise ValueError( f"Response type mismatch for request ID {request_id}: " f"expected {pending_request.response_type}, got {type(response)}" diff --git a/python/packages/core/agent_framework/_workflows/_workflow_builder.py b/python/packages/core/agent_framework/_workflows/_workflow_builder.py index 60c959823f..8cc31e2cc9 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_builder.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_builder.py @@ -18,6 +18,7 @@ from ._edge import ( Case, Default, + EdgeCondition, EdgeGroup, FanInEdgeGroup, FanOutEdgeGroup, @@ -48,12 +49,12 @@ class _EdgeRegistration: Args: source: The registered source name. target: The registered target name. - condition: An optional condition function for the edge. + condition: An optional condition function `(data) -> bool | Awaitable[bool]`. """ source: str target: str - condition: Callable[[Any], bool] | None = None + condition: EdgeCondition | None = None @dataclass @@ -222,7 +223,7 @@ def _maybe_wrap_agent( Args: candidate: The executor or agent to wrap. agent_thread: The thread to use for running the agent. If None, a new thread will be created. - output_response: Whether to yield an AgentRunResponse as a workflow output when the agent completes. + output_response: Whether to yield an AgentResponse as a workflow output when the agent completes. executor_id: A unique identifier for the executor. If None, the agent's name will be used if available. """ try: # Local import to avoid hard dependency at import time @@ -351,7 +352,7 @@ def register_agent( the agent's internal name. But it must be unique within the workflow. agent_thread: The thread to use for running the agent. If None, a new thread will be created when the agent is instantiated. - output_response: Whether to yield an AgentRunResponse as a workflow output when the agent completes. + output_response: Whether to yield an AgentResponse as a workflow output when the agent completes. Example: .. code-block:: python @@ -410,7 +411,7 @@ def add_agent( Args: agent: The agent to add to the workflow. agent_thread: The thread to use for running the agent. If None, a new thread will be created. - output_response: Whether to yield an AgentRunResponse as a workflow output when the agent completes. + output_response: Whether to yield an AgentResponse as a workflow output when the agent completes. id: A unique identifier for the executor. If None, the agent's name will be used if available. Returns: @@ -437,7 +438,10 @@ def add_agent( "Consider using register_agent() for lazy initialization instead." ) executor = self._maybe_wrap_agent( - agent, agent_thread=agent_thread, output_response=output_response, executor_id=id + agent, + agent_thread=agent_thread, + output_response=output_response, + executor_id=id, ) self._add_executor(executor) return self @@ -446,7 +450,7 @@ def add_edge( self, source: Executor | AgentProtocol | str, target: Executor | AgentProtocol | str, - condition: Callable[[Any], bool] | None = None, + condition: EdgeCondition | None = None, ) -> Self: """Add a directed edge between two executors. @@ -456,13 +460,14 @@ def add_edge( Args: source: The source executor or registered name of the source factory for the edge. target: The target executor or registered name of the target factory for the edge. - condition: An optional condition function that determines whether the edge - should be traversed based on the message. + condition: An optional condition function `(data) -> bool | Awaitable[bool]` + that determines whether the edge should be traversed. + Example: `lambda data: data["ready"]`. - Note: If instances are provided for both source and target, they will be shared across - all workflow instances created from the built Workflow. To avoid this, consider - registering the executors and agents using `register_executor` and `register_agent` - and referencing them by factory name for lazy initialization instead. + Note: If instances are provided for both source and target, they will be shared across + all workflow instances created from the built Workflow. To avoid this, consider + registering the executors and agents using `register_executor` and `register_agent` + and referencing them by factory name for lazy initialization instead. Returns: Self: The WorkflowBuilder instance for method chaining. @@ -496,12 +501,6 @@ async def process(self, count: int, ctx: WorkflowContext[Never, str]) -> None: .build() ) - - # With a condition - def only_large_numbers(msg: int) -> bool: - return msg > 100 - - workflow = ( WorkflowBuilder() .register_executor(lambda: ProcessorA(id="a"), name="ProcessorA") @@ -529,7 +528,7 @@ def only_large_numbers(msg: int) -> bool: target_exec = self._maybe_wrap_agent(target) # type: ignore[arg-type] source_id = self._add_executor(source_exec) target_id = self._add_executor(target_exec) - self._edge_groups.append(SingleEdgeGroup(source_id, target_id, condition)) # type: ignore[call-arg] + self._edge_groups.append(SingleEdgeGroup(source_id, target_id, condition)) return self def add_fan_out_edges( @@ -1141,7 +1140,9 @@ async def process(self, text: str, ctx: WorkflowContext[Never, str]) -> None: self._checkpoint_storage = checkpoint_storage return self - def _resolve_edge_registry(self) -> tuple[Executor, list[Executor], list[EdgeGroup]]: + def _resolve_edge_registry( + self, + ) -> tuple[Executor, list[Executor], list[EdgeGroup]]: """Resolve deferred edge registrations into executors and edge groups.""" if not self._start_executor: raise ValueError("Starting executor must be set using set_start_executor before building the workflow.") @@ -1211,7 +1212,11 @@ def _get_executor(name: str) -> Executor: if start_executor is None: raise ValueError("Failed to resolve starting executor from registered factories.") - return start_executor, list(executor_id_to_instance.values()), deferred_edge_groups + return ( + start_executor, + list(executor_id_to_instance.values()), + deferred_edge_groups, + ) def build(self) -> Workflow: """Build and return the constructed workflow. diff --git a/python/packages/core/agent_framework/_workflows/_workflow_context.py b/python/packages/core/agent_framework/_workflows/_workflow_context.py index cffeb02aa0..893f0ccfe9 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_context.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_context.py @@ -269,6 +269,7 @@ def __init__( runner_context: RunnerContext, trace_contexts: list[dict[str, str]] | None = None, source_span_ids: list[str] | None = None, + request_id: str | None = None, ): """Initialize the executor context with the given workflow context. @@ -281,6 +282,7 @@ def __init__( runner_context: The runner context that provides methods to send messages and events. trace_contexts: Optional trace contexts from multiple sources for OpenTelemetry propagation. source_span_ids: Optional source span IDs from multiple sources for linking (not for nesting). + request_id: Optional request ID if this context is for a `handle_response` handler. """ self._executor = executor self._executor_id = executor.id @@ -298,9 +300,21 @@ def __init__( self._trace_contexts = trace_contexts or [] self._source_span_ids = source_span_ids or [] + # request info related + self._request_id: str | None = request_id + if not self._source_executor_ids: raise ValueError("source_executor_ids cannot be empty. At least one source executor ID is required.") + @property + def request_id(self) -> str | None: + """Get the request ID if this context is for a `handle_response` handler. + + Returns: + The request ID string or None if not applicable. + """ + return self._request_id + async def send_message(self, message: T_Out, target_id: str | None = None) -> None: """Send a message to the workflow context. @@ -344,7 +358,7 @@ async def yield_output(self, output: T_W_Out) -> None: self._yielded_outputs.append(copy.deepcopy(output)) with _framework_event_origin(): - event = WorkflowOutputEvent(data=output, source_executor_id=self._executor_id) + event = WorkflowOutputEvent(data=output, executor_id=self._executor_id) await self._runner_context.add_event(event) async def add_event(self, event: WorkflowEvent) -> None: @@ -361,7 +375,7 @@ async def add_event(self, event: WorkflowEvent) -> None: return await self._runner_context.add_event(event) - async def request_info(self, request_data: object, response_type: type) -> None: + async def request_info(self, request_data: object, response_type: type, *, request_id: str | None = None) -> None: """Request information from outside of the workflow. Calling this method will cause the workflow to emit a RequestInfoEvent, carrying the @@ -374,6 +388,8 @@ async def request_info(self, request_data: object, response_type: type) -> None: Args: request_data: The data associated with the information request. response_type: The expected type of the response, used for validation. + request_id: Optional unique identifier for the request. If not provided, + a new UUID will be generated. This allows executors to track requests and responses. """ request_type: type = type(request_data) if not self._executor.is_request_supported(request_type, response_type): @@ -385,7 +401,7 @@ async def request_info(self, request_data: object, response_type: type) -> None: ) request_info_event = RequestInfoEvent( - request_id=str(uuid.uuid4()), + request_id=request_id or str(uuid.uuid4()), source_executor_id=self._executor_id, request_data=request_data, response_type=response_type, diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index dccd76403b..69f24bcf2c 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -18,10 +18,8 @@ WorkflowFailedEvent, WorkflowRunState, ) -from ._executor import ( - Executor, - handler, -) +from ._executor import Executor, handler +from ._request_info_mixin import response_handler from ._runner_context import Message from ._typing_utils import is_instance_of from ._workflow import WorkflowRunResult @@ -265,7 +263,14 @@ async def handle_subworkflow_request( - Concurrent executions are fully isolated and do not interfere with each other """ - def __init__(self, workflow: "Workflow", id: str, allow_direct_output: bool = False, **kwargs: Any): + def __init__( + self, + workflow: "Workflow", + id: str, + allow_direct_output: bool = False, + propagate_request: bool = False, + **kwargs: Any, + ): """Initialize the WorkflowExecutor. Args: @@ -277,6 +282,11 @@ def __init__(self, workflow: "Workflow", id: str, allow_direct_output: bool = Fa When this is set to true, the outputs are yielded directly from the WorkflowExecutor to the parent workflow's event stream. + propagate_request: Whether to propagate requests from the sub-workflow to the + parent workflow. If set to true, requests from the sub-workflow + will be propagated as the original RequestInfoEvent to the parent + workflow. Otherwise, they will be wrapped in a SubWorkflowRequestMessage, + which should be handled by an executor in the parent workflow. Keyword Args: **kwargs: Additional keyword arguments passed to the parent constructor. @@ -289,6 +299,7 @@ def __init__(self, workflow: "Workflow", id: str, allow_direct_output: bool = Fa self._execution_contexts: dict[str, ExecutionContext] = {} # execution_id -> ExecutionContext # Map request_id to execution_id for response routing self._request_to_execution: dict[str, str] = {} # request_id -> execution_id + self._propagate_request = propagate_request @property def input_types(self) -> list[type[Any]]: @@ -336,8 +347,15 @@ def can_handle(self, message: Message) -> bool: This prevents the WorkflowExecutor from accepting messages that should go to other executors because the handler `process_workflow` has no type restrictions. """ - # Always handle SubWorkflowResponseMessage if isinstance(message.data, SubWorkflowResponseMessage): + # Always handle SubWorkflowResponseMessage + return True + + if ( + message.original_request_info_event is not None + and message.original_request_info_event.request_id in self._request_to_execution + ): + # Handle propagated responses for known requests return True # For other messages, only handle if the wrapped workflow can accept them as input @@ -388,7 +406,11 @@ async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) del self._execution_contexts[execution_id] @handler - async def handle_response(self, response: SubWorkflowResponseMessage, ctx: WorkflowContext[Any]) -> None: + async def handle_message_wrapped_request_response( + self, + response: SubWorkflowResponseMessage, + ctx: WorkflowContext[Any], + ) -> None: """Handle response from parent for a forwarded request. This handler accumulates responses and only resumes the sub-workflow @@ -398,55 +420,34 @@ async def handle_response(self, response: SubWorkflowResponseMessage, ctx: Workf response: The response to a previous request. ctx: The workflow context. """ - # Find the execution context for this request - original_request = response.source_event - execution_id = self._request_to_execution.get(original_request.request_id) - if not execution_id or execution_id not in self._execution_contexts: - logger.warning( - f"WorkflowExecutor {self.id} received response for unknown request_id: {original_request.request_id}. " - "This response will be ignored." - ) - return - - execution_context = self._execution_contexts[execution_id] - - # Check if we have this pending request in the execution context - if original_request.request_id not in execution_context.pending_requests: - logger.warning( - f"WorkflowExecutor {self.id} received response for unknown request_id: " - f"{original_request.request_id} in execution {execution_id}, ignoring" - ) - return - - # Remove the request from pending list and request mapping - execution_context.pending_requests.pop(original_request.request_id, None) - self._request_to_execution.pop(original_request.request_id, None) - - # Accumulate the response in this execution's context - execution_context.collected_responses[original_request.request_id] = response.data - - # Check if we have all expected responses for this execution - if len(execution_context.collected_responses) < execution_context.expected_response_count: - logger.debug( - f"WorkflowExecutor {self.id} execution {execution_id} waiting for more responses: " - f"{len(execution_context.collected_responses)}/{execution_context.expected_response_count} received" - ) - return # Wait for more responses + await self._handle_response( + request_id=response.source_event.request_id, + response=response.data, + ctx=ctx, + ) - # Send all collected responses to the sub-workflow - responses_to_send = dict(execution_context.collected_responses) - execution_context.collected_responses.clear() # Clear for next batch + @response_handler + async def handle_propagated_request_response( + self, + original_request: Any, + response: object, + ctx: WorkflowContext[Any], + ) -> None: + """Handle response for a request that was propagated to the parent workflow. - try: - # Resume the sub-workflow with all collected responses - result = await self.workflow.send_responses(responses_to_send) + Args: + original_request: The original RequestInfoEvent. + response: The response data. + ctx: The workflow context. + """ + if ctx.request_id is None: + raise RuntimeError("WorkflowExecutor received a propagated response without a request ID in the context.") - # Process the workflow result using shared logic - await self._process_workflow_result(result, execution_context, ctx) - finally: - # Clean up execution context if it's completed (no pending requests) - if not execution_context.pending_requests: - del self._execution_contexts[execution_id] + await self._handle_response( + request_id=ctx.request_id, + response=response, + ctx=ctx, + ) @override async def on_checkpoint_save(self) -> dict[str, Any]: @@ -552,13 +553,15 @@ async def _process_workflow_result( execution_context.pending_requests[event.request_id] = event # Map request to execution for response routing self._request_to_execution[event.request_id] = execution_context.execution_id - # TODO(@taochen): There should be two ways a sub-workflow can make a request: - # 1. In a workflow where the parent workflow has an executor that may intercept the - # request and handle it directly, a message should be sent. - # 2. In a workflow where the parent workflow does not handle the request, the request - # should be propagated via the `request_info` mechanism to an external source. And - # a @response_handler would be required in the WorkflowExecutor to handle the response. - await ctx.send_message(SubWorkflowRequestMessage(source_event=event, executor_id=self.id)) + if self._propagate_request: + # In a workflow where the parent workflow does not handle the request, the request + # should be propagated via the `request_info` mechanism to an external source. And + # a @response_handler would be required in the WorkflowExecutor to handle the response. + await ctx.request_info(event.data, event.response_type, request_id=event.request_id) + else: + # In a workflow where the parent workflow has an executor that may intercept the + # request and handle it directly, a message should be sent. + await ctx.send_message(SubWorkflowRequestMessage(source_event=event, executor_id=self.id)) # Update expected response count for this execution execution_context.expected_response_count = len(request_info_events) @@ -602,3 +605,56 @@ async def _process_workflow_result( ) else: raise RuntimeError(f"Unexpected workflow run state: {workflow_run_state}") + + async def _handle_response( + self, + request_id: str, + response: Any, + ctx: WorkflowContext[Any], + ) -> None: + execution_id = self._request_to_execution.get(request_id) + if not execution_id or execution_id not in self._execution_contexts: + logger.warning( + f"WorkflowExecutor {self.id} received response for unknown request_id: {request_id}. " + "This response will be ignored." + ) + return + + execution_context = self._execution_contexts[execution_id] + + # Check if we have this pending request in the execution context + if request_id not in execution_context.pending_requests: + logger.warning( + f"WorkflowExecutor {self.id} received response for unknown request_id: " + f"{request_id} in execution {execution_id}, ignoring" + ) + return + + # Remove the request from pending list and request mapping + execution_context.pending_requests.pop(request_id, None) + self._request_to_execution.pop(request_id, None) + + # Accumulate the response in this execution's context + execution_context.collected_responses[request_id] = response + # Check if we have all expected responses for this execution + if len(execution_context.collected_responses) < execution_context.expected_response_count: + logger.debug( + f"WorkflowExecutor {self.id} execution {execution_id} waiting for more responses: " + f"{len(execution_context.collected_responses)}/{execution_context.expected_response_count} received" + ) + return # Wait for more responses + + # Send all collected responses to the sub-workflow + responses_to_send = dict(execution_context.collected_responses) + execution_context.collected_responses.clear() # Clear for next batch + + try: + # Resume the sub-workflow with all collected responses + result = await self.workflow.send_responses(responses_to_send) + + # Process the workflow result using shared logic + await self._process_workflow_result(result, execution_context, ctx) + finally: + # Clean up execution context if it's completed (no pending requests) + if not execution_context.pending_requests: + del self._execution_contexts[execution_id] diff --git a/python/packages/core/agent_framework/anthropic/__init__.py b/python/packages/core/agent_framework/anthropic/__init__.py index 2f4decc1eb..ea03e6cdf0 100644 --- a/python/packages/core/agent_framework/anthropic/__init__.py +++ b/python/packages/core/agent_framework/anthropic/__init__.py @@ -5,7 +5,7 @@ IMPORT_PATH = "agent_framework_anthropic" PACKAGE_NAME = "agent-framework-anthropic" -_IMPORTS = ["__version__", "AnthropicClient"] +_IMPORTS = ["__version__", "AnthropicClient", "AnthropicChatOptions"] def __getattr__(name: str) -> Any: diff --git a/python/packages/core/agent_framework/anthropic/__init__.pyi b/python/packages/core/agent_framework/anthropic/__init__.pyi index a86586b98f..3d790ebb07 100644 --- a/python/packages/core/agent_framework/anthropic/__init__.pyi +++ b/python/packages/core/agent_framework/anthropic/__init__.pyi @@ -1,11 +1,13 @@ # Copyright (c) Microsoft. All rights reserved. from agent_framework_anthropic import ( + AnthropicChatOptions, AnthropicClient, __version__, ) __all__ = [ + "AnthropicChatOptions", "AnthropicClient", "__version__", ] diff --git a/python/packages/core/agent_framework/azure/__init__.py b/python/packages/core/agent_framework/azure/__init__.py index 7990361c97..ea94d83f0e 100644 --- a/python/packages/core/agent_framework/azure/__init__.py +++ b/python/packages/core/agent_framework/azure/__init__.py @@ -8,14 +8,21 @@ "AgentFunctionApp": ("agent_framework_azurefunctions", "agent-framework-azurefunctions"), "AgentResponseCallbackProtocol": ("agent_framework_azurefunctions", "agent-framework-azurefunctions"), "AzureAIAgentClient": ("agent_framework_azure_ai", "agent-framework-azure-ai"), + "AzureAIAgentOptions": ("agent_framework_azure_ai", "agent-framework-azure-ai"), "AzureAIClient": ("agent_framework_azure_ai", "agent-framework-azure-ai"), + "AzureAIProjectAgentProvider": ("agent_framework_azure_ai", "agent-framework-azure-ai"), "AzureAISearchContextProvider": ("agent_framework_azure_ai_search", "agent-framework-azure-ai-search"), "AzureAISearchSettings": ("agent_framework_azure_ai_search", "agent-framework-azure-ai-search"), "AzureAISettings": ("agent_framework_azure_ai", "agent-framework-azure-ai"), + "AzureAIAgentsProvider": ("agent_framework_azure_ai", "agent-framework-azure-ai"), "AzureOpenAIAssistantsClient": ("agent_framework.azure._assistants_client", "agent-framework-core"), + "AzureOpenAIAssistantsOptions": ("agent_framework.azure._assistants_client", "agent-framework-core"), "AzureOpenAIChatClient": ("agent_framework.azure._chat_client", "agent-framework-core"), + "AzureOpenAIChatOptions": ("agent_framework.azure._chat_client", "agent-framework-core"), "AzureOpenAIResponsesClient": ("agent_framework.azure._responses_client", "agent-framework-core"), + "AzureOpenAIResponsesOptions": ("agent_framework.azure._responses_client", "agent-framework-core"), "AzureOpenAISettings": ("agent_framework.azure._shared", "agent-framework-core"), + "AzureUserSecurityContext": ("agent_framework.azure._chat_client", "agent-framework-core"), "DurableAIAgent": ("agent_framework_azurefunctions", "agent-framework-azurefunctions"), "get_entra_auth_token": ("agent_framework.azure._entra_id_authentication", "agent-framework-core"), } diff --git a/python/packages/core/agent_framework/azure/__init__.pyi b/python/packages/core/agent_framework/azure/__init__.pyi index add9ea1130..155ad5067f 100644 --- a/python/packages/core/agent_framework/azure/__init__.pyi +++ b/python/packages/core/agent_framework/azure/__init__.pyi @@ -1,6 +1,12 @@ # Copyright (c) Microsoft. All rights reserved. -from agent_framework_azure_ai import AzureAIAgentClient, AzureAIClient, AzureAISettings +from agent_framework_azure_ai import ( + AzureAIAgentClient, + AzureAIAgentsProvider, + AzureAIClient, + AzureAIProjectAgentProvider, + AzureAISettings, +) from agent_framework_azure_ai_search import AzureAISearchContextProvider, AzureAISearchSettings from agent_framework_azurefunctions import ( AgentCallbackContext, @@ -20,7 +26,9 @@ __all__ = [ "AgentFunctionApp", "AgentResponseCallbackProtocol", "AzureAIAgentClient", + "AzureAIAgentsProvider", "AzureAIClient", + "AzureAIProjectAgentProvider", "AzureAISearchContextProvider", "AzureAISearchSettings", "AzureAISettings", diff --git a/python/packages/core/agent_framework/azure/_assistants_client.py b/python/packages/core/agent_framework/azure/_assistants_client.py index 58d2dbe309..a835310435 100644 --- a/python/packages/core/agent_framework/azure/_assistants_client.py +++ b/python/packages/core/agent_framework/azure/_assistants_client.py @@ -1,22 +1,47 @@ # Copyright (c) Microsoft. All rights reserved. +import sys from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, Generic from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI from pydantic import ValidationError from ..exceptions import ServiceInitializationError from ..openai import OpenAIAssistantsClient +from ..openai._assistants_client import OpenAIAssistantsOptions from ._shared import AzureOpenAISettings if TYPE_CHECKING: from azure.core.credentials import TokenCredential +if sys.version_info >= (3, 13): + from typing import TypeVar # type: ignore # pragma: no cover +else: + from typing_extensions import TypeVar # type: ignore # pragma: no cover + +from typing import TypedDict + __all__ = ["AzureOpenAIAssistantsClient"] -class AzureOpenAIAssistantsClient(OpenAIAssistantsClient): +# region Azure OpenAI Assistants Options TypedDict + + +TAzureOpenAIAssistantsOptions = TypeVar( + "TAzureOpenAIAssistantsOptions", + bound=TypedDict, # type: ignore[valid-type] + default="OpenAIAssistantsOptions", + covariant=True, +) + + +# endregion + + +class AzureOpenAIAssistantsClient( + OpenAIAssistantsClient[TAzureOpenAIAssistantsOptions], Generic[TAzureOpenAIAssistantsOptions] +): """Azure OpenAI Assistants client.""" DEFAULT_AZURE_API_VERSION: ClassVar[str] = "2024-05-01-preview" @@ -95,6 +120,18 @@ def __init__( # Or loading from a .env file client = AzureOpenAIAssistantsClient(env_file_path="path/to/.env") + + # Using custom ChatOptions with type safety: + from typing import TypedDict + from agent_framework.azure import AzureOpenAIAssistantsOptions + + + class MyOptions(AzureOpenAIAssistantsOptions, total=False): + my_custom_option: str + + + client: AzureOpenAIAssistantsClient[MyOptions] = AzureOpenAIAssistantsClient() + response = await client.get_response("Hello", options={"my_custom_option": "value"}) """ try: azure_openai_settings = AzureOpenAISettings( diff --git a/python/packages/core/agent_framework/azure/_chat_client.py b/python/packages/core/agent_framework/azure/_chat_client.py index 59f74259a4..248e79ee47 100644 --- a/python/packages/core/agent_framework/azure/_chat_client.py +++ b/python/packages/core/agent_framework/azure/_chat_client.py @@ -4,7 +4,7 @@ import logging import sys from collections.abc import Mapping -from typing import Any, TypeVar +from typing import Any, Generic, TypedDict from azure.core.credentials import TokenCredential from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI @@ -22,13 +22,17 @@ ) from agent_framework.exceptions import ServiceInitializationError from agent_framework.observability import use_instrumentation -from agent_framework.openai._chat_client import OpenAIBaseChatClient +from agent_framework.openai._chat_client import OpenAIBaseChatClient, OpenAIChatOptions from ._shared import ( AzureOpenAIConfigMixin, AzureOpenAISettings, ) +if sys.version_info >= (3, 13): + from typing import TypeVar # type: ignore # pragma: no cover +else: + from typing_extensions import TypeVar # type: ignore # pragma: no cover if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: @@ -36,6 +40,99 @@ logger: logging.Logger = logging.getLogger(__name__) +__all__ = ["AzureOpenAIChatClient", "AzureOpenAIChatOptions", "AzureUserSecurityContext"] + + +# region Azure OpenAI Chat Options TypedDict + + +class AzureUserSecurityContext(TypedDict, total=False): + """User security context for Azure AI applications. + + These fields help security operations teams investigate and mitigate security + incidents by providing context about the application and end user. + + Learn more: https://learn.microsoft.com/azure/well-architected/service-guides/cosmos-db + """ + + application_name: str + """Name of the application making the request.""" + + end_user_id: str + """Unique identifier for the end user (recommend hashing username/email).""" + + end_user_tenant_id: str + """Microsoft 365 tenant ID the end user belongs to. Required for multi-tenant apps.""" + + source_ip: str + """The original client's IP address.""" + + +class AzureOpenAIChatOptions(OpenAIChatOptions, total=False): + """Azure OpenAI-specific chat options dict. + + Extends OpenAIChatOptions with Azure-specific options including + the "On Your Data" feature and enhanced security context. + + See: https://learn.microsoft.com/azure/ai-foundry/openai/reference-preview-latest + + Keys: + # Inherited from OpenAIChatOptions/ChatOptions: + model_id: The model to use for the request, + translates to ``model`` in Azure OpenAI API. + temperature: Sampling temperature between 0 and 2. + top_p: Nucleus sampling parameter. + max_tokens: Maximum number of tokens to generate, + translates to ``max_completion_tokens`` in Azure OpenAI API. + stop: Stop sequences. + seed: Random seed for reproducibility. + frequency_penalty: Frequency penalty between -2.0 and 2.0. + presence_penalty: Presence penalty between -2.0 and 2.0. + tools: List of tools (functions) available to the model. + tool_choice: How the model should use tools. + allow_multiple_tool_calls: Whether to allow parallel tool calls, + translates to ``parallel_tool_calls`` in Azure OpenAI API. + response_format: Structured output schema. + metadata: Request metadata for tracking. + user: End-user identifier for abuse monitoring. + store: Whether to store the conversation. + instructions: System instructions for the model. + logit_bias: Token bias values (-100 to 100). + logprobs: Whether to return log probabilities. + top_logprobs: Number of top log probabilities to return (0-20). + + # Azure-specific options: + data_sources: Azure "On Your Data" data sources configuration. + user_security_context: Enhanced security context for Azure Defender. + n: Number of chat completions to generate (not recommended, incurs costs). + """ + + # Azure-specific options + data_sources: list[dict[str, Any]] + """Azure "On Your Data" data sources for retrieval-augmented generation. + + Supported types: azure_search, azure_cosmos_db, elasticsearch, pinecone, mongo_db. + See: https://learn.microsoft.com/azure/ai-foundry/openai/references/on-your-data + """ + + user_security_context: AzureUserSecurityContext + """Enhanced security context for Azure Defender integration.""" + + n: int + """Number of chat completion choices to generate for each input message. + Note: You will be charged based on tokens across all choices. Keep n=1 to minimize costs.""" + + +TAzureOpenAIChatOptions = TypeVar( + "TAzureOpenAIChatOptions", + bound=TypedDict, # type: ignore[valid-type] + default="AzureOpenAIChatOptions", + covariant=True, +) + + +# endregion + TChatResponse = TypeVar("TChatResponse", ChatResponse, ChatResponseUpdate) TAzureOpenAIChatClient = TypeVar("TAzureOpenAIChatClient", bound="AzureOpenAIChatClient") @@ -43,7 +140,9 @@ @use_function_invocation @use_instrumentation @use_chat_middleware -class AzureOpenAIChatClient(AzureOpenAIConfigMixin, OpenAIBaseChatClient): +class AzureOpenAIChatClient( + AzureOpenAIConfigMixin, OpenAIBaseChatClient[TAzureOpenAIChatOptions], Generic[TAzureOpenAIChatOptions] +): """Azure OpenAI Chat completion class.""" def __init__( @@ -103,17 +202,31 @@ def __init__( # Using environment variables # Set AZURE_OPENAI_ENDPOINT=https://your-endpoint.openai.azure.com - # Set AZURE_OPENAI_CHAT_DEPLOYMENT_NAME=gpt-4 + # Set AZURE_OPENAI_CHAT_DEPLOYMENT_NAME= # Set AZURE_OPENAI_API_KEY=your-key client = AzureOpenAIChatClient() # Or passing parameters directly client = AzureOpenAIChatClient( - endpoint="https://your-endpoint.openai.azure.com", deployment_name="gpt-4", api_key="your-key" + endpoint="https://your-endpoint.openai.azure.com", + deployment_name="", + api_key="your-key", ) # Or loading from a .env file client = AzureOpenAIChatClient(env_file_path="path/to/.env") + + # Using custom ChatOptions with type safety: + from typing import TypedDict + from agent_framework.azure import AzureOpenAIChatOptions + + + class MyOptions(AzureOpenAIChatOptions, total=False): + my_custom_option: str + + + client: AzureOpenAIChatClient[MyOptions] = AzureOpenAIChatClient() + response = await client.get_response("Hello", options={"my_custom_option": "value"}) """ try: # Filter out any None values from the arguments diff --git a/python/packages/core/agent_framework/azure/_responses_client.py b/python/packages/core/agent_framework/azure/_responses_client.py index 3f6140eeeb..e4f6989fa0 100644 --- a/python/packages/core/agent_framework/azure/_responses_client.py +++ b/python/packages/core/agent_framework/azure/_responses_client.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. +import sys from collections.abc import Mapping -from typing import Any, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypedDict from urllib.parse import urljoin from azure.core.credentials import TokenCredential @@ -18,13 +19,37 @@ AzureOpenAISettings, ) -TAzureOpenAIResponsesClient = TypeVar("TAzureOpenAIResponsesClient", bound="AzureOpenAIResponsesClient") +if TYPE_CHECKING: + from agent_framework.openai._responses_client import OpenAIResponsesOptions + +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover +if sys.version_info >= (3, 13): + from typing import TypeVar # type: ignore # pragma: no cover +else: + from typing_extensions import TypeVar # type: ignore # pragma: no cover + +__all__ = ["AzureOpenAIResponsesClient"] + + +TAzureOpenAIResponsesOptions = TypeVar( + "TAzureOpenAIResponsesOptions", + bound=TypedDict, # type: ignore[valid-type] + default="OpenAIResponsesOptions", + covariant=True, +) @use_function_invocation @use_instrumentation @use_chat_middleware -class AzureOpenAIResponsesClient(AzureOpenAIConfigMixin, OpenAIBaseResponsesClient): +class AzureOpenAIResponsesClient( + AzureOpenAIConfigMixin, + OpenAIBaseResponsesClient[TAzureOpenAIResponsesOptions], + Generic[TAzureOpenAIResponsesOptions], +): """Azure Responses completion class.""" def __init__( @@ -95,6 +120,18 @@ def __init__( # Or loading from a .env file client = AzureOpenAIResponsesClient(env_file_path="path/to/.env") + + # Using custom ChatOptions with type safety: + from typing import TypedDict + from agent_framework.azure import AzureOpenAIResponsesOptions + + + class MyOptions(AzureOpenAIResponsesOptions, total=False): + my_custom_option: str + + + client: AzureOpenAIResponsesClient[MyOptions] = AzureOpenAIResponsesClient() + response = await client.get_response("Hello", options={"my_custom_option": "value"}) """ if model_id := kwargs.pop("model_id", None) and not deployment_name: deployment_name = str(model_id) @@ -144,3 +181,10 @@ def __init__( client=async_client, instruction_role=instruction_role, ) + + @override + def _check_model_presence(self, run_options: dict[str, Any]) -> None: + if not run_options.get("model"): + if not self.model_id: + raise ValueError("deployment_name must be a non-empty string") + run_options["model"] = self.model_id diff --git a/python/packages/core/agent_framework/declarative/__init__.py b/python/packages/core/agent_framework/declarative/__init__.py index d6002b9b0a..4ad557c0f7 100644 --- a/python/packages/core/agent_framework/declarative/__init__.py +++ b/python/packages/core/agent_framework/declarative/__init__.py @@ -5,7 +5,21 @@ IMPORT_PATH = "agent_framework_declarative" PACKAGE_NAME = "agent-framework-declarative" -_IMPORTS = ["__version__", "AgentFactory", "DeclarativeLoaderError", "ProviderLookupError", "ProviderTypeMapping"] +_IMPORTS = [ + "__version__", + "AgentFactory", + "AgentExternalInputRequest", + "AgentExternalInputResponse", + "AgentInvocationError", + "DeclarativeLoaderError", + "DeclarativeWorkflowError", + "ExternalInputRequest", + "ExternalInputResponse", + "ProviderLookupError", + "ProviderTypeMapping", + "WorkflowFactory", + "WorkflowState", +] def __getattr__(name: str) -> Any: diff --git a/python/packages/core/agent_framework/declarative/__init__.pyi b/python/packages/core/agent_framework/declarative/__init__.pyi index 0e19cc8687..8d2b717c99 100644 --- a/python/packages/core/agent_framework/declarative/__init__.pyi +++ b/python/packages/core/agent_framework/declarative/__init__.pyi @@ -1,17 +1,33 @@ # Copyright (c) Microsoft. All rights reserved. from agent_framework_declarative import ( + AgentExternalInputRequest, + AgentExternalInputResponse, AgentFactory, + AgentInvocationError, DeclarativeLoaderError, + DeclarativeWorkflowError, + ExternalInputRequest, + ExternalInputResponse, ProviderLookupError, ProviderTypeMapping, + WorkflowFactory, + WorkflowState, __version__, ) __all__ = [ + "AgentExternalInputRequest", + "AgentExternalInputResponse", "AgentFactory", + "AgentInvocationError", "DeclarativeLoaderError", + "DeclarativeWorkflowError", + "ExternalInputRequest", + "ExternalInputResponse", "ProviderLookupError", "ProviderTypeMapping", + "WorkflowFactory", + "WorkflowState", "__version__", ] diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 26c261038b..70564c9354 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -35,8 +35,8 @@ from ._threads import AgentThread from ._tools import AIFunction from ._types import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, ChatMessage, ChatResponse, ChatResponseUpdate, @@ -59,7 +59,7 @@ TAgent = TypeVar("TAgent", bound="AgentProtocol") -TChatClient = TypeVar("TChatClient", bound="ChatClientProtocol") +TChatClient = TypeVar("TChatClient", bound="ChatClientProtocol[Any]") logger = get_logger() @@ -1063,6 +1063,8 @@ def decorator(func: Callable[..., Awaitable["ChatResponse"]]) -> Callable[..., A async def trace_get_response( self: "ChatClientProtocol", messages: "str | ChatMessage | list[str] | list[ChatMessage]", + *, + options: dict[str, Any] | None = None, **kwargs: Any, ) -> "ChatResponse": global OBSERVABILITY_SETTINGS @@ -1071,18 +1073,15 @@ async def trace_get_response( return await func( self, messages=messages, + options=options, **kwargs, ) if "token_usage_histogram" not in self.additional_properties: self.additional_properties["token_usage_histogram"] = _get_token_usage_histogram() if "operation_duration_histogram" not in self.additional_properties: self.additional_properties["operation_duration_histogram"] = _get_duration_histogram() - model_id = ( - kwargs.get("model_id") - or (chat_options.model_id if (chat_options := kwargs.get("chat_options")) else None) - or getattr(self, "model_id", None) - or "unknown" - ) + options = options or {} + model_id = kwargs.get("model_id") or options.get("model_id") or getattr(self, "model_id", None) or "unknown" service_url = str( service_url_func() if (service_url_func := getattr(self, "service_url", None)) and callable(service_url_func) @@ -1101,7 +1100,7 @@ async def trace_get_response( start_time_stamp = perf_counter() end_time_stamp: float | None = None try: - response = await func(self, messages=messages, **kwargs) + response = await func(self, messages=messages, options=options, **kwargs) end_time_stamp = perf_counter() except Exception as exception: end_time_stamp = perf_counter() @@ -1152,12 +1151,16 @@ def decorator( @wraps(func) async def trace_get_streaming_response( - self: "ChatClientProtocol", messages: "str | ChatMessage | list[str] | list[ChatMessage]", **kwargs: Any + self: "ChatClientProtocol", + messages: "str | ChatMessage | list[str] | list[ChatMessage]", + *, + options: dict[str, Any] | None = None, + **kwargs: Any, ) -> AsyncIterable["ChatResponseUpdate"]: global OBSERVABILITY_SETTINGS if not OBSERVABILITY_SETTINGS.ENABLED: # If model diagnostics are not enabled, just return the completion - async for update in func(self, messages=messages, **kwargs): + async for update in func(self, messages=messages, options=options, **kwargs): yield update return if "token_usage_histogram" not in self.additional_properties: @@ -1165,12 +1168,8 @@ async def trace_get_streaming_response( if "operation_duration_histogram" not in self.additional_properties: self.additional_properties["operation_duration_histogram"] = _get_duration_histogram() - model_id = ( - kwargs.get("model_id") - or (chat_options.model_id if (chat_options := kwargs.get("chat_options")) else None) - or getattr(self, "model_id", None) - or "unknown" - ) + options = options or {} + model_id = kwargs.get("model_id") or options.get("model_id") or getattr(self, "model_id", None) or "unknown" service_url = str( service_url_func() if (service_url_func := getattr(self, "service_url", None)) and callable(service_url_func) @@ -1194,7 +1193,7 @@ async def trace_get_streaming_response( start_time_stamp = perf_counter() end_time_stamp: float | None = None try: - async for update in func(self, messages=messages, **kwargs): + async for update in func(self, messages=messages, options=options, **kwargs): all_updates.append(update) yield update end_time_stamp = perf_counter() @@ -1316,10 +1315,10 @@ async def get_streaming_response(self, messages, **kwargs): def _trace_agent_run( - run_func: Callable[..., Awaitable["AgentRunResponse"]], + run_func: Callable[..., Awaitable["AgentResponse"]], provider_name: str, capture_usage: bool = True, -) -> Callable[..., Awaitable["AgentRunResponse"]]: +) -> Callable[..., Awaitable["AgentResponse"]]: """Decorator to trace chat completion activities. Args: @@ -1335,22 +1334,26 @@ async def trace_run( *, thread: "AgentThread | None" = None, **kwargs: Any, - ) -> "AgentRunResponse": + ) -> "AgentResponse": global OBSERVABILITY_SETTINGS if not OBSERVABILITY_SETTINGS.ENABLED: # If model diagnostics are not enabled, just return the completion return await run_func(self, messages=messages, thread=thread, **kwargs) - filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"} + + from ._types import merge_chat_options + + default_options = getattr(self, "default_options", {}) + options = merge_chat_options(default_options, kwargs.get("options", {})) attributes = _get_span_attributes( operation_name=OtelAttr.AGENT_INVOKE_OPERATION, provider_name=provider_name, agent_id=self.id, - agent_name=self.display_name, + agent_name=self.name or self.id, agent_description=self.description, thread_id=thread.service_thread_id if thread else None, - chat_options=getattr(self, "chat_options", None), - **filtered_kwargs, + all_options=options, + **kwargs, ) with _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME) as span: if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages: @@ -1358,7 +1361,7 @@ async def trace_run( span=span, provider_name=provider_name, messages=messages, - system_instructions=getattr(getattr(self, "chat_options", None), "instructions", None), + system_instructions=_get_instructions_from_options(options), ) try: response = await run_func(self, messages=messages, thread=thread, **kwargs) @@ -1381,10 +1384,10 @@ async def trace_run( def _trace_agent_run_stream( - run_streaming_func: Callable[..., AsyncIterable["AgentRunResponseUpdate"]], + run_streaming_func: Callable[..., AsyncIterable["AgentResponseUpdate"]], provider_name: str, capture_usage: bool, -) -> Callable[..., AsyncIterable["AgentRunResponseUpdate"]]: +) -> Callable[..., AsyncIterable["AgentResponseUpdate"]]: """Decorator to trace streaming agent run activities. Args: @@ -1400,7 +1403,7 @@ async def trace_run_streaming( *, thread: "AgentThread | None" = None, **kwargs: Any, - ) -> AsyncIterable["AgentRunResponseUpdate"]: + ) -> AsyncIterable["AgentResponseUpdate"]: global OBSERVABILITY_SETTINGS if not OBSERVABILITY_SETTINGS.ENABLED: @@ -1409,20 +1412,21 @@ async def trace_run_streaming( yield streaming_agent_response return - from ._types import AgentRunResponse + from ._types import AgentResponse, merge_chat_options - all_updates: list["AgentRunResponseUpdate"] = [] + all_updates: list["AgentResponseUpdate"] = [] - filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"} + default_options = getattr(self, "default_options", {}) + options = merge_chat_options(default_options, kwargs.get("options", {})) attributes = _get_span_attributes( operation_name=OtelAttr.AGENT_INVOKE_OPERATION, provider_name=provider_name, agent_id=self.id, - agent_name=self.display_name, + agent_name=self.name or self.id, agent_description=self.description, thread_id=thread.service_thread_id if thread else None, - chat_options=getattr(self, "chat_options", None), - **filtered_kwargs, + all_options=options, + **kwargs, ) with _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME) as span: if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages: @@ -1430,7 +1434,7 @@ async def trace_run_streaming( span=span, provider_name=provider_name, messages=messages, - system_instructions=getattr(getattr(self, "chat_options", None), "instructions", None), + system_instructions=_get_instructions_from_options(options), ) try: async for update in run_streaming_func(self, messages=messages, thread=thread, **kwargs): @@ -1440,7 +1444,7 @@ async def trace_run_streaming( capture_exception(span=span, exception=exception, timestamp=time_ns()) raise else: - response = AgentRunResponse.from_agent_run_response_updates(all_updates) + response = AgentResponse.from_agent_run_response_updates(all_updates) attributes = _get_response_attributes(attributes, response, capture_usage=capture_usage) _capture_response(span=span, attributes=attributes) if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and response.messages: @@ -1586,7 +1590,9 @@ def _get_span( Note: `attributes` must contain the `span_name_attribute` key. """ - span = get_tracer().start_span(f"{attributes[OtelAttr.OPERATION]} {attributes[span_name_attribute]}") + operation = attributes.get(OtelAttr.OPERATION, "operation") + span_name = attributes.get(span_name_attribute, "unknown") + span = get_tracer().start_span(f"{operation} {span_name}") span.set_attributes(attributes) with trace.use_span( span=span, @@ -1597,65 +1603,96 @@ def _get_span( yield current_span +def _get_instructions_from_options(options: Any) -> str | None: + """Extract instructions from options dict.""" + if options is None: + return None + if isinstance(options, dict): + return options.get("instructions") + return None + + +# Mapping configuration for extracting span attributes +# Each entry: source_keys -> (otel_attribute_key, transform_func, check_options_first, default_value) +# - source_keys: single key or list of keys to check (first non-None value wins) +# - otel_attribute_key: target OTEL attribute name +# - transform_func: optional transformation function, can return None to skip attribute +# - check_options_first: whether to check options dict before kwargs +# - default_value: optional default value if key is not found (use None to skip) +OTEL_ATTR_MAP: dict[str | tuple[str, ...], tuple[str, Callable[[Any], Any] | None, bool, Any]] = { + "choice_count": (OtelAttr.CHOICE_COUNT, None, False, 1), + "operation_name": (OtelAttr.OPERATION, None, False, None), + "system_name": (SpanAttributes.LLM_SYSTEM, None, False, None), + "provider_name": (OtelAttr.PROVIDER_NAME, None, False, None), + "service_url": (OtelAttr.ADDRESS, None, False, None), + "conversation_id": (OtelAttr.CONVERSATION_ID, None, True, None), + "seed": (OtelAttr.SEED, None, True, None), + "frequency_penalty": (OtelAttr.FREQUENCY_PENALTY, None, True, None), + "max_tokens": (SpanAttributes.LLM_REQUEST_MAX_TOKENS, None, True, None), + "stop": (OtelAttr.STOP_SEQUENCES, None, True, None), + "temperature": (SpanAttributes.LLM_REQUEST_TEMPERATURE, None, True, None), + "top_p": (SpanAttributes.LLM_REQUEST_TOP_P, None, True, None), + "presence_penalty": (OtelAttr.PRESENCE_PENALTY, None, True, None), + "top_k": (OtelAttr.TOP_K, None, True, None), + "encoding_formats": ( + OtelAttr.ENCODING_FORMATS, + lambda v: json.dumps(v if isinstance(v, list) else [v]), + True, + None, + ), + "agent_id": (OtelAttr.AGENT_ID, None, False, None), + "agent_name": (OtelAttr.AGENT_NAME, None, False, None), + "agent_description": (OtelAttr.AGENT_DESCRIPTION, None, False, None), + # Multiple source keys - checks model_id in options, then model in kwargs, then model_id in kwargs + ("model_id", "model"): (SpanAttributes.LLM_REQUEST_MODEL, None, True, None), + # Tools with validation - returns None if no valid tools + "tools": ( + OtelAttr.TOOL_DEFINITIONS, + lambda tools: ( + json.dumps(tools_dict) + if (tools_dict := __import__("agent_framework._tools", fromlist=["_tools_to_dict"])._tools_to_dict(tools)) + else None + ), + True, + None, + ), + # Error type extraction + "error": (OtelAttr.ERROR_TYPE, lambda e: type(e).__name__, False, None), + # thread_id overrides conversation_id - processed after conversation_id due to dict ordering + "thread_id": (OtelAttr.CONVERSATION_ID, None, False, None), +} + + def _get_span_attributes(**kwargs: Any) -> dict[str, Any]: """Get the span attributes from a kwargs dictionary.""" - from ._tools import _tools_to_dict - from ._types import ChatOptions - attributes: dict[str, Any] = {} - chat_options: ChatOptions | None = kwargs.get("chat_options") - if chat_options is None: - chat_options = ChatOptions() - if operation_name := kwargs.get("operation_name"): - attributes[OtelAttr.OPERATION] = operation_name - if choice_count := kwargs.get("choice_count", 1): - attributes[OtelAttr.CHOICE_COUNT] = choice_count - if system_name := kwargs.get("system_name"): - attributes[SpanAttributes.LLM_SYSTEM] = system_name - if provider_name := kwargs.get("provider_name"): - attributes[OtelAttr.PROVIDER_NAME] = provider_name - if model_id := kwargs.get("model", chat_options.model_id): - attributes[SpanAttributes.LLM_REQUEST_MODEL] = model_id - if service_url := kwargs.get("service_url"): - attributes[OtelAttr.ADDRESS] = service_url - if conversation_id := kwargs.get("conversation_id", chat_options.conversation_id): - attributes[OtelAttr.CONVERSATION_ID] = conversation_id - if seed := kwargs.get("seed", chat_options.seed): - attributes[OtelAttr.SEED] = seed - if frequency_penalty := kwargs.get("frequency_penalty", chat_options.frequency_penalty): - attributes[OtelAttr.FREQUENCY_PENALTY] = frequency_penalty - if max_tokens := kwargs.get("max_tokens", chat_options.max_tokens): - attributes[SpanAttributes.LLM_REQUEST_MAX_TOKENS] = max_tokens - if stop := kwargs.get("stop", chat_options.stop): - attributes[OtelAttr.STOP_SEQUENCES] = stop - if temperature := kwargs.get("temperature", chat_options.temperature): - attributes[SpanAttributes.LLM_REQUEST_TEMPERATURE] = temperature - if top_p := kwargs.get("top_p", chat_options.top_p): - attributes[SpanAttributes.LLM_REQUEST_TOP_P] = top_p - if presence_penalty := kwargs.get("presence_penalty", chat_options.presence_penalty): - attributes[OtelAttr.PRESENCE_PENALTY] = presence_penalty - if top_k := kwargs.get("top_k"): - attributes[OtelAttr.TOP_K] = top_k - if encoding_formats := kwargs.get("encoding_formats"): - attributes[OtelAttr.ENCODING_FORMATS] = json.dumps( - encoding_formats if isinstance(encoding_formats, list) else [encoding_formats] - ) - if tools := kwargs.get("tools", chat_options.tools): - tools_as_json_list = _tools_to_dict(tools) - if tools_as_json_list: - attributes[OtelAttr.TOOL_DEFINITIONS] = json.dumps(tools_as_json_list) - if error := kwargs.get("error"): - attributes[OtelAttr.ERROR_TYPE] = type(error).__name__ - # agent attributes - if agent_id := kwargs.get("agent_id"): - attributes[OtelAttr.AGENT_ID] = agent_id - if agent_name := kwargs.get("agent_name"): - attributes[OtelAttr.AGENT_NAME] = agent_name - if agent_description := kwargs.get("agent_description"): - attributes[OtelAttr.AGENT_DESCRIPTION] = agent_description - if thread_id := kwargs.get("thread_id"): - # override if thread is set - attributes[OtelAttr.CONVERSATION_ID] = thread_id + options = kwargs.get("all_options", kwargs.get("options")) + if options is not None and not isinstance(options, dict): + options = None + + for source_keys, (otel_key, transform_func, check_options, default_value) in OTEL_ATTR_MAP.items(): + # Normalize to tuple of keys + keys = (source_keys,) if isinstance(source_keys, str) else source_keys + + value = None + for key in keys: + if check_options and options is not None: + value = options.get(key) + if value is None: + value = kwargs.get(key) + if value is not None: + break + + # Apply default value if no value found + if value is None and default_value is not None: + value = default_value + + if value is not None: + result = transform_func(value) if transform_func else value + # Allow transform_func to return None to skip attribute + if result is not None: + attributes[otel_key] = result + return attributes @@ -1747,7 +1784,7 @@ def _to_otel_part(content: "Contents") -> dict[str, Any] | None: def _get_response_attributes( attributes: dict[str, Any], - response: "ChatResponse | AgentRunResponse", + response: "ChatResponse | AgentResponse", duration: float | None = None, *, capture_usage: bool = True, diff --git a/python/packages/core/agent_framework/openai/__init__.py b/python/packages/core/agent_framework/openai/__init__.py index bf15a9e027..daa0542b13 100644 --- a/python/packages/core/agent_framework/openai/__init__.py +++ b/python/packages/core/agent_framework/openai/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. +from ._assistant_provider import * # noqa: F403 from ._assistants_client import * # noqa: F403 from ._chat_client import * # noqa: F403 from ._exceptions import * # noqa: F403 diff --git a/python/packages/core/agent_framework/openai/_assistant_provider.py b/python/packages/core/agent_framework/openai/_assistant_provider.py new file mode 100644 index 0000000000..336fe40c72 --- /dev/null +++ b/python/packages/core/agent_framework/openai/_assistant_provider.py @@ -0,0 +1,563 @@ +# Copyright (c) Microsoft. All rights reserved. + +import sys +from collections.abc import Awaitable, Callable, MutableMapping, Sequence +from typing import TYPE_CHECKING, Any, Generic, TypedDict, cast + +from openai import AsyncOpenAI +from openai.types.beta.assistant import Assistant +from pydantic import BaseModel, SecretStr, ValidationError + +from .._agents import ChatAgent +from .._memory import ContextProvider +from .._middleware import Middleware +from .._tools import AIFunction, ToolProtocol +from .._types import normalize_tools +from ..exceptions import ServiceInitializationError +from ._assistants_client import OpenAIAssistantsClient +from ._shared import OpenAISettings, from_assistant_tools, to_assistant_tools + +if TYPE_CHECKING: + from ._assistants_client import OpenAIAssistantsOptions + +if sys.version_info >= (3, 13): + from typing import Self, TypeVar # pragma: no cover +else: + from typing_extensions import Self, TypeVar # pragma: no cover + + +__all__ = ["OpenAIAssistantProvider"] + +# Type variable for options - allows typed ChatAgent[TOptions] returns +# Default matches OpenAIAssistantsClient's default options type +TOptions_co = TypeVar( + "TOptions_co", + bound=TypedDict, # type: ignore[valid-type] + default="OpenAIAssistantsOptions", + covariant=True, +) + +_ToolsType = ( + ToolProtocol + | Callable[..., Any] + | MutableMapping[str, Any] + | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] +) + + +class OpenAIAssistantProvider(Generic[TOptions_co]): + """Provider for creating ChatAgent instances from OpenAI Assistants API. + + This provider allows you to create, retrieve, and wrap OpenAI Assistants + as ChatAgent instances for use in the agent framework. + + Examples: + Basic usage with automatic client creation: + + .. code-block:: python + + from agent_framework.openai import OpenAIAssistantProvider + + # Uses OPENAI_API_KEY environment variable + provider = OpenAIAssistantProvider() + + # Create a new assistant + agent = await provider.create_agent( + name="MyAssistant", + model="gpt-4", + instructions="You are a helpful assistant.", + tools=[my_function], + ) + + result = await agent.run("Hello!") + + Using an existing client: + + .. code-block:: python + + from openai import AsyncOpenAI + from agent_framework.openai import OpenAIAssistantProvider + + client = AsyncOpenAI() + provider = OpenAIAssistantProvider(client) + + # Get an existing assistant by ID + agent = await provider.get_agent( + assistant_id="asst_123", + tools=[my_function], # Provide implementations for function tools + ) + + Wrapping an SDK Assistant object: + + .. code-block:: python + + # Fetch assistant directly via SDK + assistant = await client.beta.assistants.retrieve("asst_123") + + # Wrap without additional HTTP call + agent = provider.as_agent(assistant, tools=[my_function]) + """ + + def __init__( + self, + client: AsyncOpenAI | None = None, + *, + api_key: str | SecretStr | Callable[[], str | Awaitable[str]] | None = None, + org_id: str | None = None, + base_url: str | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + ) -> None: + """Initialize the OpenAI Assistant Provider. + + Args: + client: An existing AsyncOpenAI client to use. If not provided, + a new client will be created using the other parameters. + + Keyword Args: + api_key: OpenAI API key. Can also be set via OPENAI_API_KEY env var. + org_id: OpenAI organization ID. Can also be set via OPENAI_ORG_ID env var. + base_url: Base URL for the OpenAI API. Can also be set via OPENAI_BASE_URL env var. + env_file_path: Path to .env file for configuration. + env_file_encoding: Encoding of the .env file. + + Raises: + ServiceInitializationError: If no client is provided and API key is missing. + + Examples: + .. code-block:: python + + # Using environment variables + provider = OpenAIAssistantProvider() + + # Using explicit API key + provider = OpenAIAssistantProvider(api_key="sk-...") + + # Using existing client + client = AsyncOpenAI() + provider = OpenAIAssistantProvider(client) + """ + self._client: AsyncOpenAI | None = client + self._should_close_client: bool = client is None + + if client is None: + # Load settings and create client + try: + settings = OpenAISettings( + api_key=api_key, # type: ignore[reportArgumentType] + org_id=org_id, + base_url=base_url, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) + except ValidationError as ex: + raise ServiceInitializationError("Failed to create OpenAI settings.", ex) from ex + + if not settings.api_key: + raise ServiceInitializationError( + "OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable." + ) + + # Get API key value + api_key_value: str | Callable[[], str | Awaitable[str]] | None + if isinstance(settings.api_key, SecretStr): + api_key_value = settings.api_key.get_secret_value() + else: + api_key_value = settings.api_key + + # Create client + client_args: dict[str, Any] = {"api_key": api_key_value} + if settings.org_id: + client_args["organization"] = settings.org_id + if settings.base_url: + client_args["base_url"] = settings.base_url + + self._client = AsyncOpenAI(**client_args) + + async def __aenter__(self) -> "Self": + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + """Async context manager exit.""" + await self.close() + + async def close(self) -> None: + """Close the provider and clean up resources. + + If the provider created its own client, it will be closed. + If an external client was provided, it will not be closed. + """ + if self._should_close_client and self._client is not None: + await self._client.close() + + async def create_agent( + self, + *, + name: str, + model: str, + instructions: str | None = None, + description: str | None = None, + tools: _ToolsType | None = None, + metadata: dict[str, str] | None = None, + default_options: TOptions_co | None = None, + middleware: Sequence[Middleware] | None = None, + context_provider: ContextProvider | None = None, + ) -> "ChatAgent[TOptions_co]": + """Create a new assistant on OpenAI and return a ChatAgent. + + This method creates a new assistant on the OpenAI service and wraps it + in a ChatAgent instance. The assistant will persist on OpenAI until deleted. + + Keyword Args: + name: The name of the assistant (required). + model: The model ID to use, e.g., "gpt-4", "gpt-4o" (required). + instructions: System instructions for the assistant. + description: A description of the assistant. + tools: Tools available to the assistant. Can include: + - AIFunction instances or callables decorated with @ai_function + - HostedCodeInterpreterTool for code execution + - HostedFileSearchTool for vector store search + - Raw tool dictionaries + metadata: Metadata to attach to the assistant (max 16 key-value pairs). + default_options: A TypedDict containing default chat options for the agent. + These options are applied to every run unless overridden. + Include ``response_format`` here for structured output responses. + middleware: Middleware for the ChatAgent. + context_provider: Context provider for the ChatAgent. + + Returns: + A ChatAgent instance wrapping the created assistant. + + Raises: + ServiceInitializationError: If assistant creation fails. + + Examples: + .. code-block:: python + + provider = OpenAIAssistantProvider() + + # Create with function tools + agent = await provider.create_agent( + name="WeatherBot", + model="gpt-4", + instructions="You are a helpful weather assistant.", + tools=[get_weather], + ) + + # Create with structured output + agent = await provider.create_agent( + name="StructuredBot", + model="gpt-4", + default_options={"response_format": MyPydanticModel}, + ) + """ + # Normalize tools + normalized_tools = normalize_tools(tools) + api_tools = to_assistant_tools(normalized_tools) if normalized_tools else [] + + # Extract response_format from default_options if present + opts = dict(default_options) if default_options else {} + response_format = opts.get("response_format") + + # Build assistant creation parameters + create_params: dict[str, Any] = { + "model": model, + "name": name, + } + + if instructions is not None: + create_params["instructions"] = instructions + if description is not None: + create_params["description"] = description + if api_tools: + create_params["tools"] = api_tools + if metadata is not None: + create_params["metadata"] = metadata + + # Handle response format for OpenAI API + if response_format is not None and isinstance(response_format, type) and issubclass(response_format, BaseModel): + create_params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": response_format.__name__, + "schema": response_format.model_json_schema(), + "strict": True, + }, + } + + # Create the assistant + if not self._client: + raise ServiceInitializationError("OpenAI client is not initialized.") + + assistant = await self._client.beta.assistants.create(**create_params) + + # Create ChatAgent - pass default_options which contains response_format + return self._create_chat_agent_from_assistant( + assistant=assistant, + tools=normalized_tools, + instructions=instructions, + middleware=middleware, + context_provider=context_provider, + default_options=default_options, + ) + + async def get_agent( + self, + assistant_id: str, + *, + tools: _ToolsType | None = None, + instructions: str | None = None, + default_options: TOptions_co | None = None, + middleware: Sequence[Middleware] | None = None, + context_provider: ContextProvider | None = None, + ) -> "ChatAgent[TOptions_co]": + """Retrieve an existing assistant by ID and return a ChatAgent. + + This method fetches an existing assistant from OpenAI by its ID + and wraps it in a ChatAgent instance. + + Args: + assistant_id: The ID of the assistant to retrieve (e.g., "asst_123"). + + Keyword Args: + tools: Function tools to make available. IMPORTANT: If the assistant + was created with function tools, you MUST provide matching + implementations here. Hosted tools (code_interpreter, file_search) + are automatically included. + instructions: Override the assistant's instructions (optional). + default_options: A TypedDict containing default chat options for the agent. + These options are applied to every run unless overridden. + middleware: Middleware for the ChatAgent. + context_provider: Context provider for the ChatAgent. + + Returns: + A ChatAgent instance wrapping the retrieved assistant. + + Raises: + ServiceInitializationError: If the assistant cannot be retrieved. + ValueError: If required function tools are missing. + + Examples: + .. code-block:: python + + provider = OpenAIAssistantProvider() + + # Get assistant without function tools + agent = await provider.get_agent(assistant_id="asst_123") + + # Get assistant with function tools + agent = await provider.get_agent( + assistant_id="asst_456", + tools=[get_weather, search_database], # Implementations required! + ) + """ + # Fetch the assistant + if not self._client: + raise ServiceInitializationError("OpenAI client is not initialized.") + + assistant = await self._client.beta.assistants.retrieve(assistant_id) + + # Use as_agent to wrap it + return self.as_agent( + assistant=assistant, + tools=tools, + instructions=instructions, + default_options=default_options, + middleware=middleware, + context_provider=context_provider, + ) + + def as_agent( + self, + assistant: Assistant, + *, + tools: _ToolsType | None = None, + instructions: str | None = None, + default_options: TOptions_co | None = None, + middleware: Sequence[Middleware] | None = None, + context_provider: ContextProvider | None = None, + ) -> "ChatAgent[TOptions_co]": + """Wrap an existing SDK Assistant object as a ChatAgent. + + This method does NOT make any HTTP calls. It simply wraps an already- + fetched Assistant object in a ChatAgent. + + Args: + assistant: The OpenAI Assistant SDK object to wrap. + + Keyword Args: + tools: Function tools to make available. If the assistant has + function tools defined, you MUST provide matching implementations. + Hosted tools (code_interpreter, file_search) are automatically included. + instructions: Override the assistant's instructions (optional). + default_options: A TypedDict containing default chat options for the agent. + These options are applied to every run unless overridden. + middleware: Middleware for the ChatAgent. + context_provider: Context provider for the ChatAgent. + + Returns: + A ChatAgent instance wrapping the assistant. + + Raises: + ValueError: If required function tools are missing. + + Examples: + .. code-block:: python + + client = AsyncOpenAI() + provider = OpenAIAssistantProvider(client) + + # Fetch assistant via SDK + assistant = await client.beta.assistants.retrieve("asst_123") + + # Wrap without additional HTTP call + agent = provider.as_agent( + assistant, + tools=[my_function], + instructions="Custom instructions override", + ) + """ + # Validate that required function tools are provided + self._validate_function_tools(assistant.tools or [], tools) + + # Merge hosted tools with user-provided function tools + merged_tools = self._merge_tools(assistant.tools or [], tools) + + # Create ChatAgent + return self._create_chat_agent_from_assistant( + assistant=assistant, + tools=merged_tools, + instructions=instructions, + default_options=default_options, + middleware=middleware, + context_provider=context_provider, + ) + + def _validate_function_tools( + self, + assistant_tools: list[Any], + provided_tools: _ToolsType | None, + ) -> None: + """Validate that required function tools are provided. + + Args: + assistant_tools: Tools defined on the assistant. + provided_tools: Tools provided by the user. + + Raises: + ValueError: If a required function tool is missing. + """ + # Get function tool names from assistant + required_functions: set[str] = set() + for tool in assistant_tools: + if ( + hasattr(tool, "type") + and tool.type == "function" + and hasattr(tool, "function") + and hasattr(tool.function, "name") + ): + required_functions.add(tool.function.name) + + if not required_functions: + return # No function tools required + + # Get provided function names using normalize_tools + provided_functions: set[str] = set() + if provided_tools is not None: + normalized = normalize_tools(provided_tools) + for tool in normalized: + if isinstance(tool, AIFunction): + provided_functions.add(tool.name) + elif isinstance(tool, MutableMapping) and "function" in tool: + func_spec = tool.get("function", {}) + if isinstance(func_spec, dict): + func_dict = cast(dict[str, Any], func_spec) + if "name" in func_dict: + provided_functions.add(str(func_dict["name"])) + + # Check for missing functions + missing = required_functions - provided_functions + if missing: + missing_list = ", ".join(sorted(missing)) + raise ValueError( + f"Assistant requires function tool(s) '{missing_list}' but no implementation was provided. " + f"Please pass the function implementation(s) in the 'tools' parameter." + ) + + def _merge_tools( + self, + assistant_tools: list[Any], + user_tools: _ToolsType | None, + ) -> list[ToolProtocol | MutableMapping[str, Any]]: + """Merge hosted tools from assistant with user-provided function tools. + + Args: + assistant_tools: Tools defined on the assistant. + user_tools: Tools provided by the user. + + Returns: + A list of all tools (hosted tools + user function implementations). + """ + merged: list[ToolProtocol | MutableMapping[str, Any]] = [] + + # Add hosted tools from assistant using shared conversion + hosted_tools = from_assistant_tools(assistant_tools) + merged.extend(hosted_tools) + + # Add user-provided tools (normalized) + if user_tools is not None: + normalized_user_tools = normalize_tools(user_tools) + merged.extend(normalized_user_tools) + + return merged + + def _create_chat_agent_from_assistant( + self, + assistant: Assistant, + tools: list[ToolProtocol | MutableMapping[str, Any]] | None, + instructions: str | None, + middleware: Sequence[Middleware] | None, + context_provider: ContextProvider | None, + default_options: TOptions_co | None = None, + **kwargs: Any, + ) -> "ChatAgent[TOptions_co]": + """Create a ChatAgent from an Assistant. + + Args: + assistant: The OpenAI Assistant object. + tools: Tools for the agent. + instructions: Instructions override. + middleware: Middleware for the agent. + context_provider: Context provider for the agent. + default_options: Default chat options for the agent (may include response_format). + **kwargs: Additional arguments passed to ChatAgent. + + Returns: + A configured ChatAgent instance. + """ + # Create the chat client with the assistant + chat_client = OpenAIAssistantsClient( + model_id=assistant.model, + assistant_id=assistant.id, + assistant_name=assistant.name, + assistant_description=assistant.description, + async_client=self._client, + ) + + # Use instructions from assistant if not overridden + final_instructions = instructions if instructions is not None else assistant.instructions + + # Create and return ChatAgent + return ChatAgent( + chat_client=chat_client, + id=assistant.id, + name=assistant.name, + description=assistant.description, + instructions=final_instructions, + tools=tools if tools else None, + middleware=middleware, + context_provider=context_provider, + default_options=default_options, # type: ignore[arg-type] + **kwargs, + ) diff --git a/python/packages/core/agent_framework/openai/_assistants_client.py b/python/packages/core/agent_framework/openai/_assistants_client.py index b6f97371b7..d6dd7c251a 100644 --- a/python/packages/core/agent_framework/openai/_assistants_client.py +++ b/python/packages/core/agent_framework/openai/_assistants_client.py @@ -2,8 +2,15 @@ import json import sys -from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, MutableSequence -from typing import Any, cast +from collections.abc import ( + AsyncIterable, + Awaitable, + Callable, + Mapping, + MutableMapping, + MutableSequence, +) +from typing import Any, Generic, Literal, TypedDict, cast from openai import AsyncOpenAI from openai.types.beta.threads import ( @@ -22,7 +29,12 @@ from .._clients import BaseChatClient from .._middleware import use_chat_middleware -from .._tools import AIFunction, HostedCodeInterpreterTool, HostedFileSearchTool, use_function_invocation +from .._tools import ( + AIFunction, + HostedCodeInterpreterTool, + HostedFileSearchTool, + use_function_invocation, +) from .._types import ( ChatMessage, ChatOptions, @@ -35,7 +47,6 @@ MCPServerToolCallContent, Role, TextContent, - ToolMode, UriContent, UsageContent, UsageDetails, @@ -45,19 +56,162 @@ from ..observability import use_instrumentation from ._shared import OpenAIConfigMixin, OpenAISettings +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover + if sys.version_info >= (3, 11): from typing import Self # pragma: no cover else: from typing_extensions import Self # pragma: no cover -__all__ = ["OpenAIAssistantsClient"] +__all__ = [ + "AssistantToolResources", + "OpenAIAssistantsClient", + "OpenAIAssistantsOptions", +] + + +# region OpenAI Assistants Options TypedDict + + +class VectorStoreToolResource(TypedDict, total=False): + """Vector store configuration for file search tool resources.""" + + vector_store_ids: list[str] + """IDs of vector stores attached to this assistant.""" + + +class CodeInterpreterToolResource(TypedDict, total=False): + """Code interpreter tool resource configuration.""" + + file_ids: list[str] + """File IDs accessible by the code interpreter tool. Max 20 files per assistant.""" + + +class AssistantToolResources(TypedDict, total=False): + """Tool resources attached to the assistant. + + See: https://platform.openai.com/docs/api-reference/assistants/createAssistant#assistants-createassistant-tool_resources + """ + + code_interpreter: CodeInterpreterToolResource + """Resources for code interpreter tool, including file IDs.""" + + file_search: VectorStoreToolResource + """Resources for file search tool, including vector store IDs.""" + + +class OpenAIAssistantsOptions(ChatOptions, total=False): + """OpenAI Assistants API-specific options dict. + + Extends base ChatOptions with Assistants API-specific parameters + for creating and running assistants. + + See: https://platform.openai.com/docs/api-reference/assistants + + Keys: + # Inherited from ChatOptions: + model_id: The model to use for the assistant, + translates to ``model`` in OpenAI API. + temperature: Sampling temperature between 0 and 2. + top_p: Nucleus sampling parameter. + max_tokens: Maximum number of tokens to generate, + translates to ``max_completion_tokens`` in OpenAI API. + tools: List of tools (functions, code_interpreter, file_search). + tool_choice: How the model should use tools. + allow_multiple_tool_calls: Whether to allow parallel tool calls, + translates to ``parallel_tool_calls`` in OpenAI API. + response_format: Structured output schema. + metadata: Request metadata for tracking. + + # Options not supported in Assistants API (inherited but unused): + stop: Not supported. + seed: Not supported (use assistant-level configuration instead). + frequency_penalty: Not supported. + presence_penalty: Not supported. + user: Not supported. + store: Not supported. + + # Assistants-specific options: + name: Name of the assistant. + description: Description of the assistant. + instructions: System instructions for the assistant. + tool_resources: Resources for tools (file IDs, vector stores). + reasoning_effort: Effort level for o-series reasoning models. + conversation_id: Thread ID to continue conversation in. + """ + + # Assistants-specific options + name: str + """Name of the assistant (max 256 characters).""" + + description: str + """Description of the assistant (max 512 characters).""" + + tool_resources: AssistantToolResources + """Tool-specific resources like file IDs and vector stores.""" + + reasoning_effort: Literal["low", "medium", "high"] + """Effort level for o-series reasoning models (o1, o3-mini). + Higher effort = more reasoning time and potentially better results.""" + + conversation_id: str # type: ignore[misc] + """Thread ID to continue a conversation in an existing thread.""" + + # OpenAI/ChatOptions fields not supported in Assistants API + stop: None # type: ignore[misc] + """Not supported in Assistants API.""" + + seed: None # type: ignore[misc] + """Not supported in Assistants API (use assistant-level configuration).""" + + frequency_penalty: None # type: ignore[misc] + """Not supported in Assistants API.""" + + presence_penalty: None # type: ignore[misc] + """Not supported in Assistants API.""" + + user: None # type: ignore[misc] + """Not supported in Assistants API.""" + + store: None # type: ignore[misc] + """Not supported in Assistants API.""" + + +ASSISTANTS_OPTION_TRANSLATIONS: dict[str, str] = { + "model_id": "model", + "max_tokens": "max_completion_tokens", + "allow_multiple_tool_calls": "parallel_tool_calls", +} +"""Maps ChatOptions keys to OpenAI Assistants API parameter names.""" + +TOpenAIAssistantsOptions = TypeVar( + "TOpenAIAssistantsOptions", + bound=TypedDict, # type: ignore[valid-type] + default="OpenAIAssistantsOptions", + covariant=True, +) + + +# endregion @use_function_invocation @use_instrumentation @use_chat_middleware -class OpenAIAssistantsClient(OpenAIConfigMixin, BaseChatClient): +class OpenAIAssistantsClient( + OpenAIConfigMixin, + BaseChatClient[TOpenAIAssistantsOptions], + Generic[TOpenAIAssistantsOptions], +): """OpenAI Assistants client.""" def __init__( @@ -118,6 +272,18 @@ def __init__( # Or loading from a .env file client = OpenAIAssistantsClient(env_file_path="path/to/.env") + + # Using custom ChatOptions with type safety: + from typing import TypedDict + from agent_framework.openai import OpenAIAssistantsOptions + + + class MyOptions(OpenAIAssistantsOptions, total=False): + my_custom_option: str + + + client: OpenAIAssistantsClient[MyOptions] = OpenAIAssistantsClient(model_id="gpt-4") + response = await client.get_response("Hello", options={"my_custom_option": "value"}) """ try: openai_settings = OpenAISettings( @@ -159,7 +325,12 @@ async def __aenter__(self) -> "Self": """Async context manager entry.""" return self - async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: """Async context manager exit - clean up any assistants we created.""" await self.close() @@ -171,34 +342,32 @@ async def close(self) -> None: object.__setattr__(self, "assistant_id", None) object.__setattr__(self, "_should_delete_assistant", False) + @override async def _inner_get_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> ChatResponse: return await ChatResponse.from_chat_response_generator( - updates=self._inner_get_streaming_response(messages=messages, chat_options=chat_options, **kwargs), - output_format_type=chat_options.response_format, + updates=self._inner_get_streaming_response(messages=messages, options=options, **kwargs), + output_format_type=options.get("response_format"), ) + @override async def _inner_get_streaming_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: # prepare - run_options, tool_results = self._prepare_options(messages, chat_options, **kwargs) + run_options, tool_results = self._prepare_options(messages, options, **kwargs) # Get the thread ID - thread_id: str | None = ( - chat_options.conversation_id - if chat_options.conversation_id is not None - else run_options.get("conversation_id", self.thread_id) - ) + thread_id: str | None = options.get("conversation_id", run_options.get("conversation_id", self.thread_id)) if thread_id is None and tool_results is not None: raise ValueError("No thread ID was provided, but chat messages includes tool results.") @@ -256,7 +425,9 @@ async def _create_assistant_stream( if thread_run is not None and tool_run_id is not None and tool_run_id == thread_run.id and tool_outputs: # There's an active run and we have tool results to submit, so submit the results. stream = client.beta.threads.runs.submit_tool_outputs_stream( # type: ignore[reportDeprecated] - run_id=tool_run_id, thread_id=thread_run.thread_id, tool_outputs=tool_outputs + run_id=tool_run_id, + thread_id=thread_run.thread_id, + tool_outputs=tool_outputs, ) final_thread_id = thread_run.thread_id else: @@ -408,7 +579,11 @@ def _parse_function_calls_from_assistants(self, event_data: Run, response_id: st function_name = tool_call.function.name function_arguments = json.loads(tool_call.function.arguments) contents.append( - FunctionCallContent(call_id=call_id, name=function_name, arguments=function_arguments) + FunctionCallContent( + call_id=call_id, + name=function_name, + arguments=function_arguments, + ) ) return contents @@ -416,59 +591,76 @@ def _parse_function_calls_from_assistants(self, event_data: Run, response_id: st def _prepare_options( self, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions | None, + options: dict[str, Any], **kwargs: Any, ) -> tuple[dict[str, Any], list[FunctionResultContent] | None]: + from .._types import validate_tool_mode + run_options: dict[str, Any] = {**kwargs} - if chat_options is not None: - run_options["max_completion_tokens"] = chat_options.max_tokens - run_options["model"] = chat_options.model_id - run_options["top_p"] = chat_options.top_p - run_options["temperature"] = chat_options.temperature - - if chat_options.allow_multiple_tool_calls is not None: - run_options["parallel_tool_calls"] = chat_options.allow_multiple_tool_calls - - if chat_options.tool_choice is not None: - tool_definitions: list[MutableMapping[str, Any]] = [] - if chat_options.tool_choice != "none" and chat_options.tools is not None: - for tool in chat_options.tools: - if isinstance(tool, AIFunction): - tool_definitions.append(tool.to_json_schema_spec()) # type: ignore[reportUnknownArgumentType] - elif isinstance(tool, HostedCodeInterpreterTool): - tool_definitions.append({"type": "code_interpreter"}) - elif isinstance(tool, HostedFileSearchTool): - params: dict[str, Any] = { - "type": "file_search", - } - if tool.max_results is not None: - params["max_num_results"] = tool.max_results - tool_definitions.append(params) - elif isinstance(tool, MutableMapping): - tool_definitions.append(tool) - - if len(tool_definitions) > 0: - run_options["tools"] = tool_definitions - - if chat_options.tool_choice == "none" or chat_options.tool_choice == "auto": - run_options["tool_choice"] = chat_options.tool_choice.mode - elif ( - isinstance(chat_options.tool_choice, ToolMode) - and chat_options.tool_choice == "required" - and chat_options.tool_choice.required_function_name is not None - ): - run_options["tool_choice"] = { - "type": "function", - "function": {"name": chat_options.tool_choice.required_function_name}, + # Extract options from the dict + max_tokens = options.get("max_tokens") + model_id = options.get("model_id") + top_p = options.get("top_p") + temperature = options.get("temperature") + allow_multiple_tool_calls = options.get("allow_multiple_tool_calls") + tool_choice = options.get("tool_choice") + tools = options.get("tools") + response_format = options.get("response_format") + + if max_tokens is not None: + run_options["max_completion_tokens"] = max_tokens + if model_id is not None: + run_options["model"] = model_id + if top_p is not None: + run_options["top_p"] = top_p + if temperature is not None: + run_options["temperature"] = temperature + + if allow_multiple_tool_calls is not None: + run_options["parallel_tool_calls"] = allow_multiple_tool_calls + + tool_mode = validate_tool_mode(tool_choice) + tool_definitions: list[MutableMapping[str, Any]] = [] + if tool_mode["mode"] != "none" and tools is not None: + for tool in tools: + if isinstance(tool, AIFunction): + tool_definitions.append(tool.to_json_schema_spec()) # type: ignore[reportUnknownArgumentType] + elif isinstance(tool, HostedCodeInterpreterTool): + tool_definitions.append({"type": "code_interpreter"}) + elif isinstance(tool, HostedFileSearchTool): + params: dict[str, Any] = { + "type": "file_search", } + if tool.max_results is not None: + params["max_num_results"] = tool.max_results + tool_definitions.append(params) + elif isinstance(tool, MutableMapping): + tool_definitions.append(tool) + + if len(tool_definitions) > 0: + run_options["tools"] = tool_definitions + + if (mode := tool_mode["mode"]) == "required" and ( + func_name := tool_mode.get("required_function_name") + ) is not None: + run_options["tool_choice"] = { + "type": "function", + "function": {"name": func_name}, + } + else: + run_options["tool_choice"] = mode - if chat_options.response_format is not None: + if response_format is not None: + if isinstance(response_format, dict): + run_options["response_format"] = response_format + else: run_options["response_format"] = { "type": "json_schema", "json_schema": { - "name": chat_options.response_format.__name__, - "schema": chat_options.response_format.model_json_schema(), + "name": response_format.__name__, + "schema": response_format.model_json_schema(), + "strict": True, }, } diff --git a/python/packages/core/agent_framework/openai/_chat_client.py b/python/packages/core/agent_framework/openai/_chat_client.py index a2365b58f2..2d1ef8b463 100644 --- a/python/packages/core/agent_framework/openai/_chat_client.py +++ b/python/packages/core/agent_framework/openai/_chat_client.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, MutableSequence, Sequence from datetime import datetime, timezone from itertools import chain -from typing import Any, TypeVar +from typing import Any, Generic, Literal, TypedDict from openai import AsyncOpenAI, BadRequestError from openai.lib._parsing._completions import type_to_response_format_param @@ -49,34 +49,105 @@ from ._exceptions import OpenAIContentFilterException from ._shared import OpenAIBase, OpenAIConfigMixin, OpenAISettings +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: from typing_extensions import override # type: ignore[import] # pragma: no cover -__all__ = ["OpenAIChatClient"] +__all__ = ["OpenAIChatClient", "OpenAIChatOptions"] logger = get_logger("agent_framework.openai") +# region OpenAI Chat Options TypedDict + + +class PredictionTextContent(TypedDict, total=False): + """Prediction text content options for OpenAI Chat completions.""" + + type: Literal["text"] + text: str + + +class Prediction(TypedDict, total=False): + """Prediction options for OpenAI Chat completions.""" + + type: Literal["content"] + content: str | list[PredictionTextContent] + + +class OpenAIChatOptions(ChatOptions, total=False): + """OpenAI-specific chat options dict. + + Extends ChatOptions with options specific to OpenAI's Chat Completions API. + + Keys: + model_id: The model to use for the request, + translates to ``model`` in OpenAI API. + temperature: Sampling temperature between 0 and 2. + top_p: Nucleus sampling parameter. + max_tokens: Maximum number of tokens to generate, + translates to ``max_completion_tokens`` in OpenAI API. + stop: Stop sequences. + seed: Random seed for reproducibility. + frequency_penalty: Frequency penalty between -2.0 and 2.0. + presence_penalty: Presence penalty between -2.0 and 2.0. + tools: List of tools (functions) available to the model. + tool_choice: How the model should use tools. + allow_multiple_tool_calls: Whether to allow parallel tool calls, + translates to ``parallel_tool_calls`` in OpenAI API. + response_format: Structured output schema. + metadata: Request metadata for tracking. + user: End-user identifier for abuse monitoring. + store: Whether to store the conversation. + instructions: System instructions for the model (prepended as system message). + # OpenAI-specific options (supported by all models): + logit_bias: Token bias values (-100 to 100). + logprobs: Whether to return log probabilities. + top_logprobs: Number of top log probabilities to return (0-20). + prediction: Whether to use predicted return tokens. + """ + + # OpenAI-specific generation parameters (supported by all models) + logit_bias: dict[str | int, float] # type: ignore[misc] + logprobs: bool + top_logprobs: int + prediction: Prediction + + +TOpenAIChatOptions = TypeVar("TOpenAIChatOptions", bound=TypedDict, default="OpenAIChatOptions", covariant=True) # type: ignore[valid-type] + +OPTION_TRANSLATIONS: dict[str, str] = { + "model_id": "model", + "allow_multiple_tool_calls": "parallel_tool_calls", + "max_tokens": "max_completion_tokens", +} + + # region Base Client -class OpenAIBaseChatClient(OpenAIBase, BaseChatClient): +class OpenAIBaseChatClient(OpenAIBase, BaseChatClient[TOpenAIChatOptions], Generic[TOpenAIChatOptions]): """OpenAI Chat completion class.""" + @override async def _inner_get_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> ChatResponse: client = await self._ensure_client() # prepare - options_dict = self._prepare_options(messages, chat_options) + options_dict = self._prepare_options(messages, options) try: # execute and process return self._parse_response_from_openai( - await client.chat.completions.create(stream=False, **options_dict), chat_options + await client.chat.completions.create(stream=False, **options_dict), options ) except BadRequestError as ex: if ex.code == "content_filter": @@ -94,16 +165,17 @@ async def _inner_get_response( inner_exception=ex, ) from ex + @override async def _inner_get_streaming_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: client = await self._ensure_client() # prepare - options_dict = self._prepare_options(messages, chat_options) + options_dict = self._prepare_options(messages, options) options_dict["stream_options"] = {"include_usage": True} try: # execute and process @@ -129,49 +201,45 @@ async def _inner_get_streaming_response( # region content creation - def _prepare_tools_for_openai( - self, tools: Sequence[ToolProtocol | MutableMapping[str, Any]] - ) -> list[dict[str, Any]]: + def _prepare_tools_for_openai(self, tools: Sequence[ToolProtocol | MutableMapping[str, Any]]) -> dict[str, Any]: chat_tools: list[dict[str, Any]] = [] + web_search_options: dict[str, Any] | None = None for tool in tools: if isinstance(tool, ToolProtocol): match tool: case AIFunction(): chat_tools.append(tool.to_json_schema_spec()) + case HostedWebSearchTool(): + web_search_options = ( + { + "user_location": { + "approximate": tool.additional_properties.get("user_location", None), + "type": "approximate", + } + } + if tool.additional_properties and "user_location" in tool.additional_properties + else {} + ) case _: logger.debug("Unsupported tool passed (type: %s), ignoring", type(tool)) else: chat_tools.append(tool if isinstance(tool, dict) else dict(tool)) - return chat_tools + ret_dict: dict[str, Any] = {} + if chat_tools: + ret_dict["tools"] = chat_tools + if web_search_options is not None: + ret_dict["web_search_options"] = web_search_options + return ret_dict - def _process_web_search_tool( - self, tools: Sequence[ToolProtocol | MutableMapping[str, Any]] - ) -> dict[str, Any] | None: - for tool in tools: - if isinstance(tool, HostedWebSearchTool): - # Web search tool requires special handling - return ( - { - "user_location": { - "approximate": tool.additional_properties.get("user_location", None), - "type": "approximate", - } - } - if tool.additional_properties and "user_location" in tool.additional_properties - else {} - ) + def _prepare_options(self, messages: MutableSequence[ChatMessage], options: dict[str, Any]) -> dict[str, Any]: + # Prepend instructions from options if they exist + from .._types import prepend_instructions_to_messages, validate_tool_mode - return None + if instructions := options.get("instructions"): + messages = prepend_instructions_to_messages(list(messages), instructions, role="system") - def _prepare_options(self, messages: MutableSequence[ChatMessage], chat_options: ChatOptions) -> dict[str, Any]: - run_options = chat_options.to_dict( - exclude={ - "type", - "instructions", # included as system message - "response_format", # handled separately - "additional_properties", # handled separately - } - ) + # Start with a copy of options + run_options = {k: v for k, v in options.items() if v is not None and k not in {"instructions", "tools"}} # messages if messages and "messages" not in run_options: @@ -179,13 +247,8 @@ def _prepare_options(self, messages: MutableSequence[ChatMessage], chat_options: if "messages" not in run_options: raise ServiceInvalidRequestError("Messages are required for chat completions") - # Translation between ChatOptions and Chat Completion API - translations = { - "model_id": "model", - "allow_multiple_tool_calls": "parallel_tool_calls", - "max_tokens": "max_completion_tokens", - } - for old_key, new_key in translations.items(): + # Translation between options keys and Chat Completion API + for old_key, new_key in OPTION_TRANSLATIONS.items(): if old_key in run_options and old_key != new_key: run_options[new_key] = run_options.pop(old_key) @@ -196,32 +259,33 @@ def _prepare_options(self, messages: MutableSequence[ChatMessage], chat_options: run_options["model"] = self.model_id # tools - if chat_options.tools is not None: - # Preprocess web search tool if it exists - if web_search_options := self._process_web_search_tool(chat_options.tools): - run_options["web_search_options"] = web_search_options - run_options["tools"] = self._prepare_tools_for_openai(chat_options.tools) - if not run_options.get("tools", None): - run_options.pop("tools", None) + tools = options.get("tools") + if tools is not None: + run_options.update(self._prepare_tools_for_openai(tools)) + if not run_options.get("tools"): run_options.pop("parallel_tool_calls", None) run_options.pop("tool_choice", None) - # tool_choice: ToolMode serializes to {"type": "tool_mode", "mode": "..."}, extract mode - if (tool_choice := run_options.get("tool_choice")) and isinstance(tool_choice, dict) and "mode" in tool_choice: - run_options["tool_choice"] = tool_choice["mode"] + if tool_choice := run_options.pop("tool_choice", None): + tool_mode = validate_tool_mode(tool_choice) + if (mode := tool_mode.get("mode")) == "required" and ( + func_name := tool_mode.get("required_function_name") + ) is not None: + run_options["tool_choice"] = { + "type": "function", + "function": {"name": func_name}, + } + else: + run_options["tool_choice"] = mode # response format - if chat_options.response_format: - run_options["response_format"] = type_to_response_format_param(chat_options.response_format) - - # additional properties - additional_options = { - key: value for key, value in chat_options.additional_properties.items() if value is not None - } - if additional_options: - run_options.update(additional_options) + if response_format := options.get("response_format"): + if isinstance(response_format, dict): + run_options["response_format"] = response_format + else: + run_options["response_format"] = type_to_response_format_param(response_format) return run_options - def _parse_response_from_openai(self, response: ChatCompletion, chat_options: ChatOptions) -> "ChatResponse": + def _parse_response_from_openai(self, response: ChatCompletion, options: dict[str, Any]) -> "ChatResponse": """Parse a response from OpenAI into a ChatResponse.""" response_metadata = self._get_metadata_from_chat_response(response) messages: list[ChatMessage] = [] @@ -246,7 +310,7 @@ def _parse_response_from_openai(self, response: ChatCompletion, chat_options: Ch model_id=response.model, additional_properties=response_metadata, finish_reason=finish_reason, - response_format=chat_options.response_format, + response_format=options.get("response_format"), ) def _parse_response_update_from_openai( @@ -412,8 +476,11 @@ def _prepare_message_for_openai(self, message: ChatMessage) -> list[dict[str, An args["tool_calls"] = [self._prepare_content_for_openai(content)] # type: ignore case FunctionResultContent(): args["tool_call_id"] = content.call_id - if content.result is not None: - args["content"] = prepare_function_call_results(content.result) + # Always include content for tool results - API requires it even if empty + # Functions returning None should still have a tool result message + args["content"] = ( + prepare_function_call_results(content.result) if content.result is not None else "" + ) case TextReasoningContent(protected_data=protected_data) if protected_data is not None: all_messages[-1]["reasoning_details"] = json.loads(protected_data) case _: @@ -499,13 +566,11 @@ def service_url(self) -> str: # region Public client -TOpenAIChatClient = TypeVar("TOpenAIChatClient", bound="OpenAIChatClient") - @use_function_invocation @use_instrumentation @use_chat_middleware -class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient): +class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient[TOpenAIChatOptions], Generic[TOpenAIChatOptions]): """OpenAI Chat completion class.""" def __init__( @@ -549,14 +614,26 @@ def __init__( # Using environment variables # Set OPENAI_API_KEY=sk-... - # Set OPENAI_CHAT_MODEL_ID=gpt-4 + # Set OPENAI_CHAT_MODEL_ID= client = OpenAIChatClient() # Or passing parameters directly - client = OpenAIChatClient(model_id="gpt-4", api_key="sk-...") + client = OpenAIChatClient(model_id="", api_key="sk-...") # Or loading from a .env file client = OpenAIChatClient(env_file_path="path/to/.env") + + # Using custom ChatOptions with type safety: + from typing import TypedDict + from agent_framework.openai import OpenAIChatOptions + + + class MyOptions(OpenAIChatOptions, total=False): + my_custom_option: str + + + client: OpenAIChatClient[MyOptions] = OpenAIChatClient(model_id="") + response = await client.get_response("Hello", options={"my_custom_option": "value"}) """ try: openai_settings = OpenAISettings( diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index 6054c91ded..37a35ae9bc 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +import sys from collections.abc import ( AsyncIterable, Awaitable, @@ -11,7 +12,7 @@ ) from datetime import datetime, timezone from itertools import chain -from typing import Any, TypeVar, cast +from typing import Any, Generic, Literal, TypedDict, cast from openai import AsyncOpenAI, BadRequestError from openai.types.responses.file_search_tool_param import FileSearchToolParam @@ -30,9 +31,6 @@ Mcp, ToolParam, ) -from openai.types.responses.web_search_tool_param import ( - UserLocation as WebSearchUserLocation, -) from openai.types.responses.web_search_tool_param import WebSearchToolParam from pydantic import BaseModel, ValidationError @@ -78,6 +76,8 @@ UsageDetails, _parse_content, prepare_function_call_results, + prepend_instructions_to_messages, + validate_tool_mode, ) from ..exceptions import ( ServiceInitializationError, @@ -88,37 +88,150 @@ from ._exceptions import OpenAIContentFilterException from ._shared import OpenAIBase, OpenAIConfigMixin, OpenAISettings +if sys.version_info >= (3, 13): + from typing import TypeVar # type: ignore # pragma: no cover +else: + from typing_extensions import TypeVar # type: ignore # pragma: no cover +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover + logger = get_logger("agent_framework.openai") -__all__ = ["OpenAIResponsesClient"] +__all__ = ["OpenAIResponsesClient", "OpenAIResponsesOptions"] + + +# region OpenAI Responses Options TypedDict + + +class ReasoningOptions(TypedDict, total=False): + """Configuration options for reasoning models (gpt-5, o-series). + + See: https://platform.openai.com/docs/guides/reasoning + """ + + effort: Literal["low", "medium", "high"] + """The effort level for reasoning. Higher effort means more reasoning tokens.""" + + summary: Literal["auto", "concise", "detailed"] + """How to summarize reasoning in the response.""" + + +class StreamOptions(TypedDict, total=False): + """Options for streaming responses.""" + + include_usage: bool + """Whether to include usage statistics in stream events.""" + + +class OpenAIResponsesOptions(ChatOptions, total=False): + """OpenAI Responses API-specific chat options. + + Extends ChatOptions with options specific to OpenAI's Responses API. + These options provide fine-grained control over response generation, + reasoning, and API behavior. + + See: https://platform.openai.com/docs/api-reference/responses/create + """ + + # Responses API-specific parameters + + include: list[str] + """Additional output data to include in the response. + Supported values include: + - 'web_search_call.action.sources' + - 'code_interpreter_call.outputs' + - 'file_search_call.results' + - 'message.input_image.image_url' + - 'message.output_text.logprobs' + - 'reasoning.encrypted_content' + """ + + max_tool_calls: int + """Maximum number of total calls to built-in tools in a response.""" + + prompt: dict[str, Any] + """Reference to a prompt template and its variables. + Learn more: https://platform.openai.com/docs/guides/text#reusable-prompts""" + + prompt_cache_key: str + """Used by OpenAI to cache responses for similar requests. + Replaces the deprecated 'user' field for caching purposes.""" + + prompt_cache_retention: Literal["24h"] + """Retention policy for prompt cache. Set to '24h' for extended caching.""" + + reasoning: ReasoningOptions + """Configuration for reasoning models (gpt-5, o-series). + See: https://platform.openai.com/docs/guides/reasoning""" + + safety_identifier: str + """A stable identifier for detecting policy violations. + Recommend hashing username/email to avoid sending identifying info.""" + + service_tier: Literal["auto", "default", "flex", "priority"] + """Processing type for serving the request. + - 'auto': Use project settings + - 'default': Standard pricing/performance + - 'flex': Flexible processing + - 'priority': Priority processing""" + + stream_options: StreamOptions + """Options for streaming responses. Only set when stream=True.""" + + top_logprobs: int + """Number of most likely tokens (0-20) to return at each position.""" + + truncation: Literal["auto", "disabled"] + """Truncation strategy for model response. + - 'auto': Truncate from beginning if exceeds context + - 'disabled': Fail with 400 error if exceeds context""" + + +TOpenAIResponsesOptions = TypeVar( + "TOpenAIResponsesOptions", + bound=TypedDict, # type: ignore[valid-type] + default="OpenAIResponsesOptions", + covariant=True, +) + + +# endregion + # region ResponsesClient -class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient): +class OpenAIBaseResponsesClient( + OpenAIBase, + BaseChatClient[TOpenAIResponsesOptions], + Generic[TOpenAIResponsesOptions], +): """Base class for all OpenAI Responses based API's.""" FILE_SEARCH_MAX_RESULTS: int = 50 # region Inner Methods + @override async def _inner_get_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> ChatResponse: client = await self._ensure_client() # prepare - run_options = await self._prepare_options(messages, chat_options, **kwargs) + run_options = await self._prepare_options(messages, options, **kwargs) try: # execute and process if "text_format" in run_options: response = await client.responses.parse(stream=False, **run_options) else: response = await client.responses.create(stream=False, **run_options) - return self._parse_response_from_openai(response, chat_options=chat_options) + return self._parse_response_from_openai(response, options=options) except BadRequestError as ex: if ex.code == "content_filter": raise OpenAIContentFilterException( @@ -135,16 +248,17 @@ async def _inner_get_response( inner_exception=ex, ) from ex + @override async def _inner_get_streaming_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: client = await self._ensure_client() # prepare - run_options = await self._prepare_options(messages, chat_options, **kwargs) + run_options = await self._prepare_options(messages, options, **kwargs) function_call_ids: dict[int, tuple[str, str]] = {} # output_index: (call_id, name) try: # execute and process @@ -152,7 +266,7 @@ async def _inner_get_streaming_response( async for chunk in await client.responses.create(stream=True, **run_options): yield self._parse_chunk_from_openai( chunk, - chat_options=chat_options, + options=options, function_call_ids=function_call_ids, ) return @@ -160,7 +274,7 @@ async def _inner_get_streaming_response( async for chunk in response: yield self._parse_chunk_from_openai( chunk, - chat_options=chat_options, + options=options, function_call_ids=function_call_ids, ) except BadRequestError as ex: @@ -319,25 +433,30 @@ def _prepare_tools_for_openai( ) ) case HostedWebSearchTool(): - location: dict[str, str] | None = ( + web_search_tool = WebSearchToolParam(type="web_search") + if location := ( tool.additional_properties.get("user_location", None) if tool.additional_properties else None - ) - response_tools.append( - WebSearchToolParam( - type="web_search", - user_location=WebSearchUserLocation( - type="approximate", - city=location.get("city", None), - country=location.get("country", None), - region=location.get("region", None), - timezone=location.get("timezone", None), - ) - if location - else None, - ) - ) + ): + web_search_tool["user_location"] = { + "type": "approximate", + "city": location.get("city", None), + "country": location.get("country", None), + "region": location.get("region", None), + "timezone": location.get("timezone", None), + } + if filters := ( + tool.additional_properties.get("filters", None) if tool.additional_properties else None + ): + web_search_tool["filters"] = filters + if search_context_size := ( + tool.additional_properties.get("search_context_size", None) + if tool.additional_properties + else None + ): + web_search_tool["search_context_size"] = search_context_size + response_tools.append(web_search_tool) case HostedImageGenerationTool(): mapped_tool: dict[str, Any] = {"type": "image_generation"} if tool.options: @@ -389,37 +508,38 @@ def _prepare_mcp_tool(tool: HostedMCPTool) -> Mcp: async def _prepare_options( self, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> dict[str, Any]: - """Take ChatOptions and create the specific options for Responses API.""" - run_options: dict[str, Any] = chat_options.to_dict( - exclude={ - "type", - "presence_penalty", # not supported - "frequency_penalty", # not supported - "logit_bias", # not supported - "seed", # not supported - "stop", # not supported - "instructions", # already added as system message - "response_format", # handled separately - "conversation_id", # handled separately - "additional_properties", # handled separately - } - ) + """Take options dict and create the specific options for Responses API.""" + # Exclude keys that are not supported or handled separately + exclude_keys = { + "type", + "presence_penalty", # not supported + "frequency_penalty", # not supported + "logit_bias", # not supported + "seed", # not supported + "stop", # not supported + "instructions", # already added as system message + "response_format", # handled separately + "conversation_id", # handled separately + "tool_choice", # handled separately + } + run_options: dict[str, Any] = {k: v for k, v in options.items() if k not in exclude_keys and v is not None} + # messages + # Handle instructions by prepending to messages as system message + if instructions := options.get("instructions"): + messages = prepend_instructions_to_messages(list(messages), instructions, role="system") request_input = self._prepare_messages_for_openai(messages) if not request_input: raise ServiceInvalidRequestError("Messages are required for chat completions") run_options["input"] = request_input # model id - if not run_options.get("model"): - if not self.model_id: - raise ValueError("model_id must be a non-empty string") - run_options["model"] = self.model_id + self._check_model_presence(run_options) - # translations between ChatOptions and Responses API + # translations between options and Responses API translations = { "model_id": "model", "allow_multiple_tool_calls": "parallel_tool_calls", @@ -431,7 +551,7 @@ async def _prepare_options( run_options[new_key] = run_options.pop(old_key) # Handle different conversation ID formats - if conversation_id := self._get_current_conversation_id(chat_options, **kwargs): + if conversation_id := self._get_current_conversation_id(options, **kwargs): if conversation_id.startswith("resp_"): # For response IDs, set previous_response_id and remove conversation property run_options["previous_response_id"] = conversation_id @@ -443,32 +563,27 @@ async def _prepare_options( run_options["previous_response_id"] = conversation_id # tools - if tools := self._prepare_tools_for_openai(chat_options.tools): + if tools := self._prepare_tools_for_openai(options.get("tools")): run_options["tools"] = tools + # tool_choice: convert ToolMode to appropriate format + if tool_choice := options.get("tool_choice"): + tool_mode = validate_tool_mode(tool_choice) + if (mode := tool_mode.get("mode")) == "required" and ( + func_name := tool_mode.get("required_function_name") + ) is not None: + run_options["tool_choice"] = { + "type": "function", + "name": func_name, + } + else: + run_options["tool_choice"] = mode else: run_options.pop("parallel_tool_calls", None) run_options.pop("tool_choice", None) - # tool_choice: ToolMode serializes to {"type": "tool_mode", "mode": "..."}, extract mode - if (tool_choice := run_options.get("tool_choice")) and isinstance(tool_choice, dict) and "mode" in tool_choice: - run_options["tool_choice"] = tool_choice["mode"] - - # additional properties (excluding response_format which is handled separately) - additional_options = { - key: value - for key, value in chat_options.additional_properties.items() - if value is not None and key != "response_format" - } - if additional_options: - run_options.update(additional_options) - - # response format and text config (after additional_properties so user can pass text via additional_properties) - # Check both chat_options.response_format and additional_properties for response_format - response_format: Any = ( - chat_options.response_format - if chat_options.response_format is not None - else chat_options.additional_properties.get("response_format") - ) - text_config: Any = run_options.pop("text", None) + + # response format and text config + response_format = options.get("response_format") + text_config = run_options.pop("text", None) response_format, text_config = self._prepare_response_and_text_format( response_format=response_format, text_config=text_config ) @@ -479,9 +594,19 @@ async def _prepare_options( return run_options - def _get_current_conversation_id(self, chat_options: ChatOptions, **kwargs: Any) -> str | None: - """Get the current conversation ID from chat options or kwargs.""" - return chat_options.conversation_id or kwargs.get("conversation_id") + def _check_model_presence(self, options: dict[str, Any]) -> None: + """Check if the 'model' param is present, and if not raise a Error. + + Since AzureAIClients use a different param for this, this method is overridden in those clients. + """ + if not options.get("model"): + if not self.model_id: + raise ValueError("model_id must be a non-empty string") + options["model"] = self.model_id + + def _get_current_conversation_id(self, options: dict[str, Any], **kwargs: Any) -> str | None: + """Get the current conversation ID from options dict or kwargs.""" + return options.get("conversation_id") or kwargs.get("conversation_id") def _prepare_messages_for_openai(self, chat_messages: Sequence[ChatMessage]) -> list[dict[str, Any]]: """Prepare the chat messages for a request. @@ -673,7 +798,7 @@ def _prepare_content_for_openai( def _parse_response_from_openai( self, response: OpenAIResponse | ParsedResponse[BaseModel], - chat_options: ChatOptions, + options: dict[str, Any], ) -> "ChatResponse": """Parse an OpenAI Responses API response into a ChatResponse.""" structured_response: BaseModel | None = response.output_parsed if isinstance(response, ParsedResponse) else None # type: ignore[reportUnknownMemberType] @@ -911,20 +1036,22 @@ def _parse_response_from_openai( "raw_representation": response, } - if conversation_id := self._get_conversation_id(response, chat_options.store): + if conversation_id := self._get_conversation_id(response, options.get("store")): args["conversation_id"] = conversation_id if response.usage and (usage_details := self._parse_usage_from_openai(response.usage)): args["usage_details"] = usage_details if structured_response: args["value"] = structured_response - elif chat_options.response_format: - args["response_format"] = chat_options.response_format + elif (response_format := options.get("response_format")) and isinstance(response_format, type): + # Only pass response_format to ChatResponse if it's a Pydantic model type, + # not a runtime JSON schema dict + args["response_format"] = response_format return ChatResponse(**args) def _parse_chunk_from_openai( self, event: OpenAIResponseStreamEvent, - chat_options: ChatOptions, + options: dict[str, Any], function_call_ids: dict[int, tuple[str, str]], ) -> ChatResponseUpdate: """Parse an OpenAI Responses API streaming event into a ChatResponseUpdate.""" @@ -1016,13 +1143,13 @@ def _parse_chunk_from_openai( metadata.update(self._get_metadata_from_response(event)) case "response.created": response_id = event.response.id - conversation_id = self._get_conversation_id(event.response, chat_options.store) + conversation_id = self._get_conversation_id(event.response, options.get("store")) case "response.in_progress": response_id = event.response.id - conversation_id = self._get_conversation_id(event.response, chat_options.store) + conversation_id = self._get_conversation_id(event.response, options.get("store")) case "response.completed": response_id = event.response.id - conversation_id = self._get_conversation_id(event.response, chat_options.store) + conversation_id = self._get_conversation_id(event.response, options.get("store")) model = event.response.model if event.response.usage: usage = self._parse_usage_from_openai(event.response.usage) @@ -1289,13 +1416,14 @@ def _get_metadata_from_response(self, output: Any) -> dict[str, Any]: return {} -TOpenAIResponsesClient = TypeVar("TOpenAIResponsesClient", bound="OpenAIResponsesClient") - - @use_function_invocation @use_instrumentation @use_chat_middleware -class OpenAIResponsesClient(OpenAIConfigMixin, OpenAIBaseResponsesClient): +class OpenAIResponsesClient( + OpenAIConfigMixin, + OpenAIBaseResponsesClient[TOpenAIResponsesOptions], + Generic[TOpenAIResponsesOptions], +): """OpenAI Responses client class.""" def __init__( @@ -1348,6 +1476,18 @@ def __init__( # Or loading from a .env file client = OpenAIResponsesClient(env_file_path="path/to/.env") + + # Using custom ChatOptions with type safety: + from typing import TypedDict + from agent_framework.openai import OpenAIResponsesOptions + + + class MyOptions(OpenAIResponsesOptions, total=False): + my_custom_option: str + + + client: OpenAIResponsesClient[MyOptions] = OpenAIResponsesClient(model_id="gpt-4o") + response = await client.get_response("Hello", options={"my_custom_option": "value"}) """ try: openai_settings = OpenAISettings( diff --git a/python/packages/core/agent_framework/openai/_shared.py b/python/packages/core/agent_framework/openai/_shared.py index 77189168f1..1eef3624b0 100644 --- a/python/packages/core/agent_framework/openai/_shared.py +++ b/python/packages/core/agent_framework/openai/_shared.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import logging -from collections.abc import Awaitable, Callable, Mapping +from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence from copy import copy from typing import Any, ClassVar, Union @@ -24,7 +24,7 @@ from .._pydantic import AFBaseSettings from .._serialization import SerializationMixin from .._telemetry import APP_INFO, USER_AGENT_KEY, prepend_agent_framework_to_user_agent -from .._types import ChatOptions +from .._tools import AIFunction, HostedCodeInterpreterTool, HostedFileSearchTool, ToolProtocol from ..exceptions import ServiceInitializationError logger: logging.Logger = get_logger("agent_framework.openai") @@ -43,7 +43,7 @@ _legacy_response.HttpxBinaryResponseContent, ] -OPTION_TYPE = Union[ChatOptions, dict[str, Any]] +OPTION_TYPE = dict[str, Any] __all__ = ["OpenAISettings"] @@ -276,3 +276,74 @@ def __init__( # Ensure additional_properties and middleware are passed through kwargs to BaseChatClient # These are consumed by BaseChatClient.__init__ via kwargs super().__init__(**args, **kwargs) + + +def to_assistant_tools( + tools: Sequence[ToolProtocol | MutableMapping[str, Any]] | None, +) -> list[dict[str, Any]]: + """Convert Agent Framework tools to OpenAI Assistants API format. + + Args: + tools: Normalized tools (from ChatOptions.tools). + + Returns: + List of tool definitions for OpenAI Assistants API. + """ + if not tools: + return [] + + tool_definitions: list[dict[str, Any]] = [] + + for tool in tools: + if isinstance(tool, AIFunction): + tool_definitions.append(tool.to_json_schema_spec()) + elif isinstance(tool, HostedCodeInterpreterTool): + tool_definitions.append({"type": "code_interpreter"}) + elif isinstance(tool, HostedFileSearchTool): + params: dict[str, Any] = {"type": "file_search"} + if tool.max_results is not None: + params["file_search"] = {"max_num_results": tool.max_results} + tool_definitions.append(params) + elif isinstance(tool, MutableMapping): + # Pass through raw dict definitions + tool_definitions.append(dict(tool)) + + return tool_definitions + + +def from_assistant_tools( + assistant_tools: list[Any] | None, +) -> list[ToolProtocol]: + """Convert OpenAI Assistant tools to Agent Framework format. + + This converts hosted tools (code_interpreter, file_search) from an OpenAI + Assistant definition back to Agent Framework tool instances. + + Note: Function tools are skipped - user must provide implementations separately. + + Args: + assistant_tools: Tools from OpenAI Assistant object (assistant.tools). + + Returns: + List of Agent Framework tool instances for hosted tools. + """ + if not assistant_tools: + return [] + + tools: list[ToolProtocol] = [] + + for tool in assistant_tools: + if hasattr(tool, "type"): + tool_type = tool.type + elif isinstance(tool, dict): + tool_type = tool.get("type") + else: + tool_type = None + + if tool_type == "code_interpreter": + tools.append(HostedCodeInterpreterTool()) + elif tool_type == "file_search": + tools.append(HostedFileSearchTool()) + # Skip function tools - user must provide implementations + + return tools diff --git a/python/packages/core/pyproject.toml b/python/packages/core/pyproject.toml index 7096057690..c47f8eb8e6 100644 --- a/python/packages/core/pyproject.toml +++ b/python/packages/core/pyproject.toml @@ -4,7 +4,7 @@ description = "Microsoft Agent Framework for building AI Agents with Python. Thi authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" @@ -34,7 +34,7 @@ dependencies = [ # connectors and functions "openai>=1.99.0", "azure-identity>=1,<2", - "mcp[ws]>=1.23", + "mcp[ws]>=1.24.0,<2", "packaging>=24.1", ] diff --git a/python/packages/core/tests/azure/test_azure_assistants_client.py b/python/packages/core/tests/azure/test_azure_assistants_client.py index 758be68d3b..6c65dac7c1 100644 --- a/python/packages/core/tests/azure/test_azure_assistants_client.py +++ b/python/packages/core/tests/azure/test_azure_assistants_client.py @@ -9,8 +9,8 @@ from pydantic import Field from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentThread, ChatAgent, ChatClientProtocol, @@ -299,8 +299,7 @@ async def test_azure_assistants_client_get_response_tools() -> None: # Test that the client can be used to get a response response = await azure_assistants_client.get_response( messages=messages, - tools=[get_weather], - tool_choice="auto", + options={"tools": [get_weather], "tool_choice": "auto"}, ) assert response is not None @@ -352,8 +351,7 @@ async def test_azure_assistants_client_streaming_tools() -> None: # Test that the client can be used to get a response response = azure_assistants_client.get_streaming_response( messages=messages, - tools=[get_weather], - tool_choice="auto", + options={"tools": [get_weather], "tool_choice": "auto"}, ) full_message: str = "" async for chunk in response: @@ -405,7 +403,7 @@ async def test_azure_assistants_agent_basic_run(): response = await agent.run("Hello! Please respond with 'Hello World' exactly.") # Validate response - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert response.text is not None assert len(response.text) > 0 assert "Hello World" in response.text @@ -422,7 +420,7 @@ async def test_azure_assistants_agent_basic_run_streaming(): full_message: str = "" async for chunk in agent.run_stream("Please respond with exactly: 'This is a streaming response test.'"): assert chunk is not None - assert isinstance(chunk, AgentRunResponseUpdate) + assert isinstance(chunk, AgentResponseUpdate) if chunk.text: full_message += chunk.text @@ -446,14 +444,14 @@ async def test_azure_assistants_agent_thread_persistence(): first_response = await agent.run( "Remember this number: 42. What number did I just tell you to remember?", thread=thread ) - assert isinstance(first_response, AgentRunResponse) + assert isinstance(first_response, AgentResponse) assert "42" in first_response.text # Second message - test conversation memory second_response = await agent.run( "What number did I tell you to remember in my previous message?", thread=thread ) - assert isinstance(second_response, AgentRunResponse) + assert isinstance(second_response, AgentResponse) assert "42" in second_response.text # Verify thread has been populated with conversation ID @@ -477,7 +475,7 @@ async def test_azure_assistants_agent_existing_thread_id(): response1 = await agent.run("What's the weather in Paris?", thread=thread) # Validate first response - assert isinstance(response1, AgentRunResponse) + assert isinstance(response1, AgentResponse) assert response1.text is not None assert any(word in response1.text.lower() for word in ["weather", "paris"]) @@ -499,7 +497,7 @@ async def test_azure_assistants_agent_existing_thread_id(): response2 = await agent.run("What was the last city I asked about?", thread=thread) # Validate that the agent remembers the previous conversation - assert isinstance(response2, AgentRunResponse) + assert isinstance(response2, AgentResponse) assert response2.text is not None # Should reference Paris from the previous conversation assert "paris" in response2.text.lower() @@ -519,7 +517,7 @@ async def test_azure_assistants_agent_code_interpreter(): response = await agent.run("Write Python code to calculate the factorial of 5 and show the result.") # Validate response - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert response.text is not None # Factorial of 5 is 120 assert "120" in response.text or "factorial" in response.text.lower() @@ -538,7 +536,7 @@ async def test_azure_assistants_client_agent_level_tool_persistence(): # First run - agent-level tool should be available first_response = await agent.run("What's the weather like in Chicago?") - assert isinstance(first_response, AgentRunResponse) + assert isinstance(first_response, AgentResponse) assert first_response.text is not None # Should use the agent-level weather tool assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "72"]) @@ -546,7 +544,7 @@ async def test_azure_assistants_client_agent_level_tool_persistence(): # Second run - agent-level tool should still be available (persistence test) second_response = await agent.run("What's the weather in Miami?") - assert isinstance(second_response, AgentRunResponse) + assert isinstance(second_response, AgentResponse) assert second_response.text is not None # Should use the agent-level weather tool again assert any(term in second_response.text.lower() for term in ["miami", "sunny", "72"]) diff --git a/python/packages/core/tests/azure/test_azure_chat_client.py b/python/packages/core/tests/azure/test_azure_chat_client.py index 7da838529f..483b13f14f 100644 --- a/python/packages/core/tests/azure/test_azure_chat_client.py +++ b/python/packages/core/tests/azure/test_azure_chat_client.py @@ -17,8 +17,8 @@ from openai.types.chat.chat_completion_message import ChatCompletionMessage from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, BaseChatClient, ChatAgent, ChatClientProtocol, @@ -212,7 +212,7 @@ async def test_cmc_with_logit_bias( azure_chat_client = AzureOpenAIChatClient() - await azure_chat_client.get_response(messages=chat_history, logit_bias=token_bias) + await azure_chat_client.get_response(messages=chat_history, options={"logit_bias": token_bias}) mock_create.assert_awaited_once_with( model=azure_openai_unit_test_env["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], @@ -237,7 +237,7 @@ async def test_cmc_with_stop( azure_chat_client = AzureOpenAIChatClient() - await azure_chat_client.get_response(messages=chat_history, stop=stop) + await azure_chat_client.get_response(messages=chat_history, options={"stop": stop}) mock_create.assert_awaited_once_with( model=azure_openai_unit_test_env["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], @@ -300,7 +300,7 @@ async def test_azure_on_your_data( content = await azure_chat_client.get_response( messages=messages_in, - additional_properties={"extra_body": expected_data_settings}, + options={"extra_body": expected_data_settings}, ) assert len(content.messages) == 1 assert len(content.messages[0].contents) == 1 @@ -370,7 +370,7 @@ async def test_azure_on_your_data_string( content = await azure_chat_client.get_response( messages=messages_in, - additional_properties={"extra_body": expected_data_settings}, + options={"extra_body": expected_data_settings}, ) assert len(content.messages) == 1 assert len(content.messages[0].contents) == 1 @@ -429,7 +429,7 @@ async def test_azure_on_your_data_fail( content = await azure_chat_client.get_response( messages=messages_in, - additional_properties={"extra_body": expected_data_settings}, + options={"extra_body": expected_data_settings}, ) assert len(content.messages) == 1 assert len(content.messages[0].contents) == 1 @@ -652,13 +652,12 @@ async def test_azure_openai_chat_client_response_tools() -> None: # Test that the client can be used to get a response response = await azure_chat_client.get_response( messages=messages, - tools=[get_story_text], - tool_choice="auto", + options={"tools": [get_story_text], "tool_choice": "auto"}, ) assert response is not None assert isinstance(response, ChatResponse) - assert "scientists" in response.text + assert "Emily" in response.text or "David" in response.text @pytest.mark.flaky @@ -693,7 +692,7 @@ async def test_azure_openai_chat_client_streaming() -> None: if isinstance(content, TextContent) and content.text: full_message += content.text - assert "scientists" in full_message + assert "Emily" in full_message or "David" in full_message @pytest.mark.flaky @@ -709,8 +708,7 @@ async def test_azure_openai_chat_client_streaming_tools() -> None: # Test that the client can be used to get a response response = azure_chat_client.get_streaming_response( messages=messages, - tools=[get_story_text], - tool_choice="auto", + options={"tools": [get_story_text], "tool_choice": "auto"}, ) full_message: str = "" async for chunk in response: @@ -720,7 +718,7 @@ async def test_azure_openai_chat_client_streaming_tools() -> None: if isinstance(content, TextContent) and content.text: full_message += content.text - assert "scientists" in full_message + assert "Emily" in full_message or "David" in full_message @pytest.mark.flaky @@ -733,7 +731,7 @@ async def test_azure_openai_chat_client_agent_basic_run(): # Test basic run response = await agent.run("Please respond with exactly: 'This is a response test.'") - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert response.text is not None assert len(response.text) > 0 assert "response test" in response.text.lower() @@ -749,7 +747,7 @@ async def test_azure_openai_chat_client_agent_basic_run_streaming(): # Test streaming run full_text = "" async for chunk in agent.run_stream("Please respond with exactly: 'This is a streaming response test.'"): - assert isinstance(chunk, AgentRunResponseUpdate) + assert isinstance(chunk, AgentResponseUpdate) if chunk.text: full_text += chunk.text @@ -771,13 +769,13 @@ async def test_azure_openai_chat_client_agent_thread_persistence(): # First interaction response1 = await agent.run("My name is Alice. Remember this.", thread=thread) - assert isinstance(response1, AgentRunResponse) + assert isinstance(response1, AgentResponse) assert response1.text is not None # Second interaction - test memory response2 = await agent.run("What is my name?", thread=thread) - assert isinstance(response2, AgentRunResponse) + assert isinstance(response2, AgentResponse) assert response2.text is not None assert "alice" in response2.text.lower() @@ -797,7 +795,7 @@ async def test_azure_openai_chat_client_agent_existing_thread(): thread = first_agent.get_new_thread() first_response = await first_agent.run("My name is Alice. Remember this.", thread=thread) - assert isinstance(first_response, AgentRunResponse) + assert isinstance(first_response, AgentResponse) assert first_response.text is not None # Preserve the thread for reuse @@ -812,7 +810,7 @@ async def test_azure_openai_chat_client_agent_existing_thread(): # Reuse the preserved thread second_response = await second_agent.run("What is my name?", thread=preserved_thread) - assert isinstance(second_response, AgentRunResponse) + assert isinstance(second_response, AgentResponse) assert second_response.text is not None assert "alice" in second_response.text.lower() @@ -830,7 +828,7 @@ async def test_azure_chat_client_agent_level_tool_persistence(): # First run - agent-level tool should be available first_response = await agent.run("What's the weather like in Chicago?") - assert isinstance(first_response, AgentRunResponse) + assert isinstance(first_response, AgentResponse) assert first_response.text is not None # Should use the agent-level weather tool assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "72"]) @@ -838,7 +836,7 @@ async def test_azure_chat_client_agent_level_tool_persistence(): # Second run - agent-level tool should still be available (persistence test) second_response = await agent.run("What's the weather in Miami?") - assert isinstance(second_response, AgentRunResponse) + assert isinstance(second_response, AgentResponse) assert second_response.text is not None # Should use the agent-level weather tool again assert any(term in second_response.text.lower() for term in ["miami", "sunny", "72"]) diff --git a/python/packages/core/tests/azure/test_azure_responses_client.py b/python/packages/core/tests/azure/test_azure_responses_client.py index ec19eaf833..0e1c17f9a8 100644 --- a/python/packages/core/tests/azure/test_azure_responses_client.py +++ b/python/packages/core/tests/azure/test_azure_responses_client.py @@ -1,26 +1,25 @@ # Copyright (c) Microsoft. All rights reserved. +import json import os -from typing import Annotated +from typing import Annotated, Any import pytest from azure.identity import AzureCliCredential from pydantic import BaseModel +from pytest import param from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, - AgentThread, + AgentResponse, ChatAgent, ChatClientProtocol, ChatMessage, ChatResponse, - ChatResponseUpdate, HostedCodeInterpreterTool, HostedFileSearchTool, HostedMCPTool, HostedVectorStoreContent, - TextContent, + HostedWebSearchTool, ai_function, ) from agent_framework.azure import AzureOpenAIResponsesClient @@ -74,7 +73,7 @@ async def delete_vector_store(client: AzureOpenAIResponsesClient, file_id: str, def test_init(azure_openai_unit_test_env: dict[str, str]) -> None: # Test successful initialization - azure_responses_client = AzureOpenAIResponsesClient() + azure_responses_client = AzureOpenAIResponsesClient(credential=AzureCliCredential()) assert azure_responses_client.model_id == azure_openai_unit_test_env["AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME"] assert isinstance(azure_responses_client, ChatClientProtocol) @@ -141,283 +140,286 @@ def test_serialize(azure_openai_unit_test_env: dict[str, str]) -> None: assert "User-Agent" not in dumped_settings["default_headers"] -@pytest.mark.flaky -@skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_response() -> None: - """Test azure responses client responses.""" - azure_responses_client = AzureOpenAIResponsesClient(credential=AzureCliCredential()) - - assert isinstance(azure_responses_client, ChatClientProtocol) - - messages: list[ChatMessage] = [] - messages.append( - ChatMessage( - role="user", - text="Emily and David, two passionate scientists, met during a research expedition to Antarctica. " - "Bonded by their love for the natural world and shared curiosity, they uncovered a " - "groundbreaking phenomenon in glaciology that could potentially reshape our understanding " - "of climate change.", - ) - ) - messages.append(ChatMessage(role="user", text="who are Emily and David?")) - - # Test that the client can be used to get a response - response = await azure_responses_client.get_response(messages=messages) - - assert response is not None - assert isinstance(response, ChatResponse) - assert "scientists" in response.text - - messages.clear() - messages.append(ChatMessage(role="user", text="The weather in New York is sunny")) - messages.append(ChatMessage(role="user", text="What is the weather in New York?")) - - # Test that the client can be used to get a structured response - structured_response = await azure_responses_client.get_response( # type: ignore[reportAssignmentType] - messages=messages, - response_format=OutputStruct, - ) - - assert structured_response is not None - assert isinstance(structured_response, ChatResponse) - assert isinstance(structured_response.value, OutputStruct) - assert structured_response.value.location == "New York" - assert "sunny" in structured_response.value.weather.lower() +# region Integration Tests @pytest.mark.flaky @skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_response_tools() -> None: - """Test azure responses client tools.""" - azure_responses_client = AzureOpenAIResponsesClient(credential=AzureCliCredential()) - - assert isinstance(azure_responses_client, ChatClientProtocol) - - messages: list[ChatMessage] = [] - messages.append(ChatMessage(role="user", text="What is the weather in New York?")) - - # Test that the client can be used to get a response - response = await azure_responses_client.get_response( - messages=messages, - tools=[get_weather], - tool_choice="auto", - ) +@pytest.mark.parametrize( + "option_name,option_value,needs_validation", + [ + # Simple ChatOptions - just verify they don't fail + param("temperature", 0.7, False, id="temperature"), + param("top_p", 0.9, False, id="top_p"), + param("max_tokens", 500, False, id="max_tokens"), + param("seed", 123, False, id="seed"), + param("user", "test-user-id", False, id="user"), + param("metadata", {"test_key": "test_value"}, False, id="metadata"), + param("frequency_penalty", 0.5, False, id="frequency_penalty"), + param("presence_penalty", 0.3, False, id="presence_penalty"), + param("stop", ["END"], False, id="stop"), + param("allow_multiple_tool_calls", True, False, id="allow_multiple_tool_calls"), + param("tool_choice", "none", True, id="tool_choice_none"), + # OpenAIResponsesOptions - just verify they don't fail + param("safety_identifier", "user-hash-abc123", False, id="safety_identifier"), + param("truncation", "auto", False, id="truncation"), + param("top_logprobs", 5, False, id="top_logprobs"), + param("prompt_cache_key", "test-cache-key", False, id="prompt_cache_key"), + param("max_tool_calls", 3, False, id="max_tool_calls"), + # Complex options requiring output validation + param("tools", [get_weather], True, id="tools_function"), + param("tool_choice", "auto", True, id="tool_choice_auto"), + param( + "tool_choice", + {"mode": "required", "required_function_name": "get_weather"}, + True, + id="tool_choice_required", + ), + param("response_format", OutputStruct, True, id="response_format_pydantic"), + param( + "response_format", + { + "type": "json_schema", + "json_schema": { + "name": "WeatherDigest", + "strict": True, + "schema": { + "title": "WeatherDigest", + "type": "object", + "properties": { + "location": {"type": "string"}, + "conditions": {"type": "string"}, + "temperature_c": {"type": "number"}, + "advisory": {"type": "string"}, + }, + "required": ["location", "conditions", "temperature_c", "advisory"], + "additionalProperties": False, + }, + }, + }, + True, + id="response_format_runtime_json_schema", + ), + ], +) +async def test_integration_options( + option_name: str, + option_value: Any, + needs_validation: bool, +) -> None: + """Parametrized test covering all ChatOptions and OpenAIResponsesOptions. + + Tests both streaming and non-streaming modes for each option to ensure + they don't cause failures. Options marked with needs_validation also + check that the feature actually works correctly. + """ + client = AzureOpenAIResponsesClient(credential=AzureCliCredential()) + # to ensure toolmode required does not endlessly loop + client.function_invocation_configuration.max_iterations = 1 + + for streaming in [False, True]: + # Prepare test message + if option_name == "tools" or option_name == "tool_choice": + # Use weather-related prompt for tool tests + messages = [ChatMessage(role="user", text="What is the weather in Seattle?")] + elif option_name == "response_format": + # Use prompt that works well with structured output + messages = [ChatMessage(role="user", text="The weather in Seattle is sunny")] + messages.append(ChatMessage(role="user", text="What is the weather in Seattle?")) + else: + # Generic prompt for simple options + messages = [ChatMessage(role="user", text="Say 'Hello World' briefly.")] + + # Build options dict + options: dict[str, Any] = {option_name: option_value} + + # Add tools if testing tool_choice to avoid errors + if option_name == "tool_choice": + options["tools"] = [get_weather] + + if streaming: + # Test streaming mode + response_gen = client.get_streaming_response( + messages=messages, + options=options, + ) - assert response is not None - assert isinstance(response, ChatResponse) - assert "sunny" in response.text + output_format = option_value if option_name == "response_format" else None + response = await ChatResponse.from_chat_response_generator(response_gen, output_format_type=output_format) + else: + # Test non-streaming mode + response = await client.get_response( + messages=messages, + options=options, + ) - messages.clear() - messages.append(ChatMessage(role="user", text="What is the weather in Seattle?")) + assert response is not None + assert isinstance(response, ChatResponse) + assert response.text is not None, f"No text in response for option '{option_name}'" + assert len(response.text) > 0, f"Empty response for option '{option_name}'" + + # Validate based on option type + if needs_validation: + if option_name == "tools" or option_name == "tool_choice": + # Should have called the weather function + text = response.text.lower() + assert "sunny" in text or "seattle" in text, f"Tool not invoked for {option_name}" + elif option_name == "response_format": + if option_value == OutputStruct: + # Should have structured output + assert response.value is not None, "No structured output" + assert isinstance(response.value, OutputStruct) + assert "seattle" in response.value.location.lower() + else: + # Runtime JSON schema + assert response.value is None, "No structured output, can't parse any json." + response_value = json.loads(response.text) + assert isinstance(response_value, dict) + assert "location" in response_value + assert "seattle" in response_value["location"].lower() - # Test that the client can be used to get a response - structured_response: ChatResponse = await azure_responses_client.get_response( # type: ignore[reportAssignmentType] - messages=messages, - tools=[get_weather], - tool_choice="auto", - response_format=OutputStruct, - ) - assert structured_response is not None - assert isinstance(structured_response, ChatResponse) - assert isinstance(structured_response.value, OutputStruct) - assert "Seattle" in structured_response.value.location - assert "sunny" in structured_response.value.weather.lower() +@pytest.mark.flaky +@skip_if_azure_integration_tests_disabled +async def test_integration_web_search() -> None: + client = AzureOpenAIResponsesClient(credential=AzureCliCredential()) + + for streaming in [False, True]: + content = { + "messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.", + "options": { + "tool_choice": "auto", + "tools": [HostedWebSearchTool()], + }, + } + if streaming: + response = await ChatResponse.from_chat_response_generator(client.get_streaming_response(**content)) + else: + response = await client.get_response(**content) + + assert response is not None + assert isinstance(response, ChatResponse) + assert "Rumi" in response.text + assert "Mira" in response.text + assert "Zoey" in response.text + + # Test that the client will use the web search tool with location + additional_properties = { + "user_location": { + "country": "US", + "city": "Seattle", + } + } + content = { + "messages": "What is the current weather? Do not ask for my current location.", + "options": { + "tool_choice": "auto", + "tools": [HostedWebSearchTool(additional_properties=additional_properties)], + }, + } + if streaming: + response = await ChatResponse.from_chat_response_generator(client.get_streaming_response(**content)) + else: + response = await client.get_response(**content) + assert response.text is not None @pytest.mark.flaky @skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_streaming() -> None: - """Test Azure azure responses client streaming responses.""" +async def test_integration_client_file_search() -> None: + """Test Azure responses client with file search tool.""" azure_responses_client = AzureOpenAIResponsesClient(credential=AzureCliCredential()) - - assert isinstance(azure_responses_client, ChatClientProtocol) - - messages: list[ChatMessage] = [] - messages.append( - ChatMessage( - role="user", - text="Emily and David, two passionate scientists, met during a research expedition to Antarctica. " - "Bonded by their love for the natural world and shared curiosity, they uncovered a " - "groundbreaking phenomenon in glaciology that could potentially reshape our understanding " - "of climate change.", + file_id, vector_store = await create_vector_store(azure_responses_client) + try: + # Test that the client will use the file search tool + response = await azure_responses_client.get_response( + messages=[ + ChatMessage( + role="user", + text="What is the weather today? Do a file search to find the answer.", + ) + ], + options={"tools": [HostedFileSearchTool(inputs=vector_store)], "tool_choice": "auto"}, ) - ) - messages.append(ChatMessage(role="user", text="who are Emily and David?")) - - # Test that the client can be used to get a response - response = azure_responses_client.get_streaming_response(messages=messages) - - full_message: str = "" - async for chunk in response: - assert chunk is not None - assert isinstance(chunk, ChatResponseUpdate) - for content in chunk.contents: - if isinstance(content, TextContent) and content.text: - full_message += content.text - - assert "scientists" in full_message - - messages.clear() - messages.append(ChatMessage(role="user", text="The weather in Seattle is sunny")) - messages.append(ChatMessage(role="user", text="What is the weather in Seattle?")) - structured_response = await ChatResponse.from_chat_response_generator( - azure_responses_client.get_streaming_response( - messages=messages, - response_format=OutputStruct, - ), - output_format_type=OutputStruct, - ) - assert structured_response is not None - assert isinstance(structured_response, ChatResponse) - assert isinstance(structured_response.value, OutputStruct) - assert "Seattle" in structured_response.value.location - assert "sunny" in structured_response.value.weather.lower() + assert "sunny" in response.text.lower() + assert "75" in response.text + finally: + await delete_vector_store(azure_responses_client, file_id, vector_store.vector_store_id) @pytest.mark.flaky @skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_streaming_tools() -> None: - """Test azure responses client streaming tools.""" +async def test_integration_client_file_search_streaming() -> None: + """Test Azure responses client with file search tool and streaming.""" azure_responses_client = AzureOpenAIResponsesClient(credential=AzureCliCredential()) + file_id, vector_store = await create_vector_store(azure_responses_client) + # Test that the client will use the file search tool + try: + response = azure_responses_client.get_streaming_response( + messages=[ + ChatMessage( + role="user", + text="What is the weather today? Do a file search to find the answer.", + ) + ], + options={"tools": [HostedFileSearchTool(inputs=vector_store)], "tool_choice": "auto"}, + ) - assert isinstance(azure_responses_client, ChatClientProtocol) - - messages: list[ChatMessage] = [ChatMessage(role="user", text="What is the weather in Seattle?")] - - # Test that the client can be used to get a response - response = azure_responses_client.get_streaming_response( - messages=messages, - tools=[get_weather], - tool_choice="auto", - ) - full_message: str = "" - async for chunk in response: - assert chunk is not None - assert isinstance(chunk, ChatResponseUpdate) - for content in chunk.contents: - if isinstance(content, TextContent) and content.text: - full_message += content.text - - assert "sunny" in full_message - - messages.clear() - messages.append(ChatMessage(role="user", text="What is the weather in Seattle?")) - - structured_response = azure_responses_client.get_streaming_response( - messages=messages, - tools=[get_weather], - tool_choice="auto", - response_format=OutputStruct, - ) - full_message = "" - async for chunk in structured_response: - assert chunk is not None - assert isinstance(chunk, ChatResponseUpdate) - for content in chunk.contents: - if isinstance(content, TextContent) and content.text: - full_message += content.text - - output = OutputStruct.model_validate_json(full_message) - assert "Seattle" in output.location - assert "sunny" in output.weather.lower() + assert response is not None + full_response = await ChatResponse.from_chat_response_generator(response) + assert "sunny" in full_response.text.lower() + assert "75" in full_response.text + finally: + await delete_vector_store(azure_responses_client, file_id, vector_store.vector_store_id) @pytest.mark.flaky @skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_agent_basic_run(): - """Test Azure Responses Client agent basic run functionality with AzureOpenAIResponsesClient.""" - agent = AzureOpenAIResponsesClient(credential=AzureCliCredential()).create_agent( - instructions="You are a helpful assistant.", +async def test_integration_client_agent_hosted_mcp_tool() -> None: + """Integration test for HostedMCPTool with Azure Response Agent using Microsoft Learn MCP.""" + client = AzureOpenAIResponsesClient(credential=AzureCliCredential()) + response = await client.get_response( + "How to create an Azure storage account using az cli?", + options={ + # this needs to be high enough to handle the full MCP tool response. + "max_tokens": 5000, + "tools": HostedMCPTool( + name="Microsoft Learn MCP", + url="https://learn.microsoft.com/api/mcp", + description="A Microsoft Learn MCP server for documentation questions", + approval_mode="never_require", + ), + }, ) - - # Test basic run - response = await agent.run("Hello! Please respond with 'Hello World' exactly.") - - assert isinstance(response, AgentRunResponse) - assert response.text is not None - assert len(response.text) > 0 - assert "hello world" in response.text.lower() - - -@pytest.mark.flaky -@skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_agent_basic_run_streaming(): - """Test Azure Responses Client agent basic streaming functionality with AzureOpenAIResponsesClient.""" - async with ChatAgent( - chat_client=AzureOpenAIResponsesClient(credential=AzureCliCredential()), - ) as agent: - # Test streaming run - full_text = "" - async for chunk in agent.run_stream("Please respond with exactly: 'This is a streaming response test.'"): - assert isinstance(chunk, AgentRunResponseUpdate) - if chunk.text: - full_text += chunk.text - - assert len(full_text) > 0 - assert "streaming response test" in full_text.lower() - - -@pytest.mark.flaky -@skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_agent_thread_persistence(): - """Test Azure Responses Client agent thread persistence across runs with AzureOpenAIResponsesClient.""" - async with ChatAgent( - chat_client=AzureOpenAIResponsesClient(credential=AzureCliCredential()), - instructions="You are a helpful assistant with good memory.", - ) as agent: - # Create a new thread that will be reused - thread = agent.get_new_thread() - - # First interaction - first_response = await agent.run("My favorite programming language is Python. Remember this.", thread=thread) - - assert isinstance(first_response, AgentRunResponse) - assert first_response.text is not None - - # Second interaction - test memory - second_response = await agent.run("What is my favorite programming language?", thread=thread) - - assert isinstance(second_response, AgentRunResponse) - assert second_response.text is not None + assert isinstance(response, ChatResponse) + assert response.text + # Should contain Azure-related content since it's asking about Azure CLI + assert any(term in response.text.lower() for term in ["azure", "storage", "account", "cli"]) @pytest.mark.flaky @skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_agent_thread_storage_with_store_true(): - """Test Azure Responses Client agent with store=True to verify service_thread_id is returned.""" - async with ChatAgent( - chat_client=AzureOpenAIResponsesClient(credential=AzureCliCredential()), - instructions="You are a helpful assistant.", - ) as agent: - # Create a new thread - thread = AgentThread() - - # Initially, service_thread_id should be None - assert thread.service_thread_id is None - - # Run with store=True to store messages on Azure/OpenAI side - response = await agent.run( - "Hello! Please remember that my name is Alex.", - thread=thread, - store=True, - ) - - # Validate response - assert isinstance(response, AgentRunResponse) - assert response.text is not None - assert len(response.text) > 0 +async def test_integration_client_agent_hosted_code_interpreter_tool(): + """Test Azure Responses Client agent with HostedCodeInterpreterTool through AzureOpenAIResponsesClient.""" + client = AzureOpenAIResponsesClient(credential=AzureCliCredential()) - # After store=True, service_thread_id should be populated - assert thread.service_thread_id is not None - assert isinstance(thread.service_thread_id, str) - assert len(thread.service_thread_id) > 0 + response = await client.get_response( + "Calculate the sum of numbers from 1 to 10 using Python code.", + options={ + "tools": [HostedCodeInterpreterTool()], + }, + ) + # Should contain calculation result (sum of 1-10 = 55) or code execution content + contains_relevant_content = any( + term in response.text.lower() for term in ["55", "sum", "code", "python", "calculate", "10"] + ) + assert contains_relevant_content or len(response.text.strip()) > 10 @pytest.mark.flaky @skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_agent_existing_thread(): +async def test_integration_client_agent_existing_thread(): """Test Azure Responses Client agent with existing thread to continue conversations across agent instances.""" # First conversation - capture the thread preserved_thread = None @@ -428,9 +430,9 @@ async def test_azure_responses_client_agent_existing_thread(): ) as first_agent: # Start a conversation and capture the thread thread = first_agent.get_new_thread() - first_response = await first_agent.run("My hobby is photography. Remember this.", thread=thread) + first_response = await first_agent.run("My hobby is photography. Remember this.", thread=thread, store=True) - assert isinstance(first_response, AgentRunResponse) + assert isinstance(first_response, AgentResponse) assert first_response.text is not None # Preserve the thread for reuse @@ -445,192 +447,6 @@ async def test_azure_responses_client_agent_existing_thread(): # Reuse the preserved thread second_response = await second_agent.run("What is my hobby?", thread=preserved_thread) - assert isinstance(second_response, AgentRunResponse) + assert isinstance(second_response, AgentResponse) assert second_response.text is not None assert "photography" in second_response.text.lower() - - -@pytest.mark.flaky -@skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_agent_hosted_code_interpreter_tool(): - """Test Azure Responses Client agent with HostedCodeInterpreterTool through AzureOpenAIResponsesClient.""" - async with ChatAgent( - chat_client=AzureOpenAIResponsesClient(credential=AzureCliCredential()), - instructions="You are a helpful assistant that can execute Python code.", - tools=[HostedCodeInterpreterTool()], - ) as agent: - # Test code interpreter functionality - response = await agent.run("Calculate the sum of numbers from 1 to 10 using Python code.") - - assert isinstance(response, AgentRunResponse) - assert response.text is not None - assert len(response.text) > 0 - # Should contain calculation result (sum of 1-10 = 55) or code execution content - contains_relevant_content = any( - term in response.text.lower() for term in ["55", "sum", "code", "python", "calculate", "10"] - ) - assert contains_relevant_content or len(response.text.strip()) > 10 - - -@pytest.mark.flaky -@skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_agent_level_tool_persistence(): - """Test that agent-level tools persist across multiple runs with Azure Responses Client.""" - - async with ChatAgent( - chat_client=AzureOpenAIResponsesClient(credential=AzureCliCredential()), - instructions="You are a helpful assistant that uses available tools.", - tools=[get_weather], # Agent-level tool - ) as agent: - # First run - agent-level tool should be available - first_response = await agent.run("What's the weather like in Chicago?") - - assert isinstance(first_response, AgentRunResponse) - assert first_response.text is not None - # Should use the agent-level weather tool - assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "72"]) - - # Second run - agent-level tool should still be available (persistence test) - second_response = await agent.run("What's the weather in Miami?") - - assert isinstance(second_response, AgentRunResponse) - assert second_response.text is not None - # Should use the agent-level weather tool again - assert any(term in second_response.text.lower() for term in ["miami", "sunny", "72"]) - - -@pytest.mark.flaky -@skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_agent_chat_options_run_level() -> None: - """Integration test for comprehensive ChatOptions parameter coverage with Azure Response Agent.""" - async with ChatAgent( - chat_client=AzureOpenAIResponsesClient(credential=AzureCliCredential()), - instructions="You are a helpful assistant.", - ) as agent: - response = await agent.run( - "Provide a brief, helpful response.", - max_tokens=100, - temperature=0.7, - top_p=0.9, - seed=123, - user="comprehensive-test-user", - tools=[get_weather], - tool_choice="auto", - ) - - assert isinstance(response, AgentRunResponse) - assert response.text is not None - assert len(response.text) > 0 - - -@pytest.mark.flaky -@skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_agent_chat_options_agent_level() -> None: - """Integration test for comprehensive ChatOptions parameter coverage with Azure Response Agent.""" - async with ChatAgent( - chat_client=AzureOpenAIResponsesClient(credential=AzureCliCredential()), - instructions="You are a helpful assistant.", - max_tokens=100, - temperature=0.7, - top_p=0.9, - seed=123, - user="comprehensive-test-user", - tools=[get_weather], - tool_choice="auto", - ) as agent: - response = await agent.run( - "Provide a brief, helpful response.", - ) - - assert isinstance(response, AgentRunResponse) - assert response.text is not None - assert len(response.text) > 0 - - -@pytest.mark.flaky -@skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_agent_hosted_mcp_tool() -> None: - """Integration test for HostedMCPTool with Azure Response Agent using Microsoft Learn MCP.""" - - async with ChatAgent( - chat_client=AzureOpenAIResponsesClient(credential=AzureCliCredential()), - instructions="You are a helpful assistant that can help with microsoft documentation questions.", - tools=HostedMCPTool( - name="Microsoft Learn MCP", - url="https://learn.microsoft.com/api/mcp", - description="A Microsoft Learn MCP server for documentation questions", - approval_mode="never_require", - ), - ) as agent: - response = await agent.run( - "How to create an Azure storage account using az cli?", - # this needs to be high enough to handle the full MCP tool response. - max_tokens=5000, - ) - - assert isinstance(response, AgentRunResponse) - assert response.text - # Should contain Azure-related content since it's asking about Azure CLI - assert any(term in response.text.lower() for term in ["azure", "storage", "account", "cli"]) - - -@pytest.mark.flaky -@skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_file_search() -> None: - """Test Azure responses client with file search tool.""" - azure_responses_client = AzureOpenAIResponsesClient(credential=AzureCliCredential()) - - assert isinstance(azure_responses_client, ChatClientProtocol) - - file_id, vector_store = await create_vector_store(azure_responses_client) - # Test that the client will use the file search tool - response = await azure_responses_client.get_response( - messages=[ - ChatMessage( - role="user", - text="What is the weather today? Do a file search to find the answer.", - ) - ], - tools=[HostedFileSearchTool(inputs=vector_store)], - tool_choice="auto", - ) - - await delete_vector_store(azure_responses_client, file_id, vector_store.vector_store_id) - assert "sunny" in response.text.lower() - assert "75" in response.text - - -@pytest.mark.flaky -@skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_file_search_streaming() -> None: - """Test Azure responses client with file search tool and streaming.""" - azure_responses_client = AzureOpenAIResponsesClient(credential=AzureCliCredential()) - - assert isinstance(azure_responses_client, ChatClientProtocol) - - file_id, vector_store = await create_vector_store(azure_responses_client) - # Test that the client will use the file search tool - response = azure_responses_client.get_streaming_response( - messages=[ - ChatMessage( - role="user", - text="What is the weather today? Do a file search to find the answer.", - ) - ], - tools=[HostedFileSearchTool(inputs=vector_store)], - tool_choice="auto", - ) - - assert response is not None - full_message: str = "" - async for chunk in response: - assert chunk is not None - assert isinstance(chunk, ChatResponseUpdate) - for content in chunk.contents: - if isinstance(content, TextContent) and content.text: - full_message += content.text - - await delete_vector_store(azure_responses_client, file_id, vector_store.vector_store_id) - - assert "sunny" in full_message.lower() - assert "75" in full_message diff --git a/python/packages/core/tests/core/conftest.py b/python/packages/core/tests/core/conftest.py index ca524a4144..1561392214 100644 --- a/python/packages/core/tests/core/conftest.py +++ b/python/packages/core/tests/core/conftest.py @@ -4,7 +4,7 @@ import logging import sys from collections.abc import AsyncIterable, MutableSequence -from typing import Any +from typing import Any, Generic from unittest.mock import patch from uuid import uuid4 @@ -13,12 +13,11 @@ from agent_framework import ( AgentProtocol, - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentThread, BaseChatClient, ChatMessage, - ChatOptions, ChatResponse, ChatResponseUpdate, Role, @@ -28,6 +27,7 @@ use_chat_middleware, use_function_invocation, ) +from agent_framework._clients import TOptions_co if sys.version_info >= (3, 12): from typing import override # type: ignore @@ -113,7 +113,7 @@ async def get_streaming_response( @use_chat_middleware -class MockBaseChatClient(BaseChatClient): +class MockBaseChatClient(BaseChatClient[TOptions_co], Generic[TOptions_co]): """Mock implementation of the BaseChatClient.""" def __init__(self, **kwargs: Any): @@ -127,27 +127,27 @@ async def _inner_get_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> ChatResponse: """Send a chat request to the AI service. Args: messages: The chat messages to send. - chat_options: The options for the request. + options: The options dict for the request. kwargs: Any additional keyword arguments. Returns: The chat response contents representing the response(s). """ - logger.debug(f"Running base chat client inner, with: {messages=}, {chat_options=}, {kwargs=}") + logger.debug(f"Running base chat client inner, with: {messages=}, {options=}, {kwargs=}") self.call_count += 1 if not self.run_responses: return ChatResponse(messages=ChatMessage(role="assistant", text=f"test response - {messages[-1].text}")) response = self.run_responses.pop(0) - if chat_options.tool_choice == "none": + if options.get("tool_choice") == "none": return ChatResponse( messages=ChatMessage( role="assistant", @@ -163,14 +163,14 @@ async def _inner_get_streaming_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: - logger.debug(f"Running base chat client inner stream, with: {messages=}, {chat_options=}, {kwargs=}") + logger.debug(f"Running base chat client inner stream, with: {messages=}, {options=}, {kwargs=}") if not self.streaming_responses: yield ChatResponseUpdate(text=f"update - {messages[0].text}", role="assistant") return - if chat_options.tool_choice == "none": + if options.get("tool_choice") == "none": yield ChatResponseUpdate(text="I broke out of the function invocation loop...", role="assistant") return response = self.streaming_responses.pop(0) @@ -221,11 +221,6 @@ def name(self) -> str | None: """Returns the name of the agent.""" return "Name" - @property - def display_name(self) -> str: - """Returns the name of the agent.""" - return "Display Name" - @property def description(self) -> str | None: return "Description" @@ -236,9 +231,9 @@ async def run( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: logger.debug(f"Running mock agent, with: {messages=}, {thread=}, {kwargs=}") - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("Response")])]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("Response")])]) async def run_stream( self, @@ -246,9 +241,9 @@ async def run_stream( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: logger.debug(f"Running mock agent stream, with: {messages=}, {thread=}, {kwargs=}") - yield AgentRunResponseUpdate(contents=[TextContent("Response")]) + yield AgentResponseUpdate(contents=[TextContent("Response")]) def get_new_thread(self) -> AgentThread: return MockAgentThread() diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 7611df0cb0..ee9054c143 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -10,10 +10,9 @@ from agent_framework import ( AgentProtocol, - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentThread, - AggregateContextProvider, ChatAgent, ChatClientProtocol, ChatMessage, @@ -46,7 +45,7 @@ async def test_agent_run(agent: AgentProtocol) -> None: async def test_agent_run_streaming(agent: AgentProtocol) -> None: - async def collect_updates(updates: AsyncIterable[AgentRunResponseUpdate]) -> list[AgentRunResponseUpdate]: + async def collect_updates(updates: AsyncIterable[AgentResponseUpdate]) -> list[AgentResponseUpdate]: return [u async for u in updates] updates = await collect_updates(agent.run_stream(messages="test")) @@ -66,7 +65,6 @@ async def test_chat_client_agent_init(chat_client: ChatClientProtocol) -> None: assert agent.id == agent_id assert agent.name is None assert agent.description == "Test" - assert agent.display_name == agent_id # Display name defaults to id if name is None async def test_chat_client_agent_init_with_name(chat_client: ChatClientProtocol) -> None: @@ -76,7 +74,6 @@ async def test_chat_client_agent_init_with_name(chat_client: ChatClientProtocol) assert agent.id == agent_id assert agent.name == "Test Agent" assert agent.description == "Test" - assert agent.display_name == "Test Agent" # Display name is the name if present async def test_chat_client_agent_run(chat_client: ChatClientProtocol) -> None: @@ -90,7 +87,7 @@ async def test_chat_client_agent_run(chat_client: ChatClientProtocol) -> None: async def test_chat_client_agent_run_streaming(chat_client: ChatClientProtocol) -> None: agent = ChatAgent(chat_client=chat_client) - result = await AgentRunResponse.from_agent_response_generator(agent.run_stream("Hello")) + result = await AgentResponse.from_agent_response_generator(agent.run_stream("Hello")) assert result.text == "test streaming response another update" @@ -121,8 +118,8 @@ async def test_prepare_thread_does_not_mutate_agent_chat_options(chat_client: Ch tool = HostedCodeInterpreterTool() agent = ChatAgent(chat_client=chat_client, tools=[tool]) - assert agent.chat_options.tools is not None - base_tools = agent.chat_options.tools + assert agent.default_options.get("tools") is not None + base_tools = agent.default_options["tools"] thread = agent.get_new_thread() _, prepared_chat_options, _ = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] @@ -130,11 +127,11 @@ async def test_prepare_thread_does_not_mutate_agent_chat_options(chat_client: Ch input_messages=[ChatMessage(role=Role.USER, text="Test")], ) - assert prepared_chat_options.tools is not None - assert base_tools is not prepared_chat_options.tools + assert prepared_chat_options.get("tools") is not None + assert base_tools is not prepared_chat_options["tools"] - prepared_chat_options.tools.append(HostedCodeInterpreterTool()) # type: ignore[arg-type] - assert len(agent.chat_options.tools) == 1 + prepared_chat_options["tools"].append(HostedCodeInterpreterTool()) # type: ignore[arg-type] + assert len(agent.default_options["tools"]) == 1 async def test_chat_client_agent_update_thread_id(chat_client_base: ChatClientProtocol) -> None: @@ -255,7 +252,7 @@ async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], * async def test_chat_agent_context_providers_model_invoking(chat_client: ChatClientProtocol) -> None: """Test that context providers' invoking is called during agent run.""" mock_provider = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="Test context instructions")]) - agent = ChatAgent(chat_client=chat_client, context_providers=mock_provider) + agent = ChatAgent(chat_client=chat_client, context_provider=mock_provider) await agent.run("Hello") @@ -272,7 +269,7 @@ async def test_chat_agent_context_providers_thread_created(chat_client_base: Cha ) ] - agent = ChatAgent(chat_client=chat_client_base, context_providers=mock_provider) + agent = ChatAgent(chat_client=chat_client_base, context_provider=mock_provider) await agent.run("Hello") @@ -283,7 +280,7 @@ async def test_chat_agent_context_providers_thread_created(chat_client_base: Cha async def test_chat_agent_context_providers_messages_adding(chat_client: ChatClientProtocol) -> None: """Test that context providers' invoked is called during agent run.""" mock_provider = MockContextProvider() - agent = ChatAgent(chat_client=chat_client, context_providers=mock_provider) + agent = ChatAgent(chat_client=chat_client, context_provider=mock_provider) await agent.run("Hello") @@ -295,7 +292,7 @@ async def test_chat_agent_context_providers_messages_adding(chat_client: ChatCli async def test_chat_agent_context_instructions_in_messages(chat_client: ChatClientProtocol) -> None: """Test that AI context instructions are included in messages.""" mock_provider = MockContextProvider(messages=[ChatMessage(role="system", text="Context-specific instructions")]) - agent = ChatAgent(chat_client=chat_client, instructions="Agent instructions", context_providers=mock_provider) + agent = ChatAgent(chat_client=chat_client, instructions="Agent instructions", context_provider=mock_provider) # We need to test the _prepare_thread_and_messages method directly _, _, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] @@ -314,7 +311,7 @@ async def test_chat_agent_context_instructions_in_messages(chat_client: ChatClie async def test_chat_agent_no_context_instructions(chat_client: ChatClientProtocol) -> None: """Test behavior when AI context has no instructions.""" mock_provider = MockContextProvider() - agent = ChatAgent(chat_client=chat_client, instructions="Agent instructions", context_providers=mock_provider) + agent = ChatAgent(chat_client=chat_client, instructions="Agent instructions", context_provider=mock_provider) _, _, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] thread=None, input_messages=[ChatMessage(role=Role.USER, text="Hello")] @@ -329,10 +326,10 @@ async def test_chat_agent_no_context_instructions(chat_client: ChatClientProtoco async def test_chat_agent_run_stream_context_providers(chat_client: ChatClientProtocol) -> None: """Test that context providers work with run_stream method.""" mock_provider = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="Stream context instructions")]) - agent = ChatAgent(chat_client=chat_client, context_providers=mock_provider) + agent = ChatAgent(chat_client=chat_client, context_provider=mock_provider) # Collect all stream updates - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in agent.run_stream("Hello"): updates.append(update) @@ -343,44 +340,6 @@ async def test_chat_agent_run_stream_context_providers(chat_client: ChatClientPr assert mock_provider.invoked_called -async def test_chat_agent_multiple_context_providers(chat_client: ChatClientProtocol) -> None: - """Test that multiple context providers work together.""" - provider1 = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="First provider instructions")]) - provider2 = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="Second provider instructions")]) - - agent = ChatAgent(chat_client=chat_client, context_providers=[provider1, provider2]) - - await agent.run("Hello") - - # Both providers should be called - assert provider1.invoking_called - assert not provider1.thread_created_called - assert provider1.invoked_called - - assert provider2.invoking_called - assert not provider2.thread_created_called - assert provider2.invoked_called - - -async def test_chat_agent_aggregate_context_provider_combines_instructions() -> None: - """Test that AggregateContextProvider combines instructions from multiple providers.""" - provider1 = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="First instruction")]) - provider2 = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="Second instruction")]) - - aggregate = AggregateContextProvider() - aggregate.providers.append(provider1) - aggregate.providers.append(provider2) - - # Test invoking combines instructions - result = await aggregate.invoking([ChatMessage(role=Role.USER, text="Test")]) - - assert result.messages - assert isinstance(result.messages[0], ChatMessage) - assert isinstance(result.messages[1], ChatMessage) - assert result.messages[0].text == "First instruction" - assert result.messages[1].text == "Second instruction" - - async def test_chat_agent_context_providers_with_thread_service_id(chat_client_base: ChatClientProtocol) -> None: """Test context providers with service-managed thread.""" mock_provider = MockContextProvider() @@ -391,7 +350,7 @@ async def test_chat_agent_context_providers_with_thread_service_id(chat_client_b ) ] - agent = ChatAgent(chat_client=chat_client_base, context_providers=mock_provider) + agent = ChatAgent(chat_client=chat_client_base, context_provider=mock_provider) # Use existing service-managed thread thread = agent.get_new_thread(service_thread_id="existing-thread-id") @@ -481,9 +440,9 @@ async def test_chat_agent_as_tool_with_stream_callback(chat_client: ChatClientPr agent = ChatAgent(chat_client=chat_client, name="StreamingAgent") # Collect streaming updates - collected_updates: list[AgentRunResponseUpdate] = [] + collected_updates: list[AgentResponseUpdate] = [] - def stream_callback(update: AgentRunResponseUpdate) -> None: + def stream_callback(update: AgentResponseUpdate) -> None: collected_updates.append(update) tool = agent.as_tool(stream_callback=stream_callback) @@ -515,9 +474,9 @@ async def test_chat_agent_as_tool_with_async_stream_callback(chat_client: ChatCl agent = ChatAgent(chat_client=chat_client, name="AsyncStreamingAgent") # Collect streaming updates using an async callback - collected_updates: list[AgentRunResponseUpdate] = [] + collected_updates: list[AgentResponseUpdate] = [] - async def async_stream_callback(update: AgentRunResponseUpdate) -> None: + async def async_stream_callback(update: AgentResponseUpdate) -> None: collected_updates.append(update) tool = agent.as_tool(stream_callback=async_stream_callback) @@ -638,61 +597,68 @@ async def test_chat_agent_tool_choice_run_level_overrides_agent_level( chat_client_base: Any, ai_function_tool: Any ) -> None: """Verify that tool_choice passed to run() overrides agent-level tool_choice.""" - from agent_framework import ChatOptions, ToolMode - captured_options: list[ChatOptions] = [] + captured_options: list[dict[str, Any]] = [] # Store the original inner method original_inner = chat_client_base._inner_get_response async def capturing_inner( - *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + *, messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> ChatResponse: - captured_options.append(chat_options) - return await original_inner(messages=messages, chat_options=chat_options, **kwargs) + captured_options.append(options) + return await original_inner(messages=messages, options=options, **kwargs) chat_client_base._inner_get_response = capturing_inner # Create agent with agent-level tool_choice="auto" and a tool (tools required for tool_choice to be meaningful) - agent = ChatAgent(chat_client=chat_client_base, tool_choice="auto", tools=[ai_function_tool]) + agent = ChatAgent( + chat_client=chat_client_base, + tools=[ai_function_tool], + options={"tool_choice": "auto"}, + ) # Run with run-level tool_choice="required" - await agent.run("Hello", tool_choice="required") + await agent.run("Hello", options={"tool_choice": "required"}) # Verify the client received tool_choice="required", not "auto" assert len(captured_options) >= 1 - assert captured_options[0].tool_choice == "required" - assert captured_options[0].tool_choice == ToolMode.REQUIRED_ANY + assert captured_options[0]["tool_choice"] == "required" async def test_chat_agent_tool_choice_agent_level_used_when_run_level_not_specified( chat_client_base: Any, ai_function_tool: Any ) -> None: """Verify that agent-level tool_choice is used when run() doesn't specify one.""" - from agent_framework import ChatOptions, ToolMode + from agent_framework import ChatOptions captured_options: list[ChatOptions] = [] original_inner = chat_client_base._inner_get_response async def capturing_inner( - *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + *, messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> ChatResponse: - captured_options.append(chat_options) - return await original_inner(messages=messages, chat_options=chat_options, **kwargs) + captured_options.append(options) + return await original_inner(messages=messages, options=options, **kwargs) chat_client_base._inner_get_response = capturing_inner # Create agent with agent-level tool_choice="required" and a tool - agent = ChatAgent(chat_client=chat_client_base, tool_choice="required", tools=[ai_function_tool]) + agent = ChatAgent( + chat_client=chat_client_base, + tools=[ai_function_tool], + default_options={"tool_choice": "required"}, + ) # Run without specifying tool_choice await agent.run("Hello") # Verify the client received tool_choice="required" from agent-level assert len(captured_options) >= 1 - assert captured_options[0].tool_choice == "required" - assert captured_options[0].tool_choice == ToolMode.REQUIRED_ANY + assert captured_options[0]["tool_choice"] == "required" + # older code compared to ToolMode constants; ensure value is 'required' + assert captured_options[0]["tool_choice"] == "required" async def test_chat_agent_tool_choice_none_at_run_preserves_agent_level( @@ -706,19 +672,23 @@ async def test_chat_agent_tool_choice_none_at_run_preserves_agent_level( original_inner = chat_client_base._inner_get_response async def capturing_inner( - *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + *, messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> ChatResponse: - captured_options.append(chat_options) - return await original_inner(messages=messages, chat_options=chat_options, **kwargs) + captured_options.append(options) + return await original_inner(messages=messages, options=options, **kwargs) chat_client_base._inner_get_response = capturing_inner # Create agent with agent-level tool_choice="auto" and a tool - agent = ChatAgent(chat_client=chat_client_base, tool_choice="auto", tools=[ai_function_tool]) + agent = ChatAgent( + chat_client=chat_client_base, + tools=[ai_function_tool], + default_options={"tool_choice": "auto"}, + ) # Run with explicitly passing None (same as not specifying) - await agent.run("Hello", tool_choice=None) + await agent.run("Hello", options={"tool_choice": None}) # Verify the client received tool_choice="auto" from agent-level assert len(captured_options) >= 1 - assert captured_options[0].tool_choice == "auto" + assert captured_options[0]["tool_choice"] == "auto" diff --git a/python/packages/core/tests/core/test_chat_agent_integration.py b/python/packages/core/tests/core/test_chat_agent_integration.py new file mode 100644 index 0000000000..574c02fd61 --- /dev/null +++ b/python/packages/core/tests/core/test_chat_agent_integration.py @@ -0,0 +1,433 @@ +# Copyright (c) Microsoft. All rights reserved. + +import json +import os +from typing import Annotated + +import pytest +from pydantic import BaseModel + +from agent_framework import ( + AgentResponse, + AgentResponseUpdate, + AgentThread, + ChatAgent, + HostedCodeInterpreterTool, + HostedImageGenerationTool, + HostedMCPTool, + MCPStreamableHTTPTool, + ai_function, +) +from agent_framework.openai import OpenAIResponsesClient + +skip_if_openai_integration_tests_disabled = pytest.mark.skipif( + os.getenv("RUN_INTEGRATION_TESTS", "false").lower() != "true" + or os.getenv("OPENAI_API_KEY", "") in ("", "test-dummy-key"), + reason="No real OPENAI_API_KEY provided; skipping integration tests." + if os.getenv("RUN_INTEGRATION_TESTS", "false").lower() == "true" + else "Integration tests are disabled.", +) + + +@ai_function +async def get_weather(location: Annotated[str, "The location as a city name"]) -> str: + """Get the current weather in a given location.""" + # Implementation of the tool to get weather + return f"The current weather in {location} is sunny." + + +@pytest.mark.flaky +@skip_if_openai_integration_tests_disabled +async def test_openai_responses_client_agent_basic_run_streaming(): + """Test OpenAI Responses Client agent basic streaming functionality with OpenAIResponsesClient.""" + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + ) as agent: + # Test streaming run + full_text = "" + async for chunk in agent.run_stream("Please respond with exactly: 'This is a streaming response test.'"): + assert isinstance(chunk, AgentResponseUpdate) + if chunk.text: + full_text += chunk.text + + assert len(full_text) > 0 + assert "streaming response test" in full_text.lower() + + +@pytest.mark.flaky +@skip_if_openai_integration_tests_disabled +async def test_openai_responses_client_agent_thread_persistence(): + """Test OpenAI Responses Client agent thread persistence across runs with OpenAIResponsesClient.""" + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + instructions="You are a helpful assistant with good memory.", + ) as agent: + # Create a new thread that will be reused + thread = agent.get_new_thread() + + # First interaction + first_response = await agent.run("My favorite programming language is Python. Remember this.", thread=thread) + + assert isinstance(first_response, AgentResponse) + assert first_response.text is not None + + # Second interaction - test memory + second_response = await agent.run("What is my favorite programming language?", thread=thread) + + assert isinstance(second_response, AgentResponse) + assert second_response.text is not None + + +@pytest.mark.flaky +@skip_if_openai_integration_tests_disabled +async def test_openai_responses_client_agent_thread_storage_with_store_true(): + """Test OpenAI Responses Client agent with store=True to verify service_thread_id is returned.""" + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + instructions="You are a helpful assistant.", + ) as agent: + # Create a new thread + thread = AgentThread() + + # Initially, service_thread_id should be None + assert thread.service_thread_id is None + + # Run with store=True to store messages on OpenAI side + response = await agent.run( + "Hello! Please remember that my name is Alex.", + thread=thread, + options={"store": True}, + ) + + # Validate response + assert isinstance(response, AgentResponse) + assert response.text is not None + assert len(response.text) > 0 + + # After store=True, service_thread_id should be populated + assert thread.service_thread_id is not None + assert isinstance(thread.service_thread_id, str) + assert len(thread.service_thread_id) > 0 + + +@pytest.mark.flaky +@skip_if_openai_integration_tests_disabled +async def test_openai_responses_client_agent_existing_thread(): + """Test OpenAI Responses Client agent with existing thread to continue conversations across agent instances.""" + # First conversation - capture the thread + preserved_thread = None + + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + instructions="You are a helpful assistant with good memory.", + ) as first_agent: + # Start a conversation and capture the thread + thread = first_agent.get_new_thread() + first_response = await first_agent.run("My hobby is photography. Remember this.", thread=thread) + + assert isinstance(first_response, AgentResponse) + assert first_response.text is not None + + # Preserve the thread for reuse + preserved_thread = thread + + # Second conversation - reuse the thread in a new agent instance + if preserved_thread: + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + instructions="You are a helpful assistant with good memory.", + ) as second_agent: + # Reuse the preserved thread + second_response = await second_agent.run("What is my hobby?", thread=preserved_thread) + + assert isinstance(second_response, AgentResponse) + assert second_response.text is not None + assert "photography" in second_response.text.lower() + + +@pytest.mark.flaky +@skip_if_openai_integration_tests_disabled +async def test_openai_responses_client_agent_hosted_code_interpreter_tool(): + """Test OpenAI Responses Client agent with HostedCodeInterpreterTool through OpenAIResponsesClient.""" + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + instructions="You are a helpful assistant that can execute Python code.", + tools=[HostedCodeInterpreterTool()], + ) as agent: + # Test code interpreter functionality + response = await agent.run("Calculate the sum of numbers from 1 to 10 using Python code.") + + assert isinstance(response, AgentResponse) + assert response.text is not None + assert len(response.text) > 0 + # Should contain calculation result (sum of 1-10 = 55) or code execution content + contains_relevant_content = any( + term in response.text.lower() for term in ["55", "sum", "code", "python", "calculate", "10"] + ) + assert contains_relevant_content or len(response.text.strip()) > 10 + + +@pytest.mark.flaky +@skip_if_openai_integration_tests_disabled +async def test_openai_responses_client_agent_image_generation_tool(): + """Test OpenAI Responses Client agent with raw image_generation tool through OpenAIResponsesClient.""" + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + instructions="You are a helpful assistant that can generate images.", + tools=HostedImageGenerationTool(options={"image_size": "1024x1024", "media_type": "png"}), + ) as agent: + # Test image generation functionality + response = await agent.run("Generate an image of a cute red panda sitting on a tree branch in a forest.") + + assert isinstance(response, AgentResponse) + assert response.messages + + # Verify we got image content - look for ImageGenerationToolResultContent + image_content_found = False + for message in response.messages: + for content in message.contents: + if content.type == "image_generation_tool_result" and content.outputs: + image_content_found = True + break + if image_content_found: + break + + # The test passes if we got image content + assert image_content_found, "Expected to find image content in response" + + +@pytest.mark.flaky +@skip_if_openai_integration_tests_disabled +async def test_openai_responses_client_agent_level_tool_persistence(): + """Test that agent-level tools persist across multiple runs with OpenAI Responses Client.""" + + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + instructions="You are a helpful assistant that uses available tools.", + tools=[get_weather], # Agent-level tool + ) as agent: + # First run - agent-level tool should be available + first_response = await agent.run("What's the weather like in Chicago?") + + assert isinstance(first_response, AgentResponse) + assert first_response.text is not None + # Should use the agent-level weather tool + assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "72"]) + + # Second run - agent-level tool should still be available (persistence test) + second_response = await agent.run("What's the weather in Miami?") + + assert isinstance(second_response, AgentResponse) + assert second_response.text is not None + # Should use the agent-level weather tool again + assert any(term in second_response.text.lower() for term in ["miami", "sunny", "72"]) + + +@pytest.mark.flaky +@skip_if_openai_integration_tests_disabled +async def test_openai_responses_client_run_level_tool_isolation(): + """Test that run-level tools are isolated to specific runs and don't persist with OpenAI Responses Client.""" + # Counter to track how many times the weather tool is called + call_count = 0 + + @ai_function + async def get_weather_with_counter( + location: Annotated[str, "The location as a city name"], + ) -> str: + """Get the current weather in a given location.""" + nonlocal call_count + call_count += 1 + return f"The weather in {location} is sunny and 72°F." + + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + instructions="You are a helpful assistant.", + ) as agent: + # First run - use run-level tool + first_response = await agent.run( + "What's the weather like in Chicago?", + tools=[get_weather_with_counter], # Run-level tool + ) + + assert isinstance(first_response, AgentResponse) + assert first_response.text is not None + # Should use the run-level weather tool (call count should be 1) + assert call_count == 1 + assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "72"]) + + # Second run - run-level tool should NOT persist (key isolation test) + second_response = await agent.run("What's the weather like in Miami?") + + assert isinstance(second_response, AgentResponse) + assert second_response.text is not None + # Should NOT use the weather tool since it was only run-level in previous call + # Call count should still be 1 (no additional calls) + assert call_count == 1 + + +@pytest.mark.flaky +@skip_if_openai_integration_tests_disabled +async def test_openai_responses_client_agent_chat_options_agent_level() -> None: + """Integration test for comprehensive ChatOptions parameter coverage with OpenAI Response Agent.""" + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + instructions="You are a helpful assistant.", + tools=[get_weather], + default_options={ + "max_tokens": 100, + "temperature": 0.7, + "top_p": 0.9, + "seed": 123, + "user": "comprehensive-test-user", + "tool_choice": "auto", + }, + ) as agent: + response = await agent.run( + "Provide a brief, helpful response.", + ) + + assert isinstance(response, AgentResponse) + assert response.text is not None + assert len(response.text) > 0 + + +@pytest.mark.flaky +@skip_if_openai_integration_tests_disabled +async def test_openai_responses_client_agent_hosted_mcp_tool() -> None: + """Integration test for HostedMCPTool with OpenAI Response Agent using Microsoft Learn MCP.""" + + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + instructions="You are a helpful assistant that can help with microsoft documentation questions.", + tools=HostedMCPTool( + name="Microsoft Learn MCP", + url="https://learn.microsoft.com/api/mcp", + description="A Microsoft Learn MCP server for documentation questions", + approval_mode="never_require", + ), + ) as agent: + response = await agent.run( + "How to create an Azure storage account using az cli?", + # this needs to be high enough to handle the full MCP tool response. + options={"max_tokens": 5000}, + ) + + assert isinstance(response, AgentResponse) + assert response.text + # Should contain Azure-related content since it's asking about Azure CLI + assert any(term in response.text.lower() for term in ["azure", "storage", "account", "cli"]) + + +@pytest.mark.flaky +@skip_if_openai_integration_tests_disabled +async def test_openai_responses_client_agent_local_mcp_tool() -> None: + """Integration test for MCPStreamableHTTPTool with OpenAI Response Agent using Microsoft Learn MCP.""" + + mcp_tool = MCPStreamableHTTPTool( + name="Microsoft Learn MCP", + url="https://learn.microsoft.com/api/mcp", + ) + + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + instructions="You are a helpful assistant that can help with microsoft documentation questions.", + tools=[mcp_tool], + ) as agent: + response = await agent.run( + "How to create an Azure storage account using az cli?", + options={"max_tokens": 200}, + ) + + assert isinstance(response, AgentResponse) + assert response.text is not None + assert len(response.text) > 0 + # Should contain Azure-related content since it's asking about Azure CLI + assert any(term in response.text.lower() for term in ["azure", "storage", "account", "cli"]) + + +class ReleaseBrief(BaseModel): + """Structured output model for release brief testing.""" + + title: str + summary: str + highlights: list[str] + model_config = {"extra": "forbid"} + + +@pytest.mark.flaky +@skip_if_openai_integration_tests_disabled +async def test_openai_responses_client_agent_with_response_format_pydantic() -> None: + """Integration test for response_format with Pydantic model using OpenAI Responses Client.""" + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + instructions="You are a helpful assistant that returns structured JSON responses.", + ) as agent: + response = await agent.run( + "Summarize the following release notes into a ReleaseBrief:\n\n" + "Version 2.0 Release Notes:\n" + "- Added new streaming API for real-time responses\n" + "- Improved error handling with detailed messages\n" + "- Performance boost of 50% in batch processing\n" + "- Fixed memory leak in connection pooling", + options={ + "response_format": ReleaseBrief, + }, + ) + + # Validate response + assert isinstance(response, AgentResponse) + assert response.value is not None + assert isinstance(response.value, ReleaseBrief) + + # Validate structured output fields + brief = response.value + assert len(brief.title) > 0 + assert len(brief.summary) > 0 + assert len(brief.highlights) > 0 + + +@pytest.mark.flaky +@skip_if_openai_integration_tests_disabled +async def test_openai_responses_client_agent_with_runtime_json_schema() -> None: + """Integration test for response_format with runtime JSON schema using OpenAI Responses Client.""" + runtime_schema = { + "title": "WeatherDigest", + "type": "object", + "properties": { + "location": {"type": "string"}, + "conditions": {"type": "string"}, + "temperature_c": {"type": "number"}, + "advisory": {"type": "string"}, + }, + "required": ["location", "conditions", "temperature_c", "advisory"], + "additionalProperties": False, + } + + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + instructions="Return only JSON that matches the provided schema. Do not add commentary.", + ) as agent: + response = await agent.run( + "Give a brief weather digest for Seattle.", + options={ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": runtime_schema["title"], + "strict": True, + "schema": runtime_schema, + }, + }, + }, + ) + + # Validate response + assert isinstance(response, AgentResponse) + assert response.text is not None + + # Parse JSON and validate structure + parsed = json.loads(response.text) + assert "location" in parsed + assert "conditions" in parsed + assert "temperature_c" in parsed + assert "advisory" in parsed diff --git a/python/packages/core/tests/core/test_clients.py b/python/packages/core/tests/core/test_clients.py index 423a7e42b5..67ecd54a8d 100644 --- a/python/packages/core/tests/core/test_clients.py +++ b/python/packages/core/tests/core/test_clients.py @@ -7,7 +7,6 @@ BaseChatClient, ChatClientProtocol, ChatMessage, - ChatOptions, Role, ) @@ -50,12 +49,22 @@ async def test_chat_client_instructions_handling(chat_client_base: ChatClientPro chat_client_base, "_inner_get_response", ) as mock_inner_get_response: - await chat_client_base.get_response("hello", chat_options=ChatOptions(instructions=instructions)) + await chat_client_base.get_response("hello", options={"instructions": instructions}) mock_inner_get_response.assert_called_once() _, kwargs = mock_inner_get_response.call_args messages = kwargs.get("messages", []) - assert len(messages) == 2 - assert messages[0].role == Role.SYSTEM - assert messages[0].text == instructions - assert messages[1].role == Role.USER - assert messages[1].text == "hello" + assert len(messages) == 1 + assert messages[0].role == Role.USER + assert messages[0].text == "hello" + + from agent_framework._types import prepend_instructions_to_messages + + appended_messages = prepend_instructions_to_messages( + [ChatMessage(role=Role.USER, text="hello")], + instructions, + ) + assert len(appended_messages) == 2 + assert appended_messages[0].role == Role.SYSTEM + assert appended_messages[0].text == "You are a helpful assistant." + assert appended_messages[1].role == Role.USER + assert appended_messages[1].text == "hello" diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index bc96ddcc35..3aa8586a69 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from collections.abc import Awaitable, Callable +from typing import Any import pytest @@ -8,7 +9,6 @@ ChatAgent, ChatClientProtocol, ChatMessage, - ChatOptions, ChatResponse, ChatResponseUpdate, FunctionApprovalRequestContent, @@ -39,7 +39,7 @@ def ai_func(arg1: str) -> str: ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), ] - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[ai_func]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]}) assert exec_counter == 1 assert len(response.messages) == 3 assert response.messages[0].role == Role.ASSISTANT @@ -79,7 +79,7 @@ def ai_func(arg1: str) -> str: ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), ] - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[ai_func]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]}) assert exec_counter == 2 assert len(response.messages) == 5 assert response.messages[0].role == Role.ASSISTANT @@ -121,7 +121,9 @@ def ai_func(arg1: str) -> str: ], ] updates = [] - async for update in chat_client_base.get_streaming_response("hello", tool_choice="auto", tools=[ai_func]): + async for update in chat_client_base.get_streaming_response( + "hello", options={"tool_choice": "auto", "tools": [ai_func]} + ): updates.append(update) assert len(updates) == 4 # two updates with the function call, the function result and the final text assert updates[0].contents[0].call_id == "1" @@ -371,18 +373,18 @@ def func_with_approval(arg1: str) -> str: ] # Execute the test - chat_options = ChatOptions(tool_choice="auto", tools=tools) + options: dict[str, Any] = {"tool_choice": "auto", "tools": tools} if thread_type == "service": - # For service threads, we need to pass conversation_id via ChatOptions - chat_options.store = True - chat_options.conversation_id = conversation_id + # For service threads, we need to pass conversation_id via options + options["store"] = True + options["conversation_id"] = conversation_id if not streaming: - response = await chat_client_base.get_response("hello", chat_options=chat_options) + response = await chat_client_base.get_response("hello", options=options) messages = response.messages else: updates = [] - async for update in chat_client_base.get_streaming_response("hello", chat_options=chat_options): + async for update in chat_client_base.get_streaming_response("hello", options=options): updates.append(update) messages = updates @@ -492,7 +494,9 @@ def func_rejected(arg1: str) -> str: ] # Get the response with approval requests - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[func_approved, func_rejected]) + response = await chat_client_base.get_response( + "hello", options={"tool_choice": "auto", "tools": [func_approved, func_rejected]} + ) # Approval requests are now added to the assistant message, not a separate message assert len(response.messages) == 1 # Assistant message should have: 2 FunctionCallContent + 2 FunctionApprovalRequestContent @@ -519,7 +523,9 @@ def func_rejected(arg1: str) -> str: all_messages = response.messages + [ChatMessage(role="user", contents=[approved_response, rejected_response])] # Call get_response which will process the approvals - await chat_client_base.get_response(all_messages, tool_choice="auto", tools=[func_approved, func_rejected]) + await chat_client_base.get_response( + all_messages, options={"tool_choice": "auto", "tools": [func_approved, func_rejected]} + ) # Verify the approval/rejection was processed correctly # Find the results in the input messages (modified in-place) @@ -574,7 +580,9 @@ def func_with_approval(arg1: str) -> str: ), ] - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[func_with_approval]) + response = await chat_client_base.get_response( + "hello", options={"tool_choice": "auto", "tools": [func_with_approval]} + ) # Should have one assistant message containing both the call and approval request assert len(response.messages) == 1 @@ -610,7 +618,9 @@ def func_with_approval(arg1: str) -> str: ] # Get approval request - response1 = await chat_client_base.get_response("hello", tool_choice="auto", tools=[func_with_approval]) + response1 = await chat_client_base.get_response( + "hello", options={"tool_choice": "auto", "tools": [func_with_approval]} + ) # Store messages (like a thread would) persisted_messages = [ @@ -628,7 +638,9 @@ def func_with_approval(arg1: str) -> str: persisted_messages.append(ChatMessage(role="user", contents=[approval_response])) # Continue with all persisted messages - response2 = await chat_client_base.get_response(persisted_messages, tool_choice="auto", tools=[func_with_approval]) + response2 = await chat_client_base.get_response( + persisted_messages, options={"tool_choice": "auto", "tools": [func_with_approval]} + ) # Should execute successfully assert response2 is not None @@ -656,7 +668,9 @@ def func_with_approval(arg1: str) -> str: ChatResponse(messages=ChatMessage(role="assistant", text="done")), ] - response1 = await chat_client_base.get_response("hello", tool_choice="auto", tools=[func_with_approval]) + response1 = await chat_client_base.get_response( + "hello", options={"tool_choice": "auto", "tools": [func_with_approval]} + ) approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0] approval_response = FunctionApprovalResponseContent( @@ -666,7 +680,7 @@ def func_with_approval(arg1: str) -> str: ) all_messages = response1.messages + [ChatMessage(role="user", contents=[approval_response])] - await chat_client_base.get_response(all_messages, tool_choice="auto", tools=[func_with_approval]) + await chat_client_base.get_response(all_messages, options={"tool_choice": "auto", "tools": [func_with_approval]}) # Count function calls with the same call_id function_call_count = sum( @@ -699,7 +713,9 @@ def func_with_approval(arg1: str) -> str: ChatResponse(messages=ChatMessage(role="assistant", text="done")), ] - response1 = await chat_client_base.get_response("hello", tool_choice="auto", tools=[func_with_approval]) + response1 = await chat_client_base.get_response( + "hello", options={"tool_choice": "auto", "tools": [func_with_approval]} + ) approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0] rejection_response = FunctionApprovalResponseContent( @@ -709,7 +725,7 @@ def func_with_approval(arg1: str) -> str: ) all_messages = response1.messages + [ChatMessage(role="user", contents=[rejection_response])] - await chat_client_base.get_response(all_messages, tool_choice="auto", tools=[func_with_approval]) + await chat_client_base.get_response(all_messages, options={"tool_choice": "auto", "tools": [func_with_approval]}) # Find the rejection result rejection_result = next( @@ -753,7 +769,7 @@ def ai_func(arg1: str) -> str: # Set max_iterations to 1 in additional_properties chat_client_base.function_invocation_configuration.max_iterations = 1 - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[ai_func]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]}) # With max_iterations=1, we should: # 1. Execute first function call (exec_counter=1) @@ -780,7 +796,7 @@ def ai_func(arg1: str) -> str: # Disable function invocation chat_client_base.function_invocation_configuration.enabled = False - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[ai_func]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]}) # Function should not be executed - when enabled=False, the loop doesn't run assert exec_counter == 0 @@ -827,7 +843,7 @@ def error_func(arg1: str) -> str: # Set max_consecutive_errors to 2 chat_client_base.function_invocation_configuration.max_consecutive_errors_per_request = 2 - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[error_func]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]}) # Should stop after 2 consecutive errors and force a non-tool response error_results = [ @@ -870,7 +886,7 @@ def known_func(arg1: str) -> str: # Set terminate_on_unknown_calls to False (default) chat_client_base.function_invocation_configuration.terminate_on_unknown_calls = False - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[known_func]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [known_func]}) # Should have a result message indicating the tool wasn't found assert len(response.messages) == 3 @@ -904,7 +920,7 @@ def known_func(arg1: str) -> str: # Should raise an exception when encountering an unknown function with pytest.raises(KeyError, match='Error: Requested function "unknown_function" not found'): - await chat_client_base.get_response("hello", tool_choice="auto", tools=[known_func]) + await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [known_func]}) assert exec_counter == 0 @@ -940,7 +956,7 @@ def hidden_func(arg1: str) -> str: chat_client_base.function_invocation_configuration.additional_tools = [hidden_func] # Only pass visible_func in the tools parameter - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[visible_func]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [visible_func]}) # Additional tools are treated as declaration_only, so not executed # The function call should be in the messages but not executed @@ -976,7 +992,7 @@ def error_func(arg1: str) -> str: # Set include_detailed_errors to False (default) chat_client_base.function_invocation_configuration.include_detailed_errors = False - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[error_func]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]}) # Should have a generic error message error_result = next( @@ -1008,7 +1024,7 @@ def error_func(arg1: str) -> str: # Set include_detailed_errors to True chat_client_base.function_invocation_configuration.include_detailed_errors = True - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[error_func]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]}) # Should have detailed error message error_result = next( @@ -1076,7 +1092,7 @@ def typed_func(arg1: int) -> str: # Expects int, not str # Set include_detailed_errors to True chat_client_base.function_invocation_configuration.include_detailed_errors = True - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[typed_func]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]}) # Should have detailed validation error error_result = next( @@ -1108,7 +1124,7 @@ def typed_func(arg1: int) -> str: # Expects int, not str # Set include_detailed_errors to False (default) chat_client_base.function_invocation_configuration.include_detailed_errors = False - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[typed_func]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]}) # Should have generic validation error error_result = next( @@ -1175,7 +1191,7 @@ def test_func(arg1: str) -> str: ] # Get approval request - response1 = await chat_client_base.get_response("hello", tool_choice="auto", tools=[test_func]) + response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]}) approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0] @@ -1190,7 +1206,7 @@ def test_func(arg1: str) -> str: all_messages = response1.messages + [ChatMessage(role="user", contents=[rejection_response])] # This should handle the rejection gracefully (not raise ToolException to user) - await chat_client_base.get_response(all_messages, tool_choice="auto", tools=[test_func]) + await chat_client_base.get_response(all_messages, options={"tool_choice": "auto", "tools": [test_func]}) # Should have a rejection result rejection_result = next( @@ -1235,7 +1251,7 @@ def error_func(arg1: str) -> str: chat_client_base.function_invocation_configuration.include_detailed_errors = False # Get approval request - response1 = await chat_client_base.get_response("hello", tool_choice="auto", tools=[error_func]) + response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]}) approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0] @@ -1249,7 +1265,7 @@ def error_func(arg1: str) -> str: all_messages = response1.messages + [ChatMessage(role="user", contents=[approval_response])] # Execute the approved function (which will error) - await chat_client_base.get_response(all_messages, tool_choice="auto", tools=[error_func]) + await chat_client_base.get_response(all_messages, options={"tool_choice": "auto", "tools": [error_func]}) # Should have executed the function assert exec_counter == 1 @@ -1299,7 +1315,7 @@ def error_func(arg1: str) -> str: chat_client_base.function_invocation_configuration.include_detailed_errors = True # Get approval request - response1 = await chat_client_base.get_response("hello", tool_choice="auto", tools=[error_func]) + response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]}) approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0] @@ -1313,7 +1329,7 @@ def error_func(arg1: str) -> str: all_messages = response1.messages + [ChatMessage(role="user", contents=[approval_response])] # Execute the approved function (which will error) - await chat_client_base.get_response(all_messages, tool_choice="auto", tools=[error_func]) + await chat_client_base.get_response(all_messages, options={"tool_choice": "auto", "tools": [error_func]}) # Should have executed the function assert exec_counter == 1 @@ -1361,7 +1377,7 @@ def typed_func(arg1: int) -> str: # Expects int, not str chat_client_base.function_invocation_configuration.include_detailed_errors = True # Get approval request - response1 = await chat_client_base.get_response("hello", tool_choice="auto", tools=[typed_func]) + response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]}) approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0] @@ -1375,7 +1391,7 @@ def typed_func(arg1: int) -> str: # Expects int, not str all_messages = response1.messages + [ChatMessage(role="user", contents=[approval_response])] # Execute the approved function (which will fail validation) - await chat_client_base.get_response(all_messages, tool_choice="auto", tools=[typed_func]) + await chat_client_base.get_response(all_messages, options={"tool_choice": "auto", "tools": [typed_func]}) # Should NOT have executed the function (validation failed before execution) assert exec_counter == 0 @@ -1418,7 +1434,7 @@ def success_func(arg1: str) -> str: ] # Get approval request - response1 = await chat_client_base.get_response("hello", tool_choice="auto", tools=[success_func]) + response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [success_func]}) approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0] @@ -1432,7 +1448,7 @@ def success_func(arg1: str) -> str: all_messages = response1.messages + [ChatMessage(role="user", contents=[approval_response])] # Execute the approved function - await chat_client_base.get_response(all_messages, tool_choice="auto", tools=[success_func]) + await chat_client_base.get_response(all_messages, options={"tool_choice": "auto", "tools": [success_func]}) # Should have executed successfully assert exec_counter == 1 @@ -1476,7 +1492,9 @@ async def test_declaration_only_tool(chat_client_base: ChatClientProtocol): ChatResponse(messages=ChatMessage(role="assistant", text="done")), ] - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[declaration_func]) + response = await chat_client_base.get_response( + "hello", options={"tool_choice": "auto", "tools": [declaration_func]} + ) # Should have the function call in messages but not a result function_calls = [ @@ -1530,7 +1548,7 @@ async def func2(arg1: str) -> str: ChatResponse(messages=ChatMessage(role="assistant", text="done")), ] - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[func1, func2]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [func1, func2]}) # Both functions should have been executed assert "func1_start" in exec_order @@ -1566,7 +1584,7 @@ def plain_function(arg1: str) -> str: ] # Pass plain function (will be auto-converted) - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[plain_function]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [plain_function]}) # Function should be executed assert exec_counter == 1 @@ -1598,7 +1616,7 @@ def test_func(arg1: str) -> str: ), ] - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[test_func]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]}) # Should have executed the function results = [ @@ -1625,7 +1643,7 @@ def test_func(arg1: str) -> str: ChatResponse(messages=ChatMessage(role="assistant", text="done")), ] - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[test_func]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]}) # Should have messages with both function call and function result assert len(response.messages) >= 2 @@ -1667,7 +1685,7 @@ def sometimes_fails(arg1: str) -> str: ChatResponse(messages=ChatMessage(role="assistant", text="done")), ] - response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[sometimes_fails]) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [sometimes_fails]}) # Should have both an error and a success error_results = [ @@ -1714,7 +1732,7 @@ def func_with_approval(arg1: str) -> str: # Get the streaming response with approval request updates = [] async for update in chat_client_base.get_streaming_response( - "hello", tool_choice="auto", tools=[func_with_approval] + "hello", options={"tool_choice": "auto", "tools": [func_with_approval]} ): updates.append(update) @@ -1770,7 +1788,9 @@ def ai_func(arg1: str) -> str: chat_client_base.function_invocation_configuration.max_iterations = 1 updates = [] - async for update in chat_client_base.get_streaming_response("hello", tool_choice="auto", tools=[ai_func]): + async for update in chat_client_base.get_streaming_response( + "hello", options={"tool_choice": "auto", "tools": [ai_func]} + ): updates.append(update) # With max_iterations=1, we should only execute first function @@ -1798,7 +1818,9 @@ def ai_func(arg1: str) -> str: chat_client_base.function_invocation_configuration.enabled = False updates = [] - async for update in chat_client_base.get_streaming_response("hello", tool_choice="auto", tools=[ai_func]): + async for update in chat_client_base.get_streaming_response( + "hello", options={"tool_choice": "auto", "tools": [ai_func]} + ): updates.append(update) # Function should not be executed - when enabled=False, the loop doesn't run @@ -1841,7 +1863,9 @@ def error_func(arg1: str) -> str: chat_client_base.function_invocation_configuration.max_consecutive_errors_per_request = 2 updates = [] - async for update in chat_client_base.get_streaming_response("hello", tool_choice="auto", tools=[error_func]): + async for update in chat_client_base.get_streaming_response( + "hello", options={"tool_choice": "auto", "tools": [error_func]} + ): updates.append(update) # Should stop after 2 consecutive errors @@ -1887,7 +1911,9 @@ def known_func(arg1: str) -> str: chat_client_base.function_invocation_configuration.terminate_on_unknown_calls = False updates = [] - async for update in chat_client_base.get_streaming_response("hello", tool_choice="auto", tools=[known_func]): + async for update in chat_client_base.get_streaming_response( + "hello", options={"tool_choice": "auto", "tools": [known_func]} + ): updates.append(update) # Should have a result message indicating the tool wasn't found @@ -1926,7 +1952,9 @@ def known_func(arg1: str) -> str: # Should raise an exception when encountering an unknown function with pytest.raises(KeyError, match='Error: Requested function "unknown_function" not found'): - async for _ in chat_client_base.get_streaming_response("hello", tool_choice="auto", tools=[known_func]): + async for _ in chat_client_base.get_streaming_response( + "hello", options={"tool_choice": "auto", "tools": [known_func]} + ): pass assert exec_counter == 0 @@ -1953,7 +1981,9 @@ def error_func(arg1: str) -> str: chat_client_base.function_invocation_configuration.include_detailed_errors = True updates = [] - async for update in chat_client_base.get_streaming_response("hello", tool_choice="auto", tools=[error_func]): + async for update in chat_client_base.get_streaming_response( + "hello", options={"tool_choice": "auto", "tools": [error_func]} + ): updates.append(update) # Should have detailed error message @@ -1989,7 +2019,9 @@ def error_func(arg1: str) -> str: chat_client_base.function_invocation_configuration.include_detailed_errors = False updates = [] - async for update in chat_client_base.get_streaming_response("hello", tool_choice="auto", tools=[error_func]): + async for update in chat_client_base.get_streaming_response( + "hello", options={"tool_choice": "auto", "tools": [error_func]} + ): updates.append(update) # Should have a generic error message @@ -2023,7 +2055,9 @@ def typed_func(arg1: int) -> str: # Expects int, not str chat_client_base.function_invocation_configuration.include_detailed_errors = True updates = [] - async for update in chat_client_base.get_streaming_response("hello", tool_choice="auto", tools=[typed_func]): + async for update in chat_client_base.get_streaming_response( + "hello", options={"tool_choice": "auto", "tools": [typed_func]} + ): updates.append(update) # Should have detailed validation error @@ -2057,7 +2091,9 @@ def typed_func(arg1: int) -> str: # Expects int, not str chat_client_base.function_invocation_configuration.include_detailed_errors = False updates = [] - async for update in chat_client_base.get_streaming_response("hello", tool_choice="auto", tools=[typed_func]): + async for update in chat_client_base.get_streaming_response( + "hello", options={"tool_choice": "auto", "tools": [typed_func]} + ): updates.append(update) # Should have generic validation error @@ -2105,7 +2141,9 @@ async def func2(arg1: str) -> str: ] updates = [] - async for update in chat_client_base.get_streaming_response("hello", tool_choice="auto", tools=[func1, func2]): + async for update in chat_client_base.get_streaming_response( + "hello", options={"tool_choice": "auto", "tools": [func1, func2]} + ): updates.append(update) # Both functions should have been executed @@ -2144,7 +2182,7 @@ def func_with_approval(arg1: str) -> str: updates = [] async for update in chat_client_base.get_streaming_response( - "hello", tool_choice="auto", tools=[func_with_approval] + "hello", options={"tool_choice": "auto", "tools": [func_with_approval]} ): updates.append(update) @@ -2189,7 +2227,9 @@ def sometimes_fails(arg1: str) -> str: ] updates = [] - async for update in chat_client_base.get_streaming_response("hello", tool_choice="auto", tools=[sometimes_fails]): + async for update in chat_client_base.get_streaming_response( + "hello", options={"tool_choice": "auto", "tools": [sometimes_fails]} + ): updates.append(update) # Should have both an error and a success @@ -2246,8 +2286,7 @@ def ai_func(arg1: str) -> str: response = await chat_client_base.get_response( "hello", - tool_choice="auto", - tools=[ai_func], + options={"tool_choice": "auto", "tools": [ai_func]}, middleware=[TerminateLoopMiddleware()], ) @@ -2314,8 +2353,7 @@ def terminating_func(arg1: str) -> str: response = await chat_client_base.get_response( "hello", - tool_choice="auto", - tools=[normal_func, terminating_func], + options={"tool_choice": "auto", "tools": [normal_func, terminating_func]}, middleware=[SelectiveTerminateMiddleware()], ) @@ -2366,8 +2404,7 @@ def ai_func(arg1: str) -> str: updates = [] async for update in chat_client_base.get_streaming_response( "hello", - tool_choice="auto", - tools=[ai_func], + options={"tool_choice": "auto", "tools": [ai_func]}, middleware=[TerminateLoopMiddleware()], ): updates.append(update) diff --git a/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py b/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py new file mode 100644 index 0000000000..fc6acb435d --- /dev/null +++ b/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py @@ -0,0 +1,218 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for kwargs propagation from get_response() to @ai_function tools.""" + +from typing import Any + +from agent_framework import ( + ChatMessage, + ChatResponse, + ChatResponseUpdate, + FunctionCallContent, + TextContent, + ai_function, +) +from agent_framework._tools import _handle_function_calls_response, _handle_function_calls_streaming_response + + +class TestKwargsPropagationToAIFunction: + """Test cases for kwargs flowing from get_response() to @ai_function tools.""" + + async def test_kwargs_propagate_to_ai_function_with_kwargs(self) -> None: + """Test that kwargs passed to get_response() are available in @ai_function **kwargs.""" + captured_kwargs: dict[str, Any] = {} + + @ai_function + def capture_kwargs_tool(x: int, **kwargs: Any) -> str: + """A tool that captures kwargs for testing.""" + captured_kwargs.update(kwargs) + return f"result: x={x}" + + # Create a mock client + mock_client = type("MockClient", (), {})() + + call_count = [0] + + async def mock_get_response(self, messages, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + # First call: return a function call + return ChatResponse( + messages=[ + ChatMessage( + role="assistant", + contents=[ + FunctionCallContent(call_id="call_1", name="capture_kwargs_tool", arguments='{"x": 42}') + ], + ) + ] + ) + # Second call: return final response + return ChatResponse(messages=[ChatMessage(role="assistant", text="Done!")]) + + # Wrap the function with function invocation decorator + wrapped = _handle_function_calls_response(mock_get_response) + + # Call with custom kwargs that should propagate to the tool + # Note: tools are passed in options dict, custom kwargs are passed separately + result = await wrapped( + mock_client, + messages=[], + options={"tools": [capture_kwargs_tool]}, + user_id="user-123", + session_token="secret-token", + custom_data={"key": "value"}, + ) + + # Verify the tool was called and received the kwargs + assert "user_id" in captured_kwargs, f"Expected 'user_id' in captured kwargs: {captured_kwargs}" + assert captured_kwargs["user_id"] == "user-123" + assert "session_token" in captured_kwargs + assert captured_kwargs["session_token"] == "secret-token" + assert "custom_data" in captured_kwargs + assert captured_kwargs["custom_data"] == {"key": "value"} + # Verify result + assert result.messages[-1].text == "Done!" + + async def test_kwargs_not_forwarded_to_ai_function_without_kwargs(self) -> None: + """Test that kwargs are NOT forwarded to @ai_function that doesn't accept **kwargs.""" + + @ai_function + def simple_tool(x: int) -> str: + """A simple tool without **kwargs.""" + # This should not receive any extra kwargs + return f"result: x={x}" + + mock_client = type("MockClient", (), {})() + + call_count = [0] + + async def mock_get_response(self, messages, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return ChatResponse( + messages=[ + ChatMessage( + role="assistant", + contents=[FunctionCallContent(call_id="call_1", name="simple_tool", arguments='{"x": 99}')], + ) + ] + ) + return ChatResponse(messages=[ChatMessage(role="assistant", text="Completed!")]) + + wrapped = _handle_function_calls_response(mock_get_response) + + # Call with kwargs - the tool should work but not receive them + result = await wrapped( + mock_client, + messages=[], + options={"tools": [simple_tool]}, + user_id="user-123", # This kwarg should be ignored by the tool + ) + + # Verify the tool was called successfully (no error from extra kwargs) + assert result.messages[-1].text == "Completed!" + + async def test_kwargs_isolated_between_function_calls(self) -> None: + """Test that kwargs don't leak between different function call invocations.""" + invocation_kwargs: list[dict[str, Any]] = [] + + @ai_function + def tracking_tool(name: str, **kwargs: Any) -> str: + """A tool that tracks kwargs from each invocation.""" + invocation_kwargs.append(dict(kwargs)) + return f"called with {name}" + + mock_client = type("MockClient", (), {})() + + call_count = [0] + + async def mock_get_response(self, messages, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + # Two function calls in one response + return ChatResponse( + messages=[ + ChatMessage( + role="assistant", + contents=[ + FunctionCallContent( + call_id="call_1", name="tracking_tool", arguments='{"name": "first"}' + ), + FunctionCallContent( + call_id="call_2", name="tracking_tool", arguments='{"name": "second"}' + ), + ], + ) + ] + ) + return ChatResponse(messages=[ChatMessage(role="assistant", text="All done!")]) + + wrapped = _handle_function_calls_response(mock_get_response) + + # Call with kwargs + result = await wrapped( + mock_client, + messages=[], + options={"tools": [tracking_tool]}, + request_id="req-001", + trace_context={"trace_id": "abc"}, + ) + + # Both invocations should have received the same kwargs + assert len(invocation_kwargs) == 2 + for kwargs in invocation_kwargs: + assert kwargs.get("request_id") == "req-001" + assert kwargs.get("trace_context") == {"trace_id": "abc"} + assert result.messages[-1].text == "All done!" + + async def test_streaming_response_kwargs_propagation(self) -> None: + """Test that kwargs propagate to @ai_function in streaming mode.""" + captured_kwargs: dict[str, Any] = {} + + @ai_function + def streaming_capture_tool(value: str, **kwargs: Any) -> str: + """A tool that captures kwargs during streaming.""" + captured_kwargs.update(kwargs) + return f"processed: {value}" + + mock_client = type("MockClient", (), {})() + + call_count = [0] + + async def mock_get_streaming_response(self, messages, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + # First call: return function call update + yield ChatResponseUpdate( + role="assistant", + contents=[ + FunctionCallContent( + call_id="stream_call_1", + name="streaming_capture_tool", + arguments='{"value": "streaming-test"}', + ) + ], + is_finished=True, + ) + else: + # Second call: return final response + yield ChatResponseUpdate(text=TextContent(text="Stream complete!"), role="assistant", is_finished=True) + + wrapped = _handle_function_calls_streaming_response(mock_get_streaming_response) + + # Collect streaming updates + updates: list[ChatResponseUpdate] = [] + async for update in wrapped( + mock_client, + messages=[], + options={"tools": [streaming_capture_tool]}, + streaming_session="session-xyz", + correlation_id="corr-123", + ): + updates.append(update) + + # Verify kwargs were captured by the tool + assert "streaming_session" in captured_kwargs, f"Expected 'streaming_session' in {captured_kwargs}" + assert captured_kwargs["streaming_session"] == "session-xyz" + assert captured_kwargs["correlation_id"] == "corr-123" diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 18c90d64b3..c4e8cb09df 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -1248,6 +1248,75 @@ async def test_streamable_http_integration(): assert result[0].text is not None +@pytest.mark.flaky +@skip_if_mcp_integration_tests_disabled +async def test_mcp_connection_reset_integration(): + """Test that connection reset works correctly with a real MCP server. + + This integration test verifies: + 1. Initial connection and tool execution works + 2. Simulating connection failure triggers automatic reconnection + 3. Tool execution works after reconnection + 4. Exit stack cleanup happens properly during reconnection + """ + url = os.environ.get("LOCAL_MCP_URL") + + tool = MCPStreamableHTTPTool(name="integration_test", url=url) + + async with tool: + # Verify initial connection + assert tool.session is not None + assert tool.is_connected is True + assert len(tool.functions) > 0, "The MCP server should have at least one function." + + # Get the first function and invoke it + func = tool.functions[0] + first_result = await func.invoke(query="What is Agent Framework?") + assert first_result is not None + assert len(first_result) > 0 + + # Store the original session and exit stack for comparison + original_session = tool.session + original_exit_stack = tool._exit_stack + original_call_tool = tool.session.call_tool + + # Simulate connection failure by making call_tool raise ClosedResourceError once + call_count = 0 + + async def call_tool_with_error(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call fails with connection error + from anyio.streams.memory import ClosedResourceError + + raise ClosedResourceError + # After reconnection, delegate to the original method + return await original_call_tool(*args, **kwargs) + + tool.session.call_tool = call_tool_with_error + + # Invoke the function again - this should trigger automatic reconnection on ClosedResourceError + second_result = await func.invoke(query="What is Agent Framework?") + assert second_result is not None + assert len(second_result) > 0 + + # Verify we have a new session and exit stack after reconnection + assert tool.session is not None + assert tool.session is not original_session, "Session should be replaced after reconnection" + assert tool._exit_stack is not original_exit_stack, "Exit stack should be replaced after reconnection" + assert tool.is_connected is True + + # Verify tools are still available after reconnection + assert len(tool.functions) > 0 + + # Both results should be valid (we don't compare content as it may vary) + if hasattr(first_result[0], "text"): + assert first_result[0].text is not None + if hasattr(second_result[0], "text"): + assert second_result[0].text is not None + + async def test_mcp_tool_message_handler_notification(): """Test that message_handler correctly processes tools/list_changed and prompts/list_changed notifications.""" @@ -1512,24 +1581,18 @@ def test_mcp_streamable_http_tool_get_mcp_client_all_params(): tool = MCPStreamableHTTPTool( name="test", url="http://example.com", - headers={"Auth": "token"}, - timeout=30.0, - sse_read_timeout=10.0, terminate_on_close=True, - custom_param="test", ) - with patch("agent_framework._mcp.streamablehttp_client") as mock_http_client: + with patch("agent_framework._mcp.streamable_http_client") as mock_http_client: tool.get_mcp_client() - # Verify all parameters were passed + # Verify streamable_http_client was called with None for http_client + # (since we didn't provide one, the API will create its own) mock_http_client.assert_called_once_with( url="http://example.com", - headers={"Auth": "token"}, - timeout=30.0, - sse_read_timeout=10.0, + http_client=None, terminate_on_close=True, - custom_param="test", ) @@ -1555,7 +1618,6 @@ def test_mcp_websocket_tool_get_mcp_client_with_kwargs(): ) -@pytest.mark.asyncio async def test_mcp_tool_deduplication(): """Test that MCP tools are not duplicated in MCPTool""" from agent_framework._mcp import MCPTool @@ -1617,7 +1679,6 @@ async def test_mcp_tool_deduplication(): assert added_count == 1 # Only 1 new function added -@pytest.mark.asyncio async def test_load_tools_prevents_multiple_calls(): """Test that connect() prevents calling load_tools() multiple times""" from unittest.mock import AsyncMock, MagicMock @@ -1633,6 +1694,7 @@ async def test_load_tools_prevents_multiple_calls(): mock_session = AsyncMock() mock_tool_list = MagicMock() mock_tool_list.tools = [] + mock_tool_list.nextCursor = None # No pagination mock_session.list_tools = AsyncMock(return_value=mock_tool_list) mock_session.initialize = AsyncMock() @@ -1656,7 +1718,6 @@ async def test_load_tools_prevents_multiple_calls(): assert mock_session.list_tools.call_count == 1 # Still 1, not incremented -@pytest.mark.asyncio async def test_load_prompts_prevents_multiple_calls(): """Test that connect() prevents calling load_prompts() multiple times""" from unittest.mock import AsyncMock, MagicMock @@ -1672,6 +1733,7 @@ async def test_load_prompts_prevents_multiple_calls(): mock_session = AsyncMock() mock_prompt_list = MagicMock() mock_prompt_list.prompts = [] + mock_prompt_list.nextCursor = None # No pagination mock_session.list_prompts = AsyncMock(return_value=mock_prompt_list) tool.session = mock_session @@ -1692,3 +1754,613 @@ async def test_load_prompts_prevents_multiple_calls(): tool._prompts_loaded = True assert mock_session.list_prompts.call_count == 1 # Still 1, not incremented + + +async def test_mcp_streamable_http_tool_httpx_client_cleanup(): + """Test that MCPStreamableHTTPTool properly passes through httpx clients.""" + from unittest.mock import AsyncMock, Mock, patch + + from agent_framework import MCPStreamableHTTPTool + + # Mock the streamable_http_client to avoid actual connections + with ( + patch("agent_framework._mcp.streamable_http_client") as mock_client, + patch("agent_framework._mcp.ClientSession") as mock_session_class, + ): + # Setup mock context manager for streamable_http_client + mock_transport = (Mock(), Mock()) + mock_context_manager = Mock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + mock_client.return_value = mock_context_manager + + # Setup mock session + mock_session = Mock() + mock_session.initialize = AsyncMock() + mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_class.return_value.__aexit__ = AsyncMock(return_value=None) + + # Test 1: Tool without provided client (passes None to streamable_http_client) + tool1 = MCPStreamableHTTPTool( + name="test", + url="http://localhost:8081/mcp", + load_tools=False, + load_prompts=False, + terminate_on_close=False, + ) + await tool1.connect() + # When no client is provided, _httpx_client should be None + assert tool1._httpx_client is None, "httpx client should be None when not provided" + + # Test 2: Tool with user-provided client + user_client = Mock() + tool2 = MCPStreamableHTTPTool( + name="test", + url="http://localhost:8081/mcp", + load_tools=False, + load_prompts=False, + terminate_on_close=False, + http_client=user_client, + ) + await tool2.connect() + + # Verify the user-provided client was stored + assert tool2._httpx_client is user_client, "User-provided client should be stored" + + # Verify streamable_http_client was called with the user's client + # Get the last call (should be from tool2.connect()) + call_args = mock_client.call_args + assert call_args.kwargs["http_client"] is user_client, "User's client should be passed through" + + +async def test_load_tools_with_pagination(): + """Test that load_tools handles pagination correctly.""" + from unittest.mock import AsyncMock, MagicMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_tools_flag = True + + # Create paginated responses + page1 = MagicMock() + page1.tools = [ + types.Tool( + name="tool_1", + description="First tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + types.Tool( + name="tool_2", + description="Second tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + ] + page1.nextCursor = "cursor_page2" + + page2 = MagicMock() + page2.tools = [ + types.Tool( + name="tool_3", + description="Third tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + ] + page2.nextCursor = "cursor_page3" + + page3 = MagicMock() + page3.tools = [ + types.Tool( + name="tool_4", + description="Fourth tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + ] + page3.nextCursor = None # No more pages + + # Mock list_tools to return different pages based on params + async def mock_list_tools(params=None): + if params is None: + return page1 + if params.cursor == "cursor_page2": + return page2 + if params.cursor == "cursor_page3": + return page3 + raise ValueError("Unexpected cursor value") + + mock_session.list_tools = AsyncMock(side_effect=mock_list_tools) + + # Load tools with pagination + await tool.load_tools() + + # Verify all pages were fetched + assert mock_session.list_tools.call_count == 3 + assert len(tool._functions) == 4 + assert [f.name for f in tool._functions] == ["tool_1", "tool_2", "tool_3", "tool_4"] + + +async def test_load_prompts_with_pagination(): + """Test that load_prompts handles pagination correctly.""" + from unittest.mock import AsyncMock, MagicMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_prompts_flag = True + + # Create paginated responses + page1 = MagicMock() + page1.prompts = [ + types.Prompt( + name="prompt_1", + description="First prompt", + arguments=[types.PromptArgument(name="arg1", description="Arg 1", required=True)], + ), + types.Prompt( + name="prompt_2", + description="Second prompt", + arguments=[types.PromptArgument(name="arg2", description="Arg 2", required=True)], + ), + ] + page1.nextCursor = "cursor_page2" + + page2 = MagicMock() + page2.prompts = [ + types.Prompt( + name="prompt_3", + description="Third prompt", + arguments=[types.PromptArgument(name="arg3", description="Arg 3", required=False)], + ), + ] + page2.nextCursor = None # No more pages + + # Mock list_prompts to return different pages based on params + async def mock_list_prompts(params=None): + if params is None: + return page1 + if params.cursor == "cursor_page2": + return page2 + raise ValueError("Unexpected cursor value") + + mock_session.list_prompts = AsyncMock(side_effect=mock_list_prompts) + + # Load prompts with pagination + await tool.load_prompts() + + # Verify all pages were fetched + assert mock_session.list_prompts.call_count == 2 + assert len(tool._functions) == 3 + assert [f.name for f in tool._functions] == ["prompt_1", "prompt_2", "prompt_3"] + + +async def test_load_tools_pagination_with_duplicates(): + """Test that load_tools prevents duplicates across paginated results.""" + from unittest.mock import AsyncMock, MagicMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_tools_flag = True + + # Create paginated responses with duplicate tool names + page1 = MagicMock() + page1.tools = [ + types.Tool( + name="tool_1", + description="First tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + types.Tool( + name="tool_2", + description="Second tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + ] + page1.nextCursor = "cursor_page2" + + page2 = MagicMock() + page2.tools = [ + types.Tool( + name="tool_1", # Duplicate from page1 + description="Duplicate tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + types.Tool( + name="tool_3", + description="Third tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + ] + page2.nextCursor = None + + # Mock list_tools to return different pages + async def mock_list_tools(params=None): + if params is None: + return page1 + if params.cursor == "cursor_page2": + return page2 + raise ValueError("Unexpected cursor value") + + mock_session.list_tools = AsyncMock(side_effect=mock_list_tools) + + # Load tools with pagination + await tool.load_tools() + + # Verify duplicates were skipped + assert mock_session.list_tools.call_count == 2 + assert len(tool._functions) == 3 + assert [f.name for f in tool._functions] == ["tool_1", "tool_2", "tool_3"] + + +async def test_load_prompts_pagination_with_duplicates(): + """Test that load_prompts prevents duplicates across paginated results.""" + from unittest.mock import AsyncMock, MagicMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_prompts_flag = True + + # Create paginated responses with duplicate prompt names + page1 = MagicMock() + page1.prompts = [ + types.Prompt( + name="prompt_1", + description="First prompt", + arguments=[types.PromptArgument(name="arg1", description="Arg 1", required=True)], + ), + ] + page1.nextCursor = "cursor_page2" + + page2 = MagicMock() + page2.prompts = [ + types.Prompt( + name="prompt_1", # Duplicate from page1 + description="Duplicate prompt", + arguments=[types.PromptArgument(name="arg2", description="Arg 2", required=False)], + ), + types.Prompt( + name="prompt_2", + description="Second prompt", + arguments=[types.PromptArgument(name="arg3", description="Arg 3", required=True)], + ), + ] + page2.nextCursor = None + + # Mock list_prompts to return different pages + async def mock_list_prompts(params=None): + if params is None: + return page1 + if params.cursor == "cursor_page2": + return page2 + raise ValueError("Unexpected cursor value") + + mock_session.list_prompts = AsyncMock(side_effect=mock_list_prompts) + + # Load prompts with pagination + await tool.load_prompts() + + # Verify duplicates were skipped + assert mock_session.list_prompts.call_count == 2 + assert len(tool._functions) == 2 + assert [f.name for f in tool._functions] == ["prompt_1", "prompt_2"] + + +async def test_load_tools_pagination_exception_handling(): + """Test that load_tools handles exceptions during pagination gracefully.""" + from unittest.mock import AsyncMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_tools_flag = True + + # Mock list_tools to raise an exception on first call + mock_session.list_tools = AsyncMock(side_effect=RuntimeError("Connection error")) + + # Load tools should raise the exception (not handled gracefully) + with pytest.raises(RuntimeError, match="Connection error"): + await tool.load_tools() + + # Verify exception was raised on first call + assert mock_session.list_tools.call_count == 1 + assert len(tool._functions) == 0 + + +async def test_load_prompts_pagination_exception_handling(): + """Test that load_prompts handles exceptions during pagination gracefully.""" + from unittest.mock import AsyncMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_prompts_flag = True + + # Mock list_prompts to raise an exception on first call + mock_session.list_prompts = AsyncMock(side_effect=RuntimeError("Connection error")) + + # Load prompts should raise the exception (not handled gracefully) + with pytest.raises(RuntimeError, match="Connection error"): + await tool.load_prompts() + + # Verify exception was raised on first call + assert mock_session.list_prompts.call_count == 1 + assert len(tool._functions) == 0 + + +async def test_load_tools_empty_pagination(): + """Test that load_tools handles empty paginated results.""" + from unittest.mock import AsyncMock, MagicMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_tools_flag = True + + # Create empty response + page1 = MagicMock() + page1.tools = [] + page1.nextCursor = None + + mock_session.list_tools = AsyncMock(return_value=page1) + + # Load tools + await tool.load_tools() + + # Verify + assert mock_session.list_tools.call_count == 1 + assert len(tool._functions) == 0 + + +async def test_load_prompts_empty_pagination(): + """Test that load_prompts handles empty paginated results.""" + from unittest.mock import AsyncMock, MagicMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_prompts_flag = True + + # Create empty response + page1 = MagicMock() + page1.prompts = [] + page1.nextCursor = None + + mock_session.list_prompts = AsyncMock(return_value=page1) + + # Load prompts + await tool.load_prompts() + + # Verify + assert mock_session.list_prompts.call_count == 1 + assert len(tool._functions) == 0 + + +async def test_mcp_tool_connection_properly_invalidated_after_closed_resource_error(): + """Test that verifies reconnection on ClosedResourceError for issue #2884. + + This test verifies the fix for issue #2884: the tool tries operations optimistically + and only reconnects when ClosedResourceError is encountered, avoiding extra latency. + """ + from unittest.mock import AsyncMock, MagicMock, patch + + from anyio.streams.memory import ClosedResourceError + + from agent_framework._mcp import MCPStdioTool + from agent_framework.exceptions import ToolExecutionException + + # Create a mock MCP tool + tool = MCPStdioTool( + name="test_server", + command="test_command", + args=["arg1"], + load_tools=True, + ) + + # Mock the session + mock_session = MagicMock() + mock_session._request_id = 1 + mock_session.call_tool = AsyncMock() + + # Mock _exit_stack.aclose to track cleanup calls + original_exit_stack = tool._exit_stack + tool._exit_stack.aclose = AsyncMock() + + # Mock connect() to avoid trying to start actual process + with patch.object(tool, "connect", new_callable=AsyncMock) as mock_connect: + + async def restore_session(*, reset=False): + if reset: + await original_exit_stack.aclose() + tool.session = mock_session + tool.is_connected = True + tool._tools_loaded = True + + mock_connect.side_effect = restore_session + + # Simulate initial connection + tool.session = mock_session + tool.is_connected = True + tool._tools_loaded = True + + # First call should work - connection is valid + mock_session.call_tool.return_value = MagicMock(content=[]) + result = await tool.call_tool("test_tool", arg1="value1") + assert result is not None + + # Test Case 1: Connection closed unexpectedly, should reconnect and retry + # Simulate ClosedResourceError on first call, then succeed + call_count = 0 + + async def call_tool_with_error(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ClosedResourceError + return MagicMock(content=[]) + + mock_session.call_tool = call_tool_with_error + + # This call should trigger reconnection after ClosedResourceError + result = await tool.call_tool("test_tool", arg1="value2") + assert result is not None + # Verify reconnect was attempted with reset=True + assert mock_connect.call_count >= 1 + mock_connect.assert_called_with(reset=True) + # Verify _exit_stack.aclose was called during reconnection + original_exit_stack.aclose.assert_called() + + # Test Case 2: Reconnection failure + # Reset counters + call_count = 0 + mock_connect.reset_mock() + original_exit_stack.aclose.reset_mock() + + # Make call_tool always raise ClosedResourceError + async def always_fail(*args, **kwargs): + raise ClosedResourceError + + mock_session.call_tool = always_fail + + # Change mock_connect to simulate failed reconnection + mock_connect.side_effect = Exception("Failed to reconnect") + + # This should raise ToolExecutionException when reconnection fails + with pytest.raises(ToolExecutionException) as exc_info: + await tool.call_tool("test_tool", arg1="value3") + + # Verify reconnection was attempted + assert mock_connect.call_count >= 1 + # Verify error message indicates reconnection failure + assert "failed to reconnect" in str(exc_info.value).lower() + + +async def test_mcp_tool_get_prompt_reconnection_on_closed_resource_error(): + """Test that get_prompt also reconnects on ClosedResourceError. + + This verifies that the fix for issue #2884 applies to get_prompt as well, + and that _exit_stack.aclose() is properly called during reconnection. + """ + from unittest.mock import AsyncMock, MagicMock, patch + + from anyio.streams.memory import ClosedResourceError + + from agent_framework._mcp import MCPStdioTool + from agent_framework.exceptions import ToolExecutionException + + # Create a mock MCP tool + tool = MCPStdioTool( + name="test_server", + command="test_command", + args=["arg1"], + load_prompts=True, + ) + + # Mock the session + mock_session = MagicMock() + mock_session._request_id = 1 + mock_session.get_prompt = AsyncMock() + + # Mock _exit_stack.aclose to track cleanup calls + original_exit_stack = tool._exit_stack + tool._exit_stack.aclose = AsyncMock() + + # Mock connect() to avoid trying to start actual process + with patch.object(tool, "connect", new_callable=AsyncMock) as mock_connect: + + async def restore_session(*, reset=False): + if reset: + await original_exit_stack.aclose() + tool.session = mock_session + tool.is_connected = True + tool._prompts_loaded = True + + mock_connect.side_effect = restore_session + + # Simulate initial connection + tool.session = mock_session + tool.is_connected = True + tool._prompts_loaded = True + + # First call should work - connection is valid + mock_session.get_prompt.return_value = MagicMock(messages=[]) + result = await tool.get_prompt("test_prompt", arg1="value1") + assert result is not None + + # Test Case 1: Connection closed unexpectedly, should reconnect and retry + # Simulate ClosedResourceError on first call, then succeed + call_count = 0 + + async def get_prompt_with_error(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ClosedResourceError + return MagicMock(messages=[]) + + mock_session.get_prompt = get_prompt_with_error + + # This call should trigger reconnection after ClosedResourceError + result = await tool.get_prompt("test_prompt", arg1="value2") + assert result is not None + # Verify reconnect was attempted with reset=True + assert mock_connect.call_count >= 1 + mock_connect.assert_called_with(reset=True) + # Verify _exit_stack.aclose was called during reconnection + original_exit_stack.aclose.assert_called() + + # Test Case 2: Reconnection failure + # Reset counters + call_count = 0 + mock_connect.reset_mock() + original_exit_stack.aclose.reset_mock() + + # Make get_prompt always raise ClosedResourceError + async def always_fail(*args, **kwargs): + raise ClosedResourceError + + mock_session.get_prompt = always_fail + + # Change mock_connect to simulate failed reconnection + mock_connect.side_effect = Exception("Failed to reconnect") + + # This should raise ToolExecutionException when reconnection fails + with pytest.raises(ToolExecutionException) as exc_info: + await tool.get_prompt("test_prompt", arg1="value3") + + # Verify reconnection was attempted + assert mock_connect.call_count >= 1 + # Verify error message indicates reconnection failure + assert "failed to reconnect" in str(exc_info.value).lower() diff --git a/python/packages/core/tests/core/test_memory.py b/python/packages/core/tests/core/test_memory.py index f3750f20e2..6cc7ba436e 100644 --- a/python/packages/core/tests/core/test_memory.py +++ b/python/packages/core/tests/core/test_memory.py @@ -2,10 +2,9 @@ from collections.abc import MutableSequence from typing import Any -from unittest.mock import AsyncMock, Mock -from agent_framework import ChatMessage, Role, TextContent -from agent_framework._memory import AggregateContextProvider, Context, ContextProvider +from agent_framework import ChatMessage, Role +from agent_framework._memory import Context, ContextProvider class MockContextProvider(ContextProvider): @@ -45,252 +44,50 @@ async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], * return context -class TestAggregateContextProvider: - """Tests for AggregateContextProvider class.""" - - def test_init_with_no_providers(self) -> None: - """Test initialization with no providers.""" - aggregate = AggregateContextProvider() - assert aggregate.providers == [] - - def test_init_with_none_providers(self) -> None: - """Test initialization with None providers.""" - aggregate = AggregateContextProvider(None) - assert aggregate.providers == [] - - def test_init_with_providers(self) -> None: - """Test initialization with providers.""" - provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")]) - provider2 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 2")]) - provider3 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 3")]) - providers = [provider1, provider2, provider3] - - aggregate = AggregateContextProvider(providers) - assert len(aggregate.providers) == 3 - assert aggregate.providers[0] is provider1 - assert aggregate.providers[1] is provider2 - assert aggregate.providers[2] is provider3 - - def test_add_provider(self) -> None: - """Test adding a provider.""" - aggregate = AggregateContextProvider() - provider = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions")]) - - aggregate.add(provider) - assert len(aggregate.providers) == 1 - assert aggregate.providers[0] is provider - - def test_add_multiple_providers(self) -> None: - """Test adding multiple providers.""" - aggregate = AggregateContextProvider() - provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")]) - provider2 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 2")]) - - aggregate.add(provider1) - aggregate.add(provider2) - - assert len(aggregate.providers) == 2 - assert aggregate.providers[0] is provider1 - assert aggregate.providers[1] is provider2 - - async def test_thread_created_with_no_providers(self) -> None: - """Test thread_created with no providers.""" - aggregate = AggregateContextProvider() - - # Should not raise an exception - await aggregate.thread_created("thread-123") - - async def test_thread_created_with_providers(self) -> None: - """Test thread_created calls all providers.""" - provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")]) - provider2 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 2")]) - aggregate = AggregateContextProvider([provider1, provider2]) - - thread_id = "thread-123" - await aggregate.thread_created(thread_id) - - assert provider1.thread_created_called - assert provider1.thread_created_thread_id == thread_id - assert provider2.thread_created_called - assert provider2.thread_created_thread_id == thread_id - - async def test_thread_created_with_none_thread_id(self) -> None: - """Test thread_created with None thread_id.""" - provider = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions")]) - aggregate = AggregateContextProvider([provider]) - - await aggregate.thread_created(None) +class TestContext: + """Tests for Context class.""" + def test_context_default_values(self) -> None: + """Test Context has correct default values.""" + context = Context() + assert context.instructions is None + assert context.messages == [] + assert context.tools == [] + + def test_context_with_values(self) -> None: + """Test Context can be initialized with values.""" + messages = [ChatMessage(role=Role.USER, text="Test message")] + context = Context(instructions="Test instructions", messages=messages) + assert context.instructions == "Test instructions" + assert len(context.messages) == 1 + assert context.messages[0].text == "Test message" + + +class TestContextProvider: + """Tests for ContextProvider class.""" + + async def test_thread_created(self) -> None: + """Test thread_created is called.""" + provider = MockContextProvider() + await provider.thread_created("test-thread-id") assert provider.thread_created_called - assert provider.thread_created_thread_id is None - - async def test_messages_adding_with_no_providers(self) -> None: - """Test invoked with no providers.""" - aggregate = AggregateContextProvider() - message = ChatMessage(text="Hello", role=Role.USER) - - # Should not raise an exception - await aggregate.invoked(message) - - async def test_messages_adding_with_single_message(self) -> None: - """Test invoked with a single message.""" - provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")]) - provider2 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 2")]) - aggregate = AggregateContextProvider([provider1, provider2]) - - message = ChatMessage(text="Hello", role=Role.USER) - await aggregate.invoked(message) - - assert provider1.invoked_called - assert provider1.new_messages == message - assert provider2.invoked_called - assert provider2.new_messages == message - - async def test_messages_adding_with_message_sequence(self) -> None: - """Test invoked with a sequence of messages.""" - provider = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions")]) - aggregate = AggregateContextProvider([provider]) - - messages = [ - ChatMessage(text="Hello", role=Role.USER), - ChatMessage(text="Hi there", role=Role.ASSISTANT), - ] - await aggregate.invoked(messages) + assert provider.thread_created_thread_id == "test-thread-id" + async def test_invoked(self) -> None: + """Test invoked is called.""" + provider = MockContextProvider() + message = ChatMessage(role=Role.USER, text="Test message") + await provider.invoked(message) assert provider.invoked_called - assert provider.new_messages == messages - - async def test_model_invoking_with_no_providers(self) -> None: - """Test invoking with no providers.""" - aggregate = AggregateContextProvider() - message = ChatMessage(text="Hello", role=Role.USER) - - context = await aggregate.invoking(message) - - assert isinstance(context, Context) - assert not context.messages - - async def test_model_invoking_with_single_provider(self) -> None: - """Test invoking with a single provider.""" - provider = MockContextProvider(messages=[ChatMessage(role="user", text="Test instructions")]) - aggregate = AggregateContextProvider([provider]) - - message = [ChatMessage(text="Hello", role=Role.USER)] - context = await aggregate.invoking(message) + assert provider.new_messages == message + async def test_invoking(self) -> None: + """Test invoking is called and returns context.""" + provider = MockContextProvider(messages=[ChatMessage(role=Role.USER, text="Context message")]) + message = ChatMessage(role=Role.USER, text="Test message") + context = await provider.invoking(message) assert provider.invoking_called assert provider.model_invoking_messages == message - assert isinstance(context, Context) - - assert context.messages - assert isinstance(context.messages[0].contents[0], TextContent) - assert context.messages[0].text == "Test instructions" - - async def test_model_invoking_with_multiple_providers(self) -> None: - """Test invoking combines contexts from multiple providers.""" - provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")]) - provider2 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 2")]) - provider3 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 3")]) - aggregate = AggregateContextProvider([provider1, provider2, provider3]) - - messages = [ChatMessage(text="Hello", role=Role.USER)] - context = await aggregate.invoking(messages) - - assert provider1.invoking_called - assert provider1.model_invoking_messages == messages - assert provider2.invoking_called - assert provider2.model_invoking_messages == messages - assert provider3.invoking_called - assert provider3.model_invoking_messages == messages - - assert isinstance(context, Context) - - assert context.messages - assert isinstance(context.messages[0].contents[0], TextContent) - assert isinstance(context.messages[1].contents[0], TextContent) - assert isinstance(context.messages[2].contents[0], TextContent) - assert context.messages[0].text == "Instructions 1" - assert context.messages[1].text == "Instructions 2" - assert context.messages[2].text == "Instructions 3" - - async def test_model_invoking_with_none_instructions(self) -> None: - """Test invoking filters out None instructions.""" - provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")]) - provider2 = MockContextProvider(messages=None) # None instructions - provider3 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 3")]) - aggregate = AggregateContextProvider([provider1, provider2, provider3]) - - message = ChatMessage(text="Hello", role=Role.USER) - context = await aggregate.invoking(message) - - assert isinstance(context, Context) - assert context.messages - assert isinstance(context.messages[0].contents[0], TextContent) - assert isinstance(context.messages[1].contents[0], TextContent) - assert context.messages[0].text == "Instructions 1" - assert context.messages[1].text == "Instructions 3" - - async def test_model_invoking_with_all_none_instructions(self) -> None: - """Test invoking when all providers return None instructions.""" - provider1 = MockContextProvider(None) - provider2 = MockContextProvider(None) - aggregate = AggregateContextProvider([provider1, provider2]) - - message = ChatMessage(text="Hello", role=Role.USER) - context = await aggregate.invoking(message) - - assert isinstance(context, Context) - assert not context.messages - - async def test_model_invoking_with_mutable_sequence(self) -> None: - """Test invoking with MutableSequence of messages.""" - provider = MockContextProvider(messages=[ChatMessage(role="user", text="Test instructions")]) - aggregate = AggregateContextProvider([provider]) - - messages = [ChatMessage(text="Hello", role=Role.USER)] - context = await aggregate.invoking(messages) - - assert provider.invoking_called - assert provider.model_invoking_messages == messages - assert isinstance(context, Context) - assert context.messages - assert isinstance(context.messages[0].contents[0], TextContent) - assert context.messages[0].text == "Test instructions" - - async def test_async_methods_concurrent_execution(self) -> None: - """Test that async methods execute providers concurrently.""" - # Use AsyncMock to verify concurrent execution - provider1 = Mock(spec=ContextProvider) - provider1.thread_created = AsyncMock() - provider1.invoked = AsyncMock() - provider1.invoking = AsyncMock(return_value=Context(messages=[ChatMessage(role="user", text="Test 1")])) - - provider2 = Mock(spec=ContextProvider) - provider2.thread_created = AsyncMock() - provider2.invoked = AsyncMock() - provider2.invoking = AsyncMock(return_value=Context(messages=[ChatMessage(role="user", text="Test 2")])) - - aggregate = AggregateContextProvider([provider1, provider2]) - - # Test thread_created - await aggregate.thread_created("thread-123") - provider1.thread_created.assert_called_once_with("thread-123") - provider2.thread_created.assert_called_once_with("thread-123") - - # Test invoked - message = ChatMessage(text="Hello", role=Role.USER) - await aggregate.invoked(message) - provider1.invoked.assert_called_once_with( - request_messages=message, response_messages=None, invoke_exception=None - ) - provider2.invoked.assert_called_once_with( - request_messages=message, response_messages=None, invoke_exception=None - ) - - # Test invoking - context = await aggregate.invoking(message) - provider1.invoking.assert_called_once_with(message) - provider2.invoking.assert_called_once_with(message) - assert context.messages - assert context.messages[0].text == "Test 1" - assert context.messages[1].text == "Test 2" + assert context.messages is not None + assert len(context.messages) == 1 + assert context.messages[0].text == "Context message" diff --git a/python/packages/core/tests/core/test_middleware.py b/python/packages/core/tests/core/test_middleware.py index a84c8927d0..ebb833f2b4 100644 --- a/python/packages/core/tests/core/test_middleware.py +++ b/python/packages/core/tests/core/test_middleware.py @@ -9,8 +9,8 @@ from agent_framework import ( AgentProtocol, - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, ChatMessage, ChatResponse, ChatResponseUpdate, @@ -29,7 +29,6 @@ FunctionMiddlewarePipeline, ) from agent_framework._tools import AIFunction -from agent_framework._types import ChatOptions class TestAgentRunContext: @@ -100,12 +99,12 @@ class TestChatContext: def test_init_with_defaults(self, mock_chat_client: Any) -> None: """Test ChatContext initialization with default values.""" messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() - context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + chat_options: dict[str, Any] = {} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options) assert context.chat_client is mock_chat_client assert context.messages == messages - assert context.chat_options is chat_options + assert context.options is chat_options assert context.is_streaming is False assert context.metadata == {} assert context.result is None @@ -114,13 +113,13 @@ def test_init_with_defaults(self, mock_chat_client: Any) -> None: def test_init_with_custom_values(self, mock_chat_client: Any) -> None: """Test ChatContext initialization with custom values.""" messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions(temperature=0.5) + chat_options: dict[str, Any] = {"temperature": 0.5} metadata = {"key": "value"} context = ChatContext( chat_client=mock_chat_client, messages=messages, - chat_options=chat_options, + options=chat_options, is_streaming=True, metadata=metadata, terminate=True, @@ -128,7 +127,7 @@ def test_init_with_custom_values(self, mock_chat_client: Any) -> None: assert context.chat_client is mock_chat_client assert context.messages == messages - assert context.chat_options is chat_options + assert context.options is chat_options assert context.is_streaming is True assert context.metadata == metadata assert context.terminate is True @@ -148,7 +147,7 @@ async def process(self, context: AgentRunContext, next: Any) -> None: context.terminate = True def test_init_empty(self) -> None: - """Test AgentMiddlewarePipeline initialization with no middlewares.""" + """Test AgentMiddlewarePipeline initialization with no middleware.""" pipeline = AgentMiddlewarePipeline() assert not pipeline.has_middlewares @@ -173,9 +172,9 @@ async def test_execute_no_middleware(self, mock_agent: AgentProtocol) -> None: messages = [ChatMessage(role=Role.USER, text="test")] context = AgentRunContext(agent=mock_agent, messages=messages) - expected_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + expected_response = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: return expected_response result = await pipeline.execute(mock_agent, messages, context, final_handler) @@ -201,9 +200,9 @@ async def process( messages = [ChatMessage(role=Role.USER, text="test")] context = AgentRunContext(agent=mock_agent, messages=messages) - expected_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + expected_response = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: execution_order.append("handler") return expected_response @@ -217,11 +216,11 @@ async def test_execute_stream_no_middleware(self, mock_agent: AgentProtocol) -> messages = [ChatMessage(role=Role.USER, text="test")] context = AgentRunContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: - yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")]) - yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")]) + async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate(contents=[TextContent(text="chunk1")]) + yield AgentResponseUpdate(contents=[TextContent(text="chunk2")]) - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler): updates.append(update) @@ -249,13 +248,13 @@ async def process( messages = [ChatMessage(role=Role.USER, text="test")] context = AgentRunContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: execution_order.append("handler_start") - yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")]) - yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")]) + yield AgentResponseUpdate(contents=[TextContent(text="chunk1")]) + yield AgentResponseUpdate(contents=[TextContent(text="chunk2")]) execution_order.append("handler_end") - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler): updates.append(update) @@ -272,10 +271,10 @@ async def test_execute_with_pre_next_termination(self, mock_agent: AgentProtocol context = AgentRunContext(agent=mock_agent, messages=messages) execution_order: list[str] = [] - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: # Handler should not be executed when terminated before next() execution_order.append("handler") - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) response = await pipeline.execute(mock_agent, messages, context, final_handler) assert response is not None @@ -292,9 +291,9 @@ async def test_execute_with_post_next_termination(self, mock_agent: AgentProtoco context = AgentRunContext(agent=mock_agent, messages=messages) execution_order: list[str] = [] - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: execution_order.append("handler") - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) response = await pipeline.execute(mock_agent, messages, context, final_handler) assert response is not None @@ -311,14 +310,14 @@ async def test_execute_stream_with_pre_next_termination(self, mock_agent: AgentP context = AgentRunContext(agent=mock_agent, messages=messages) execution_order: list[str] = [] - async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: # Handler should not be executed when terminated before next() execution_order.append("handler_start") - yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")]) - yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")]) + yield AgentResponseUpdate(contents=[TextContent(text="chunk1")]) + yield AgentResponseUpdate(contents=[TextContent(text="chunk2")]) execution_order.append("handler_end") - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler): updates.append(update) @@ -335,13 +334,13 @@ async def test_execute_stream_with_post_next_termination(self, mock_agent: Agent context = AgentRunContext(agent=mock_agent, messages=messages) execution_order: list[str] = [] - async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: execution_order.append("handler_start") - yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")]) - yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")]) + yield AgentResponseUpdate(contents=[TextContent(text="chunk1")]) + yield AgentResponseUpdate(contents=[TextContent(text="chunk2")]) execution_order.append("handler_end") - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler): updates.append(update) @@ -371,9 +370,9 @@ async def process( thread = AgentThread() context = AgentRunContext(agent=mock_agent, messages=messages, thread=thread) - expected_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + expected_response = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: return expected_response result = await pipeline.execute(mock_agent, messages, context, final_handler) @@ -397,9 +396,9 @@ async def process( messages = [ChatMessage(role=Role.USER, text="test")] context = AgentRunContext(agent=mock_agent, messages=messages, thread=None) - expected_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + expected_response = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: return expected_response result = await pipeline.execute(mock_agent, messages, context, final_handler) @@ -457,7 +456,7 @@ async def final_handler(ctx: FunctionInvocationContext) -> str: assert execution_order == ["handler"] def test_init_empty(self) -> None: - """Test FunctionMiddlewarePipeline initialization with no middlewares.""" + """Test FunctionMiddlewarePipeline initialization with no middleware.""" pipeline = FunctionMiddlewarePipeline() assert not pipeline.has_middlewares @@ -539,7 +538,7 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai context.terminate = True def test_init_empty(self) -> None: - """Test ChatMiddlewarePipeline initialization with no middlewares.""" + """Test ChatMiddlewarePipeline initialization with no middleware.""" pipeline = ChatMiddlewarePipeline() assert not pipeline.has_middlewares @@ -562,8 +561,8 @@ async def test_execute_no_middleware(self, mock_chat_client: Any) -> None: """Test pipeline execution with no middleware.""" pipeline = ChatMiddlewarePipeline() messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() - context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + chat_options: dict[str, Any] = {} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options) expected_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) @@ -589,8 +588,8 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai middleware = OrderTrackingChatMiddleware("test") pipeline = ChatMiddlewarePipeline([middleware]) messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() - context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + chat_options: dict[str, Any] = {} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options) expected_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) @@ -606,8 +605,8 @@ async def test_execute_stream_no_middleware(self, mock_chat_client: Any) -> None """Test pipeline streaming execution with no middleware.""" pipeline = ChatMiddlewarePipeline() messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() - context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + chat_options: dict[str, Any] = {} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options) async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[TextContent(text="chunk1")]) @@ -637,10 +636,8 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai middleware = StreamOrderTrackingChatMiddleware("test") pipeline = ChatMiddlewarePipeline([middleware]) messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() - context = ChatContext( - chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True - ) + chat_options: dict[str, Any] = {} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options, is_streaming=True) async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: execution_order.append("handler_start") @@ -662,8 +659,8 @@ async def test_execute_with_pre_next_termination(self, mock_chat_client: Any) -> middleware = self.PreNextTerminateChatMiddleware() pipeline = ChatMiddlewarePipeline([middleware]) messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() - context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + chat_options: dict[str, Any] = {} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options) execution_order: list[str] = [] async def final_handler(ctx: ChatContext) -> ChatResponse: @@ -682,8 +679,8 @@ async def test_execute_with_post_next_termination(self, mock_chat_client: Any) - middleware = self.PostNextTerminateChatMiddleware() pipeline = ChatMiddlewarePipeline([middleware]) messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() - context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + chat_options: dict[str, Any] = {} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options) execution_order: list[str] = [] async def final_handler(ctx: ChatContext) -> ChatResponse: @@ -702,10 +699,8 @@ async def test_execute_stream_with_pre_next_termination(self, mock_chat_client: middleware = self.PreNextTerminateChatMiddleware() pipeline = ChatMiddlewarePipeline([middleware]) messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() - context = ChatContext( - chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True - ) + chat_options: dict[str, Any] = {} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options, is_streaming=True) execution_order: list[str] = [] async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: @@ -729,10 +724,8 @@ async def test_execute_stream_with_post_next_termination(self, mock_chat_client: middleware = self.PostNextTerminateChatMiddleware() pipeline = ChatMiddlewarePipeline([middleware]) messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() - context = ChatContext( - chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True - ) + chat_options: dict[str, Any] = {} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options, is_streaming=True) execution_order: list[str] = [] async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: @@ -774,9 +767,9 @@ async def process( messages = [ChatMessage(role=Role.USER, text="test")] context = AgentRunContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: metadata_updates.append("handler") - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) result = await pipeline.execute(mock_agent, messages, context, final_handler) @@ -837,9 +830,9 @@ async def test_agent_middleware( messages = [ChatMessage(role=Role.USER, text="test")] context = AgentRunContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: execution_order.append("handler") - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) result = await pipeline.execute(mock_agent, messages, context, final_handler) @@ -900,9 +893,9 @@ async def function_middleware( messages = [ChatMessage(role=Role.USER, text="test")] context = AgentRunContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: execution_order.append("handler") - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) result = await pipeline.execute(mock_agent, messages, context, final_handler) @@ -962,8 +955,8 @@ async def function_chat_middleware( pipeline = ChatMiddlewarePipeline([ClassChatMiddleware(), function_chat_middleware]) messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() - context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + chat_options: dict[str, Any] = {} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options) async def final_handler(ctx: ChatContext) -> ChatResponse: execution_order.append("handler") @@ -979,7 +972,7 @@ class TestMultipleMiddlewareOrdering: """Test cases for multiple middleware execution order.""" async def test_agent_middleware_execution_order(self, mock_agent: AgentProtocol) -> None: - """Test that multiple agent middlewares execute in registration order.""" + """Test that multiple agent middleware execute in registration order.""" execution_order: list[str] = [] class FirstMiddleware(AgentMiddleware): @@ -1006,14 +999,14 @@ async def process( await next(context) execution_order.append("third_after") - middlewares = [FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()] - pipeline = AgentMiddlewarePipeline(middlewares) # type: ignore + middleware = [FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()] + pipeline = AgentMiddlewarePipeline(middleware) # type: ignore messages = [ChatMessage(role=Role.USER, text="test")] context = AgentRunContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: execution_order.append("handler") - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) result = await pipeline.execute(mock_agent, messages, context, final_handler) @@ -1030,7 +1023,7 @@ async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: assert execution_order == expected_order async def test_function_middleware_execution_order(self, mock_function: AIFunction[Any, Any]) -> None: - """Test that multiple function middlewares execute in registration order.""" + """Test that multiple function middleware execute in registration order.""" execution_order: list[str] = [] class FirstMiddleware(FunctionMiddleware): @@ -1053,8 +1046,8 @@ async def process( await next(context) execution_order.append("second_after") - middlewares = [FirstMiddleware(), SecondMiddleware()] - pipeline = FunctionMiddlewarePipeline(middlewares) # type: ignore + middleware = [FirstMiddleware(), SecondMiddleware()] + pipeline = FunctionMiddlewarePipeline(middleware) # type: ignore arguments = FunctionTestArgs(name="test") context = FunctionInvocationContext(function=mock_function, arguments=arguments) @@ -1069,7 +1062,7 @@ async def final_handler(ctx: FunctionInvocationContext) -> str: assert execution_order == expected_order async def test_chat_middleware_execution_order(self, mock_chat_client: Any) -> None: - """Test that multiple chat middlewares execute in registration order.""" + """Test that multiple chat middleware execute in registration order.""" execution_order: list[str] = [] class FirstChatMiddleware(ChatMiddleware): @@ -1090,11 +1083,11 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai await next(context) execution_order.append("third_after") - middlewares = [FirstChatMiddleware(), SecondChatMiddleware(), ThirdChatMiddleware()] - pipeline = ChatMiddlewarePipeline(middlewares) # type: ignore + middleware = [FirstChatMiddleware(), SecondChatMiddleware(), ThirdChatMiddleware()] + pipeline = ChatMiddlewarePipeline(middleware) # type: ignore messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() - context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + chat_options: dict[str, Any] = {} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options) async def final_handler(ctx: ChatContext) -> ChatResponse: execution_order.append("handler") @@ -1149,10 +1142,10 @@ async def process( messages = [ChatMessage(role=Role.USER, text="test")] context = AgentRunContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: # Verify metadata was set by middleware assert ctx.metadata.get("validated") is True - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) result = await pipeline.execute(mock_agent, messages, context, final_handler) assert result is not None @@ -1203,7 +1196,7 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai # Verify context has all expected attributes assert hasattr(context, "chat_client") assert hasattr(context, "messages") - assert hasattr(context, "chat_options") + assert hasattr(context, "options") assert hasattr(context, "is_streaming") assert hasattr(context, "metadata") assert hasattr(context, "result") @@ -1216,8 +1209,8 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai assert context.messages[0].text == "test" assert context.is_streaming is False assert isinstance(context.metadata, dict) - assert isinstance(context.chat_options, ChatOptions) - assert context.chat_options.temperature == 0.5 + assert isinstance(context.options, dict) + assert context.options.get("temperature") == 0.5 # Add custom metadata context.metadata["validated"] = True @@ -1227,8 +1220,8 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai middleware = ChatContextValidationMiddleware() pipeline = ChatMiddlewarePipeline([middleware]) messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions(temperature=0.5) - context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + chat_options: dict[str, Any] = {"temperature": 0.5} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options) async def final_handler(ctx: ChatContext) -> ChatResponse: # Verify metadata was set by middleware @@ -1260,20 +1253,20 @@ async def process( # Test non-streaming context = AgentRunContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: streaming_flags.append(ctx.is_streaming) - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) await pipeline.execute(mock_agent, messages, context, final_handler) # Test streaming context_stream = AgentRunContext(agent=mock_agent, messages=messages) - async def final_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + async def final_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: streaming_flags.append(ctx.is_streaming) - yield AgentRunResponseUpdate(contents=[TextContent(text="chunk")]) + yield AgentResponseUpdate(contents=[TextContent(text="chunk")]) - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in pipeline.execute_stream(mock_agent, messages, context_stream, final_stream_handler): updates.append(update) @@ -1297,11 +1290,11 @@ async def process( messages = [ChatMessage(role=Role.USER, text="test")] context = AgentRunContext(agent=mock_agent, messages=messages) - async def final_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + async def final_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: chunks_processed.append("stream_start") - yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")]) + yield AgentResponseUpdate(contents=[TextContent(text="chunk1")]) chunks_processed.append("chunk1_yielded") - yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")]) + yield AgentResponseUpdate(contents=[TextContent(text="chunk2")]) chunks_processed.append("chunk2_yielded") chunks_processed.append("stream_end") @@ -1331,10 +1324,10 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai middleware = ChatStreamingFlagMiddleware() pipeline = ChatMiddlewarePipeline([middleware]) messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} # Test non-streaming - context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options) async def final_handler(ctx: ChatContext) -> ChatResponse: streaming_flags.append(ctx.is_streaming) @@ -1344,7 +1337,7 @@ async def final_handler(ctx: ChatContext) -> ChatResponse: # Test streaming context_stream = ChatContext( - chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True + chat_client=mock_chat_client, messages=messages, options=chat_options, is_streaming=True ) async def final_stream_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: @@ -1373,10 +1366,8 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai middleware = ChatStreamProcessingMiddleware() pipeline = ChatMiddlewarePipeline([middleware]) messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() - context = ChatContext( - chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True - ) + chat_options: dict[str, Any] = {} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options, is_streaming=True) async def final_stream_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: chunks_processed.append("stream_start") @@ -1461,16 +1452,16 @@ async def process( handler_called = False - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: nonlocal handler_called handler_called = True - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")]) result = await pipeline.execute(mock_agent, messages, context, final_handler) - # Verify no execution happened - should return empty AgentRunResponse + # Verify no execution happened - should return empty AgentResponse assert result is not None - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert result.messages == [] # Empty response assert not handler_called assert context.result is None @@ -1492,13 +1483,13 @@ async def process( handler_called = False - async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: nonlocal handler_called handler_called = True - yield AgentRunResponseUpdate(contents=[TextContent(text="should not execute")]) + yield AgentResponseUpdate(contents=[TextContent(text="should not execute")]) # When middleware doesn't call next(), streaming should yield no updates - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler): updates.append(update) @@ -1542,7 +1533,7 @@ async def final_handler(ctx: FunctionInvocationContext) -> str: assert context.result is None async def test_multiple_middlewares_early_stop(self, mock_agent: AgentProtocol) -> None: - """Test that when first middleware doesn't call next(), subsequent middlewares are not called.""" + """Test that when first middleware doesn't call next(), subsequent middleware are not called.""" execution_order: list[str] = [] class FirstMiddleware(AgentMiddleware): @@ -1565,17 +1556,17 @@ async def process( handler_called = False - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: nonlocal handler_called handler_called = True - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")]) result = await pipeline.execute(mock_agent, messages, context, final_handler) # Verify only first middleware was called and empty response returned assert execution_order == ["first"] assert result is not None - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert result.messages == [] # Empty response assert not handler_called @@ -1590,8 +1581,8 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai middleware = NoNextChatMiddleware() pipeline = ChatMiddlewarePipeline([middleware]) messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() - context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + chat_options: dict[str, Any] = {} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options) handler_called = False @@ -1618,10 +1609,8 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai middleware = NoNextStreamingChatMiddleware() pipeline = ChatMiddlewarePipeline([middleware]) messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() - context = ChatContext( - chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True - ) + chat_options: dict[str, Any] = {} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options, is_streaming=True) handler_called = False @@ -1641,7 +1630,7 @@ async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: assert context.result is None async def test_multiple_chat_middlewares_early_stop(self, mock_chat_client: Any) -> None: - """Test that when first chat middleware doesn't call next(), subsequent middlewares are not called.""" + """Test that when first chat middleware doesn't call next(), subsequent middleware are not called.""" execution_order: list[str] = [] class FirstChatMiddleware(ChatMiddleware): @@ -1656,8 +1645,8 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai pipeline = ChatMiddlewarePipeline([FirstChatMiddleware(), SecondChatMiddleware()]) messages = [ChatMessage(role=Role.USER, text="test")] - chat_options = ChatOptions() - context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + chat_options: dict[str, Any] = {} + context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options) handler_called = False diff --git a/python/packages/core/tests/core/test_middleware_context_result.py b/python/packages/core/tests/core/test_middleware_context_result.py index 447ba0d4b9..bfcfb48e5f 100644 --- a/python/packages/core/tests/core/test_middleware_context_result.py +++ b/python/packages/core/tests/core/test_middleware_context_result.py @@ -9,8 +9,8 @@ from agent_framework import ( AgentProtocol, - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, ChatAgent, ChatMessage, Role, @@ -40,7 +40,7 @@ class TestResultOverrideMiddleware: async def test_agent_middleware_response_override_non_streaming(self, mock_agent: AgentProtocol) -> None: """Test that agent middleware can override response for non-streaming execution.""" - override_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="overridden response")]) + override_response = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="overridden response")]) class ResponseOverrideMiddleware(AgentMiddleware): async def process( @@ -57,10 +57,10 @@ async def process( handler_called = False - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: nonlocal handler_called handler_called = True - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="original response")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="original response")]) result = await pipeline.execute(mock_agent, messages, context, final_handler) @@ -74,9 +74,9 @@ async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: async def test_agent_middleware_response_override_streaming(self, mock_agent: AgentProtocol) -> None: """Test that agent middleware can override response for streaming execution.""" - async def override_stream() -> AsyncIterable[AgentRunResponseUpdate]: - yield AgentRunResponseUpdate(contents=[TextContent(text="overridden")]) - yield AgentRunResponseUpdate(contents=[TextContent(text=" stream")]) + async def override_stream() -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate(contents=[TextContent(text="overridden")]) + yield AgentResponseUpdate(contents=[TextContent(text=" stream")]) class StreamResponseOverrideMiddleware(AgentMiddleware): async def process( @@ -91,10 +91,10 @@ async def process( messages = [ChatMessage(role=Role.USER, text="test")] context = AgentRunContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: - yield AgentRunResponseUpdate(contents=[TextContent(text="original")]) + async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate(contents=[TextContent(text="original")]) - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler): updates.append(update) @@ -148,7 +148,7 @@ async def process( await next(context) # Then conditionally override based on content if any("special" in msg.text for msg in context.messages if msg.text): - context.result = AgentRunResponse( + context.result = AgentResponse( messages=[ChatMessage(role=Role.ASSISTANT, text="Special response from middleware!")] ) @@ -174,10 +174,10 @@ async def test_chat_agent_middleware_streaming_override(self) -> None: """Test streaming result override functionality with ChatAgent integration.""" mock_chat_client = MockChatClient() - async def custom_stream() -> AsyncIterable[AgentRunResponseUpdate]: - yield AgentRunResponseUpdate(contents=[TextContent(text="Custom")]) - yield AgentRunResponseUpdate(contents=[TextContent(text=" streaming")]) - yield AgentRunResponseUpdate(contents=[TextContent(text=" response!")]) + async def custom_stream() -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate(contents=[TextContent(text="Custom")]) + yield AgentResponseUpdate(contents=[TextContent(text=" streaming")]) + yield AgentResponseUpdate(contents=[TextContent(text=" response!")]) class ChatAgentStreamOverrideMiddleware(AgentMiddleware): async def process( @@ -195,7 +195,7 @@ async def process( # Test streaming override case override_messages = [ChatMessage(role=Role.USER, text="Give me a custom stream")] - override_updates: list[AgentRunResponseUpdate] = [] + override_updates: list[AgentResponseUpdate] = [] async for update in agent.run_stream(override_messages): override_updates.append(update) @@ -206,7 +206,7 @@ async def process( # Test normal streaming case normal_messages = [ChatMessage(role=Role.USER, text="Normal streaming request")] - normal_updates: list[AgentRunResponseUpdate] = [] + normal_updates: list[AgentResponseUpdate] = [] async for update in agent.run_stream(normal_messages): normal_updates.append(update) @@ -231,19 +231,19 @@ async def process( handler_called = False - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + async def final_handler(ctx: AgentRunContext) -> AgentResponse: nonlocal handler_called handler_called = True - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="executed response")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="executed response")]) # Test case where next() is NOT called no_execute_messages = [ChatMessage(role=Role.USER, text="Don't run this")] no_execute_context = AgentRunContext(agent=mock_agent, messages=no_execute_messages) no_execute_result = await pipeline.execute(mock_agent, no_execute_messages, no_execute_context, final_handler) - # When middleware doesn't call next(), result should be empty AgentRunResponse + # When middleware doesn't call next(), result should be empty AgentResponse assert no_execute_result is not None - assert isinstance(no_execute_result, AgentRunResponse) + assert isinstance(no_execute_result, AgentResponse) assert no_execute_result.messages == [] # Empty response assert not handler_called assert no_execute_context.result is None @@ -313,7 +313,7 @@ class TestResultObservability: async def test_agent_middleware_response_observability(self, mock_agent: AgentProtocol) -> None: """Test that middleware can observe response after execution.""" - observed_responses: list[AgentRunResponse] = [] + observed_responses: list[AgentResponse] = [] class ObservabilityMiddleware(AgentMiddleware): async def process( @@ -327,7 +327,7 @@ async def process( # Context should now contain the response for observability assert context.result is not None - assert isinstance(context.result, AgentRunResponse) + assert isinstance(context.result, AgentResponse) observed_responses.append(context.result) middleware = ObservabilityMiddleware() @@ -335,8 +335,8 @@ async def process( messages = [ChatMessage(role=Role.USER, text="test")] context = AgentRunContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="executed response")]) + async def final_handler(ctx: AgentRunContext) -> AgentResponse: + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="executed response")]) result = await pipeline.execute(mock_agent, messages, context, final_handler) @@ -392,11 +392,11 @@ async def process( # Now observe and conditionally override assert context.result is not None - assert isinstance(context.result, AgentRunResponse) + assert isinstance(context.result, AgentResponse) if "modify" in context.result.messages[0].text: # Override after observing - context.result = AgentRunResponse( + context.result = AgentResponse( messages=[ChatMessage(role=Role.ASSISTANT, text="modified after execution")] ) @@ -405,8 +405,8 @@ async def process( messages = [ChatMessage(role=Role.USER, text="test")] context = AgentRunContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response to modify")]) + async def final_handler(ctx: AgentRunContext) -> AgentResponse: + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response to modify")]) result = await pipeline.execute(mock_agent, messages, context, final_handler) diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index 6cb41f674b..5cfea39287 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -6,7 +6,7 @@ import pytest from agent_framework import ( - AgentRunResponseUpdate, + AgentResponseUpdate, ChatAgent, ChatContext, ChatMessage, @@ -372,7 +372,7 @@ async def process( # Execute streaming messages = [ChatMessage(role=Role.USER, text="test message")] - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in agent.run_stream(messages): updates.append(update) @@ -418,7 +418,7 @@ class TestChatAgentMultipleMiddlewareOrdering: """Test cases for multiple middleware execution order with ChatAgent.""" async def test_multiple_agent_middleware_execution_order(self, chat_client: "MockChatClient") -> None: - """Test that multiple agent middlewares execute in correct order with ChatAgent.""" + """Test that multiple agent middleware execute in correct order with ChatAgent.""" execution_order: list[str] = [] class OrderedMiddleware(AgentMiddleware): @@ -432,12 +432,12 @@ async def process( await next(context) execution_order.append(f"{self.name}_after") - # Create multiple middlewares + # Create multiple middleware middleware1 = OrderedMiddleware("first") middleware2 = OrderedMiddleware("second") middleware3 = OrderedMiddleware("third") - # Create ChatAgent with multiple middlewares + # Create ChatAgent with multiple middleware agent = ChatAgent(chat_client=chat_client, middleware=[middleware1, middleware2, middleware3]) # Execute the agent @@ -453,7 +453,7 @@ async def process( assert execution_order == expected_order async def test_mixed_middleware_types_with_chat_agent(self, chat_client: "MockChatClient") -> None: - """Test mixed class and function-based middlewares with ChatAgent.""" + """Test mixed class and function-based middleware with ChatAgent.""" execution_order: list[str] = [] class ClassAgentMiddleware(AgentMiddleware): @@ -507,8 +507,8 @@ async def function_function_middleware( assert response is not None assert chat_client.call_count == 1 - # Verify that agent middlewares were executed in correct order - # (Function middlewares won't execute since no functions are called) + # Verify that agent middleware were executed in correct order + # (Function middleware won't execute since no functions are called) expected_order = ["class_agent_before", "function_agent_before", "function_agent_after", "class_agent_after"] assert execution_order == expected_order @@ -734,7 +734,7 @@ async def process( async def test_function_middleware_can_access_and_override_custom_kwargs( self, chat_client: "MockChatClient" ) -> None: - """Test that function middleware can access and override custom parameters like temperature.""" + """Test that function middleware can access and override custom parameters.""" captured_kwargs: dict[str, Any] = {} modified_kwargs: dict[str, Any] = {} middleware_called = False @@ -747,38 +747,20 @@ async def kwargs_middleware( middleware_called = True # Capture the original kwargs - captured_kwargs["has_chat_options"] = "chat_options" in context.kwargs captured_kwargs["has_custom_param"] = "custom_param" in context.kwargs captured_kwargs["custom_param"] = context.kwargs.get("custom_param") - # Capture original chat_options values if present - if "chat_options" in context.kwargs: - chat_options = context.kwargs["chat_options"] - captured_kwargs["original_temperature"] = getattr(chat_options, "temperature", None) - captured_kwargs["original_max_tokens"] = getattr(chat_options, "max_tokens", None) - # Modify some kwargs context.kwargs["temperature"] = 0.9 context.kwargs["max_tokens"] = 500 context.kwargs["new_param"] = "added_by_middleware" - # Also modify chat_options if present - if "chat_options" in context.kwargs: - context.kwargs["chat_options"].temperature = 0.9 - context.kwargs["chat_options"].max_tokens = 500 - # Store modified kwargs for verification modified_kwargs["temperature"] = context.kwargs.get("temperature") modified_kwargs["max_tokens"] = context.kwargs.get("max_tokens") modified_kwargs["new_param"] = context.kwargs.get("new_param") modified_kwargs["custom_param"] = context.kwargs.get("custom_param") - # Capture modified chat_options values if present - if "chat_options" in context.kwargs: - chat_options = context.kwargs["chat_options"] - modified_kwargs["chat_options_temperature"] = getattr(chat_options, "temperature", None) - modified_kwargs["chat_options_max_tokens"] = getattr(chat_options, "max_tokens", None) - await next(context) chat_client.responses = [ @@ -800,9 +782,9 @@ async def kwargs_middleware( # Create ChatAgent with function middleware agent = ChatAgent(chat_client=chat_client, middleware=[kwargs_middleware], tools=[sample_tool_function]) - # Execute the agent with custom parameters + # Execute the agent with custom parameters passed as kwargs messages = [ChatMessage(role=Role.USER, text="test message")] - response = await agent.run(messages, temperature=0.7, max_tokens=100, custom_param="test_value") + response = await agent.run(messages, custom_param="test_value") # Verify response assert response is not None @@ -812,19 +794,14 @@ async def kwargs_middleware( assert middleware_called, "Function middleware was not called" # Verify middleware captured the original kwargs - assert captured_kwargs["has_chat_options"] is True assert captured_kwargs["has_custom_param"] is True assert captured_kwargs["custom_param"] == "test_value" - assert captured_kwargs["original_temperature"] == 0.7 - assert captured_kwargs["original_max_tokens"] == 100 # Verify middleware could modify the kwargs assert modified_kwargs["temperature"] == 0.9 assert modified_kwargs["max_tokens"] == 500 assert modified_kwargs["new_param"] == "added_by_middleware" assert modified_kwargs["custom_param"] == "test_value" - assert modified_kwargs["chat_options_temperature"] == 0.9 - assert modified_kwargs["chat_options_max_tokens"] == 500 class TestMiddlewareDynamicRebuild: @@ -901,7 +878,7 @@ async def test_middleware_dynamic_rebuild_streaming(self, chat_client: "MockChat agent = ChatAgent(chat_client=chat_client, middleware=[middleware1]) # First streaming execution - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in agent.run_stream("Test stream message 1"): updates.append(update) @@ -999,7 +976,7 @@ async def test_run_level_middleware_isolation(self, chat_client: "MockChatClient # Clear execution log execution_log.clear() - # Fourth run with both run middlewares - should see both + # Fourth run with both run middleware - should see both await agent.run("Test message 4", middleware=[run_middleware1, run_middleware2]) assert execution_log == ["run1_start", "run2_start", "run2_end", "run1_end"] @@ -1108,7 +1085,7 @@ async def process( run_middleware = StreamingTrackingMiddleware("run_stream") # Execute streaming with run middleware - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in agent.run_stream("Test streaming", middleware=[run_middleware]): updates.append(update) @@ -1734,7 +1711,7 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai # Execute streaming messages = [ChatMessage(role=Role.USER, text="test message")] - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in agent.run_stream(messages): updates.append(update) diff --git a/python/packages/core/tests/core/test_middleware_with_chat.py b/python/packages/core/tests/core/test_middleware_with_chat.py index 91c501de6b..9d395284ea 100644 --- a/python/packages/core/tests/core/test_middleware_with_chat.py +++ b/python/packages/core/tests/core/test_middleware_with_chat.py @@ -366,7 +366,7 @@ def sample_tool(location: str) -> str: # Execute the chat client directly with tools - this should trigger function invocation and middleware messages = [ChatMessage(role=Role.USER, text="What's the weather in San Francisco?")] - response = await chat_client.get_response(messages, tools=[sample_tool]) + response = await chat_client.get_response(messages, options={"tools": [sample_tool]}) # Verify response assert response is not None @@ -423,7 +423,7 @@ def sample_tool(location: str) -> str: # Execute the chat client directly with run-level middleware and tools messages = [ChatMessage(role=Role.USER, text="What's the weather in New York?")] response = await chat_client.get_response( - messages, tools=[sample_tool], middleware=[run_level_function_middleware] + messages, options={"tools": [sample_tool]}, middleware=[run_level_function_middleware] ) # Verify response diff --git a/python/packages/core/tests/core/test_observability.py b/python/packages/core/tests/core/test_observability.py index 8528295406..95f234efd4 100644 --- a/python/packages/core/tests/core/test_observability.py +++ b/python/packages/core/tests/core/test_observability.py @@ -13,11 +13,10 @@ from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, AgentProtocol, - AgentRunResponse, + AgentResponse, AgentThread, BaseChatClient, ChatMessage, - ChatOptions, ChatResponse, ChatResponseUpdate, Role, @@ -215,7 +214,7 @@ def service_url(self): return "https://test.example.com" async def _inner_get_response( - self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + self, *, messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ): return ChatResponse( messages=[ChatMessage(role=Role.ASSISTANT, text="Test response")], @@ -224,7 +223,7 @@ async def _inner_get_response( ) async def _inner_get_streaming_response( - self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + self, *, messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ): yield ChatResponseUpdate(text="Hello", role=Role.ASSISTANT) yield ChatResponseUpdate(text=" world", role=Role.ASSISTANT) @@ -342,7 +341,6 @@ class MockChatClientAgent: def __init__(self): self.id = "test_agent_id" self.name = "test_agent" - self.display_name = "Test Agent" self.description = "Test agent description" async def run(self, messages=None, *, thread=None, **kwargs): @@ -384,7 +382,6 @@ class MockAgent: def __init__(self): self.id = "test_agent_id" self.name = "test_agent" - self.display_name = "Test Agent" async def run(self, messages=None, *, thread=None, **kwargs): return Mock() @@ -406,12 +403,11 @@ class MockChatClientAgent: def __init__(self): self.id = "test_agent_id" self.name = "test_agent" - self.display_name = "Test Agent" self.description = "Test agent description" - self.chat_options = ChatOptions(model_id="TestModel") + self.default_options: dict[str, Any] = {"model_id": "TestModel"} async def run(self, messages=None, *, thread=None, **kwargs): - return AgentRunResponse( + return AgentResponse( messages=[ChatMessage(role=Role.ASSISTANT, text="Agent response")], usage_details=UsageDetails(input_token_count=15, output_token_count=25), response_id="test_response_id", @@ -419,10 +415,10 @@ async def run(self, messages=None, *, thread=None, **kwargs): ) async def run_stream(self, messages=None, *, thread=None, **kwargs): - from agent_framework import AgentRunResponseUpdate + from agent_framework import AgentResponseUpdate - yield AgentRunResponseUpdate(text="Hello", role=Role.ASSISTANT) - yield AgentRunResponseUpdate(text=" from agent", role=Role.ASSISTANT) + yield AgentResponseUpdate(text="Hello", role=Role.ASSISTANT) + yield AgentResponseUpdate(text=" from agent", role=Role.ASSISTANT) return MockChatClientAgent @@ -441,10 +437,10 @@ async def test_agent_instrumentation_enabled( spans = span_exporter.get_finished_spans() assert len(spans) == 1 span = spans[0] - assert span.name == "invoke_agent Test Agent" + assert span.name == "invoke_agent test_agent" assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.AGENT_INVOKE_OPERATION assert span.attributes[OtelAttr.AGENT_ID] == "test_agent_id" - assert span.attributes[OtelAttr.AGENT_NAME] == "Test Agent" + assert span.attributes[OtelAttr.AGENT_NAME] == "test_agent" assert span.attributes[OtelAttr.AGENT_DESCRIPTION] == "Test agent description" assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "TestModel" assert span.attributes[OtelAttr.INPUT_TOKENS] == 15 @@ -469,10 +465,10 @@ async def test_agent_streaming_response_with_diagnostics_enabled_via_decorator( spans = span_exporter.get_finished_spans() assert len(spans) == 1 span = spans[0] - assert span.name == "invoke_agent Test Agent" + assert span.name == "invoke_agent test_agent" assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.AGENT_INVOKE_OPERATION assert span.attributes[OtelAttr.AGENT_ID] == "test_agent_id" - assert span.attributes[OtelAttr.AGENT_NAME] == "Test Agent" + assert span.attributes[OtelAttr.AGENT_NAME] == "test_agent" assert span.attributes[OtelAttr.AGENT_DESCRIPTION] == "Test agent description" assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "TestModel" if enable_sensitive_data: diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index f70e6ddb56..77442be322 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -5,7 +5,7 @@ import pytest from opentelemetry import trace from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from agent_framework import ( AIFunction, @@ -15,7 +15,11 @@ ToolProtocol, ai_function, ) -from agent_framework._tools import _parse_annotation, _parse_inputs +from agent_framework._tools import ( + _build_pydantic_model_from_json_schema, + _parse_annotation, + _parse_inputs, +) from agent_framework.exceptions import ToolException from agent_framework.observability import OtelAttr @@ -425,7 +429,7 @@ async def simple_tool(message: str) -> str: result = await simple_tool.invoke( arguments=args, api_token="secret-token", - chat_options={"model_id": "dummy"}, + options={"model_id": "dummy"}, ) assert result == "HELLO WORLD" @@ -1031,7 +1035,7 @@ async def mock_get_response(self, messages, **kwargs): wrapped = _handle_function_calls_response(mock_get_response) # Execute - result = await wrapped(mock_client, messages=[], tools=[no_approval_tool]) + result = await wrapped(mock_client, messages=[], options={"tools": [no_approval_tool]}) # Verify: should have 3 messages: function call, function result, final answer assert len(result.messages) == 3 @@ -1071,7 +1075,7 @@ async def mock_get_response(self, messages, **kwargs): wrapped = _handle_function_calls_response(mock_get_response) # Execute - result = await wrapped(mock_client, messages=[], tools=[requires_approval_tool]) + result = await wrapped(mock_client, messages=[], options={"tools": [requires_approval_tool]}) # Verify: should return 1 message with function call and approval request from agent_framework import FunctionApprovalRequestContent @@ -1117,7 +1121,7 @@ async def mock_get_response(self, messages, **kwargs): wrapped = _handle_function_calls_response(mock_get_response) # Execute - result = await wrapped(mock_client, messages=[], tools=[no_approval_tool]) + result = await wrapped(mock_client, messages=[], options={"tools": [no_approval_tool]}) # Verify: should have function calls, results, and final answer from agent_framework import FunctionResultContent @@ -1163,7 +1167,7 @@ async def mock_get_response(self, messages, **kwargs): wrapped = _handle_function_calls_response(mock_get_response) # Execute - result = await wrapped(mock_client, messages=[], tools=[requires_approval_tool]) + result = await wrapped(mock_client, messages=[], options={"tools": [requires_approval_tool]}) # Verify: should return 1 message with function calls and approval requests from agent_framework import FunctionApprovalRequestContent @@ -1209,7 +1213,7 @@ async def mock_get_response(self, messages, **kwargs): wrapped = _handle_function_calls_response(mock_get_response) # Execute - result = await wrapped(mock_client, messages=[], tools=[no_approval_tool, requires_approval_tool]) + result = await wrapped(mock_client, messages=[], options={"tools": [no_approval_tool, requires_approval_tool]}) # Verify: should return approval requests for both (when one needs approval, all are sent for approval) from agent_framework import FunctionApprovalRequestContent @@ -1249,7 +1253,7 @@ async def mock_get_streaming_response(self, messages, **kwargs): # Execute and collect updates updates = [] - async for update in wrapped(mock_client, messages=[], tools=[no_approval_tool]): + async for update in wrapped(mock_client, messages=[], options={"tools": [no_approval_tool]}): updates.append(update) # Verify: should have function call update, tool result update (injected), and final update @@ -1294,7 +1298,7 @@ async def mock_get_streaming_response(self, messages, **kwargs): # Execute and collect updates updates = [] - async for update in wrapped(mock_client, messages=[], tools=[requires_approval_tool]): + async for update in wrapped(mock_client, messages=[], options={"tools": [requires_approval_tool]}): updates.append(update) # Verify: should yield function call and then approval request @@ -1339,7 +1343,7 @@ async def mock_get_streaming_response(self, messages, **kwargs): # Execute and collect updates updates = [] - async for update in wrapped(mock_client, messages=[], tools=[no_approval_tool]): + async for update in wrapped(mock_client, messages=[], options={"tools": [no_approval_tool]}): updates.append(update) # Verify: should have both function calls, one tool result update with both results, and final message @@ -1388,7 +1392,7 @@ async def mock_get_streaming_response(self, messages, **kwargs): # Execute and collect updates updates = [] - async for update in wrapped(mock_client, messages=[], tools=[requires_approval_tool]): + async for update in wrapped(mock_client, messages=[], options={"tools": [requires_approval_tool]}): updates.append(update) # Verify: should yield both function calls and then approval requests @@ -1435,7 +1439,9 @@ async def mock_get_streaming_response(self, messages, **kwargs): # Execute and collect updates updates = [] - async for update in wrapped(mock_client, messages=[], tools=[no_approval_tool, requires_approval_tool]): + async for update in wrapped( + mock_client, messages=[], options={"tools": [no_approval_tool, requires_approval_tool]} + ): updates.append(update) # Verify: should yield both function calls and then approval requests (when one needs approval, all wait) @@ -1548,4 +1554,467 @@ def test_parse_annotation_with_annotated_and_literal(): assert get_args(literal_type) == ("A", "B", "C") +def test_build_pydantic_model_from_json_schema_array_of_objects_issue(): + """Test for Tools with complex input schema (array of objects). + + This test verifies that JSON schemas with array properties containing nested objects + are properly parsed, ensuring that the nested object schema is preserved + and not reduced to a bare dict. + + Example from issue: + ``` + const SalesOrderItemSchema = z.object({ + customerMaterialNumber: z.string().optional(), + quantity: z.number(), + unitOfMeasure: z.string() + }); + + const CreateSalesOrderInputSchema = z.object({ + contract: z.string(), + items: z.array(SalesOrderItemSchema) + }); + ``` + + The issue was that agents only saw: + ``` + {"contract": "str", "items": "list[dict]"} + ``` + + Instead of the proper nested schema with all fields. + """ + # Schema matching the issue description + schema = { + "type": "object", + "properties": { + "contract": {"type": "string", "description": "Reference contract number"}, + "items": { + "type": "array", + "description": "Sales order line items", + "items": { + "type": "object", + "properties": { + "customerMaterialNumber": { + "type": "string", + "description": "Customer's material number", + }, + "quantity": {"type": "number", "description": "Order quantity"}, + "unitOfMeasure": { + "type": "string", + "description": "Unit of measure (e.g., 'ST', 'KG', 'TO')", + }, + }, + "required": ["quantity", "unitOfMeasure"], + }, + }, + }, + "required": ["contract", "items"], + } + + model = _build_pydantic_model_from_json_schema("create_sales_order", schema) + + # Test valid data + valid_data = { + "contract": "CONTRACT-123", + "items": [ + { + "customerMaterialNumber": "MAT-001", + "quantity": 10, + "unitOfMeasure": "ST", + }, + {"quantity": 5.5, "unitOfMeasure": "KG"}, + ], + } + + instance = model(**valid_data) + + # Verify the data was parsed correctly + assert instance.contract == "CONTRACT-123" + assert len(instance.items) == 2 + + # Verify first item + assert instance.items[0].customerMaterialNumber == "MAT-001" + assert instance.items[0].quantity == 10 + assert instance.items[0].unitOfMeasure == "ST" + + # Verify second item (optional field not provided) + assert instance.items[1].quantity == 5.5 + assert instance.items[1].unitOfMeasure == "KG" + + # Verify that items are proper BaseModel instances, not bare dicts + assert isinstance(instance.items[0], BaseModel) + assert isinstance(instance.items[1], BaseModel) + + # Verify that the nested object has the expected fields + assert hasattr(instance.items[0], "customerMaterialNumber") + assert hasattr(instance.items[0], "quantity") + assert hasattr(instance.items[0], "unitOfMeasure") + + # CRITICAL: Validate using the same methods that actual chat clients use + # This is what would actually be sent to the LLM + + # Create an AIFunction wrapper to access the client-facing APIs + def dummy_func(**kwargs): + return kwargs + + test_func = AIFunction( + func=dummy_func, + name="create_sales_order", + description="Create a sales order", + input_model=model, + ) + + # Test 1: Anthropic client uses tool.parameters() directly + anthropic_schema = test_func.parameters() + + # Verify contract property + assert "contract" in anthropic_schema["properties"] + assert anthropic_schema["properties"]["contract"]["type"] == "string" + + # Verify items array property exists + assert "items" in anthropic_schema["properties"] + items_prop = anthropic_schema["properties"]["items"] + assert items_prop["type"] == "array" + + # THE KEY TEST for Anthropic: array items must have proper object schema + assert "items" in items_prop, "Array should have 'items' schema definition" + array_items_schema = items_prop["items"] + + # Resolve schema if using $ref + if "$ref" in array_items_schema: + ref_path = array_items_schema["$ref"] + assert ref_path.startswith("#/$defs/") or ref_path.startswith("#/definitions/") + ref_name = ref_path.split("/")[-1] + defs = anthropic_schema.get("$defs", anthropic_schema.get("definitions", {})) + assert ref_name in defs, f"Referenced schema '{ref_name}' should exist" + item_schema = defs[ref_name] + else: + item_schema = array_items_schema + + # Verify the nested object has all properties defined + assert "properties" in item_schema, "Array items should have properties (not bare dict)" + item_properties = item_schema["properties"] + + # All three fields must be present in schema sent to LLM + assert "customerMaterialNumber" in item_properties, "customerMaterialNumber missing from LLM schema" + assert "quantity" in item_properties, "quantity missing from LLM schema" + assert "unitOfMeasure" in item_properties, "unitOfMeasure missing from LLM schema" + + # Verify types are correct + assert item_properties["customerMaterialNumber"]["type"] == "string" + assert item_properties["quantity"]["type"] in ["number", "integer"] + assert item_properties["unitOfMeasure"]["type"] == "string" + + # Test 2: OpenAI client uses tool.to_json_schema_spec() + openai_spec = test_func.to_json_schema_spec() + + assert openai_spec["type"] == "function" + assert "function" in openai_spec + openai_schema = openai_spec["function"]["parameters"] + + # Verify the same structure is present in OpenAI format + assert "items" in openai_schema["properties"] + openai_items_prop = openai_schema["properties"]["items"] + assert openai_items_prop["type"] == "array" + assert "items" in openai_items_prop + + openai_array_items = openai_items_prop["items"] + if "$ref" in openai_array_items: + ref_path = openai_array_items["$ref"] + ref_name = ref_path.split("/")[-1] + defs = openai_schema.get("$defs", openai_schema.get("definitions", {})) + openai_item_schema = defs[ref_name] + else: + openai_item_schema = openai_array_items + + assert "properties" in openai_item_schema + openai_props = openai_item_schema["properties"] + assert "customerMaterialNumber" in openai_props + assert "quantity" in openai_props + assert "unitOfMeasure" in openai_props + + # Test validation - missing required quantity + with pytest.raises(ValidationError): + model( + contract="CONTRACT-456", + items=[ + { + "customerMaterialNumber": "MAT-002", + "unitOfMeasure": "TO", + # Missing required 'quantity' + } + ], + ) + + # Test validation - missing required unitOfMeasure + with pytest.raises(ValidationError): + model( + contract="CONTRACT-789", + items=[ + { + "quantity": 20 + # Missing required 'unitOfMeasure' + } + ], + ) + + +def test_one_of_discriminator_polymorphism(): + """Test that oneOf with discriminator creates proper polymorphic union types. + + Tests that oneOf + discriminator patterns are properly converted to Pydantic discriminated unions. + """ + schema = { + "$defs": { + "CreateProject": { + "description": "Action: Create an Azure DevOps project.", + "properties": { + "name": { + "const": "create_project", + "default": "create_project", + "type": "string", + }, + "params": {"$ref": "#/$defs/CreateProjectParams"}, + }, + "required": ["params"], + "type": "object", + }, + "CreateProjectParams": { + "description": "Parameters for the create_project action.", + "properties": { + "orgUrl": {"minLength": 1, "type": "string"}, + "projectName": {"minLength": 1, "type": "string"}, + "description": {"default": "", "type": "string"}, + "template": {"default": "Agile", "type": "string"}, + "sourceControl": { + "default": "Git", + "enum": ["Git", "Tfvc"], + "type": "string", + }, + "visibility": {"default": "private", "type": "string"}, + }, + "required": ["orgUrl", "projectName"], + "type": "object", + }, + "DeployRequest": { + "description": "Request to deploy Azure DevOps resources.", + "properties": { + "projectName": {"minLength": 1, "type": "string"}, + "organization": {"minLength": 1, "type": "string"}, + "actions": { + "items": { + "discriminator": { + "mapping": { + "create_project": "#/$defs/CreateProject", + "hello_world": "#/$defs/HelloWorld", + }, + "propertyName": "name", + }, + "oneOf": [ + {"$ref": "#/$defs/HelloWorld"}, + {"$ref": "#/$defs/CreateProject"}, + ], + }, + "type": "array", + }, + }, + "required": ["projectName", "organization"], + "type": "object", + }, + "HelloWorld": { + "description": "Action: Prints a greeting message.", + "properties": { + "name": { + "const": "hello_world", + "default": "hello_world", + "type": "string", + }, + "params": {"$ref": "#/$defs/HelloWorldParams"}, + }, + "required": ["params"], + "type": "object", + }, + "HelloWorldParams": { + "description": "Parameters for the hello_world action.", + "properties": { + "name": { + "description": "Name to greet", + "minLength": 1, + "type": "string", + } + }, + "required": ["name"], + "type": "object", + }, + }, + "properties": {"params": {"$ref": "#/$defs/DeployRequest"}}, + "required": ["params"], + "type": "object", + } + + # Build the model + model = _build_pydantic_model_from_json_schema("deploy_tool", schema) + + # Verify the model structure + assert model is not None + assert issubclass(model, BaseModel) + + # Test with HelloWorld action + hello_world_data = { + "params": { + "projectName": "MyProject", + "organization": "MyOrg", + "actions": [ + { + "name": "hello_world", + "params": {"name": "Alice"}, + } + ], + } + } + + instance = model(**hello_world_data) + assert instance.params.projectName == "MyProject" + assert instance.params.organization == "MyOrg" + assert len(instance.params.actions) == 1 + assert instance.params.actions[0].name == "hello_world" + assert instance.params.actions[0].params.name == "Alice" + + # Test with CreateProject action + create_project_data = { + "params": { + "projectName": "MyProject", + "organization": "MyOrg", + "actions": [ + { + "name": "create_project", + "params": { + "orgUrl": "https://dev.azure.com/myorg", + "projectName": "NewProject", + "sourceControl": "Git", + }, + } + ], + } + } + + instance2 = model(**create_project_data) + assert instance2.params.actions[0].name == "create_project" + assert instance2.params.actions[0].params.projectName == "NewProject" + assert instance2.params.actions[0].params.sourceControl == "Git" + + # Test with mixed actions + mixed_data = { + "params": { + "projectName": "MyProject", + "organization": "MyOrg", + "actions": [ + {"name": "hello_world", "params": {"name": "Bob"}}, + { + "name": "create_project", + "params": { + "orgUrl": "https://dev.azure.com/myorg", + "projectName": "AnotherProject", + }, + }, + ], + } + } + + instance3 = model(**mixed_data) + assert len(instance3.params.actions) == 2 + assert instance3.params.actions[0].name == "hello_world" + assert instance3.params.actions[1].name == "create_project" + + +def test_const_creates_literal(): + """Test that const in JSON Schema creates Literal type.""" + schema = { + "properties": { + "action": { + "const": "create", + "type": "string", + "description": "Action type", + }, + "value": {"type": "integer"}, + }, + "required": ["action", "value"], + } + + model = _build_pydantic_model_from_json_schema("test_const", schema) + + # Verify valid const value works + instance = model(action="create", value=42) + assert instance.action == "create" + assert instance.value == 42 + + # Verify incorrect const value fails + with pytest.raises(ValidationError): + model(action="delete", value=42) + + +def test_enum_creates_literal(): + """Test that enum in JSON Schema creates Literal type.""" + schema = { + "properties": { + "status": { + "enum": ["pending", "approved", "rejected"], + "type": "string", + "description": "Status", + }, + "priority": {"enum": [1, 2, 3], "type": "integer"}, + }, + "required": ["status"], + } + + model = _build_pydantic_model_from_json_schema("test_enum", schema) + + # Verify valid enum values work + instance = model(status="approved", priority=2) + assert instance.status == "approved" + assert instance.priority == 2 + + # Verify invalid enum value fails + with pytest.raises(ValidationError): + model(status="unknown") + + with pytest.raises(ValidationError): + model(status="pending", priority=5) + + +def test_nested_object_with_const_and_enum(): + """Test that const and enum work in nested objects.""" + schema = { + "properties": { + "config": { + "type": "object", + "properties": { + "type": { + "const": "production", + "default": "production", + "type": "string", + }, + "level": {"enum": ["low", "medium", "high"], "type": "string"}, + }, + "required": ["level"], + } + }, + "required": ["config"], + } + + model = _build_pydantic_model_from_json_schema("test_nested", schema) + + # Valid data + instance = model(config={"type": "production", "level": "high"}) + assert instance.config.type == "production" + assert instance.config.level == "high" + + # Invalid const in nested object + with pytest.raises(ValidationError): + model(config={"type": "development", "level": "low"}) + + # Invalid enum in nested object + with pytest.raises(ValidationError): + model(config={"type": "production", "level": "critical"}) + + # endregion diff --git a/python/packages/core/tests/core/test_types.py b/python/packages/core/tests/core/test_types.py index 6e6e5bfee7..c5187fd960 100644 --- a/python/packages/core/tests/core/test_types.py +++ b/python/packages/core/tests/core/test_types.py @@ -10,8 +10,8 @@ from pytest import fixture, mark, raises from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, BaseContent, ChatMessage, ChatOptions, @@ -43,6 +43,7 @@ UsageContent, UsageDetails, ai_function, + merge_chat_options, prepare_function_call_results, ) from agent_framework.exceptions import AdditionItemMismatch, ContentError @@ -866,117 +867,149 @@ async def gen() -> AsyncIterable[ChatResponseUpdate]: def test_chat_tool_mode(): """Test the ToolMode class to ensure it initializes correctly.""" # Create instances of ToolMode - auto_mode = ToolMode.AUTO - required_any = ToolMode.REQUIRED_ANY - required_mode = ToolMode.REQUIRED("example_function") - none_mode = ToolMode.NONE + auto_mode: ToolMode = {"mode": "auto"} + required_any: ToolMode = {"mode": "required"} + required_mode: ToolMode = {"mode": "required", "required_function_name": "example_function"} + none_mode: ToolMode = {"mode": "none"} # Check the type and content - assert auto_mode.mode == "auto" - assert auto_mode.required_function_name is None - assert required_any.mode == "required" - assert required_any.required_function_name is None - assert required_mode.mode == "required" - assert required_mode.required_function_name == "example_function" - assert none_mode.mode == "none" - assert none_mode.required_function_name is None - - # Ensure the instances are of type ToolMode - assert isinstance(auto_mode, ToolMode) - assert isinstance(required_any, ToolMode) - assert isinstance(required_mode, ToolMode) - assert isinstance(none_mode, ToolMode) - - assert ToolMode.REQUIRED("example_function") == ToolMode.REQUIRED("example_function") - # serializer returns just the mode - assert ToolMode.REQUIRED_ANY.serialize_model() == "required" + assert auto_mode["mode"] == "auto" + assert "required_function_name" not in auto_mode + assert required_any["mode"] == "required" + assert "required_function_name" not in required_any + assert required_mode["mode"] == "required" + assert required_mode["required_function_name"] == "example_function" + assert none_mode["mode"] == "none" + assert "required_function_name" not in none_mode + + # equality of dicts + assert {"mode": "required", "required_function_name": "example_function"} == { + "mode": "required", + "required_function_name": "example_function", + } def test_chat_tool_mode_from_dict(): """Test creating ToolMode from a dictionary.""" - mode_dict = {"mode": "required", "required_function_name": "example_function"} - mode = ToolMode(**mode_dict) + mode: ToolMode = {"mode": "required", "required_function_name": "example_function"} # Check the type and content - assert mode.mode == "required" - assert mode.required_function_name == "example_function" - - # Ensure the instance is of type ToolMode - assert isinstance(mode, ToolMode) + assert mode["mode"] == "required" + assert mode["required_function_name"] == "example_function" # region ChatOptions def test_chat_options_init() -> None: - options = ChatOptions() - assert options.model_id is None + """Test that ChatOptions can be created as a TypedDict.""" + options: ChatOptions = {} + assert options.get("model_id") is None + + # With values + options_with_model: ChatOptions = {"model_id": "gpt-4o", "temperature": 0.7} + assert options_with_model.get("model_id") == "gpt-4o" + assert options_with_model.get("temperature") == 0.7 + + +def test_chat_options_tool_choice_validation(): + """Test validate_tool_mode utility function.""" + from agent_framework._types import validate_tool_mode + + # Valid string values + assert validate_tool_mode("auto") == {"mode": "auto"} + assert validate_tool_mode("required") == {"mode": "required"} + assert validate_tool_mode("none") == {"mode": "none"} + + # Valid ToolMode dict values + assert validate_tool_mode({"mode": "auto"}) == {"mode": "auto"} + assert validate_tool_mode({"mode": "required"}) == {"mode": "required"} + assert validate_tool_mode({"mode": "required", "required_function_name": "example_function"}) == { + "mode": "required", + "required_function_name": "example_function", + } + assert validate_tool_mode({"mode": "none"}) == {"mode": "none"} + # None should return mode==none + assert validate_tool_mode(None) == {"mode": "none"} -def test_chat_options_tool_choice_validation_errors(): - with raises((ContentError, TypeError)): - ChatOptions(tool_choice="invalid-choice") + with raises(ContentError): + validate_tool_mode("invalid_mode") + with raises(ContentError): + validate_tool_mode({"mode": "invalid_mode"}) + with raises(ContentError): + validate_tool_mode({"mode": "auto", "required_function_name": "should_not_be_here"}) -def test_chat_options_and(ai_function_tool, ai_tool) -> None: - options1 = ChatOptions(model_id="gpt-4o", tools=[ai_function_tool], logit_bias={"x": 1}, metadata={"a": "b"}) - options2 = ChatOptions(model_id="gpt-4.1", tools=[ai_tool], additional_properties={"p": 1}) +def test_chat_options_merge(ai_function_tool, ai_tool) -> None: + """Test merge_chat_options utility function.""" + from agent_framework import merge_chat_options + + options1: ChatOptions = { + "model_id": "gpt-4o", + "tools": [ai_function_tool], + "logit_bias": {"x": 1}, + "metadata": {"a": "b"}, + } + options2: ChatOptions = {"model_id": "gpt-4.1", "tools": [ai_tool]} assert options1 != options2 - options3 = options1 & options2 - assert options3.model_id == "gpt-4.1" - assert options3.tools == [ai_function_tool, ai_tool] - assert options3.logit_bias == {"x": 1} - assert options3.metadata == {"a": "b"} - assert options3.additional_properties.get("p") == 1 + # Merge options - override takes precedence for non-collection fields + options3 = merge_chat_options(options1, options2) + + assert options3.get("model_id") == "gpt-4.1" + assert options3.get("tools") == [ai_function_tool, ai_tool] # tools are combined + assert options3.get("logit_bias") == {"x": 1} # base value preserved + assert options3.get("metadata") == {"a": "b"} # base value preserved def test_chat_options_and_tool_choice_override() -> None: """Test that tool_choice from other takes precedence in ChatOptions merge.""" # Agent-level defaults to "auto" - agent_options = ChatOptions(model_id="gpt-4o", tool_choice="auto") + agent_options: ChatOptions = {"model_id": "gpt-4o", "tool_choice": "auto"} # Run-level specifies "required" - run_options = ChatOptions(tool_choice="required") + run_options: ChatOptions = {"tool_choice": "required"} - merged = agent_options & run_options + merged = merge_chat_options(agent_options, run_options) # Run-level should override agent-level - assert merged.tool_choice == "required" - assert merged.model_id == "gpt-4o" # Other fields preserved + assert merged.get("tool_choice") == "required" + assert merged.get("model_id") == "gpt-4o" # Other fields preserved def test_chat_options_and_tool_choice_none_in_other_uses_self() -> None: """Test that when other.tool_choice is None, self.tool_choice is used.""" - agent_options = ChatOptions(tool_choice="auto") - run_options = ChatOptions(model_id="gpt-4.1") # tool_choice is None + agent_options: ChatOptions = {"tool_choice": "auto"} + run_options: ChatOptions = {"model_id": "gpt-4.1"} # tool_choice is None - merged = agent_options & run_options + merged = merge_chat_options(agent_options, run_options) # Should keep agent-level tool_choice since run-level is None - assert merged.tool_choice == "auto" - assert merged.model_id == "gpt-4.1" + assert merged.get("tool_choice") == "auto" + assert merged.get("model_id") == "gpt-4.1" def test_chat_options_and_tool_choice_with_tool_mode() -> None: """Test ChatOptions merge with ToolMode objects.""" - agent_options = ChatOptions(tool_choice=ToolMode.AUTO) - run_options = ChatOptions(tool_choice=ToolMode.REQUIRED_ANY) + agent_options: ChatOptions = {"tool_choice": "auto"} + run_options: ChatOptions = {"tool_choice": "required"} - merged = agent_options & run_options + merged = merge_chat_options(agent_options, run_options) - assert merged.tool_choice == ToolMode.REQUIRED_ANY - assert merged.tool_choice == "required" # ToolMode equality with string + assert merged.get("tool_choice") == "required" + assert merged.get("tool_choice") == "required" def test_chat_options_and_tool_choice_required_specific_function() -> None: """Test ChatOptions merge with required specific function.""" - agent_options = ChatOptions(tool_choice="auto") - run_options = ChatOptions(tool_choice=ToolMode.REQUIRED(function_name="get_weather")) + agent_options: ChatOptions = {"tool_choice": "auto"} + run_options: ChatOptions = {"tool_choice": {"mode": "required", "required_function_name": "get_weather"}} - merged = agent_options & run_options + merged = merge_chat_options(agent_options, run_options) - assert merged.tool_choice == "required" - assert merged.tool_choice.required_function_name == "get_weather" + tool_choice = merged.get("tool_choice") + assert tool_choice == {"mode": "required", "required_function_name": "get_weather"} + assert tool_choice["required_function_name"] == "get_weather" # region Agent Response Fixtures @@ -993,90 +1026,90 @@ def text_content() -> TextContent: @fixture -def agent_run_response(chat_message: ChatMessage) -> AgentRunResponse: - return AgentRunResponse(messages=chat_message) +def agent_response(chat_message: ChatMessage) -> AgentResponse: + return AgentResponse(messages=chat_message) @fixture -def agent_run_response_update(text_content: TextContent) -> AgentRunResponseUpdate: - return AgentRunResponseUpdate(role=Role.ASSISTANT, contents=[text_content]) +def agent_response_update(text_content: TextContent) -> AgentResponseUpdate: + return AgentResponseUpdate(role=Role.ASSISTANT, contents=[text_content]) -# region AgentRunResponse +# region AgentResponse def test_agent_run_response_init_single_message(chat_message: ChatMessage) -> None: - response = AgentRunResponse(messages=chat_message) + response = AgentResponse(messages=chat_message) assert response.messages == [chat_message] def test_agent_run_response_init_list_messages(chat_message: ChatMessage) -> None: - response = AgentRunResponse(messages=[chat_message, chat_message]) + response = AgentResponse(messages=[chat_message, chat_message]) assert len(response.messages) == 2 assert response.messages[0] == chat_message def test_agent_run_response_init_none_messages() -> None: - response = AgentRunResponse() + response = AgentResponse() assert response.messages == [] def test_agent_run_response_text_property(chat_message: ChatMessage) -> None: - response = AgentRunResponse(messages=[chat_message, chat_message]) + response = AgentResponse(messages=[chat_message, chat_message]) assert response.text == "HelloHello" def test_agent_run_response_text_property_empty() -> None: - response = AgentRunResponse() + response = AgentResponse() assert response.text == "" -def test_agent_run_response_from_updates(agent_run_response_update: AgentRunResponseUpdate) -> None: - updates = [agent_run_response_update, agent_run_response_update] - response = AgentRunResponse.from_agent_run_response_updates(updates) +def test_agent_run_response_from_updates(agent_response_update: AgentResponseUpdate) -> None: + updates = [agent_response_update, agent_response_update] + response = AgentResponse.from_agent_run_response_updates(updates) assert len(response.messages) > 0 assert response.text == "Test contentTest content" def test_agent_run_response_str_method(chat_message: ChatMessage) -> None: - response = AgentRunResponse(messages=chat_message) + response = AgentResponse(messages=chat_message) assert str(response) == "Hello" -# region AgentRunResponseUpdate +# region AgentResponseUpdate def test_agent_run_response_update_init_content_list(text_content: TextContent) -> None: - update = AgentRunResponseUpdate(contents=[text_content, text_content]) + update = AgentResponseUpdate(contents=[text_content, text_content]) assert len(update.contents) == 2 assert update.contents[0] == text_content def test_agent_run_response_update_init_none_content() -> None: - update = AgentRunResponseUpdate() + update = AgentResponseUpdate() assert update.contents == [] def test_agent_run_response_update_text_property(text_content: TextContent) -> None: - update = AgentRunResponseUpdate(contents=[text_content, text_content]) + update = AgentResponseUpdate(contents=[text_content, text_content]) assert update.text == "Test contentTest content" def test_agent_run_response_update_text_property_empty() -> None: - update = AgentRunResponseUpdate() + update = AgentResponseUpdate() assert update.text == "" def test_agent_run_response_update_str_method(text_content: TextContent) -> None: - update = AgentRunResponseUpdate(contents=[text_content]) + update = AgentResponseUpdate(contents=[text_content]) assert str(update) == "Test content" def test_agent_run_response_update_created_at() -> None: - """Test that AgentRunResponseUpdate properly handles created_at timestamps.""" + """Test that AgentResponseUpdate properly handles created_at timestamps.""" # Test with a properly formatted UTC timestamp utc_timestamp = "2024-12-01T00:31:30.000000Z" - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[TextContent(text="test")], role=Role.ASSISTANT, created_at=utc_timestamp, @@ -1087,7 +1120,7 @@ def test_agent_run_response_update_created_at() -> None: # Verify that we can generate a proper UTC timestamp now_utc = datetime.now(tz=timezone.utc) formatted_utc = now_utc.strftime("%Y-%m-%dT%H:%M:%S.%fZ") - update_with_now = AgentRunResponseUpdate( + update_with_now = AgentResponseUpdate( contents=[TextContent(text="test")], role=Role.ASSISTANT, created_at=formatted_utc, @@ -1097,10 +1130,10 @@ def test_agent_run_response_update_created_at() -> None: def test_agent_run_response_created_at() -> None: - """Test that AgentRunResponse properly handles created_at timestamps.""" + """Test that AgentResponse properly handles created_at timestamps.""" # Test with a properly formatted UTC timestamp utc_timestamp = "2024-12-01T00:31:30.000000Z" - response = AgentRunResponse( + response = AgentResponse( messages=[ChatMessage(role=Role.ASSISTANT, text="Hello")], created_at=utc_timestamp, ) @@ -1110,7 +1143,7 @@ def test_agent_run_response_created_at() -> None: # Verify that we can generate a proper UTC timestamp now_utc = datetime.now(tz=timezone.utc) formatted_utc = now_utc.strftime("%Y-%m-%dT%H:%M:%S.%fZ") - response_with_now = AgentRunResponse( + response_with_now = AgentResponse( messages=[ChatMessage(role=Role.ASSISTANT, text="Hello")], created_at=formatted_utc, ) @@ -1249,23 +1282,23 @@ def test_function_call_content_parse_numeric_or_list(): def test_chat_tool_mode_eq_with_string(): - assert ToolMode.AUTO == "auto" + assert {"mode": "auto"} == {"mode": "auto"} -# region AgentRunResponse +# region AgentResponse @fixture -def agent_run_response_async() -> AgentRunResponse: - return AgentRunResponse(messages=[ChatMessage(role="user", text="Hello")]) +def agent_run_response_async() -> AgentResponse: + return AgentResponse(messages=[ChatMessage(role="user", text="Hello")]) async def test_agent_run_response_from_async_generator(): async def gen(): - yield AgentRunResponseUpdate(contents=[TextContent("A")]) - yield AgentRunResponseUpdate(contents=[TextContent("B")]) + yield AgentResponseUpdate(contents=[TextContent("A")]) + yield AgentResponseUpdate(contents=[TextContent("B")]) - r = await AgentRunResponse.from_agent_response_generator(gen()) + r = await AgentResponse.from_agent_response_generator(gen()) assert r.text == "AB" @@ -1437,30 +1470,6 @@ def test_chat_message_from_dict_with_mixed_content(): assert len(message_dict["contents"]) == 3 -def test_chat_options_edge_cases(): - """Test ChatOptions with edge cases for better coverage.""" - - # Test with tools conversion - def sample_tool(): - return "test" - - options = ChatOptions(tools=[sample_tool], tool_choice="auto") - assert options.tool_choice == ToolMode.AUTO - - # Test to_dict with ToolMode - options_dict = options.to_dict() - assert "tool_choice" in options_dict - - # Test from_dict with tool_choice dict - data_with_dict_tool_choice = { - "model_id": "gpt-4", - "tool_choice": {"mode": "required", "required_function_name": "test_func"}, - } - options_from_dict = ChatOptions.from_dict(data_with_dict_tool_choice) - assert options_from_dict.tool_choice.mode == "required" - assert options_from_dict.tool_choice.required_function_name == "test_func" - - def test_text_content_add_type_error(): """Test TextContent __add__ raises TypeError for incompatible types.""" t1 = TextContent("Hello") @@ -1501,30 +1510,6 @@ def test_comprehensive_serialization_methods(): assert result_content.result == "success" -def test_chat_options_tool_choice_variations(): - """Test ChatOptions from_dict and to_dict with various tool_choice values.""" - - # Test with string tool_choice - data = {"model_id": "gpt-4", "tool_choice": "auto", "temperature": 0.7} - options = ChatOptions.from_dict(data) - assert options.tool_choice == ToolMode.AUTO - - # Test with dict tool_choice - data_dict = { - "model_id": "gpt-4", - "tool_choice": {"mode": "required", "required_function_name": "test_func"}, - "temperature": 0.7, - } - options_dict = ChatOptions.from_dict(data_dict) - assert options_dict.tool_choice.mode == "required" - assert options_dict.tool_choice.required_function_name == "test_func" - - # Test to_dict with ToolMode - options_dict_serialized = options_dict.to_dict() - assert "tool_choice" in options_dict_serialized - assert isinstance(options_dict_serialized["tool_choice"], dict) - - def test_chat_message_complex_content_serialization(): """Test ChatMessage serialization with various content types.""" @@ -1683,7 +1668,7 @@ def test_chat_response_update_all_content_types(): def test_agent_run_response_complex_serialization(): - """Test AgentRunResponse from_dict and to_dict with messages and usage_details.""" + """Test AgentResponse from_dict and to_dict with messages and usage_details.""" response_data = { "messages": [ @@ -1698,7 +1683,7 @@ def test_agent_run_response_complex_serialization(): }, } - response = AgentRunResponse.from_dict(response_data) + response = AgentResponse.from_dict(response_data) assert len(response.messages) == 2 assert isinstance(response.messages[0], ChatMessage) assert isinstance(response.usage_details, UsageDetails) @@ -1711,7 +1696,7 @@ def test_agent_run_response_complex_serialization(): def test_agent_run_response_update_all_content_types(): - """Test AgentRunResponseUpdate from_dict with all content types and role handling.""" + """Test AgentResponseUpdate from_dict with all content types and role handling.""" update_data = { "contents": [ @@ -1740,7 +1725,7 @@ def test_agent_run_response_update_all_content_types(): "role": {"value": "assistant"}, # Test role as dict } - update = AgentRunResponseUpdate.from_dict(update_data) + update = AgentResponseUpdate.from_dict(update_data) assert len(update.contents) == 12 # unknown_type is logged and ignored assert isinstance(update.role, Role) assert update.role.value == "assistant" @@ -1753,7 +1738,7 @@ def test_agent_run_response_update_all_content_types(): # Test role as string conversion update_data_str_role = update_data.copy() update_data_str_role["role"] = "user" - update_str = AgentRunResponseUpdate.from_dict(update_data_str_role) + update_str = AgentResponseUpdate.from_dict(update_data_str_role) assert isinstance(update_str.role, Role) assert update_str.role.value == "user" @@ -1937,7 +1922,7 @@ def test_agent_run_response_update_all_content_types(): id="chat_response_update", ), pytest.param( - AgentRunResponse, + AgentResponse, { "messages": [ { @@ -1957,10 +1942,10 @@ def test_agent_run_response_update_all_content_types(): "total_token_count": 8, }, }, - id="agent_run_response", + id="agent_response", ), pytest.param( - AgentRunResponseUpdate, + AgentResponseUpdate, { "contents": [ {"type": "text", "text": "Streaming"}, @@ -1971,7 +1956,7 @@ def test_agent_run_response_update_all_content_types(): "response_id": "run-123", "author_name": "Agent", }, - id="agent_run_response_update", + id="agent_response_update", ), ], ) diff --git a/python/packages/core/tests/openai/test_assistant_provider.py b/python/packages/core/tests/openai/test_assistant_provider.py new file mode 100644 index 0000000000..bb29691a07 --- /dev/null +++ b/python/packages/core/tests/openai/test_assistant_provider.py @@ -0,0 +1,814 @@ +# Copyright (c) Microsoft. All rights reserved. + +import os +from typing import Annotated, Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from openai.types.beta.assistant import Assistant +from pydantic import BaseModel, Field + +from agent_framework import ChatAgent, HostedCodeInterpreterTool, HostedFileSearchTool, ai_function, normalize_tools +from agent_framework.exceptions import ServiceInitializationError +from agent_framework.openai import OpenAIAssistantProvider +from agent_framework.openai._shared import from_assistant_tools, to_assistant_tools + +# region Test Helpers + + +def create_mock_assistant( + assistant_id: str = "asst_test123", + name: str = "TestAssistant", + model: str = "gpt-4", + instructions: str | None = "You are a helpful assistant.", + description: str | None = None, + tools: list[Any] | None = None, +) -> Assistant: + """Create a mock Assistant object.""" + mock = MagicMock(spec=Assistant) + mock.id = assistant_id + mock.name = name + mock.model = model + mock.instructions = instructions + mock.description = description + mock.tools = tools or [] + return mock + + +def create_function_tool(name: str, description: str = "A test function") -> MagicMock: + """Create a mock FunctionTool.""" + mock = MagicMock() + mock.type = "function" + mock.function = MagicMock() + mock.function.name = name + mock.function.description = description + return mock + + +def create_code_interpreter_tool() -> MagicMock: + """Create a mock CodeInterpreterTool.""" + mock = MagicMock() + mock.type = "code_interpreter" + return mock + + +def create_file_search_tool() -> MagicMock: + """Create a mock FileSearchTool.""" + mock = MagicMock() + mock.type = "file_search" + return mock + + +@pytest.fixture +def mock_async_openai() -> MagicMock: + """Mock AsyncOpenAI client.""" + mock_client = MagicMock() + + # Mock beta.assistants + mock_client.beta.assistants.create = AsyncMock( + return_value=create_mock_assistant(assistant_id="asst_created123", name="CreatedAssistant") + ) + mock_client.beta.assistants.retrieve = AsyncMock( + return_value=create_mock_assistant(assistant_id="asst_retrieved123", name="RetrievedAssistant") + ) + mock_client.beta.assistants.delete = AsyncMock() + + # Mock close method + mock_client.close = AsyncMock() + + return mock_client + + +# Test function for tool validation +def get_weather(location: Annotated[str, Field(description="The location")]) -> str: + """Get the weather for a location.""" + return f"Weather in {location}: sunny" + + +def search_database(query: Annotated[str, Field(description="Search query")]) -> str: + """Search the database.""" + return f"Results for: {query}" + + +# Pydantic model for structured output tests +class WeatherResponse(BaseModel): + location: str + temperature: float + conditions: str + + +# endregion + + +# region Initialization Tests + + +class TestOpenAIAssistantProviderInit: + """Tests for provider initialization.""" + + def test_init_with_client(self, mock_async_openai: MagicMock) -> None: + """Test initialization with existing AsyncOpenAI client.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + assert provider._client is mock_async_openai # type: ignore[reportPrivateUsage] + assert provider._should_close_client is False # type: ignore[reportPrivateUsage] + + def test_init_without_client_creates_one(self, openai_unit_test_env: dict[str, str]) -> None: + """Test initialization creates client from settings.""" + provider = OpenAIAssistantProvider() + + assert provider._client is not None # type: ignore[reportPrivateUsage] + assert provider._should_close_client is True # type: ignore[reportPrivateUsage] + + def test_init_with_api_key(self) -> None: + """Test initialization with explicit API key.""" + provider = OpenAIAssistantProvider(api_key="sk-test-key") + + assert provider._client is not None # type: ignore[reportPrivateUsage] + assert provider._should_close_client is True # type: ignore[reportPrivateUsage] + + def test_init_fails_without_api_key(self) -> None: + """Test initialization fails without API key when settings return None.""" + from unittest.mock import patch + + # Mock OpenAISettings to return None for api_key + with patch("agent_framework.openai._assistant_provider.OpenAISettings") as mock_settings: + mock_settings.return_value.api_key = None + + with pytest.raises(ServiceInitializationError) as exc_info: + OpenAIAssistantProvider() + + assert "API key is required" in str(exc_info.value) + + def test_init_with_org_id_and_base_url(self) -> None: + """Test initialization with organization ID and base URL.""" + provider = OpenAIAssistantProvider( + api_key="sk-test-key", + org_id="org-123", + base_url="https://custom.openai.com", + ) + + assert provider._client is not None # type: ignore[reportPrivateUsage] + + +class TestOpenAIAssistantProviderContextManager: + """Tests for async context manager.""" + + async def test_context_manager_enter_exit(self, mock_async_openai: MagicMock) -> None: + """Test async context manager entry and exit.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + async with provider as p: + assert p is provider + + async def test_context_manager_closes_owned_client(self, openai_unit_test_env: dict[str, str]) -> None: + """Test that owned client is closed on exit.""" + provider = OpenAIAssistantProvider() + client = provider._client # type: ignore[reportPrivateUsage] + assert client is not None + client.close = AsyncMock() + + async with provider: + pass + + client.close.assert_called_once() + + async def test_context_manager_does_not_close_external_client(self, mock_async_openai: MagicMock) -> None: + """Test that external client is not closed on exit.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + async with provider: + pass + + mock_async_openai.close.assert_not_called() + + +# endregion + + +# region create_agent Tests + + +class TestOpenAIAssistantProviderCreateAgent: + """Tests for create_agent method.""" + + async def test_create_agent_basic(self, mock_async_openai: MagicMock) -> None: + """Test basic assistant creation.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + agent = await provider.create_agent( + name="TestAgent", + model="gpt-4", + instructions="You are helpful.", + ) + + assert isinstance(agent, ChatAgent) + assert agent.name == "CreatedAssistant" + mock_async_openai.beta.assistants.create.assert_called_once() + + # Verify create was called with correct parameters + call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs + assert call_kwargs["name"] == "TestAgent" + assert call_kwargs["model"] == "gpt-4" + assert call_kwargs["instructions"] == "You are helpful." + + async def test_create_agent_with_description(self, mock_async_openai: MagicMock) -> None: + """Test assistant creation with description.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + await provider.create_agent( + name="TestAgent", + model="gpt-4", + description="A test agent description", + ) + + call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs + assert call_kwargs["description"] == "A test agent description" + + async def test_create_agent_with_function_tools(self, mock_async_openai: MagicMock) -> None: + """Test assistant creation with function tools.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + agent = await provider.create_agent( + name="WeatherAgent", + model="gpt-4", + tools=[get_weather], + ) + + assert isinstance(agent, ChatAgent) + + # Verify tools were passed to create + call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs + assert "tools" in call_kwargs + assert len(call_kwargs["tools"]) == 1 + assert call_kwargs["tools"][0]["type"] == "function" + assert call_kwargs["tools"][0]["function"]["name"] == "get_weather" + + async def test_create_agent_with_ai_function(self, mock_async_openai: MagicMock) -> None: + """Test assistant creation with AIFunction.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + @ai_function + def my_function(x: int) -> int: + """Double a number.""" + return x * 2 + + await provider.create_agent( + name="TestAgent", + model="gpt-4", + tools=[my_function], + ) + + call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs + assert call_kwargs["tools"][0]["function"]["name"] == "my_function" + + async def test_create_agent_with_code_interpreter(self, mock_async_openai: MagicMock) -> None: + """Test assistant creation with code interpreter.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + await provider.create_agent( + name="CodeAgent", + model="gpt-4", + tools=[HostedCodeInterpreterTool()], + ) + + call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs + assert {"type": "code_interpreter"} in call_kwargs["tools"] + + async def test_create_agent_with_file_search(self, mock_async_openai: MagicMock) -> None: + """Test assistant creation with file search.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + await provider.create_agent( + name="SearchAgent", + model="gpt-4", + tools=[HostedFileSearchTool()], + ) + + call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs + assert any(t["type"] == "file_search" for t in call_kwargs["tools"]) + + async def test_create_agent_with_file_search_max_results(self, mock_async_openai: MagicMock) -> None: + """Test assistant creation with file search and max_results.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + await provider.create_agent( + name="SearchAgent", + model="gpt-4", + tools=[HostedFileSearchTool(max_results=10)], + ) + + call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs + file_search_tool = next(t for t in call_kwargs["tools"] if t["type"] == "file_search") + assert file_search_tool.get("file_search", {}).get("max_num_results") == 10 + + async def test_create_agent_with_mixed_tools(self, mock_async_openai: MagicMock) -> None: + """Test assistant creation with multiple tool types.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + await provider.create_agent( + name="MultiToolAgent", + model="gpt-4", + tools=[get_weather, HostedCodeInterpreterTool(), HostedFileSearchTool()], + ) + + call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs + assert len(call_kwargs["tools"]) == 3 + + async def test_create_agent_with_metadata(self, mock_async_openai: MagicMock) -> None: + """Test assistant creation with metadata.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + await provider.create_agent( + name="TestAgent", + model="gpt-4", + metadata={"env": "test", "version": "1.0"}, + ) + + call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs + assert call_kwargs["metadata"] == {"env": "test", "version": "1.0"} + + async def test_create_agent_with_response_format_pydantic(self, mock_async_openai: MagicMock) -> None: + """Test assistant creation with Pydantic response format via default_options.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + await provider.create_agent( + name="StructuredAgent", + model="gpt-4", + default_options={"response_format": WeatherResponse}, + ) + + call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs + assert call_kwargs["response_format"]["type"] == "json_schema" + assert call_kwargs["response_format"]["json_schema"]["name"] == "WeatherResponse" + + async def test_create_agent_returns_chat_agent(self, mock_async_openai: MagicMock) -> None: + """Test that create_agent returns a ChatAgent instance.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + agent = await provider.create_agent( + name="TestAgent", + model="gpt-4", + ) + + assert isinstance(agent, ChatAgent) + + +# endregion + + +# region get_agent Tests + + +class TestOpenAIAssistantProviderGetAgent: + """Tests for get_agent method.""" + + async def test_get_agent_basic(self, mock_async_openai: MagicMock) -> None: + """Test retrieving an existing assistant.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + agent = await provider.get_agent(assistant_id="asst_123") + + assert isinstance(agent, ChatAgent) + mock_async_openai.beta.assistants.retrieve.assert_called_once_with("asst_123") + + async def test_get_agent_with_instructions_override(self, mock_async_openai: MagicMock) -> None: + """Test retrieving assistant with instruction override.""" + provider = OpenAIAssistantProvider(mock_async_openai) + + agent = await provider.get_agent( + assistant_id="asst_123", + instructions="Custom instructions", + ) + + # Agent should be created successfully with the custom instructions + assert isinstance(agent, ChatAgent) + assert agent.id == "asst_retrieved123" + + async def test_get_agent_with_function_tools(self, mock_async_openai: MagicMock) -> None: + """Test retrieving assistant with function tools provided.""" + # Setup assistant with function tool + assistant = create_mock_assistant(tools=[create_function_tool("get_weather")]) + mock_async_openai.beta.assistants.retrieve = AsyncMock(return_value=assistant) + + provider = OpenAIAssistantProvider(mock_async_openai) + + agent = await provider.get_agent( + assistant_id="asst_123", + tools=[get_weather], + ) + + assert isinstance(agent, ChatAgent) + + async def test_get_agent_validates_missing_function_tools(self, mock_async_openai: MagicMock) -> None: + """Test that missing function tools raise ValueError.""" + # Setup assistant with function tool + assistant = create_mock_assistant(tools=[create_function_tool("get_weather")]) + mock_async_openai.beta.assistants.retrieve = AsyncMock(return_value=assistant) + + provider = OpenAIAssistantProvider(mock_async_openai) + + with pytest.raises(ValueError) as exc_info: + await provider.get_agent(assistant_id="asst_123") + + assert "get_weather" in str(exc_info.value) + assert "no implementation was provided" in str(exc_info.value) + + async def test_get_agent_validates_multiple_missing_function_tools(self, mock_async_openai: MagicMock) -> None: + """Test validation with multiple missing function tools.""" + assistant = create_mock_assistant( + tools=[create_function_tool("get_weather"), create_function_tool("search_database")] + ) + mock_async_openai.beta.assistants.retrieve = AsyncMock(return_value=assistant) + + provider = OpenAIAssistantProvider(mock_async_openai) + + with pytest.raises(ValueError) as exc_info: + await provider.get_agent(assistant_id="asst_123") + + error_msg = str(exc_info.value) + assert "get_weather" in error_msg or "search_database" in error_msg + + async def test_get_agent_merges_hosted_tools(self, mock_async_openai: MagicMock) -> None: + """Test that hosted tools are automatically included.""" + assistant = create_mock_assistant(tools=[create_code_interpreter_tool(), create_file_search_tool()]) + mock_async_openai.beta.assistants.retrieve = AsyncMock(return_value=assistant) + + provider = OpenAIAssistantProvider(mock_async_openai) + + agent = await provider.get_agent(assistant_id="asst_123") + + # Hosted tools should be merged automatically + assert isinstance(agent, ChatAgent) + + +# endregion + + +# region as_agent Tests + + +class TestOpenAIAssistantProviderAsAgent: + """Tests for as_agent method.""" + + def test_as_agent_no_http_call(self, mock_async_openai: MagicMock) -> None: + """Test that as_agent doesn't make HTTP calls.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant = create_mock_assistant() + + agent = provider.as_agent(assistant) + + assert isinstance(agent, ChatAgent) + # Verify no HTTP calls were made + mock_async_openai.beta.assistants.create.assert_not_called() + mock_async_openai.beta.assistants.retrieve.assert_not_called() + + def test_as_agent_wraps_assistant(self, mock_async_openai: MagicMock) -> None: + """Test wrapping an SDK Assistant object.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant = create_mock_assistant( + assistant_id="asst_wrap123", + name="WrappedAssistant", + instructions="Original instructions", + ) + + agent = provider.as_agent(assistant) + + assert agent.id == "asst_wrap123" + assert agent.name == "WrappedAssistant" + # Instructions are passed to ChatOptions, not exposed as attribute + assert isinstance(agent, ChatAgent) + + def test_as_agent_with_instructions_override(self, mock_async_openai: MagicMock) -> None: + """Test as_agent with instruction override.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant = create_mock_assistant(instructions="Original") + + agent = provider.as_agent(assistant, instructions="Override") + + # Agent should be created successfully with override instructions + assert isinstance(agent, ChatAgent) + + def test_as_agent_validates_function_tools(self, mock_async_openai: MagicMock) -> None: + """Test that missing function tools raise ValueError.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant = create_mock_assistant(tools=[create_function_tool("get_weather")]) + + with pytest.raises(ValueError) as exc_info: + provider.as_agent(assistant) + + assert "get_weather" in str(exc_info.value) + + def test_as_agent_with_function_tools_provided(self, mock_async_openai: MagicMock) -> None: + """Test as_agent with function tools provided.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant = create_mock_assistant(tools=[create_function_tool("get_weather")]) + + agent = provider.as_agent(assistant, tools=[get_weather]) + + assert isinstance(agent, ChatAgent) + + def test_as_agent_merges_hosted_tools(self, mock_async_openai: MagicMock) -> None: + """Test that hosted tools are merged automatically.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant = create_mock_assistant(tools=[create_code_interpreter_tool()]) + + agent = provider.as_agent(assistant) + + assert isinstance(agent, ChatAgent) + + def test_as_agent_hosted_tools_not_required(self, mock_async_openai: MagicMock) -> None: + """Test that hosted tools don't require user implementations.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant = create_mock_assistant(tools=[create_code_interpreter_tool(), create_file_search_tool()]) + + # Should not raise - hosted tools don't need implementations + agent = provider.as_agent(assistant) + + assert isinstance(agent, ChatAgent) + + +# endregion + + +# region Tool Conversion Tests + + +class TestToolConversion: + """Tests for tool conversion utilities (shared functions).""" + + def test_to_assistant_tools_ai_function(self) -> None: + """Test AIFunction conversion to API format.""" + + @ai_function + def test_func(x: int) -> int: + """Test function.""" + return x + + # Normalize tools first, then convert + normalized = normalize_tools([test_func]) + api_tools = to_assistant_tools(normalized) + + assert len(api_tools) == 1 + assert api_tools[0]["type"] == "function" + assert api_tools[0]["function"]["name"] == "test_func" + + def test_to_assistant_tools_callable(self) -> None: + """Test raw callable conversion via normalize_tools.""" + # normalize_tools converts callables to AIFunction + normalized = normalize_tools([get_weather]) + api_tools = to_assistant_tools(normalized) + + assert len(api_tools) == 1 + assert api_tools[0]["type"] == "function" + assert api_tools[0]["function"]["name"] == "get_weather" + + def test_to_assistant_tools_code_interpreter(self) -> None: + """Test HostedCodeInterpreterTool conversion.""" + api_tools = to_assistant_tools([HostedCodeInterpreterTool()]) + + assert len(api_tools) == 1 + assert api_tools[0] == {"type": "code_interpreter"} + + def test_to_assistant_tools_file_search(self) -> None: + """Test HostedFileSearchTool conversion.""" + api_tools = to_assistant_tools([HostedFileSearchTool()]) + + assert len(api_tools) == 1 + assert api_tools[0]["type"] == "file_search" + + def test_to_assistant_tools_file_search_with_max_results(self) -> None: + """Test HostedFileSearchTool with max_results conversion.""" + api_tools = to_assistant_tools([HostedFileSearchTool(max_results=5)]) + + assert api_tools[0]["file_search"]["max_num_results"] == 5 + + def test_to_assistant_tools_dict(self) -> None: + """Test raw dict tool passthrough.""" + raw_tool = {"type": "function", "function": {"name": "custom", "description": "Custom tool"}} + + api_tools = to_assistant_tools([raw_tool]) + + assert len(api_tools) == 1 + assert api_tools[0] == raw_tool + + def test_to_assistant_tools_empty(self) -> None: + """Test conversion with no tools.""" + api_tools = to_assistant_tools(None) + + assert api_tools == [] + + def test_from_assistant_tools_code_interpreter(self) -> None: + """Test converting code_interpreter tool from OpenAI format.""" + assistant_tools = [create_code_interpreter_tool()] + + tools = from_assistant_tools(assistant_tools) + + assert len(tools) == 1 + assert isinstance(tools[0], HostedCodeInterpreterTool) + + def test_from_assistant_tools_file_search(self) -> None: + """Test converting file_search tool from OpenAI format.""" + assistant_tools = [create_file_search_tool()] + + tools = from_assistant_tools(assistant_tools) + + assert len(tools) == 1 + assert isinstance(tools[0], HostedFileSearchTool) + + def test_from_assistant_tools_function_skipped(self) -> None: + """Test that function tools are skipped (no implementations).""" + assistant_tools = [create_function_tool("test_func")] + + tools = from_assistant_tools(assistant_tools) + + assert len(tools) == 0 # Function tools are skipped + + def test_from_assistant_tools_empty(self) -> None: + """Test conversion with no tools.""" + tools = from_assistant_tools(None) + + assert tools == [] + + +# endregion + + +# region Tool Validation Tests + + +class TestToolValidation: + """Tests for tool validation.""" + + def test_validate_missing_function_tool_raises(self, mock_async_openai: MagicMock) -> None: + """Test that missing function tools raise ValueError.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant_tools = [create_function_tool("my_function")] + + with pytest.raises(ValueError) as exc_info: + provider._validate_function_tools(assistant_tools, None) # type: ignore[reportPrivateUsage] + + assert "my_function" in str(exc_info.value) + + def test_validate_all_tools_provided_passes(self, mock_async_openai: MagicMock) -> None: + """Test that validation passes when all tools provided.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant_tools = [create_function_tool("get_weather")] + + # Should not raise + provider._validate_function_tools(assistant_tools, [get_weather]) # type: ignore[reportPrivateUsage] + + def test_validate_hosted_tools_not_required(self, mock_async_openai: MagicMock) -> None: + """Test that hosted tools don't require implementations.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant_tools = [create_code_interpreter_tool(), create_file_search_tool()] + + # Should not raise + provider._validate_function_tools(assistant_tools, None) # type: ignore[reportPrivateUsage] + + def test_validate_with_ai_function(self, mock_async_openai: MagicMock) -> None: + """Test validation with AIFunction.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant_tools = [create_function_tool("get_weather")] + + wrapped = ai_function(get_weather) + + # Should not raise + provider._validate_function_tools(assistant_tools, [wrapped]) # type: ignore[reportPrivateUsage] + + def test_validate_partial_tools_raises(self, mock_async_openai: MagicMock) -> None: + """Test that partial tool provision raises error.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant_tools = [ + create_function_tool("get_weather"), + create_function_tool("search_database"), + ] + + with pytest.raises(ValueError) as exc_info: + provider._validate_function_tools(assistant_tools, [get_weather]) # type: ignore[reportPrivateUsage] + + assert "search_database" in str(exc_info.value) + + +# endregion + + +# region Tool Merging Tests + + +class TestToolMerging: + """Tests for tool merging.""" + + def test_merge_code_interpreter(self, mock_async_openai: MagicMock) -> None: + """Test merging code interpreter tool.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant_tools = [create_code_interpreter_tool()] + + merged = provider._merge_tools(assistant_tools, None) # type: ignore[reportPrivateUsage] + + assert len(merged) == 1 + assert isinstance(merged[0], HostedCodeInterpreterTool) + + def test_merge_file_search(self, mock_async_openai: MagicMock) -> None: + """Test merging file search tool.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant_tools = [create_file_search_tool()] + + merged = provider._merge_tools(assistant_tools, None) # type: ignore[reportPrivateUsage] + + assert len(merged) == 1 + assert isinstance(merged[0], HostedFileSearchTool) + + def test_merge_with_user_tools(self, mock_async_openai: MagicMock) -> None: + """Test merging hosted and user tools.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant_tools = [create_code_interpreter_tool()] + + merged = provider._merge_tools(assistant_tools, [get_weather]) # type: ignore[reportPrivateUsage] + + assert len(merged) == 2 + assert isinstance(merged[0], HostedCodeInterpreterTool) + + def test_merge_multiple_hosted_tools(self, mock_async_openai: MagicMock) -> None: + """Test merging multiple hosted tools.""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant_tools = [create_code_interpreter_tool(), create_file_search_tool()] + + merged = provider._merge_tools(assistant_tools, None) # type: ignore[reportPrivateUsage] + + assert len(merged) == 2 + + def test_merge_single_user_tool(self, mock_async_openai: MagicMock) -> None: + """Test merging with single user tool (not list).""" + provider = OpenAIAssistantProvider(mock_async_openai) + assistant_tools: list[Any] = [] + + merged = provider._merge_tools(assistant_tools, get_weather) # type: ignore[reportPrivateUsage] + + assert len(merged) == 1 + + +# endregion + + +# region Integration Tests + + +skip_if_openai_integration_tests_disabled = pytest.mark.skipif( + os.getenv("RUN_INTEGRATION_TESTS", "false").lower() != "true" + or os.getenv("OPENAI_API_KEY", "") in ("", "test-dummy-key"), + reason="No real OPENAI_API_KEY provided; skipping integration tests." + if os.getenv("RUN_INTEGRATION_TESTS", "false").lower() == "true" + else "Integration tests are disabled.", +) + + +@skip_if_openai_integration_tests_disabled +class TestOpenAIAssistantProviderIntegration: + """Integration tests requiring real OpenAI API.""" + + async def test_create_and_run_agent(self) -> None: + """End-to-end test of creating and running an agent.""" + provider = OpenAIAssistantProvider() + + agent = await provider.create_agent( + name="IntegrationTestAgent", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), + instructions="You are a helpful assistant. Respond briefly.", + ) + + try: + result = await agent.run("Say 'hello' and nothing else.") + result_text = str(result) + assert "hello" in result_text.lower() + finally: + # Clean up the assistant + await provider._client.beta.assistants.delete(agent.id) # type: ignore[reportPrivateUsage, union-attr] + + async def test_create_agent_with_function_tools_integration(self) -> None: + """Integration test with function tools.""" + provider = OpenAIAssistantProvider() + + def get_current_time() -> str: + """Get the current time.""" + from datetime import datetime + + return datetime.now().strftime("%H:%M") + + agent = await provider.create_agent( + name="TimeAgent", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), + instructions="You are a helpful assistant.", + tools=[get_current_time], + ) + + try: + result = await agent.run("What time is it? Use the get_current_time function.") + result_text = str(result) + # The response should contain time information + assert ":" in result_text or "time" in result_text.lower() + finally: + await provider._client.beta.assistants.delete(agent.id) # type: ignore[reportPrivateUsage, union-attr] + + +# endregion diff --git a/python/packages/core/tests/openai/test_openai_assistants_client.py b/python/packages/core/tests/openai/test_openai_assistants_client.py index 861ccc73d1..424c1cc044 100644 --- a/python/packages/core/tests/openai/test_openai_assistants_client.py +++ b/python/packages/core/tests/openai/test_openai_assistants_client.py @@ -11,13 +11,12 @@ from pydantic import Field from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentThread, ChatAgent, ChatClientProtocol, ChatMessage, - ChatOptions, ChatResponse, ChatResponseUpdate, FunctionCallContent, @@ -27,7 +26,6 @@ HostedVectorStoreContent, Role, TextContent, - ToolMode, UriContent, UsageContent, ai_function, @@ -43,6 +41,8 @@ else "Integration tests are disabled.", ) +INTEGRATION_TEST_MODEL = "gpt-4.1-nano" + def create_test_openai_assistants_client( mock_async_openai: MagicMock, @@ -117,7 +117,7 @@ def mock_async_openai() -> MagicMock: return mock_client -def test_openai_assistants_client_init_with_client(mock_async_openai: MagicMock) -> None: +def test_init_with_client(mock_async_openai: MagicMock) -> None: """Test OpenAIAssistantsClient initialization with existing client.""" chat_client = create_test_openai_assistants_client( mock_async_openai, model_id="gpt-4", assistant_id="existing-assistant-id", thread_id="test-thread-id" @@ -131,7 +131,7 @@ def test_openai_assistants_client_init_with_client(mock_async_openai: MagicMock) assert isinstance(chat_client, ChatClientProtocol) -def test_openai_assistants_client_init_auto_create_client( +def test_init_auto_create_client( openai_unit_test_env: dict[str, str], mock_async_openai: MagicMock, ) -> None: @@ -151,7 +151,7 @@ def test_openai_assistants_client_init_auto_create_client( assert not chat_client._should_delete_assistant # type: ignore -def test_openai_assistants_client_init_validation_fail() -> None: +def test_init_validation_fail() -> None: """Test OpenAIAssistantsClient initialization with validation failure.""" with pytest.raises(ServiceInitializationError): # Force failure by providing invalid model ID type - this should cause validation to fail @@ -159,7 +159,7 @@ def test_openai_assistants_client_init_validation_fail() -> None: @pytest.mark.parametrize("exclude_list", [["OPENAI_CHAT_MODEL_ID"]], indirect=True) -def test_openai_assistants_client_init_missing_model_id(openai_unit_test_env: dict[str, str]) -> None: +def test_init_missing_model_id(openai_unit_test_env: dict[str, str]) -> None: """Test OpenAIAssistantsClient initialization with missing model ID.""" with pytest.raises(ServiceInitializationError): OpenAIAssistantsClient( @@ -168,13 +168,13 @@ def test_openai_assistants_client_init_missing_model_id(openai_unit_test_env: di @pytest.mark.parametrize("exclude_list", [["OPENAI_API_KEY"]], indirect=True) -def test_openai_assistants_client_init_missing_api_key(openai_unit_test_env: dict[str, str]) -> None: +def test_init_missing_api_key(openai_unit_test_env: dict[str, str]) -> None: """Test OpenAIAssistantsClient initialization with missing API key.""" with pytest.raises(ServiceInitializationError): OpenAIAssistantsClient(model_id="gpt-4", env_file_path="nonexistent.env") -def test_openai_assistants_client_init_with_default_headers(openai_unit_test_env: dict[str, str]) -> None: +def test_init_with_default_headers(openai_unit_test_env: dict[str, str]) -> None: """Test OpenAIAssistantsClient initialization with default headers.""" default_headers = {"X-Unit-Test": "test-guid"} @@ -193,7 +193,7 @@ def test_openai_assistants_client_init_with_default_headers(openai_unit_test_env assert chat_client.client.default_headers[key] == value -async def test_openai_assistants_client_get_assistant_id_or_create_existing_assistant( +async def test_get_assistant_id_or_create_existing_assistant( mock_async_openai: MagicMock, ) -> None: """Test _get_assistant_id_or_create when assistant_id is already provided.""" @@ -206,7 +206,7 @@ async def test_openai_assistants_client_get_assistant_id_or_create_existing_assi mock_async_openai.beta.assistants.create.assert_not_called() -async def test_openai_assistants_client_get_assistant_id_or_create_create_new( +async def test_get_assistant_id_or_create_create_new( mock_async_openai: MagicMock, ) -> None: """Test _get_assistant_id_or_create when creating a new assistant.""" @@ -221,7 +221,7 @@ async def test_openai_assistants_client_get_assistant_id_or_create_create_new( mock_async_openai.beta.assistants.create.assert_called_once() -async def test_openai_assistants_client_aclose_should_not_delete( +async def test_aclose_should_not_delete( mock_async_openai: MagicMock, ) -> None: """Test close when assistant should not be deleted.""" @@ -236,7 +236,7 @@ async def test_openai_assistants_client_aclose_should_not_delete( assert not chat_client._should_delete_assistant # type: ignore -async def test_openai_assistants_client_aclose_should_delete(mock_async_openai: MagicMock) -> None: +async def test_aclose_should_delete(mock_async_openai: MagicMock) -> None: """Test close method calls cleanup.""" chat_client = create_test_openai_assistants_client( mock_async_openai, assistant_id="assistant-to-delete", should_delete_assistant=True @@ -249,7 +249,7 @@ async def test_openai_assistants_client_aclose_should_delete(mock_async_openai: assert not chat_client._should_delete_assistant # type: ignore -async def test_openai_assistants_client_async_context_manager(mock_async_openai: MagicMock) -> None: +async def test_async_context_manager(mock_async_openai: MagicMock) -> None: """Test async context manager functionality.""" chat_client = create_test_openai_assistants_client( mock_async_openai, assistant_id="assistant-to-delete", should_delete_assistant=True @@ -263,7 +263,7 @@ async def test_openai_assistants_client_async_context_manager(mock_async_openai: mock_async_openai.beta.assistants.delete.assert_called_once_with("assistant-to-delete") -def test_openai_assistants_client_serialize(openai_unit_test_env: dict[str, str]) -> None: +def test_serialize(openai_unit_test_env: dict[str, str]) -> None: """Test serialization of OpenAIAssistantsClient.""" default_headers = {"X-Unit-Test": "test-guid"} @@ -294,7 +294,7 @@ def test_openai_assistants_client_serialize(openai_unit_test_env: dict[str, str] assert "User-Agent" not in dumped_settings["default_headers"] -async def test_openai_assistants_client_get_active_thread_run_none_thread_id(mock_async_openai: MagicMock) -> None: +async def test_get_active_thread_run_none_thread_id(mock_async_openai: MagicMock) -> None: """Test _get_active_thread_run with None thread_id returns None.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -305,7 +305,7 @@ async def test_openai_assistants_client_get_active_thread_run_none_thread_id(moc mock_async_openai.beta.threads.runs.list.assert_not_called() -async def test_openai_assistants_client_get_active_thread_run_with_active_run(mock_async_openai: MagicMock) -> None: +async def test_get_active_thread_run_with_active_run(mock_async_openai: MagicMock) -> None: """Test _get_active_thread_run finds an active run.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -326,7 +326,7 @@ async def mock_runs_list(*args: Any, **kwargs: Any) -> Any: mock_async_openai.beta.threads.runs.list.assert_called_once_with(thread_id="thread-123", limit=1, order="desc") -async def test_openai_assistants_client_prepare_thread_create_new(mock_async_openai: MagicMock) -> None: +async def test_prepare_thread_create_new(mock_async_openai: MagicMock) -> None: """Test _prepare_thread creates new thread when thread_id is None.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -353,7 +353,7 @@ async def test_openai_assistants_client_prepare_thread_create_new(mock_async_ope ) -async def test_openai_assistants_client_prepare_thread_cancel_existing_run(mock_async_openai: MagicMock) -> None: +async def test_prepare_thread_cancel_existing_run(mock_async_openai: MagicMock) -> None: """Test _prepare_thread cancels existing run when provided.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -369,7 +369,7 @@ async def test_openai_assistants_client_prepare_thread_cancel_existing_run(mock_ mock_async_openai.beta.threads.runs.cancel.assert_called_once_with(run_id="run-456", thread_id="thread-123") -async def test_openai_assistants_client_prepare_thread_existing_no_run(mock_async_openai: MagicMock) -> None: +async def test_prepare_thread_existing_no_run(mock_async_openai: MagicMock) -> None: """Test _prepare_thread with existing thread_id but no active run.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -382,7 +382,7 @@ async def test_openai_assistants_client_prepare_thread_existing_no_run(mock_asyn mock_async_openai.beta.threads.runs.cancel.assert_not_called() -async def test_openai_assistants_client_process_stream_events_thread_run_created(mock_async_openai: MagicMock) -> None: +async def test_process_stream_events_thread_run_created(mock_async_openai: MagicMock) -> None: """Test _process_stream_events with thread.run.created event.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -415,7 +415,7 @@ async def async_iterator() -> Any: assert update.raw_representation == mock_response.data -async def test_openai_assistants_client_process_stream_events_message_delta_text(mock_async_openai: MagicMock) -> None: +async def test_process_stream_events_message_delta_text(mock_async_openai: MagicMock) -> None: """Test _process_stream_events with thread.message.delta event containing text.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -459,7 +459,7 @@ async def async_iterator() -> Any: assert update.raw_representation == mock_message_delta -async def test_openai_assistants_client_process_stream_events_requires_action(mock_async_openai: MagicMock) -> None: +async def test_process_stream_events_requires_action(mock_async_openai: MagicMock) -> None: """Test _process_stream_events with thread.run.requires_action event.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -502,7 +502,7 @@ async def async_iterator() -> Any: chat_client._parse_function_calls_from_assistants.assert_called_once_with(mock_run, None) # type: ignore -async def test_openai_assistants_client_process_stream_events_run_step_created(mock_async_openai: MagicMock) -> None: +async def test_process_stream_events_run_step_created(mock_async_openai: MagicMock) -> None: """Test _process_stream_events with thread.run.step.created event.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -534,7 +534,7 @@ async def async_iterator() -> Any: assert len(updates) == 0 -async def test_openai_assistants_client_process_stream_events_run_completed_with_usage( +async def test_process_stream_events_run_completed_with_usage( mock_async_openai: MagicMock, ) -> None: """Test _process_stream_events with thread.run.completed event containing usage.""" @@ -585,7 +585,7 @@ async def async_iterator() -> Any: assert update.raw_representation == mock_run -def test_openai_assistants_client_parse_function_calls_from_assistants_basic(mock_async_openai: MagicMock) -> None: +def test_parse_function_calls_from_assistants_basic(mock_async_openai: MagicMock) -> None: """Test _parse_function_calls_from_assistants with a simple function call.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -614,22 +614,22 @@ def test_openai_assistants_client_parse_function_calls_from_assistants_basic(moc assert contents[0].arguments == {"location": "Seattle"} -def test_openai_assistants_client_prepare_options_basic(mock_async_openai: MagicMock) -> None: +def test_prepare_options_basic(mock_async_openai: MagicMock) -> None: """Test _prepare_options with basic chat options.""" chat_client = create_test_openai_assistants_client(mock_async_openai) - # Create basic chat options - chat_options = ChatOptions( - max_tokens=100, - model_id="gpt-4", - temperature=0.7, - top_p=0.9, - ) + # Create basic chat options as a dict + options = { + "max_tokens": 100, + "model_id": "gpt-4", + "temperature": 0.7, + "top_p": 0.9, + } messages = [ChatMessage(role=Role.USER, text="Hello")] # Call the method - run_options, tool_results = chat_client._prepare_options(messages, chat_options) # type: ignore + run_options, tool_results = chat_client._prepare_options(messages, options) # type: ignore # Check basic options were set assert run_options["max_completion_tokens"] == 100 @@ -639,7 +639,7 @@ def test_openai_assistants_client_prepare_options_basic(mock_async_openai: Magic assert tool_results is None -def test_openai_assistants_client_prepare_options_with_ai_function_tool(mock_async_openai: MagicMock) -> None: +def test_prepare_options_with_ai_function_tool(mock_async_openai: MagicMock) -> None: """Test _prepare_options with AIFunction tool.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -650,15 +650,15 @@ def test_function(query: str) -> str: """A test function.""" return f"Result for {query}" - chat_options = ChatOptions( - tools=[test_function], - tool_choice="auto", - ) + options = { + "tools": [test_function], + "tool_choice": "auto", + } messages = [ChatMessage(role=Role.USER, text="Hello")] # Call the method - run_options, tool_results = chat_client._prepare_options(messages, chat_options) # type: ignore + run_options, tool_results = chat_client._prepare_options(messages, options) # type: ignore # Check tools were set correctly assert "tools" in run_options @@ -668,22 +668,22 @@ def test_function(query: str) -> str: assert run_options["tool_choice"] == "auto" -def test_openai_assistants_client_prepare_options_with_code_interpreter(mock_async_openai: MagicMock) -> None: +def test_prepare_options_with_code_interpreter(mock_async_openai: MagicMock) -> None: """Test _prepare_options with HostedCodeInterpreterTool.""" chat_client = create_test_openai_assistants_client(mock_async_openai) # Create a real HostedCodeInterpreterTool code_tool = HostedCodeInterpreterTool() - chat_options = ChatOptions( - tools=[code_tool], - tool_choice="auto", - ) + options = { + "tools": [code_tool], + "tool_choice": "auto", + } messages = [ChatMessage(role=Role.USER, text="Calculate something")] # Call the method - run_options, tool_results = chat_client._prepare_options(messages, chat_options) # type: ignore + run_options, tool_results = chat_client._prepare_options(messages, options) # type: ignore # Check code interpreter tool was set correctly assert "tools" in run_options @@ -692,39 +692,39 @@ def test_openai_assistants_client_prepare_options_with_code_interpreter(mock_asy assert run_options["tool_choice"] == "auto" -def test_openai_assistants_client_prepare_options_tool_choice_none(mock_async_openai: MagicMock) -> None: +def test_prepare_options_tool_choice_none(mock_async_openai: MagicMock) -> None: """Test _prepare_options with tool_choice set to 'none'.""" chat_client = create_test_openai_assistants_client(mock_async_openai) - chat_options = ChatOptions( - tool_choice="none", - ) + options = { + "tool_choice": "none", + } messages = [ChatMessage(role=Role.USER, text="Hello")] # Call the method - run_options, tool_results = chat_client._prepare_options(messages, chat_options) # type: ignore + run_options, tool_results = chat_client._prepare_options(messages, options) # type: ignore # Should set tool_choice to none and not include tools assert run_options["tool_choice"] == "none" assert "tools" not in run_options -def test_openai_assistants_client_prepare_options_required_function(mock_async_openai: MagicMock) -> None: +def test_prepare_options_required_function(mock_async_openai: MagicMock) -> None: """Test _prepare_options with required function tool choice.""" chat_client = create_test_openai_assistants_client(mock_async_openai) - # Create a required function tool choice - tool_choice = ToolMode(mode="required", required_function_name="specific_function") + # Create a required function tool choice as dict + tool_choice = {"mode": "required", "required_function_name": "specific_function"} - chat_options = ChatOptions( - tool_choice=tool_choice, - ) + options = { + "tool_choice": tool_choice, + } messages = [ChatMessage(role=Role.USER, text="Hello")] # Call the method - run_options, tool_results = chat_client._prepare_options(messages, chat_options) # type: ignore + run_options, tool_results = chat_client._prepare_options(messages, options) # type: ignore # Check required function tool choice was set correctly expected_tool_choice = { @@ -734,7 +734,7 @@ def test_openai_assistants_client_prepare_options_required_function(mock_async_o assert run_options["tool_choice"] == expected_tool_choice -def test_openai_assistants_client_prepare_options_with_file_search_tool(mock_async_openai: MagicMock) -> None: +def test_prepare_options_with_file_search_tool(mock_async_openai: MagicMock) -> None: """Test _prepare_options with HostedFileSearchTool.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -742,15 +742,15 @@ def test_openai_assistants_client_prepare_options_with_file_search_tool(mock_asy # Create a HostedFileSearchTool with max_results file_search_tool = HostedFileSearchTool(max_results=10) - chat_options = ChatOptions( - tools=[file_search_tool], - tool_choice="auto", - ) + options = { + "tools": [file_search_tool], + "tool_choice": "auto", + } messages = [ChatMessage(role=Role.USER, text="Search for information")] # Call the method - run_options, tool_results = chat_client._prepare_options(messages, chat_options) # type: ignore + run_options, tool_results = chat_client._prepare_options(messages, options) # type: ignore # Check file search tool was set correctly assert "tools" in run_options @@ -760,22 +760,22 @@ def test_openai_assistants_client_prepare_options_with_file_search_tool(mock_asy assert run_options["tool_choice"] == "auto" -def test_openai_assistants_client_prepare_options_with_mapping_tool(mock_async_openai: MagicMock) -> None: +def test_prepare_options_with_mapping_tool(mock_async_openai: MagicMock) -> None: """Test _prepare_options with MutableMapping tool.""" chat_client = create_test_openai_assistants_client(mock_async_openai) # Create a tool as a MutableMapping (dict) mapping_tool = {"type": "custom_tool", "parameters": {"setting": "value"}} - chat_options = ChatOptions( - tools=[mapping_tool], # type: ignore - tool_choice="auto", - ) + options = { + "tools": [mapping_tool], # type: ignore + "tool_choice": "auto", + } messages = [ChatMessage(role=Role.USER, text="Use custom tool")] # Call the method - run_options, tool_results = chat_client._prepare_options(messages, chat_options) # type: ignore + run_options, tool_results = chat_client._prepare_options(messages, options) # type: ignore # Check mapping tool was set correctly assert "tools" in run_options @@ -784,7 +784,28 @@ def test_openai_assistants_client_prepare_options_with_mapping_tool(mock_async_o assert run_options["tool_choice"] == "auto" -def test_openai_assistants_client_prepare_options_with_system_message(mock_async_openai: MagicMock) -> None: +def test_prepare_options_with_pydantic_response_format(mock_async_openai: MagicMock) -> None: + """Test _prepare_options sets strict=True for Pydantic response_format.""" + from pydantic import BaseModel, ConfigDict + + class TestResponse(BaseModel): + name: str + value: int + model_config = ConfigDict(extra="forbid") + + chat_client = create_test_openai_assistants_client(mock_async_openai) + messages = [ChatMessage(role=Role.USER, text="Test")] + options = {"response_format": TestResponse} + + run_options, _ = chat_client._prepare_options(messages, options) # type: ignore + + assert "response_format" in run_options + assert run_options["response_format"]["type"] == "json_schema" + assert run_options["response_format"]["json_schema"]["name"] == "TestResponse" + assert run_options["response_format"]["json_schema"]["strict"] is True + + +def test_prepare_options_with_system_message(mock_async_openai: MagicMock) -> None: """Test _prepare_options with system message converted to instructions.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -794,7 +815,7 @@ def test_openai_assistants_client_prepare_options_with_system_message(mock_async ] # Call the method - run_options, tool_results = chat_client._prepare_options(messages, None) # type: ignore + run_options, tool_results = chat_client._prepare_options(messages, {}) # type: ignore # Check that additional_messages only contains the user message # System message should be converted to instructions (though this is handled internally) @@ -803,7 +824,7 @@ def test_openai_assistants_client_prepare_options_with_system_message(mock_async assert run_options["additional_messages"][0]["role"] == "user" -def test_openai_assistants_client_prepare_options_with_image_content(mock_async_openai: MagicMock) -> None: +def test_prepare_options_with_image_content(mock_async_openai: MagicMock) -> None: """Test _prepare_options with image content.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -813,7 +834,7 @@ def test_openai_assistants_client_prepare_options_with_image_content(mock_async_ messages = [ChatMessage(role=Role.USER, contents=[image_content])] # Call the method - run_options, tool_results = chat_client._prepare_options(messages, None) # type: ignore + run_options, tool_results = chat_client._prepare_options(messages, {}) # type: ignore # Check that image content was processed assert "additional_messages" in run_options @@ -825,7 +846,7 @@ def test_openai_assistants_client_prepare_options_with_image_content(mock_async_ assert message["content"][0]["image_url"]["url"] == "https://example.com/image.jpg" -def test_openai_assistants_client_prepare_tool_outputs_for_assistants_empty(mock_async_openai: MagicMock) -> None: +def test_prepare_tool_outputs_for_assistants_empty(mock_async_openai: MagicMock) -> None: """Test _prepare_tool_outputs_for_assistants with empty list.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -835,7 +856,7 @@ def test_openai_assistants_client_prepare_tool_outputs_for_assistants_empty(mock assert tool_outputs is None -def test_openai_assistants_client_prepare_tool_outputs_for_assistants_valid(mock_async_openai: MagicMock) -> None: +def test_prepare_tool_outputs_for_assistants_valid(mock_async_openai: MagicMock) -> None: """Test _prepare_tool_outputs_for_assistants with valid function results.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -851,7 +872,7 @@ def test_openai_assistants_client_prepare_tool_outputs_for_assistants_valid(mock assert tool_outputs[0].get("output") == "Function executed successfully" -def test_openai_assistants_client_prepare_tool_outputs_for_assistants_mismatched_run_ids( +def test_prepare_tool_outputs_for_assistants_mismatched_run_ids( mock_async_openai: MagicMock, ) -> None: """Test _prepare_tool_outputs_for_assistants with mismatched run IDs.""" @@ -872,7 +893,7 @@ def test_openai_assistants_client_prepare_tool_outputs_for_assistants_mismatched assert tool_outputs[0].get("tool_call_id") == "call-456" -def test_openai_assistants_client_update_agent_name_and_description(mock_async_openai: MagicMock) -> None: +def test_update_agent_name_and_description(mock_async_openai: MagicMock) -> None: """Test _update_agent_name_and_description method updates assistant_name when not already set.""" # Test updating agent name when assistant_name is None chat_client = create_test_openai_assistants_client(mock_async_openai, assistant_name=None) @@ -883,7 +904,7 @@ def test_openai_assistants_client_update_agent_name_and_description(mock_async_o assert chat_client.assistant_name == "New Assistant Name" -def test_openai_assistants_client_update_agent_name_and_description_existing(mock_async_openai: MagicMock) -> None: +def test_update_agent_name_and_description_existing(mock_async_openai: MagicMock) -> None: """Test _update_agent_name_and_description method doesn't override existing assistant_name.""" # Test that existing assistant_name is not overridden chat_client = create_test_openai_assistants_client(mock_async_openai, assistant_name="Existing Assistant") @@ -895,7 +916,7 @@ def test_openai_assistants_client_update_agent_name_and_description_existing(moc assert chat_client.assistant_name == "Existing Assistant" -def test_openai_assistants_client_update_agent_name_and_description_none(mock_async_openai: MagicMock) -> None: +def test_update_agent_name_and_description_none(mock_async_openai: MagicMock) -> None: """Test _update_agent_name_and_description method with None agent_name parameter.""" # Test that None agent_name doesn't change anything chat_client = create_test_openai_assistants_client(mock_async_openai, assistant_name=None) @@ -916,9 +937,9 @@ def get_weather( @pytest.mark.flaky @skip_if_openai_integration_tests_disabled -async def test_openai_assistants_client_get_response() -> None: +async def test_get_response() -> None: """Test OpenAI Assistants Client response.""" - async with OpenAIAssistantsClient() as openai_assistants_client: + async with OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL) as openai_assistants_client: assert isinstance(openai_assistants_client, ChatClientProtocol) messages: list[ChatMessage] = [] @@ -941,9 +962,9 @@ async def test_openai_assistants_client_get_response() -> None: @pytest.mark.flaky @skip_if_openai_integration_tests_disabled -async def test_openai_assistants_client_get_response_tools() -> None: +async def test_get_response_tools() -> None: """Test OpenAI Assistants Client response with tools.""" - async with OpenAIAssistantsClient() as openai_assistants_client: + async with OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL) as openai_assistants_client: assert isinstance(openai_assistants_client, ChatClientProtocol) messages: list[ChatMessage] = [] @@ -952,8 +973,7 @@ async def test_openai_assistants_client_get_response_tools() -> None: # Test that the client can be used to get a response response = await openai_assistants_client.get_response( messages=messages, - tools=[get_weather], - tool_choice="auto", + options={"tools": [get_weather], "tool_choice": "auto"}, ) assert response is not None @@ -963,9 +983,9 @@ async def test_openai_assistants_client_get_response_tools() -> None: @pytest.mark.flaky @skip_if_openai_integration_tests_disabled -async def test_openai_assistants_client_streaming() -> None: +async def test_streaming() -> None: """Test OpenAI Assistants Client streaming response.""" - async with OpenAIAssistantsClient() as openai_assistants_client: + async with OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL) as openai_assistants_client: assert isinstance(openai_assistants_client, ChatClientProtocol) messages: list[ChatMessage] = [] @@ -994,9 +1014,9 @@ async def test_openai_assistants_client_streaming() -> None: @pytest.mark.flaky @skip_if_openai_integration_tests_disabled -async def test_openai_assistants_client_streaming_tools() -> None: +async def test_streaming_tools() -> None: """Test OpenAI Assistants Client streaming response with tools.""" - async with OpenAIAssistantsClient() as openai_assistants_client: + async with OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL) as openai_assistants_client: assert isinstance(openai_assistants_client, ChatClientProtocol) messages: list[ChatMessage] = [] @@ -1005,8 +1025,10 @@ async def test_openai_assistants_client_streaming_tools() -> None: # Test that the client can be used to get a response response = openai_assistants_client.get_streaming_response( messages=messages, - tools=[get_weather], - tool_choice="auto", + options={ + "tools": [get_weather], + "tool_choice": "auto", + }, ) full_message: str = "" async for chunk in response: @@ -1021,10 +1043,10 @@ async def test_openai_assistants_client_streaming_tools() -> None: @pytest.mark.flaky @skip_if_openai_integration_tests_disabled -async def test_openai_assistants_client_with_existing_assistant() -> None: +async def test_with_existing_assistant() -> None: """Test OpenAI Assistants Client with existing assistant ID.""" # First create an assistant to use in the test - async with OpenAIAssistantsClient() as temp_client: + async with OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL) as temp_client: # Get the assistant ID by triggering assistant creation messages = [ChatMessage(role="user", text="Hello")] await temp_client.get_response(messages=messages) @@ -1032,7 +1054,7 @@ async def test_openai_assistants_client_with_existing_assistant() -> None: # Now test using the existing assistant async with OpenAIAssistantsClient( - model_id="gpt-4o-mini", assistant_id=assistant_id + model_id=INTEGRATION_TEST_MODEL, assistant_id=assistant_id ) as openai_assistants_client: assert isinstance(openai_assistants_client, ChatClientProtocol) assert openai_assistants_client.assistant_id == assistant_id @@ -1050,9 +1072,9 @@ async def test_openai_assistants_client_with_existing_assistant() -> None: @pytest.mark.flaky @skip_if_openai_integration_tests_disabled @pytest.mark.skip(reason="OpenAI file search functionality is currently broken - tracked in GitHub issue") -async def test_openai_assistants_client_file_search() -> None: +async def test_file_search() -> None: """Test OpenAI Assistants Client response.""" - async with OpenAIAssistantsClient() as openai_assistants_client: + async with OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL) as openai_assistants_client: assert isinstance(openai_assistants_client, ChatClientProtocol) messages: list[ChatMessage] = [] @@ -1061,8 +1083,10 @@ async def test_openai_assistants_client_file_search() -> None: file_id, vector_store = await create_vector_store(openai_assistants_client) response = await openai_assistants_client.get_response( messages=messages, - tools=[HostedFileSearchTool()], - tool_resources={"file_search": {"vector_store_ids": [vector_store.vector_store_id]}}, + options={ + "tools": [HostedFileSearchTool()], + "tool_resources": {"file_search": {"vector_store_ids": [vector_store.vector_store_id]}}, + }, ) await delete_vector_store(openai_assistants_client, file_id, vector_store.vector_store_id) @@ -1074,9 +1098,9 @@ async def test_openai_assistants_client_file_search() -> None: @pytest.mark.flaky @skip_if_openai_integration_tests_disabled @pytest.mark.skip(reason="OpenAI file search functionality is currently broken - tracked in GitHub issue") -async def test_openai_assistants_client_file_search_streaming() -> None: +async def test_file_search_streaming() -> None: """Test OpenAI Assistants Client response.""" - async with OpenAIAssistantsClient() as openai_assistants_client: + async with OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL) as openai_assistants_client: assert isinstance(openai_assistants_client, ChatClientProtocol) messages: list[ChatMessage] = [] @@ -1085,8 +1109,10 @@ async def test_openai_assistants_client_file_search_streaming() -> None: file_id, vector_store = await create_vector_store(openai_assistants_client) response = openai_assistants_client.get_streaming_response( messages=messages, - tools=[HostedFileSearchTool()], - tool_resources={"file_search": {"vector_store_ids": [vector_store.vector_store_id]}}, + options={ + "tools": [HostedFileSearchTool()], + "tool_resources": {"file_search": {"vector_store_ids": [vector_store.vector_store_id]}}, + }, ) assert response is not None @@ -1107,13 +1133,13 @@ async def test_openai_assistants_client_file_search_streaming() -> None: async def test_openai_assistants_agent_basic_run(): """Test ChatAgent basic run functionality with OpenAIAssistantsClient.""" async with ChatAgent( - chat_client=OpenAIAssistantsClient(), + chat_client=OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL), ) as agent: # Run a simple query response = await agent.run("Hello! Please respond with 'Hello World' exactly.") # Validate response - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert response.text is not None assert len(response.text) > 0 assert "Hello World" in response.text @@ -1124,13 +1150,13 @@ async def test_openai_assistants_agent_basic_run(): async def test_openai_assistants_agent_basic_run_streaming(): """Test ChatAgent basic streaming functionality with OpenAIAssistantsClient.""" async with ChatAgent( - chat_client=OpenAIAssistantsClient(), + chat_client=OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL), ) as agent: # Run streaming query full_message: str = "" async for chunk in agent.run_stream("Please respond with exactly: 'This is a streaming response test.'"): assert chunk is not None - assert isinstance(chunk, AgentRunResponseUpdate) + assert isinstance(chunk, AgentResponseUpdate) if chunk.text: full_message += chunk.text @@ -1144,7 +1170,7 @@ async def test_openai_assistants_agent_basic_run_streaming(): async def test_openai_assistants_agent_thread_persistence(): """Test ChatAgent thread persistence across runs with OpenAIAssistantsClient.""" async with ChatAgent( - chat_client=OpenAIAssistantsClient(), + chat_client=OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL), instructions="You are a helpful assistant with good memory.", ) as agent: # Create a new thread that will be reused @@ -1154,14 +1180,14 @@ async def test_openai_assistants_agent_thread_persistence(): first_response = await agent.run( "Remember this number: 42. What number did I just tell you to remember?", thread=thread ) - assert isinstance(first_response, AgentRunResponse) + assert isinstance(first_response, AgentResponse) assert "42" in first_response.text # Second message - test conversation memory second_response = await agent.run( "What number did I tell you to remember in my previous message?", thread=thread ) - assert isinstance(second_response, AgentRunResponse) + assert isinstance(second_response, AgentResponse) assert "42" in second_response.text # Verify thread has been populated with conversation ID @@ -1176,7 +1202,7 @@ async def test_openai_assistants_agent_existing_thread_id(): existing_thread_id = None async with ChatAgent( - chat_client=OpenAIAssistantsClient(), + chat_client=OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL), instructions="You are a helpful weather agent.", tools=[get_weather], ) as agent: @@ -1185,7 +1211,7 @@ async def test_openai_assistants_agent_existing_thread_id(): response1 = await agent.run("What's the weather in Paris?", thread=thread) # Validate first response - assert isinstance(response1, AgentRunResponse) + assert isinstance(response1, AgentResponse) assert response1.text is not None assert any(word in response1.text.lower() for word in ["weather", "paris"]) @@ -1207,7 +1233,7 @@ async def test_openai_assistants_agent_existing_thread_id(): response2 = await agent.run("What was the last city I asked about?", thread=thread) # Validate that the agent remembers the previous conversation - assert isinstance(response2, AgentRunResponse) + assert isinstance(response2, AgentResponse) assert response2.text is not None # Should reference Paris from the previous conversation assert "paris" in response2.text.lower() @@ -1219,7 +1245,7 @@ async def test_openai_assistants_agent_code_interpreter(): """Test ChatAgent with code interpreter through OpenAIAssistantsClient.""" async with ChatAgent( - chat_client=OpenAIAssistantsClient(), + chat_client=OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL), instructions="You are a helpful assistant that can write and execute Python code.", tools=[HostedCodeInterpreterTool()], ) as agent: @@ -1227,7 +1253,7 @@ async def test_openai_assistants_agent_code_interpreter(): response = await agent.run("Write Python code to calculate the factorial of 5 and show the result.") # Validate response - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert response.text is not None # Factorial of 5 is 120 assert "120" in response.text or "factorial" in response.text.lower() @@ -1235,18 +1261,18 @@ async def test_openai_assistants_agent_code_interpreter(): @pytest.mark.flaky @skip_if_openai_integration_tests_disabled -async def test_openai_assistants_client_agent_level_tool_persistence(): +async def test_agent_level_tool_persistence(): """Test that agent-level tools persist across multiple runs with OpenAI Assistants Client.""" async with ChatAgent( - chat_client=OpenAIAssistantsClient(), + chat_client=OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL), instructions="You are a helpful assistant that uses available tools.", tools=[get_weather], # Agent-level tool ) as agent: # First run - agent-level tool should be available first_response = await agent.run("What's the weather like in Chicago?") - assert isinstance(first_response, AgentRunResponse) + assert isinstance(first_response, AgentResponse) assert first_response.text is not None # Should use the agent-level weather tool assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "72"]) @@ -1254,14 +1280,14 @@ async def test_openai_assistants_client_agent_level_tool_persistence(): # Second run - agent-level tool should still be available (persistence test) second_response = await agent.run("What's the weather in Miami?") - assert isinstance(second_response, AgentRunResponse) + assert isinstance(second_response, AgentResponse) assert second_response.text is not None # Should use the agent-level weather tool again assert any(term in second_response.text.lower() for term in ["miami", "sunny", "72"]) # Callable API Key Tests -def test_openai_assistants_client_with_callable_api_key() -> None: +def test_with_callable_api_key() -> None: """Test OpenAIAssistantsClient initialization with callable API key.""" async def get_api_key() -> str: diff --git a/python/packages/core/tests/openai/test_openai_chat_client.py b/python/packages/core/tests/openai/test_openai_chat_client.py index 18854799fd..1f1d624345 100644 --- a/python/packages/core/tests/openai/test_openai_chat_client.py +++ b/python/packages/core/tests/openai/test_openai_chat_client.py @@ -1,25 +1,22 @@ # Copyright (c) Microsoft. All rights reserved. +import json import os -from typing import Annotated +from typing import Any from unittest.mock import MagicMock, patch import pytest from openai import BadRequestError +from pydantic import BaseModel +from pytest import param from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, - ChatAgent, ChatClientProtocol, ChatMessage, - ChatOptions, ChatResponse, - ChatResponseUpdate, DataContent, FunctionResultContent, HostedWebSearchTool, - TextContent, ToolProtocol, ai_function, prepare_function_call_results, @@ -170,7 +167,7 @@ async def test_content_filter_exception_handling(openai_unit_test_env: dict[str, patch.object(client.client.chat.completions, "create", side_effect=mock_error), pytest.raises(OpenAIContentFilterException), ): - await client._inner_get_response(messages=messages, chat_options=ChatOptions()) # type: ignore + await client._inner_get_response(messages=messages, options={}) # type: ignore def test_unsupported_tool_handling(openai_unit_test_env: dict[str, str]) -> None: @@ -183,12 +180,12 @@ def test_unsupported_tool_handling(openai_unit_test_env: dict[str, str]) -> None # This should ignore the unsupported ToolProtocol and return empty list result = client._prepare_tools_for_openai([unsupported_tool]) # type: ignore - assert result == [] + assert result == {} # Also test with a non-ToolProtocol that should be converted to dict dict_tool = {"type": "function", "name": "test"} result = client._prepare_tools_for_openai([dict_tool]) # type: ignore - assert result == [dict_tool] + assert result["tools"] == [dict_tool] @ai_function @@ -208,407 +205,6 @@ def get_weather(location: str) -> str: return f"The weather in {location} is sunny and 72°F." -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_chat_completion_response() -> None: - """Test OpenAI chat completion responses.""" - openai_chat_client = OpenAIChatClient() - - assert isinstance(openai_chat_client, ChatClientProtocol) - - messages: list[ChatMessage] = [] - messages.append( - ChatMessage( - role="user", - text="Emily and David, two passionate scientists, met during a research expedition to Antarctica. " - "Bonded by their love for the natural world and shared curiosity, they uncovered a " - "groundbreaking phenomenon in glaciology that could potentially reshape our understanding " - "of climate change.", - ) - ) - messages.append(ChatMessage(role="user", text="who are Emily and David?")) - - # Test that the client can be used to get a response - response = await openai_chat_client.get_response(messages=messages) - - assert response is not None - assert isinstance(response, ChatResponse) - assert "scientists" in response.text - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_chat_completion_response_params() -> None: - """Test OpenAI chat completion responses.""" - openai_chat_client = OpenAIChatClient() - - assert isinstance(openai_chat_client, ChatClientProtocol) - - messages: list[ChatMessage] = [] - messages.append( - ChatMessage( - role="user", - text="Emily and David, two passionate scientists, met during a research expedition to Antarctica. " - "Bonded by their love for the natural world and shared curiosity, they uncovered a " - "groundbreaking phenomenon in glaciology that could potentially reshape our understanding " - "of climate change.", - ) - ) - messages.append(ChatMessage(role="user", text="who are Emily and David?")) - - # Test that the client can be used to get a response - response = await openai_chat_client.get_response( - messages=messages, chat_options=ChatOptions(max_tokens=150, temperature=0.7, top_p=0.9) - ) - - assert response is not None - assert isinstance(response, ChatResponse) - assert "scientists" in response.text - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_chat_completion_response_tools() -> None: - """Test OpenAI chat completion responses.""" - openai_chat_client = OpenAIChatClient() - - assert isinstance(openai_chat_client, ChatClientProtocol) - - messages: list[ChatMessage] = [] - messages.append(ChatMessage(role="user", text="who are Emily and David?")) - - # Test that the client can be used to get a response - response = await openai_chat_client.get_response( - messages=messages, - tools=[get_story_text], - tool_choice="auto", - ) - - assert response is not None - assert isinstance(response, ChatResponse) - assert "scientists" in response.text - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_chat_client_streaming() -> None: - """Test Azure OpenAI chat completion responses.""" - openai_chat_client = OpenAIChatClient() - - assert isinstance(openai_chat_client, ChatClientProtocol) - - messages: list[ChatMessage] = [] - messages.append( - ChatMessage( - role="user", - text="Emily and David, two passionate scientists, met during a research expedition to Antarctica. " - "Bonded by their love for the natural world and shared curiosity, they uncovered a " - "groundbreaking phenomenon in glaciology that could potentially reshape our understanding " - "of climate change.", - ) - ) - messages.append(ChatMessage(role="user", text="who are Emily and David?")) - - # Test that the client can be used to get a response - response = openai_chat_client.get_streaming_response(messages=messages) - - full_message: str = "" - async for chunk in response: - assert chunk is not None - assert isinstance(chunk, ChatResponseUpdate) - assert chunk.message_id is not None - assert chunk.response_id is not None - for content in chunk.contents: - if isinstance(content, TextContent) and content.text: - full_message += content.text - - assert "scientists" in full_message - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_chat_client_streaming_tools() -> None: - """Test AzureOpenAI chat completion responses.""" - openai_chat_client = OpenAIChatClient() - - assert isinstance(openai_chat_client, ChatClientProtocol) - - messages: list[ChatMessage] = [] - messages.append(ChatMessage(role="user", text="who are Emily and David?")) - - # Test that the client can be used to get a response - response = openai_chat_client.get_streaming_response( - messages=messages, - tools=[get_story_text], - tool_choice="auto", - ) - full_message: str = "" - async for chunk in response: - assert chunk is not None - assert isinstance(chunk, ChatResponseUpdate) - for content in chunk.contents: - if isinstance(content, TextContent) and content.text: - full_message += content.text - - assert "scientists" in full_message - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_chat_client_web_search() -> None: - # Currently only a select few models support web search tool calls - openai_chat_client = OpenAIChatClient(model_id="gpt-4o-search-preview") - - assert isinstance(openai_chat_client, ChatClientProtocol) - - # Test that the client will use the web search tool - response = await openai_chat_client.get_response( - messages=[ - ChatMessage( - role="user", - text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.", - ) - ], - tools=[HostedWebSearchTool()], - tool_choice="auto", - ) - - assert response is not None - assert isinstance(response, ChatResponse) - assert "Rumi" in response.text - assert "Mira" in response.text - assert "Zoey" in response.text - - # Test that the client will use the web search tool with location - additional_properties = { - "user_location": { - "country": "US", - "city": "Seattle", - } - } - response = await openai_chat_client.get_response( - messages=[ChatMessage(role="user", text="What is the current weather? Do not ask for my current location.")], - tools=[HostedWebSearchTool(additional_properties=additional_properties)], - tool_choice="auto", - ) - assert response.text is not None - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_chat_client_web_search_streaming() -> None: - openai_chat_client = OpenAIChatClient(model_id="gpt-4o-search-preview") - - assert isinstance(openai_chat_client, ChatClientProtocol) - - # Test that the client will use the web search tool - response = openai_chat_client.get_streaming_response( - messages=[ - ChatMessage( - role="user", - text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.", - ) - ], - tools=[HostedWebSearchTool()], - tool_choice="auto", - ) - - assert response is not None - full_message: str = "" - async for chunk in response: - assert chunk is not None - assert isinstance(chunk, ChatResponseUpdate) - for content in chunk.contents: - if isinstance(content, TextContent) and content.text: - full_message += content.text - assert "Rumi" in full_message - assert "Mira" in full_message - assert "Zoey" in full_message - - # Test that the client will use the web search tool with location - additional_properties = { - "user_location": { - "country": "US", - "city": "Seattle", - } - } - response = openai_chat_client.get_streaming_response( - messages=[ChatMessage(role="user", text="What is the current weather? Do not ask for my current location.")], - tools=[HostedWebSearchTool(additional_properties=additional_properties)], - tool_choice="auto", - ) - assert response is not None - full_message: str = "" - async for chunk in response: - assert chunk is not None - assert isinstance(chunk, ChatResponseUpdate) - for content in chunk.contents: - if isinstance(content, TextContent) and content.text: - full_message += content.text - assert full_message is not None - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_chat_client_agent_basic_run(): - """Test OpenAI chat client agent basic run functionality with OpenAIChatClient.""" - async with ChatAgent( - chat_client=OpenAIChatClient(model_id="gpt-4o-search-preview"), - ) as agent: - # Test basic run - response = await agent.run("Hello! Please respond with 'Hello World' exactly.") - - assert isinstance(response, AgentRunResponse) - assert response.text is not None - assert len(response.text) > 0 - assert "hello world" in response.text.lower() - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_chat_client_agent_basic_run_streaming(): - """Test OpenAI chat client agent basic streaming functionality with OpenAIChatClient.""" - async with ChatAgent( - chat_client=OpenAIChatClient(model_id="gpt-4o-search-preview"), - ) as agent: - # Test streaming run - full_text = "" - async for chunk in agent.run_stream("Please respond with exactly: 'This is a streaming response test.'"): - assert isinstance(chunk, AgentRunResponseUpdate) - if chunk.text: - full_text += chunk.text - - assert len(full_text) > 0 - assert "streaming response test" in full_text.lower() - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_chat_client_agent_thread_persistence(): - """Test OpenAI chat client agent thread persistence across runs with OpenAIChatClient.""" - async with ChatAgent( - chat_client=OpenAIChatClient(model_id="gpt-4o-search-preview"), - instructions="You are a helpful assistant with good memory.", - ) as agent: - # Create a new thread that will be reused - thread = agent.get_new_thread() - - # First interaction - response1 = await agent.run("My name is Alice. Remember this.", thread=thread) - - assert isinstance(response1, AgentRunResponse) - assert response1.text is not None - - # Second interaction - test memory - response2 = await agent.run("What is my name?", thread=thread) - - assert isinstance(response2, AgentRunResponse) - assert response2.text is not None - assert "alice" in response2.text.lower() - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_chat_client_agent_existing_thread(): - """Test OpenAI chat client agent with existing thread to continue conversations across agent instances.""" - # First conversation - capture the thread - preserved_thread = None - - async with ChatAgent( - chat_client=OpenAIChatClient(model_id="gpt-4o-search-preview"), - instructions="You are a helpful assistant with good memory.", - ) as first_agent: - # Start a conversation and capture the thread - thread = first_agent.get_new_thread() - first_response = await first_agent.run("My name is Alice. Remember this.", thread=thread) - - assert isinstance(first_response, AgentRunResponse) - assert first_response.text is not None - - # Preserve the thread for reuse - preserved_thread = thread - - # Second conversation - reuse the thread in a new agent instance - if preserved_thread: - async with ChatAgent( - chat_client=OpenAIChatClient(model_id="gpt-4o-search-preview"), - instructions="You are a helpful assistant with good memory.", - ) as second_agent: - # Reuse the preserved thread - second_response = await second_agent.run("What is my name?", thread=preserved_thread) - - assert isinstance(second_response, AgentRunResponse) - assert second_response.text is not None - assert "alice" in second_response.text.lower() - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_chat_client_agent_level_tool_persistence(): - """Test that agent-level tools persist across multiple runs with OpenAI Chat Client.""" - - async with ChatAgent( - chat_client=OpenAIChatClient(model_id="gpt-4.1"), - instructions="You are a helpful assistant that uses available tools.", - tools=[get_weather], # Agent-level tool - ) as agent: - # First run - agent-level tool should be available - first_response = await agent.run("What's the weather like in Chicago?") - - assert isinstance(first_response, AgentRunResponse) - assert first_response.text is not None - # Should use the agent-level weather tool - assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "72"]) - - # Second run - agent-level tool should still be available (persistence test) - second_response = await agent.run("What's the weather in Miami?") - - assert isinstance(second_response, AgentRunResponse) - assert second_response.text is not None - # Should use the agent-level weather tool again - assert any(term in second_response.text.lower() for term in ["miami", "sunny", "72"]) - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_chat_client_run_level_tool_isolation(): - """Test that run-level tools are isolated to specific runs and don't persist with OpenAI Chat Client.""" - # Counter to track how many times the weather tool is called - call_count = 0 - - @ai_function - async def get_weather_with_counter(location: Annotated[str, "The location as a city name"]) -> str: - """Get the current weather in a given location.""" - nonlocal call_count - call_count += 1 - return f"The weather in {location} is sunny and 72°F." - - async with ChatAgent( - chat_client=OpenAIChatClient(model_id="gpt-4.1"), - instructions="You are a helpful assistant.", - ) as agent: - # First run - use run-level tool - first_response = await agent.run( - "What's the weather like in Chicago?", - tools=[get_weather_with_counter], # Run-level tool - ) - - assert isinstance(first_response, AgentRunResponse) - assert first_response.text is not None - # Should use the run-level weather tool (call count should be 1) - assert call_count == 1 - assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "72"]) - - # Second run - run-level tool should NOT persist (key isolation test) - second_response = await agent.run("What's the weather like in Miami?") - - assert isinstance(second_response, AgentRunResponse) - assert second_response.text is not None - # Should NOT use the weather tool since it was only run-level in previous call - # Call count should still be 1 (no additional calls) - assert call_count == 1 - - async def test_exception_message_includes_original_error_details() -> None: """Test that exception messages include original error details in the new format.""" client = OpenAIChatClient(model_id="test-model", api_key="test-key") @@ -627,7 +223,7 @@ async def test_exception_message_includes_original_error_details() -> None: patch.object(client.client.chat.completions, "create", side_effect=mock_error), pytest.raises(ServiceResponseException) as exc_info, ): - await client._inner_get_response(messages=messages, chat_options=ChatOptions()) # type: ignore + await client._inner_get_response(messages=messages, options={}) # type: ignore exception_message = str(exc_info.value) assert "service failed to complete the prompt:" in exception_message @@ -667,7 +263,7 @@ def test_chat_response_content_order_text_before_tool_calls(openai_unit_test_env ) client = OpenAIChatClient() - response = client._parse_response_from_openai(mock_response, ChatOptions()) + response = client._parse_response_from_openai(mock_response, {}) # Verify we have both text and tool call content assert len(response.messages) == 1 @@ -894,3 +490,191 @@ def test_prepare_content_for_openai_document_file_mapping(openai_unit_test_env: assert result["type"] == "file" assert "filename" not in result["file"] # None filename should be omitted + + +# region Integration Tests + + +class OutputStruct(BaseModel): + """A structured output for testing purposes.""" + + location: str + weather: str | None = None + + +@pytest.mark.flaky +@skip_if_openai_integration_tests_disabled +@pytest.mark.parametrize( + "option_name,option_value,needs_validation", + [ + # Simple ChatOptions - just verify they don't fail + param("temperature", 0.7, False, id="temperature"), + param("top_p", 0.9, False, id="top_p"), + param("max_tokens", 500, False, id="max_tokens"), + param("seed", 123, False, id="seed"), + param("user", "test-user-id", False, id="user"), + param("frequency_penalty", 0.5, False, id="frequency_penalty"), + param("presence_penalty", 0.3, False, id="presence_penalty"), + param("stop", ["END"], False, id="stop"), + param("allow_multiple_tool_calls", True, False, id="allow_multiple_tool_calls"), + # OpenAIChatOptions - just verify they don't fail + param("logit_bias", {"50256": -1}, False, id="logit_bias"), + param("prediction", {"type": "content", "content": "hello world"}, False, id="prediction"), + # Complex options requiring output validation + param("tools", [get_weather], True, id="tools_function"), + param("tool_choice", "auto", True, id="tool_choice_auto"), + param("tool_choice", "none", True, id="tool_choice_none"), + param("tool_choice", "required", True, id="tool_choice_required_any"), + param( + "tool_choice", + {"mode": "required", "required_function_name": "get_weather"}, + True, + id="tool_choice_required", + ), + param("response_format", OutputStruct, True, id="response_format_pydantic"), + param( + "response_format", + { + "type": "json_schema", + "json_schema": { + "name": "WeatherDigest", + "strict": True, + "schema": { + "title": "WeatherDigest", + "type": "object", + "properties": { + "location": {"type": "string"}, + "conditions": {"type": "string"}, + "temperature_c": {"type": "number"}, + "advisory": {"type": "string"}, + }, + "required": ["location", "conditions", "temperature_c", "advisory"], + "additionalProperties": False, + }, + }, + }, + True, + id="response_format_runtime_json_schema", + ), + ], +) +async def test_integration_options( + option_name: str, + option_value: Any, + needs_validation: bool, +) -> None: + """Parametrized test covering all ChatOptions and OpenAIChatOptions. + + Tests both streaming and non-streaming modes for each option to ensure + they don't cause failures. Options marked with needs_validation also + check that the feature actually works correctly. + """ + client = OpenAIChatClient() + # to ensure toolmode required does not endlessly loop + client.function_invocation_configuration.max_iterations = 1 + + for streaming in [False, True]: + # Prepare test message + if option_name.startswith("tools") or option_name.startswith("tool_choice"): + # Use weather-related prompt for tool tests + messages = [ChatMessage(role="user", text="What is the weather in Seattle?")] + elif option_name.startswith("response_format"): + # Use prompt that works well with structured output + messages = [ChatMessage(role="user", text="The weather in Seattle is sunny")] + messages.append(ChatMessage(role="user", text="What is the weather in Seattle?")) + else: + # Generic prompt for simple options + messages = [ChatMessage(role="user", text="Say 'Hello World' briefly.")] + + # Build options dict + options: dict[str, Any] = {option_name: option_value} + + # Add tools if testing tool_choice to avoid errors + if option_name.startswith("tool_choice"): + options["tools"] = [get_weather] + + if streaming: + # Test streaming mode + response_gen = client.get_streaming_response( + messages=messages, + options=options, + ) + + output_format = option_value if option_name.startswith("response_format") else None + response = await ChatResponse.from_chat_response_generator(response_gen, output_format_type=output_format) + else: + # Test non-streaming mode + response = await client.get_response( + messages=messages, + options=options, + ) + + assert response is not None + assert isinstance(response, ChatResponse) + assert response.text is not None, f"No text in response for option '{option_name}'" + assert len(response.text) > 0, f"Empty response for option '{option_name}'" + + # Validate based on option type + if needs_validation: + if option_name.startswith("tools") or option_name.startswith("tool_choice"): + # Should have called the weather function + text = response.text.lower() + assert "sunny" in text or "seattle" in text, f"Tool not invoked for {option_name}" + elif option_name.startswith("response_format"): + if option_value == OutputStruct: + # Should have structured output + assert response.value is not None, "No structured output" + assert isinstance(response.value, OutputStruct) + assert "seattle" in response.value.location.lower() + else: + # Runtime JSON schema + assert response.value is None, "No structured output, can't parse any json." + response_value = json.loads(response.text) + assert isinstance(response_value, dict) + assert "location" in response_value + assert "seattle" in response_value["location"].lower() + + +@pytest.mark.flaky +@skip_if_openai_integration_tests_disabled +async def test_integration_web_search() -> None: + client = OpenAIChatClient(model_id="gpt-4o-search-preview") + + for streaming in [False, True]: + content = { + "messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.", + "options": { + "tool_choice": "auto", + "tools": [HostedWebSearchTool()], + }, + } + if streaming: + response = await ChatResponse.from_chat_response_generator(client.get_streaming_response(**content)) + else: + response = await client.get_response(**content) + + assert response is not None + assert isinstance(response, ChatResponse) + assert "Rumi" in response.text + assert "Mira" in response.text + assert "Zoey" in response.text + + # Test that the client will use the web search tool with location + additional_properties = { + "user_location": { + "country": "US", + "city": "Seattle", + } + } + content = { + "messages": "What is the current weather? Do not ask for my current location.", + "options": { + "tool_choice": "auto", + "tools": [HostedWebSearchTool(additional_properties=additional_properties)], + }, + } + if streaming: + response = await ChatResponse.from_chat_response_generator(client.get_streaming_response(**content)) + else: + response = await client.get_response(**content) + assert response.text is not None diff --git a/python/packages/core/tests/openai/test_openai_chat_client_base.py b/python/packages/core/tests/openai/test_openai_chat_client_base.py index 3e48899509..3c9a432db0 100644 --- a/python/packages/core/tests/openai/test_openai_chat_client_base.py +++ b/python/packages/core/tests/openai/test_openai_chat_client_base.py @@ -115,7 +115,6 @@ async def test_cmc_no_fcc_in_response( openai_chat_completion = OpenAIChatClient() await openai_chat_completion.get_response( messages=chat_history, - arguments={}, ) mock_create.assert_awaited_once_with( model=openai_unit_test_env["OPENAI_CHAT_MODEL_ID"], @@ -199,7 +198,7 @@ async def test_cmc_additional_properties( chat_history.append(ChatMessage(role="user", text="hello world")) openai_chat_completion = OpenAIChatClient() - await openai_chat_completion.get_response(messages=chat_history, additional_properties={"reasoning_effort": "low"}) + await openai_chat_completion.get_response(messages=chat_history, options={"reasoning_effort": "low"}) mock_create.assert_awaited_once_with( model=openai_unit_test_env["OPENAI_CHAT_MODEL_ID"], stream=False, @@ -382,8 +381,6 @@ def test_chat_response_created_at_uses_utc(openai_unit_test_env: dict[str, str]) This is a regression test for the issue where created_at was using local time but labeling it as UTC (with 'Z' suffix). """ - from agent_framework import ChatOptions - # Use a specific Unix timestamp: 1733011890 = 2024-12-01T00:31:30Z (UTC) # This ensures we test that the timestamp is actually converted to UTC utc_timestamp = 1733011890 @@ -399,7 +396,7 @@ def test_chat_response_created_at_uses_utc(openai_unit_test_env: dict[str, str]) ) client = OpenAIChatClient() - response = client._parse_response_from_openai(mock_response, ChatOptions()) + response = client._parse_response_from_openai(mock_response, {}) # Verify that created_at is correctly formatted as UTC assert response.created_at is not None diff --git a/python/packages/core/tests/openai/test_openai_responses_client.py b/python/packages/core/tests/openai/test_openai_responses_client.py index 778ce843ee..c91297d7df 100644 --- a/python/packages/core/tests/openai/test_openai_responses_client.py +++ b/python/packages/core/tests/openai/test_openai_responses_client.py @@ -2,28 +2,35 @@ import asyncio import base64 +import json import os from datetime import datetime, timezone -from typing import Annotated +from typing import Annotated, Any from unittest.mock import MagicMock, patch import pytest from openai import BadRequestError from openai.types.responses.response_reasoning_item import Summary -from openai.types.responses.response_reasoning_summary_text_delta_event import ResponseReasoningSummaryTextDeltaEvent -from openai.types.responses.response_reasoning_summary_text_done_event import ResponseReasoningSummaryTextDoneEvent -from openai.types.responses.response_reasoning_text_delta_event import ResponseReasoningTextDeltaEvent -from openai.types.responses.response_reasoning_text_done_event import ResponseReasoningTextDoneEvent +from openai.types.responses.response_reasoning_summary_text_delta_event import ( + ResponseReasoningSummaryTextDeltaEvent, +) +from openai.types.responses.response_reasoning_summary_text_done_event import ( + ResponseReasoningSummaryTextDoneEvent, +) +from openai.types.responses.response_reasoning_text_delta_event import ( + ResponseReasoningTextDeltaEvent, +) +from openai.types.responses.response_reasoning_text_done_event import ( + ResponseReasoningTextDoneEvent, +) from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent from pydantic import BaseModel +from pytest import param from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, - AgentThread, - ChatAgent, ChatClientProtocol, ChatMessage, + ChatOptions, ChatResponse, ChatResponseUpdate, CodeInterpreterToolCallContent, @@ -42,15 +49,17 @@ HostedWebSearchTool, ImageGenerationToolCallContent, ImageGenerationToolResultContent, - MCPStreamableHTTPTool, Role, TextContent, TextReasoningContent, UriContent, ai_function, ) -from agent_framework._types import ChatOptions -from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidRequestError, ServiceResponseException +from agent_framework.exceptions import ( + ServiceInitializationError, + ServiceInvalidRequestError, + ServiceResponseException, +) from agent_framework.openai import OpenAIResponsesClient from agent_framework.openai._exceptions import OpenAIContentFilterException @@ -70,10 +79,13 @@ class OutputStruct(BaseModel): weather: str | None = None -async def create_vector_store(client: OpenAIResponsesClient) -> tuple[str, HostedVectorStoreContent]: +async def create_vector_store( + client: OpenAIResponsesClient, +) -> tuple[str, HostedVectorStoreContent]: """Create a vector store with sample documents for testing.""" file = await client.client.files.create( - file=("todays_weather.txt", b"The weather today is sunny with a high of 75F."), purpose="user_data" + file=("todays_weather.txt", b"The weather today is sunny with a high of 75F."), + purpose="user_data", ) vector_store = await client.client.vector_stores.create( name="knowledge_base", @@ -217,25 +229,27 @@ def test_get_response_with_all_parameters() -> None: asyncio.run( client.get_response( messages=[ChatMessage(role="user", text="Test message")], - include=["message.output_text.logprobs"], - instructions="You are a helpful assistant", - max_tokens=100, - parallel_tool_calls=True, - model_id="gpt-4", - previous_response_id="prev-123", - reasoning={"chain_of_thought": "enabled"}, - service_tier="auto", - response_format=OutputStruct, - seed=42, - store=True, - temperature=0.7, - tool_choice="auto", - tools=[get_weather], - top_p=0.9, - user="test-user", - truncation="auto", - timeout=30.0, - additional_properties={"custom": "value"}, + options={ + "include": ["message.output_text.logprobs"], + "instructions": "You are a helpful assistant", + "max_tokens": 100, + "parallel_tool_calls": True, + "model_id": "gpt-4", + "previous_response_id": "prev-123", + "reasoning": {"chain_of_thought": "enabled"}, + "service_tier": "auto", + "response_format": OutputStruct, + "seed": 42, + "store": True, + "temperature": 0.7, + "tool_choice": "auto", + "tools": [get_weather], + "top_p": 0.9, + "user": "test-user", + "truncation": "auto", + "timeout": 30.0, + "additional_properties": {"custom": "value"}, + }, ) ) @@ -247,7 +261,12 @@ def test_web_search_tool_with_location() -> None: # Test web search tool with location web_search_tool = HostedWebSearchTool( additional_properties={ - "user_location": {"country": "US", "city": "Seattle", "region": "WA", "timezone": "America/Los_Angeles"} + "user_location": { + "country": "US", + "city": "Seattle", + "region": "WA", + "timezone": "America/Los_Angeles", + } } ) @@ -256,8 +275,7 @@ def test_web_search_tool_with_location() -> None: asyncio.run( client.get_response( messages=[ChatMessage(role="user", text="What's the weather?")], - tools=[web_search_tool], - tool_choice="auto", + options={"tools": [web_search_tool], "tool_choice": "auto"}, ) ) @@ -272,7 +290,10 @@ def test_file_search_tool_with_invalid_inputs() -> None: # Should raise an error due to invalid inputs with pytest.raises(ValueError, match="HostedFileSearchTool requires inputs to be of type"): asyncio.run( - client.get_response(messages=[ChatMessage(role="user", text="Search files")], tools=[file_search_tool]) + client.get_response( + messages=[ChatMessage(role="user", text="Search files")], + options={"tools": [file_search_tool]}, + ) ) @@ -285,7 +306,10 @@ def test_code_interpreter_tool_variations() -> None: with pytest.raises(ServiceResponseException): asyncio.run( - client.get_response(messages=[ChatMessage(role="user", text="Run some code")], tools=[code_tool_empty]) + client.get_response( + messages=[ChatMessage(role="user", text="Run some code")], + options={"tools": [code_tool_empty]}, + ) ) # Test code interpreter with files @@ -296,7 +320,8 @@ def test_code_interpreter_tool_variations() -> None: with pytest.raises(ServiceResponseException): asyncio.run( client.get_response( - messages=[ChatMessage(role="user", text="Process these files")], tools=[code_tool_with_files] + messages=[ChatMessage(role="user", text="Process these files")], + options={"tools": [code_tool_with_files]}, ) ) @@ -330,7 +355,10 @@ def test_hosted_file_search_tool_validation() -> None: with pytest.raises((ValueError, ServiceInvalidRequestError)): asyncio.run( - client.get_response(messages=[ChatMessage(role="user", text="Test")], tools=[empty_file_search_tool]) + client.get_response( + messages=[ChatMessage(role="user", text="Test")], + options={"tools": [empty_file_search_tool]}, + ) ) @@ -377,7 +405,8 @@ async def test_response_format_parse_path() -> None: with patch.object(client.client.responses, "parse", return_value=mock_parsed_response): response = await client.get_response( - messages=[ChatMessage(role="user", text="Test message")], response_format=OutputStruct, store=True + messages=[ChatMessage(role="user", text="Test message")], + options={"response_format": OutputStruct, "store": True}, ) assert response.response_id == "parsed_response_123" assert response.conversation_id == "parsed_response_123" @@ -403,7 +432,8 @@ async def test_response_format_parse_path_with_conversation_id() -> None: with patch.object(client.client.responses, "parse", return_value=mock_parsed_response): response = await client.get_response( - messages=[ChatMessage(role="user", text="Test message")], response_format=OutputStruct, store=True + messages=[ChatMessage(role="user", text="Test message")], + options={"response_format": OutputStruct, "store": True}, ) assert response.response_id == "parsed_response_123" assert response.conversation_id == "conversation_456" @@ -425,7 +455,8 @@ async def test_bad_request_error_non_content_filter() -> None: with patch.object(client.client.responses, "parse", side_effect=mock_error): with pytest.raises(ServiceResponseException) as exc_info: await client.get_response( - messages=[ChatMessage(role="user", text="Test message")], response_format=OutputStruct + messages=[ChatMessage(role="user", text="Test message")], + options={"response_format": OutputStruct}, ) assert "failed to complete the prompt" in str(exc_info.value) @@ -450,41 +481,6 @@ async def test_streaming_content_filter_exception_handling() -> None: break -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_get_streaming_response_with_all_parameters() -> None: - """Test get_streaming_response with all possible parameters.""" - client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") - - # Should fail due to invalid API key - with pytest.raises(ServiceResponseException): - response = client.get_streaming_response( - messages=[ChatMessage(role="user", text="Test streaming")], - include=["file_search_call.results"], - instructions="Stream response test", - max_tokens=50, - parallel_tool_calls=False, - model_id="gpt-4", - previous_response_id="stream-prev-123", - reasoning={"mode": "stream"}, - service_tier="default", - response_format=OutputStruct, - seed=123, - store=False, - temperature=0.5, - tool_choice="none", - tools=[], - top_p=0.8, - user="stream-user", - truncation="last_messages", - timeout=15.0, - additional_properties={"stream_custom": "stream_value"}, - ) - # Just iterate once to trigger the logic - async for _ in response: - break - - def test_response_content_creation_with_annotations() -> None: """Test _parse_response_from_openai with different annotation types.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") @@ -517,7 +513,7 @@ def test_response_content_creation_with_annotations() -> None: mock_response.output = [mock_message_item] with patch.object(client, "_get_metadata_from_response", return_value={}): - response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, options={}) # type: ignore assert len(response.messages[0].contents) >= 1 assert isinstance(response.messages[0].contents[0], TextContent) @@ -548,7 +544,7 @@ def test_response_content_creation_with_refusal() -> None: mock_response.output = [mock_message_item] - response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, options={}) # type: ignore assert len(response.messages[0].contents) == 1 assert isinstance(response.messages[0].contents[0], TextContent) @@ -578,7 +574,7 @@ def test_response_content_creation_with_reasoning() -> None: mock_response.output = [mock_reasoning_item] - response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, options={}) # type: ignore assert len(response.messages[0].contents) == 2 assert isinstance(response.messages[0].contents[0], TextReasoningContent) @@ -614,7 +610,7 @@ def test_response_content_creation_with_code_interpreter() -> None: mock_response.output = [mock_code_interpreter_item] - response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, options={}) # type: ignore assert len(response.messages[0].contents) == 2 call_content, result_content = response.messages[0].contents @@ -649,7 +645,7 @@ def test_response_content_creation_with_function_call() -> None: mock_response.output = [mock_function_call_item] - response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, options={}) # type: ignore assert len(response.messages[0].contents) == 1 assert isinstance(response.messages[0].contents[0], FunctionCallContent) @@ -710,7 +706,7 @@ def test_parse_response_from_openai_with_mcp_approval_request() -> None: mock_response.output = [mock_item] - response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, options={}) # type: ignore assert isinstance(response.messages[0].contents[0], FunctionApprovalRequestContent) req = response.messages[0].contents[0] @@ -720,7 +716,9 @@ def test_parse_response_from_openai_with_mcp_approval_request() -> None: assert req.function_call.additional_properties["server_label"] == "My_MCP" -def test_responses_client_created_at_uses_utc(openai_unit_test_env: dict[str, str]) -> None: +def test_responses_client_created_at_uses_utc( + openai_unit_test_env: dict[str, str], +) -> None: """Test that ChatResponse from responses client uses UTC timestamp. This is a regression test for the issue where created_at was using local time @@ -751,7 +749,7 @@ def test_responses_client_created_at_uses_utc(openai_unit_test_env: dict[str, st mock_response.output = [mock_message_item] with patch.object(client, "_get_metadata_from_response", return_value={}): - response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, options={}) # type: ignore # Verify that created_at is correctly formatted as UTC assert response.created_at is not None @@ -1203,7 +1201,7 @@ def test_service_response_exception_includes_original_error_details() -> None: patch.object(client.client.responses, "parse", side_effect=mock_error), pytest.raises(ServiceResponseException) as exc_info, ): - asyncio.run(client.get_response(messages=messages, response_format=OutputStruct)) + asyncio.run(client.get_response(messages=messages, options={"response_format": OutputStruct})) exception_message = str(exc_info.value) assert "service failed to complete the prompt:" in exception_message @@ -1219,7 +1217,7 @@ def test_get_streaming_response_with_response_format() -> None: with pytest.raises(ServiceResponseException): async def run_streaming(): - async for _ in client.get_streaming_response(messages=messages, response_format=OutputStruct): + async for _ in client.get_streaming_response(messages=messages, options={"response_format": OutputStruct}): pass asyncio.run(run_streaming()) @@ -1518,7 +1516,7 @@ def test_parse_response_from_openai_image_generation_raw_base64(): mock_response.output = [mock_item] with patch.object(client, "_get_metadata_from_response", return_value={}): - response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, options={}) # type: ignore # Verify the response contains call + result with DataContent output assert len(response.messages[0].contents) == 2 @@ -1555,7 +1553,7 @@ def test_parse_response_from_openai_image_generation_existing_data_uri(): mock_response.output = [mock_item] with patch.object(client, "_get_metadata_from_response", return_value={}): - response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, options={}) # type: ignore # Verify the response contains call + result with DataContent output assert len(response.messages[0].contents) == 2 @@ -1591,7 +1589,7 @@ def test_parse_response_from_openai_image_generation_format_detection(): mock_response_jpeg.output = [mock_item_jpeg] with patch.object(client, "_get_metadata_from_response", return_value={}): - response_jpeg = client._parse_response_from_openai(mock_response_jpeg, chat_options=ChatOptions()) # type: ignore + response_jpeg = client._parse_response_from_openai(mock_response_jpeg, options={}) # type: ignore result_contents = response_jpeg.messages[0].contents assert isinstance(result_contents[1], ImageGenerationToolResultContent) outputs = result_contents[1].outputs @@ -1617,7 +1615,7 @@ def test_parse_response_from_openai_image_generation_format_detection(): mock_response_webp.output = [mock_item_webp] with patch.object(client, "_get_metadata_from_response", return_value={}): - response_webp = client._parse_response_from_openai(mock_response_webp, chat_options=ChatOptions()) # type: ignore + response_webp = client._parse_response_from_openai(mock_response_webp, options={}) # type: ignore outputs_webp = response_webp.messages[0].contents[1].outputs assert outputs_webp and isinstance(outputs_webp, DataContent) assert outputs_webp.media_type == "image/webp" @@ -1647,7 +1645,7 @@ def test_parse_response_from_openai_image_generation_fallback(): mock_response.output = [mock_item] with patch.object(client, "_get_metadata_from_response", return_value={}): - response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, options={}) # type: ignore # Verify it falls back to PNG format for unrecognized binary data assert len(response.messages[0].contents) == 2 @@ -1684,7 +1682,7 @@ async def test_prepare_options_store_parameter_handling() -> None: assert "previous_response_id" not in options -def test_openai_responses_client_with_callable_api_key() -> None: +def test_with_callable_api_key() -> None: """Test OpenAIResponsesClient initialization with callable API key.""" async def get_api_key() -> str: @@ -1698,278 +1696,189 @@ async def get_api_key() -> str: assert client.client is not None -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_response() -> None: - """Test OpenAI chat completion responses.""" - openai_responses_client = OpenAIResponsesClient() - - assert isinstance(openai_responses_client, ChatClientProtocol) - - messages: list[ChatMessage] = [] - messages.append( - ChatMessage( - role="user", - text="Emily and David, two passionate scientists, met during a research expedition to Antarctica. " - "Bonded by their love for the natural world and shared curiosity, they uncovered a " - "groundbreaking phenomenon in glaciology that could potentially reshape our understanding " - "of climate change.", - ) - ) - messages.append(ChatMessage(role="user", text="who are Emily and David?")) - - # Test that the client can be used to get a response - response = await openai_responses_client.get_response(messages=messages) - - assert response is not None - assert isinstance(response, ChatResponse) - assert "scientists" in response.text - - messages.clear() - messages.append(ChatMessage(role="user", text="The weather in Seattle is sunny")) - messages.append(ChatMessage(role="user", text="What is the weather in Seattle?")) - - # Test that the client can be used to get a response - response = await openai_responses_client.get_response( - messages=messages, - response_format=OutputStruct, - ) - - assert response is not None - assert isinstance(response, ChatResponse) - output = response.value - assert output is not None, "Response value is None" - assert "seattle" in output.location.lower() - assert output.weather is not None +# region Integration Tests @pytest.mark.flaky @skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_response_tools() -> None: - """Test OpenAI chat completion responses.""" - openai_responses_client = OpenAIResponsesClient() - - assert isinstance(openai_responses_client, ChatClientProtocol) - - messages: list[ChatMessage] = [] - messages.append(ChatMessage(role="user", text="What is the weather in New York?")) - - # Test that the client can be used to get a response - response = await openai_responses_client.get_response( - messages=messages, - tools=[get_weather], - tool_choice="auto", - ) - - assert response is not None - assert isinstance(response, ChatResponse) - assert "sunny" in response.text.lower() - - messages.clear() - messages.append(ChatMessage(role="user", text="What is the weather in Seattle?")) - - # Test that the client can be used to get a response - response = await openai_responses_client.get_response( - messages=messages, - tools=[get_weather], - tool_choice="auto", - response_format=OutputStruct, - ) - - assert response is not None - assert isinstance(response, ChatResponse) - output = OutputStruct.model_validate_json(response.text) - assert "seattle" in output.location.lower() - assert "sunny" in output.weather.lower() - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_streaming() -> None: - """Test OpenAI chat completion responses.""" - openai_responses_client = OpenAIResponsesClient() - - assert isinstance(openai_responses_client, ChatClientProtocol) - - messages: list[ChatMessage] = [] - messages.append( - ChatMessage( - role="user", - text="Emily and David, two passionate scientists, met during a research expedition to Antarctica. " - "Bonded by their love for the natural world and shared curiosity, they uncovered a " - "groundbreaking phenomenon in glaciology that could potentially reshape our understanding " - "of climate change.", - ) - ) - messages.append(ChatMessage(role="user", text="who are Emily and David?")) - - # Test that the client can be used to get a response - response = await ChatResponse.from_chat_response_generator( - openai_responses_client.get_streaming_response(messages=messages) - ) - - assert "scientists" in response.text - - messages.clear() - messages.append(ChatMessage(role="user", text="The weather in Seattle is sunny")) - messages.append(ChatMessage(role="user", text="What is the weather in Seattle?")) - - response = openai_responses_client.get_streaming_response( - messages=messages, - response_format=OutputStruct, - ) - chunks = [] - async for chunk in response: - assert chunk is not None - assert isinstance(chunk, ChatResponseUpdate) - chunks.append(chunk) - full_message = ChatResponse.from_chat_response_updates(chunks, output_format_type=OutputStruct) - output = full_message.value - assert output is not None, "Response value is None" - assert "seattle" in output.location.lower() - assert output.weather is not None - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_streaming_tools() -> None: - """Test OpenAI chat completion responses.""" - openai_responses_client = OpenAIResponsesClient() - - assert isinstance(openai_responses_client, ChatClientProtocol) - - messages: list[ChatMessage] = [ChatMessage(role="user", text="What is the weather in Seattle?")] - - # Test that the client can be used to get a response - response = openai_responses_client.get_streaming_response( - messages=messages, - tools=[get_weather], - tool_choice="auto", - ) - full_message: str = "" - async for chunk in response: - assert chunk is not None - assert isinstance(chunk, ChatResponseUpdate) - for content in chunk.contents: - if isinstance(content, TextContent) and content.text: - full_message += content.text - - assert "sunny" in full_message.lower() - - messages.clear() - messages.append(ChatMessage(role="user", text="What is the weather in Seattle?")) - - response = openai_responses_client.get_streaming_response( - messages=messages, - tools=[get_weather], - tool_choice="auto", - response_format=OutputStruct, - ) - chunks = [] - async for chunk in response: - assert chunk is not None - assert isinstance(chunk, ChatResponseUpdate) - chunks.append(chunk) - - full_message = ChatResponse.from_chat_response_updates(chunks, output_format_type=OutputStruct) - output = full_message.value - assert output is not None, "Response value is None" - assert "seattle" in output.location.lower() - assert "sunny" in output.weather.lower() - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_web_search() -> None: +@pytest.mark.parametrize( + "option_name,option_value,needs_validation", + [ + # Simple ChatOptions - just verify they don't fail + param("temperature", 0.7, False, id="temperature"), + param("top_p", 0.9, False, id="top_p"), + param("max_tokens", 500, False, id="max_tokens"), + param("seed", 123, False, id="seed"), + param("user", "test-user-id", False, id="user"), + param("metadata", {"test_key": "test_value"}, False, id="metadata"), + param("frequency_penalty", 0.5, False, id="frequency_penalty"), + param("presence_penalty", 0.3, False, id="presence_penalty"), + param("stop", ["END"], False, id="stop"), + param("allow_multiple_tool_calls", True, False, id="allow_multiple_tool_calls"), + param("tool_choice", "none", True, id="tool_choice_none"), + # OpenAIResponsesOptions - just verify they don't fail + param("safety_identifier", "user-hash-abc123", False, id="safety_identifier"), + param("truncation", "auto", False, id="truncation"), + param("top_logprobs", 5, False, id="top_logprobs"), + param("prompt_cache_key", "test-cache-key", False, id="prompt_cache_key"), + param("max_tool_calls", 3, False, id="max_tool_calls"), + # Complex options requiring output validation + param("tools", [get_weather], True, id="tools_function"), + param("tool_choice", "auto", True, id="tool_choice_auto"), + param("tool_choice", "required", True, id="tool_choice_required_any"), + param( + "tool_choice", + {"mode": "required", "required_function_name": "get_weather"}, + True, + id="tool_choice_required", + ), + param("response_format", OutputStruct, True, id="response_format_pydantic"), + param( + "response_format", + { + "type": "json_schema", + "json_schema": { + "name": "WeatherDigest", + "strict": True, + "schema": { + "title": "WeatherDigest", + "type": "object", + "properties": { + "location": {"type": "string"}, + "conditions": {"type": "string"}, + "temperature_c": {"type": "number"}, + "advisory": {"type": "string"}, + }, + "required": ["location", "conditions", "temperature_c", "advisory"], + "additionalProperties": False, + }, + }, + }, + True, + id="response_format_runtime_json_schema", + ), + ], +) +async def test_integration_options( + option_name: str, + option_value: Any, + needs_validation: bool, +) -> None: + """Parametrized test covering all ChatOptions and OpenAIResponsesOptions. + + Tests both streaming and non-streaming modes for each option to ensure + they don't cause failures. Options marked with needs_validation also + check that the feature actually works correctly. + """ openai_responses_client = OpenAIResponsesClient() + # to ensure toolmode required does not endlessly loop + openai_responses_client.function_invocation_configuration.max_iterations = 1 + + for streaming in [False, True]: + # Prepare test message + if option_name.startswith("tools") or option_name.startswith("tool_choice"): + # Use weather-related prompt for tool tests + messages = [ChatMessage(role="user", text="What is the weather in Seattle?")] + elif option_name.startswith("response_format"): + # Use prompt that works well with structured output + messages = [ChatMessage(role="user", text="The weather in Seattle is sunny")] + messages.append(ChatMessage(role="user", text="What is the weather in Seattle?")) + else: + # Generic prompt for simple options + messages = [ChatMessage(role="user", text="Say 'Hello World' briefly.")] + + # Build options dict + options: dict[str, Any] = {option_name: option_value} + + # Add tools if testing tool_choice to avoid errors + if option_name.startswith("tool_choice"): + options["tools"] = [get_weather] + + if streaming: + # Test streaming mode + response_gen = openai_responses_client.get_streaming_response( + messages=messages, + options=options, + ) - assert isinstance(openai_responses_client, ChatClientProtocol) - - # Test that the client will use the web search tool - response = await openai_responses_client.get_response( - messages=[ - ChatMessage( - role="user", - text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.", + output_format = option_value if option_name.startswith("response_format") else None + response = await ChatResponse.from_chat_response_generator(response_gen, output_format_type=output_format) + else: + # Test non-streaming mode + response = await openai_responses_client.get_response( + messages=messages, + options=options, ) - ], - tools=[HostedWebSearchTool()], - tool_choice="auto", - ) - assert response is not None - assert isinstance(response, ChatResponse) - assert "Rumi" in response.text - assert "Mira" in response.text - assert "Zoey" in response.text - - # Test that the client will use the web search tool with location - additional_properties = { - "user_location": { - "country": "US", - "city": "Seattle", - } - } - response = await openai_responses_client.get_response( - messages=[ChatMessage(role="user", text="What is the current weather? Do not ask for my current location.")], - tools=[HostedWebSearchTool(additional_properties=additional_properties)], - tool_choice="auto", - ) - assert response.text is not None + assert response is not None + assert isinstance(response, ChatResponse) + assert response.text is not None, f"No text in response for option '{option_name}'" + assert len(response.text) > 0, f"Empty response for option '{option_name}'" + + # Validate based on option type + if needs_validation: + if option_name.startswith("tools") or option_name.startswith("tool_choice"): + # Should have called the weather function + text = response.text.lower() + assert "sunny" in text or "seattle" in text, f"Tool not invoked for {option_name}" + elif option_name.startswith("response_format"): + if option_value == OutputStruct: + # Should have structured output + assert response.value is not None, "No structured output" + assert isinstance(response.value, OutputStruct) + assert "seattle" in response.value.location.lower() + else: + # Runtime JSON schema + assert response.value is None, "No structured output, can't parse any json." + response_value = json.loads(response.text) + assert isinstance(response_value, dict) + assert "location" in response_value + assert "seattle" in response_value["location"].lower() @pytest.mark.flaky @skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_web_search_streaming() -> None: - openai_responses_client = OpenAIResponsesClient() - - assert isinstance(openai_responses_client, ChatClientProtocol) - - # Test that the client will use the web search tool - response = openai_responses_client.get_streaming_response( - messages=[ - ChatMessage( - role="user", - text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.", - ) - ], - tools=[HostedWebSearchTool()], - tool_choice="auto", - ) - - assert response is not None - full_message: str = "" - async for chunk in response: - assert chunk is not None - assert isinstance(chunk, ChatResponseUpdate) - for content in chunk.contents: - if isinstance(content, TextContent) and content.text: - full_message += content.text - assert "Rumi" in full_message - assert "Mira" in full_message - assert "Zoey" in full_message - - # Test that the client will use the web search tool with location - additional_properties = { - "user_location": { - "country": "US", - "city": "Seattle", +async def test_integration_web_search() -> None: + client = OpenAIResponsesClient(model_id="gpt-5") + + for streaming in [False, True]: + content = { + "messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.", + "options": { + "tool_choice": "auto", + "tools": [HostedWebSearchTool()], + }, } - } - response = openai_responses_client.get_streaming_response( - messages=[ChatMessage(role="user", text="What is the current weather? Do not ask for my current location.")], - tools=[HostedWebSearchTool(additional_properties=additional_properties)], - tool_choice="auto", - ) - assert response is not None - full_message: str = "" - async for chunk in response: - assert chunk is not None - assert isinstance(chunk, ChatResponseUpdate) - for content in chunk.contents: - if isinstance(content, TextContent) and content.text: - full_message += content.text - assert full_message is not None + if streaming: + response = await ChatResponse.from_chat_response_generator(client.get_streaming_response(**content)) + else: + response = await client.get_response(**content) + + assert response is not None + assert isinstance(response, ChatResponse) + assert "Rumi" in response.text + assert "Mira" in response.text + assert "Zoey" in response.text + + # Test that the client will use the web search tool with location + additional_properties = { + "user_location": { + "country": "US", + "city": "Seattle", + } + } + content = { + "messages": "What is the current weather? Do not ask for my current location.", + "options": { + "tool_choice": "auto", + "tools": [HostedWebSearchTool(additional_properties=additional_properties)], + }, + } + if streaming: + response = await ChatResponse.from_chat_response_generator(client.get_streaming_response(**content)) + else: + response = await client.get_response(**content) + assert response.text is not None @pytest.mark.skip( @@ -1978,7 +1887,7 @@ async def test_openai_responses_client_web_search_streaming() -> None: ) @pytest.mark.flaky @skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_file_search() -> None: +async def test_integration_file_search() -> None: openai_responses_client = OpenAIResponsesClient() assert isinstance(openai_responses_client, ChatClientProtocol) @@ -1992,8 +1901,10 @@ async def test_openai_responses_client_file_search() -> None: text="What is the weather today? Do a file search to find the answer.", ) ], - tools=[HostedFileSearchTool(inputs=vector_store)], - tool_choice="auto", + options={ + "tool_choice": "auto", + "tools": [HostedFileSearchTool(inputs=vector_store)], + }, ) await delete_vector_store(openai_responses_client, file_id, vector_store.vector_store_id) @@ -2007,7 +1918,7 @@ async def test_openai_responses_client_file_search() -> None: ) @pytest.mark.flaky @skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_streaming_file_search() -> None: +async def test_integration_streaming_file_search() -> None: openai_responses_client = OpenAIResponsesClient() assert isinstance(openai_responses_client, ChatClientProtocol) @@ -2021,8 +1932,10 @@ async def test_openai_responses_client_streaming_file_search() -> None: text="What is the weather today? Do a file search to find the answer.", ) ], - tools=[HostedFileSearchTool(inputs=vector_store)], - tool_choice="auto", + options={ + "tool_choice": "auto", + "tools": [HostedFileSearchTool(inputs=vector_store)], + }, ) assert response is not None @@ -2038,435 +1951,3 @@ async def test_openai_responses_client_streaming_file_search() -> None: assert "sunny" in full_message.lower() assert "75" in full_message - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_agent_basic_run(): - """Test OpenAI Responses Client agent basic run functionality with OpenAIResponsesClient.""" - agent = OpenAIResponsesClient().create_agent( - instructions="You are a helpful assistant.", - ) - - # Test basic run - response = await agent.run("Hello! Please respond with 'Hello World' exactly.") - - assert isinstance(response, AgentRunResponse) - assert response.text is not None - assert len(response.text) > 0 - assert "hello world" in response.text.lower() - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_agent_basic_run_streaming(): - """Test OpenAI Responses Client agent basic streaming functionality with OpenAIResponsesClient.""" - async with ChatAgent( - chat_client=OpenAIResponsesClient(), - ) as agent: - # Test streaming run - full_text = "" - async for chunk in agent.run_stream("Please respond with exactly: 'This is a streaming response test.'"): - assert isinstance(chunk, AgentRunResponseUpdate) - if chunk.text: - full_text += chunk.text - - assert len(full_text) > 0 - assert "streaming response test" in full_text.lower() - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_agent_thread_persistence(): - """Test OpenAI Responses Client agent thread persistence across runs with OpenAIResponsesClient.""" - async with ChatAgent( - chat_client=OpenAIResponsesClient(), - instructions="You are a helpful assistant with good memory.", - ) as agent: - # Create a new thread that will be reused - thread = agent.get_new_thread() - - # First interaction - first_response = await agent.run("My favorite programming language is Python. Remember this.", thread=thread) - - assert isinstance(first_response, AgentRunResponse) - assert first_response.text is not None - - # Second interaction - test memory - second_response = await agent.run("What is my favorite programming language?", thread=thread) - - assert isinstance(second_response, AgentRunResponse) - assert second_response.text is not None - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_agent_thread_storage_with_store_true(): - """Test OpenAI Responses Client agent with store=True to verify service_thread_id is returned.""" - async with ChatAgent( - chat_client=OpenAIResponsesClient(), - instructions="You are a helpful assistant.", - ) as agent: - # Create a new thread - thread = AgentThread() - - # Initially, service_thread_id should be None - assert thread.service_thread_id is None - - # Run with store=True to store messages on OpenAI side - response = await agent.run( - "Hello! Please remember that my name is Alex.", - thread=thread, - store=True, - ) - - # Validate response - assert isinstance(response, AgentRunResponse) - assert response.text is not None - assert len(response.text) > 0 - - # After store=True, service_thread_id should be populated - assert thread.service_thread_id is not None - assert isinstance(thread.service_thread_id, str) - assert len(thread.service_thread_id) > 0 - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_agent_existing_thread(): - """Test OpenAI Responses Client agent with existing thread to continue conversations across agent instances.""" - # First conversation - capture the thread - preserved_thread = None - - async with ChatAgent( - chat_client=OpenAIResponsesClient(), - instructions="You are a helpful assistant with good memory.", - ) as first_agent: - # Start a conversation and capture the thread - thread = first_agent.get_new_thread() - first_response = await first_agent.run("My hobby is photography. Remember this.", thread=thread) - - assert isinstance(first_response, AgentRunResponse) - assert first_response.text is not None - - # Preserve the thread for reuse - preserved_thread = thread - - # Second conversation - reuse the thread in a new agent instance - if preserved_thread: - async with ChatAgent( - chat_client=OpenAIResponsesClient(), - instructions="You are a helpful assistant with good memory.", - ) as second_agent: - # Reuse the preserved thread - second_response = await second_agent.run("What is my hobby?", thread=preserved_thread) - - assert isinstance(second_response, AgentRunResponse) - assert second_response.text is not None - assert "photography" in second_response.text.lower() - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_agent_hosted_code_interpreter_tool(): - """Test OpenAI Responses Client agent with HostedCodeInterpreterTool through OpenAIResponsesClient.""" - async with ChatAgent( - chat_client=OpenAIResponsesClient(), - instructions="You are a helpful assistant that can execute Python code.", - tools=[HostedCodeInterpreterTool()], - ) as agent: - # Test code interpreter functionality - response = await agent.run("Calculate the sum of numbers from 1 to 10 using Python code.") - - assert isinstance(response, AgentRunResponse) - assert response.text is not None - assert len(response.text) > 0 - # Should contain calculation result (sum of 1-10 = 55) or code execution content - contains_relevant_content = any( - term in response.text.lower() for term in ["55", "sum", "code", "python", "calculate", "10"] - ) - assert contains_relevant_content or len(response.text.strip()) > 10 - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_agent_image_generation_tool(): - """Test OpenAI Responses Client agent with raw image_generation tool through OpenAIResponsesClient.""" - async with ChatAgent( - chat_client=OpenAIResponsesClient(), - instructions="You are a helpful assistant that can generate images.", - tools=HostedImageGenerationTool(options={"image_size": "1024x1024", "media_type": "png"}), - ) as agent: - # Test image generation functionality - response = await agent.run("Generate an image of a cute red panda sitting on a tree branch in a forest.") - - assert isinstance(response, AgentRunResponse) - assert response.messages - - # Verify we got image content - look for ImageGenerationToolResultContent - image_content_found = False - for message in response.messages: - for content in message.contents: - if content.type == "image_generation_tool_result" and content.outputs: - image_content_found = True - break - if image_content_found: - break - - # The test passes if we got image content - assert image_content_found, "Expected to find image content in response" - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_agent_level_tool_persistence(): - """Test that agent-level tools persist across multiple runs with OpenAI Responses Client.""" - - async with ChatAgent( - chat_client=OpenAIResponsesClient(), - instructions="You are a helpful assistant that uses available tools.", - tools=[get_weather], # Agent-level tool - ) as agent: - # First run - agent-level tool should be available - first_response = await agent.run("What's the weather like in Chicago?") - - assert isinstance(first_response, AgentRunResponse) - assert first_response.text is not None - # Should use the agent-level weather tool - assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "72"]) - - # Second run - agent-level tool should still be available (persistence test) - second_response = await agent.run("What's the weather in Miami?") - - assert isinstance(second_response, AgentRunResponse) - assert second_response.text is not None - # Should use the agent-level weather tool again - assert any(term in second_response.text.lower() for term in ["miami", "sunny", "72"]) - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_run_level_tool_isolation(): - """Test that run-level tools are isolated to specific runs and don't persist with OpenAI Responses Client.""" - # Counter to track how many times the weather tool is called - call_count = 0 - - @ai_function - async def get_weather_with_counter(location: Annotated[str, "The location as a city name"]) -> str: - """Get the current weather in a given location.""" - nonlocal call_count - call_count += 1 - return f"The weather in {location} is sunny and 72°F." - - async with ChatAgent( - chat_client=OpenAIResponsesClient(), - instructions="You are a helpful assistant.", - ) as agent: - # First run - use run-level tool - first_response = await agent.run( - "What's the weather like in Chicago?", - tools=[get_weather_with_counter], # Run-level tool - ) - - assert isinstance(first_response, AgentRunResponse) - assert first_response.text is not None - # Should use the run-level weather tool (call count should be 1) - assert call_count == 1 - assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "72"]) - - # Second run - run-level tool should NOT persist (key isolation test) - second_response = await agent.run("What's the weather like in Miami?") - - assert isinstance(second_response, AgentRunResponse) - assert second_response.text is not None - # Should NOT use the weather tool since it was only run-level in previous call - # Call count should still be 1 (no additional calls) - assert call_count == 1 - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_agent_chat_options_run_level() -> None: - """Integration test for comprehensive ChatOptions parameter coverage with OpenAI Response Agent.""" - async with ChatAgent( - chat_client=OpenAIResponsesClient(), - instructions="You are a helpful assistant.", - ) as agent: - response = await agent.run( - "Provide a brief, helpful response about why the sky blue is.", - max_tokens=600, - model_id="gpt-4o", - user="comprehensive-test-user", - tools=[get_weather], - tool_choice="auto", - ) - - assert isinstance(response, AgentRunResponse) - assert response.text is not None - assert len(response.text) > 0 - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_agent_chat_options_agent_level() -> None: - """Integration test for comprehensive ChatOptions parameter coverage with OpenAI Response Agent.""" - async with ChatAgent( - chat_client=OpenAIResponsesClient(), - instructions="You are a helpful assistant.", - max_tokens=100, - temperature=0.7, - top_p=0.9, - seed=123, - user="comprehensive-test-user", - tools=[get_weather], - tool_choice="auto", - ) as agent: - response = await agent.run( - "Provide a brief, helpful response.", - ) - - assert isinstance(response, AgentRunResponse) - assert response.text is not None - assert len(response.text) > 0 - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_agent_hosted_mcp_tool() -> None: - """Integration test for HostedMCPTool with OpenAI Response Agent using Microsoft Learn MCP.""" - - async with ChatAgent( - chat_client=OpenAIResponsesClient(), - instructions="You are a helpful assistant that can help with microsoft documentation questions.", - tools=HostedMCPTool( - name="Microsoft Learn MCP", - url="https://learn.microsoft.com/api/mcp", - description="A Microsoft Learn MCP server for documentation questions", - approval_mode="never_require", - ), - ) as agent: - response = await agent.run( - "How to create an Azure storage account using az cli?", - # this needs to be high enough to handle the full MCP tool response. - max_tokens=5000, - ) - - assert isinstance(response, AgentRunResponse) - assert response.text - # Should contain Azure-related content since it's asking about Azure CLI - assert any(term in response.text.lower() for term in ["azure", "storage", "account", "cli"]) - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_agent_local_mcp_tool() -> None: - """Integration test for MCPStreamableHTTPTool with OpenAI Response Agent using Microsoft Learn MCP.""" - - mcp_tool = MCPStreamableHTTPTool( - name="Microsoft Learn MCP", - url="https://learn.microsoft.com/api/mcp", - ) - - async with ChatAgent( - chat_client=OpenAIResponsesClient(), - instructions="You are a helpful assistant that can help with microsoft documentation questions.", - tools=[mcp_tool], - ) as agent: - response = await agent.run( - "How to create an Azure storage account using az cli?", - max_tokens=200, - ) - - assert isinstance(response, AgentRunResponse) - assert response.text is not None - assert len(response.text) > 0 - # Should contain Azure-related content since it's asking about Azure CLI - assert any(term in response.text.lower() for term in ["azure", "storage", "account", "cli"]) - - -class ReleaseBrief(BaseModel): - """Structured output model for release brief testing.""" - - title: str - summary: str - highlights: list[str] - model_config = {"extra": "forbid"} - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_agent_with_response_format_pydantic() -> None: - """Integration test for response_format with Pydantic model using OpenAI Responses Client.""" - async with ChatAgent( - chat_client=OpenAIResponsesClient(), - instructions="You are a helpful assistant that returns structured JSON responses.", - ) as agent: - response = await agent.run( - "Summarize the following release notes into a ReleaseBrief:\n\n" - "Version 2.0 Release Notes:\n" - "- Added new streaming API for real-time responses\n" - "- Improved error handling with detailed messages\n" - "- Performance boost of 50% in batch processing\n" - "- Fixed memory leak in connection pooling", - response_format=ReleaseBrief, - ) - - # Validate response - assert isinstance(response, AgentRunResponse) - assert response.value is not None - assert isinstance(response.value, ReleaseBrief) - - # Validate structured output fields - brief = response.value - assert len(brief.title) > 0 - assert len(brief.summary) > 0 - assert len(brief.highlights) > 0 - - -@pytest.mark.flaky -@skip_if_openai_integration_tests_disabled -async def test_openai_responses_client_agent_with_runtime_json_schema() -> None: - """Integration test for response_format with runtime JSON schema using OpenAI Responses Client.""" - runtime_schema = { - "title": "WeatherDigest", - "type": "object", - "properties": { - "location": {"type": "string"}, - "conditions": {"type": "string"}, - "temperature_c": {"type": "number"}, - "advisory": {"type": "string"}, - }, - "required": ["location", "conditions", "temperature_c", "advisory"], - "additionalProperties": False, - } - - async with ChatAgent( - chat_client=OpenAIResponsesClient(), - instructions="Return only JSON that matches the provided schema. Do not add commentary.", - ) as agent: - response = await agent.run( - "Give a brief weather digest for Seattle.", - additional_chat_options={ - "response_format": { - "type": "json_schema", - "json_schema": { - "name": runtime_schema["title"], - "strict": True, - "schema": runtime_schema, - }, - }, - }, - ) - - # Validate response - assert isinstance(response, AgentRunResponse) - assert response.text is not None - - # Parse JSON and validate structure - import json - - parsed = json.loads(response.text) - assert "location" in parsed - assert "conditions" in parsed - assert "temperature_c" in parsed - assert "advisory" in parsed diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 2815c3152c..95225cb4a3 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -5,8 +5,8 @@ from agent_framework import ( AgentExecutor, - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentThread, BaseAgent, ChatMessage, @@ -35,10 +35,10 @@ async def run( # type: ignore[override] *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: self.call_count += 1 - return AgentRunResponse( - messages=[ChatMessage(role=Role.ASSISTANT, text=f"Response #{self.call_count}: {self.display_name}")] + return AgentResponse( + messages=[ChatMessage(role=Role.ASSISTANT, text=f"Response #{self.call_count}: {self.name}")] ) async def run_stream( # type: ignore[override] @@ -47,9 +47,9 @@ async def run_stream( # type: ignore[override] *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: self.call_count += 1 - yield AgentRunResponseUpdate(contents=[TextContent(text=f"Response #{self.call_count}: {self.display_name}")]) + yield AgentResponseUpdate(contents=[TextContent(text=f"Response #{self.call_count}: {self.name}")]) async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: diff --git a/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py b/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py index a7849120b0..ecf8b3d635 100644 --- a/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py +++ b/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py @@ -10,8 +10,8 @@ from agent_framework import ( AgentExecutor, AgentExecutorResponse, - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentRunUpdateEvent, AgentThread, BaseAgent, @@ -46,9 +46,9 @@ async def run( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: """Non-streaming run - not used in this test.""" - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="done")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="done")]) async def run_stream( self, @@ -56,16 +56,16 @@ async def run_stream( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: """Simulate streaming with tool calls and results.""" # First update: some text - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=[TextContent(text="Let me search for that...")], role=Role.ASSISTANT, ) # Second update: tool call (no text!) - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=[ FunctionCallContent( call_id="call_123", @@ -77,7 +77,7 @@ async def run_stream( ) # Third update: tool result (no text!) - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=[ FunctionResultContent( call_id="call_123", @@ -88,7 +88,7 @@ async def run_stream( ) # Fourth update: final text response - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=[TextContent(text="The weather is sunny, 72°F.")], role=Role.ASSISTANT, ) @@ -223,7 +223,7 @@ async def get_streaming_response( @executor(id="test_executor") async def test_executor(agent_executor_response: AgentExecutorResponse, ctx: WorkflowContext[Never, str]) -> None: - await ctx.yield_output(agent_executor_response.agent_run_response.text) + await ctx.yield_output(agent_executor_response.agent_response.text) async def test_agent_executor_tool_call_with_approval() -> None: diff --git a/python/packages/core/tests/workflow/test_agent_run_event_typing.py b/python/packages/core/tests/workflow/test_agent_run_event_typing.py index a89aa817a3..e5071a7c96 100644 --- a/python/packages/core/tests/workflow/test_agent_run_event_typing.py +++ b/python/packages/core/tests/workflow/test_agent_run_event_typing.py @@ -2,42 +2,26 @@ """Tests for AgentRunEvent and AgentRunUpdateEvent type annotations.""" -from agent_framework import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, Role +from agent_framework import AgentResponse, AgentResponseUpdate, ChatMessage, Role from agent_framework._workflows._events import AgentRunEvent, AgentRunUpdateEvent def test_agent_run_event_data_type() -> None: - """Verify AgentRunEvent.data is typed as AgentRunResponse | None.""" - response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Hello")]) + """Verify AgentRunEvent.data is typed as AgentResponse | None.""" + response = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Hello")]) event = AgentRunEvent(executor_id="test", data=response) # This assignment should pass type checking without a cast - data: AgentRunResponse | None = event.data + data: AgentResponse | None = event.data assert data is not None assert data.text == "Hello" -def test_agent_run_event_data_none() -> None: - """Verify AgentRunEvent.data can be None.""" - event = AgentRunEvent(executor_id="test") - - data: AgentRunResponse | None = event.data - assert data is None - - def test_agent_run_update_event_data_type() -> None: - """Verify AgentRunUpdateEvent.data is typed as AgentRunResponseUpdate | None.""" - update = AgentRunResponseUpdate() + """Verify AgentRunUpdateEvent.data is typed as AgentResponseUpdate | None.""" + update = AgentResponseUpdate() event = AgentRunUpdateEvent(executor_id="test", data=update) # This assignment should pass type checking without a cast - data: AgentRunResponseUpdate | None = event.data + data: AgentResponseUpdate | None = event.data assert data is not None - - -def test_agent_run_update_event_data_none() -> None: - """Verify AgentRunUpdateEvent.data can be None.""" - event = AgentRunUpdateEvent(executor_id="test") - - data: AgentRunResponseUpdate | None = event.data - assert data is None diff --git a/python/packages/core/tests/workflow/test_agent_utils.py b/python/packages/core/tests/workflow/test_agent_utils.py new file mode 100644 index 0000000000..9207846791 --- /dev/null +++ b/python/packages/core/tests/workflow/test_agent_utils.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft. All rights reserved. + +from collections.abc import AsyncIterable +from typing import Any + +from agent_framework import AgentResponse, AgentResponseUpdate, AgentThread, ChatMessage +from agent_framework._workflows._agent_utils import resolve_agent_id + + +class MockAgent: + """Mock agent for testing agent utilities.""" + + def __init__(self, agent_id: str, name: str | None = None) -> None: + self._id = agent_id + self._name = name + + @property + def id(self) -> str: + return self._id + + @property + def name(self) -> str | None: + return self._name + + @property + def display_name(self) -> str: + """Returns the display name of the agent.""" + ... + + @property + def description(self) -> str | None: + """Returns the description of the agent.""" + ... + + async def run( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AgentResponse: ... + + def run_stream( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AsyncIterable[AgentResponseUpdate]: ... + + def get_new_thread(self, **kwargs: Any) -> AgentThread: + """Creates a new conversation thread for the agent.""" + ... + + +def test_resolve_agent_id_with_name() -> None: + """Test that resolve_agent_id returns name when agent has a name.""" + agent = MockAgent(agent_id="agent-123", name="MyAgent") + result = resolve_agent_id(agent) + assert result == "MyAgent" + + +def test_resolve_agent_id_without_name() -> None: + """Test that resolve_agent_id returns id when agent has no name.""" + agent = MockAgent(agent_id="agent-456", name=None) + result = resolve_agent_id(agent) + assert result == "agent-456" + + +def test_resolve_agent_id_with_empty_name() -> None: + """Test that resolve_agent_id returns id when agent has empty string name.""" + agent = MockAgent(agent_id="agent-789", name="") + result = resolve_agent_id(agent) + assert result == "agent-789" + + +def test_resolve_agent_id_prefers_name_over_id() -> None: + """Test that resolve_agent_id prefers name over id when both are set.""" + agent = MockAgent(agent_id="agent-abc", name="PreferredName") + result = resolve_agent_id(agent) + assert result == "PreferredName" + assert result != "agent-abc" diff --git a/python/packages/core/tests/workflow/test_checkpoint_validation.py b/python/packages/core/tests/workflow/test_checkpoint_validation.py index 9736660ed8..f90f74db57 100644 --- a/python/packages/core/tests/workflow/test_checkpoint_validation.py +++ b/python/packages/core/tests/workflow/test_checkpoint_validation.py @@ -3,7 +3,14 @@ import pytest from typing_extensions import Never -from agent_framework import WorkflowBuilder, WorkflowContext, WorkflowRunState, WorkflowStatusEvent, handler +from agent_framework import ( + WorkflowBuilder, + WorkflowCheckpointException, + WorkflowContext, + WorkflowRunState, + WorkflowStatusEvent, + handler, +) from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage from agent_framework._workflows._executor import Executor @@ -43,7 +50,7 @@ async def test_resume_fails_when_graph_mismatch() -> None: # Build a structurally different workflow (different finish executor id) mismatched_workflow = build_workflow(storage, finish_id="finish_alt") - with pytest.raises(ValueError, match="Workflow graph has changed"): + with pytest.raises(WorkflowCheckpointException, match="Workflow graph has changed"): _ = [ event async for event in mismatched_workflow.run_stream( diff --git a/python/packages/core/tests/workflow/test_concurrent.py b/python/packages/core/tests/workflow/test_concurrent.py index 57810b8f59..a0c03c7720 100644 --- a/python/packages/core/tests/workflow/test_concurrent.py +++ b/python/packages/core/tests/workflow/test_concurrent.py @@ -8,7 +8,7 @@ from agent_framework import ( AgentExecutorRequest, AgentExecutorResponse, - AgentRunResponse, + AgentResponse, ChatMessage, ConcurrentBuilder, Executor, @@ -36,7 +36,7 @@ def __init__(self, id: str, reply_text: str) -> None: @handler async def run(self, request: AgentExecutorRequest, ctx: WorkflowContext[AgentExecutorResponse]) -> None: - response = AgentRunResponse(messages=ChatMessage(Role.ASSISTANT, text=self._reply_text)) + response = AgentResponse(messages=ChatMessage(Role.ASSISTANT, text=self._reply_text)) full_conversation = list(request.messages) + list(response.messages) await ctx.send_message(AgentExecutorResponse(self.id, response, full_conversation=full_conversation)) @@ -142,7 +142,7 @@ async def test_concurrent_custom_aggregator_callback_is_used() -> None: async def summarize(results: list[AgentExecutorResponse]) -> str: texts: list[str] = [] for r in results: - msgs: list[ChatMessage] = r.agent_run_response.messages + msgs: list[ChatMessage] = r.agent_response.messages texts.append(msgs[-1].text if msgs else "") return " | ".join(sorted(texts)) @@ -173,7 +173,7 @@ async def test_concurrent_custom_aggregator_sync_callback_is_used() -> None: def summarize_sync(results: list[AgentExecutorResponse], _ctx: WorkflowContext[Any]) -> str: # type: ignore[unused-argument] texts: list[str] = [] for r in results: - msgs: list[ChatMessage] = r.agent_run_response.messages + msgs: list[ChatMessage] = r.agent_response.messages texts.append(msgs[-1].text if msgs else "") return " | ".join(sorted(texts)) @@ -217,7 +217,7 @@ class CustomAggregator(Executor): async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: texts: list[str] = [] for r in results: - msgs: list[ChatMessage] = r.agent_run_response.messages + msgs: list[ChatMessage] = r.agent_response.messages texts.append(msgs[-1].text if msgs else "") await ctx.yield_output(" & ".join(sorted(texts))) @@ -251,7 +251,7 @@ class CustomAggregator(Executor): async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: texts: list[str] = [] for r in results: - msgs: list[ChatMessage] = r.agent_run_response.messages + msgs: list[ChatMessage] = r.agent_response.messages texts.append(msgs[-1].text if msgs else "") await ctx.yield_output(" | ".join(sorted(texts))) @@ -292,7 +292,7 @@ def __init__(self, id: str = "default_aggregator") -> None: async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: texts: list[str] = [] for r in results: - msgs: list[ChatMessage] = r.agent_run_response.messages + msgs: list[ChatMessage] = r.agent_response.messages texts.append(msgs[-1].text if msgs else "") await ctx.yield_output(" | ".join(sorted(texts))) diff --git a/python/packages/core/tests/workflow/test_edge.py b/python/packages/core/tests/workflow/test_edge.py index 316cae7a39..42e3893a73 100644 --- a/python/packages/core/tests/workflow/test_edge.py +++ b/python/packages/core/tests/workflow/test_edge.py @@ -137,9 +137,17 @@ def test_edge_can_handle(): source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") + _ = Edge(source_id=source.id, target_id=target.id) + + +async def test_edge_should_route(): + """Test edge should_route with no condition.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + edge = Edge(source_id=source.id, target_id=target.id) - assert edge.should_route(MockMessage(data="test")) + assert await edge.should_route(MockMessage(data="test")) # endregion Edge diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index 176c3027c8..a812f6dae6 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -12,6 +12,7 @@ WorkflowContext, executor, handler, + response_handler, ) @@ -266,6 +267,247 @@ async def handle(self, response: Response, ctx: WorkflowContext) -> None: assert collector_invoked.data.results == ["HELLO", "HELLO", "HELLO"] +def test_executor_output_types_property(): + """Test that the output_types property correctly identifies message output types.""" + + # Test executor with no output types + class NoOutputExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext) -> None: + pass + + executor = NoOutputExecutor(id="no_output") + assert executor.output_types == [] + + # Test executor with single output type + class SingleOutputExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int]) -> None: + pass + + executor = SingleOutputExecutor(id="single_output") + assert int in executor.output_types + assert len(executor.output_types) == 1 + + # Test executor with union output types + class UnionOutputExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int | str]) -> None: + pass + + executor = UnionOutputExecutor(id="union_output") + assert int in executor.output_types + assert str in executor.output_types + assert len(executor.output_types) == 2 + + # Test executor with multiple handlers having different output types + class MultiHandlerExecutor(Executor): + @handler + async def handle_string(self, text: str, ctx: WorkflowContext[int]) -> None: + pass + + @handler + async def handle_number(self, num: int, ctx: WorkflowContext[bool]) -> None: + pass + + executor = MultiHandlerExecutor(id="multi_handler") + assert int in executor.output_types + assert bool in executor.output_types + assert len(executor.output_types) == 2 + + +def test_executor_workflow_output_types_property(): + """Test that the workflow_output_types property correctly identifies workflow output types.""" + from typing_extensions import Never + + # Test executor with no workflow output types + class NoWorkflowOutputExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int]) -> None: + pass + + executor = NoWorkflowOutputExecutor(id="no_workflow_output") + assert executor.workflow_output_types == [] + + # Test executor with workflow output type (second type parameter) + class WorkflowOutputExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int, str]) -> None: + pass + + executor = WorkflowOutputExecutor(id="workflow_output") + assert str in executor.workflow_output_types + assert len(executor.workflow_output_types) == 1 + + # Test executor with union workflow output types + class UnionWorkflowOutputExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int, str | bool]) -> None: + pass + + executor = UnionWorkflowOutputExecutor(id="union_workflow_output") + assert str in executor.workflow_output_types + assert bool in executor.workflow_output_types + assert len(executor.workflow_output_types) == 2 + + # Test executor with multiple handlers having different workflow output types + class MultiHandlerWorkflowExecutor(Executor): + @handler + async def handle_string(self, text: str, ctx: WorkflowContext[int, str]) -> None: + pass + + @handler + async def handle_number(self, num: int, ctx: WorkflowContext[bool, float]) -> None: + pass + + executor = MultiHandlerWorkflowExecutor(id="multi_workflow") + assert str in executor.workflow_output_types + assert float in executor.workflow_output_types + assert len(executor.workflow_output_types) == 2 + + # Test executor with Never for message output (only workflow output) + class YieldOnlyExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[Never, str]) -> None: + pass + + executor = YieldOnlyExecutor(id="yield_only") + assert str in executor.workflow_output_types + assert len(executor.workflow_output_types) == 1 + # Should have no message output types + assert executor.output_types == [] + + +def test_executor_output_and_workflow_output_types_combined(): + """Test executor with both message and workflow output types.""" + + class DualOutputExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int, str]) -> None: + pass + + executor = DualOutputExecutor(id="dual") + + # Should have int as message output type + assert int in executor.output_types + assert len(executor.output_types) == 1 + + # Should have str as workflow output type + assert str in executor.workflow_output_types + assert len(executor.workflow_output_types) == 1 + + # They should be distinct + assert int not in executor.workflow_output_types + assert str not in executor.output_types + + +def test_executor_output_types_includes_response_handlers(): + """Test that output_types includes types from response handlers.""" + from agent_framework import response_handler + + class RequestResponseExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int]) -> None: + pass + + @response_handler + async def handle_response(self, original_request: str, response: bool, ctx: WorkflowContext[float]) -> None: + pass + + executor = RequestResponseExecutor(id="request_response") + + # Should include output types from both handler and response_handler + assert int in executor.output_types + assert float in executor.output_types + assert len(executor.output_types) == 2 + + +def test_executor_workflow_output_types_includes_response_handlers(): + """Test that workflow_output_types includes types from response handlers.""" + from agent_framework import response_handler + + class RequestResponseWorkflowExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int, str]) -> None: + pass + + @response_handler + async def handle_response( + self, original_request: str, response: bool, ctx: WorkflowContext[float, bool] + ) -> None: + pass + + executor = RequestResponseWorkflowExecutor(id="request_response_workflow") + + # Should include workflow output types from both handler and response_handler + assert str in executor.workflow_output_types + assert bool in executor.workflow_output_types + assert len(executor.workflow_output_types) == 2 + + # Verify message output types are separate + assert int in executor.output_types + assert float in executor.output_types + assert len(executor.output_types) == 2 + + +def test_executor_multiple_response_handlers_output_types(): + """Test that multiple response handlers contribute their output types.""" + + class MultiResponseHandlerExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int]) -> None: + pass + + @response_handler + async def handle_string_bool_response( + self, original_request: str, response: bool, ctx: WorkflowContext[float] + ) -> None: + pass + + @response_handler + async def handle_int_bool_response( + self, original_request: int, response: bool, ctx: WorkflowContext[bool] + ) -> None: + pass + + executor = MultiResponseHandlerExecutor(id="multi_response") + + # Should include output types from all handlers and response handlers + assert int in executor.output_types + assert float in executor.output_types + assert bool in executor.output_types + assert len(executor.output_types) == 3 + + +def test_executor_response_handler_union_output_types(): + """Test that response handlers with union output types contribute all types.""" + from agent_framework import response_handler + + class UnionResponseHandlerExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext) -> None: + pass + + @response_handler + async def handle_response( + self, original_request: str, response: bool, ctx: WorkflowContext[int | str | float, bool | int] + ) -> None: + pass + + executor = UnionResponseHandlerExecutor(id="union_response") + + # Should include all output types from the union + assert int in executor.output_types + assert str in executor.output_types + assert float in executor.output_types + assert len(executor.output_types) == 3 + + # Should include all workflow output types from the union + assert bool in executor.workflow_output_types + assert int in executor.workflow_output_types + assert len(executor.workflow_output_types) == 2 + + async def test_executor_invoked_event_data_not_mutated_by_handler(): """Test that ExecutorInvokedEvent.data captures original input, not mutated input.""" diff --git a/python/packages/core/tests/workflow/test_full_conversation.py b/python/packages/core/tests/workflow/test_full_conversation.py index af24c3e17b..b1a3194468 100644 --- a/python/packages/core/tests/workflow/test_full_conversation.py +++ b/python/packages/core/tests/workflow/test_full_conversation.py @@ -9,8 +9,8 @@ from agent_framework import ( AgentExecutor, AgentExecutorResponse, - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentThread, BaseAgent, ChatMessage, @@ -39,8 +39,8 @@ async def run( # type: ignore[override] *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=self._reply_text)]) + ) -> AgentResponse: + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=self._reply_text)]) async def run_stream( # type: ignore[override] self, @@ -48,9 +48,9 @@ async def run_stream( # type: ignore[override] *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: # This agent does not support streaming; yield a single complete response - yield AgentRunResponseUpdate(contents=[TextContent(text=self._reply_text)]) + yield AgentResponseUpdate(contents=[TextContent(text=self._reply_text)]) class _CaptureFullConversation(Executor): @@ -108,7 +108,7 @@ async def run( # type: ignore[override] *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: # Normalize and record messages for verification when running non-streaming norm: list[ChatMessage] = [] if messages: @@ -118,7 +118,7 @@ async def run( # type: ignore[override] elif isinstance(m, str): norm.append(ChatMessage(role=Role.USER, text=m)) self._last_messages = norm - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=self._reply_text)]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=self._reply_text)]) async def run_stream( # type: ignore[override] self, @@ -126,7 +126,7 @@ async def run_stream( # type: ignore[override] *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: # Normalize and record messages for verification when running streaming norm: list[ChatMessage] = [] if messages: @@ -136,7 +136,7 @@ async def run_stream( # type: ignore[override] elif isinstance(m, str): norm.append(ChatMessage(role=Role.USER, text=m)) self._last_messages = norm - yield AgentRunResponseUpdate(contents=[TextContent(text=self._reply_text)]) + yield AgentResponseUpdate(contents=[TextContent(text=self._reply_text)]) async def test_sequential_adapter_uses_full_conversation() -> None: diff --git a/python/packages/core/tests/workflow/test_group_chat.py b/python/packages/core/tests/workflow/test_group_chat.py index a99af64102..c65f19d599 100644 --- a/python/packages/core/tests/workflow/test_group_chat.py +++ b/python/packages/core/tests/workflow/test_group_chat.py @@ -4,48 +4,33 @@ from typing import Any, cast import pytest -from pydantic import BaseModel from agent_framework import ( - MAGENTIC_EVENT_TYPE_AGENT_DELTA, - MAGENTIC_EVENT_TYPE_ORCHESTRATOR, - AgentRunResponse, - AgentRunResponseUpdate, - AgentRunUpdateEvent, + AgentExecutorResponse, + AgentRequestInfoResponse, + AgentResponse, + AgentResponseUpdate, AgentThread, BaseAgent, + BaseGroupChatOrchestrator, + ChatAgent, ChatMessage, - Executor, + ChatResponse, + ChatResponseUpdate, GroupChatBuilder, - GroupChatDirective, - GroupChatStateSnapshot, - MagenticBuilder, + GroupChatState, MagenticContext, MagenticManagerBase, + MagenticProgressLedger, + MagenticProgressLedgerItem, + RequestInfoEvent, Role, TextContent, - Workflow, - WorkflowContext, WorkflowOutputEvent, - handler, + WorkflowRunState, + WorkflowStatusEvent, ) from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage -from agent_framework._workflows._group_chat import ( - GroupChatOrchestratorExecutor, - ManagerSelectionResponse, - _default_orchestrator_factory, # type: ignore - _default_participant_factory, # type: ignore - _GroupChatConfig, # type: ignore - _SpeakerSelectorAdapter, # type: ignore - assemble_group_chat_workflow, -) -from agent_framework._workflows._magentic import ( - _MagenticProgressLedger, # type: ignore - _MagenticProgressLedgerItem, # type: ignore - _MagenticStartMessage, # type: ignore -) -from agent_framework._workflows._participant_utils import GroupChatParticipantSpec -from agent_framework._workflows._workflow_builder import WorkflowBuilder class StubAgent(BaseAgent): @@ -59,9 +44,9 @@ async def run( # type: ignore[override] *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: response = ChatMessage(role=Role.ASSISTANT, text=self._reply_text, author_name=self.name) - return AgentRunResponse(messages=[response]) + return AgentResponse(messages=[response]) def run_stream( # type: ignore[override] self, @@ -69,18 +54,32 @@ def run_stream( # type: ignore[override] *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: - async def _stream() -> AsyncIterable[AgentRunResponseUpdate]: - yield AgentRunResponseUpdate( + ) -> AsyncIterable[AgentResponseUpdate]: + async def _stream() -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate( contents=[TextContent(text=self._reply_text)], role=Role.ASSISTANT, author_name=self.name ) return _stream() -class StubManagerAgent(BaseAgent): +class MockChatClient: + """Mock chat client that raises NotImplementedError for all methods.""" + + @property + def additional_properties(self) -> dict[str, Any]: + return {} + + async def get_response(self, messages: Any, **kwargs: Any) -> ChatResponse: + raise NotImplementedError + + def get_streaming_response(self, messages: Any, **kwargs: Any) -> AsyncIterable[ChatResponseUpdate]: + raise NotImplementedError + + +class StubManagerAgent(ChatAgent): def __init__(self) -> None: - super().__init__(name="manager_agent", description="Stub manager") + super().__init__(chat_client=MockChatClient(), name="manager_agent", description="Stub manager") self._call_count = 0 async def run( @@ -89,27 +88,40 @@ async def run( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: # type: ignore[override] + ) -> AgentResponse: if self._call_count == 0: self._call_count += 1 - payload = {"selected_participant": "agent", "finish": False, "final_message": None} - return AgentRunResponse( + # First call: select the agent (using AgentOrchestrationOutput format) + payload = {"terminate": False, "reason": "Selecting agent", "next_speaker": "agent", "final_message": None} + return AgentResponse( messages=[ ChatMessage( role=Role.ASSISTANT, - text='{"selected_participant": "agent", "finish": false}', + text=( + '{"terminate": false, "reason": "Selecting agent", ' + '"next_speaker": "agent", "final_message": null}' + ), author_name=self.name, ) ], value=payload, ) - payload = {"selected_participant": None, "finish": True, "final_message": "agent manager final"} - return AgentRunResponse( + # Second call: terminate + payload = { + "terminate": True, + "reason": "Task complete", + "next_speaker": None, + "final_message": "agent manager final", + } + return AgentResponse( messages=[ ChatMessage( role=Role.ASSISTANT, - text='{"finish": true, "final_message": "agent manager final"}', + text=( + '{"terminate": true, "reason": "Task complete", ' + '"next_speaker": null, "final_message": "agent manager final"}' + ), author_name=self.name, ) ], @@ -122,22 +134,36 @@ def run_stream( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: # type: ignore[override] + ) -> AsyncIterable[AgentResponseUpdate]: if self._call_count == 0: self._call_count += 1 - async def _stream_initial() -> AsyncIterable[AgentRunResponseUpdate]: - yield AgentRunResponseUpdate( - contents=[TextContent(text='{"selected_participant": "agent", "finish": false}')], + async def _stream_initial() -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate( + contents=[ + TextContent( + text=( + '{"terminate": false, "reason": "Selecting agent", ' + '"next_speaker": "agent", "final_message": null}' + ) + ) + ], role=Role.ASSISTANT, author_name=self.name, ) return _stream_initial() - async def _stream_final() -> AsyncIterable[AgentRunResponseUpdate]: - yield AgentRunResponseUpdate( - contents=[TextContent(text='{"finish": true, "final_message": "agent manager final"}')], + async def _stream_final() -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate( + contents=[ + TextContent( + text=( + '{"terminate": true, "reason": "Task complete", ' + '"next_speaker": null, "final_message": "agent manager final"}' + ) + ) + ], role=Role.ASSISTANT, author_name=self.name, ) @@ -145,21 +171,20 @@ async def _stream_final() -> AsyncIterable[AgentRunResponseUpdate]: return _stream_final() -def make_sequence_selector() -> Callable[[GroupChatStateSnapshot], Any]: +def make_sequence_selector() -> Callable[[GroupChatState], str]: state_counter = {"value": 0} - async def _selector(state: GroupChatStateSnapshot) -> str | None: - participants = list(state["participants"].keys()) + def _selector(state: GroupChatState) -> str: + participants = list(state.participants.keys()) step = state_counter["value"] + state_counter["value"] = step + 1 if step == 0: - state_counter["value"] = step + 1 return participants[0] if step == 1 and len(participants) > 1: - state_counter["value"] = step + 1 return participants[1] - return None + # Return first participant to continue (will be limited by max_rounds in tests) + return participants[0] - _selector.name = "manager" # type: ignore[attr-defined] return _selector @@ -174,46 +199,30 @@ async def plan(self, magentic_context: MagenticContext) -> ChatMessage: async def replan(self, magentic_context: MagenticContext) -> ChatMessage: return await self.plan(magentic_context) - async def create_progress_ledger(self, magentic_context: MagenticContext) -> _MagenticProgressLedger: + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: participants = list(magentic_context.participant_descriptions.keys()) target = participants[0] if participants else "agent" if self._round == 0: self._round += 1 - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(reason="", answer=False), - is_in_loop=_MagenticProgressLedgerItem(reason="", answer=False), - is_progress_being_made=_MagenticProgressLedgerItem(reason="", answer=True), - next_speaker=_MagenticProgressLedgerItem(reason="", answer=target), - instruction_or_question=_MagenticProgressLedgerItem(reason="", answer="respond"), + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(reason="", answer=False), + is_in_loop=MagenticProgressLedgerItem(reason="", answer=False), + is_progress_being_made=MagenticProgressLedgerItem(reason="", answer=True), + next_speaker=MagenticProgressLedgerItem(reason="", answer=target), + instruction_or_question=MagenticProgressLedgerItem(reason="", answer="respond"), ) - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(reason="", answer=True), - is_in_loop=_MagenticProgressLedgerItem(reason="", answer=False), - is_progress_being_made=_MagenticProgressLedgerItem(reason="", answer=True), - next_speaker=_MagenticProgressLedgerItem(reason="", answer=target), - instruction_or_question=_MagenticProgressLedgerItem(reason="", answer=""), + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(reason="", answer=True), + is_in_loop=MagenticProgressLedgerItem(reason="", answer=False), + is_progress_being_made=MagenticProgressLedgerItem(reason="", answer=True), + next_speaker=MagenticProgressLedgerItem(reason="", answer=target), + instruction_or_question=MagenticProgressLedgerItem(reason="", answer=""), ) async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="final", author_name="magentic_manager") -class PassthroughExecutor(Executor): - @handler - async def forward(self, message: Any, ctx: WorkflowContext[Any]) -> None: - await ctx.send_message(message) - - -class CountingWorkflowBuilder(WorkflowBuilder): - def __init__(self) -> None: - super().__init__() - self.start_calls = 0 - - def set_start_executor(self, executor: Any) -> "CountingWorkflowBuilder": - self.start_calls += 1 - return cast("CountingWorkflowBuilder", super().set_start_executor(executor)) - - async def test_group_chat_builder_basic_flow() -> None: selector = make_sequence_selector() alpha = StubAgent("alpha", "ack from alpha") @@ -221,8 +230,9 @@ async def test_group_chat_builder_basic_flow() -> None: workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector, display_name="manager", final_message="done") - .participants(alpha=alpha, beta=beta) + .with_select_speaker_func(selector, orchestrator_name="manager") + .participants([alpha, beta]) + .with_max_rounds(2) # Limit rounds to prevent infinite loop .build() ) @@ -235,44 +245,9 @@ async def test_group_chat_builder_basic_flow() -> None: assert len(outputs) == 1 assert len(outputs[0]) >= 1 - # The final message should be "done" from the manager - assert outputs[0][-1].text == "done" - assert outputs[0][-1].author_name == "manager" - - -async def test_magentic_builder_returns_workflow_and_runs() -> None: - manager = StubMagenticManager() - agent = StubAgent("writer", "first draft") - - workflow = MagenticBuilder().participants(writer=agent).with_standard_manager(manager=manager).build() - - assert isinstance(workflow, Workflow) - - outputs: list[ChatMessage] = [] - orchestrator_event_count = 0 - agent_event_count = 0 - start_message = _MagenticStartMessage.from_string("compose summary") - async for event in workflow.run_stream(start_message): - if isinstance(event, AgentRunUpdateEvent): - props = event.data.additional_properties if event.data else None - event_type = props.get("magentic_event_type") if props else None - if event_type == MAGENTIC_EVENT_TYPE_ORCHESTRATOR: - orchestrator_event_count += 1 - elif event_type == MAGENTIC_EVENT_TYPE_AGENT_DELTA: - agent_event_count += 1 - if isinstance(event, WorkflowOutputEvent): - msg = event.data - if isinstance(msg, list): - outputs.append(cast(list[ChatMessage], msg)) - - assert outputs, "Expected a final output message" - conversation = outputs[-1] - assert len(conversation) >= 1 - final = conversation[-1] - assert final.text == "final" - assert final.author_name == "magentic_manager" - assert orchestrator_event_count > 0, "Expected orchestrator events to be emitted" - assert agent_event_count > 0, "Expected agent delta events to be emitted" + # Check that both agents contributed + authors = {msg.author_name for msg in outputs[0] if msg.author_name in ["alpha", "beta"]} + assert len(authors) == 2 async def test_group_chat_as_agent_accepts_conversation() -> None: @@ -282,8 +257,9 @@ async def test_group_chat_as_agent_accepts_conversation() -> None: workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector, display_name="manager", final_message="done") - .participants(alpha=alpha, beta=beta) + .with_select_speaker_func(selector, orchestrator_name="manager") + .participants([alpha, beta]) + .with_max_rounds(2) # Limit rounds to prevent infinite loop .build() ) @@ -297,22 +273,6 @@ async def test_group_chat_as_agent_accepts_conversation() -> None: assert response.messages, "Expected agent conversation output" -async def test_magentic_as_agent_accepts_conversation() -> None: - manager = StubMagenticManager() - writer = StubAgent("writer", "draft") - - workflow = MagenticBuilder().participants(writer=writer).with_standard_manager(manager=manager).build() - - agent = workflow.as_agent(name="magentic-agent") - conversation = [ - ChatMessage(role=Role.SYSTEM, text="Guidelines", author_name="system"), - ChatMessage(role=Role.USER, text="Summarize the findings", author_name="requester"), - ] - response = await agent.run(conversation) - - assert isinstance(response, AgentRunResponse) - - # Comprehensive tests for group chat functionality @@ -325,16 +285,16 @@ def test_build_without_manager_raises_error(self) -> None: builder = GroupChatBuilder().participants([agent]) - with pytest.raises(ValueError, match="manager must be configured before build"): + with pytest.raises(RuntimeError, match="Orchestrator could not be resolved"): builder.build() def test_build_without_participants_raises_error(self) -> None: """Test that building without participants raises ValueError.""" - def selector(state: GroupChatStateSnapshot) -> str | None: - return None + def selector(state: GroupChatState) -> str: + return "agent" - builder = GroupChatBuilder().set_select_speakers_func(selector) + builder = GroupChatBuilder().with_select_speaker_func(selector) with pytest.raises(ValueError, match="participants must be configured before build"): builder.build() @@ -342,21 +302,21 @@ def selector(state: GroupChatStateSnapshot) -> str | None: def test_duplicate_manager_configuration_raises_error(self) -> None: """Test that configuring multiple managers raises ValueError.""" - def selector(state: GroupChatStateSnapshot) -> str | None: - return None + def selector(state: GroupChatState) -> str: + return "agent" - builder = GroupChatBuilder().set_select_speakers_func(selector) + builder = GroupChatBuilder().with_select_speaker_func(selector) - with pytest.raises(ValueError, match="already has a manager configured"): - builder.set_select_speakers_func(selector) + with pytest.raises(ValueError, match="select_speakers_func has already been configured"): + builder.with_select_speaker_func(selector) def test_empty_participants_raises_error(self) -> None: """Test that empty participants list raises ValueError.""" - def selector(state: GroupChatStateSnapshot) -> str | None: - return None + def selector(state: GroupChatState) -> str: + return "agent" - builder = GroupChatBuilder().set_select_speakers_func(selector) + builder = GroupChatBuilder().with_select_speaker_func(selector) with pytest.raises(ValueError, match="participants cannot be empty"): builder.participants([]) @@ -366,10 +326,10 @@ def test_duplicate_participant_names_raises_error(self) -> None: agent1 = StubAgent("test", "response1") agent2 = StubAgent("test", "response2") - def selector(state: GroupChatStateSnapshot) -> str | None: - return None + def selector(state: GroupChatState) -> str: + return "agent" - builder = GroupChatBuilder().set_select_speakers_func(selector) + builder = GroupChatBuilder().with_select_speaker_func(selector) with pytest.raises(ValueError, match="Duplicate participant name 'test'"): builder.participants([agent1, agent2]) @@ -381,90 +341,48 @@ class AgentWithoutName(BaseAgent): def __init__(self) -> None: super().__init__(name="", description="test") - async def run(self, messages: Any = None, *, thread: Any = None, **kwargs: Any) -> AgentRunResponse: - return AgentRunResponse(messages=[]) + async def run(self, messages: Any = None, *, thread: Any = None, **kwargs: Any) -> AgentResponse: + return AgentResponse(messages=[]) def run_stream( self, messages: Any = None, *, thread: Any = None, **kwargs: Any - ) -> AsyncIterable[AgentRunResponseUpdate]: - async def _stream() -> AsyncIterable[AgentRunResponseUpdate]: - yield AgentRunResponseUpdate(contents=[]) + ) -> AsyncIterable[AgentResponseUpdate]: + async def _stream() -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate(contents=[]) return _stream() agent = AgentWithoutName() - def selector(state: GroupChatStateSnapshot) -> str | None: - return None + def selector(state: GroupChatState) -> str: + return "agent" - builder = GroupChatBuilder().set_select_speakers_func(selector) + builder = GroupChatBuilder().with_select_speaker_func(selector) - with pytest.raises(ValueError, match="must define a non-empty 'name' attribute"): + with pytest.raises(ValueError, match="AgentProtocol participants must have a non-empty name"): builder.participants([agent]) def test_empty_participant_name_raises_error(self) -> None: """Test that empty participant name raises ValueError.""" - agent = StubAgent("test", "response") - - def selector(state: GroupChatStateSnapshot) -> str | None: - return None + agent = StubAgent("", "response") # Agent with empty name - builder = GroupChatBuilder().set_select_speakers_func(selector) - - with pytest.raises(ValueError, match="participant names must be non-empty strings"): - builder.participants({"": agent}) - - def test_assemble_group_chat_respects_existing_start_executor(self) -> None: - """Ensure assemble_group_chat_workflow does not override preconfigured start executor.""" - - async def manager(_: GroupChatStateSnapshot) -> GroupChatDirective: - return GroupChatDirective(finish=True) - - builder = CountingWorkflowBuilder() - entry = PassthroughExecutor(id="entry") - builder = builder.set_start_executor(entry) - - participant = PassthroughExecutor(id="participant") - participant_spec = GroupChatParticipantSpec( - name="participant", - participant=participant, - description="participant", - ) + def selector(state: GroupChatState) -> str: + return "agent" - wiring = _GroupChatConfig( - manager=manager, - manager_participant=None, - manager_name="manager", - participants={"participant": participant_spec}, - max_rounds=None, - termination_condition=None, - participant_aliases={}, - participant_executors={"participant": participant}, - ) + builder = GroupChatBuilder().with_select_speaker_func(selector) - result = assemble_group_chat_workflow( - wiring=wiring, - participant_factory=_default_participant_factory, - orchestrator_factory=_default_orchestrator_factory, - builder=builder, - return_builder=True, - ) - - assert isinstance(result, tuple) - assembled_builder, _ = result - assert assembled_builder is builder - assert builder.start_calls == 1 - assert assembled_builder._start_executor is entry # type: ignore + with pytest.raises(ValueError, match="AgentProtocol participants must have a non-empty name"): + builder.participants([agent]) -class TestGroupChatOrchestrator: - """Tests for GroupChatOrchestratorExecutor core functionality.""" +class TestGroupChatWorkflow: + """Tests for GroupChat workflow functionality.""" async def test_max_rounds_enforcement(self) -> None: """Test that max_rounds properly limits conversation rounds.""" call_count = {"value": 0} - def selector(state: GroupChatStateSnapshot) -> str | None: + def selector(state: GroupChatState) -> str: call_count["value"] += 1 # Always return the agent name to try to continue indefinitely return "agent" @@ -473,7 +391,7 @@ def selector(state: GroupChatStateSnapshot) -> str | None: workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector) + .with_select_speaker_func(selector) .participants([agent]) .with_max_rounds(2) # Limit to 2 rounds .build() @@ -492,12 +410,12 @@ def selector(state: GroupChatStateSnapshot) -> str | None: conversation = outputs[-1] assert len(conversation) >= 1 final_output = conversation[-1] - assert "round limit" in final_output.text.lower() + assert "maximum number of rounds" in final_output.text.lower() async def test_termination_condition_halts_conversation(self) -> None: """Test that a custom termination condition stops the workflow.""" - def selector(state: GroupChatStateSnapshot) -> str | None: + def selector(state: GroupChatState) -> str: return "agent" def termination_condition(conversation: list[ChatMessage]) -> bool: @@ -508,7 +426,7 @@ def termination_condition(conversation: list[ChatMessage]) -> bool: workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector) + .with_select_speaker_func(selector) .participants([agent]) .with_termination_condition(termination_condition) .build() @@ -526,46 +444,17 @@ def termination_condition(conversation: list[ChatMessage]) -> bool: agent_replies = [msg for msg in conversation if msg.author_name == "agent" and msg.role == Role.ASSISTANT] assert len(agent_replies) == 2 final_output = conversation[-1] - assert final_output.author_name == "manager" + # The orchestrator uses its ID as author_name by default assert "termination condition" in final_output.text.lower() - async def test_termination_condition_uses_manager_final_message(self) -> None: - """Test that manager-provided final message is used on termination.""" - - async def selector(state: GroupChatStateSnapshot) -> str | None: - return None - - agent = StubAgent("agent", "response") - final_text = "manager summary on termination" - - workflow = ( - GroupChatBuilder() - .set_select_speakers_func(selector, final_message=final_text) - .participants([agent]) - .with_termination_condition(lambda _: True) - .build() - ) - - outputs: list[list[ChatMessage]] = [] - async for event in workflow.run_stream("test task"): - if isinstance(event, WorkflowOutputEvent): - data = event.data - if isinstance(data, list): - outputs.append(cast(list[ChatMessage], data)) - - assert outputs, "Expected termination to yield output" - conversation = outputs[-1] - assert conversation[-1].text == final_text - assert conversation[-1].author_name == "manager" - async def test_termination_condition_agent_manager_finalizes(self) -> None: - """Test that agent-based manager can provide final message on termination.""" + """Test that termination condition with agent orchestrator produces default termination message.""" manager = StubManagerAgent() worker = StubAgent("agent", "response") workflow = ( GroupChatBuilder() - .set_manager(manager, display_name="Manager") + .with_agent_orchestrator(manager) .participants([worker]) .with_termination_condition(lambda conv: any(msg.author_name == "agent" for msg in conv)) .build() @@ -580,167 +469,23 @@ async def test_termination_condition_agent_manager_finalizes(self) -> None: assert outputs, "Expected termination to yield output" conversation = outputs[-1] - assert conversation[-1].text == "agent manager final" - assert conversation[-1].author_name == "Manager" + assert conversation[-1].text == BaseGroupChatOrchestrator.TERMINATION_CONDITION_MET_MESSAGE + assert conversation[-1].author_name == manager.name async def test_unknown_participant_error(self) -> None: - """Test that _apply_directive raises error for unknown participants.""" + """Test that unknown participant selection raises error.""" - def selector(state: GroupChatStateSnapshot) -> str | None: + def selector(state: GroupChatState) -> str: return "unknown_agent" # Return non-existent participant agent = StubAgent("agent", "response") - workflow = GroupChatBuilder().set_select_speakers_func(selector).participants([agent]).build() + workflow = GroupChatBuilder().with_select_speaker_func(selector).participants([agent]).build() - with pytest.raises(ValueError, match="Manager selected unknown participant 'unknown_agent'"): + with pytest.raises(RuntimeError, match="Selection function returned unknown participant 'unknown_agent'"): async for _ in workflow.run_stream("test task"): pass - async def test_directive_without_agent_name_raises_error(self) -> None: - """Test that directive without agent_name raises error when finish=False.""" - - def bad_selector(state: GroupChatStateSnapshot) -> GroupChatDirective: - # Return a GroupChatDirective object instead of string to trigger error - return GroupChatDirective(finish=False, agent_name=None) # type: ignore - - agent = StubAgent("agent", "response") - - # The _SpeakerSelectorAdapter will catch this and raise TypeError - workflow = GroupChatBuilder().set_select_speakers_func(bad_selector).participants([agent]).build() # type: ignore - - # This should raise a TypeError because selector doesn't return str or None - with pytest.raises(TypeError, match="must return a participant name \\(str\\) or None"): - async for _ in workflow.run_stream("test"): - pass - - async def test_handle_empty_conversation_raises_error(self) -> None: - """Test that empty conversation list raises ValueError.""" - - def selector(state: GroupChatStateSnapshot) -> str | None: - return None - - agent = StubAgent("agent", "response") - - workflow = GroupChatBuilder().set_select_speakers_func(selector).participants([agent]).build() - - with pytest.raises(ValueError, match="requires at least one chat message"): - async for _ in workflow.run_stream([]): - pass - - async def test_unknown_participant_response_raises_error(self) -> None: - """Test that responses from unknown participants raise errors.""" - - def selector(state: GroupChatStateSnapshot) -> str | None: - return "agent" - - # Create orchestrator to test _ingest_participant_message directly - orchestrator = GroupChatOrchestratorExecutor( - manager=selector, # type: ignore - participants={"agent": "test agent"}, - manager_name="test_manager", # type: ignore - ) - - # Mock the workflow context - class MockContext: - async def yield_output(self, message: ChatMessage) -> None: - pass - - ctx = MockContext() - - # Initialize orchestrator state - orchestrator._task_message = ChatMessage(role=Role.USER, text="test") # type: ignore - orchestrator._conversation = [orchestrator._task_message] # type: ignore - orchestrator._history = [] # type: ignore - orchestrator._pending_agent = None # type: ignore - orchestrator._round_index = 0 # type: ignore - - # Test with unknown participant - message = ChatMessage(role=Role.ASSISTANT, text="response") - - with pytest.raises(ValueError, match="Received response from unknown participant 'unknown'"): - await orchestrator._ingest_participant_message("unknown", message, ctx) # type: ignore - - async def test_state_build_before_initialization_raises_error(self) -> None: - """Test that _build_state raises error before task message initialization.""" - - def selector(state: GroupChatStateSnapshot) -> str | None: - return None - - orchestrator = GroupChatOrchestratorExecutor( - manager=selector, # type: ignore - participants={"agent": "test agent"}, - manager_name="test_manager", # type: ignore - ) - - with pytest.raises(RuntimeError, match="state not initialized with task message"): - orchestrator._build_state() # type: ignore - - -class TestSpeakerSelectorAdapter: - """Tests for _SpeakerSelectorAdapter functionality.""" - - async def test_selector_returning_list_with_multiple_items_raises_error(self) -> None: - """Test that selector returning list with multiple items raises error.""" - - def bad_selector(state: GroupChatStateSnapshot) -> list[str]: - return ["agent1", "agent2"] # Multiple items - - adapter = _SpeakerSelectorAdapter(bad_selector, manager_name="manager") - - state = { - "participants": {"agent1": "desc1", "agent2": "desc2"}, - "task": ChatMessage(role=Role.USER, text="test"), - "conversation": (), - "history": (), - "round_index": 0, - "pending_agent": None, - } - - with pytest.raises(ValueError, match="must return a single participant name"): - await adapter(state) - - async def test_selector_returning_non_string_raises_error(self) -> None: - """Test that selector returning non-string raises TypeError.""" - - def bad_selector(state: GroupChatStateSnapshot) -> int: - return 42 # Not a string - - adapter = _SpeakerSelectorAdapter(bad_selector, manager_name="manager") - - state = { - "participants": {"agent": "desc"}, - "task": ChatMessage(role=Role.USER, text="test"), - "conversation": (), - "history": (), - "round_index": 0, - "pending_agent": None, - } - - with pytest.raises(TypeError, match="must return a participant name \\(str\\) or None"): - await adapter(state) - - async def test_selector_returning_empty_list_finishes(self) -> None: - """Test that selector returning empty list finishes conversation.""" - - def empty_selector(state: GroupChatStateSnapshot) -> list[str]: - return [] # Empty list should finish - - adapter = _SpeakerSelectorAdapter(empty_selector, manager_name="manager") - - state = { - "participants": {"agent": "desc"}, - "task": ChatMessage(role=Role.USER, text="test"), - "conversation": (), - "history": (), - "round_index": 0, - "pending_agent": None, - } - - directive = await adapter(state) - assert directive.finish is True - assert directive.final_message is not None - class TestCheckpointing: """Tests for checkpointing functionality.""" @@ -748,9 +493,7 @@ class TestCheckpointing: async def test_workflow_with_checkpointing(self) -> None: """Test that workflow works with checkpointing enabled.""" - def selector(state: GroupChatStateSnapshot) -> str | None: - if state["round_index"] >= 1: - return None + def selector(state: GroupChatState) -> str: return "agent" agent = StubAgent("agent", "response") @@ -758,8 +501,9 @@ def selector(state: GroupChatStateSnapshot) -> str | None: workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector) + .with_select_speaker_func(selector) .participants([agent]) + .with_max_rounds(1) .with_checkpointing(storage) .build() ) @@ -774,87 +518,40 @@ def selector(state: GroupChatStateSnapshot) -> str | None: assert len(outputs) == 1 # Should complete normally -class TestAgentManagerConfiguration: - """Tests for agent-based manager configuration.""" - - async def test_set_manager_configures_response_format(self) -> None: - """Ensure ChatAgent managers receive default ManagerSelectionResponse formatting.""" - from unittest.mock import MagicMock - - from agent_framework import ChatAgent - - chat_client = MagicMock() - manager_agent = ChatAgent(chat_client=chat_client, name="Coordinator") - assert manager_agent.chat_options.response_format is None - - worker = StubAgent("worker", "response") - - builder = GroupChatBuilder().set_manager(manager_agent).participants([worker]) - - assert manager_agent.chat_options.response_format is ManagerSelectionResponse - assert builder._manager_participant is manager_agent # type: ignore[attr-defined] - - async def test_set_manager_accepts_agent_manager(self) -> None: - """Verify agent-based manager can be set and workflow builds.""" - from unittest.mock import MagicMock - - from agent_framework import ChatAgent - - chat_client = MagicMock() - manager_agent = ChatAgent(chat_client=chat_client, name="Coordinator") - worker = StubAgent("worker", "response") - - builder = GroupChatBuilder().set_manager(manager_agent, display_name="Orchestrator") - builder = builder.participants([worker]).with_max_rounds(1) - - assert builder._manager_participant is manager_agent # type: ignore[attr-defined] - assert "worker" in builder._participants # type: ignore[attr-defined] - - async def test_set_manager_rejects_custom_response_format(self) -> None: - """Reject custom response_format on ChatAgent managers.""" - from unittest.mock import MagicMock - - from agent_framework import ChatAgent - - class CustomResponse(BaseModel): - value: str - - chat_client = MagicMock() - manager_agent = ChatAgent(chat_client=chat_client, name="Coordinator", response_format=CustomResponse) - worker = StubAgent("worker", "response") - - with pytest.raises(ValueError, match="response_format must be ManagerSelectionResponse"): - GroupChatBuilder().set_manager(manager_agent).participants([worker]) - - assert manager_agent.chat_options.response_format is CustomResponse - +class TestConversationHandling: + """Tests for different conversation input types.""" -class TestFactoryFunctions: - """Tests for factory functions.""" + async def test_handle_empty_conversation_raises_error(self) -> None: + """Test that empty conversation list raises ValueError.""" - def test_default_orchestrator_factory_without_manager_raises_error(self) -> None: - """Test that default factory requires manager to be set.""" - config = _GroupChatConfig(manager=None, manager_participant=None, manager_name="test", participants={}) + def selector(state: GroupChatState) -> str: + return "agent" - with pytest.raises(RuntimeError, match="requires a manager to be configured"): - _default_orchestrator_factory(config) + agent = StubAgent("agent", "response") + workflow = ( + GroupChatBuilder().with_select_speaker_func(selector).participants([agent]).with_max_rounds(1).build() + ) -class TestConversationHandling: - """Tests for different conversation input types.""" + with pytest.raises(ValueError, match="At least one ChatMessage is required to start the group chat workflow."): + async for _ in workflow.run_stream([]): + pass async def test_handle_string_input(self) -> None: """Test handling string input creates proper ChatMessage.""" - def selector(state: GroupChatStateSnapshot) -> str | None: - # Verify the task was properly converted - assert state["task"].role == Role.USER - assert state["task"].text == "test string" - return None + def selector(state: GroupChatState) -> str: + # Verify the conversation has the user message + assert len(state.conversation) > 0 + assert state.conversation[0].role == Role.USER + assert state.conversation[0].text == "test string" + return "agent" agent = StubAgent("agent", "response") - workflow = GroupChatBuilder().set_select_speakers_func(selector).participants([agent]).build() + workflow = ( + GroupChatBuilder().with_select_speaker_func(selector).participants([agent]).with_max_rounds(1).build() + ) outputs: list[list[ChatMessage]] = [] async for event in workflow.run_stream("test string"): @@ -869,14 +566,17 @@ async def test_handle_chat_message_input(self) -> None: """Test handling ChatMessage input directly.""" task_message = ChatMessage(role=Role.USER, text="test message") - def selector(state: GroupChatStateSnapshot) -> str | None: - # Verify the task message was preserved - assert state["task"] == task_message - return None + def selector(state: GroupChatState) -> str: + # Verify the task message was preserved in conversation + assert len(state.conversation) > 0 + assert state.conversation[0] == task_message + return "agent" agent = StubAgent("agent", "response") - workflow = GroupChatBuilder().set_select_speakers_func(selector).participants([agent]).build() + workflow = ( + GroupChatBuilder().with_select_speaker_func(selector).participants([agent]).with_max_rounds(1).build() + ) outputs: list[list[ChatMessage]] = [] async for event in workflow.run_stream(task_message): @@ -894,15 +594,17 @@ async def test_handle_conversation_list_input(self) -> None: ChatMessage(role=Role.USER, text="user message"), ] - def selector(state: GroupChatStateSnapshot) -> str | None: + def selector(state: GroupChatState) -> str: # Verify conversation context is preserved - assert len(state["conversation"]) == 2 - assert state["task"].text == "user message" - return None + assert len(state.conversation) >= 2 + assert state.conversation[-1].text == "user message" + return "agent" agent = StubAgent("agent", "response") - workflow = GroupChatBuilder().set_select_speakers_func(selector).participants([agent]).build() + workflow = ( + GroupChatBuilder().with_select_speaker_func(selector).participants([agent]).with_max_rounds(1).build() + ) outputs: list[list[ChatMessage]] = [] async for event in workflow.run_stream(conversation): @@ -918,10 +620,10 @@ class TestRoundLimitEnforcement: """Tests for round limit checking functionality.""" async def test_round_limit_in_apply_directive(self) -> None: - """Test round limit enforcement in _apply_directive.""" + """Test round limit enforcement.""" rounds_called = {"count": 0} - def selector(state: GroupChatStateSnapshot) -> str | None: + def selector(state: GroupChatState) -> str: rounds_called["count"] += 1 # Keep trying to select agent to test limit enforcement return "agent" @@ -930,7 +632,7 @@ def selector(state: GroupChatStateSnapshot) -> str | None: workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector) + .with_select_speaker_func(selector) .participants([agent]) .with_max_rounds(1) # Very low limit .build() @@ -949,13 +651,13 @@ def selector(state: GroupChatStateSnapshot) -> str | None: conversation = outputs[-1] assert len(conversation) >= 1 final_output = conversation[-1] - assert "round limit" in final_output.text.lower() + assert "maximum number of rounds" in final_output.text.lower() async def test_round_limit_in_ingest_participant_message(self) -> None: """Test round limit enforcement after participant response.""" responses_received = {"count": 0} - def selector(state: GroupChatStateSnapshot) -> str | None: + def selector(state: GroupChatState) -> str: responses_received["count"] += 1 if responses_received["count"] == 1: return "agent" # First call selects agent @@ -965,7 +667,7 @@ def selector(state: GroupChatStateSnapshot) -> str | None: workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector) + .with_select_speaker_func(selector) .participants([agent]) .with_max_rounds(1) # Hit limit after first response .build() @@ -984,25 +686,29 @@ def selector(state: GroupChatStateSnapshot) -> str | None: conversation = outputs[-1] assert len(conversation) >= 1 final_output = conversation[-1] - assert "round limit" in final_output.text.lower() + assert "maximum number of rounds" in final_output.text.lower() async def test_group_chat_checkpoint_runtime_only() -> None: """Test checkpointing configured ONLY at runtime, not at build time.""" - from agent_framework import WorkflowRunState, WorkflowStatusEvent - storage = InMemoryCheckpointStorage() agent_a = StubAgent("agentA", "Reply from A") agent_b = StubAgent("agentB", "Reply from B") selector = make_sequence_selector() - wf = GroupChatBuilder().participants([agent_a, agent_b]).set_select_speakers_func(selector).build() + wf = ( + GroupChatBuilder() + .participants([agent_a, agent_b]) + .with_select_speaker_func(selector) + .with_max_rounds(2) + .build() + ) baseline_output: list[ChatMessage] | None = None async for ev in wf.run_stream("runtime checkpoint test", checkpoint_storage=storage): if isinstance(ev, WorkflowOutputEvent): - baseline_output = cast(list[ChatMessage], ev.data) if isinstance(ev.data, list) else None + baseline_output = cast(list[ChatMessage], ev.data) if isinstance(ev.data, list) else None # type: ignore if isinstance(ev, WorkflowStatusEvent) and ev.state in ( WorkflowRunState.IDLE, WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, @@ -1020,7 +726,6 @@ async def test_group_chat_checkpoint_runtime_overrides_buildtime() -> None: import tempfile with tempfile.TemporaryDirectory() as temp_dir1, tempfile.TemporaryDirectory() as temp_dir2: - from agent_framework import WorkflowRunState, WorkflowStatusEvent from agent_framework._workflows._checkpoint import FileCheckpointStorage buildtime_storage = FileCheckpointStorage(temp_dir1) @@ -1033,15 +738,15 @@ async def test_group_chat_checkpoint_runtime_overrides_buildtime() -> None: wf = ( GroupChatBuilder() .participants([agent_a, agent_b]) - .set_select_speakers_func(selector) + .with_select_speaker_func(selector) + .with_max_rounds(2) .with_checkpointing(buildtime_storage) .build() ) - baseline_output: list[ChatMessage] | None = None async for ev in wf.run_stream("override test", checkpoint_storage=runtime_storage): if isinstance(ev, WorkflowOutputEvent): - baseline_output = cast(list[ChatMessage], ev.data) if isinstance(ev.data, list) else None + baseline_output = cast(list[ChatMessage], ev.data) if isinstance(ev.data, list) else None # type: ignore if isinstance(ev, WorkflowStatusEvent) and ev.state in ( WorkflowRunState.IDLE, WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, @@ -1057,37 +762,8 @@ async def test_group_chat_checkpoint_runtime_overrides_buildtime() -> None: assert len(buildtime_checkpoints) == 0, "Build-time storage should have no checkpoints when overridden" -class _StubExecutor(Executor): - """Minimal executor used to satisfy workflow wiring in tests.""" - - def __init__(self, id: str) -> None: - super().__init__(id=id) - - @handler - async def handle(self, message: object, ctx: WorkflowContext[ChatMessage]) -> None: - await ctx.yield_output(message) - - -def test_set_manager_builds_with_agent_manager() -> None: - """GroupChatBuilder should build when using an agent-based manager.""" - - manager = _StubExecutor("manager_executor") - participant = _StubExecutor("participant_executor") - - workflow = ( - GroupChatBuilder().set_manager(manager, display_name="Moderator").participants({"worker": participant}).build() - ) - - orchestrator = workflow.get_start_executor() - - assert isinstance(orchestrator, GroupChatOrchestratorExecutor) - assert orchestrator._is_manager_agent() - - async def test_group_chat_with_request_info_filtering(): """Test that with_request_info(agents=[...]) only pauses before specified agents run.""" - from agent_framework import AgentInputRequest, RequestInfoEvent - # Create agents - we want to verify only beta triggers pause alpha = StubAgent("alpha", "response from alpha") beta = StubAgent("beta", "response from beta") @@ -1095,19 +771,21 @@ async def test_group_chat_with_request_info_filtering(): # Manager that selects alpha first, then beta, then finishes call_count = 0 - async def selector(state: GroupChatStateSnapshot) -> str | None: + async def selector(state: GroupChatState) -> str: nonlocal call_count call_count += 1 if call_count == 1: return "alpha" if call_count == 2: return "beta" - return None + # Return to alpha to continue + return "alpha" workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector, display_name="manager", final_message="done") - .participants(alpha=alpha, beta=beta) + .with_select_speaker_func(selector, orchestrator_name="manager") + .participants([alpha, beta]) + .with_max_rounds(2) .with_request_info(agents=["beta"]) # Only pause before beta runs .build() ) @@ -1115,7 +793,7 @@ async def selector(state: GroupChatStateSnapshot) -> str | None: # Run until we get a request info event (should be before beta, not alpha) request_events: list[RequestInfoEvent] = [] async for event in workflow.run_stream("test task"): - if isinstance(event, RequestInfoEvent) and isinstance(event.data, AgentInputRequest): + if isinstance(event, RequestInfoEvent) and isinstance(event.data, AgentExecutorResponse): request_events.append(event) # Don't break - let stream complete naturally when paused @@ -1123,13 +801,15 @@ async def selector(state: GroupChatStateSnapshot) -> str | None: assert len(request_events) == 1 request_event = request_events[0] - # The target agent should be beta's executor ID (groupchat_agent:beta) - assert request_event.data.target_agent_id is not None - assert "beta" in request_event.data.target_agent_id + # The target agent should be beta's executor ID + assert isinstance(request_event.data, AgentExecutorResponse) + assert request_event.source_executor_id == "beta" # Continue the workflow with a response outputs: list[WorkflowOutputEvent] = [] - async for event in workflow.send_responses_streaming({request_event.request_id: "continue please"}): + async for event in workflow.send_responses_streaming({ + request_event.request_id: AgentRequestInfoResponse.approve() + }): if isinstance(event, WorkflowOutputEvent): outputs.append(event) @@ -1139,25 +819,25 @@ async def selector(state: GroupChatStateSnapshot) -> str | None: async def test_group_chat_with_request_info_no_filter_pauses_all(): """Test that with_request_info() without agents pauses before all participants.""" - from agent_framework import AgentInputRequest, RequestInfoEvent - # Create agents alpha = StubAgent("alpha", "response from alpha") # Manager selects alpha then finishes call_count = 0 - async def selector(state: GroupChatStateSnapshot) -> str | None: + async def selector(state: GroupChatState) -> str: nonlocal call_count call_count += 1 if call_count == 1: return "alpha" - return None + # Keep returning alpha to continue + return "alpha" workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector, display_name="manager", final_message="done") - .participants(alpha=alpha) + .with_select_speaker_func(selector, orchestrator_name="manager") + .participants([alpha]) + .with_max_rounds(1) .with_request_info() # No filter - pause for all .build() ) @@ -1165,14 +845,13 @@ async def selector(state: GroupChatStateSnapshot) -> str | None: # Run until we get a request info event request_events: list[RequestInfoEvent] = [] async for event in workflow.run_stream("test task"): - if isinstance(event, RequestInfoEvent) and isinstance(event.data, AgentInputRequest): + if isinstance(event, RequestInfoEvent) and isinstance(event.data, AgentExecutorResponse): request_events.append(event) break # Should pause before alpha assert len(request_events) == 1 - assert request_events[0].data.target_agent_id is not None - assert "alpha" in request_events[0].data.target_agent_id + assert request_events[0].source_executor_id == "alpha" def test_group_chat_builder_with_request_info_returns_self(): diff --git a/python/packages/core/tests/workflow/test_handoff.py b/python/packages/core/tests/workflow/test_handoff.py index d0d5092323..268f89d513 100644 --- a/python/packages/core/tests/workflow/test_handoff.py +++ b/python/packages/core/tests/workflow/test_handoff.py @@ -1,151 +1,78 @@ # Copyright (c) Microsoft. All rights reserved. -from collections.abc import AsyncIterable, AsyncIterator -from dataclasses import dataclass +from collections.abc import AsyncIterable from typing import Any, cast -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, - BaseAgent, ChatAgent, ChatMessage, + ChatResponse, + ChatResponseUpdate, FunctionCallContent, + HandoffAgentUserRequest, HandoffBuilder, - HandoffUserInputRequest, RequestInfoEvent, Role, TextContent, WorkflowEvent, WorkflowOutputEvent, + resolve_agent_id, + use_function_invocation, ) -from agent_framework._mcp import MCPTool -from agent_framework._workflows import AgentRunEvent -from agent_framework._workflows import _handoff as handoff_module # type: ignore -from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value -from agent_framework._workflows._handoff import ( - _clone_chat_agent, # type: ignore[reportPrivateUsage] - _ConversationWithUserInput, - _UserInputGateway, -) -from agent_framework._workflows._workflow_builder import WorkflowBuilder - - -class _CountingWorkflowBuilder(WorkflowBuilder): - created: list["_CountingWorkflowBuilder"] = [] - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.start_calls = 0 - _CountingWorkflowBuilder.created.append(self) - - def set_start_executor(self, executor: Any) -> "_CountingWorkflowBuilder": # type: ignore[override] - self.start_calls += 1 - return cast("_CountingWorkflowBuilder", super().set_start_executor(executor)) - - -@dataclass -class _ComplexMetadata: - reason: str - payload: dict[str, str] - - -@pytest.fixture -def complex_metadata() -> _ComplexMetadata: - return _ComplexMetadata(reason="route", payload={"code": "X1"}) - - -def _metadata_from_conversation(conversation: list[ChatMessage], key: str) -> list[object]: - return [msg.additional_properties[key] for msg in conversation if key in msg.additional_properties] -def _conversation_debug(conversation: list[ChatMessage]) -> list[tuple[str, str | None, str]]: - return [ - (msg.role.value if hasattr(msg.role, "value") else str(msg.role), msg.author_name, msg.text) - for msg in conversation - ] +@use_function_invocation +class MockChatClient: + """Mock chat client for testing handoff workflows.""" + additional_properties: dict[str, Any] -class _RecordingAgent(BaseAgent): def __init__( self, - *, name: str, + *, handoff_to: str | None = None, - text_handoff: bool = False, - extra_properties: dict[str, object] | None = None, ) -> None: - super().__init__(id=name, name=name, display_name=name) - self._agent_name = name - self.handoff_to = handoff_to - self.calls: list[list[ChatMessage]] = [] - self._text_handoff = text_handoff - self._extra_properties = dict(extra_properties or {}) + """Initialize the mock chat client. + + Args: + name: The name of the agent using this chat client. + handoff_to: The name of the agent to hand off to, or None for no handoff. + This is hardcoded for testing purposes so that the agent always attempts to hand off. + """ + self._name = name + self._handoff_to = handoff_to self._call_index = 0 - async def run( # type: ignore[override] - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: Any = None, - **kwargs: Any, - ) -> AgentRunResponse: - conversation = _normalise(messages) - self.calls.append(conversation) - additional_properties = _merge_additional_properties( - self.handoff_to, self._text_handoff, self._extra_properties - ) - contents = _build_reply_contents(self._agent_name, self.handoff_to, self._text_handoff, self._next_call_id()) + async def get_response(self, messages: Any, **kwargs: Any) -> ChatResponse: + contents = _build_reply_contents(self._name, self._handoff_to, self._next_call_id()) reply = ChatMessage( role=Role.ASSISTANT, contents=contents, - author_name=self.display_name, - additional_properties=additional_properties, ) - return AgentRunResponse(messages=[reply]) + return ChatResponse(messages=reply, response_id="mock_response") - async def run_stream( # type: ignore[override] - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: Any = None, - **kwargs: Any, - ) -> AsyncIterator[AgentRunResponseUpdate]: - conversation = _normalise(messages) - self.calls.append(conversation) - additional_props = _merge_additional_properties(self.handoff_to, self._text_handoff, self._extra_properties) - contents = _build_reply_contents(self._agent_name, self.handoff_to, self._text_handoff, self._next_call_id()) - yield AgentRunResponseUpdate( - contents=contents, - role=Role.ASSISTANT, - additional_properties=additional_props, - ) + def get_streaming_response(self, messages: Any, **kwargs: Any) -> AsyncIterable[ChatResponseUpdate]: + async def _stream() -> AsyncIterable[ChatResponseUpdate]: + contents = _build_reply_contents(self._name, self._handoff_to, self._next_call_id()) + yield ChatResponseUpdate(contents=contents, role=Role.ASSISTANT) + + return _stream() def _next_call_id(self) -> str | None: - if not self.handoff_to: + if not self._handoff_to: return None - call_id = f"{self.id}-handoff-{self._call_index}" + call_id = f"{self._name}-handoff-{self._call_index}" self._call_index += 1 return call_id -def _merge_additional_properties( - handoff_to: str | None, use_text_hint: bool, extras: dict[str, object] -) -> dict[str, object]: - additional_properties: dict[str, object] = {} - if handoff_to and not use_text_hint: - additional_properties["handoff_to"] = handoff_to - additional_properties.update(extras) - return additional_properties - - def _build_reply_contents( agent_name: str, handoff_to: str | None, - use_text_hint: bool, call_id: str | None, ) -> list[TextContent | FunctionCallContent]: contents: list[TextContent | FunctionCallContent] = [] @@ -154,161 +81,89 @@ def _build_reply_contents( FunctionCallContent(call_id=call_id, name=f"handoff_to_{handoff_to}", arguments={"handoff_to": handoff_to}) ) text = f"{agent_name} reply" - if use_text_hint and handoff_to: - text += f"\nHANDOFF_TO: {handoff_to}" contents.append(TextContent(text=text)) return contents -def _normalise(messages: str | ChatMessage | list[str] | list[ChatMessage] | None) -> list[ChatMessage]: - if isinstance(messages, list): - result: list[ChatMessage] = [] - for msg in messages: - if isinstance(msg, ChatMessage): - result.append(msg) - elif isinstance(msg, str): - result.append(ChatMessage(Role.USER, text=msg)) - return result - if isinstance(messages, ChatMessage): - return [messages] - if isinstance(messages, str): - return [ChatMessage(Role.USER, text=messages)] - return [] +class MockHandoffAgent(ChatAgent): + """Mock agent that can hand off to another agent.""" + + def __init__( + self, + *, + name: str, + handoff_to: str | None = None, + ) -> None: + """Initialize the mock handoff agent. + + Args: + name: The name of the agent. + handoff_to: The name of the agent to hand off to, or None for no handoff. + This is hardcoded for testing purposes so that the agent always attempts to hand off. + """ + super().__init__(chat_client=MockChatClient(name, handoff_to=handoff_to), name=name, id=name) async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]: return [event async for event in stream] -async def test_specialist_to_specialist_handoff(): - """Test that specialists can hand off to other specialists via .add_handoff() configuration.""" - triage = _RecordingAgent(name="triage", handoff_to="specialist") - specialist = _RecordingAgent(name="specialist", handoff_to="escalation") - escalation = _RecordingAgent(name="escalation") +async def test_handoff(): + """Test that agents can hand off to each other.""" + + # `triage` hands off to `specialist`, who then hands off to `escalation`. + # `escalation` has no handoff, so the workflow should request user input to continue. + triage = MockHandoffAgent(name="triage", handoff_to="specialist") + specialist = MockHandoffAgent(name="specialist", handoff_to="escalation") + escalation = MockHandoffAgent(name="escalation") + # Without explicitly defining handoffs, the builder will create connections + # between all agents. workflow = ( HandoffBuilder(participants=[triage, specialist, escalation]) - .set_coordinator(triage) - .add_handoff(triage, [specialist, escalation]) - .add_handoff(specialist, escalation) + .with_start_agent(triage) .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 2) .build() ) - # Start conversation - triage hands off to specialist + # Start conversation - triage hands off to specialist then escalation + # escalation won't trigger a handoff, so the response from it will become + # a request for user input because autonomous mode is not enabled by default. events = await _drain(workflow.run_stream("Need technical support")) requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - - # Specialist should have been called - assert len(specialist.calls) > 0 - - # Second user message - specialist hands off to escalation - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "This is complex"})) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs - - # Escalation should have been called - assert len(escalation.calls) > 0 - - -async def test_handoff_preserves_complex_additional_properties(complex_metadata: _ComplexMetadata): - triage = _RecordingAgent(name="triage", handoff_to="specialist", extra_properties={"complex": complex_metadata}) - specialist = _RecordingAgent(name="specialist") - - # Sanity check: agent response contains complex metadata before entering workflow - triage_response = await triage.run([ChatMessage(role=Role.USER, text="Need help with a return")]) - assert triage_response.messages - assert "complex" in triage_response.messages[0].additional_properties - - workflow = ( - HandoffBuilder(participants=[triage, specialist]) - .set_coordinator(triage) - .with_termination_condition(lambda conv: sum(1 for msg in conv if msg.role == Role.USER) >= 2) - .build() - ) - - # Initial run should preserve complex metadata in the triage response - events = await _drain(workflow.run_stream("Need help with a return")) - agent_events = [ev for ev in events if isinstance(ev, AgentRunEvent)] - if agent_events: - first_agent_event = agent_events[0] - first_agent_event_data = first_agent_event.data - if first_agent_event_data and first_agent_event_data.messages: - first_agent_message = first_agent_event_data.messages[0] - assert "complex" in first_agent_message.additional_properties, "Agent event lost complex metadata" - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests, "Workflow should request additional user input" - - request_data = requests[-1].data - assert isinstance(request_data, HandoffUserInputRequest) - conversation_snapshot = request_data.conversation - metadata_values = _metadata_from_conversation(conversation_snapshot, "complex") - assert metadata_values, ( - "Expected triage message in conversation, found " - f"additional_properties={[msg.additional_properties for msg in conversation_snapshot]}," - f" messages={_conversation_debug(conversation_snapshot)}" - ) - assert any(isinstance(value, _ComplexMetadata) for value in metadata_values), ( - "Complex metadata lost after first hop" - ) - restored_meta = next(value for value in metadata_values if isinstance(value, _ComplexMetadata)) - assert restored_meta.payload["code"] == "X1" - - # Respond and ensure metadata survives subsequent cycles - follow_up_events = await _drain( - workflow.send_responses_streaming({requests[-1].request_id: "Here are more details"}) - ) - follow_up_requests = [ev for ev in follow_up_events if isinstance(ev, RequestInfoEvent)] - outputs = [ev for ev in follow_up_events if isinstance(ev, WorkflowOutputEvent)] - - follow_up_conversation: list[ChatMessage] - if follow_up_requests: - follow_up_request_data = follow_up_requests[-1].data - assert isinstance(follow_up_request_data, HandoffUserInputRequest) - follow_up_conversation = follow_up_request_data.conversation - else: - assert outputs, "Workflow produced neither follow-up request nor output" - output_data = outputs[-1].data - follow_up_conversation = cast(list[ChatMessage], output_data) if isinstance(output_data, list) else [] - - metadata_values_after = _metadata_from_conversation(follow_up_conversation, "complex") - assert metadata_values_after, "Expected triage message after follow-up" - assert any(isinstance(value, _ComplexMetadata) for value in metadata_values_after), ( - "Complex metadata lost after restore" - ) - - restored_meta_after = next(value for value in metadata_values_after if isinstance(value, _ComplexMetadata)) - assert restored_meta_after.payload["code"] == "X1" + assert requests + assert len(requests) == 1 -async def test_tool_call_handoff_detection_with_text_hint(): - triage = _RecordingAgent(name="triage", handoff_to="specialist", text_handoff=True) - specialist = _RecordingAgent(name="specialist") - - workflow = HandoffBuilder(participants=[triage, specialist]).set_coordinator(triage).build() - - await _drain(workflow.run_stream("Package arrived broken")) - - assert specialist.calls, "Specialist should be invoked using handoff tool call" - assert len(specialist.calls[0]) >= 2 + request = requests[0] + assert isinstance(request.data, HandoffAgentUserRequest) + assert request.source_executor_id == escalation.name -async def test_autonomous_interaction_mode_yields_output_without_user_request(): +async def test_autonomous_mode_yields_output_without_user_request(): """Ensure autonomous interaction mode yields output without requesting user input.""" - triage = _RecordingAgent(name="triage", handoff_to="specialist") - specialist = _RecordingAgent(name="specialist") + triage = MockHandoffAgent(name="triage", handoff_to="specialist") + specialist = MockHandoffAgent(name="specialist") workflow = ( HandoffBuilder(participants=[triage, specialist]) - .set_coordinator(triage) - .with_interaction_mode("autonomous", autonomous_turn_limit=1) + .with_start_agent(triage) + # Since specialist has no handoff, the specialist will be generating normal responses. + # With autonomous mode, this should continue until the termination condition is met. + .with_autonomous_mode( + agents=[specialist], + turn_limits={resolve_agent_id(specialist): 1}, + ) + # This termination condition ensures the workflow runs through both agents. + # First message is the user message to triage, second is triage's response, which + # is a handoff to specialist, third is specialist's response that should not request + # user input due to autonomous mode. Fourth message will come from the specialist + # again and will trigger termination. + .with_termination_condition(lambda conv: len(conv) >= 4) .build() ) events = await _drain(workflow.run_stream("Package arrived broken")) - assert len(triage.calls) == 1 - assert len(specialist.calls) == 1 requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] assert not requests, "Autonomous mode should not request additional user input" @@ -323,117 +178,31 @@ async def test_autonomous_interaction_mode_yields_output_without_user_request(): ) -async def test_autonomous_continues_without_handoff_until_termination(): - """Autonomous mode should keep invoking the same agent when no handoff occurs.""" - worker = _RecordingAgent(name="worker") - - workflow = ( - HandoffBuilder(participants=[worker]) - .set_coordinator(worker) - .with_interaction_mode("autonomous", autonomous_turn_limit=3) - .with_termination_condition(lambda conv: False) - .build() - ) - - events = await _drain(workflow.run_stream("Start")) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs, "Autonomous mode should yield output after termination condition" - assert len(worker.calls) == 3, "Worker should be invoked multiple times without user input" - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert not requests, "Autonomous mode should not request user input" - - -async def test_autonomous_turn_limit_stops_loop(): - """Autonomous mode should stop when the configured turn limit is reached.""" - worker = _RecordingAgent(name="worker") +async def test_autonomous_mode_resumes_user_input_on_turn_limit(): + """Autonomous mode should resume user input request when turn limit is reached.""" + triage = MockHandoffAgent(name="triage", handoff_to="worker") + worker = MockHandoffAgent(name="worker") workflow = ( - HandoffBuilder(participants=[worker]) - .set_coordinator(worker) - .with_interaction_mode("autonomous", autonomous_turn_limit=2) + HandoffBuilder(participants=[triage, worker]) + .with_start_agent(triage) + .with_autonomous_mode(agents=[worker], turn_limits={resolve_agent_id(worker): 2}) .with_termination_condition(lambda conv: False) .build() ) events = await _drain(workflow.run_stream("Start")) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs, "Turn limit should force a workflow output" - assert len(worker.calls) == 2, "Worker should stop after reaching the turn limit" requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert not requests, "Autonomous mode should not request user input" - - -async def test_autonomous_routes_back_to_coordinator_when_specialist_stops(): - """Specialist without handoff should route back to coordinator in autonomous mode.""" - triage = _RecordingAgent(name="triage", handoff_to="specialist") - specialist = _RecordingAgent(name="specialist") - - workflow = ( - HandoffBuilder(participants=[triage, specialist]) - .set_coordinator(triage) - .add_handoff(triage, specialist) - .with_interaction_mode("autonomous", autonomous_turn_limit=3) - .with_termination_condition(lambda conv: len(conv) >= 4) - .build() - ) - - events = await _drain(workflow.run_stream("Issue")) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs, "Workflow should complete without user input" - assert len(specialist.calls) >= 1, "Specialist should run without handoff" - - -async def test_autonomous_mode_with_inline_turn_limit(): - """Autonomous mode should respect turn limit passed via with_interaction_mode.""" - worker = _RecordingAgent(name="worker") - - workflow = ( - HandoffBuilder(participants=[worker]) - .set_coordinator(worker) - .with_interaction_mode("autonomous", autonomous_turn_limit=2) - .with_termination_condition(lambda conv: False) - .build() - ) - - events = await _drain(workflow.run_stream("Start")) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs, "Turn limit should force a workflow output" - assert len(worker.calls) == 2, "Worker should stop after reaching the inline turn limit" - - -def test_autonomous_turn_limit_ignored_in_human_in_loop_mode(caplog): - """Verify that autonomous_turn_limit logs a warning when mode is human_in_loop.""" - worker = _RecordingAgent(name="worker") - - # Should not raise, but should log a warning - HandoffBuilder(participants=[worker]).set_coordinator(worker).with_interaction_mode( - "human_in_loop", autonomous_turn_limit=10 - ) - - assert "autonomous_turn_limit=10 was provided but interaction_mode is 'human_in_loop'; ignoring." in caplog.text + assert requests and len(requests) == 1, "Turn limit should force a user input request" + assert requests[0].source_executor_id == worker.name -def test_autonomous_turn_limit_must_be_positive(): - """Verify that autonomous_turn_limit raises an error when <= 0.""" - worker = _RecordingAgent(name="worker") +def test_build_fails_without_start_agent(): + """Verify that build() raises ValueError when with_start_agent() was not called.""" + triage = MockHandoffAgent(name="triage") + specialist = MockHandoffAgent(name="specialist") - with pytest.raises(ValueError, match="autonomous_turn_limit must be positive"): - HandoffBuilder(participants=[worker]).set_coordinator(worker).with_interaction_mode( - "autonomous", autonomous_turn_limit=0 - ) - - with pytest.raises(ValueError, match="autonomous_turn_limit must be positive"): - HandoffBuilder(participants=[worker]).set_coordinator(worker).with_interaction_mode( - "autonomous", autonomous_turn_limit=-5 - ) - - -def test_build_fails_without_coordinator(): - """Verify that build() raises ValueError when set_coordinator() was not called.""" - triage = _RecordingAgent(name="triage") - specialist = _RecordingAgent(name="specialist") - - with pytest.raises(ValueError, match=r"Must call set_coordinator\(...\) before building the workflow."): + with pytest.raises(ValueError, match=r"Must call with_start_agent\(...\) before building the workflow."): HandoffBuilder(participants=[triage, specialist]).build() @@ -453,11 +222,12 @@ async def async_termination(conv: list[ChatMessage]) -> bool: user_count = sum(1 for msg in conv if msg.role == Role.USER) return user_count >= 2 - coordinator = _RecordingAgent(name="coordinator") + coordinator = MockHandoffAgent(name="coordinator", handoff_to="worker") + worker = MockHandoffAgent(name="worker") workflow = ( - HandoffBuilder(participants=[coordinator]) - .set_coordinator(coordinator) + HandoffBuilder(participants=[coordinator, worker]) + .with_start_agent(coordinator) .with_termination_condition(async_termination) .build() ) @@ -466,7 +236,11 @@ async def async_termination(conv: list[ChatMessage]) -> bool: requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] assert requests - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Second user message"})) + events = await _drain( + workflow.send_responses_streaming({ + requests[-1].request_id: [ChatMessage(role=Role.USER, text="Second user message")] + }) + ) outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] assert len(outputs) == 1 @@ -478,195 +252,14 @@ async def async_termination(conv: list[ChatMessage]) -> bool: assert termination_call_count > 0 -async def test_clone_chat_agent_preserves_mcp_tools() -> None: - """Test that _clone_chat_agent preserves MCP tools when cloning an agent.""" - mock_chat_client = MagicMock() - - mock_mcp_tool = MagicMock(spec=MCPTool) - mock_mcp_tool.name = "test_mcp_tool" - - def sample_function() -> str: - return "test" - - original_agent = ChatAgent( - chat_client=mock_chat_client, - name="TestAgent", - instructions="Test instructions", - tools=[mock_mcp_tool, sample_function], - ) - - assert hasattr(original_agent, "_local_mcp_tools") - assert len(original_agent._local_mcp_tools) == 1 # type: ignore[reportPrivateUsage] - assert original_agent._local_mcp_tools[0] == mock_mcp_tool # type: ignore[reportPrivateUsage] - - cloned_agent = _clone_chat_agent(original_agent) - - assert hasattr(cloned_agent, "_local_mcp_tools") - assert len(cloned_agent._local_mcp_tools) == 1 # type: ignore[reportPrivateUsage] - assert cloned_agent._local_mcp_tools[0] == mock_mcp_tool # type: ignore[reportPrivateUsage] - assert cloned_agent.chat_options.tools is not None - assert len(cloned_agent.chat_options.tools) == 1 - - -async def test_return_to_previous_routing(): - """Test that return-to-previous routes back to the current specialist handling the conversation.""" - triage = _RecordingAgent(name="triage", handoff_to="specialist_a") - specialist_a = _RecordingAgent(name="specialist_a", handoff_to="specialist_b") - specialist_b = _RecordingAgent(name="specialist_b") - - workflow = ( - HandoffBuilder(participants=[triage, specialist_a, specialist_b]) - .set_coordinator(triage) - .add_handoff(triage, [specialist_a, specialist_b]) - .add_handoff(specialist_a, specialist_b) - .enable_return_to_previous(True) - .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 4) - .build() - ) - - # Start conversation - triage hands off to specialist_a - events = await _drain(workflow.run_stream("Initial request")) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - assert len(specialist_a.calls) > 0 - - # Specialist_a should have been called with initial request - initial_specialist_a_calls = len(specialist_a.calls) - - # Second user message - specialist_a hands off to specialist_b - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need more help"})) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - - # Specialist_b should have been called - assert len(specialist_b.calls) > 0 - initial_specialist_b_calls = len(specialist_b.calls) - - # Third user message - with return_to_previous, should route back to specialist_b (current agent) - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up question"})) - third_requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - - # Specialist_b should have been called again (return-to-previous routes to current agent) - assert len(specialist_b.calls) > initial_specialist_b_calls, ( - "Specialist B should be called again due to return-to-previous routing to current agent" - ) - - # Specialist_a should NOT be called again (it's no longer the current agent) - assert len(specialist_a.calls) == initial_specialist_a_calls, ( - "Specialist A should not be called again - specialist_b is the current agent" - ) - - # Triage should only have been called once at the start - assert len(triage.calls) == 1, "Triage should only be called once (initial routing)" - - # Verify awaiting_agent_id is set to specialist_b (the agent that just responded) - if third_requests: - user_input_req = third_requests[-1].data - assert isinstance(user_input_req, HandoffUserInputRequest) - assert user_input_req.awaiting_agent_id == "specialist_b", ( - f"Expected awaiting_agent_id 'specialist_b' but got '{user_input_req.awaiting_agent_id}'" - ) - - -async def test_return_to_previous_disabled_routes_to_coordinator(): - """Test that with return-to-previous disabled, routing goes back to coordinator.""" - triage = _RecordingAgent(name="triage", handoff_to="specialist_a") - specialist_a = _RecordingAgent(name="specialist_a", handoff_to="specialist_b") - specialist_b = _RecordingAgent(name="specialist_b") - - workflow = ( - HandoffBuilder(participants=[triage, specialist_a, specialist_b]) - .set_coordinator(triage) - .add_handoff(triage, [specialist_a, specialist_b]) - .add_handoff(specialist_a, specialist_b) - .enable_return_to_previous(False) - .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3) - .build() - ) - - # Start conversation - triage hands off to specialist_a - events = await _drain(workflow.run_stream("Initial request")) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - assert len(triage.calls) == 1 - - # Second user message - specialist_a hands off to specialist_b - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need more help"})) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - - # Third user message - without return_to_previous, should route back to triage - await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up question"})) - - # Triage should have been called twice total: initial + after specialist_b responds - assert len(triage.calls) == 2, "Triage should be called twice (initial + default routing to coordinator)" - - -async def test_return_to_previous_enabled(): - """Verify that enable_return_to_previous() keeps control with the current specialist.""" - triage = _RecordingAgent(name="triage", handoff_to="specialist_a") - specialist_a = _RecordingAgent(name="specialist_a") - specialist_b = _RecordingAgent(name="specialist_b") - - workflow = ( - HandoffBuilder(participants=[triage, specialist_a, specialist_b]) - .set_coordinator(triage) - .enable_return_to_previous(True) - .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3) - .build() - ) - - # Start conversation - triage hands off to specialist_a - events = await _drain(workflow.run_stream("Initial request")) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - assert len(triage.calls) == 1 - assert len(specialist_a.calls) == 1 - - # Second user message - with return_to_previous, should route to specialist_a (not triage) - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up question"})) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - - # Triage should only have been called once (initial) - specialist_a handles follow-up - assert len(triage.calls) == 1, "Triage should only be called once (initial)" - assert len(specialist_a.calls) == 2, "Specialist A should handle follow-up with return_to_previous enabled" - - -def test_handoff_builder_sets_start_executor_once(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure HandoffBuilder.build sets the start executor only once when assembling the workflow.""" - _CountingWorkflowBuilder.created.clear() - monkeypatch.setattr(handoff_module, "WorkflowBuilder", _CountingWorkflowBuilder) - - coordinator = _RecordingAgent(name="coordinator") - specialist = _RecordingAgent(name="specialist") - - workflow = ( - HandoffBuilder(participants=[coordinator, specialist]) - .set_coordinator(coordinator) - .with_termination_condition(lambda conv: len(conv) > 0) - .build() - ) - - assert workflow is not None - assert _CountingWorkflowBuilder.created, "Expected CountingWorkflowBuilder to be instantiated" - builder = _CountingWorkflowBuilder.created[-1] - assert builder.start_calls == 1, "set_start_executor should be invoked exactly once" - - async def test_tool_choice_preserved_from_agent_config(): """Verify that agent-level tool_choice configuration is preserved and not overridden.""" - from unittest.mock import AsyncMock - - from agent_framework import ChatResponse, ToolMode - # Create a mock chat client that records the tool_choice used recorded_tool_choices: list[Any] = [] - async def mock_get_response(messages: Any, **kwargs: Any) -> ChatResponse: - chat_options = kwargs.get("chat_options") - if chat_options: - recorded_tool_choices.append(chat_options.tool_choice) + async def mock_get_response(messages: Any, options: dict[str, Any] | None = None, **kwargs: Any) -> ChatResponse: + if options: + recorded_tool_choices.append(options.get("tool_choice")) return ChatResponse( messages=[ChatMessage(role=Role.ASSISTANT, text="Response")], response_id="test_response", @@ -675,11 +268,11 @@ async def mock_get_response(messages: Any, **kwargs: Any) -> ChatResponse: mock_client = MagicMock() mock_client.get_response = AsyncMock(side_effect=mock_get_response) - # Create agent with specific tool_choice configuration + # Create agent with specific tool_choice configuration via default_options agent = ChatAgent( chat_client=mock_client, name="test_agent", - tool_choice=ToolMode(mode="required"), # type: ignore[arg-type] + default_options={"tool_choice": {"mode": "required"}}, ) # Run the agent @@ -689,97 +282,7 @@ async def mock_get_response(messages: Any, **kwargs: Any) -> ChatResponse: assert len(recorded_tool_choices) > 0, "No tool_choice recorded" last_tool_choice = recorded_tool_choices[-1] assert last_tool_choice is not None, "tool_choice should not be None" - assert str(last_tool_choice) == "required", f"Expected 'required', got {last_tool_choice}" - - -async def test_handoff_builder_with_request_info(): - """Test that HandoffBuilder supports request info via with_request_info().""" - from agent_framework import AgentInputRequest, RequestInfoEvent - - # Create test agents - coordinator = _RecordingAgent(name="coordinator") - specialist = _RecordingAgent(name="specialist") - - # Build workflow with request info enabled - workflow = ( - HandoffBuilder(participants=[coordinator, specialist]) - .set_coordinator(coordinator) - .with_termination_condition(lambda conv: len([m for m in conv if m.role == Role.USER]) >= 1) - .with_request_info() - .build() - ) - - # Run workflow until it pauses for request info - request_event: RequestInfoEvent | None = None - async for event in workflow.run_stream("Hello"): - if isinstance(event, RequestInfoEvent) and isinstance(event.data, AgentInputRequest): - request_event = event - - # Verify request info was emitted - assert request_event is not None, "Request info should have been emitted" - assert isinstance(request_event.data, AgentInputRequest) - - # Provide response and continue - output_events: list[WorkflowOutputEvent] = [] - async for event in workflow.send_responses_streaming({request_event.request_id: "approved"}): - if isinstance(event, WorkflowOutputEvent): - output_events.append(event) - - # Verify we got output events - assert len(output_events) > 0, "Should produce output events after response" - - -async def test_handoff_builder_with_request_info_method_chaining(): - """Test that with_request_info returns self for method chaining.""" - coordinator = _RecordingAgent(name="coordinator") - - builder = HandoffBuilder(participants=[coordinator]) - result = builder.with_request_info() - - assert result is builder, "with_request_info should return self for chaining" - assert builder._request_info_enabled is True # type: ignore - - -async def test_return_to_previous_state_serialization(): - """Test that return_to_previous state is properly serialized/deserialized for checkpointing.""" - from agent_framework._workflows._handoff import _HandoffCoordinator # type: ignore[reportPrivateUsage] - - # Create a coordinator with return_to_previous enabled - coordinator = _HandoffCoordinator( - starting_agent_id="triage", - specialist_ids={"specialist_a": "specialist_a", "specialist_b": "specialist_b"}, - input_gateway_id="gateway", - termination_condition=lambda conv: False, - id="test-coordinator", - return_to_previous=True, - ) - - # Set the current agent (simulating a handoff scenario) - coordinator._current_agent_id = "specialist_a" # type: ignore[reportPrivateUsage] - - # Snapshot the state - state = await coordinator.on_checkpoint_save() - - # Verify pattern metadata includes current_agent_id - assert "metadata" in state - assert "current_agent_id" in state["metadata"] - assert state["metadata"]["current_agent_id"] == "specialist_a" - - # Create a new coordinator and restore state - coordinator2 = _HandoffCoordinator( - starting_agent_id="triage", - specialist_ids={"specialist_a": "specialist_a", "specialist_b": "specialist_b"}, - input_gateway_id="gateway", - termination_condition=lambda conv: False, - id="test-coordinator", - return_to_previous=True, - ) - - # Restore state - await coordinator2.on_checkpoint_restore(state) - - # Verify current_agent_id was restored - assert coordinator2._current_agent_id == "specialist_a", "Current agent should be restored from checkpoint" # type: ignore[reportPrivateUsage] + assert last_tool_choice == {"mode": "required"}, f"Expected 'required', got {last_tool_choice}" # region Participant Factory Tests @@ -797,43 +300,43 @@ def test_handoff_builder_rejects_empty_participant_factories(): def test_handoff_builder_rejects_mixing_participants_and_factories(): """Test that mixing participants and participant_factories in __init__ raises an error.""" - triage = _RecordingAgent(name="triage") + triage = MockHandoffAgent(name="triage") with pytest.raises(ValueError, match="Cannot mix .participants"): HandoffBuilder(participants=[triage], participant_factories={"triage": lambda: triage}) def test_handoff_builder_rejects_mixing_participants_and_participant_factories_methods(): """Test that mixing .participants() and .participant_factories() raises an error.""" - triage = _RecordingAgent(name="triage") + triage = MockHandoffAgent(name="triage") # Case 1: participants first, then participant_factories with pytest.raises(ValueError, match="Cannot mix .participants"): HandoffBuilder(participants=[triage]).participant_factories({ - "specialist": lambda: _RecordingAgent(name="specialist") + "specialist": lambda: MockHandoffAgent(name="specialist") }) # Case 2: participant_factories first, then participants with pytest.raises(ValueError, match="Cannot mix .participants"): HandoffBuilder(participant_factories={"triage": lambda: triage}).participants([ - _RecordingAgent(name="specialist") + MockHandoffAgent(name="specialist") ]) # Case 3: participants(), then participant_factories() with pytest.raises(ValueError, match="Cannot mix .participants"): HandoffBuilder().participants([triage]).participant_factories({ - "specialist": lambda: _RecordingAgent(name="specialist") + "specialist": lambda: MockHandoffAgent(name="specialist") }) # Case 4: participant_factories(), then participants() with pytest.raises(ValueError, match="Cannot mix .participants"): HandoffBuilder().participant_factories({"triage": lambda: triage}).participants([ - _RecordingAgent(name="specialist") + MockHandoffAgent(name="specialist") ]) # Case 5: mix during initialization with pytest.raises(ValueError, match="Cannot mix .participants"): HandoffBuilder( - participants=[triage], participant_factories={"specialist": lambda: _RecordingAgent(name="specialist")} + participants=[triage], participant_factories={"specialist": lambda: MockHandoffAgent(name="specialist")} ) @@ -842,60 +345,49 @@ def test_handoff_builder_rejects_multiple_calls_to_participant_factories(): with pytest.raises(ValueError, match=r"participant_factories\(\) has already been called"): ( HandoffBuilder() - .participant_factories({"agent1": lambda: _RecordingAgent(name="agent1")}) - .participant_factories({"agent2": lambda: _RecordingAgent(name="agent2")}) + .participant_factories({"agent1": lambda: MockHandoffAgent(name="agent1")}) + .participant_factories({"agent2": lambda: MockHandoffAgent(name="agent2")}) ) def test_handoff_builder_rejects_multiple_calls_to_participants(): """Test that multiple calls to .participants() raises an error.""" with pytest.raises(ValueError, match="participants have already been assigned"): - (HandoffBuilder().participants([_RecordingAgent(name="agent1")]).participants([_RecordingAgent(name="agent2")])) - - -def test_handoff_builder_rejects_duplicate_factories(): - """Test that multiple calls to participant_factories are rejected.""" - factories = { - "triage": lambda: _RecordingAgent(name="triage"), - "specialist": lambda: _RecordingAgent(name="specialist"), - } - - # Multiple calls to participant_factories should fail - builder = HandoffBuilder(participant_factories=factories) - with pytest.raises(ValueError, match=r"participant_factories\(\) has already been called"): - builder.participant_factories({"triage": lambda: _RecordingAgent(name="triage2")}) + ( + HandoffBuilder() + .participants([MockHandoffAgent(name="agent1")]) + .participants([MockHandoffAgent(name="agent2")]) + ) def test_handoff_builder_rejects_instance_coordinator_with_factories(): """Test that using an agent instance for set_coordinator when using factories raises an error.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage") - def create_specialist() -> _RecordingAgent: - return _RecordingAgent(name="specialist") + def create_specialist() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist") # Create an agent instance - coordinator_instance = _RecordingAgent(name="coordinator") + coordinator_instance = MockHandoffAgent(name="coordinator") - with pytest.raises(ValueError, match=r"Call participants\(\.\.\.\) before coordinator\(\.\.\.\)"): + with pytest.raises(ValueError, match=r"Call participants\(\.\.\.\) before with_start_agent\(\.\.\.\)"): ( HandoffBuilder( participant_factories={"triage": create_triage, "specialist": create_specialist} - ).set_coordinator(coordinator_instance) # Instance, not factory name + ).with_start_agent(coordinator_instance) # Instance, not factory name ) def test_handoff_builder_rejects_factory_name_coordinator_with_instances(): """Test that using a factory name for set_coordinator when using instances raises an error.""" - triage = _RecordingAgent(name="triage") - specialist = _RecordingAgent(name="specialist") + triage = MockHandoffAgent(name="triage") + specialist = MockHandoffAgent(name="specialist") - with pytest.raises( - ValueError, match="coordinator factory name 'triage' is not part of the participant_factories list" - ): + with pytest.raises(ValueError, match="Call participant_factories.*before with_start_agent"): ( - HandoffBuilder(participants=[triage, specialist]).set_coordinator( + HandoffBuilder(participants=[triage, specialist]).with_start_agent( "triage" ) # String factory name, not instance ) @@ -903,28 +395,28 @@ def test_handoff_builder_rejects_factory_name_coordinator_with_instances(): def test_handoff_builder_rejects_mixed_types_in_add_handoff_source(): """Test that add_handoff rejects factory name source with instance-based participants.""" - triage = _RecordingAgent(name="triage") - specialist = _RecordingAgent(name="specialist") + triage = MockHandoffAgent(name="triage") + specialist = MockHandoffAgent(name="specialist") - with pytest.raises(TypeError, match="Cannot mix factory names \\(str\\) and AgentProtocol/Executor instances"): + with pytest.raises(TypeError, match="Cannot mix factory names \\(str\\) and AgentProtocol.*instances"): ( HandoffBuilder(participants=[triage, specialist]) - .set_coordinator(triage) - .add_handoff("triage", specialist) # String source with instance participants + .with_start_agent(triage) + .add_handoff("triage", [specialist]) # String source with instance participants ) def test_handoff_builder_accepts_all_factory_names_in_add_handoff(): """Test that add_handoff accepts all factory names when using participant_factories.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage") - def create_specialist_a() -> _RecordingAgent: - return _RecordingAgent(name="specialist_a") + def create_specialist_a() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist_a") - def create_specialist_b() -> _RecordingAgent: - return _RecordingAgent(name="specialist_b") + def create_specialist_b() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist_b") # This should work - all strings with participant_factories builder = ( @@ -935,7 +427,7 @@ def create_specialist_b() -> _RecordingAgent: "specialist_b": create_specialist_b, } ) - .set_coordinator("triage") + .with_start_agent("triage") .add_handoff("triage", ["specialist_a", "specialist_b"]) ) @@ -947,14 +439,14 @@ def create_specialist_b() -> _RecordingAgent: def test_handoff_builder_accepts_all_instances_in_add_handoff(): """Test that add_handoff accepts all instances when using participants.""" - triage = _RecordingAgent(name="triage", handoff_to="specialist_a") - specialist_a = _RecordingAgent(name="specialist_a") - specialist_b = _RecordingAgent(name="specialist_b") + triage = MockHandoffAgent(name="triage", handoff_to="specialist_a") + specialist_a = MockHandoffAgent(name="specialist_a") + specialist_b = MockHandoffAgent(name="specialist_b") # This should work - all instances with participants builder = ( HandoffBuilder(participants=[triage, specialist_a, specialist_b]) - .set_coordinator(triage) + .with_start_agent(triage) .add_handoff(triage, [specialist_a, specialist_b]) ) @@ -968,19 +460,19 @@ async def test_handoff_with_participant_factories(): """Test workflow creation using participant_factories.""" call_count = 0 - def create_triage() -> _RecordingAgent: + def create_triage() -> MockHandoffAgent: nonlocal call_count call_count += 1 - return _RecordingAgent(name="triage", handoff_to="specialist") + return MockHandoffAgent(name="triage", handoff_to="specialist") - def create_specialist() -> _RecordingAgent: + def create_specialist() -> MockHandoffAgent: nonlocal call_count call_count += 1 - return _RecordingAgent(name="specialist") + return MockHandoffAgent(name="specialist") workflow = ( HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) - .set_coordinator("triage") + .with_start_agent("triage") .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 2) .build() ) @@ -993,7 +485,9 @@ def create_specialist() -> _RecordingAgent: assert requests # Follow-up message - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "More details"})) + events = await _drain( + workflow.send_responses_streaming({requests[-1].request_id: [ChatMessage(role=Role.USER, text="More details")]}) + ) outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] assert outputs @@ -1002,19 +496,19 @@ async def test_handoff_participant_factories_reusable_builder(): """Test that the builder can be reused to build multiple workflows with factories.""" call_count = 0 - def create_triage() -> _RecordingAgent: + def create_triage() -> MockHandoffAgent: nonlocal call_count call_count += 1 - return _RecordingAgent(name="triage", handoff_to="specialist") + return MockHandoffAgent(name="triage", handoff_to="specialist") - def create_specialist() -> _RecordingAgent: + def create_specialist() -> MockHandoffAgent: nonlocal call_count call_count += 1 - return _RecordingAgent(name="specialist") + return MockHandoffAgent(name="specialist") builder = HandoffBuilder( participant_factories={"triage": create_triage, "specialist": create_specialist} - ).set_coordinator("triage") + ).with_start_agent("triage") # Build first workflow wf1 = builder.build() @@ -1032,14 +526,14 @@ def create_specialist() -> _RecordingAgent: async def test_handoff_with_participant_factories_and_add_handoff(): """Test that .add_handoff() works correctly with participant_factories.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage", handoff_to="specialist_a") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage", handoff_to="specialist_a") - def create_specialist_a() -> _RecordingAgent: - return _RecordingAgent(name="specialist_a", handoff_to="specialist_b") + def create_specialist_a() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist_a", handoff_to="specialist_b") - def create_specialist_b() -> _RecordingAgent: - return _RecordingAgent(name="specialist_b") + def create_specialist_b() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist_b") workflow = ( HandoffBuilder( @@ -1049,9 +543,9 @@ def create_specialist_b() -> _RecordingAgent: "specialist_b": create_specialist_b, } ) - .set_coordinator("triage") + .with_start_agent("triage") .add_handoff("triage", ["specialist_a", "specialist_b"]) - .add_handoff("specialist_a", "specialist_b") + .add_handoff("specialist_a", ["specialist_b"]) .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3) .build() ) @@ -1065,7 +559,11 @@ def create_specialist_b() -> _RecordingAgent: assert "specialist_a" in workflow.executors # Second user message - specialist_a hands off to specialist_b - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need escalation"})) + events = await _drain( + workflow.send_responses_streaming({ + requests[-1].request_id: [ChatMessage(role=Role.USER, text="Need escalation")] + }) + ) requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] assert requests @@ -1079,15 +577,15 @@ async def test_handoff_participant_factories_with_checkpointing(): storage = InMemoryCheckpointStorage() - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage", handoff_to="specialist") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage", handoff_to="specialist") - def create_specialist() -> _RecordingAgent: - return _RecordingAgent(name="specialist") + def create_specialist() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist") workflow = ( HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) - .set_coordinator("triage") + .with_start_agent("triage") .with_checkpointing(storage) .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 2) .build() @@ -1098,7 +596,9 @@ def create_specialist() -> _RecordingAgent: requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] assert requests - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "follow up"})) + events = await _drain( + workflow.send_responses_streaming({requests[-1].request_id: [ChatMessage(role=Role.USER, text="follow up")]}) + ) outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] assert outputs, "Should have workflow output after termination condition is met" @@ -1110,15 +610,15 @@ def create_specialist() -> _RecordingAgent: def test_handoff_set_coordinator_with_factory_name(): """Test that set_coordinator accepts factory name as string.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage") - def create_specialist() -> _RecordingAgent: - return _RecordingAgent(name="specialist") + def create_specialist() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist") builder = HandoffBuilder( participant_factories={"triage": create_triage, "specialist": create_specialist} - ).set_coordinator("triage") + ).with_start_agent("triage") workflow = builder.build() assert "triage" in workflow.executors @@ -1127,14 +627,14 @@ def create_specialist() -> _RecordingAgent: def test_handoff_add_handoff_with_factory_names(): """Test that add_handoff accepts factory names as strings.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage", handoff_to="specialist_a") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage", handoff_to="specialist_a") - def create_specialist_a() -> _RecordingAgent: - return _RecordingAgent(name="specialist_a") + def create_specialist_a() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist_a") - def create_specialist_b() -> _RecordingAgent: - return _RecordingAgent(name="specialist_b") + def create_specialist_b() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist_b") builder = ( HandoffBuilder( @@ -1144,7 +644,7 @@ def create_specialist_b() -> _RecordingAgent: "specialist_b": create_specialist_b, } ) - .set_coordinator("triage") + .with_start_agent("triage") .add_handoff("triage", ["specialist_a", "specialist_b"]) ) @@ -1157,516 +657,53 @@ def create_specialist_b() -> _RecordingAgent: async def test_handoff_participant_factories_autonomous_mode(): """Test autonomous mode with participant_factories.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage", handoff_to="specialist") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage", handoff_to="specialist") - def create_specialist() -> _RecordingAgent: - return _RecordingAgent(name="specialist") + def create_specialist() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist") workflow = ( HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) - .set_coordinator("triage") - .with_interaction_mode("autonomous", autonomous_turn_limit=2) + .with_start_agent("triage") + .with_autonomous_mode(agents=["specialist"], turn_limits={"specialist": 1}) .build() ) events = await _drain(workflow.run_stream("Issue")) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs, "Autonomous mode should yield output" requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert not requests, "Autonomous mode should not request user input" - - -async def test_handoff_participant_factories_with_request_info(): - """Test that .with_request_info() works with participant_factories.""" - - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage") - - def create_specialist() -> _RecordingAgent: - return _RecordingAgent(name="specialist") - - builder = ( - HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) - .set_coordinator("triage") - .with_request_info(agents=["triage"]) - ) - - workflow = builder.build() - assert "triage" in workflow.executors + assert requests and len(requests) == 1 + assert requests[0].source_executor_id == "specialist" def test_handoff_participant_factories_invalid_coordinator_name(): """Test that set_coordinator raises error for non-existent factory name.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage") with pytest.raises( - ValueError, match="coordinator factory name 'nonexistent' is not part of the participant_factories list" + ValueError, match="Start agent factory name 'nonexistent' is not in the participant_factories list" ): - (HandoffBuilder(participant_factories={"triage": create_triage}).set_coordinator("nonexistent").build()) + (HandoffBuilder(participant_factories={"triage": create_triage}).with_start_agent("nonexistent").build()) def test_handoff_participant_factories_invalid_handoff_target(): """Test that add_handoff raises error for non-existent target factory name.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage") - def create_specialist() -> _RecordingAgent: - return _RecordingAgent(name="specialist") + def create_specialist() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist") with pytest.raises(ValueError, match="Target factory name 'nonexistent' is not in the participant_factories list"): ( HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) - .set_coordinator("triage") - .add_handoff("triage", "nonexistent") + .with_start_agent("triage") + .add_handoff("triage", ["nonexistent"]) .build() ) -async def test_handoff_participant_factories_enable_return_to_previous(): - """Test return_to_previous works with participant_factories.""" - - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage", handoff_to="specialist_a") - - def create_specialist_a() -> _RecordingAgent: - return _RecordingAgent(name="specialist_a", handoff_to="specialist_b") - - def create_specialist_b() -> _RecordingAgent: - return _RecordingAgent(name="specialist_b") - - workflow = ( - HandoffBuilder( - participant_factories={ - "triage": create_triage, - "specialist_a": create_specialist_a, - "specialist_b": create_specialist_b, - } - ) - .set_coordinator("triage") - .add_handoff("triage", ["specialist_a", "specialist_b"]) - .add_handoff("specialist_a", "specialist_b") - .enable_return_to_previous(True) - .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3) - .build() - ) - - # Start conversation - triage hands off to specialist_a - events = await _drain(workflow.run_stream("Initial request")) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - - # Second user message - specialist_a hands off to specialist_b - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need escalation"})) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - - # Third user message - should route back to specialist_b (return to previous) - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up"})) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs or [ev for ev in events if isinstance(ev, RequestInfoEvent)] - - # endregion Participant Factory Tests - - -async def test_handoff_user_input_request_checkpoint_excludes_conversation(): - """Test that HandoffUserInputRequest serialization excludes conversation to prevent duplication. - - Issue #2667: When checkpointing a workflow with a pending HandoffUserInputRequest, - the conversation field gets serialized twice: once in the RequestInfoEvent's data - and once in the coordinator's conversation state. On restore, this causes duplicate - messages. - - The fix is to exclude the conversation field during checkpoint serialization since - the conversation is already preserved in the coordinator's state. - """ - # Create a conversation history - conversation = [ - ChatMessage(role=Role.USER, text="Hello"), - ChatMessage(role=Role.ASSISTANT, text="Hi there!"), - ChatMessage(role=Role.USER, text="Help me"), - ] - - # Create a HandoffUserInputRequest with the conversation - request = HandoffUserInputRequest( - conversation=conversation, - awaiting_agent_id="specialist_agent", - prompt="Please provide your input", - source_executor_id="gateway", - ) - - # Encode the request (simulating checkpoint save) - encoded = encode_checkpoint_value(request) - - # Verify conversation is NOT in the encoded output - # The fix should exclude conversation from serialization - assert isinstance(encoded, dict) - - # If using MODEL_MARKER strategy (to_dict/from_dict) - if "__af_model__" in encoded or "__af_dataclass__" in encoded: - value = encoded.get("value", {}) - assert "conversation" not in value, "conversation should be excluded from checkpoint serialization" - - # Decode the request (simulating checkpoint restore) - decoded = decode_checkpoint_value(encoded) - - # Verify the decoded request is a HandoffUserInputRequest - assert isinstance(decoded, HandoffUserInputRequest) - - # Verify other fields are preserved - assert decoded.awaiting_agent_id == "specialist_agent" - assert decoded.prompt == "Please provide your input" - assert decoded.source_executor_id == "gateway" - - # Conversation should be an empty list after deserialization - # (will be reconstructed from coordinator state on restore) - assert decoded.conversation == [] - - -async def test_handoff_user_input_request_roundtrip_preserves_metadata(): - """Test that non-conversation fields survive checkpoint roundtrip.""" - request = HandoffUserInputRequest( - conversation=[ChatMessage(role=Role.USER, text="test")], - awaiting_agent_id="test_agent", - prompt="Enter your response", - source_executor_id="test_gateway", - ) - - # Roundtrip through checkpoint encoding - encoded = encode_checkpoint_value(request) - decoded = decode_checkpoint_value(encoded) - - assert isinstance(decoded, HandoffUserInputRequest) - assert decoded.awaiting_agent_id == request.awaiting_agent_id - assert decoded.prompt == request.prompt - assert decoded.source_executor_id == request.source_executor_id - - -async def test_request_info_event_with_handoff_user_input_request(): - """Test RequestInfoEvent serialization with HandoffUserInputRequest data.""" - conversation = [ - ChatMessage(role=Role.USER, text="Hello"), - ChatMessage(role=Role.ASSISTANT, text="How can I help?"), - ] - - request = HandoffUserInputRequest( - conversation=conversation, - awaiting_agent_id="specialist", - prompt="Provide input", - source_executor_id="gateway", - ) - - # Create a RequestInfoEvent wrapping the request - event = RequestInfoEvent( - request_id="test-request-123", - source_executor_id="gateway", - request_data=request, - response_type=object, - ) - - # Serialize the event - event_dict = event.to_dict() - - # Verify the data field doesn't contain conversation - data_encoded = event_dict["data"] - if isinstance(data_encoded, dict) and ("__af_model__" in data_encoded or "__af_dataclass__" in data_encoded): - value = data_encoded.get("value", {}) - assert "conversation" not in value - - # Deserialize and verify - restored_event = RequestInfoEvent.from_dict(event_dict) - assert isinstance(restored_event.data, HandoffUserInputRequest) - assert restored_event.data.awaiting_agent_id == "specialist" - assert restored_event.data.conversation == [] - - -async def test_handoff_user_input_request_to_dict_excludes_conversation(): - """Test that to_dict() method excludes conversation field.""" - conversation = [ - ChatMessage(role=Role.USER, text="Hello"), - ChatMessage(role=Role.ASSISTANT, text="Hi!"), - ] - - request = HandoffUserInputRequest( - conversation=conversation, - awaiting_agent_id="agent1", - prompt="Enter input", - source_executor_id="gateway", - ) - - # Call to_dict directly - data = request.to_dict() - - # Verify conversation is excluded - assert "conversation" not in data - assert data["awaiting_agent_id"] == "agent1" - assert data["prompt"] == "Enter input" - assert data["source_executor_id"] == "gateway" - - -async def test_handoff_user_input_request_from_dict_creates_empty_conversation(): - """Test that from_dict() creates an instance with empty conversation.""" - data = { - "awaiting_agent_id": "agent1", - "prompt": "Enter input", - "source_executor_id": "gateway", - } - - request = HandoffUserInputRequest.from_dict(data) - - assert request.conversation == [] - assert request.awaiting_agent_id == "agent1" - assert request.prompt == "Enter input" - assert request.source_executor_id == "gateway" - - -async def test_user_input_gateway_resume_handles_empty_conversation(): - """Test that _UserInputGateway.resume_from_user handles post-restore scenario. - - After checkpoint restore, the HandoffUserInputRequest will have an empty - conversation. The gateway should handle this by sending only the new user - messages to the coordinator. - """ - from unittest.mock import AsyncMock - - # Create a gateway - gateway = _UserInputGateway( - starting_agent_id="coordinator", - prompt="Enter input", - id="test-gateway", - ) - - # Simulate post-restore: request with empty conversation - restored_request = HandoffUserInputRequest( - conversation=[], # Empty after restore - awaiting_agent_id="specialist", - prompt="Enter input", - source_executor_id="test-gateway", - ) - - # Create mock context - mock_ctx = MagicMock() - mock_ctx.send_message = AsyncMock() - - # Call resume_from_user with a user response - await gateway.resume_from_user(restored_request, "New user message", mock_ctx) - - # Verify send_message was called - mock_ctx.send_message.assert_called_once() - - # Get the message that was sent - call_args = mock_ctx.send_message.call_args - sent_message = call_args[0][0] - - # Verify it's a _ConversationWithUserInput - assert isinstance(sent_message, _ConversationWithUserInput) - - # Verify it contains only the new user message (not any history) - assert len(sent_message.full_conversation) == 1 - assert sent_message.full_conversation[0].role == Role.USER - assert sent_message.full_conversation[0].text == "New user message" - - -async def test_user_input_gateway_resume_with_full_conversation(): - """Test that _UserInputGateway.resume_from_user handles normal flow correctly. - - In normal flow (no checkpoint restore), the HandoffUserInputRequest has - the full conversation. The gateway should send the full conversation - plus the new user messages. - """ - from unittest.mock import AsyncMock - - # Create a gateway - gateway = _UserInputGateway( - starting_agent_id="coordinator", - prompt="Enter input", - id="test-gateway", - ) - - # Normal flow: request with full conversation - normal_request = HandoffUserInputRequest( - conversation=[ - ChatMessage(role=Role.USER, text="Hello"), - ChatMessage(role=Role.ASSISTANT, text="Hi!"), - ], - awaiting_agent_id="specialist", - prompt="Enter input", - source_executor_id="test-gateway", - ) - - # Create mock context - mock_ctx = MagicMock() - mock_ctx.send_message = AsyncMock() - - # Call resume_from_user with a user response - await gateway.resume_from_user(normal_request, "Follow up message", mock_ctx) - - # Verify send_message was called - mock_ctx.send_message.assert_called_once() - - # Get the message that was sent - call_args = mock_ctx.send_message.call_args - sent_message = call_args[0][0] - - # Verify it's a _ConversationWithUserInput - assert isinstance(sent_message, _ConversationWithUserInput) - - # Verify it contains the full conversation plus new user message - assert len(sent_message.full_conversation) == 3 - assert sent_message.full_conversation[0].text == "Hello" - assert sent_message.full_conversation[1].text == "Hi!" - assert sent_message.full_conversation[2].text == "Follow up message" - - -async def test_coordinator_handle_user_input_post_restore(): - """Test that _HandoffCoordinator.handle_user_input handles post-restore correctly. - - After checkpoint restore, the coordinator has its conversation restored, - and the gateway sends only the new user messages. The coordinator should - append these to its existing conversation rather than replacing. - """ - from unittest.mock import AsyncMock - - from agent_framework._workflows._handoff import _HandoffCoordinator - - # Create a coordinator with pre-existing conversation (simulating restored state) - coordinator = _HandoffCoordinator( - starting_agent_id="triage", - specialist_ids={"specialist_a": "specialist_a"}, - input_gateway_id="gateway", - termination_condition=lambda conv: False, - id="test-coordinator", - ) - - # Simulate restored conversation - coordinator._conversation = [ - ChatMessage(role=Role.USER, text="Hello"), - ChatMessage(role=Role.ASSISTANT, text="Hi there!"), - ChatMessage(role=Role.USER, text="Help me"), - ChatMessage(role=Role.ASSISTANT, text="Sure, what do you need?"), - ] - - # Create mock context - mock_ctx = MagicMock() - mock_ctx.send_message = AsyncMock() - - # Simulate post-restore: only new user message with explicit flag - incoming = _ConversationWithUserInput( - full_conversation=[ChatMessage(role=Role.USER, text="I need shipping help")], - is_post_restore=True, - ) - - # Handle the user input - await coordinator.handle_user_input(incoming, mock_ctx) - - # Verify conversation was appended, not replaced - assert len(coordinator._conversation) == 5 - assert coordinator._conversation[0].text == "Hello" - assert coordinator._conversation[1].text == "Hi there!" - assert coordinator._conversation[2].text == "Help me" - assert coordinator._conversation[3].text == "Sure, what do you need?" - assert coordinator._conversation[4].text == "I need shipping help" - - -async def test_coordinator_handle_user_input_normal_flow(): - """Test that _HandoffCoordinator.handle_user_input handles normal flow correctly. - - In normal flow (no restore), the gateway sends the full conversation. - The coordinator should replace its conversation with the incoming one. - """ - from unittest.mock import AsyncMock - - from agent_framework._workflows._handoff import _HandoffCoordinator - - # Create a coordinator - coordinator = _HandoffCoordinator( - starting_agent_id="triage", - specialist_ids={"specialist_a": "specialist_a"}, - input_gateway_id="gateway", - termination_condition=lambda conv: False, - id="test-coordinator", - ) - - # Set some initial conversation - coordinator._conversation = [ - ChatMessage(role=Role.USER, text="Old message"), - ] - - # Create mock context - mock_ctx = MagicMock() - mock_ctx.send_message = AsyncMock() - - # Normal flow: full conversation including new user message (is_post_restore=False by default) - incoming = _ConversationWithUserInput( - full_conversation=[ - ChatMessage(role=Role.USER, text="Hello"), - ChatMessage(role=Role.ASSISTANT, text="Hi!"), - ChatMessage(role=Role.USER, text="New message"), - ], - is_post_restore=False, - ) - - # Handle the user input - await coordinator.handle_user_input(incoming, mock_ctx) - - # Verify conversation was replaced (normal flow with full history) - assert len(coordinator._conversation) == 3 - assert coordinator._conversation[0].text == "Hello" - assert coordinator._conversation[1].text == "Hi!" - assert coordinator._conversation[2].text == "New message" - - -async def test_coordinator_handle_user_input_multiple_consecutive_user_messages(): - """Test that multiple consecutive USER messages in normal flow are handled correctly. - - This is a regression test for the edge case where a user submits multiple consecutive - USER messages. The explicit is_post_restore flag ensures this doesn't get incorrectly - detected as a post-restore scenario. - """ - from unittest.mock import AsyncMock - - from agent_framework._workflows._handoff import _HandoffCoordinator - - # Create a coordinator with existing conversation - coordinator = _HandoffCoordinator( - starting_agent_id="triage", - specialist_ids={"specialist_a": "specialist_a"}, - input_gateway_id="gateway", - termination_condition=lambda conv: False, - id="test-coordinator", - ) - - # Set existing conversation with 4 messages - coordinator._conversation = [ - ChatMessage(role=Role.USER, text="Original message 1"), - ChatMessage(role=Role.ASSISTANT, text="Response 1"), - ChatMessage(role=Role.USER, text="Original message 2"), - ChatMessage(role=Role.ASSISTANT, text="Response 2"), - ] - - # Create mock context - mock_ctx = MagicMock() - mock_ctx.send_message = AsyncMock() - - # Normal flow: User sends multiple consecutive USER messages - # This should REPLACE the conversation, not append to it - incoming = _ConversationWithUserInput( - full_conversation=[ - ChatMessage(role=Role.USER, text="New user message 1"), - ChatMessage(role=Role.USER, text="New user message 2"), - ], - is_post_restore=False, # Explicit flag - this is normal flow - ) - - # Handle the user input - await coordinator.handle_user_input(incoming, mock_ctx) - - # Verify conversation was REPLACED (not appended) - # Without the explicit flag, the old heuristic might incorrectly append - assert len(coordinator._conversation) == 2 - assert coordinator._conversation[0].text == "New user message 1" - assert coordinator._conversation[1].text == "New user message 2" diff --git a/python/packages/core/tests/workflow/test_magentic.py b/python/packages/core/tests/workflow/test_magentic.py index 4ee16ddb5f..7e4a5bb48e 100644 --- a/python/packages/core/tests/workflow/test_magentic.py +++ b/python/packages/core/tests/workflow/test_magentic.py @@ -3,44 +3,43 @@ import sys from collections.abc import AsyncIterable from dataclasses import dataclass -from typing import Any, cast +from typing import Any, ClassVar, cast import pytest from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentProtocol, + AgentResponse, + AgentResponseUpdate, AgentRunUpdateEvent, + AgentThread, BaseAgent, ChatMessage, Executor, + GroupChatRequestMessage, MagenticBuilder, - MagenticHumanInterventionDecision, - MagenticHumanInterventionReply, - MagenticHumanInterventionRequest, + MagenticContext, MagenticManagerBase, + MagenticOrchestrator, + MagenticOrchestratorEvent, + MagenticPlanReviewRequest, + MagenticProgressLedger, + MagenticProgressLedgerItem, RequestInfoEvent, Role, + StandardMagenticManager, TextContent, + Workflow, WorkflowCheckpoint, + WorkflowCheckpointException, WorkflowContext, - WorkflowEvent, # type: ignore # noqa: E402 + WorkflowEvent, WorkflowOutputEvent, WorkflowRunState, WorkflowStatusEvent, handler, ) -from agent_framework._workflows import _group_chat as group_chat_module # type: ignore from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage -from agent_framework._workflows._magentic import ( # type: ignore[reportPrivateUsage] - MagenticAgentExecutor, - MagenticContext, - MagenticOrchestratorExecutor, - _MagenticProgressLedger, # type: ignore - _MagenticProgressLedgerItem, # type: ignore - _MagenticStartMessage, # type: ignore -) -from agent_framework._workflows._workflow_builder import WorkflowBuilder if sys.version_info >= (3, 12): from typing import override @@ -48,40 +47,9 @@ from typing_extensions import override -def test_magentic_start_message_from_string(): - msg = _MagenticStartMessage.from_string("Do the thing") - assert isinstance(msg, _MagenticStartMessage) - assert isinstance(msg.task, ChatMessage) - assert msg.task.role == Role.USER - assert msg.task.text == "Do the thing" - - -def test_human_intervention_request_defaults_and_reply_variants(): - from agent_framework._workflows._magentic import MagenticHumanInterventionKind - - req = MagenticHumanInterventionRequest(kind=MagenticHumanInterventionKind.PLAN_REVIEW) - assert hasattr(req, "request_id") - assert req.task_text == "" and req.facts_text == "" and req.plan_text == "" - assert isinstance(req.round_index, int) and req.round_index == 0 - - # Replies: approve, revise with comments, revise with edited text - approve = MagenticHumanInterventionReply(decision=MagenticHumanInterventionDecision.APPROVE) - revise_comments = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.REVISE, comments="Tighten scope" - ) - revise_text = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.REVISE, - edited_plan_text="- Step 1\n- Step 2", - ) - - assert approve.decision == MagenticHumanInterventionDecision.APPROVE - assert revise_comments.comments == "Tighten scope" - assert revise_text.edited_plan_text is not None and revise_text.edited_plan_text.startswith("- Step 1") - - def test_magentic_context_reset_behavior(): ctx = MagenticContext( - task=ChatMessage(role=Role.USER, text="task"), + task="task", participant_descriptions={"Alice": "Researcher"}, ) # seed context state @@ -105,10 +73,24 @@ class _SimpleLedger: class FakeManager(MagenticManagerBase): """Deterministic manager for tests that avoids real LLM calls.""" - task_ledger: _SimpleLedger | None = None - satisfied_after_signoff: bool = True - next_speaker_name: str = "agentA" - instruction_text: str = "Proceed with step 1" + FINAL_ANSWER: ClassVar[str] = "FINAL" + + def __init__( + self, + *, + max_stall_count: int = 3, + max_reset_count: int | None = None, + max_round_count: int | None = None, + ) -> None: + super().__init__( + max_stall_count=max_stall_count, + max_reset_count=max_reset_count, + max_round_count=max_round_count, + ) + self.name = "magentic_manager" + self.task_ledger: _SimpleLedger | None = None + self.next_speaker_name: str = "agentA" + self.instruction_text: str = "Proceed with step 1" @override def on_checkpoint_save(self) -> dict[str, Any]: @@ -141,47 +123,117 @@ async def plan(self, magentic_context: MagenticContext) -> ChatMessage: facts = ChatMessage(role=Role.ASSISTANT, text="GIVEN OR VERIFIED FACTS\n- A\n") plan = ChatMessage(role=Role.ASSISTANT, text="- Do X\n- Do Y\n") self.task_ledger = _SimpleLedger(facts=facts, plan=plan) - combined = f"Task: {magentic_context.task.text}\n\nFacts:\n{facts.text}\n\nPlan:\n{plan.text}" - return ChatMessage(role=Role.ASSISTANT, text=combined, author_name="magentic_manager") + combined = f"Task: {magentic_context.task}\n\nFacts:\n{facts.text}\n\nPlan:\n{plan.text}" + return ChatMessage(role=Role.ASSISTANT, text=combined, author_name=self.name) async def replan(self, magentic_context: MagenticContext) -> ChatMessage: facts = ChatMessage(role=Role.ASSISTANT, text="GIVEN OR VERIFIED FACTS\n- A2\n") plan = ChatMessage(role=Role.ASSISTANT, text="- Do Z\n") self.task_ledger = _SimpleLedger(facts=facts, plan=plan) - combined = f"Task: {magentic_context.task.text}\n\nFacts:\n{facts.text}\n\nPlan:\n{plan.text}" - return ChatMessage(role=Role.ASSISTANT, text=combined, author_name="magentic_manager") - - async def create_progress_ledger(self, magentic_context: MagenticContext) -> _MagenticProgressLedger: - is_satisfied = self.satisfied_after_signoff and len(magentic_context.chat_history) > 0 - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(reason="test", answer=is_satisfied), - is_in_loop=_MagenticProgressLedgerItem(reason="test", answer=False), - is_progress_being_made=_MagenticProgressLedgerItem(reason="test", answer=True), - next_speaker=_MagenticProgressLedgerItem(reason="test", answer=self.next_speaker_name), - instruction_or_question=_MagenticProgressLedgerItem(reason="test", answer=self.instruction_text), + combined = f"Task: {magentic_context.task}\n\nFacts:\n{facts.text}\n\nPlan:\n{plan.text}" + return ChatMessage(role=Role.ASSISTANT, text=combined, author_name=self.name) + + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: + # At least two messages in chat history means request is satisfied for testing + is_satisfied = len(magentic_context.chat_history) > 1 + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(reason="test", answer=is_satisfied), + is_in_loop=MagenticProgressLedgerItem(reason="test", answer=False), + is_progress_being_made=MagenticProgressLedgerItem(reason="test", answer=True), + next_speaker=MagenticProgressLedgerItem(reason="test", answer=self.next_speaker_name), + instruction_or_question=MagenticProgressLedgerItem(reason="test", answer=self.instruction_text), ) async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessage: - return ChatMessage(role=Role.ASSISTANT, text="FINAL", author_name="magentic_manager") + return ChatMessage(role=Role.ASSISTANT, text=self.FINAL_ANSWER, author_name=self.name) + + +class StubAgent(BaseAgent): + def __init__(self, agent_name: str, reply_text: str, **kwargs: Any) -> None: + super().__init__(name=agent_name, description=f"Stub agent {agent_name}", **kwargs) + self._reply_text = reply_text + + async def run( # type: ignore[override] + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AgentResponse: + response = ChatMessage(role=Role.ASSISTANT, text=self._reply_text, author_name=self.name) + return AgentResponse(messages=[response]) + + def run_stream( # type: ignore[override] + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AsyncIterable[AgentResponseUpdate]: + async def _stream() -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate( + contents=[TextContent(text=self._reply_text)], role=Role.ASSISTANT, author_name=self.name + ) + + return _stream() + + +class DummyExec(Executor): + def __init__(self, name: str) -> None: + super().__init__(name) + + @handler + async def _noop( + self, message: GroupChatRequestMessage, ctx: WorkflowContext[ChatMessage] + ) -> None: # pragma: no cover - not called + pass + + +async def test_magentic_builder_returns_workflow_and_runs() -> None: + manager = FakeManager() + agent = StubAgent(manager.next_speaker_name, "first draft") + + workflow = MagenticBuilder().participants([agent]).with_standard_manager(manager).build() + assert isinstance(workflow, Workflow) + + outputs: list[ChatMessage] = [] + orchestrator_event_count = 0 + async for event in workflow.run_stream("compose summary"): + if isinstance(event, WorkflowOutputEvent): + msg = event.data + if isinstance(msg, list): + outputs.extend(cast(list[ChatMessage], msg)) + elif isinstance(event, MagenticOrchestratorEvent): + orchestrator_event_count += 1 + + assert outputs, "Expected a final output message" + assert len(outputs) >= 1 + final = outputs[-1] + assert final.text == manager.FINAL_ANSWER + assert final.author_name == manager.name + assert orchestrator_event_count > 0, "Expected orchestrator events to be emitted" -class _CountingWorkflowBuilder(WorkflowBuilder): - created: list["_CountingWorkflowBuilder"] = [] - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.start_calls = 0 - _CountingWorkflowBuilder.created.append(self) +async def test_magentic_as_agent_does_not_accept_conversation() -> None: + manager = FakeManager() + writer = StubAgent(manager.next_speaker_name, "summary response") + + workflow = MagenticBuilder().participants([writer]).with_standard_manager(manager).build() - def set_start_executor(self, executor: Any) -> "_CountingWorkflowBuilder": # type: ignore[override] - self.start_calls += 1 - return cast("_CountingWorkflowBuilder", super().set_start_executor(executor)) + agent = workflow.as_agent(name="magentic-agent") + conversation = [ + ChatMessage(role=Role.SYSTEM, text="Guidelines", author_name="system"), + ChatMessage(role=Role.USER, text="Summarize the findings", author_name="requester"), + ] + with pytest.raises(ValueError, match="Magentic only support a single task message to start the workflow."): + await agent.run(conversation) async def test_standard_manager_plan_and_replan_combined_ledger(): - manager = FakeManager(max_round_count=10, max_stall_count=3, max_reset_count=2) + manager = FakeManager() ctx = MagenticContext( - task=ChatMessage(role=Role.USER, text="demo task"), + task="demo task", participant_descriptions={"agentA": "Agent A"}, ) @@ -193,55 +245,34 @@ async def test_standard_manager_plan_and_replan_combined_ledger(): assert "A2" in replanned.text or "Do Z" in replanned.text -async def test_standard_manager_progress_ledger_and_fallback(): - manager = FakeManager(max_round_count=10) - ctx = MagenticContext( - task=ChatMessage(role=Role.USER, text="demo"), - participant_descriptions={"agentA": "Agent A"}, - ) - - ledger = await manager.create_progress_ledger(ctx.clone()) - assert isinstance(ledger, _MagenticProgressLedger) - assert ledger.next_speaker.answer == "agentA" - - manager.satisfied_after_signoff = False - ledger2 = await manager.create_progress_ledger(ctx.clone()) - assert ledger2.is_request_satisfied.answer is False - - async def test_magentic_workflow_plan_review_approval_to_completion(): - manager = FakeManager(max_round_count=10) - wf = ( - MagenticBuilder() - .participants(agentA=_DummyExec("agentA")) - .with_standard_manager(manager) - .with_plan_review() - .build() - ) + manager = FakeManager() + wf = MagenticBuilder().participants([DummyExec("agentA")]).with_standard_manager(manager).with_plan_review().build() req_event: RequestInfoEvent | None = None async for ev in wf.run_stream("do work"): - if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticHumanInterventionRequest: + if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticPlanReviewRequest: req_event = ev assert req_event is not None + assert isinstance(req_event.data, MagenticPlanReviewRequest) completed = False output: list[ChatMessage] | None = None - reply = MagenticHumanInterventionReply(decision=MagenticHumanInterventionDecision.APPROVE) - async for ev in wf.send_responses_streaming(responses={req_event.request_id: reply}): + async for ev in wf.send_responses_streaming(responses={req_event.request_id: req_event.data.approve()}): if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE: completed = True elif isinstance(ev, WorkflowOutputEvent): output = ev.data # type: ignore[assignment] if completed and output is not None: break + assert completed assert output is not None assert isinstance(output, list) assert all(isinstance(msg, ChatMessage) for msg in output) -async def test_magentic_plan_review_approve_with_comments_replans_and_proceeds(): +async def test_magentic_plan_review_with_revise(): class CountingManager(FakeManager): # Declare as a model field so assignment is allowed under Pydantic replan_count: int = 0 @@ -253,10 +284,10 @@ async def replan(self, magentic_context: MagenticContext) -> ChatMessage: # typ self.replan_count += 1 return await super().replan(magentic_context) - manager = CountingManager(max_round_count=10) + manager = CountingManager() wf = ( MagenticBuilder() - .participants(agentA=_DummyExec("agentA")) + .participants([DummyExec(name=manager.next_speaker_name)]) .with_standard_manager(manager) .with_plan_review() .build() @@ -265,30 +296,32 @@ async def replan(self, magentic_context: MagenticContext) -> ChatMessage: # typ # Wait for the initial plan review request req_event: RequestInfoEvent | None = None async for ev in wf.run_stream("do work"): - if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticHumanInterventionRequest: + if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticPlanReviewRequest: req_event = ev assert req_event is not None + assert isinstance(req_event.data, MagenticPlanReviewRequest) - # Reply APPROVE with comments (no edited text). Expect one replan and no second review round. + # Send a revise response saw_second_review = False completed = False async for ev in wf.send_responses_streaming( - responses={ - req_event.request_id: MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.APPROVE, - comments="Looks good; consider Z", - ) - } + responses={req_event.request_id: req_event.data.revise("Looks good; consider Z")} ): - if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticHumanInterventionRequest: + if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticPlanReviewRequest: saw_second_review = True + req_event = ev + + # Approve the second review + async for ev in wf.send_responses_streaming( + responses={req_event.request_id: req_event.data.approve()} # type: ignore[union-attr] + ): if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE: completed = True break assert completed assert manager.replan_count >= 1 - assert saw_second_review is False + assert saw_second_review is True # Replan from FakeManager updates facts/plan to include A2 / Do Z assert manager.task_ledger is not None combined_text = (manager.task_ledger.facts.text or "") + (manager.task_ledger.plan.text or "") @@ -297,19 +330,20 @@ async def replan(self, magentic_context: MagenticContext) -> ChatMessage: # typ async def test_magentic_orchestrator_round_limit_produces_partial_result(): manager = FakeManager(max_round_count=1) - manager.satisfied_after_signoff = False - wf = MagenticBuilder().participants(agentA=_DummyExec("agentA")).with_standard_manager(manager).build() - - from agent_framework import WorkflowEvent # type: ignore + wf = ( + MagenticBuilder() + .participants([DummyExec(name=manager.next_speaker_name)]) + .with_standard_manager(manager) + .build() + ) events: list[WorkflowEvent] = [] async for ev in wf.run_stream("round limit test"): events.append(ev) - if len(events) > 50: - break idle_status = next( - (e for e in events if isinstance(e, WorkflowStatusEvent) and e.state == WorkflowRunState.IDLE), None + (e for e in events if isinstance(e, WorkflowStatusEvent) and e.state == WorkflowRunState.IDLE), + None, ) assert idle_status is not None # Check that we got workflow output via WorkflowOutputEvent @@ -317,18 +351,18 @@ async def test_magentic_orchestrator_round_limit_produces_partial_result(): assert output_event is not None data = output_event.data assert isinstance(data, list) - assert all(isinstance(msg, ChatMessage) for msg in data) - assert len(data) > 0 - assert data[-1].role == Role.ASSISTANT + assert len(data) > 0 # type: ignore + assert data[-1].role == Role.ASSISTANT # type: ignore + assert all(isinstance(msg, ChatMessage) for msg in data) # type: ignore async def test_magentic_checkpoint_resume_round_trip(): storage = InMemoryCheckpointStorage() - manager1 = FakeManager(max_round_count=10) + manager1 = FakeManager() wf = ( MagenticBuilder() - .participants(agentA=_DummyExec("agentA")) + .participants([DummyExec(name=manager1.next_speaker_name)]) .with_standard_manager(manager1) .with_plan_review() .with_checkpointing(storage) @@ -338,99 +372,52 @@ async def test_magentic_checkpoint_resume_round_trip(): task_text = "checkpoint task" req_event: RequestInfoEvent | None = None async for ev in wf.run_stream(task_text): - if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticHumanInterventionRequest: + if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticPlanReviewRequest: req_event = ev assert req_event is not None + assert isinstance(req_event.data, MagenticPlanReviewRequest) checkpoints = await storage.list_checkpoints() assert checkpoints checkpoints.sort(key=lambda cp: cp.timestamp) resume_checkpoint = checkpoints[-1] - manager2 = FakeManager(max_round_count=10) + manager2 = FakeManager() wf_resume = ( MagenticBuilder() - .participants(agentA=_DummyExec("agentA")) + .participants([DummyExec(name=manager2.next_speaker_name)]) .with_standard_manager(manager2) .with_plan_review() .with_checkpointing(storage) .build() ) - orchestrator = next(exec for exec in wf_resume.executors.values() if isinstance(exec, MagenticOrchestratorExecutor)) - - reply = MagenticHumanInterventionReply(decision=MagenticHumanInterventionDecision.APPROVE) completed: WorkflowOutputEvent | None = None req_event = None async for event in wf_resume.run_stream( resume_checkpoint.checkpoint_id, ): - if isinstance(event, RequestInfoEvent) and event.request_type is MagenticHumanInterventionRequest: + if isinstance(event, RequestInfoEvent) and event.request_type is MagenticPlanReviewRequest: req_event = event assert req_event is not None + assert isinstance(req_event.data, MagenticPlanReviewRequest) - responses = {req_event.request_id: reply} + responses = {req_event.request_id: req_event.data.approve()} async for event in wf_resume.send_responses_streaming(responses=responses): if isinstance(event, WorkflowOutputEvent): completed = event assert completed is not None - assert orchestrator._context is not None # type: ignore[reportPrivateUsage] - assert orchestrator._context.chat_history # type: ignore[reportPrivateUsage] + orchestrator = next(exec for exec in wf_resume.executors.values() if isinstance(exec, MagenticOrchestrator)) + assert orchestrator._magentic_context is not None # type: ignore[reportPrivateUsage] + assert orchestrator._magentic_context.chat_history # type: ignore[reportPrivateUsage] assert orchestrator._task_ledger is not None # type: ignore[reportPrivateUsage] assert manager2.task_ledger is not None # Latest entry in chat history should be the task ledger plan - assert orchestrator._context.chat_history[-1].text == orchestrator._task_ledger.text # type: ignore[reportPrivateUsage] - - -class _DummyExec(Executor): - def __init__(self, name: str) -> None: - super().__init__(name) - - @handler - async def _noop(self, message: object, ctx: WorkflowContext[object]) -> None: # pragma: no cover - not called - pass - - -def test_magentic_builder_sets_start_executor_once(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure MagenticBuilder wiring sets the start executor only once.""" - _CountingWorkflowBuilder.created.clear() - monkeypatch.setattr(group_chat_module, "WorkflowBuilder", _CountingWorkflowBuilder) - - manager = FakeManager() - - workflow = ( - MagenticBuilder().participants(agentA=_DummyExec("agentA")).with_standard_manager(manager=manager).build() - ) - - assert workflow is not None - assert _CountingWorkflowBuilder.created, "Expected CountingWorkflowBuilder to be instantiated" - builder = _CountingWorkflowBuilder.created[-1] - assert builder.start_calls == 1, "set_start_executor should be called exactly once" + assert orchestrator._magentic_context.chat_history[-1].text == orchestrator._task_ledger.text # type: ignore[reportPrivateUsage] -async def test_magentic_agent_executor_on_checkpoint_save_and_restore_roundtrip(): - backing_executor = _DummyExec("backing") - agent_exec = MagenticAgentExecutor(backing_executor, "agentA") - agent_exec._chat_history.extend([ # type: ignore[reportPrivateUsage] - ChatMessage(role=Role.USER, text="hello"), - ChatMessage(role=Role.ASSISTANT, text="world", author_name="agentA"), - ]) - - state = await agent_exec.on_checkpoint_save() - - restored_executor = MagenticAgentExecutor(_DummyExec("backing2"), "agentA") - await restored_executor.on_checkpoint_restore(state) - - assert len(restored_executor._chat_history) == 2 # type: ignore[reportPrivateUsage] - assert restored_executor._chat_history[0].text == "hello" # type: ignore[reportPrivateUsage] - assert restored_executor._chat_history[1].author_name == "agentA" # type: ignore[reportPrivateUsage] - - -from agent_framework import StandardMagenticManager # noqa: E402 - - -class _StubManagerAgent(BaseAgent): +class StubManagerAgent(BaseAgent): """Stub agent for testing StandardMagenticManager.""" async def run( @@ -439,8 +426,8 @@ async def run( *, thread: Any = None, **kwargs: Any, - ) -> AgentRunResponse: - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="ok")]) + ) -> AgentResponse: + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="ok")]) def run_stream( self, @@ -448,15 +435,15 @@ def run_stream( *, thread: Any = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: - async def _gen() -> AsyncIterable[AgentRunResponseUpdate]: - yield AgentRunResponseUpdate(message_deltas=[ChatMessage(role=Role.ASSISTANT, text="ok")]) + ) -> AsyncIterable[AgentResponseUpdate]: + async def _gen() -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate(message_deltas=[ChatMessage(role=Role.ASSISTANT, text="ok")]) return _gen() async def test_standard_manager_plan_and_replan_via_complete_monkeypatch(): - mgr = StandardMagenticManager(agent=_StubManagerAgent()) + mgr = StandardMagenticManager(StubManagerAgent()) async def fake_complete_plan(messages: list[ChatMessage], **kwargs: Any) -> ChatMessage: # Return a different response depending on call order length @@ -467,10 +454,7 @@ async def fake_complete_plan(messages: list[ChatMessage], **kwargs: Any) -> Chat # First, patch to produce facts then plan mgr._complete = fake_complete_plan # type: ignore[attr-defined] - ctx = MagenticContext( - task=ChatMessage(role=Role.USER, text="T"), - participant_descriptions={"A": "desc"}, - ) + ctx = MagenticContext(task="T", participant_descriptions={"A": "desc"}) combined = await mgr.plan(ctx.clone()) # Assert structural headings and that steps appear in the combined ledger output. assert "We are working to address the following user request:" in combined.text @@ -489,11 +473,8 @@ async def fake_complete_replan(messages: list[ChatMessage], **kwargs: Any) -> Ch async def test_standard_manager_progress_ledger_success_and_error(): - mgr = StandardMagenticManager(agent=_StubManagerAgent()) - ctx = MagenticContext( - task=ChatMessage(role=Role.USER, text="task"), - participant_descriptions={"alice": "desc"}, - ) + mgr = StandardMagenticManager(agent=StubManagerAgent()) + ctx = MagenticContext(task="task", participant_descriptions={"alice": "desc"}) # Success path: valid JSON async def fake_complete_ok(messages: list[ChatMessage], **kwargs: Any) -> ChatMessage: @@ -530,24 +511,24 @@ async def plan(self, magentic_context: MagenticContext) -> ChatMessage: async def replan(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="re-ledger") - async def create_progress_ledger(self, magentic_context: MagenticContext) -> _MagenticProgressLedger: + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: if not self._invoked: # First round: ask agentA to respond self._invoked = True - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(reason="r", answer=False), - is_in_loop=_MagenticProgressLedgerItem(reason="r", answer=False), - is_progress_being_made=_MagenticProgressLedgerItem(reason="r", answer=True), - next_speaker=_MagenticProgressLedgerItem(reason="r", answer="agentA"), - instruction_or_question=_MagenticProgressLedgerItem(reason="r", answer="say hi"), + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(reason="r", answer=False), + is_in_loop=MagenticProgressLedgerItem(reason="r", answer=False), + is_progress_being_made=MagenticProgressLedgerItem(reason="r", answer=True), + next_speaker=MagenticProgressLedgerItem(reason="r", answer="agentA"), + instruction_or_question=MagenticProgressLedgerItem(reason="r", answer="say hi"), ) # Next round: mark satisfied so run can conclude - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(reason="r", answer=True), - is_in_loop=_MagenticProgressLedgerItem(reason="r", answer=False), - is_progress_being_made=_MagenticProgressLedgerItem(reason="r", answer=True), - next_speaker=_MagenticProgressLedgerItem(reason="r", answer="agentA"), - instruction_or_question=_MagenticProgressLedgerItem(reason="r", answer="done"), + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(reason="r", answer=True), + is_in_loop=MagenticProgressLedgerItem(reason="r", answer=False), + is_progress_being_made=MagenticProgressLedgerItem(reason="r", answer=True), + next_speaker=MagenticProgressLedgerItem(reason="r", answer="agentA"), + instruction_or_question=MagenticProgressLedgerItem(reason="r", answer="done"), ) async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessage: @@ -555,15 +536,18 @@ async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatM class StubThreadAgent(BaseAgent): + def __init__(self, name: str | None = None) -> None: + super().__init__(name=name or "agentA") + async def run_stream(self, messages=None, *, thread=None, **kwargs): # type: ignore[override] - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=[TextContent(text="thread-ok")], - author_name="agentA", + author_name=self.name, role=Role.ASSISTANT, ) async def run(self, messages=None, *, thread=None, **kwargs): # type: ignore[override] - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="thread-ok", author_name="agentA")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="thread-ok", author_name=self.name)]) class StubAssistantsClient: @@ -574,29 +558,24 @@ class StubAssistantsAgent(BaseAgent): chat_client: object | None = None # allow assignment via Pydantic field def __init__(self) -> None: - super().__init__() + super().__init__(name="agentA") self.chat_client = StubAssistantsClient() # type name contains 'AssistantsClient' async def run_stream(self, messages=None, *, thread=None, **kwargs): # type: ignore[override] - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=[TextContent(text="assistants-ok")], - author_name="agentA", + author_name=self.name, role=Role.ASSISTANT, ) async def run(self, messages=None, *, thread=None, **kwargs): # type: ignore[override] - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="assistants-ok", author_name="agentA")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="assistants-ok", author_name=self.name)]) -async def _collect_agent_responses_setup(participant_obj: object): +async def _collect_agent_responses_setup(participant: AgentProtocol) -> list[ChatMessage]: captured: list[ChatMessage] = [] - wf = ( - MagenticBuilder() - .participants(agentA=participant_obj) # type: ignore[arg-type] - .with_standard_manager(InvokeOnceManager()) - .build() - ) + wf = MagenticBuilder().participants([participant]).with_standard_manager(InvokeOnceManager()).build() # Run a bounded stream to allow one invoke and then completion events: list[WorkflowEvent] = [] @@ -604,30 +583,34 @@ async def _collect_agent_responses_setup(participant_obj: object): events.append(ev) if isinstance(ev, WorkflowOutputEvent): break - if isinstance(ev, AgentRunUpdateEvent) and ev.data is not None: + if isinstance(ev, AgentRunUpdateEvent): captured.append( ChatMessage( - role=ev.data.role or Role.ASSISTANT, text=ev.data.text or "", author_name=ev.data.author_name + role=ev.data.role or Role.ASSISTANT, + text=ev.data.text or "", + author_name=ev.data.author_name, ) ) - if len(events) > 50: - break return captured async def test_agent_executor_invoke_with_thread_chat_client(): - captured = await _collect_agent_responses_setup(StubThreadAgent()) + agent = StubThreadAgent() + captured = await _collect_agent_responses_setup(agent) # Should have at least one response from agentA via _MagenticAgentExecutor path - assert any((m.author_name == "agentA" and "ok" in (m.text or "")) for m in captured) + assert any((m.author_name == agent.name and "ok" in (m.text or "")) for m in captured) async def test_agent_executor_invoke_with_assistants_client_messages(): - captured = await _collect_agent_responses_setup(StubAssistantsAgent()) - assert any((m.author_name == "agentA" and "ok" in (m.text or "")) for m in captured) + agent = StubAssistantsAgent() + captured = await _collect_agent_responses_setup(agent) + assert any((m.author_name == agent.name and "ok" in (m.text or "")) for m in captured) -async def _collect_checkpoints(storage: InMemoryCheckpointStorage) -> list[WorkflowCheckpoint]: +async def _collect_checkpoints( + storage: InMemoryCheckpointStorage, +) -> list[WorkflowCheckpoint]: checkpoints = await storage.list_checkpoints() assert checkpoints checkpoints.sort(key=lambda cp: cp.timestamp) @@ -639,7 +622,7 @@ async def test_magentic_checkpoint_resume_inner_loop_superstep(): workflow = ( MagenticBuilder() - .participants(agentA=StubThreadAgent()) + .participants([StubThreadAgent()]) .with_standard_manager(InvokeOnceManager()) .with_checkpointing(storage) .build() @@ -654,7 +637,7 @@ async def test_magentic_checkpoint_resume_inner_loop_superstep(): resumed = ( MagenticBuilder() - .participants(agentA=StubThreadAgent()) + .participants([StubThreadAgent()]) .with_standard_manager(InvokeOnceManager()) .with_checkpointing(storage) .build() @@ -668,7 +651,8 @@ async def test_magentic_checkpoint_resume_inner_loop_superstep(): assert completed is not None -async def test_magentic_checkpoint_resume_after_reset(): +async def test_magentic_checkpoint_resume_from_saved_state(): + """Test that we can resume workflow execution from a saved checkpoint.""" storage = InMemoryCheckpointStorage() # Use the working InvokeOnceManager first to get a completed workflow @@ -676,27 +660,24 @@ async def test_magentic_checkpoint_resume_after_reset(): workflow = ( MagenticBuilder() - .participants(agentA=StubThreadAgent()) + .participants([StubThreadAgent()]) .with_standard_manager(manager) .with_checkpointing(storage) .build() ) - async for event in workflow.run_stream("reset task"): + async for event in workflow.run_stream("checkpoint resume task"): if isinstance(event, WorkflowOutputEvent): break checkpoints = await _collect_checkpoints(storage) - # For this test, we just need to verify that we can resume from any checkpoint - # The original test intention was to test resuming after a reset has occurred - # Since we can't easily simulate a reset in the test environment without causing hangs, - # we'll test the basic checkpoint resume functionality which is the core requirement + # Verify we can resume from the last saved checkpoint resumed_state = checkpoints[-1] # Use the last checkpoint resumed_workflow = ( MagenticBuilder() - .participants(agentA=StubThreadAgent()) + .participants([StubThreadAgent()]) .with_standard_manager(InvokeOnceManager()) .with_checkpointing(storage) .build() @@ -717,7 +698,7 @@ async def test_magentic_checkpoint_resume_rejects_participant_renames(): workflow = ( MagenticBuilder() - .participants(agentA=StubThreadAgent()) + .participants([StubThreadAgent()]) .with_standard_manager(manager) .with_plan_review() .with_checkpointing(storage) @@ -726,24 +707,25 @@ async def test_magentic_checkpoint_resume_rejects_participant_renames(): req_event: RequestInfoEvent | None = None async for event in workflow.run_stream("task"): - if isinstance(event, RequestInfoEvent) and event.request_type is MagenticHumanInterventionRequest: + if isinstance(event, RequestInfoEvent) and event.request_type is MagenticPlanReviewRequest: req_event = event assert req_event is not None + assert isinstance(req_event.data, MagenticPlanReviewRequest) checkpoints = await _collect_checkpoints(storage) target_checkpoint = checkpoints[-1] renamed_workflow = ( MagenticBuilder() - .participants(agentB=StubThreadAgent()) + .participants([StubThreadAgent(name="renamedAgent")]) .with_standard_manager(InvokeOnceManager()) .with_plan_review() .with_checkpointing(storage) .build() ) - with pytest.raises(ValueError, match="Workflow graph has changed"): + with pytest.raises(WorkflowCheckpointException, match="Workflow graph has changed"): async for _ in renamed_workflow.run_stream( checkpoint_id=target_checkpoint.checkpoint_id, # type: ignore[reportUnknownMemberType] ): @@ -761,39 +743,40 @@ async def plan(self, magentic_context: MagenticContext) -> ChatMessage: async def replan(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="re-ledger") - async def create_progress_ledger(self, magentic_context: MagenticContext) -> _MagenticProgressLedger: - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(reason="r", answer=False), - is_in_loop=_MagenticProgressLedgerItem(reason="r", answer=True), - is_progress_being_made=_MagenticProgressLedgerItem(reason="r", answer=False), - next_speaker=_MagenticProgressLedgerItem(reason="r", answer="agentA"), - instruction_or_question=_MagenticProgressLedgerItem(reason="r", answer="done"), + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(reason="r", answer=False), + is_in_loop=MagenticProgressLedgerItem(reason="r", answer=True), + is_progress_being_made=MagenticProgressLedgerItem(reason="r", answer=False), + next_speaker=MagenticProgressLedgerItem(reason="r", answer="agentA"), + instruction_or_question=MagenticProgressLedgerItem(reason="r", answer="done"), ) async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="final") -async def test_magentic_stall_and_reset_successfully(): +async def test_magentic_stall_and_reset_reach_limits(): manager = NotProgressingManager(max_round_count=10, max_stall_count=0, max_reset_count=1) - wf = MagenticBuilder().participants(agentA=_DummyExec("agentA")).with_standard_manager(manager).build() + wf = MagenticBuilder().participants([DummyExec("agentA")]).with_standard_manager(manager).build() events: list[WorkflowEvent] = [] async for ev in wf.run_stream("test limits"): events.append(ev) idle_status = next( - (e for e in events if isinstance(e, WorkflowStatusEvent) and e.state == WorkflowRunState.IDLE), None + (e for e in events if isinstance(e, WorkflowStatusEvent) and e.state == WorkflowRunState.IDLE), + None, ) assert idle_status is not None output_event = next((e for e in events if isinstance(e, WorkflowOutputEvent)), None) assert output_event is not None assert isinstance(output_event.data, list) - assert all(isinstance(msg, ChatMessage) for msg in output_event.data) - assert len(output_event.data) > 0 - assert output_event.data[-1].text is not None - assert output_event.data[-1].text == "re-ledger" + assert all(isinstance(msg, ChatMessage) for msg in output_event.data) # type: ignore + assert len(output_event.data) > 0 # type: ignore + assert output_event.data[-1].text is not None # type: ignore + assert output_event.data[-1].text == "Workflow terminated due to reaching maximum reset count." # type: ignore async def test_magentic_checkpoint_runtime_only() -> None: @@ -801,8 +784,7 @@ async def test_magentic_checkpoint_runtime_only() -> None: storage = InMemoryCheckpointStorage() manager = FakeManager(max_round_count=10) - manager.satisfied_after_signoff = True - wf = MagenticBuilder().participants(agentA=_DummyExec("agentA")).with_standard_manager(manager).build() + wf = MagenticBuilder().participants([DummyExec("agentA")]).with_standard_manager(manager).build() baseline_output: ChatMessage | None = None async for ev in wf.run_stream("runtime checkpoint test", checkpoint_storage=storage): @@ -824,17 +806,19 @@ async def test_magentic_checkpoint_runtime_overrides_buildtime() -> None: """Test that runtime checkpoint storage overrides build-time configuration.""" import tempfile - with tempfile.TemporaryDirectory() as temp_dir1, tempfile.TemporaryDirectory() as temp_dir2: + with ( + tempfile.TemporaryDirectory() as temp_dir1, + tempfile.TemporaryDirectory() as temp_dir2, + ): from agent_framework._workflows._checkpoint import FileCheckpointStorage buildtime_storage = FileCheckpointStorage(temp_dir1) runtime_storage = FileCheckpointStorage(temp_dir2) manager = FakeManager(max_round_count=10) - manager.satisfied_after_signoff = True wf = ( MagenticBuilder() - .participants(agentA=_DummyExec("agentA")) + .participants([DummyExec("agentA")]) .with_standard_manager(manager) .with_checkpointing(buildtime_storage) .build() @@ -859,127 +843,12 @@ async def test_magentic_checkpoint_runtime_overrides_buildtime() -> None: assert len(buildtime_checkpoints) == 0, "Build-time storage should have no checkpoints when overridden" -def test_magentic_builder_does_not_have_human_input_hook(): - """Test that MagenticBuilder does not expose with_human_input_hook (uses specialized HITL instead). - - Magentic uses specialized human intervention mechanisms: - - with_plan_review() for plan approval - - with_human_input_on_stall() for stall intervention - - Tool approval via FunctionApprovalRequestContent - - These emit MagenticHumanInterventionRequest events with structured decision options. - """ - builder = MagenticBuilder() - - # MagenticBuilder should NOT have the generic human input hook mixin - assert not hasattr(builder, "with_human_input_hook"), ( - "MagenticBuilder should not have with_human_input_hook - " - "use with_plan_review() or with_human_input_on_stall() instead" - ) - - # region Message Deduplication Tests -async def test_magentic_no_duplicate_messages_with_conversation_history(): - """Test that passing list[ChatMessage] does not create duplicate messages in chat_history. - - When a frontend passes conversation history as list[ChatMessage], the last message - (task) should not be duplicated in the orchestrator's chat_history. - """ - manager = FakeManager(max_round_count=10) - manager.satisfied_after_signoff = True # Complete immediately after first agent response - - wf = MagenticBuilder().participants(agentA=_DummyExec("agentA")).with_standard_manager(manager).build() - - # Simulate frontend passing conversation history - conversation: list[ChatMessage] = [ - ChatMessage(role=Role.USER, text="previous question"), - ChatMessage(role=Role.ASSISTANT, text="previous answer"), - ChatMessage(role=Role.USER, text="current task"), - ] - - # Get orchestrator to inspect chat_history after run - orchestrator = None - for executor in wf.executors.values(): - if isinstance(executor, MagenticOrchestratorExecutor): - orchestrator = executor - break - - events: list[WorkflowEvent] = [] - async for event in wf.run_stream(conversation): - events.append(event) - if isinstance(event, WorkflowStatusEvent) and event.state in ( - WorkflowRunState.IDLE, - WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, - ): - break - - assert orchestrator is not None - assert orchestrator._context is not None # type: ignore[reportPrivateUsage] - - # Count occurrences of each message text in chat_history - history = orchestrator._context.chat_history # type: ignore[reportPrivateUsage] - user_task_count = sum(1 for msg in history if msg.text == "current task") - prev_question_count = sum(1 for msg in history if msg.text == "previous question") - prev_answer_count = sum(1 for msg in history if msg.text == "previous answer") - - # Each input message should appear exactly once (no duplicates) - assert prev_question_count == 1, f"Expected 1 'previous question', got {prev_question_count}" - assert prev_answer_count == 1, f"Expected 1 'previous answer', got {prev_answer_count}" - assert user_task_count == 1, f"Expected 1 'current task', got {user_task_count}" - - -async def test_magentic_agent_executor_no_duplicate_messages_on_broadcast(): - """Test that MagenticAgentExecutor does not duplicate messages from broadcasts. - - When the orchestrator broadcasts the task ledger to all agents, each agent - should receive it exactly once, not multiple times. - """ - backing_executor = _DummyExec("backing") - agent_exec = MagenticAgentExecutor(backing_executor, "agentA") - - # Simulate orchestrator sending a broadcast message - broadcast_msg = ChatMessage( - role=Role.ASSISTANT, - text="Task ledger content", - author_name="magentic_manager", - ) - - # Simulate the same message being received multiple times (e.g., from checkpoint restore + live) - from agent_framework._workflows._magentic import _MagenticResponseMessage - - response1 = _MagenticResponseMessage(body=broadcast_msg, broadcast=True) - response2 = _MagenticResponseMessage(body=broadcast_msg, broadcast=True) - - # Create a mock context - from unittest.mock import AsyncMock, MagicMock - - mock_context = MagicMock() - mock_context.send_message = AsyncMock() - - # Call the handler twice with the same message - await agent_exec.handle_response_message(response1, mock_context) # type: ignore[arg-type] - await agent_exec.handle_response_message(response2, mock_context) # type: ignore[arg-type] - - # Count how many times the broadcast message appears - history = agent_exec._chat_history # type: ignore[reportPrivateUsage] - broadcast_count = sum(1 for msg in history if msg.text == "Task ledger content") - - # Each broadcast should be recorded (this is expected behavior - broadcasts are additive) - # The test documents current behavior. If dedup is needed, this assertion would change. - assert broadcast_count == 2, ( - f"Expected 2 broadcasts (current behavior is additive), got {broadcast_count}. " - "If deduplication is required, update the handler logic." - ) - - async def test_magentic_context_no_duplicate_on_reset(): """Test that MagenticContext.reset() clears chat_history without leaving duplicates.""" - ctx = MagenticContext( - task=ChatMessage(role=Role.USER, text="task"), - participant_descriptions={"Alice": "Researcher"}, - ) + ctx = MagenticContext(task="task", participant_descriptions={"Alice": "Researcher"}) # Add some history ctx.chat_history.append(ChatMessage(role=Role.ASSISTANT, text="response1")) @@ -997,24 +866,6 @@ async def test_magentic_context_no_duplicate_on_reset(): assert len(ctx.chat_history) == 1, "Should have exactly 1 message after adding to reset context" -async def test_magentic_start_message_messages_list_integrity(): - """Test that _MagenticStartMessage preserves message list without internal duplication.""" - conversation: list[ChatMessage] = [ - ChatMessage(role=Role.USER, text="msg1"), - ChatMessage(role=Role.ASSISTANT, text="msg2"), - ChatMessage(role=Role.USER, text="msg3"), - ] - - start_msg = _MagenticStartMessage(conversation) - - # Verify messages list is preserved - assert len(start_msg.messages) == 3, f"Expected 3 messages, got {len(start_msg.messages)}" - - # Verify task is the last message (not a copy) - assert start_msg.task is start_msg.messages[-1], "task should be the same object as messages[-1]" - assert start_msg.task.text == "msg3" - - async def test_magentic_checkpoint_restore_no_duplicate_history(): """Test that checkpoint restore does not create duplicate messages in chat_history.""" manager = FakeManager(max_round_count=10) @@ -1022,7 +873,7 @@ async def test_magentic_checkpoint_restore_no_duplicate_history(): wf = ( MagenticBuilder() - .participants(agentA=_DummyExec("agentA")) + .participants([DummyExec("agentA")]) .with_standard_manager(manager) .with_checkpointing(storage) .build() @@ -1030,7 +881,6 @@ async def test_magentic_checkpoint_restore_no_duplicate_history(): # Run with conversation history to create initial checkpoint conversation: list[ChatMessage] = [ - ChatMessage(role=Role.USER, text="history_msg"), ChatMessage(role=Role.USER, text="task_msg"), ] @@ -1054,18 +904,18 @@ async def test_magentic_checkpoint_restore_no_duplicate_history(): # Check the magentic_context in the checkpoint for _, executor_state in checkpoint_data.metadata.items(): if isinstance(executor_state, dict) and "magentic_context" in executor_state: - ctx_data = executor_state["magentic_context"] - chat_history = ctx_data.get("chat_history", []) + ctx_data: dict[str, Any] = executor_state["magentic_context"] # type: ignore + chat_history = ctx_data.get("chat_history", []) # type: ignore # Count unique messages by text - texts = [ - msg.get("text") or (msg.get("contents", [{}])[0].get("text") if msg.get("contents") else None) - for msg in chat_history + texts = [ # type: ignore + msg.get("text") or (msg.get("contents", [{}])[0].get("text") if msg.get("contents") else None) # type: ignore + for msg in chat_history # type: ignore ] text_counts: dict[str, int] = {} - for text in texts: + for text in texts: # type: ignore if text: - text_counts[text] = text_counts.get(text, 0) + 1 + text_counts[text] = text_counts.get(text, 0) + 1 # type: ignore # Input messages should not be duplicated assert text_counts.get("history_msg", 0) <= 1, ( diff --git a/python/packages/core/tests/workflow/test_orchestration_request_info.py b/python/packages/core/tests/workflow/test_orchestration_request_info.py index e5f4d7a11f..24b2239757 100644 --- a/python/packages/core/tests/workflow/test_orchestration_request_info.py +++ b/python/packages/core/tests/workflow/test_orchestration_request_info.py @@ -1,59 +1,51 @@ # Copyright (c) Microsoft. All rights reserved. -"""Unit tests for request info support in high-level builders.""" +"""Unit tests for orchestration request info support.""" +from collections.abc import AsyncIterable from typing import Any -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock + +import pytest from agent_framework import ( - AgentInputRequest, AgentProtocol, - AgentResponseReviewRequest, + AgentResponse, + AgentResponseUpdate, + AgentThread, ChatMessage, - RequestInfoInterceptor, Role, ) -from agent_framework._workflows._executor import Executor, handler -from agent_framework._workflows._orchestration_request_info import resolve_request_info_filter +from agent_framework._workflows._agent_executor import AgentExecutorRequest, AgentExecutorResponse +from agent_framework._workflows._orchestration_request_info import ( + AgentApprovalExecutor, + AgentRequestInfoExecutor, + AgentRequestInfoResponse, + resolve_request_info_filter, +) from agent_framework._workflows._workflow_context import WorkflowContext -class DummyExecutor(Executor): - """Dummy executor with a handler for testing.""" - - @handler - async def handle(self, data: str, ctx: WorkflowContext[Any, Any]) -> None: - pass - - class TestResolveRequestInfoFilter: """Tests for resolve_request_info_filter function.""" - def test_returns_none_for_none_input(self): - """Test that None input returns None (no filtering).""" + def test_returns_empty_set_for_none_input(self): + """Test that None input returns empty set (no filtering).""" result = resolve_request_info_filter(None) - assert result is None + assert result == set() - def test_returns_none_for_empty_list(self): - """Test that empty list returns None.""" + def test_returns_empty_set_for_empty_list(self): + """Test that empty list returns empty set.""" result = resolve_request_info_filter([]) - assert result is None + assert result == set() def test_resolves_string_names(self): """Test resolving string agent names.""" result = resolve_request_info_filter(["agent1", "agent2"]) assert result == {"agent1", "agent2"} - def test_resolves_executor_ids(self): - """Test resolving Executor instances by ID.""" - exec1 = DummyExecutor(id="executor1") - exec2 = DummyExecutor(id="executor2") - - result = resolve_request_info_filter([exec1, exec2]) - assert result == {"executor1", "executor2"} - - def test_resolves_agent_names(self): - """Test resolving AgentProtocol-like objects by name attribute.""" + def test_resolves_agent_display_names(self): + """Test resolving AgentProtocol instances by name attribute.""" agent1 = MagicMock(spec=AgentProtocol) agent1.name = "writer" agent2 = MagicMock(spec=AgentProtocol) @@ -63,106 +55,205 @@ def test_resolves_agent_names(self): assert result == {"writer", "reviewer"} def test_mixed_types(self): - """Test resolving a mix of strings, agents, and executors.""" + """Test resolving a mix of strings and agents.""" agent = MagicMock(spec=AgentProtocol) agent.name = "writer" - executor = DummyExecutor(id="custom_exec") - result = resolve_request_info_filter(["manual_name", agent, executor]) - assert result == {"manual_name", "writer", "custom_exec"} + result = resolve_request_info_filter(["manual_name", agent]) + assert result == {"manual_name", "writer"} + + def test_raises_on_unsupported_type(self): + """Test that unsupported types raise TypeError.""" + with pytest.raises(TypeError, match="Unsupported type for request_info filter"): + resolve_request_info_filter([123]) # type: ignore + + +class TestAgentRequestInfoResponse: + """Tests for AgentRequestInfoResponse dataclass.""" + + def test_create_response_with_messages(self): + """Test creating an AgentRequestInfoResponse with messages.""" + messages = [ChatMessage(role=Role.USER, text="Additional info")] + response = AgentRequestInfoResponse(messages=messages) + + assert response.messages == messages + + def test_from_messages_factory(self): + """Test creating response from ChatMessage list.""" + messages = [ + ChatMessage(role=Role.USER, text="Message 1"), + ChatMessage(role=Role.USER, text="Message 2"), + ] + response = AgentRequestInfoResponse.from_messages(messages) + + assert response.messages == messages - def test_skips_agent_without_name(self): - """Test that agents without names are skipped.""" - agent_with_name = MagicMock(spec=AgentProtocol) - agent_with_name.name = "valid" - agent_without_name = MagicMock(spec=AgentProtocol) - agent_without_name.name = None + def test_from_strings_factory(self): + """Test creating response from string list.""" + texts = ["First message", "Second message"] + response = AgentRequestInfoResponse.from_strings(texts) - result = resolve_request_info_filter([agent_with_name, agent_without_name]) - assert result == {"valid"} + assert len(response.messages) == 2 + assert response.messages[0].role == Role.USER + assert response.messages[0].text == "First message" + assert response.messages[1].role == Role.USER + assert response.messages[1].text == "Second message" + def test_approve_factory(self): + """Test creating an approval response (empty messages).""" + response = AgentRequestInfoResponse.approve() -class TestAgentInputRequest: - """Tests for AgentInputRequest dataclass (formerly AgentResponseReviewRequest).""" + assert response.messages == [] - def test_create_request(self): - """Test creating an AgentInputRequest with all fields.""" - conversation = [ChatMessage(role=Role.USER, text="Hello")] - request = AgentInputRequest( - target_agent_id="test_agent", - conversation=conversation, - instruction="Review this", - metadata={"key": "value"}, + +class TestAgentRequestInfoExecutor: + """Tests for AgentRequestInfoExecutor.""" + + @pytest.mark.asyncio + async def test_request_info_handler(self): + """Test that request_info handler calls ctx.request_info.""" + executor = AgentRequestInfoExecutor(id="test_executor") + + agent_response = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Agent response")]) + agent_response = AgentExecutorResponse( + executor_id="test_agent", + agent_response=agent_response, + ) + + ctx = MagicMock(spec=WorkflowContext) + ctx.request_info = AsyncMock() + + await executor.request_info(agent_response, ctx) + + ctx.request_info.assert_called_once_with(agent_response, AgentRequestInfoResponse) + + @pytest.mark.asyncio + async def test_handle_request_info_response_with_messages(self): + """Test response handler when user provides additional messages.""" + executor = AgentRequestInfoExecutor(id="test_executor") + + agent_response = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Original")]) + original_request = AgentExecutorResponse( + executor_id="test_agent", + agent_response=agent_response, ) - assert request.target_agent_id == "test_agent" - assert request.conversation == conversation - assert request.instruction == "Review this" - assert request.metadata == {"key": "value"} - - def test_create_request_defaults(self): - """Test creating an AgentInputRequest with default values.""" - request = AgentInputRequest(target_agent_id="test_agent") - - assert request.target_agent_id == "test_agent" - assert request.conversation == [] - assert request.instruction is None - assert request.metadata == {} - - def test_backward_compatibility_alias(self): - """Test that AgentResponseReviewRequest is an alias for AgentInputRequest.""" - assert AgentResponseReviewRequest is AgentInputRequest - - -class TestRequestInfoInterceptor: - """Tests for RequestInfoInterceptor executor.""" - - def test_interceptor_creation_generates_unique_id(self): - """Test creating a RequestInfoInterceptor generates unique IDs.""" - interceptor1 = RequestInfoInterceptor() - interceptor2 = RequestInfoInterceptor() - assert interceptor1.id.startswith("request_info_interceptor-") - assert interceptor2.id.startswith("request_info_interceptor-") - assert interceptor1.id != interceptor2.id - - def test_interceptor_with_custom_id(self): - """Test creating a RequestInfoInterceptor with custom ID.""" - interceptor = RequestInfoInterceptor(executor_id="custom_review") - assert interceptor.id == "custom_review" - - def test_interceptor_with_agent_filter(self): - """Test creating a RequestInfoInterceptor with agent filter.""" - agent_filter = {"agent1", "agent2"} - interceptor = RequestInfoInterceptor( - executor_id="filtered_review", - agent_filter=agent_filter, + response = AgentRequestInfoResponse.from_strings(["Additional input"]) + + ctx = MagicMock(spec=WorkflowContext) + ctx.send_message = AsyncMock() + + await executor.handle_request_info_response(original_request, response, ctx) + + # Should send new request with additional messages + ctx.send_message.assert_called_once() + call_args = ctx.send_message.call_args[0][0] + assert isinstance(call_args, AgentExecutorRequest) + assert call_args.should_respond is True + assert len(call_args.messages) == 1 + assert call_args.messages[0].text == "Additional input" + + @pytest.mark.asyncio + async def test_handle_request_info_response_approval(self): + """Test response handler when user approves (no additional messages).""" + executor = AgentRequestInfoExecutor(id="test_executor") + + agent_response = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Original")]) + original_request = AgentExecutorResponse( + executor_id="test_agent", + agent_response=agent_response, ) - assert interceptor.id == "filtered_review" - assert interceptor._agent_filter == agent_filter - - def test_should_pause_for_agent_no_filter(self): - """Test that interceptor pauses for all agents when no filter is set.""" - interceptor = RequestInfoInterceptor() - assert interceptor._should_pause_for_agent("any_agent") is True - assert interceptor._should_pause_for_agent("another_agent") is True - assert interceptor._should_pause_for_agent(None) is True - - def test_should_pause_for_agent_with_filter(self): - """Test that interceptor only pauses for agents in the filter.""" - agent_filter = {"writer", "reviewer"} - interceptor = RequestInfoInterceptor(agent_filter=agent_filter) - - assert interceptor._should_pause_for_agent("writer") is True - assert interceptor._should_pause_for_agent("reviewer") is True - assert interceptor._should_pause_for_agent("drafter") is False - assert interceptor._should_pause_for_agent(None) is False - - def test_should_pause_for_agent_with_prefixed_id(self): - """Test that filter matches agent names in prefixed executor IDs.""" - agent_filter = {"writer"} - interceptor = RequestInfoInterceptor(agent_filter=agent_filter) - - # Should match the name portion after the colon - assert interceptor._should_pause_for_agent("groupchat_agent:writer") is True - assert interceptor._should_pause_for_agent("request_info:writer") is True - assert interceptor._should_pause_for_agent("groupchat_agent:editor") is False + + response = AgentRequestInfoResponse.approve() + + ctx = MagicMock(spec=WorkflowContext) + ctx.yield_output = AsyncMock() + + await executor.handle_request_info_response(original_request, response, ctx) + + # Should yield original response without modification + ctx.yield_output.assert_called_once_with(original_request) + + +class _TestAgent: + """Simple test agent implementation.""" + + def __init__(self, id: str, name: str | None = None, description: str | None = None): + self._id = id + self._name = name + self._description = description + + @property + def id(self) -> str: + return self._id + + @property + def name(self) -> str | None: + return self._name + + @property + def display_name(self) -> str: + return self._name or self._id + + @property + def description(self) -> str | None: + return self._description + + async def run( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AgentResponse: + """Dummy run method.""" + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Test response")]) + + def run_stream( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AsyncIterable[AgentResponseUpdate]: + """Dummy run_stream method.""" + + async def generator(): + yield AgentResponseUpdate(messages=[ChatMessage(role=Role.ASSISTANT, text="Test response stream")]) + + return generator() + + def get_new_thread(self, **kwargs: Any) -> AgentThread: + """Creates a new conversation thread for the agent.""" + return AgentThread(**kwargs) + + +class TestAgentApprovalExecutor: + """Tests for AgentApprovalExecutor.""" + + def test_initialization(self): + """Test that AgentApprovalExecutor initializes correctly.""" + agent = _TestAgent(id="test_id", name="test_agent", description="Test agent description") + + executor = AgentApprovalExecutor(agent) + + assert executor.id == "test_agent" + assert executor.description == "Test agent description" + + def test_builds_workflow_with_agent_and_request_info_executors(self): + """Test that the internal workflow is created successfully.""" + agent = _TestAgent(id="test_id", name="test_agent", description="Test description") + + executor = AgentApprovalExecutor(agent) + + # Verify the executor has a workflow + assert executor.workflow is not None + assert executor.id == "test_agent" + + def test_propagate_request_enabled(self): + """Test that AgentApprovalExecutor has propagate_request enabled.""" + agent = _TestAgent(id="test_id", name="test_agent", description="Test description") + + executor = AgentApprovalExecutor(agent) + + assert executor._propagate_request is True # type: ignore diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index 73fe20d834..f6a031e5a3 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -7,18 +7,24 @@ from agent_framework import ( AgentExecutorResponse, - AgentRunResponse, + AgentResponse, Executor, WorkflowContext, + WorkflowConvergenceException, WorkflowEvent, WorkflowOutputEvent, + WorkflowRunnerException, WorkflowRunState, WorkflowStatusEvent, handler, ) from agent_framework._workflows._edge import SingleEdgeGroup from agent_framework._workflows._runner import Runner -from agent_framework._workflows._runner_context import InProcRunnerContext, Message, RunnerContext +from agent_framework._workflows._runner_context import ( + InProcRunnerContext, + Message, + RunnerContext, +) from agent_framework._workflows._shared_state import SharedState @@ -52,7 +58,10 @@ def test_create_runner(): SingleEdgeGroup(executor_b.id, executor_a.id), ] - executors: dict[str, Executor] = {executor_a.id: executor_a, executor_b.id: executor_b} + executors: dict[str, Executor] = { + executor_a.id: executor_a, + executor_b.id: executor_b, + } runner = Runner(edge_groups, executors, shared_state=SharedState(), ctx=InProcRunnerContext()) @@ -70,7 +79,10 @@ async def test_runner_run_until_convergence(): SingleEdgeGroup(executor_b.id, executor_a.id), ] - executors: dict[str, Executor] = {executor_a.id: executor_a, executor_b.id: executor_b} + executors: dict[str, Executor] = { + executor_a.id: executor_a, + executor_b.id: executor_b, + } shared_state = SharedState() ctx = InProcRunnerContext() @@ -90,6 +102,9 @@ async def test_runner_run_until_convergence(): assert result is not None and result == 10 + # iteration count shouldn't be reset after convergence + assert runner._iteration == 10 # type: ignore + async def test_runner_run_until_convergence_not_completed(): """Test running the runner with a simple workflow.""" @@ -102,7 +117,10 @@ async def test_runner_run_until_convergence_not_completed(): SingleEdgeGroup(executor_b.id, executor_a.id), ] - executors: dict[str, Executor] = {executor_a.id: executor_a, executor_b.id: executor_b} + executors: dict[str, Executor] = { + executor_a.id: executor_a, + executor_b.id: executor_b, + } shared_state = SharedState() ctx = InProcRunnerContext() @@ -114,7 +132,10 @@ async def test_runner_run_until_convergence_not_completed(): shared_state, # shared_state ctx, # runner_context ) - with pytest.raises(RuntimeError, match="Runner did not converge after 5 iterations."): + with pytest.raises( + WorkflowConvergenceException, + match="Runner did not converge after 5 iterations.", + ): async for event in runner.run_until_convergence(): assert not isinstance(event, WorkflowStatusEvent) or event.state != WorkflowRunState.IDLE @@ -130,7 +151,10 @@ async def test_runner_already_running(): SingleEdgeGroup(executor_b.id, executor_a.id), ] - executors: dict[str, Executor] = {executor_a.id: executor_a, executor_b.id: executor_b} + executors: dict[str, Executor] = { + executor_a.id: executor_a, + executor_b.id: executor_b, + } shared_state = SharedState() ctx = InProcRunnerContext() @@ -143,7 +167,7 @@ async def test_runner_already_running(): ctx, # runner_context ) - with pytest.raises(RuntimeError, match="Runner is already running."): + with pytest.raises(WorkflowRunnerException, match="Runner is already running."): async def _run(): async for _ in runner.run_until_convergence(): @@ -158,7 +182,7 @@ async def test_runner_emits_runner_completion_for_agent_response_without_targets await ctx.send_message( Message( - data=AgentExecutorResponse("agent", AgentRunResponse()), + data=AgentExecutorResponse("agent", AgentResponse()), source_id="agent", ) ) diff --git a/python/packages/core/tests/workflow/test_sequential.py b/python/packages/core/tests/workflow/test_sequential.py index 8ff0098c38..d104eb8a02 100644 --- a/python/packages/core/tests/workflow/test_sequential.py +++ b/python/packages/core/tests/workflow/test_sequential.py @@ -6,8 +6,9 @@ import pytest from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentExecutorResponse, + AgentResponse, + AgentResponseUpdate, AgentThread, BaseAgent, ChatMessage, @@ -34,8 +35,8 @@ async def run( # type: ignore[override] *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.display_name} reply")]) + ) -> AgentResponse: + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.name} reply")]) async def run_stream( # type: ignore[override] self, @@ -43,16 +44,17 @@ async def run_stream( # type: ignore[override] *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: # Minimal async generator with one assistant update - yield AgentRunResponseUpdate(contents=[TextContent(text=f"{self.display_name} reply")]) + yield AgentResponseUpdate(contents=[TextContent(text=f"{self.name} reply")]) class _SummarizerExec(Executor): """Custom executor that summarizes by appending a short assistant message.""" @handler - async def summarize(self, conversation: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: + async def summarize(self, agent_response: AgentExecutorResponse, ctx: WorkflowContext[list[ChatMessage]]) -> None: + conversation = agent_response.full_conversation or [] user_texts = [m.text for m in conversation if m.role == Role.USER] agents = [m.author_name or m.role for m in conversation if m.role == Role.ASSISTANT] summary = ChatMessage(role=Role.ASSISTANT, text=f"Summary of users:{len(user_texts)} agents:{len(agents)}") diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index 059c94803e..5a0f54a24e 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -11,9 +11,9 @@ from agent_framework import ( AgentExecutor, + AgentResponse, + AgentResponseUpdate, AgentRunEvent, - AgentRunResponse, - AgentRunResponseUpdate, AgentRunUpdateEvent, AgentThread, BaseAgent, @@ -25,7 +25,9 @@ Role, TextContent, WorkflowBuilder, + WorkflowCheckpointException, WorkflowContext, + WorkflowConvergenceException, WorkflowEvent, WorkflowOutputEvent, WorkflowRunState, @@ -143,7 +145,7 @@ async def test_workflow_run_stream_not_completed(): .build() ) - with pytest.raises(RuntimeError): + with pytest.raises(WorkflowConvergenceException): async for _ in workflow.run_stream(NumberMessage(data=0)): pass @@ -181,7 +183,7 @@ async def test_workflow_run_not_completed(): .build() ) - with pytest.raises(RuntimeError): + with pytest.raises(WorkflowConvergenceException): await workflow.run(NumberMessage(data=0)) @@ -289,7 +291,9 @@ async def test_workflow_with_checkpointing_enabled(simple_executor: Executor): assert result is not None -async def test_workflow_checkpointing_not_enabled_for_external_restore(simple_executor: Executor): +async def test_workflow_checkpointing_not_enabled_for_external_restore( + simple_executor: Executor, +): """Test that external checkpoint restoration fails when workflow doesn't support checkpointing.""" # Build workflow WITHOUT checkpointing workflow = ( @@ -308,7 +312,9 @@ async def test_workflow_checkpointing_not_enabled_for_external_restore(simple_ex assert "either provide checkpoint_storage parameter" in str(e) -async def test_workflow_run_stream_from_checkpoint_no_checkpointing_enabled(simple_executor: Executor): +async def test_workflow_run_stream_from_checkpoint_no_checkpointing_enabled( + simple_executor: Executor, +): # Build workflow WITHOUT checkpointing workflow = ( WorkflowBuilder() @@ -327,7 +333,9 @@ async def test_workflow_run_stream_from_checkpoint_no_checkpointing_enabled(simp assert "either provide checkpoint_storage parameter" in str(e) -async def test_workflow_run_stream_from_checkpoint_invalid_checkpoint(simple_executor: Executor): +async def test_workflow_run_stream_from_checkpoint_invalid_checkpoint( + simple_executor: Executor, +): """Test that attempting to restore from a non-existent checkpoint fails appropriately.""" with tempfile.TemporaryDirectory() as temp_dir: storage = FileCheckpointStorage(temp_dir) @@ -345,12 +353,14 @@ async def test_workflow_run_stream_from_checkpoint_invalid_checkpoint(simple_exe try: async for _ in workflow.run_stream(checkpoint_id="nonexistent_checkpoint_id"): pass - raise AssertionError("Expected RuntimeError to be raised") - except RuntimeError as e: - assert "Failed to restore from checkpoint" in str(e) + raise AssertionError("Expected WorkflowCheckpointException to be raised") + except WorkflowCheckpointException as e: + assert str(e) == "Checkpoint nonexistent_checkpoint_id not found" -async def test_workflow_run_stream_from_checkpoint_with_external_storage(simple_executor: Executor): +async def test_workflow_run_stream_from_checkpoint_with_external_storage( + simple_executor: Executor, +): """Test that external checkpoint storage can be provided for restoration.""" with tempfile.TemporaryDirectory() as temp_dir: storage = FileCheckpointStorage(temp_dir) @@ -416,7 +426,9 @@ async def test_workflow_run_from_checkpoint_non_streaming(simple_executor: Execu assert hasattr(result, "get_outputs") # Should have WorkflowRunResult methods -async def test_workflow_run_stream_from_checkpoint_with_responses(simple_executor: Executor): +async def test_workflow_run_stream_from_checkpoint_with_responses( + simple_executor: Executor, +): """Test that workflow can be resumed from checkpoint with pending RequestInfoEvents.""" with tempfile.TemporaryDirectory() as temp_dir: storage = FileCheckpointStorage(temp_dir) @@ -475,7 +487,9 @@ class StateTrackingExecutor(Executor): @handler async def handle_message( - self, message: StateTrackingMessage, ctx: WorkflowContext[StateTrackingMessage, list[str]] + self, + message: StateTrackingMessage, + ctx: WorkflowContext[StateTrackingMessage, list[str]], ) -> None: """Handle the message and track it in shared state.""" # Get existing messages from shared state @@ -537,7 +551,9 @@ async def test_workflow_multiple_runs_no_state_collision(): assert outputs1[0] != outputs3[0] -async def test_workflow_checkpoint_runtime_only_configuration(simple_executor: Executor): +async def test_workflow_checkpoint_runtime_only_configuration( + simple_executor: Executor, +): """Test that checkpointing can be configured ONLY at runtime, not at build time.""" with tempfile.TemporaryDirectory() as temp_dir: storage = FileCheckpointStorage(temp_dir) @@ -574,12 +590,20 @@ async def test_workflow_checkpoint_runtime_only_configuration(simple_executor: E checkpoint_id=resume_checkpoint.checkpoint_id, checkpoint_storage=storage ) assert result_resumed is not None - assert result_resumed.get_final_state() in (WorkflowRunState.IDLE, WorkflowRunState.IDLE_WITH_PENDING_REQUESTS) + assert result_resumed.get_final_state() in ( + WorkflowRunState.IDLE, + WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, + ) -async def test_workflow_checkpoint_runtime_overrides_buildtime(simple_executor: Executor): +async def test_workflow_checkpoint_runtime_overrides_buildtime( + simple_executor: Executor, +): """Test that runtime checkpoint storage overrides build-time configuration.""" - with tempfile.TemporaryDirectory() as temp_dir1, tempfile.TemporaryDirectory() as temp_dir2: + with ( + tempfile.TemporaryDirectory() as temp_dir1, + tempfile.TemporaryDirectory() as temp_dir2, + ): buildtime_storage = FileCheckpointStorage(temp_dir1) runtime_storage = FileCheckpointStorage(temp_dir2) @@ -740,7 +764,10 @@ async def run_workflow(): await asyncio.sleep(0.01) # Try to start a second concurrent execution - this should fail - with pytest.raises(RuntimeError, match="Workflow is already running. Concurrent executions are not allowed."): + with pytest.raises( + RuntimeError, + match="Workflow is already running. Concurrent executions are not allowed.", + ): await workflow.run(NumberMessage(data=0)) # Wait for the first task to complete @@ -773,7 +800,10 @@ async def consume_stream_slowly(): await asyncio.sleep(0.02) # Try to start a second concurrent execution - this should fail - with pytest.raises(RuntimeError, match="Workflow is already running. Concurrent executions are not allowed."): + with pytest.raises( + RuntimeError, + match="Workflow is already running. Concurrent executions are not allowed.", + ): await workflow.run(NumberMessage(data=0)) # Wait for the first task to complete @@ -803,10 +833,16 @@ async def consume_stream(): await asyncio.sleep(0.02) # Let it start # Try different execution methods - all should fail - with pytest.raises(RuntimeError, match="Workflow is already running. Concurrent executions are not allowed."): + with pytest.raises( + RuntimeError, + match="Workflow is already running. Concurrent executions are not allowed.", + ): await workflow.run(NumberMessage(data=0)) - with pytest.raises(RuntimeError, match="Workflow is already running. Concurrent executions are not allowed."): + with pytest.raises( + RuntimeError, + match="Workflow is already running. Concurrent executions are not allowed.", + ): async for _ in workflow.run_stream(NumberMessage(data=0)): break @@ -831,9 +867,9 @@ async def run( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: """Non-streaming run - returns complete response.""" - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=self._reply_text)]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=self._reply_text)]) async def run_stream( self, @@ -841,11 +877,11 @@ async def run_stream( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: """Streaming run - yields incremental updates.""" # Simulate streaming by yielding character by character for char in self._reply_text: - yield AgentRunResponseUpdate(contents=[TextContent(text=char)]) + yield AgentResponseUpdate(contents=[TextContent(text=char)]) async def test_agent_streaming_vs_non_streaming() -> None: @@ -923,7 +959,9 @@ async def test_workflow_run_parameter_validation(simple_executor: Executor) -> N pass -async def test_workflow_run_stream_parameter_validation(simple_executor: Executor) -> None: +async def test_workflow_run_stream_parameter_validation( + simple_executor: Executor, +) -> None: """Test run_stream() specific parameter validation scenarios.""" workflow = WorkflowBuilder().add_edge(simple_executor, simple_executor).set_start_executor(simple_executor).build() diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index d2ed8d1394..7e47a82c9c 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -8,8 +8,8 @@ from agent_framework import ( AgentProtocol, - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentRunUpdateEvent, AgentThread, ChatMessage, @@ -19,6 +19,7 @@ FunctionApprovalRequestContent, FunctionApprovalResponseContent, FunctionCallContent, + FunctionResultContent, Role, TextContent, UriContent, @@ -52,7 +53,7 @@ async def handle_message(self, message: list[ChatMessage], ctx: WorkflowContext[ response_message = ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=response_text)]) # Emit update event. - streaming_update = AgentRunResponseUpdate( + streaming_update = AgentResponseUpdate( contents=[TextContent(text=response_text)], role=Role.ASSISTANT, message_id=str(uuid.uuid4()) ) await ctx.add_event(AgentRunUpdateEvent(executor_id=self.id, data=streaming_update)) @@ -74,7 +75,7 @@ async def handle_request_response( self, original_request: str, response: str, ctx: WorkflowContext[ChatMessage] ) -> None: # Handle the response and emit completion response - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[TextContent(text="Request completed successfully")], role=Role.ASSISTANT, message_id=str(uuid.uuid4()), @@ -100,7 +101,7 @@ async def handle_message(self, messages: list[ChatMessage], ctx: WorkflowContext response_message = ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=response_text)]) - streaming_update = AgentRunResponseUpdate( + streaming_update = AgentResponseUpdate( contents=[TextContent(text=response_text)], role=Role.ASSISTANT, message_id=str(uuid.uuid4()) ) await ctx.add_event(AgentRunUpdateEvent(executor_id=self.id, data=streaming_update)) @@ -124,7 +125,7 @@ async def test_end_to_end_basic_workflow(self): result = await agent.run("Hello World") # Verify we got responses from both executors - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert len(result.messages) >= 2, f"Expected at least 2 messages, got {len(result.messages)}" # Find messages from each executor @@ -162,7 +163,7 @@ async def test_end_to_end_basic_workflow_streaming(self): agent = WorkflowAgent(workflow=workflow, name="Streaming Test Agent") # Execute workflow streaming to capture streaming events - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in agent.run_stream("Test input"): updates.append(update) @@ -191,13 +192,13 @@ async def test_end_to_end_request_info_handling(self): agent = WorkflowAgent(workflow=workflow, name="Request Test Agent") # Execute workflow streaming to get request info event - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in agent.run_stream("Start request"): updates.append(update) # Should have received an approval request for the request info assert len(updates) > 0 - approval_update: AgentRunResponseUpdate | None = None + approval_update: AgentResponseUpdate | None = None for update in updates: if any(isinstance(content, FunctionApprovalRequestContent) for content in update.contents): approval_update = update @@ -248,7 +249,7 @@ async def test_end_to_end_request_info_handling(self): continuation_result = await agent.run(response_message) # Should complete successfully - assert isinstance(continuation_result, AgentRunResponse) + assert isinstance(continuation_result, AgentResponse) # Verify cleanup - pending requests should be cleared after function response handling assert len(agent.pending_requests) == 0 @@ -293,7 +294,7 @@ async def test_workflow_as_agent_yield_output_surfaces_as_agent_response(self) - """Test that ctx.yield_output() in a workflow executor surfaces as agent output when using .as_agent(). This validates the fix for issue #2813: WorkflowOutputEvent should be converted to - AgentRunResponseUpdate when the workflow is wrapped via .as_agent(). + AgentResponseUpdate when the workflow is wrapped via .as_agent(). """ @executor @@ -314,12 +315,12 @@ async def yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) - agent = workflow.as_agent("test-agent") agent_result = await agent.run("hello") - assert isinstance(agent_result, AgentRunResponse) + assert isinstance(agent_result, AgentResponse) assert len(agent_result.messages) == 1 assert agent_result.messages[0].text == "processed: hello" async def test_workflow_as_agent_yield_output_surfaces_in_run_stream(self) -> None: - """Test that ctx.yield_output() surfaces as AgentRunResponseUpdate when streaming.""" + """Test that ctx.yield_output() surfaces as AgentResponseUpdate when streaming.""" @executor async def yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None: @@ -329,7 +330,7 @@ async def yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) - workflow = WorkflowBuilder().set_start_executor(yielding_executor).build() agent = workflow.as_agent("test-agent") - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in agent.run_stream("hello"): updates.append(update) @@ -353,7 +354,7 @@ async def content_yielding_executor(messages: list[ChatMessage], ctx: WorkflowCo result = await agent.run("test") - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert len(result.messages) == 3 # Verify each content type is preserved @@ -410,7 +411,7 @@ async def raw_yielding_executor(messages: list[ChatMessage], ctx: WorkflowContex workflow = WorkflowBuilder().set_start_executor(raw_yielding_executor).build() agent = workflow.as_agent("raw-test-agent") - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in agent.run_stream("test"): updates.append(update) @@ -448,7 +449,7 @@ async def list_yielding_executor(messages: list[ChatMessage], ctx: WorkflowConte agent = workflow.as_agent("list-msg-agent") # Verify streaming returns the update with all 4 contents before coalescing - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in agent.run_stream("test"): updates.append(update) @@ -460,7 +461,7 @@ async def list_yielding_executor(messages: list[ChatMessage], ctx: WorkflowConte # Verify run() coalesces text contents (expected behavior) result = await agent.run("test") - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert len(result.messages) == 1 # TextContent items are coalesced into one assert len(result.messages[0].contents) == 1 @@ -587,17 +588,17 @@ def description(self) -> str | None: def get_new_thread(self) -> AgentThread: return AgentThread() - async def run(self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any) -> AgentRunResponse: - return AgentRunResponse( + async def run(self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any) -> AgentResponse: + return AgentResponse( messages=[ChatMessage(role=Role.ASSISTANT, text=self._response_text)], text=self._response_text, ) async def run_stream( self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: for word in self._response_text.split(): - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=[TextContent(text=word + " ")], role=Role.ASSISTANT, author_name=self._name, @@ -661,16 +662,16 @@ def description(self) -> str | None: def get_new_thread(self) -> AgentThread: return AgentThread() - async def run(self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any) -> AgentRunResponse: - return AgentRunResponse( + async def run(self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any) -> AgentResponse: + return AgentResponse( messages=[ChatMessage(role=Role.ASSISTANT, text=self._response_text)], text=self._response_text, ) async def run_stream( self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any - ) -> AsyncIterable[AgentRunResponseUpdate]: - yield AgentRunResponseUpdate( + ) -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate( contents=[TextContent(text=self._response_text)], role=Role.ASSISTANT, author_name=self._name, @@ -717,7 +718,7 @@ async def test_agent_run_update_event_gets_executor_id_as_author_name(self): agent = WorkflowAgent(workflow=workflow, name="Test Agent") # Collect streaming updates - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in agent.run_stream("Hello"): updates.append(update) @@ -736,7 +737,7 @@ class AuthorNameExecutor(Executor): @handler async def handle_message(self, message: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: # Emit update with explicit author_name - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[TextContent(text="Response with author")], role=Role.ASSISTANT, author_name="custom_author_name", # Explicitly set @@ -749,7 +750,7 @@ async def handle_message(self, message: list[ChatMessage], ctx: WorkflowContext[ agent = WorkflowAgent(workflow=workflow, name="Test Agent") # Collect streaming updates - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in agent.run_stream("Hello"): updates.append(update) @@ -767,7 +768,7 @@ async def test_multiple_executors_have_distinct_author_names(self): agent = WorkflowAgent(workflow=workflow, name="Multi-Executor Agent") # Collect streaming updates - updates: list[AgentRunResponseUpdate] = [] + updates: list[AgentResponseUpdate] = [] async for update in agent.run_stream("Hello"): updates.append(update) @@ -788,7 +789,7 @@ def test_merge_updates_ordering_by_response_and_message_id(self): # Create updates with different response_ids and message_ids in non-chronological order updates = [ # Response B, Message 2 (latest in resp B) - AgentRunResponseUpdate( + AgentResponseUpdate( contents=[TextContent(text="RespB-Msg2")], role=Role.ASSISTANT, response_id="resp-b", @@ -796,7 +797,7 @@ def test_merge_updates_ordering_by_response_and_message_id(self): created_at="2024-01-01T12:02:00Z", ), # Response A, Message 1 (earliest overall) - AgentRunResponseUpdate( + AgentResponseUpdate( contents=[TextContent(text="RespA-Msg1")], role=Role.ASSISTANT, response_id="resp-a", @@ -804,7 +805,7 @@ def test_merge_updates_ordering_by_response_and_message_id(self): created_at="2024-01-01T12:00:00Z", ), # Response B, Message 1 (earlier in resp B) - AgentRunResponseUpdate( + AgentResponseUpdate( contents=[TextContent(text="RespB-Msg1")], role=Role.ASSISTANT, response_id="resp-b", @@ -812,7 +813,7 @@ def test_merge_updates_ordering_by_response_and_message_id(self): created_at="2024-01-01T12:01:00Z", ), # Response A, Message 2 (later in resp A) - AgentRunResponseUpdate( + AgentResponseUpdate( contents=[TextContent(text="RespA-Msg2")], role=Role.ASSISTANT, response_id="resp-a", @@ -820,7 +821,7 @@ def test_merge_updates_ordering_by_response_and_message_id(self): created_at="2024-01-01T12:00:30Z", ), # Global dangling update (no response_id) - should go at end - AgentRunResponseUpdate( + AgentResponseUpdate( contents=[TextContent(text="Global-Dangling")], role=Role.ASSISTANT, response_id=None, @@ -891,7 +892,7 @@ def test_merge_updates_metadata_aggregation(self): """Test that merge_updates correctly aggregates usage details, timestamps, and additional properties.""" # Create updates with various metadata including usage details updates = [ - AgentRunResponseUpdate( + AgentResponseUpdate( contents=[ TextContent(text="First"), UsageContent( @@ -904,7 +905,7 @@ def test_merge_updates_metadata_aggregation(self): created_at="2024-01-01T12:00:00Z", additional_properties={"source": "executor1", "priority": "high"}, ), - AgentRunResponseUpdate( + AgentResponseUpdate( contents=[ TextContent(text="Second"), UsageContent( @@ -917,7 +918,7 @@ def test_merge_updates_metadata_aggregation(self): created_at="2024-01-01T12:01:00Z", # Later timestamp additional_properties={"source": "executor2", "category": "analysis"}, ), - AgentRunResponseUpdate( + AgentResponseUpdate( contents=[ TextContent(text="Third"), UsageContent(details=UsageDetails(input_token_count=5, output_token_count=3, total_token_count=8)), @@ -957,3 +958,255 @@ def test_merge_updates_metadata_aggregation(self): # properties only include final merged result from its own updates } assert result.additional_properties == expected_properties + + def test_merge_updates_function_result_ordering_github_2977(self): + """Test that FunctionResultContent updates are placed after their FunctionCallContent. + + This test reproduces GitHub issue #2977: When using a thread with WorkflowAgent, + FunctionResultContent updates without response_id were being added to global_dangling + and placed at the end of messages. This caused OpenAI to reject the conversation because + "An assistant message with 'tool_calls' must be followed by tool messages responding + to each 'tool_call_id'." + + The expected ordering should be: + - User Question + - FunctionCallContent (assistant) + - FunctionResultContent (tool) + - Assistant Answer + + NOT: + - User Question + - FunctionCallContent (assistant) + - Assistant Answer + - FunctionResultContent (tool) <-- This was the bug + """ + call_id = "call_F09je20iUue6DlFRDLLh3dGK" + + updates = [ + # User question + AgentResponseUpdate( + contents=[TextContent(text="What is the weather?")], + role=Role.USER, + response_id="resp-1", + message_id="msg-1", + created_at="2024-01-01T12:00:00Z", + ), + # Assistant with function call + AgentResponseUpdate( + contents=[FunctionCallContent(call_id=call_id, name="get_weather", arguments='{"location": "NYC"}')], + role=Role.ASSISTANT, + response_id="resp-1", + message_id="msg-2", + created_at="2024-01-01T12:00:01Z", + ), + # Function result: no response_id previously caused this to go to global_dangling + # and be placed at the end (the bug); fix now correctly associates via call_id + AgentResponseUpdate( + contents=[FunctionResultContent(call_id=call_id, result="Sunny, 72F")], + role=Role.TOOL, + response_id=None, + message_id="msg-3", + created_at="2024-01-01T12:00:02Z", + ), + # Final assistant answer + AgentResponseUpdate( + contents=[TextContent(text="The weather in NYC is sunny and 72F.")], + role=Role.ASSISTANT, + response_id="resp-1", + message_id="msg-4", + created_at="2024-01-01T12:00:03Z", + ), + ] + + result = WorkflowAgent.merge_updates(updates, "final-response") + + assert len(result.messages) == 4 + + # Extract content types for verification + content_sequence = [] + for msg in result.messages: + for content in msg.contents: + if isinstance(content, TextContent): + content_sequence.append(("text", msg.role)) + elif isinstance(content, FunctionCallContent): + content_sequence.append(("function_call", msg.role)) + elif isinstance(content, FunctionResultContent): + content_sequence.append(("function_result", msg.role)) + + # Verify correct ordering: user -> function_call -> function_result -> assistant_answer + expected_sequence = [ + ("text", Role.USER), + ("function_call", Role.ASSISTANT), + ("function_result", Role.TOOL), + ("text", Role.ASSISTANT), + ] + + assert content_sequence == expected_sequence, ( + f"FunctionResultContent should come immediately after FunctionCallContent. " + f"Got: {content_sequence}, Expected: {expected_sequence}" + ) + + # Additional check: verify FunctionResultContent call_id matches FunctionCallContent + function_call_idx = None + function_result_idx = None + for i, msg in enumerate(result.messages): + for content in msg.contents: + if isinstance(content, FunctionCallContent): + function_call_idx = i + assert content.call_id == call_id + elif isinstance(content, FunctionResultContent): + function_result_idx = i + assert content.call_id == call_id + + assert function_call_idx is not None + assert function_result_idx is not None + assert function_result_idx == function_call_idx + 1, ( + f"FunctionResultContent at index {function_result_idx} should immediately follow " + f"FunctionCallContent at index {function_call_idx}" + ) + + def test_merge_updates_multiple_function_results_ordering_github_2977(self): + """Test ordering with multiple FunctionCallContent/FunctionResultContent pairs. + + Validates that multiple tool calls and results appear before the final assistant + answer, even when results arrive without response_id and in different order than calls. + + OpenAI requires that tool results appear after their calls and before the next + assistant text message, but doesn't require strict interleaving (result_1 immediately + after call_1). The key constraint is: calls -> results -> final_answer. + """ + call_id_1 = "call_weather_001" + call_id_2 = "call_time_002" + + updates = [ + # User question + AgentResponseUpdate( + contents=[TextContent(text="What's the weather and time?")], + role=Role.USER, + response_id="resp-1", + message_id="msg-1", + created_at="2024-01-01T12:00:00Z", + ), + # Assistant with first function call + AgentResponseUpdate( + contents=[FunctionCallContent(call_id=call_id_1, name="get_weather", arguments='{"location": "NYC"}')], + role=Role.ASSISTANT, + response_id="resp-1", + message_id="msg-2", + created_at="2024-01-01T12:00:01Z", + ), + # Assistant with second function call + AgentResponseUpdate( + contents=[FunctionCallContent(call_id=call_id_2, name="get_time", arguments='{"timezone": "EST"}')], + role=Role.ASSISTANT, + response_id="resp-1", + message_id="msg-3", + created_at="2024-01-01T12:00:02Z", + ), + # Second function result arrives first (no response_id) + AgentResponseUpdate( + contents=[FunctionResultContent(call_id=call_id_2, result="3:00 PM EST")], + role=Role.TOOL, + response_id=None, + message_id="msg-4", + created_at="2024-01-01T12:00:03Z", + ), + # First function result arrives second (no response_id) + AgentResponseUpdate( + contents=[FunctionResultContent(call_id=call_id_1, result="Sunny, 72F")], + role=Role.TOOL, + response_id=None, + message_id="msg-5", + created_at="2024-01-01T12:00:04Z", + ), + # Final assistant answer + AgentResponseUpdate( + contents=[TextContent(text="It's sunny (72F) and 3 PM in NYC.")], + role=Role.ASSISTANT, + response_id="resp-1", + message_id="msg-6", + created_at="2024-01-01T12:00:05Z", + ), + ] + + result = WorkflowAgent.merge_updates(updates, "final-response") + + assert len(result.messages) == 6 + + # Build a sequence of (content_type, call_id_if_applicable) + content_sequence = [] + for msg in result.messages: + for content in msg.contents: + if isinstance(content, TextContent): + content_sequence.append(("text", None)) + elif isinstance(content, FunctionCallContent): + content_sequence.append(("function_call", content.call_id)) + elif isinstance(content, FunctionResultContent): + content_sequence.append(("function_result", content.call_id)) + + # Verify all function results appear before the final assistant text + # Find indices + call_indices = [i for i, (t, _) in enumerate(content_sequence) if t == "function_call"] + result_indices = [i for i, (t, _) in enumerate(content_sequence) if t == "function_result"] + final_text_idx = len(content_sequence) - 1 # Last item should be final text + + # All calls should have corresponding results + call_ids_in_calls = {content_sequence[i][1] for i in call_indices} + call_ids_in_results = {content_sequence[i][1] for i in result_indices} + assert call_ids_in_calls == call_ids_in_results, "All function calls should have matching results" + + # All results should appear after all calls and before final text + assert all(r > max(call_indices) for r in result_indices), ( + "All function results should appear after all function calls" + ) + assert all(r < final_text_idx for r in result_indices), ( + "All function results should appear before the final assistant answer" + ) + assert content_sequence[final_text_idx] == ("text", None), "Final message should be assistant text" + + def test_merge_updates_function_result_no_matching_call(self): + """Test that FunctionResultContent without matching FunctionCallContent still appears. + + If a FunctionResultContent has a call_id that doesn't match any FunctionCallContent + in the messages, it should be appended at the end (fallback behavior). + """ + updates = [ + AgentResponseUpdate( + contents=[TextContent(text="Hello")], + role=Role.USER, + response_id="resp-1", + message_id="msg-1", + created_at="2024-01-01T12:00:00Z", + ), + # Function result with no matching call + AgentResponseUpdate( + contents=[FunctionResultContent(call_id="orphan_call_id", result="orphan result")], + role=Role.TOOL, + response_id=None, + message_id="msg-2", + created_at="2024-01-01T12:00:01Z", + ), + AgentResponseUpdate( + contents=[TextContent(text="Goodbye")], + role=Role.ASSISTANT, + response_id="resp-1", + message_id="msg-3", + created_at="2024-01-01T12:00:02Z", + ), + ] + + result = WorkflowAgent.merge_updates(updates, "final-response") + + assert len(result.messages) == 3 + + # Orphan function result should be at the end since it can't be matched + content_types = [] + for msg in result.messages: + for content in msg.contents: + if isinstance(content, TextContent): + content_types.append("text") + elif isinstance(content, FunctionResultContent): + content_types.append("function_result") + + # Order: text (user), text (assistant), function_result (orphan at end) + assert content_types == ["text", "text", "function_result"] diff --git a/python/packages/core/tests/workflow/test_workflow_builder.py b/python/packages/core/tests/workflow/test_workflow_builder.py index 91a213e3c2..ef572ba82b 100644 --- a/python/packages/core/tests/workflow/test_workflow_builder.py +++ b/python/packages/core/tests/workflow/test_workflow_builder.py @@ -7,8 +7,8 @@ from agent_framework import ( AgentExecutor, - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentThread, BaseAgent, ChatMessage, @@ -29,11 +29,11 @@ async def run(self, messages=None, *, thread: AgentThread | None = None, **kwarg norm.append(m) elif isinstance(m, str): norm.append(ChatMessage(role=Role.USER, text=m)) - return AgentRunResponse(messages=norm) + return AgentResponse(messages=norm) async def run_stream(self, messages=None, *, thread: AgentThread | None = None, **kwargs): # type: ignore[override] # Minimal async generator - yield AgentRunResponseUpdate() + yield AgentResponseUpdate() def test_builder_accepts_agents_directly(): diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index 5b7637057b..75c34f9d95 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -3,15 +3,17 @@ from collections.abc import AsyncIterable from typing import Annotated, Any +import pytest + from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentThread, BaseAgent, ChatMessage, ConcurrentBuilder, GroupChatBuilder, - GroupChatStateSnapshot, + GroupChatState, HandoffBuilder, Role, SequentialBuilder, @@ -26,11 +28,6 @@ _received_kwargs: list[dict[str, Any]] = [] -def _reset_received_kwargs() -> None: - """Reset the kwargs tracker before each test.""" - _received_kwargs.clear() - - @ai_function def tool_with_kwargs( action: Annotated[str, "The action to perform"], @@ -58,9 +55,9 @@ async def run( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: self.captured_kwargs.append(dict(kwargs)) - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.display_name} response")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.name} response")]) async def run_stream( self, @@ -68,31 +65,9 @@ async def run_stream( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: self.captured_kwargs.append(dict(kwargs)) - yield AgentRunResponseUpdate(contents=[TextContent(text=f"{self.display_name} response")]) - - -class _EchoAgent(BaseAgent): - """Simple agent that echoes back for workflow completion.""" - - async def run( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AgentRunResponse: - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.display_name} reply")]) - - async def run_stream( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: - yield AgentRunResponseUpdate(contents=[TextContent(text=f"{self.display_name} reply")]) + yield AgentResponseUpdate(contents=[TextContent(text=f"{self.name} response")]) # region Sequential Builder Tests @@ -200,17 +175,21 @@ async def test_groupchat_kwargs_flow_to_agents() -> None: # Simple selector that takes GroupChatStateSnapshot turn_count = 0 - def simple_selector(state: GroupChatStateSnapshot) -> str | None: + def simple_selector(state: GroupChatState) -> str: nonlocal turn_count turn_count += 1 - if turn_count > 2: # Stop after 2 turns - return None + if turn_count > 2: # Loop after two turns for test + turn_count = 0 # state is a Mapping - access via dict syntax - names = list(state["participants"].keys()) + names = list(state.participants.keys()) return names[(turn_count - 1) % len(names)] workflow = ( - GroupChatBuilder().participants(chat1=agent1, chat2=agent2).set_select_speakers_func(simple_selector).build() + GroupChatBuilder() + .participants([agent1, agent2]) + .with_select_speaker_func(simple_selector) + .with_max_rounds(2) # Limit rounds to prevent infinite loop + .build() ) custom_data = {"session_id": "group123"} @@ -359,6 +338,7 @@ async def test_kwargs_preserved_across_workflow_reruns() -> None: # region Handoff Builder Tests +@pytest.mark.xfail(reason="Handoff workflow does not yet propagate kwargs to agents") async def test_handoff_kwargs_flow_to_agents() -> None: """Test that kwargs flow to agents in a handoff workflow.""" agent1 = _KwargsCapturingAgent(name="coordinator") @@ -367,8 +347,9 @@ async def test_handoff_kwargs_flow_to_agents() -> None: workflow = ( HandoffBuilder() .participants([agent1, agent2]) - .set_coordinator(agent1) - .with_interaction_mode("autonomous") + .with_start_agent(agent1) + .with_autonomous_mode() + .with_termination_condition(lambda conv: len(conv) >= 4) .build() ) @@ -395,8 +376,8 @@ async def test_magentic_kwargs_flow_to_agents() -> None: from agent_framework._workflows._magentic import ( MagenticContext, MagenticManagerBase, - _MagenticProgressLedger, - _MagenticProgressLedgerItem, + MagenticProgressLedger, + MagenticProgressLedgerItem, ) # Create a mock manager that completes after one round @@ -405,29 +386,29 @@ def __init__(self) -> None: super().__init__(max_stall_count=3, max_reset_count=None, max_round_count=2) self.task_ledger = None - async def plan(self, context: MagenticContext) -> ChatMessage: + async def plan(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="Plan: Test task", author_name="manager") - async def replan(self, context: MagenticContext) -> ChatMessage: + async def replan(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="Replan: Test task", author_name="manager") - async def create_progress_ledger(self, context: MagenticContext) -> _MagenticProgressLedger: + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: # Return completed on first call - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(answer=True, reason="Done"), - is_progress_being_made=_MagenticProgressLedgerItem(answer=True, reason="Progress"), - is_in_loop=_MagenticProgressLedgerItem(answer=False, reason="Not looping"), - instruction_or_question=_MagenticProgressLedgerItem(answer="Complete", reason="Done"), - next_speaker=_MagenticProgressLedgerItem(answer="agent1", reason="First"), + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(answer=True, reason="Done"), + is_progress_being_made=MagenticProgressLedgerItem(answer=True, reason="Progress"), + is_in_loop=MagenticProgressLedgerItem(answer=False, reason="Not looping"), + instruction_or_question=MagenticProgressLedgerItem(answer="Complete", reason="Done"), + next_speaker=MagenticProgressLedgerItem(answer="agent1", reason="First"), ) - async def prepare_final_answer(self, context: MagenticContext) -> ChatMessage: + async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="Final answer", author_name="manager") agent = _KwargsCapturingAgent(name="agent1") manager = _MockManager() - workflow = MagenticBuilder().participants(agent1=agent).with_standard_manager(manager=manager).build() + workflow = MagenticBuilder().participants([agent]).with_standard_manager(manager=manager).build() custom_data = {"session_id": "magentic123"} @@ -446,8 +427,8 @@ async def test_magentic_kwargs_stored_in_shared_state() -> None: from agent_framework._workflows._magentic import ( MagenticContext, MagenticManagerBase, - _MagenticProgressLedger, - _MagenticProgressLedgerItem, + MagenticProgressLedger, + MagenticProgressLedgerItem, ) class _MockManager(MagenticManagerBase): @@ -455,28 +436,28 @@ def __init__(self) -> None: super().__init__(max_stall_count=3, max_reset_count=None, max_round_count=1) self.task_ledger = None - async def plan(self, context: MagenticContext) -> ChatMessage: + async def plan(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="Plan", author_name="manager") - async def replan(self, context: MagenticContext) -> ChatMessage: + async def replan(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="Replan", author_name="manager") - async def create_progress_ledger(self, context: MagenticContext) -> _MagenticProgressLedger: - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(answer=True, reason="Done"), - is_progress_being_made=_MagenticProgressLedgerItem(answer=True, reason="Progress"), - is_in_loop=_MagenticProgressLedgerItem(answer=False, reason="Not looping"), - instruction_or_question=_MagenticProgressLedgerItem(answer="Done", reason="Done"), - next_speaker=_MagenticProgressLedgerItem(answer="agent1", reason="First"), + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(answer=True, reason="Done"), + is_progress_being_made=MagenticProgressLedgerItem(answer=True, reason="Progress"), + is_in_loop=MagenticProgressLedgerItem(answer=False, reason="Not looping"), + instruction_or_question=MagenticProgressLedgerItem(answer="Done", reason="Done"), + next_speaker=MagenticProgressLedgerItem(answer="agent1", reason="First"), ) - async def prepare_final_answer(self, context: MagenticContext) -> ChatMessage: + async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="Final", author_name="manager") agent = _KwargsCapturingAgent(name="agent1") manager = _MockManager() - magentic_workflow = MagenticBuilder().participants(agent1=agent).with_standard_manager(manager=manager).build() + magentic_workflow = MagenticBuilder().participants([agent]).with_standard_manager(manager=manager).build() # Use MagenticWorkflow.run_stream() which goes through the kwargs attachment path custom_data = {"magentic_key": "magentic_value"} diff --git a/python/packages/declarative/agent_framework_declarative/__init__.py b/python/packages/declarative/agent_framework_declarative/__init__.py index bfc1bdffdc..7412a9e529 100644 --- a/python/packages/declarative/agent_framework_declarative/__init__.py +++ b/python/packages/declarative/agent_framework_declarative/__init__.py @@ -3,10 +3,34 @@ from importlib import metadata from ._loader import AgentFactory, DeclarativeLoaderError, ProviderLookupError, ProviderTypeMapping +from ._workflows import ( + AgentExternalInputRequest, + AgentExternalInputResponse, + AgentInvocationError, + DeclarativeWorkflowError, + ExternalInputRequest, + ExternalInputResponse, + WorkflowFactory, + WorkflowState, +) try: __version__ = metadata.version(__name__) except metadata.PackageNotFoundError: __version__ = "0.0.0" # Fallback for development mode -__all__ = ["AgentFactory", "DeclarativeLoaderError", "ProviderLookupError", "ProviderTypeMapping", "__version__"] +__all__ = [ + "AgentExternalInputRequest", + "AgentExternalInputResponse", + "AgentFactory", + "AgentInvocationError", + "DeclarativeLoaderError", + "DeclarativeWorkflowError", + "ExternalInputRequest", + "ExternalInputResponse", + "ProviderLookupError", + "ProviderTypeMapping", + "WorkflowFactory", + "WorkflowState", + "__version__", +] diff --git a/python/packages/declarative/agent_framework_declarative/_loader.py b/python/packages/declarative/agent_framework_declarative/_loader.py index 86a6b94225..30b6bab521 100644 --- a/python/packages/declarative/agent_framework_declarative/_loader.py +++ b/python/packages/declarative/agent_framework_declarative/_loader.py @@ -110,6 +110,53 @@ class ProviderLookupError(DeclarativeLoaderError): class AgentFactory: + """Factory for creating ChatAgent instances from declarative YAML definitions. + + AgentFactory parses YAML agent definitions (PromptAgent kind) and creates + configured ChatAgent instances with the appropriate chat client, tools, + and response format. + + Examples: + .. code-block:: python + + from agent_framework_declarative import AgentFactory + + # Create agent from YAML file + factory = AgentFactory() + agent = factory.create_agent_from_yaml_path("agent.yaml") + + # Run the agent + async for event in agent.run_stream("Hello!"): + print(event) + + .. code-block:: python + + from agent_framework.azure import AzureOpenAIChatClient + from agent_framework_declarative import AgentFactory + + # With pre-configured chat client + client = AzureOpenAIChatClient() + factory = AgentFactory(chat_client=client) + agent = factory.create_agent_from_yaml_path("agent.yaml") + + .. code-block:: python + + from agent_framework_declarative import AgentFactory + + # From inline YAML string + yaml_content = ''' + kind: Prompt + name: GreetingAgent + instructions: You are a friendly assistant. + model: + id: gpt-4o + provider: AzureOpenAI + ''' + + factory = AgentFactory() + agent = factory.create_agent_from_yaml(yaml_content) + """ + def __init__( self, *, @@ -123,11 +170,11 @@ def __init__( env_file_path: str | None = None, env_file_encoding: str | None = None, ) -> None: - """Create the agent factory, with bindings. + """Create the agent factory. Args: - chat_client: An optional ChatClientProtocol instance to use as a dependency, - this will be passed to the ChatAgent that get's created. + chat_client: An optional ChatClientProtocol instance to use as a dependency. + This will be passed to the ChatAgent that gets created. If you need to create multiple agents with different chat clients, do not pass this and instead provide the chat client in the YAML definition. bindings: An optional dictionary of bindings to use when creating agents. @@ -136,22 +183,22 @@ def __init__( additional_mappings: An optional dictionary to extend the provider type to object mapping. Should have the structure: - ..code-block:: python - - additional_mappings = { - "Provider.ApiType": { - "package": "package.name", - "name": "ClassName", - "model_id_field": "field_name_in_constructor", - }, - ... - } - - Here, "Provider.ApiType" is the lookup key used when both provider and apiType are specified in the - model, "Provider" is also allowed. - Package refers to which model needs to be imported, Name is the class name of the ChatClientProtocol - implementation, and model_id_field is the name of the field in the constructor - that accepts the model.id value. + ..code-block:: python + + additional_mappings = { + "Provider.ApiType": { + "package": "package.name", + "name": "ClassName", + "model_id_field": "field_name_in_constructor", + }, + ... + } + + Here, "Provider.ApiType" is the lookup key used when both provider and apiType are specified in the + model, "Provider" is also allowed. + Package refers to which model needs to be imported, Name is the class name of the ChatClientProtocol + implementation, and model_id_field is the name of the field in the constructor + that accepts the model.id value. default_provider: The default provider used when model.provider is not specified, default is "AzureAIClient". safe_mode: Whether to run in safe mode, default is True. @@ -163,6 +210,41 @@ def __init__( via the AgentFactory constructor. env_file_path: The path to the .env file to load environment variables from. env_file_encoding: The encoding of the .env file, defaults to 'utf-8'. + + Examples: + .. code-block:: python + + from agent_framework_declarative import AgentFactory + + # Minimal initialization + factory = AgentFactory() + + .. code-block:: python + + from agent_framework.azure import AzureOpenAIChatClient + from agent_framework_declarative import AgentFactory + + # With shared chat client + client = AzureOpenAIChatClient() + factory = AgentFactory( + chat_client=client, + env_file_path=".env", + ) + + .. code-block:: python + + from agent_framework_declarative import AgentFactory + + # With custom provider mappings + factory = AgentFactory( + additional_mappings={ + "CustomProvider.Chat": { + "package": "my_package.clients", + "name": "CustomChatClient", + "model_id_field": "model_name", + }, + }, + ) """ self.chat_client = chat_client self.bindings = bindings @@ -177,14 +259,15 @@ def create_agent_from_yaml_path(self, yaml_path: str | Path) -> ChatAgent: """Create a ChatAgent from a YAML file path. This method does the following things: - 1. Loads the YAML file into a AgentSchema object using open and agent_schema_dispatch. + + 1. Loads the YAML file into an AgentSchema object. 2. Validates that the loaded object is a PromptAgent. 3. Creates the appropriate ChatClient based on the model provider and apiType. 4. Parses the tools, options, and response format from the PromptAgent. 5. Creates and returns a ChatAgent instance with the configured properties. Args: - yaml_path: Path to the YAML file representation of a AgentSchema object + yaml_path: Path to the YAML file representation of a PromptAgent. Returns: The ``ChatAgent`` instance created from the YAML file. @@ -195,6 +278,28 @@ def create_agent_from_yaml_path(self, yaml_path: str | Path) -> ChatAgent: ValueError: If a ReferenceConnection cannot be resolved. ModuleNotFoundError: If the required module for the provider type cannot be imported. AttributeError: If the required class for the provider type cannot be found in the module. + + Examples: + .. code-block:: python + + from agent_framework_declarative import AgentFactory + + factory = AgentFactory() + agent = factory.create_agent_from_yaml_path("agents/support_agent.yaml") + + # Execute the agent + async for event in agent.run_stream("Help me with my order"): + print(event) + + .. code-block:: python + + from pathlib import Path + from agent_framework_declarative import AgentFactory + + # Using Path object for cross-platform compatibility + agent_path = Path(__file__).parent / "agents" / "writer.yaml" + factory = AgentFactory() + agent = factory.create_agent_from_yaml_path(agent_path) """ if not isinstance(yaml_path, Path): yaml_path = Path(yaml_path) @@ -208,14 +313,15 @@ def create_agent_from_yaml(self, yaml_str: str) -> ChatAgent: """Create a ChatAgent from a YAML string. This method does the following things: - 1. Loads the YAML string into a AgentSchema object using agent_schema_dispatch. + + 1. Loads the YAML string into an AgentSchema object. 2. Validates that the loaded object is a PromptAgent. 3. Creates the appropriate ChatClient based on the model provider and apiType. 4. Parses the tools, options, and response format from the PromptAgent. 5. Creates and returns a ChatAgent instance with the configured properties. Args: - yaml_str: YAML string representation of a AgentSchema object + yaml_str: YAML string representation of a PromptAgent. Returns: The ``ChatAgent`` instance created from the YAML string. @@ -226,12 +332,104 @@ def create_agent_from_yaml(self, yaml_str: str) -> ChatAgent: ValueError: If a ReferenceConnection cannot be resolved. ModuleNotFoundError: If the required module for the provider type cannot be imported. AttributeError: If the required class for the provider type cannot be found in the module. + + Examples: + .. code-block:: python + + from agent_framework_declarative import AgentFactory + + yaml_content = ''' + kind: Prompt + name: TranslationAgent + description: Translates text between languages + instructions: | + You are a translation assistant. + Translate user input to the requested language. + model: + id: gpt-4o + provider: AzureOpenAI + options: + temperature: 0.3 + ''' + + factory = AgentFactory() + agent = factory.create_agent_from_yaml(yaml_content) + + .. code-block:: python + + from agent_framework_declarative import AgentFactory + from pydantic import BaseModel + + # Agent with structured output + yaml_content = ''' + kind: Prompt + name: SentimentAnalyzer + instructions: Analyze the sentiment of the input text. + model: + id: gpt-4o + outputSchema: + type: object + properties: + sentiment: + type: string + enum: [positive, negative, neutral] + confidence: + type: number + ''' + + factory = AgentFactory() + agent = factory.create_agent_from_yaml(yaml_content) + """ + return self.create_agent_from_dict(yaml.safe_load(yaml_str)) + + def create_agent_from_dict(self, agent_def: dict[str, Any]) -> ChatAgent: + """Create a ChatAgent from a dictionary definition. + + This method does the following things: + + 1. Converts the dictionary into an AgentSchema object. + 2. Validates that the loaded object is a PromptAgent. + 3. Creates the appropriate ChatClient based on the model provider and apiType. + 4. Parses the tools, options, and response format from the PromptAgent. + 5. Creates and returns a ChatAgent instance with the configured properties. + + Args: + agent_def: Dictionary representation of a PromptAgent. + + Returns: + The `ChatAgent` instance created from the dictionary. + + Raises: + DeclarativeLoaderError: If the dictionary does not represent a PromptAgent. + ProviderLookupError: If the provider type is unknown or unsupported. + ValueError: If a ReferenceConnection cannot be resolved. + ModuleNotFoundError: If the required module for the provider type cannot be imported. + AttributeError: If the required class for the provider type cannot be found in the module. + + Examples: + .. code-block:: python + + from agent_framework_declarative import AgentFactory + + agent_def = { + "kind": "Prompt", + "name": "TranslationAgent", + "description": "Translates text between languages", + "instructions": "You are a translation assistant.", + "model": { + "id": "gpt-4o", + "provider": "AzureOpenAI", + }, + } + + factory = AgentFactory() + agent = factory.create_agent_from_dict(agent_def) """ # Set safe_mode context before parsing YAML to control PowerFx environment variable access _safe_mode_context.set(self.safe_mode) - prompt_agent = agent_schema_dispatch(yaml.safe_load(yaml_str)) + prompt_agent = agent_schema_dispatch(agent_def) if not isinstance(prompt_agent, PromptAgent): - raise DeclarativeLoaderError("Only yaml definitions for a PromptAgent are supported for agent creation.") + raise DeclarativeLoaderError("Only definitions for a PromptAgent are supported for agent creation.") # Step 1: Create the ChatClient client = self._get_client(prompt_agent) diff --git a/python/packages/declarative/agent_framework_declarative/_models.py b/python/packages/declarative/agent_framework_declarative/_models.py index b3a235a732..0132590a1c 100644 --- a/python/packages/declarative/agent_framework_declarative/_models.py +++ b/python/packages/declarative/agent_framework_declarative/_models.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. import os -import sys from collections.abc import MutableMapping from contextvars import ContextVar from typing import Any, Literal, TypeVar, Union @@ -11,14 +10,13 @@ try: from powerfx import Engine - engine = Engine() -except ImportError: + engine: Engine | None = Engine() +except (ImportError, RuntimeError): + # ImportError: powerfx package not installed + # RuntimeError: .NET runtime not available or misconfigured engine = None -if sys.version_info >= (3, 11): - from typing import overload # pragma: no cover -else: - from typing_extensions import overload # pragma: no cover +from typing import overload logger = get_logger("agent_framework.declarative") @@ -61,9 +59,9 @@ def _try_powerfx_eval(value: str | None, log_value: bool = True) -> str | None: return engine.eval(value[1:], symbols={"Env": dict(os.environ)}) except Exception as exc: if log_value: - logger.debug("PowerFx evaluation failed for value '%s': %s", value, exc) + logger.debug(f"PowerFx evaluation failed for value '{value}': {exc}") else: - logger.debug("PowerFx evaluation failed for value (first five characters shown) '%s': %s", value[:5], exc) + logger.debug(f"PowerFx evaluation failed for value (first five characters shown) '{value[:5]}': {exc}") return value @@ -108,7 +106,7 @@ def from_dict( # Only dispatch if we're being called on the base Property class if cls is not Property: # We're being called on a subclass, use the normal from_dict - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[misc] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] # Filter out 'type' (if it exists) field which is not a Property parameter value.pop("type", None) @@ -118,7 +116,7 @@ def from_dict( if kind == "object": return ObjectProperty.from_dict(value, dependencies=dependencies) # Default to Property for kind="property" or empty - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[misc] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] class ArrayProperty(Property): @@ -161,7 +159,7 @@ def __init__( default: Any | None = None, example: Any | None = None, enum: list[Any] | None = None, - properties: list[Property] | dict[str, Property] | None = None, + properties: list[Property] | dict[str, dict[str, Any]] | None = None, ) -> None: super().__init__( name=name, @@ -193,7 +191,7 @@ def __init__( self, examples: list[dict[str, Any]] | None = None, strict: bool = False, - properties: list[Property] | dict[str, Property] | None = None, + properties: list[Property] | dict[str, dict[str, Any]] | None = None, ) -> None: self.examples = examples or [] self.strict = strict @@ -218,7 +216,7 @@ def from_dict( # Filter out 'kind', 'type', 'name', and 'description' fields that may appear in YAML # but aren't PropertySchema params kwargs = {k: v for k, v in value.items() if k not in ("type", "kind", "name", "description")} - return SerializationMixin.from_dict.__func__(cls, kwargs, dependencies=dependencies) # type: ignore[misc] + return SerializationMixin.from_dict.__func__(cls, kwargs, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] def to_json_schema(self) -> dict[str, Any]: """Get a schema out of this PropertySchema to create pydantic models.""" @@ -260,26 +258,26 @@ def from_dict( # Only dispatch if we're being called on the base Connection class if cls is not Connection: # We're being called on a subclass, use the normal from_dict - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[misc] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] kind = value.get("kind", "").lower() if kind == "reference": - return SerializationMixin.from_dict.__func__( # type: ignore[misc] + return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] ReferenceConnection, value, dependencies=dependencies ) if kind == "remote": - return SerializationMixin.from_dict.__func__( # type: ignore[misc] + return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] RemoteConnection, value, dependencies=dependencies ) if kind in ("key", "apikey"): - return SerializationMixin.from_dict.__func__( # type: ignore[misc] + return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] ApiKeyConnection, value, dependencies=dependencies ) if kind == "anonymous": - return SerializationMixin.from_dict.__func__( # type: ignore[misc] + return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] AnonymousConnection, value, dependencies=dependencies ) - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[misc] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] class ReferenceConnection(Connection): @@ -498,13 +496,13 @@ def from_dict( # Only dispatch if we're being called on the base AgentDefinition class if cls is not AgentDefinition: # We're being called on a subclass, use the normal from_dict - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[misc] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] kind = value.get("kind", "") if kind == "Prompt" or kind == "Agent": return PromptAgent.from_dict(value, dependencies=dependencies) # Default to AgentDefinition - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[misc] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] TTool = TypeVar("TTool", bound="Tool") @@ -544,39 +542,39 @@ def from_dict( # Only dispatch if we're being called on the base Tool class if cls is not Tool: # We're being called on a subclass, use the normal from_dict - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[misc] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] kind = value.get("kind", "") if kind == "function": - return SerializationMixin.from_dict.__func__( # type: ignore[misc] + return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] FunctionTool, value, dependencies=dependencies ) if kind == "custom": - return SerializationMixin.from_dict.__func__( # type: ignore[misc] + return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] CustomTool, value, dependencies=dependencies ) if kind == "web_search": - return SerializationMixin.from_dict.__func__( # type: ignore[misc] + return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] WebSearchTool, value, dependencies=dependencies ) if kind == "file_search": - return SerializationMixin.from_dict.__func__( # type: ignore[misc] + return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] FileSearchTool, value, dependencies=dependencies ) if kind == "mcp": - return SerializationMixin.from_dict.__func__( # type: ignore[misc] + return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] McpTool, value, dependencies=dependencies ) if kind == "openapi": - return SerializationMixin.from_dict.__func__( # type: ignore[misc] + return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] OpenApiTool, value, dependencies=dependencies ) if kind == "code_interpreter": - return SerializationMixin.from_dict.__func__( # type: ignore[misc] + return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] CodeInterpreterTool, value, dependencies=dependencies ) # Default to base Tool class - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[misc] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] class FunctionTool(Tool): @@ -874,18 +872,18 @@ def from_dict( # Only dispatch if we're being called on the base Resource class if cls is not Resource: # We're being called on a subclass, use the normal from_dict - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[misc] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] kind = value.get("kind", "") if kind == "model": - return SerializationMixin.from_dict.__func__( # type: ignore[misc] + return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] ModelResource, value, dependencies=dependencies ) if kind == "tool": - return SerializationMixin.from_dict.__func__( # type: ignore[misc] + return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] ToolResource, value, dependencies=dependencies ) - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[misc] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] class ModelResource(Resource): diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py b/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py new file mode 100644 index 0000000000..9fb693b18b --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py @@ -0,0 +1,139 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Declarative workflow support for agent-framework. + +This module provides the ability to create executable Workflow objects from YAML definitions, +enabling multi-agent orchestration patterns like Foreach, conditionals, and agent invocations. + +Graph-based execution enables: +- Checkpointing at action boundaries +- Workflow visualization +- Pause/resume capabilities +- Full integration with the workflow runtime +""" + +from ._declarative_base import ( + DECLARATIVE_STATE_KEY, + ActionComplete, + ActionTrigger, + ConversationData, + DeclarativeActionExecutor, + DeclarativeMessage, + DeclarativeStateData, + DeclarativeWorkflowState, + LoopControl, + LoopIterationResult, +) +from ._declarative_builder import ALL_ACTION_EXECUTORS, DeclarativeWorkflowBuilder +from ._executors_agents import ( + AGENT_ACTION_EXECUTORS, + AGENT_REGISTRY_KEY, + TOOL_REGISTRY_KEY, + AgentExternalInputRequest, + AgentExternalInputResponse, + AgentInvocationError, + AgentResult, + ExternalLoopState, + InvokeAzureAgentExecutor, + InvokeToolExecutor, +) +from ._executors_basic import ( + BASIC_ACTION_EXECUTORS, + AppendValueExecutor, + ClearAllVariablesExecutor, + EmitEventExecutor, + ResetVariableExecutor, + SendActivityExecutor, + SetMultipleVariablesExecutor, + SetTextVariableExecutor, + SetValueExecutor, + SetVariableExecutor, +) +from ._executors_control_flow import ( + CONTROL_FLOW_EXECUTORS, + BreakLoopExecutor, + ContinueLoopExecutor, + EndConversationExecutor, + EndWorkflowExecutor, + ForeachInitExecutor, + ForeachNextExecutor, + JoinExecutor, +) +from ._executors_external_input import ( + EXTERNAL_INPUT_EXECUTORS, + ConfirmationExecutor, + ExternalInputRequest, + ExternalInputResponse, + QuestionExecutor, + RequestExternalInputExecutor, + WaitForInputExecutor, +) +from ._factory import DeclarativeWorkflowError, WorkflowFactory +from ._handlers import ActionHandler, action_handler, get_action_handler +from ._human_input import ( + ExternalLoopEvent, + QuestionRequest, + process_external_loop, + validate_input_response, +) +from ._state import WorkflowState + +__all__ = [ + "AGENT_ACTION_EXECUTORS", + "AGENT_REGISTRY_KEY", + "ALL_ACTION_EXECUTORS", + "BASIC_ACTION_EXECUTORS", + "CONTROL_FLOW_EXECUTORS", + "DECLARATIVE_STATE_KEY", + "EXTERNAL_INPUT_EXECUTORS", + "TOOL_REGISTRY_KEY", + "ActionComplete", + "ActionHandler", + "ActionTrigger", + "AgentExternalInputRequest", + "AgentExternalInputResponse", + "AgentInvocationError", + "AgentResult", + "AppendValueExecutor", + "BreakLoopExecutor", + "ClearAllVariablesExecutor", + "ConfirmationExecutor", + "ContinueLoopExecutor", + "ConversationData", + "DeclarativeActionExecutor", + "DeclarativeMessage", + "DeclarativeStateData", + "DeclarativeWorkflowBuilder", + "DeclarativeWorkflowError", + "DeclarativeWorkflowState", + "EmitEventExecutor", + "EndConversationExecutor", + "EndWorkflowExecutor", + "ExternalInputRequest", + "ExternalInputResponse", + "ExternalLoopEvent", + "ExternalLoopState", + "ForeachInitExecutor", + "ForeachNextExecutor", + "InvokeAzureAgentExecutor", + "InvokeToolExecutor", + "JoinExecutor", + "LoopControl", + "LoopIterationResult", + "QuestionExecutor", + "QuestionRequest", + "RequestExternalInputExecutor", + "ResetVariableExecutor", + "SendActivityExecutor", + "SetMultipleVariablesExecutor", + "SetTextVariableExecutor", + "SetValueExecutor", + "SetVariableExecutor", + "WaitForInputExecutor", + "WorkflowFactory", + "WorkflowState", + "action_handler", + "get_action_handler", + "process_external_loop", + "validate_input_response", +] diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_actions_agents.py b/python/packages/declarative/agent_framework_declarative/_workflows/_actions_agents.py new file mode 100644 index 0000000000..9d610d057d --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_actions_agents.py @@ -0,0 +1,625 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Agent invocation action handlers for declarative workflows. + +This module implements handlers for: +- InvokeAzureAgent: Invoke a hosted Azure AI agent +- InvokePromptAgent: Invoke a local prompt-based agent +""" + +import json +from collections.abc import AsyncGenerator +from typing import Any, cast + +from agent_framework import get_logger +from agent_framework._types import AgentResponse, ChatMessage + +from ._handlers import ( + ActionContext, + AgentResponseEvent, + AgentStreamingChunkEvent, + WorkflowEvent, + action_handler, +) +from ._human_input import ExternalLoopEvent, QuestionRequest + +logger = get_logger("agent_framework.declarative.workflows.actions") + + +def _extract_json_from_response(text: str) -> Any: + r"""Extract and parse JSON from an agent response. + + Agents often return JSON wrapped in markdown code blocks or with + explanatory text. This function attempts to extract and parse the + JSON content from various formats: + + 1. Pure JSON: {"key": "value"} + 2. Markdown code block: ```json\n{"key": "value"}\n``` + 3. Markdown code block (no language): ```\n{"key": "value"}\n``` + 4. JSON with leading/trailing text: Here's the result: {"key": "value"} + 5. Multiple JSON objects: Returns the LAST valid JSON object + + When multiple JSON objects are present (e.g., streaming agent responses + that emit partial then final results), this returns the last complete + JSON object, which is typically the final/complete result. + + Args: + text: The raw text response from an agent + + Returns: + Parsed JSON as a Python dict/list, or None if parsing fails + + Raises: + json.JSONDecodeError: If no valid JSON can be extracted + """ + import re + + if not text: + return None + + text = text.strip() + + if not text: + return None + + # Try parsing as pure JSON first + try: + return json.loads(text) + except json.JSONDecodeError: + pass + + # Try extracting from markdown code blocks: ```json ... ``` or ``` ... ``` + # Use the last code block if there are multiple + code_block_patterns = [ + r"```json\s*\n?(.*?)\n?```", # ```json ... ``` + r"```\s*\n?(.*?)\n?```", # ``` ... ``` + ] + for pattern in code_block_patterns: + matches = list(re.finditer(pattern, text, re.DOTALL)) + if matches: + # Try the last match first (most likely to be the final result) + for match in reversed(matches): + try: + return json.loads(match.group(1).strip()) + except json.JSONDecodeError: + continue + + # Find ALL JSON objects {...} or arrays [...] in the text and return the last valid one + # This handles cases where agents stream multiple JSON objects (partial, then final) + all_json_objects: list[Any] = [] + + pos = 0 + while pos < len(text): + # Find next { or [ + json_start = -1 + bracket_char = None + for i in range(pos, len(text)): + if text[i] == "{": + json_start = i + bracket_char = "{" + break + if text[i] == "[": + json_start = i + bracket_char = "[" + break + + if json_start < 0: + break # No more JSON objects + + # Find matching closing bracket + open_bracket = bracket_char + close_bracket = "}" if open_bracket == "{" else "]" + depth = 0 + in_string = False + escape_next = False + found_end = False + + for i in range(json_start, len(text)): + char = text[i] + + if escape_next: + escape_next = False + continue + + if char == "\\": + escape_next = True + continue + + if char == '"' and not escape_next: + in_string = not in_string + continue + + if in_string: + continue + + if char == open_bracket: + depth += 1 + elif char == close_bracket: + depth -= 1 + if depth == 0: + # Found the end + potential_json = text[json_start : i + 1] + try: + parsed = json.loads(potential_json) + all_json_objects.append(parsed) + except json.JSONDecodeError: + pass + pos = i + 1 + found_end = True + break + + if not found_end: + # Malformed JSON, move past the start character + pos = json_start + 1 + + # Return the last valid JSON object (most likely to be the final/complete result) + if all_json_objects: + return all_json_objects[-1] + + # Unable to extract JSON + raise json.JSONDecodeError("No valid JSON found in response", text, 0) + + +def _build_messages_from_state(ctx: ActionContext) -> list[ChatMessage]: + """Build the message list to send to an agent. + + This collects messages from: + 1. Conversation history + 2. Current input (if first agent call) + 3. Additional context from instructions + + Args: + ctx: The action context + + Returns: + List of ChatMessage objects to send to the agent + """ + messages: list[ChatMessage] = [] + + # Get conversation history + history = ctx.state.get("conversation.messages", []) + if history: + messages.extend(history) + + return messages + + +@action_handler("InvokeAzureAgent") +async def handle_invoke_azure_agent(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: + """Invoke a hosted Azure AI agent. + + Supports both Python-style and .NET-style YAML schemas: + + Python-style schema: + kind: InvokeAzureAgent + agent: agentName + input: =expression or literal input + outputPath: Local.response + + .NET-style schema: + kind: InvokeAzureAgent + agent: + name: AgentName + conversationId: =System.ConversationId + input: + arguments: + param1: value1 + messages: =expression + output: + messages: Local.Response + responseObject: Local.StructuredResponse + """ + # Get agent name - support both formats + agent_config: dict[str, Any] | str | None = ctx.action.get("agent") + agent_name: str | None = None + if isinstance(agent_config, dict): + agent_name = str(agent_config.get("name")) if agent_config.get("name") else None + # Support dynamic agent name from expression + if agent_name and isinstance(agent_name, str) and agent_name.startswith("="): + evaluated = ctx.state.eval_if_expression(agent_name) + agent_name = str(evaluated) if evaluated is not None else None + elif isinstance(agent_config, str): + agent_name = agent_config + + if not agent_name: + logger.warning("InvokeAzureAgent action missing 'agent' or 'agent.name' property") + return + + # Get input configuration + input_config: dict[str, Any] | Any = ctx.action.get("input", {}) + input_arguments: dict[str, Any] = {} + input_messages: Any = None + external_loop_when: str | None = None + if isinstance(input_config, dict): + input_config_typed = cast(dict[str, Any], input_config) + input_arguments = cast(dict[str, Any], input_config_typed.get("arguments") or {}) + input_messages = input_config_typed.get("messages") + # Extract external loop configuration + external_loop = input_config_typed.get("externalLoop") + if isinstance(external_loop, dict): + external_loop_typed = cast(dict[str, Any], external_loop) + external_loop_when = str(external_loop_typed.get("when")) if external_loop_typed.get("when") else None + else: + input_messages = input_config # Treat as message directly + + # Get output configuration (.NET style) + output_config: dict[str, Any] | Any = ctx.action.get("output", {}) + output_messages_var: str | None = None + output_response_obj_var: str | None = None + if isinstance(output_config, dict): + output_config_typed = cast(dict[str, Any], output_config) + output_messages_var = str(output_config_typed.get("messages")) if output_config_typed.get("messages") else None + output_response_obj_var = ( + str(output_config_typed.get("responseObject")) if output_config_typed.get("responseObject") else None + ) + # auto_send is defined but not used currently + _auto_send: bool = bool(output_config_typed.get("autoSend", True)) + + # Legacy Python style output path + output_path = ctx.action.get("outputPath") + + # Other properties + conversation_id = ctx.action.get("conversationId") + instructions = ctx.action.get("instructions") + tools_config: list[dict[str, Any]] = ctx.action.get("tools", []) + + # Get the agent from registry + agent = ctx.agents.get(agent_name) + if agent is None: + logger.error(f"InvokeAzureAgent: agent '{agent_name}' not found in registry") + return + + # Evaluate conversation ID + if conversation_id: + evaluated_conv_id = ctx.state.eval_if_expression(conversation_id) + ctx.state.set("System.ConversationId", evaluated_conv_id) + + # Evaluate instructions (unused currently but may be used for prompting) + _ = ctx.state.eval_if_expression(instructions) if instructions else None + + # Build messages + messages = _build_messages_from_state(ctx) + + # Handle input messages from .NET style + if input_messages: + evaluated_input = ctx.state.eval_if_expression(input_messages) + if evaluated_input: + if isinstance(evaluated_input, str): + messages.append(ChatMessage(role="user", text=evaluated_input)) + elif isinstance(evaluated_input, list): + for msg_item in evaluated_input: # type: ignore + if isinstance(msg_item, str): + messages.append(ChatMessage(role="user", text=msg_item)) + elif isinstance(msg_item, ChatMessage): + messages.append(msg_item) + elif isinstance(msg_item, dict) and "content" in msg_item: + item_dict = cast(dict[str, Any], msg_item) + role: str = str(item_dict.get("role", "user")) + content: str = str(item_dict.get("content", "")) + if role == "user": + messages.append(ChatMessage(role="user", text=content)) + elif role == "assistant": + messages.append(ChatMessage(role="assistant", text=content)) + elif role == "system": + messages.append(ChatMessage(role="system", text=content)) + + # Evaluate and include input arguments + evaluated_args: dict[str, Any] = {} + for arg_key, arg_value in input_arguments.items(): + evaluated_args[arg_key] = ctx.state.eval_if_expression(arg_value) + + # Prepare tool bindings + tool_bindings: dict[str, dict[str, Any]] = {} + for tool_config in tools_config: + tool_name: str | None = str(tool_config.get("name")) if tool_config.get("name") else None + bindings: list[dict[str, Any]] = list(tool_config.get("bindings", [])) # type: ignore[arg-type] + if tool_name and bindings: + tool_bindings[tool_name] = { + str(b.get("name")): ctx.state.eval_if_expression(b.get("input")) for b in bindings if b.get("name") + } + + logger.debug(f"InvokeAzureAgent: calling '{agent_name}' with {len(messages)} messages") + + # External loop iteration counter + iteration = 0 + max_iterations = 100 # Safety limit + + # Start external loop if configured + while True: + # Invoke the agent + try: + # Check if agent supports streaming + if hasattr(agent, "run_stream"): + updates: list[Any] = [] + tool_calls: list[Any] = [] + + async for chunk in agent.run_stream(messages): + updates.append(chunk) + + # Yield streaming events for text chunks + if hasattr(chunk, "text") and chunk.text: + yield AgentStreamingChunkEvent( + agent_name=str(agent_name), + chunk=chunk.text, + ) + + # Collect tool calls + if hasattr(chunk, "tool_calls"): + tool_calls.extend(chunk.tool_calls) + + # Build consolidated response from updates + response = AgentResponse.from_agent_run_response_updates(updates) + text = response.text + response_messages = response.messages + + # Update state with result + ctx.state.set_agent_result( + text=text, + messages=response_messages, + tool_calls=tool_calls if tool_calls else None, + ) + + # Add to conversation history + if text: + ctx.state.add_conversation_message(ChatMessage(role="assistant", text=text)) + + # Store in output variables (.NET style) + if output_messages_var: + output_path_mapped = _normalize_variable_path(output_messages_var) + ctx.state.set(output_path_mapped, response_messages if response_messages else text) + + if output_response_obj_var: + output_path_mapped = _normalize_variable_path(output_response_obj_var) + # Try to extract and parse JSON from the response + try: + parsed = _extract_json_from_response(text) if text else None + logger.debug( + f"InvokeAzureAgent (streaming): parsed responseObject for " + f"'{output_path_mapped}': type={type(parsed).__name__}, " + f"value_preview={str(parsed)[:100] if parsed else None}" + ) + ctx.state.set(output_path_mapped, parsed) + except (json.JSONDecodeError, TypeError) as e: + logger.warning( + f"InvokeAzureAgent (streaming): failed to parse JSON for " + f"'{output_path_mapped}': {e}, text_preview={text[:100] if text else None}" + ) + ctx.state.set(output_path_mapped, text) + + # Store in output path (Python style) + if output_path: + ctx.state.set(output_path, text) + + yield AgentResponseEvent( + agent_name=str(agent_name), + text=text, + messages=response_messages, + tool_calls=tool_calls if tool_calls else None, + ) + + elif hasattr(agent, "run"): + # Non-streaming invocation + response = await agent.run(messages) + + text = response.text + response_messages = response.messages + response_tool_calls: list[Any] | None = getattr(response, "tool_calls", None) + + # Update state with result + ctx.state.set_agent_result( + text=text, + messages=response_messages, + tool_calls=response_tool_calls, + ) + + # Add to conversation history + if text: + ctx.state.add_conversation_message(ChatMessage(role="assistant", text=text)) + + # Store in output variables (.NET style) + if output_messages_var: + output_path_mapped = _normalize_variable_path(output_messages_var) + ctx.state.set(output_path_mapped, response_messages if response_messages else text) + + if output_response_obj_var: + output_path_mapped = _normalize_variable_path(output_response_obj_var) + try: + parsed = _extract_json_from_response(text) if text else None + logger.debug( + f"InvokeAzureAgent (non-streaming): parsed responseObject for " + f"'{output_path_mapped}': type={type(parsed).__name__}, " + f"value_preview={str(parsed)[:100] if parsed else None}" + ) + ctx.state.set(output_path_mapped, parsed) + except (json.JSONDecodeError, TypeError) as e: + logger.warning( + f"InvokeAzureAgent (non-streaming): failed to parse JSON for " + f"'{output_path_mapped}': {e}, text_preview={text[:100] if text else None}" + ) + ctx.state.set(output_path_mapped, text) + + # Store in output path (Python style) + if output_path: + ctx.state.set(output_path, text) + + yield AgentResponseEvent( + agent_name=str(agent_name), + text=text, + messages=response_messages, + tool_calls=response_tool_calls, + ) + else: + logger.error(f"InvokeAzureAgent: agent '{agent_name}' has no run or run_stream method") + break + + except Exception as e: + logger.error(f"InvokeAzureAgent: error invoking agent '{agent_name}': {e}") + raise + + # Check external loop condition + if external_loop_when: + # Evaluate the loop condition + should_continue = ctx.state.eval(external_loop_when) + should_continue = bool(should_continue) if should_continue is not None else False + + logger.debug( + f"InvokeAzureAgent: external loop condition '{str(external_loop_when)[:50]}' = " + f"{should_continue} (iteration {iteration})" + ) + + if should_continue and iteration < max_iterations: + # Emit event to signal waiting for external input + action_id: str = str(ctx.action.get("id", f"agent_{agent_name}")) + yield ExternalLoopEvent( + action_id=action_id, + iteration=iteration, + condition_expression=str(external_loop_when), + ) + + # The workflow executor should: + # 1. Pause execution + # 2. Wait for external input + # 3. Update state with input + # 4. Resume this generator + + # For now, we request input via QuestionRequest + yield QuestionRequest( + request_id=f"{action_id}_input_{iteration}", + prompt="Waiting for user input...", + variable="Local.userInput", + ) + + iteration += 1 + + # Clear messages for next iteration (start fresh with conversation) + messages = _build_messages_from_state(ctx) + continue + elif iteration >= max_iterations: + logger.warning(f"InvokeAzureAgent: external loop exceeded max iterations ({max_iterations})") + + # No external loop or condition is false - exit + break + + +def _normalize_variable_path(variable: str) -> str: + """Normalize variable names to ensure they have a scope prefix. + + Args: + variable: Variable name like 'Local.X' or 'System.ConversationId' + + Returns: + The variable path with a scope prefix (defaults to Local if none provided) + """ + if variable.startswith(("Local.", "System.", "Workflow.", "Agent.", "Conversation.")): + # Already has a proper namespace + return variable + if "." in variable: + # Has some namespace, use as-is + return variable + # Default to Local scope + return "Local." + variable + + +@action_handler("InvokePromptAgent") +async def handle_invoke_prompt_agent(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: + """Invoke a local prompt-based agent (similar to InvokeAzureAgent but for local agents). + + Action schema: + kind: InvokePromptAgent + agent: agentName # name of the agent in the agents registry + input: =expression or literal input + instructions: =expression or literal prompt/instructions + outputPath: Local.response # optional path to store result + """ + # Implementation is similar to InvokeAzureAgent + # The difference is primarily in how the agent is configured + agent_name_raw = ctx.action.get("agent") + if not isinstance(agent_name_raw, str): + logger.warning("InvokePromptAgent action missing 'agent' property") + return + agent_name: str = agent_name_raw + input_expr = ctx.action.get("input") + instructions = ctx.action.get("instructions") + output_path = ctx.action.get("outputPath") + + # Get the agent from registry + agent = ctx.agents.get(agent_name) + if agent is None: + logger.error(f"InvokePromptAgent: agent '{agent_name}' not found in registry") + return + + # Evaluate input + input_value = ctx.state.eval_if_expression(input_expr) if input_expr else None + + # Evaluate instructions (unused currently but may be used for prompting) + _ = ctx.state.eval_if_expression(instructions) if instructions else None + + # Build messages + messages = _build_messages_from_state(ctx) + + # Add input as user message if provided + if input_value: + if isinstance(input_value, str): + messages.append(ChatMessage(role="user", text=input_value)) + elif isinstance(input_value, ChatMessage): + messages.append(input_value) + + logger.debug(f"InvokePromptAgent: calling '{agent_name}' with {len(messages)} messages") + + # Invoke the agent + try: + if hasattr(agent, "run_stream"): + updates: list[Any] = [] + + async for chunk in agent.run_stream(messages): + updates.append(chunk) + + if hasattr(chunk, "text") and chunk.text: + yield AgentStreamingChunkEvent( + agent_name=agent_name, + chunk=chunk.text, + ) + + # Build consolidated response from updates + response = AgentResponse.from_agent_run_response_updates(updates) + text = response.text + response_messages = response.messages + + ctx.state.set_agent_result(text=text, messages=response_messages) + + if text: + ctx.state.add_conversation_message(ChatMessage(role="assistant", text=text)) + + if output_path: + ctx.state.set(output_path, text) + + yield AgentResponseEvent( + agent_name=agent_name, + text=text, + messages=response_messages, + ) + + elif hasattr(agent, "run"): + response = await agent.run(messages) + text = response.text + response_messages = response.messages + + ctx.state.set_agent_result(text=text, messages=response_messages) + + if text: + ctx.state.add_conversation_message(ChatMessage(role="assistant", text=text)) + + if output_path: + ctx.state.set(output_path, text) + + yield AgentResponseEvent( + agent_name=agent_name, + text=text, + messages=response_messages, + ) + else: + logger.error(f"InvokePromptAgent: agent '{agent_name}' has no run or run_stream method") + + except Exception as e: + logger.error(f"InvokePromptAgent: error invoking agent '{agent_name}': {e}") + raise diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_actions_basic.py b/python/packages/declarative/agent_framework_declarative/_workflows/_actions_basic.py new file mode 100644 index 0000000000..243fe36e04 --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_actions_basic.py @@ -0,0 +1,571 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Basic action handlers for variable manipulation and output. + +This module implements handlers for: +- SetValue: Set a variable in the workflow state +- AppendValue: Append a value to a list variable +- SendActivity: Send text or attachments to the user +- EmitEvent: Emit a custom workflow event + +Note: All handlers are defined as async generators (AsyncGenerator[WorkflowEvent, None]) +for consistency with the ActionHandler protocol, even when they don't perform async +operations. This uniform interface allows the workflow executor to consume all handlers +the same way, and some handlers (like InvokeAzureAgent) genuinely require async for +network calls. The `return; yield` pattern makes a function an async generator without +actually yielding any events. +""" + +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any, cast + +from agent_framework import get_logger + +from ._handlers import ( + ActionContext, + AttachmentOutputEvent, + CustomEvent, + TextOutputEvent, + WorkflowEvent, + action_handler, +) + +if TYPE_CHECKING: + from ._state import WorkflowState + +logger = get_logger("agent_framework.declarative.workflows.actions") + + +@action_handler("SetValue") +async def handle_set_value(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Set a value in the workflow state. + + Action schema: + kind: SetValue + path: Local.variableName # or Workflow.Outputs.result + value: =expression or literal value + """ + path = ctx.action.get("path") + value = ctx.action.get("value") + + if not path: + logger.warning("SetValue action missing 'path' property") + return + + # Evaluate the value if it's an expression + evaluated_value = ctx.state.eval_if_expression(value) + + logger.debug(f"SetValue: {path} = {evaluated_value}") + ctx.state.set(path, evaluated_value) + + return + yield # Make it a generator + + +@action_handler("SetVariable") +async def handle_set_variable(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Set a variable in the workflow state (.NET workflow format). + + This is an alias for SetValue with 'variable' instead of 'path'. + + Action schema: + kind: SetVariable + variable: Local.variableName + value: =expression or literal value + """ + variable = ctx.action.get("variable") + value = ctx.action.get("value") + + if not variable: + logger.warning("SetVariable action missing 'variable' property") + return + + # Evaluate the value if it's an expression + evaluated_value = ctx.state.eval_if_expression(value) + + # Use .NET-style variable names directly (Local.X, System.X, Workflow.X) + path = _normalize_variable_path(variable) + + logger.debug(f"SetVariable: {variable} ({path}) = {evaluated_value}") + ctx.state.set(path, evaluated_value) + + return + yield # Make it a generator + + +def _normalize_variable_path(variable: str) -> str: + """Normalize variable names to ensure they have a scope prefix. + + Args: + variable: Variable name like 'Local.X' or 'System.ConversationId' + + Returns: + The variable path with a scope prefix (defaults to Local if none provided) + """ + if variable.startswith(("Local.", "System.", "Workflow.", "Agent.", "Conversation.")): + # Already has a proper namespace + return variable + if "." in variable: + # Has some namespace, use as-is + return variable + # Default to Local scope + return "Local." + variable + + +@action_handler("AppendValue") +async def handle_append_value(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Append a value to a list in the workflow state. + + Action schema: + kind: AppendValue + path: Local.results + value: =expression or literal value + """ + path = ctx.action.get("path") + value = ctx.action.get("value") + + if not path: + logger.warning("AppendValue action missing 'path' property") + return + + # Evaluate the value if it's an expression + evaluated_value = ctx.state.eval_if_expression(value) + + logger.debug(f"AppendValue: {path} += {evaluated_value}") + ctx.state.append(path, evaluated_value) + + return + yield # Make it a generator + + +@action_handler("SendActivity") +async def handle_send_activity(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Send text or attachments to the user. + + Action schema (object form): + kind: SendActivity + activity: + text: =expression or literal text + attachments: + - content: ... + contentType: text/plain + + Action schema (simple form): + kind: SendActivity + activity: =expression or literal text + """ + activity = ctx.action.get("activity", {}) + + # Handle simple string form + if isinstance(activity, str): + evaluated_text = ctx.state.eval_if_expression(activity) + if evaluated_text: + logger.debug( + "SendActivity: text = %s", evaluated_text[:100] if len(str(evaluated_text)) > 100 else evaluated_text + ) + yield TextOutputEvent(text=str(evaluated_text)) + return + + # Handle object form - text output + text = activity.get("text") + if text: + evaluated_text = ctx.state.eval_if_expression(text) + if evaluated_text: + logger.debug( + "SendActivity: text = %s", evaluated_text[:100] if len(str(evaluated_text)) > 100 else evaluated_text + ) + yield TextOutputEvent(text=str(evaluated_text)) + + # Handle attachments + attachments = activity.get("attachments", []) + for attachment in attachments: + content = attachment.get("content") + content_type = attachment.get("contentType", "application/octet-stream") + + if content: + evaluated_content = ctx.state.eval_if_expression(content) + logger.debug(f"SendActivity: attachment type={content_type}") + yield AttachmentOutputEvent(content=evaluated_content, content_type=content_type) + + +@action_handler("EmitEvent") +async def handle_emit_event(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Emit a custom workflow event. + + Action schema: + kind: EmitEvent + event: + name: eventName + data: =expression or literal data + """ + event_def = ctx.action.get("event", {}) + name = event_def.get("name") + data = event_def.get("data") + + if not name: + logger.warning("EmitEvent action missing 'event.name' property") + return + + # Evaluate data if it's an expression + evaluated_data = ctx.state.eval_if_expression(data) + + logger.debug(f"EmitEvent: {name} = {evaluated_data}") + yield CustomEvent(name=name, data=evaluated_data) + + +def _evaluate_dict_values(d: dict[str, Any], state: "WorkflowState") -> dict[str, Any]: + """Recursively evaluate PowerFx expressions in a dictionary. + + Args: + d: Dictionary that may contain expression values + state: The workflow state for expression evaluation + + Returns: + Dictionary with all expressions evaluated + """ + result: dict[str, Any] = {} + for key, value in d.items(): + if isinstance(value, str): + result[key] = state.eval_if_expression(value) + elif isinstance(value, dict): + result[key] = _evaluate_dict_values(cast(dict[str, Any], value), state) + elif isinstance(value, list): + evaluated_list: list[Any] = [] + for list_item in value: + if isinstance(list_item, dict): + evaluated_list.append(_evaluate_dict_values(cast(dict[str, Any], list_item), state)) + elif isinstance(list_item, str): + evaluated_list.append(state.eval_if_expression(list_item)) + else: + evaluated_list.append(list_item) + result[key] = evaluated_list + else: + result[key] = value + return result + + +@action_handler("SetTextVariable") +async def handle_set_text_variable(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Set a text variable with string interpolation support. + + This is similar to SetVariable but supports multi-line text with + {Local.Variable} style interpolation. + + Action schema: + kind: SetTextVariable + variable: Local.myText + value: |- + Multi-line text with {Local.Variable} interpolation + and more content here. + """ + variable = ctx.action.get("variable") + value = ctx.action.get("value") + + if not variable: + logger.warning("SetTextVariable action missing 'variable' property") + return + + # Evaluate the value - handle string interpolation + if isinstance(value, str): + evaluated_value = _interpolate_string(value, ctx.state) + else: + evaluated_value = ctx.state.eval_if_expression(value) + + path = _normalize_variable_path(variable) + + logger.debug(f"SetTextVariable: {variable} ({path}) = {str(evaluated_value)[:100]}") + ctx.state.set(path, evaluated_value) + + return + yield # Make it a generator + + +@action_handler("SetMultipleVariables") +async def handle_set_multiple_variables(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Set multiple variables at once. + + Action schema: + kind: SetMultipleVariables + variables: + - variable: Local.var1 + value: value1 + - variable: Local.var2 + value: =expression + """ + variables = ctx.action.get("variables", []) + + for var_def in variables: + variable = var_def.get("variable") + value = var_def.get("value") + + if not variable: + logger.warning("SetMultipleVariables: variable entry missing 'variable' property") + continue + + evaluated_value = ctx.state.eval_if_expression(value) + path = _normalize_variable_path(variable) + + logger.debug(f"SetMultipleVariables: {variable} ({path}) = {evaluated_value}") + ctx.state.set(path, evaluated_value) + + return + yield # Make it a generator + + +@action_handler("ResetVariable") +async def handle_reset_variable(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Reset a variable to its default/blank state. + + Action schema: + kind: ResetVariable + variable: Local.variableName + """ + variable = ctx.action.get("variable") + + if not variable: + logger.warning("ResetVariable action missing 'variable' property") + return + + path = _normalize_variable_path(variable) + + logger.debug(f"ResetVariable: {variable} ({path}) = None") + ctx.state.set(path, None) + + return + yield # Make it a generator + + +@action_handler("ClearAllVariables") +async def handle_clear_all_variables(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Clear all turn-scoped variables. + + Action schema: + kind: ClearAllVariables + """ + logger.debug("ClearAllVariables: clearing turn scope") + ctx.state.reset_local() + + return + yield # Make it a generator + + +@action_handler("CreateConversation") +async def handle_create_conversation(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Create a new conversation context. + + Action schema (.NET style): + kind: CreateConversation + conversationId: Local.myConversationId # Variable to store the generated ID + + The conversationId parameter is the OUTPUT variable where the generated + conversation ID will be stored. This matches .NET behavior where: + - A unique conversation ID is always auto-generated + - The conversationId parameter specifies where to store it + """ + import uuid + + conversation_id_var = ctx.action.get("conversationId") + + # Always generate a unique ID (.NET behavior) + generated_id = str(uuid.uuid4()) + + # Store conversation in state + conversations: dict[str, Any] = ctx.state.get("System.conversations") or {} + conversations[generated_id] = { + "id": generated_id, + "messages": [], + "created_at": None, # Could add timestamp + } + ctx.state.set("System.conversations", conversations) + + logger.debug(f"CreateConversation: created {generated_id}") + + # Store the generated ID in the specified variable (.NET style output binding) + if conversation_id_var: + output_path = _normalize_variable_path(conversation_id_var) + ctx.state.set(output_path, generated_id) + logger.debug(f"CreateConversation: bound to {output_path} = {generated_id}") + + # Also handle legacy output binding for backwards compatibility + output = ctx.action.get("output", {}) + output_var = output.get("conversationId") + if output_var: + output_path = _normalize_variable_path(output_var) + ctx.state.set(output_path, generated_id) + logger.debug(f"CreateConversation: legacy output bound to {output_path}") + + return + yield # Make it a generator + + +@action_handler("AddConversationMessage") +async def handle_add_conversation_message(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Add a message to a conversation. + + Action schema: + kind: AddConversationMessage + conversationId: =expression or variable reference + message: + role: user | assistant | system + content: =expression or literal text + """ + conversation_id = ctx.action.get("conversationId") + message_def = ctx.action.get("message", {}) + + if not conversation_id: + logger.warning("AddConversationMessage missing 'conversationId' property") + return + + # Evaluate conversation ID + evaluated_id = ctx.state.eval_if_expression(conversation_id) + + # Evaluate message content + role = message_def.get("role", "user") + content = message_def.get("content", "") + + evaluated_content = ctx.state.eval_if_expression(content) + if isinstance(evaluated_content, str): + evaluated_content = _interpolate_string(evaluated_content, ctx.state) + + # Get or create conversation + conversations: dict[str, Any] = ctx.state.get("System.conversations") or {} + if evaluated_id not in conversations: + conversations[evaluated_id] = {"id": evaluated_id, "messages": []} + + # Add message + message: dict[str, Any] = {"role": role, "content": evaluated_content} + conv_entry: dict[str, Any] = dict(conversations[evaluated_id]) + messages_list: list[Any] = list(conv_entry.get("messages", [])) + messages_list.append(message) + conv_entry["messages"] = messages_list + conversations[evaluated_id] = conv_entry + ctx.state.set("System.conversations", conversations) + + # Also add to global conversation state + ctx.state.add_conversation_message(message) + + logger.debug(f"AddConversationMessage: added {role} message to {evaluated_id}") + + return + yield # Make it a generator + + +@action_handler("CopyConversationMessages") +async def handle_copy_conversation_messages(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Copy messages from one conversation to another. + + Action schema: + kind: CopyConversationMessages + sourceConversationId: =expression + targetConversationId: =expression + count: 10 # optional, number of messages to copy + """ + source_id = ctx.action.get("sourceConversationId") + target_id = ctx.action.get("targetConversationId") + count = ctx.action.get("count") + + if not source_id or not target_id: + logger.warning("CopyConversationMessages missing source or target conversation ID") + return + + # Evaluate IDs + evaluated_source = ctx.state.eval_if_expression(source_id) + evaluated_target = ctx.state.eval_if_expression(target_id) + + # Get conversations + conversations: dict[str, Any] = ctx.state.get("System.conversations") or {} + + source_conv: dict[str, Any] = conversations.get(evaluated_source, {}) + source_messages: list[Any] = source_conv.get("messages", []) + + # Limit messages if count specified + if count is not None: + source_messages = source_messages[-count:] + + # Get or create target conversation + if evaluated_target not in conversations: + conversations[evaluated_target] = {"id": evaluated_target, "messages": []} + + # Copy messages + target_entry: dict[str, Any] = dict(conversations[evaluated_target]) + target_messages: list[Any] = list(target_entry.get("messages", [])) + target_messages.extend(source_messages) + target_entry["messages"] = target_messages + conversations[evaluated_target] = target_entry + ctx.state.set("System.conversations", conversations) + + logger.debug( + "CopyConversationMessages: copied %d messages from %s to %s", + len(source_messages), + evaluated_source, + evaluated_target, + ) + + return + yield # Make it a generator + + +@action_handler("RetrieveConversationMessages") +async def handle_retrieve_conversation_messages(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Retrieve messages from a conversation and store in a variable. + + Action schema: + kind: RetrieveConversationMessages + conversationId: =expression + output: + messages: Local.myMessages + count: 10 # optional + """ + conversation_id = ctx.action.get("conversationId") + output = ctx.action.get("output", {}) + count = ctx.action.get("count") + + if not conversation_id: + logger.warning("RetrieveConversationMessages missing 'conversationId' property") + return + + # Evaluate conversation ID + evaluated_id = ctx.state.eval_if_expression(conversation_id) + + # Get messages + conversations: dict[str, Any] = ctx.state.get("System.conversations") or {} + conv: dict[str, Any] = conversations.get(evaluated_id, {}) + messages: list[Any] = conv.get("messages", []) + + # Limit messages if count specified + if count is not None: + messages = messages[-count:] + + # Handle output binding + output_var = output.get("messages") + if output_var: + output_path = _normalize_variable_path(output_var) + ctx.state.set(output_path, messages) + logger.debug(f"RetrieveConversationMessages: bound {len(messages)} messages to {output_path}") + + return + yield # Make it a generator + + +def _interpolate_string(text: str, state: "WorkflowState") -> str: + """Interpolate {Variable.Path} references in a string. + + Args: + text: Text that may contain {Variable.Path} references + state: The workflow state for variable lookup + + Returns: + Text with variables interpolated + """ + import re + + def replace_var(match: re.Match[str]) -> str: + var_path: str = match.group(1) + # Map .NET style to Python style + path = _normalize_variable_path(var_path) + value = state.get(path) + return str(value) if value is not None else "" + + # Match {Variable.Path} patterns + pattern = r"\{([A-Za-z][A-Za-z0-9_.]*)\}" + return re.sub(pattern, replace_var, text) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_actions_control_flow.py b/python/packages/declarative/agent_framework_declarative/_workflows/_actions_control_flow.py new file mode 100644 index 0000000000..0bd04369f7 --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_actions_control_flow.py @@ -0,0 +1,397 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Control flow action handlers for declarative workflows. + +This module implements handlers for: +- Foreach: Iterate over a collection and execute nested actions +- If: Conditional branching +- Switch: Multi-way branching based on value matching +- RepeatUntil: Loop until a condition is met +- BreakLoop: Exit the current loop +- ContinueLoop: Skip to the next iteration +""" + +from collections.abc import AsyncGenerator + +from agent_framework import get_logger + +from ._handlers import ( + ActionContext, + LoopControlSignal, + WorkflowEvent, + action_handler, +) + +logger = get_logger("agent_framework.declarative.workflows.actions") + + +@action_handler("Foreach") +async def handle_foreach(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: + """Iterate over a collection and execute nested actions for each item. + + Action schema: + kind: Foreach + source: =expression returning a collection + itemName: itemVariable # optional, defaults to 'item' + indexName: indexVariable # optional, defaults to 'index' + actions: + - kind: ... + """ + source_expr = ctx.action.get("source") + item_name = ctx.action.get("itemName", "item") + index_name = ctx.action.get("indexName", "index") + actions = ctx.action.get("actions", []) + + if not source_expr: + logger.warning("Foreach action missing 'source' property") + return + + # Evaluate the source collection + collection = ctx.state.eval_if_expression(source_expr) + + if collection is None: + logger.debug("Foreach: source evaluated to None, skipping") + return + + if not hasattr(collection, "__iter__"): + logger.warning(f"Foreach: source is not iterable: {type(collection).__name__}") + return + + collection_len = len(list(collection)) if hasattr(collection, "__len__") else "?" + logger.debug(f"Foreach: iterating over {collection_len} items") + + # Iterate over the collection + for index, item in enumerate(collection): + # Set loop variables in the Local scope + ctx.state.set(f"Local.{item_name}", item) + ctx.state.set(f"Local.{index_name}", index) + + # Execute nested actions + try: + async for event in ctx.execute_actions(actions, ctx.state): + # Check for loop control signals + if isinstance(event, LoopControlSignal): + if event.signal_type == "break": + logger.debug(f"Foreach: break signal received at index {index}") + return + elif event.signal_type == "continue": + logger.debug(f"Foreach: continue signal received at index {index}") + break # Break inner loop to continue outer + else: + yield event + except StopIteration: + # Continue signal was raised + continue + + +@action_handler("If") +async def handle_if(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: + """Conditional branching based on a condition expression. + + Action schema: + kind: If + condition: =boolean expression + then: + - kind: ... # actions if condition is true + else: + - kind: ... # actions if condition is false (optional) + """ + condition_expr = ctx.action.get("condition") + then_actions = ctx.action.get("then", []) + else_actions = ctx.action.get("else", []) + + if condition_expr is None: + logger.warning("If action missing 'condition' property") + return + + # Evaluate the condition + condition_result = ctx.state.eval_if_expression(condition_expr) + + # Coerce to boolean + is_truthy = bool(condition_result) + + logger.debug( + "If: condition '%s' evaluated to %s", + condition_expr[:50] if len(str(condition_expr)) > 50 else condition_expr, + is_truthy, + ) + + # Execute the appropriate branch + actions_to_execute = then_actions if is_truthy else else_actions + + async for event in ctx.execute_actions(actions_to_execute, ctx.state): + yield event + + +@action_handler("Switch") +async def handle_switch(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: + """Multi-way branching based on value matching. + + Action schema: + kind: Switch + value: =expression to match + cases: + - match: value1 + actions: + - kind: ... + - match: value2 + actions: + - kind: ... + default: + - kind: ... # optional default actions + """ + value_expr = ctx.action.get("value") + cases = ctx.action.get("cases", []) + default_actions = ctx.action.get("default", []) + + if not value_expr: + logger.warning("Switch action missing 'value' property") + return + + # Evaluate the switch value + switch_value = ctx.state.eval_if_expression(value_expr) + + logger.debug(f"Switch: value = {switch_value}") + + # Find matching case + matched_actions = None + for case in cases: + match_value = ctx.state.eval_if_expression(case.get("match")) + if switch_value == match_value: + matched_actions = case.get("actions", []) + logger.debug(f"Switch: matched case '{match_value}'") + break + + # Use default if no match found + if matched_actions is None: + matched_actions = default_actions + logger.debug("Switch: using default case") + + # Execute matched actions + async for event in ctx.execute_actions(matched_actions, ctx.state): + yield event + + +@action_handler("RepeatUntil") +async def handle_repeat_until(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: + """Loop until a condition becomes true. + + Action schema: + kind: RepeatUntil + condition: =boolean expression (loop exits when true) + maxIterations: 100 # optional safety limit + actions: + - kind: ... + """ + condition_expr = ctx.action.get("condition") + max_iterations = ctx.action.get("maxIterations", 100) + actions = ctx.action.get("actions", []) + + if condition_expr is None: + logger.warning("RepeatUntil action missing 'condition' property") + return + + iteration = 0 + while iteration < max_iterations: + iteration += 1 + ctx.state.set("Local.iteration", iteration) + + logger.debug(f"RepeatUntil: iteration {iteration}") + + # Execute loop body + should_break = False + async for event in ctx.execute_actions(actions, ctx.state): + if isinstance(event, LoopControlSignal): + if event.signal_type == "break": + logger.debug(f"RepeatUntil: break signal received at iteration {iteration}") + should_break = True + break + elif event.signal_type == "continue": + logger.debug(f"RepeatUntil: continue signal received at iteration {iteration}") + break + else: + yield event + + if should_break: + break + + # Check exit condition + condition_result = ctx.state.eval_if_expression(condition_expr) + if bool(condition_result): + logger.debug(f"RepeatUntil: condition met after {iteration} iterations") + break + + if iteration >= max_iterations: + logger.warning(f"RepeatUntil: reached max iterations ({max_iterations})") + + +@action_handler("BreakLoop") +async def handle_break_loop(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Signal to break out of the current loop. + + Action schema: + kind: BreakLoop + """ + logger.debug("BreakLoop: signaling break") + yield LoopControlSignal(signal_type="break") + + +@action_handler("ContinueLoop") +async def handle_continue_loop(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Signal to continue to the next iteration of the current loop. + + Action schema: + kind: ContinueLoop + """ + logger.debug("ContinueLoop: signaling continue") + yield LoopControlSignal(signal_type="continue") + + +@action_handler("ConditionGroup") +async def handle_condition_group(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: + """Multi-condition branching (like else-if chains). + + Evaluates conditions in order and executes the first matching condition's actions. + If no conditions match and elseActions is provided, executes those. + + Action schema: + kind: ConditionGroup + conditions: + - condition: =boolean expression + actions: + - kind: ... + - condition: =another expression + actions: + - kind: ... + elseActions: + - kind: ... # optional, executed if no conditions match + """ + conditions = ctx.action.get("conditions", []) + else_actions = ctx.action.get("elseActions", []) + + matched = False + for condition_def in conditions: + condition_expr = condition_def.get("condition") + actions = condition_def.get("actions", []) + + if condition_expr is None: + logger.warning("ConditionGroup condition missing 'condition' property") + continue + + # Evaluate the condition + condition_result = ctx.state.eval_if_expression(condition_expr) + is_truthy = bool(condition_result) + + logger.debug( + "ConditionGroup: condition '%s' evaluated to %s", + str(condition_expr)[:50] if len(str(condition_expr)) > 50 else condition_expr, + is_truthy, + ) + + if is_truthy: + matched = True + # Execute this condition's actions + async for event in ctx.execute_actions(actions, ctx.state): + yield event + # Only execute the first matching condition + break + + # Execute elseActions if no condition matched + if not matched and else_actions: + logger.debug("ConditionGroup: no conditions matched, executing elseActions") + async for event in ctx.execute_actions(else_actions, ctx.state): + yield event + + +@action_handler("GotoAction") +async def handle_goto_action(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Jump to another action by ID (triggers re-execution from that action). + + Note: GotoAction in the .NET implementation creates a loop by restarting + execution from a specific action. In Python, we emit a GotoSignal that + the top-level executor should handle. + + Action schema: + kind: GotoAction + actionId: target_action_id + """ + action_id = ctx.action.get("actionId") + + if not action_id: + logger.warning("GotoAction missing 'actionId' property") + return + + logger.debug(f"GotoAction: jumping to action '{action_id}'") + + # Emit a goto signal that the executor should handle + yield GotoSignal(target_action_id=action_id) + + +class GotoSignal(WorkflowEvent): + """Signal to jump to a specific action by ID. + + This signal is used by GotoAction to implement control flow jumps. + The top-level executor should handle this signal appropriately. + """ + + def __init__(self, target_action_id: str) -> None: + self.target_action_id = target_action_id + + +class EndWorkflowSignal(WorkflowEvent): + """Signal to end the workflow execution. + + This signal causes the workflow to terminate gracefully. + """ + + def __init__(self, reason: str | None = None) -> None: + self.reason = reason + + +class EndConversationSignal(WorkflowEvent): + """Signal to end the current conversation. + + This signal causes the conversation to terminate while the workflow may continue. + """ + + def __init__(self, conversation_id: str | None = None, reason: str | None = None) -> None: + self.conversation_id = conversation_id + self.reason = reason + + +@action_handler("EndWorkflow") +async def handle_end_workflow(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """End the workflow execution. + + Action schema: + kind: EndWorkflow + reason: Optional reason for ending (for logging) + """ + reason = ctx.action.get("reason") + + logger.debug(f"EndWorkflow: ending workflow{f' (reason: {reason})' if reason else ''}") + + yield EndWorkflowSignal(reason=reason) + + +@action_handler("EndConversation") +async def handle_end_conversation(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """End the current conversation. + + Action schema: + kind: EndConversation + conversationId: Optional specific conversation to end + reason: Optional reason for ending + """ + conversation_id = ctx.action.get("conversationId") + reason = ctx.action.get("reason") + + # Evaluate conversation ID if provided + if conversation_id: + evaluated_id = ctx.state.eval_if_expression(conversation_id) + else: + evaluated_id = ctx.state.get("System.ConversationId") + + logger.debug(f"EndConversation: ending conversation {evaluated_id}{f' (reason: {reason})' if reason else ''}") + + yield EndConversationSignal(conversation_id=evaluated_id, reason=reason) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_actions_error.py b/python/packages/declarative/agent_framework_declarative/_workflows/_actions_error.py new file mode 100644 index 0000000000..d59a65e668 --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_actions_error.py @@ -0,0 +1,130 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Error handling action handlers for declarative workflows. + +This module implements handlers for: +- ThrowException: Raise an error that can be caught by TryCatch +- TryCatch: Try-catch-finally error handling +""" + +from collections.abc import AsyncGenerator +from dataclasses import dataclass + +from agent_framework import get_logger + +from ._handlers import ( + ActionContext, + WorkflowEvent, + action_handler, +) + +logger = get_logger("agent_framework.declarative.workflows.actions") + + +class WorkflowActionError(Exception): + """Exception raised by ThrowException action.""" + + def __init__(self, message: str, code: str | None = None): + super().__init__(message) + self.code = code + + +@dataclass +class ErrorEvent(WorkflowEvent): + """Event emitted when an error occurs.""" + + message: str + """The error message.""" + + code: str | None = None + """Optional error code.""" + + source_action: str | None = None + """The action that caused the error.""" + + +@action_handler("ThrowException") +async def handle_throw_exception(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Raise an exception that can be caught by TryCatch. + + Action schema: + kind: ThrowException + message: =expression or literal error message + code: ERROR_CODE # optional error code + """ + message_expr = ctx.action.get("message", "An error occurred") + code = ctx.action.get("code") + + # Evaluate the message if it's an expression + message = ctx.state.eval_if_expression(message_expr) + + logger.debug(f"ThrowException: {message} (code={code})") + + raise WorkflowActionError(str(message), code) + + # This yield is never reached but makes it a generator + yield ErrorEvent(message=str(message), code=code) # type: ignore[unreachable] + + +@action_handler("TryCatch") +async def handle_try_catch(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: + """Try-catch-finally error handling. + + Action schema: + kind: TryCatch + try: + - kind: ... # actions to try + catch: + - kind: ... # actions to execute on error (optional) + finally: + - kind: ... # actions to always execute (optional) + + In the catch block, the following variables are available: + Local.error.message: The error message + Local.error.code: The error code (if provided) + Local.error.type: The error type name + """ + try_actions = ctx.action.get("try", []) + catch_actions = ctx.action.get("catch", []) + finally_actions = ctx.action.get("finally", []) + + error_occurred = False + error_info = None + + # Execute try block + try: + async for event in ctx.execute_actions(try_actions, ctx.state): + yield event + except WorkflowActionError as e: + error_occurred = True + error_info = { + "message": str(e), + "code": e.code, + "type": "WorkflowActionError", + } + logger.debug(f"TryCatch: caught WorkflowActionError: {e}") + except Exception as e: + error_occurred = True + error_info = { + "message": str(e), + "code": None, + "type": type(e).__name__, + } + logger.debug(f"TryCatch: caught {type(e).__name__}: {e}") + + # Execute catch block if error occurred + if error_occurred and catch_actions: + # Set error info in Local scope + ctx.state.set("Local.error", error_info) + + try: + async for event in ctx.execute_actions(catch_actions, ctx.state): + yield event + finally: + # Clean up error info (but don't interfere with finally block) + pass + + # Execute finally block + if finally_actions: + async for event in ctx.execute_actions(finally_actions, ctx.state): + yield event diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py new file mode 100644 index 0000000000..0d881389b2 --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py @@ -0,0 +1,836 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Base classes for graph-based declarative workflow executors. + +This module provides: +- DeclarativeWorkflowState: Manages workflow variables via SharedState +- DeclarativeActionExecutor: Base class for action executors +- Message types for inter-executor communication + +PowerFx Expression Evaluation +----------------------------- +The .NET version uses RecalcEngine with: +1. Pre-registered custom functions (UserMessage, AgentMessage, MessageText) +2. Typed schemas for variables defined at compile time +3. UpdateVariable() to register mutable state with proper types + +The Python `powerfx` library only exposes eval() with runtime symbols, not +the full RecalcEngine API. We work around this by: +1. Pre-processing custom functions (UserMessage, MessageText) before PowerFx +2. Gracefully handling undefined variable errors (returning None) +3. Converting non-serializable objects to PowerFx-safe types at runtime + +See: dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/PowerFx/ +""" + +import logging +from collections.abc import Mapping +from dataclasses import dataclass +from decimal import Decimal as _Decimal +from typing import Any, Literal, TypedDict, cast + +from agent_framework._workflows import ( + Executor, + SharedState, + WorkflowContext, +) +from powerfx import Engine + +logger = logging.getLogger(__name__) + + +class ConversationData(TypedDict): + """Structure for conversation-related state data. + + Attributes: + messages: Active conversation messages for the current agent interaction. + This is the primary storage used by InvokeAgent actions. + history: Deprecated. Previously used as a separate history buffer, but + messages and history are now kept in sync. Use messages instead. + """ + + messages: list[Any] + history: list[Any] # Deprecated: use messages instead + + +class DeclarativeStateData(TypedDict, total=False): + """Structure for the declarative workflow state stored in SharedState. + + This TypedDict defines the schema for workflow variables stored + under the DECLARATIVE_STATE_KEY in SharedState. + + Variable Scopes (matching .NET naming conventions): + Inputs: Initial workflow inputs (read-only after initialization). + Outputs: Values to return from the workflow. + Local: Variables persisting within the current workflow turn. + System: System-level variables (ConversationId, LastMessage, etc.). + Agent: Results from the most recent agent invocation. + Conversation: Conversation history and messages. + Custom: User-defined custom variables. + _declarative_loop_state: Internal loop iteration state (managed by ForeachExecutors). + """ + + Inputs: dict[str, Any] + Outputs: dict[str, Any] + Local: dict[str, Any] + System: dict[str, Any] + Agent: dict[str, Any] + Conversation: ConversationData + Custom: dict[str, Any] + _declarative_loop_state: dict[str, Any] + + +# Key used in SharedState to store declarative workflow variables +DECLARATIVE_STATE_KEY = "_declarative_workflow_state" + + +# Types that PowerFx can serialize directly +# Note: Decimal is included because PowerFx returns Decimal for numeric values +_POWERFX_SAFE_TYPES = (str, int, float, bool, type(None), _Decimal) + + +def _make_powerfx_safe(value: Any) -> Any: + """Convert a value to a PowerFx-serializable form. + + PowerFx can only serialize primitive types, dicts, and lists. + Custom objects (like ChatMessage) must be converted to dicts or excluded. + + Args: + value: Any Python value + + Returns: + A PowerFx-safe representation of the value + """ + if value is None or isinstance(value, _POWERFX_SAFE_TYPES): + return value + + if isinstance(value, dict): + return {k: _make_powerfx_safe(v) for k, v in value.items()} + + if isinstance(value, list): + return [_make_powerfx_safe(item) for item in value] + + # Try to convert objects with __dict__ or dataclass-style attributes + if hasattr(value, "__dict__"): + return _make_powerfx_safe(vars(value)) + + # For other objects, try to convert to string representation + return str(value) + + +class DeclarativeWorkflowState: + """Manages workflow variables stored in SharedState. + + This class provides the same interface as the interpreter-based WorkflowState + but stores all data in SharedState for checkpointing support. + + The state is organized into namespaces (matching .NET naming conventions): + - Workflow.Inputs: Initial inputs (read-only) + - Workflow.Outputs: Values to return from workflow + - Local: Variables persisting within the workflow turn + - System: System-level variables (ConversationId, LastMessage, etc.) + - Agent: Results from most recent agent invocation + - Conversation: Conversation history + """ + + def __init__(self, shared_state: SharedState): + """Initialize with a SharedState instance. + + Args: + shared_state: The workflow's shared state for persistence + """ + self._shared_state = shared_state + + async def initialize(self, inputs: "Mapping[str, Any] | None" = None) -> None: + """Initialize the declarative state with inputs. + + Args: + inputs: Initial workflow inputs (become Workflow.Inputs.*) + """ + state_data: DeclarativeStateData = { + "Inputs": dict(inputs) if inputs else {}, + "Outputs": {}, + "Local": {}, + "System": { + "ConversationId": "default", + "LastMessage": {"Text": "", "Id": ""}, + "LastMessageText": "", + "LastMessageId": "", + }, + "Agent": {}, + "Conversation": {"messages": [], "history": []}, + "Custom": {}, + } + await self._shared_state.set(DECLARATIVE_STATE_KEY, state_data) + + async def get_state_data(self) -> DeclarativeStateData: + """Get the full state data dict from shared state.""" + try: + result: DeclarativeStateData = await self._shared_state.get(DECLARATIVE_STATE_KEY) + return result + except KeyError: + # Initialize if not present + await self.initialize() + return cast(DeclarativeStateData, await self._shared_state.get(DECLARATIVE_STATE_KEY)) + + async def set_state_data(self, data: DeclarativeStateData) -> None: + """Set the full state data dict in shared state.""" + await self._shared_state.set(DECLARATIVE_STATE_KEY, data) + + async def get(self, path: str, default: Any = None) -> Any: + """Get a value from the state using a dot-notated path. + + Args: + path: Dot-notated path like 'Local.results' or 'Workflow.Inputs.query' + default: Default value if path doesn't exist + + Returns: + The value at the path, or default if not found + """ + state_data = await self.get_state_data() + parts = path.split(".") + if not parts: + return default + + namespace = parts[0] + remaining = parts[1:] + + # Handle Workflow.Inputs and Workflow.Outputs specially + if namespace == "Workflow" and remaining: + sub_namespace = remaining[0] + remaining = remaining[1:] + if sub_namespace == "Inputs": + obj: Any = state_data.get("Inputs", {}) + elif sub_namespace == "Outputs": + obj = state_data.get("Outputs", {}) + else: + return default + elif namespace == "Local": + obj = state_data.get("Local", {}) + elif namespace == "System": + obj = state_data.get("System", {}) + elif namespace == "Agent": + obj = state_data.get("Agent", {}) + elif namespace == "Conversation": + obj = state_data.get("Conversation", {}) + else: + # Try custom namespace + custom_data: dict[str, Any] = state_data.get("Custom", {}) + obj = custom_data.get(namespace, default) + if obj is default: + return default + + # Navigate the remaining path + for part in remaining: + if isinstance(obj, dict): + obj = obj.get(part, default) # type: ignore[union-attr] + if obj is default: + return default + elif hasattr(obj, part): # type: ignore[arg-type] + obj = getattr(obj, part) # type: ignore[arg-type] + else: + return default + + return obj # type: ignore[return-value] + + async def set(self, path: str, value: Any) -> None: + """Set a value in the state using a dot-notated path. + + Args: + path: Dot-notated path like 'Local.results' or 'Workflow.Outputs.response' + value: The value to set + + Raises: + ValueError: If attempting to set Workflow.Inputs (which is read-only) + """ + state_data = await self.get_state_data() + parts = path.split(".") + if not parts: + return + + namespace = parts[0] + remaining = parts[1:] + + # Determine target dict + if namespace == "Workflow": + if not remaining: + raise ValueError("Cannot set 'Workflow' directly; use 'Workflow.Outputs.*'") + sub_namespace = remaining[0] + remaining = remaining[1:] + if sub_namespace == "Inputs": + raise ValueError("Cannot modify Workflow.Inputs - they are read-only") + if sub_namespace == "Outputs": + target = state_data.setdefault("Outputs", {}) + else: + raise ValueError(f"Unknown Workflow namespace: {sub_namespace}") + elif namespace == "Local": + target = state_data.setdefault("Local", {}) + elif namespace == "System": + target = state_data.setdefault("System", {}) + elif namespace == "Agent": + target = state_data.setdefault("Agent", {}) + elif namespace == "Conversation": + target = cast(dict[str, Any], state_data).setdefault("Conversation", {}) + else: + # Create or use custom namespace + custom = state_data.setdefault("Custom", {}) + if namespace not in custom: + custom[namespace] = {} + target = custom[namespace] + + if not remaining: + raise ValueError(f"Cannot replace entire namespace '{namespace}'") + + # Navigate to parent, creating dicts as needed + for part in remaining[:-1]: + if part not in target: + target[part] = {} + target = target[part] + + # Set the final value + target[remaining[-1]] = value + await self.set_state_data(state_data) + + async def append(self, path: str, value: Any) -> None: + """Append a value to a list at the specified path. + + If the path doesn't exist, creates a new list with the value. + + Note: This operation is not atomic. In concurrent scenarios, use explicit + locking or consider using atomic operations at the storage layer. + + Args: + path: Dot-notated path to a list + value: The value to append + """ + existing = await self.get(path) + if existing is None: + await self.set(path, [value]) + elif isinstance(existing, list): + existing_list: list[Any] = list(existing) # type: ignore[arg-type] + existing_list.append(value) + await self.set(path, existing_list) + else: + raise ValueError(f"Cannot append to non-list at path '{path}'") + + async def eval(self, expression: str) -> Any: + """Evaluate a PowerFx expression with the current state. + + Expressions starting with '=' are evaluated as PowerFx. + Other strings are returned as-is. + + Handles special custom functions not supported by PowerFx: + - UserMessage(text): Creates a user message dict from text + - MessageText(messages): Extracts text from the last message + + Args: + expression: The expression to evaluate + + Returns: + The evaluated result. Returns None if the expression references + undefined variables (matching legacy fallback parser behavior). + + Raises: + ImportError: If the powerfx package is not installed. + """ + if not expression: + return expression + + if not isinstance(expression, str): + return expression + + if not expression.startswith("="): + return expression + + # Strip the leading '=' for evaluation + formula = expression[1:] + + # Handle custom functions not supported by PowerFx + # First check if the entire formula is a custom function + result = await self._eval_custom_function(formula) + if result is not None: + return result + + # Pre-process nested custom functions (e.g., Upper(MessageText(...))) + # Replace them with their evaluated results before sending to PowerFx + formula = await self._preprocess_custom_functions(formula) + + engine = Engine() + symbols = await self._to_powerfx_symbols() + try: + return engine.eval(formula, symbols=symbols) + except ValueError as e: + error_msg = str(e) + # Handle undefined variable errors gracefully by returning None + # This matches the behavior of the legacy fallback parser + if "isn't recognized" in error_msg or "Name isn't valid" in error_msg: + logger.debug(f"PowerFx: undefined variable in expression '{formula}', returning None") + return None + raise + + async def _eval_custom_function(self, formula: str) -> Any | None: + """Handle custom functions not supported by the Python PowerFx library. + + The standard PowerFx library supports these functions but the Python wrapper + may have limitations. We also handle Copilot Studio-specific dialects. + + Returns None if the formula is not a custom function call. + """ + import re + + # Concat/Concatenate - string concatenation + # In standard PowerFx, Concatenate is for strings, Concat is for tables. + # Copilot Studio uses Concat for strings, so we support both. + match = re.match(r"(?:Concat|Concatenate)\((.+)\)$", formula.strip()) + if match: + args_str = match.group(1) + # Parse comma-separated arguments (handling nested parentheses) + args = self._parse_function_args(args_str) + evaluated_args = [] + for arg in args: + arg = arg.strip() + if arg.startswith('"') and arg.endswith('"'): + # String literal + evaluated_args.append(arg[1:-1]) + elif arg.startswith("'") and arg.endswith("'"): + # Single-quoted string literal + evaluated_args.append(arg[1:-1]) + else: + # Variable reference - evaluate it + result = await self.eval(f"={arg}") + evaluated_args.append(str(result) if result is not None else "") + return "".join(evaluated_args) + + # UserMessage(expr) - creates a user message dict + match = re.match(r"UserMessage\((.+)\)$", formula.strip()) + if match: + inner_expr = match.group(1).strip() + # Evaluate the inner expression + text = await self.eval(f"={inner_expr}") + return {"role": "user", "text": str(text) if text else ""} + + # AgentMessage(expr) - creates an assistant message dict + match = re.match(r"AgentMessage\((.+)\)$", formula.strip()) + if match: + inner_expr = match.group(1).strip() + text = await self.eval(f"={inner_expr}") + return {"role": "assistant", "text": str(text) if text else ""} + + # MessageText(expr) - extracts text from the last message + match = re.match(r"MessageText\((.+)\)$", formula.strip()) + if match: + inner_expr = match.group(1).strip() + # Reuse the helper method for consistent text extraction + return await self._eval_and_replace_message_text(inner_expr) + + return None + + async def _preprocess_custom_functions(self, formula: str) -> str: + """Pre-process custom functions nested inside other PowerFx functions. + + Custom functions like MessageText() are not supported by the PowerFx engine. + When they appear nested inside other functions (e.g., Upper(MessageText(...))), + we need to evaluate them first and replace with the result. + + For long strings (>500 chars), the result is stored in a temporary state variable + to avoid exceeding PowerFx's 1000 character expression limit. This is a limitation + of the Python PowerFx wrapper (powerfx package), which doesn't expose the + MaximumExpressionLength configuration that the .NET PowerFxConfig provides. + The .NET implementation defaults to 10,000 characters, while Python defaults to 1,000. + + Args: + formula: The PowerFx formula to pre-process + + Returns: + The formula with custom function calls replaced by their evaluated results + """ + import re + + # Threshold for storing in state vs embedding as literal. + # The Python PowerFx wrapper defaults to a 1000 char expression limit (vs 10,000 in .NET). + # We use 500 to leave room for the rest of the expression around the replaced value. + MAX_INLINE_LENGTH = 500 + + # Counter for generating unique temp variable names + temp_var_counter = 0 + + # Custom functions that need pre-processing: (regex pattern, handler) + custom_functions = [ + (r"MessageText\(", self._eval_and_replace_message_text), + ] + + for pattern, handler in custom_functions: + # Find all occurrences of the custom function + while True: + match = re.search(pattern, formula) + if not match: + break + + # Find the matching closing parenthesis + start = match.start() + paren_start = match.end() - 1 # Position of opening ( + depth = 1 + pos = paren_start + 1 + in_string = False + escape_next = False + + while pos < len(formula) and depth > 0: + char = formula[pos] + if escape_next: + escape_next = False + pos += 1 + continue + if char == "\\": + escape_next = True + pos += 1 + continue + if char == '"' and not escape_next: + in_string = not in_string + elif not in_string: + if char == "(": + depth += 1 + elif char == ")": + depth -= 1 + pos += 1 + + if depth != 0: + # Malformed expression, skip + break + + # Extract the inner expression (between parentheses) + end = pos + inner_expr = formula[paren_start + 1 : end - 1] + + # Evaluate and get replacement + replacement = await handler(inner_expr) + + # Replace in formula + if isinstance(replacement, str): + if len(replacement) > MAX_INLINE_LENGTH: + # Store long strings in a temp variable to avoid PowerFx expression limit + temp_var_name = f"_TempMessageText{temp_var_counter}" + temp_var_counter += 1 + await self.set(f"Local.{temp_var_name}", replacement) + replacement_str = f"Local.{temp_var_name}" + logger.debug( + f"Stored long MessageText result ({len(replacement)} chars) " + f"in temp variable {temp_var_name}" + ) + else: + # Short strings can be embedded directly + escaped = replacement.replace('"', '""') + replacement_str = f'"{escaped}"' + else: + replacement_str = str(replacement) if replacement is not None else '""' + + formula = formula[:start] + replacement_str + formula[end:] + + return formula + + async def _eval_and_replace_message_text(self, inner_expr: str) -> str: + """Evaluate MessageText() and return the text result. + + Args: + inner_expr: The expression inside MessageText() + + Returns: + The extracted text from the messages + """ + messages: Any = await self.eval(f"={inner_expr}") + if isinstance(messages, list) and messages: + last_msg: Any = messages[-1] + if isinstance(last_msg, dict): + # Try "text" key first (simple dict format) + if "text" in last_msg: + return str(last_msg["text"]) + # Try extracting from "contents" (ChatMessage dict format) + # ChatMessage.text concatenates text from all TextContent items + contents = last_msg.get("contents", []) + if isinstance(contents, list): + text_parts = [] + for content in contents: + if isinstance(content, dict): + # TextContent has a "text" key + if content.get("type") == "text" or "text" in content: + text_parts.append(str(content.get("text", ""))) + elif hasattr(content, "text"): + text_parts.append(str(getattr(content, "text", ""))) + if text_parts: + return " ".join(text_parts) + return "" + if hasattr(last_msg, "text"): + return str(getattr(last_msg, "text", "")) + return "" + + def _parse_function_args(self, args_str: str) -> list[str]: + """Parse comma-separated function arguments, handling nested parentheses and strings.""" + args = [] + current = [] + depth = 0 + in_string = False + string_char = None + + for char in args_str: + if char in ('"', "'") and not in_string: + in_string = True + string_char = char + current.append(char) + elif char == string_char and in_string: + in_string = False + string_char = None + current.append(char) + elif char == "(" and not in_string: + depth += 1 + current.append(char) + elif char == ")" and not in_string: + depth -= 1 + current.append(char) + elif char == "," and depth == 0 and not in_string: + args.append("".join(current).strip()) + current = [] + else: + current.append(char) + + if current: + args.append("".join(current).strip()) + + return args + + async def _to_powerfx_symbols(self) -> dict[str, Any]: + """Convert the current state to a PowerFx symbols dictionary. + + Uses .NET-style PascalCase names (System, Local, Workflow) matching + the .NET declarative workflow implementation. + """ + state_data = await self.get_state_data() + local_data = state_data.get("Local", {}) + agent_data = state_data.get("Agent", {}) + conversation_data = state_data.get("Conversation", {}) + system_data = state_data.get("System", {}) + inputs_data = state_data.get("Inputs", {}) + outputs_data = state_data.get("Outputs", {}) + + symbols: dict[str, Any] = { + # .NET-style PascalCase names (matching .NET implementation) + "Workflow": { + "Inputs": inputs_data, + "Outputs": outputs_data, + }, + "Local": local_data, + "Agent": agent_data, + "Conversation": conversation_data, + "System": system_data, + # Also expose inputs at top level for backward compatibility with =inputs.X syntax + "inputs": inputs_data, + # Custom namespaces + **state_data.get("Custom", {}), + } + # Debug log the Local symbols to help diagnose type issues + if local_data: + for key, value in local_data.items(): + logger.debug( + f"PowerFx symbol Local.{key}: type={type(value).__name__}, " + f"value_preview={str(value)[:100] if value else None}" + ) + result = _make_powerfx_safe(symbols) + return cast(dict[str, Any], result) + + async def eval_if_expression(self, value: Any) -> Any: + """Evaluate a value if it's a PowerFx expression, otherwise return as-is.""" + if isinstance(value, str): + return await self.eval(value) + if isinstance(value, dict): + value_dict: dict[str, Any] = dict(value) # type: ignore[arg-type] + return {k: await self.eval_if_expression(v) for k, v in value_dict.items()} + if isinstance(value, list): + value_list: list[Any] = list(value) # type: ignore[arg-type] + return [await self.eval_if_expression(item) for item in value_list] + return value + + async def interpolate_string(self, text: str) -> str: + """Interpolate {Variable.Path} references in a string. + + This handles template-style variable substitution like: + - "Created ticket #{Local.TicketParameters.TicketId}" + - "Routing to {Local.RoutingParameters.TeamName}" + + Args: + text: Text that may contain {Variable.Path} references + + Returns: + Text with variables interpolated + """ + import re + + async def replace_var(match: re.Match[str]) -> str: + var_path: str = match.group(1) + value = await self.get(var_path) + return str(value) if value is not None else "" + + # Match {Variable.Path} patterns + pattern = r"\{([A-Za-z][A-Za-z0-9_.]*)\}" + + # re.sub doesn't support async, so we need to do it manually + result = text + for match in re.finditer(pattern, text): + replacement = await replace_var(match) + result = result.replace(match.group(0), replacement, 1) + + return result + + +# Message types for inter-executor communication +# These are defined before DeclarativeActionExecutor since it references them + + +class ActionTrigger: + """Message that triggers a declarative action executor. + + This is sent between executors in the graph to pass control + and any action-specific data. + """ + + def __init__(self, data: Any = None): + """Initialize the action trigger. + + Args: + data: Optional data to pass to the action + """ + self.data = data + + +class ActionComplete: + """Message sent when a declarative action completes. + + This is sent to downstream executors to continue the workflow. + """ + + def __init__(self, result: Any = None): + """Initialize the completion message. + + Args: + result: Optional result from the action + """ + self.result = result + + +@dataclass +class ConditionResult: + """Result of evaluating a condition (If/Switch). + + This message is output by ConditionEvaluatorExecutor and SwitchEvaluatorExecutor + to indicate which branch should be taken. + """ + + matched: bool + branch_index: int # Which branch matched (0 = first, -1 = else/default) + value: Any = None # The evaluated condition value + + +@dataclass +class LoopIterationResult: + """Result of a loop iteration step. + + This message is output by ForeachInitExecutor and ForeachNextExecutor + to indicate whether the loop should continue. + """ + + has_next: bool + current_item: Any = None + current_index: int = 0 + + +@dataclass +class LoopControl: + """Signal for loop control (break/continue). + + This message is output by BreakLoopExecutor and ContinueLoopExecutor. + """ + + action: Literal["break", "continue"] + + +# Union type for any declarative action message - allows executors to accept +# messages from triggers, completions, and control flow results +DeclarativeMessage = ActionTrigger | ActionComplete | ConditionResult | LoopIterationResult | LoopControl + + +class DeclarativeActionExecutor(Executor): + """Base class for declarative action executors. + + Each declarative action (SetValue, SendActivity, etc.) is implemented + as a subclass of this executor. The executor receives an ActionInput + message containing the action definition and state reference. + """ + + def __init__( + self, + action_def: dict[str, Any], + *, + id: str | None = None, + ): + """Initialize the declarative action executor. + + Args: + action_def: The action definition from YAML + id: Optional executor ID (defaults to action id or generated) + """ + action_id = id or action_def.get("id") or f"{action_def.get('kind', 'action')}_{hash(str(action_def)) % 10000}" + super().__init__(id=action_id, defer_discovery=True) + self._action_def = action_def + + # Manually register handlers after initialization + self._handlers = {} + self._handler_specs = [] + self._discover_handlers() + self._discover_response_handlers() + + @property + def action_def(self) -> dict[str, Any]: + """Get the action definition.""" + return self._action_def + + @property + def display_name(self) -> str | None: + """Get the display name for logging.""" + return self._action_def.get("displayName") + + def _get_state(self, shared_state: SharedState) -> DeclarativeWorkflowState: + """Get the declarative workflow state wrapper.""" + return DeclarativeWorkflowState(shared_state) + + async def _ensure_state_initialized( + self, + ctx: "WorkflowContext[Any, Any]", + trigger: Any, + ) -> DeclarativeWorkflowState: + """Ensure declarative state is initialized. + + Follows .NET's DefaultTransform pattern - accepts any input type: + - dict/Mapping: Used directly as workflow.inputs + - str: Converted to {"input": value} + - DeclarativeMessage: Internal message, no initialization needed + - Any other type: Converted via str() to {"input": str(value)} + + Args: + ctx: The workflow context + trigger: The trigger message - can be any type + + Returns: + The initialized DeclarativeWorkflowState + """ + state = self._get_state(ctx.shared_state) + + if isinstance(trigger, dict): + # Structured inputs - use directly + await state.initialize(trigger) # type: ignore + elif isinstance(trigger, str): + # String input - wrap in dict + await state.initialize({"input": trigger}) + elif not isinstance( + trigger, (ActionTrigger, ActionComplete, ConditionResult, LoopIterationResult, LoopControl) + ): + # Any other type - convert to string like .NET's DefaultTransform + await state.initialize({"input": str(trigger)}) + + return state diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_builder.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_builder.py new file mode 100644 index 0000000000..84ecc8ea4e --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_builder.py @@ -0,0 +1,973 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Builder that transforms declarative YAML into a workflow graph. + +This module provides the DeclarativeWorkflowBuilder which is analogous to +.NET's WorkflowActionVisitor + WorkflowElementWalker. It walks the YAML +action definitions and creates a proper workflow graph with: +- Executor nodes for each action +- Edges for sequential flow +- Condition evaluator executors for If/Switch that ensure first-match semantics +- Loop edges for foreach +""" + +from typing import Any + +from agent_framework._workflows import ( + Workflow, + WorkflowBuilder, +) + +from ._declarative_base import ( + ConditionResult, + DeclarativeActionExecutor, + LoopIterationResult, +) +from ._executors_agents import AGENT_ACTION_EXECUTORS, InvokeAzureAgentExecutor +from ._executors_basic import BASIC_ACTION_EXECUTORS +from ._executors_control_flow import ( + CONTROL_FLOW_EXECUTORS, + ELSE_BRANCH_INDEX, + ConditionGroupEvaluatorExecutor, + ForeachInitExecutor, + ForeachNextExecutor, + IfConditionEvaluatorExecutor, + JoinExecutor, + SwitchEvaluatorExecutor, +) +from ._executors_external_input import EXTERNAL_INPUT_EXECUTORS + +# Combined mapping of all action kinds to executor classes +ALL_ACTION_EXECUTORS = { + **BASIC_ACTION_EXECUTORS, + **CONTROL_FLOW_EXECUTORS, + **AGENT_ACTION_EXECUTORS, + **EXTERNAL_INPUT_EXECUTORS, +} + +# Action kinds that terminate control flow (no fall-through to successor) +# These actions transfer control elsewhere and should not have sequential edges to the next action +TERMINATOR_ACTIONS = frozenset({"Goto", "GotoAction", "BreakLoop", "ContinueLoop", "EndWorkflow", "EndDialog"}) + +# Required fields for specific action kinds (schema validation) +# Each action needs at least one of the listed fields (checked with alternates) +ACTION_REQUIRED_FIELDS: dict[str, list[str]] = { + "SetValue": ["path"], + "SetVariable": ["variable"], + "AppendValue": ["path", "value"], + "SendActivity": ["activity"], + "InvokeAzureAgent": ["agent"], + "Goto": ["target"], + "GotoAction": ["actionId"], + "Foreach": ["items", "actions"], + "If": ["condition"], + "Switch": ["value"], # Switch can use value/cases or conditions (ConditionGroup style) + "ConditionGroup": ["conditions"], + "RequestHumanInput": ["variable"], + "WaitForHumanInput": ["variable"], + "EmitEvent": ["event"], +} + +# Alternate field names that satisfy required field requirements +# Key: "ActionKind.field", Value: list of alternates that satisfy the requirement +ACTION_ALTERNATE_FIELDS: dict[str, list[str]] = { + "SetValue.path": ["variable"], + "Goto.target": ["actionId"], + "GotoAction.actionId": ["target"], + "InvokeAzureAgent.agent": ["agentName"], + "Foreach.items": ["itemsSource", "source"], # source is used in some schemas + "Switch.value": ["conditions"], # Switch can be condition-based instead of value-based +} + + +class DeclarativeWorkflowBuilder: + """Builds a Workflow graph from declarative YAML actions. + + This builder transforms declarative action definitions into a proper + workflow graph with executor nodes and edges. It handles: + - Sequential actions (simple edges) + - Conditional branching (If/Switch with condition edges) + - Loops (Foreach with loop edges) + - Jumps (Goto with target edges) + + Example usage: + yaml_def = { + "actions": [ + {"kind": "SendActivity", "activity": {"text": "Hello"}}, + {"kind": "SetValue", "path": "turn.count", "value": 0}, + ] + } + builder = DeclarativeWorkflowBuilder(yaml_def) + workflow = builder.build() + """ + + def __init__( + self, + yaml_definition: dict[str, Any], + workflow_id: str | None = None, + agents: dict[str, Any] | None = None, + checkpoint_storage: Any | None = None, + validate: bool = True, + ): + """Initialize the builder. + + Args: + yaml_definition: The parsed YAML workflow definition + workflow_id: Optional ID for the workflow (defaults to name from YAML) + agents: Registry of agent instances by name (for InvokeAzureAgent actions) + checkpoint_storage: Optional checkpoint storage for pause/resume support + validate: Whether to validate the workflow definition before building (default: True) + """ + self._yaml_def = yaml_definition + self._workflow_id = workflow_id or yaml_definition.get("name", "declarative_workflow") + self._executors: dict[str, Any] = {} # id -> executor + self._action_index = 0 # Counter for generating unique IDs + self._agents = agents or {} # Agent registry for agent executors + self._checkpoint_storage = checkpoint_storage + self._pending_gotos: list[tuple[Any, str]] = [] # (goto_executor, target_id) + self._validate = validate + self._seen_explicit_ids: set[str] = set() # Track explicit IDs for duplicate detection + + def build(self) -> Workflow: + """Build the workflow graph. + + Returns: + A Workflow instance with all executors wired together + + Raises: + ValueError: If no actions are defined (empty workflow), or validation fails + """ + builder = WorkflowBuilder(name=self._workflow_id) + + # Enable checkpointing if storage is provided + if self._checkpoint_storage: + builder.with_checkpointing(self._checkpoint_storage) + + actions = self._yaml_def.get("actions", []) + if not actions: + # Empty workflow - raise an error since we need at least one executor + raise ValueError("Cannot build workflow with no actions. At least one action is required.") + + # Validate workflow definition before building + if self._validate: + self._validate_workflow(actions) + + # First pass: create all executors + entry_executor = self._create_executors_for_actions(actions, builder) + + # Set the entry point + if entry_executor: + # Check if entry is a control flow structure (If/Switch) + if getattr(entry_executor, "_is_if_structure", False) or getattr( + entry_executor, "_is_switch_structure", False + ): + # Create an entry passthrough node and wire to the structure's branches + entry_node = JoinExecutor({"kind": "Entry"}, id="_workflow_entry") + self._executors[entry_node.id] = entry_node + builder.set_start_executor(entry_node) + # Use _add_sequential_edge which knows how to wire to structures + self._add_sequential_edge(builder, entry_node, entry_executor) + else: + builder.set_start_executor(entry_executor) + else: + raise ValueError("Failed to create any executors from actions.") + + # Resolve pending gotos (back-edges for loops, forward-edges for jumps) + self._resolve_pending_gotos(builder) + + return builder.build() + + def _validate_workflow(self, actions: list[dict[str, Any]]) -> None: + """Validate the workflow definition before building. + + Performs: + - Schema validation (required fields for action types) + - Duplicate explicit action ID detection + - Circular goto reference detection + + Args: + actions: List of action definitions to validate + + Raises: + ValueError: If validation fails + """ + seen_ids: set[str] = set() + goto_targets: list[tuple[str, str | None]] = [] # (target_id, source_id) + defined_ids: set[str] = set() + + # Collect all defined IDs and validate each action + self._validate_actions_recursive(actions, seen_ids, goto_targets, defined_ids) + + # Check for circular goto chains (A -> B -> A) + # Build a simple graph of goto targets + self._validate_no_circular_gotos(goto_targets, defined_ids) + + def _validate_actions_recursive( + self, + actions: list[dict[str, Any]], + seen_ids: set[str], + goto_targets: list[tuple[str, str | None]], + defined_ids: set[str], + ) -> None: + """Recursively validate actions and collect metadata. + + Args: + actions: List of action definitions + seen_ids: Set of seen explicit IDs (for duplicate detection) + goto_targets: List of (target_id, source_id) tuples for goto validation + defined_ids: Set of all defined action IDs + """ + for action_def in actions: + kind = action_def.get("kind", "") + + # Check for duplicate explicit IDs + explicit_id = action_def.get("id") + if explicit_id: + if explicit_id in seen_ids: + raise ValueError(f"Duplicate action ID '{explicit_id}'. Action IDs must be unique.") + seen_ids.add(explicit_id) + defined_ids.add(explicit_id) + + # Schema validation: check required fields + required_fields = ACTION_REQUIRED_FIELDS.get(kind, []) + for field in required_fields: + if field not in action_def and not self._has_alternate_field(action_def, kind, field): + raise ValueError(f"Action '{kind}' is missing required field '{field}'. Action: {action_def}") + + # Collect goto targets for circular reference detection + if kind in ("Goto", "GotoAction"): + target = action_def.get("target") or action_def.get("actionId") + if target: + goto_targets.append((target, explicit_id)) + + # Recursively validate nested actions + if kind == "If": + then_actions = action_def.get("then", action_def.get("actions", [])) + if then_actions: + self._validate_actions_recursive(then_actions, seen_ids, goto_targets, defined_ids) + else_actions = action_def.get("else", []) + if else_actions: + self._validate_actions_recursive(else_actions, seen_ids, goto_targets, defined_ids) + + elif kind in ("Switch", "ConditionGroup"): + cases = action_def.get("cases", action_def.get("conditions", [])) + for case in cases: + case_actions = case.get("actions", []) + if case_actions: + self._validate_actions_recursive(case_actions, seen_ids, goto_targets, defined_ids) + else_actions = action_def.get("elseActions", action_def.get("else", action_def.get("default", []))) + if else_actions: + self._validate_actions_recursive(else_actions, seen_ids, goto_targets, defined_ids) + + elif kind == "Foreach": + body_actions = action_def.get("actions", []) + if body_actions: + self._validate_actions_recursive(body_actions, seen_ids, goto_targets, defined_ids) + + def _has_alternate_field(self, action_def: dict[str, Any], kind: str, field: str) -> bool: + """Check if an action has an alternate field that satisfies the requirement. + + Some actions support multiple field names for the same purpose. + + Args: + action_def: The action definition + kind: The action kind + field: The required field name + + Returns: + True if an alternate field exists + """ + key = f"{kind}.{field}" + return any(alt in action_def for alt in ACTION_ALTERNATE_FIELDS.get(key, [])) + + def _validate_no_circular_gotos( + self, + goto_targets: list[tuple[str, str | None]], + defined_ids: set[str], + ) -> None: + """Validate that there are no problematic circular goto chains. + + Note: Some circular references are valid (e.g., loop-back patterns). + This checks for direct self-references only as a basic validation. + + Args: + goto_targets: List of (target_id, source_id) tuples + defined_ids: Set of defined action IDs + """ + for target_id, source_id in goto_targets: + # Check for direct self-reference + if source_id and target_id == source_id: + raise ValueError( + f"Action '{source_id}' has a direct self-referencing Goto, which would cause an infinite loop." + ) + + def _resolve_pending_gotos(self, builder: WorkflowBuilder) -> None: + """Resolve pending goto edges after all executors are created. + + Creates edges from goto executors to their target executors. + + Raises: + ValueError: If a goto target references an action ID that does not exist. + """ + for goto_executor, target_id in self._pending_gotos: + target_executor = self._executors.get(target_id) + if target_executor: + # Create edge from goto to target + builder.add_edge(source=goto_executor, target=target_executor) + else: + available_ids = list(self._executors.keys()) + raise ValueError(f"Goto target '{target_id}' not found. Available action IDs: {available_ids}") + + def _create_executors_for_actions( + self, + actions: list[dict[str, Any]], + builder: WorkflowBuilder, + parent_context: dict[str, Any] | None = None, + ) -> Any | None: + """Create executors for a list of actions and wire them together. + + Args: + actions: List of action definitions + builder: The workflow builder + parent_context: Context from parent (e.g., loop info) + + Returns: + The first executor in the chain, or None if no actions + """ + if not actions: + return None + + first_executor = None + prev_executor = None + executors_in_chain: list[Any] = [] + + for action_def in actions: + executor = self._create_executor_for_action(action_def, builder, parent_context) + + if executor is None: + continue + + executors_in_chain.append(executor) + + if first_executor is None: + first_executor = executor + + # Wire sequential edge from previous executor + if prev_executor is not None: + self._add_sequential_edge(builder, prev_executor, executor) + + # Check if this action is a terminator (transfers control elsewhere) + # Terminators should not have fall-through edges to subsequent actions + action_kind = action_def.get("kind", "") + # Don't wire terminators to the next action - control flow ends there + prev_executor = None if action_kind in TERMINATOR_ACTIONS else executor + + # Store the chain for later reference + if first_executor is not None: + first_executor._chain_executors = executors_in_chain # type: ignore[attr-defined] + + return first_executor + + def _create_executor_for_action( + self, + action_def: dict[str, Any], + builder: WorkflowBuilder, + parent_context: dict[str, Any] | None = None, + ) -> Any | None: + """Create an executor for a single action. + + Args: + action_def: The action definition from YAML + builder: The workflow builder + parent_context: Context from parent + + Returns: + The created executor, or None if action type not supported + """ + kind = action_def.get("kind", "") + + # Handle special control flow actions + if kind == "If": + return self._create_if_structure(action_def, builder, parent_context) + if kind == "Switch" or kind == "ConditionGroup": + return self._create_switch_structure(action_def, builder, parent_context) + if kind == "Foreach": + return self._create_foreach_structure(action_def, builder, parent_context) + if kind == "Goto" or kind == "GotoAction": + return self._create_goto_reference(action_def, builder, parent_context) + if kind == "BreakLoop": + return self._create_break_executor(action_def, builder, parent_context) + if kind == "ContinueLoop": + return self._create_continue_executor(action_def, builder, parent_context) + + # Get the executor class for this action kind + executor_class = ALL_ACTION_EXECUTORS.get(kind) + + if executor_class is None: + # Unknown action type - skip with warning + # In production, might want to log this + return None + + # Create the executor with ID + # Priority: explicit ID from YAML > index-based ID (matches .NET behavior) + explicit_id = action_def.get("id") + if explicit_id: + action_id = explicit_id + else: + parent_id = (parent_context or {}).get("parent_id") + action_id = f"{parent_id}_{kind}_{self._action_index}" if parent_id else f"{kind}_{self._action_index}" + self._action_index += 1 + + # Pass agents to agent-related executors + executor: Any + if kind in ("InvokeAzureAgent",): + executor = InvokeAzureAgentExecutor(action_def, id=action_id, agents=self._agents) + else: + executor = executor_class(action_def, id=action_id) + self._executors[action_id] = executor + + return executor + + def _create_if_structure( + self, + action_def: dict[str, Any], + builder: WorkflowBuilder, + parent_context: dict[str, Any] | None = None, + ) -> Any: + """Create the graph structure for an If action. + + An If action is implemented with a condition evaluator executor that + outputs a ConditionResult. Edge conditions check the branch_index to + route to either the then or else branch. This ensures first-match + semantics (only one branch executes). + + Args: + action_def: The If action definition + builder: The workflow builder + parent_context: Context from parent + + Returns: + A structure representing the If with evaluator, branch entries and exits + """ + action_id = action_def.get("id") or f"If_{self._action_index}" + self._action_index += 1 + + condition_expr = action_def.get("condition", "true") + # Normalize boolean conditions from YAML to PowerFx-style strings + if condition_expr is True: + condition_expr = "=true" + elif condition_expr is False: + condition_expr = "=false" + elif isinstance(condition_expr, str) and not condition_expr.startswith("="): + # Bare string conditions should be evaluated as expressions + condition_expr = f"={condition_expr}" + + # Pass the If's ID as context for child action naming + branch_context = { + **(parent_context or {}), + "parent_id": action_id, + } + + # Create the condition evaluator executor + evaluator = IfConditionEvaluatorExecutor( + action_def, + condition_expr, + id=f"{action_id}_eval", + ) + self._executors[evaluator.id] = evaluator + + # Create then branch + then_actions = action_def.get("then", action_def.get("actions", [])) + then_entry = self._create_executors_for_actions(then_actions, builder, branch_context) + + # Create else branch + else_actions = action_def.get("else", []) + else_entry = self._create_executors_for_actions(else_actions, builder, branch_context) if else_actions else None + else_passthrough = None + if not else_entry: + # No else branch - create a passthrough for continuation when condition is false + else_passthrough = JoinExecutor({"kind": "ElsePassthrough"}, id=f"{action_id}_else_pass") + self._executors[else_passthrough.id] = else_passthrough + + # Wire evaluator to branches with conditions that check ConditionResult.branch_index + # branch_index=0 means "then" branch, branch_index=-1 (ELSE_BRANCH_INDEX) means "else" + # For nested If/Switch structures, wire to the evaluator (entry point) + if then_entry: + then_target = self._get_structure_entry(then_entry) + builder.add_edge( + source=evaluator, + target=then_target, + condition=lambda msg: isinstance(msg, ConditionResult) and msg.branch_index == 0, + ) + if else_entry: + else_target = self._get_structure_entry(else_entry) + builder.add_edge( + source=evaluator, + target=else_target, + condition=lambda msg: isinstance(msg, ConditionResult) and msg.branch_index == ELSE_BRANCH_INDEX, + ) + elif else_passthrough: + builder.add_edge( + source=evaluator, + target=else_passthrough, + condition=lambda msg: isinstance(msg, ConditionResult) and msg.branch_index == ELSE_BRANCH_INDEX, + ) + + # Get branch exit executors for later wiring to successor + then_exit = self._get_branch_exit(then_entry) + else_exit = self._get_branch_exit(else_entry) if else_entry else else_passthrough + + # Collect all branch exits (for wiring to successor) + branch_exits: list[Any] = [] + if then_exit: + branch_exits.append(then_exit) + if else_exit: + branch_exits.append(else_exit) + + # Create an IfStructure to hold all the info needed for wiring + class IfStructure: + def __init__(self) -> None: + self.id = action_id + self.evaluator = evaluator # The entry point for this structure + self.then_entry = then_entry + self.else_entry = else_entry + self.else_passthrough = else_passthrough + self.branch_exits = branch_exits # All exits that need wiring to successor + self._is_if_structure = True + + return IfStructure() + + def _create_switch_structure( + self, + action_def: dict[str, Any], + builder: WorkflowBuilder, + parent_context: dict[str, Any] | None = None, + ) -> Any: + """Create the graph structure for a Switch/ConditionGroup action. + + Supports two schema formats: + 1. ConditionGroup schema (matches .NET): + - conditions: list of {condition: expr, actions: [...]} + - elseActions: default actions + + 2. Switch schema (interpreter style): + - value: expression to match + - cases: list of {match: value, actions: [...]} + - default: default actions + + Both use evaluator executors that output ConditionResult with branch_index + for first-match semantics. + + Args: + action_def: The Switch/ConditionGroup action definition + builder: The workflow builder + parent_context: Context from parent + + Returns: + A SwitchStructure containing branch info for wiring + """ + action_id = action_def.get("id") or f"Switch_{self._action_index}" + self._action_index += 1 + + # Pass the Switch's ID as context for child action naming + branch_context = { + **(parent_context or {}), + "parent_id": action_id, + } + + # Detect schema type: + # - If "cases" present: interpreter Switch schema (value/cases/default) + # - If "conditions" present: ConditionGroup schema (conditions/elseActions) + cases = action_def.get("cases", []) + conditions = action_def.get("conditions", []) + + if cases: + # Interpreter Switch schema: value/cases/default + evaluator: DeclarativeActionExecutor = SwitchEvaluatorExecutor( + action_def, + cases, + id=f"{action_id}_eval", + ) + branch_items = cases + else: + # ConditionGroup schema: conditions/elseActions + evaluator = ConditionGroupEvaluatorExecutor( + action_def, + conditions, + id=f"{action_id}_eval", + ) + branch_items = conditions + + self._executors[evaluator.id] = evaluator + + # Collect branches and create executors for each + branch_entries: list[tuple[int, Any]] = [] # (branch_index, entry_executor) + branch_exits: list[Any] = [] # All exits that need wiring to successor + + for i, item in enumerate(branch_items): + branch_actions = item.get("actions", []) + # Use branch-specific context + case_context = {**branch_context, "parent_id": f"{action_id}_case{i}"} + branch_entry = self._create_executors_for_actions(branch_actions, builder, case_context) + + if branch_entry: + branch_entries.append((i, branch_entry)) + # Track exit for later wiring + branch_exit = self._get_branch_exit(branch_entry) + if branch_exit: + branch_exits.append(branch_exit) + + # Handle else/default branch + # .NET uses "elseActions", interpreter uses "else" or "default" + else_actions = action_def.get("elseActions", action_def.get("else", action_def.get("default", []))) + default_entry = None + default_passthrough = None + if else_actions: + default_context = {**branch_context, "parent_id": f"{action_id}_else"} + default_entry = self._create_executors_for_actions(else_actions, builder, default_context) + if default_entry: + default_exit = self._get_branch_exit(default_entry) + if default_exit: + branch_exits.append(default_exit) + else: + # No else actions - create a passthrough for the "no match" case + # This allows the workflow to continue to the next action when no condition matches + default_passthrough = JoinExecutor({"kind": "DefaultPassthrough"}, id=f"{action_id}_default") + self._executors[default_passthrough.id] = default_passthrough + branch_exits.append(default_passthrough) + + # Wire evaluator to branches with conditions that check ConditionResult.branch_index + # For nested If/Switch structures, wire to the evaluator (entry point) + for branch_index, branch_entry in branch_entries: + # Capture branch_index in closure properly using a factory function for type inference + def make_branch_condition(expected: int) -> Any: + return lambda msg: isinstance(msg, ConditionResult) and msg.branch_index == expected # type: ignore + + branch_target = self._get_structure_entry(branch_entry) + builder.add_edge( + source=evaluator, + target=branch_target, + condition=make_branch_condition(branch_index), + ) + + # Wire evaluator to default/else branch + if default_entry: + default_target = self._get_structure_entry(default_entry) + builder.add_edge( + source=evaluator, + target=default_target, + condition=lambda msg: isinstance(msg, ConditionResult) and msg.branch_index == ELSE_BRANCH_INDEX, + ) + elif default_passthrough: + builder.add_edge( + source=evaluator, + target=default_passthrough, + condition=lambda msg: isinstance(msg, ConditionResult) and msg.branch_index == ELSE_BRANCH_INDEX, + ) + + # Create a SwitchStructure to hold all the info needed for wiring + class SwitchStructure: + def __init__(self) -> None: + self.id = action_id + self.evaluator = evaluator # The entry point for this structure + self.branch_entries = branch_entries + self.default_entry = default_entry + self.default_passthrough = default_passthrough + self.branch_exits = branch_exits # All exits that need wiring to successor + self._is_switch_structure = True + + return SwitchStructure() + + def _create_foreach_structure( + self, + action_def: dict[str, Any], + builder: WorkflowBuilder, + parent_context: dict[str, Any] | None = None, + ) -> Any: + """Create the graph structure for a Foreach action. + + A Foreach action becomes: + 1. ForeachInit node that initializes the loop + 2. Loop body actions + 3. ForeachNext node that advances to next item + 4. Back-edge from ForeachNext to loop body (when has_next=True) + 5. Exit edge from ForeachNext (when has_next=False) + + Args: + action_def: The Foreach action definition + builder: The workflow builder + parent_context: Context from parent + + Returns: + The foreach init executor (entry point) + """ + action_id = action_def.get("id") or f"Foreach_{self._action_index}" + self._action_index += 1 + + # Create foreach init executor + init_executor = ForeachInitExecutor(action_def, id=f"{action_id}_init") + self._executors[init_executor.id] = init_executor + + # Create foreach next executor (for advancing to next item) + next_executor = ForeachNextExecutor(action_def, init_executor.id, id=f"{action_id}_next") + self._executors[next_executor.id] = next_executor + + # Create join node for loop exit + join_executor = JoinExecutor({"kind": "Join"}, id=f"{action_id}_exit") + self._executors[join_executor.id] = join_executor + + # Create loop body + body_actions = action_def.get("actions", []) + loop_context = { + **(parent_context or {}), + "loop_id": action_id, + "loop_next_executor": next_executor, + } + body_entry = self._create_executors_for_actions(body_actions, builder, loop_context) + + if body_entry: + # For nested If/Switch structures, wire to the evaluator (entry point) + body_target = self._get_structure_entry(body_entry) + + # Init -> body (when has_next=True) + builder.add_edge( + source=init_executor, + target=body_target, + condition=lambda msg: isinstance(msg, LoopIterationResult) and msg.has_next, + ) + + # Body exit -> Next (get all exits from body and wire to next_executor) + body_exits = self._get_source_exits(body_entry) + for body_exit in body_exits: + builder.add_edge(source=body_exit, target=next_executor) + + # Next -> body (when has_next=True, loop back) + builder.add_edge( + source=next_executor, + target=body_target, + condition=lambda msg: isinstance(msg, LoopIterationResult) and msg.has_next, + ) + + # Init -> join (when has_next=False, empty collection) + builder.add_edge( + source=init_executor, + target=join_executor, + condition=lambda msg: isinstance(msg, LoopIterationResult) and not msg.has_next, + ) + + # Next -> join (when has_next=False, loop complete) + builder.add_edge( + source=next_executor, + target=join_executor, + condition=lambda msg: isinstance(msg, LoopIterationResult) and not msg.has_next, + ) + + init_executor._exit_executor = join_executor # type: ignore[attr-defined] + return init_executor + + def _create_goto_reference( + self, + action_def: dict[str, Any], + builder: WorkflowBuilder, + parent_context: dict[str, Any] | None = None, + ) -> Any | None: + """Create a GotoAction executor that jumps to the target action. + + GotoAction creates a back-edge (or forward-edge) in the graph to the target action. + We create a pass-through executor and record the pending edge to be resolved + after all executors are created. + """ + from ._executors_control_flow import JoinExecutor + + target_id = action_def.get("target") or action_def.get("actionId") + + if not target_id: + return None + + # Create a pass-through executor for the goto + action_id = action_def.get("id") or f"goto_{target_id}_{self._action_index}" + self._action_index += 1 + + # Use JoinExecutor as a simple pass-through node + goto_executor = JoinExecutor(action_def, id=action_id) + self._executors[action_id] = goto_executor + + # Record pending goto edge to be resolved after all executors created + self._pending_gotos.append((goto_executor, target_id)) + + return goto_executor + + def _create_break_executor( + self, + action_def: dict[str, Any], + builder: WorkflowBuilder, + parent_context: dict[str, Any] | None = None, + ) -> Any | None: + """Create a break executor for loop control. + + Raises: + ValueError: If BreakLoop is used outside of a loop. + """ + from ._executors_control_flow import BreakLoopExecutor + + if parent_context and "loop_next_executor" in parent_context: + loop_next = parent_context["loop_next_executor"] + action_id = action_def.get("id") or f"Break_{self._action_index}" + self._action_index += 1 + + executor = BreakLoopExecutor(action_def, loop_next.id, id=action_id) + self._executors[action_id] = executor + + # Wire break to loop next + builder.add_edge(source=executor, target=loop_next) + + return executor + + raise ValueError("BreakLoop action can only be used inside a Foreach loop") + + def _create_continue_executor( + self, + action_def: dict[str, Any], + builder: WorkflowBuilder, + parent_context: dict[str, Any] | None = None, + ) -> Any | None: + """Create a continue executor for loop control. + + Raises: + ValueError: If ContinueLoop is used outside of a loop. + """ + from ._executors_control_flow import ContinueLoopExecutor + + if parent_context and "loop_next_executor" in parent_context: + loop_next = parent_context["loop_next_executor"] + action_id = action_def.get("id") or f"Continue_{self._action_index}" + self._action_index += 1 + + executor = ContinueLoopExecutor(action_def, loop_next.id, id=action_id) + self._executors[action_id] = executor + + # Wire continue to loop next + builder.add_edge(source=executor, target=loop_next) + + return executor + + raise ValueError("ContinueLoop action can only be used inside a Foreach loop") + + def _add_sequential_edge( + self, + builder: WorkflowBuilder, + source: Any, + target: Any, + ) -> None: + """Add a sequential edge between two executors. + + Handles control flow structures: + - If source is a structure (If/Switch), wire from all branch exits + - If target is a structure (If/Switch), wire with conditional edges to branches + """ + # Get all source exit points + source_exits = self._get_source_exits(source) + + # Wire each source exit to target + for source_exit in source_exits: + self._wire_to_target(builder, source_exit, target) + + def _get_source_exits(self, source: Any) -> list[Any]: + """Get all exit executors from a source (handles structures with multiple exits).""" + # Check if source is a structure with branch_exits + if hasattr(source, "branch_exits"): + # Collect all exits, recursively flattening nested structures + all_exits: list[Any] = [] + for exit_item in source.branch_exits: + if hasattr(exit_item, "branch_exits"): + # Nested structure - recurse + all_exits.extend(self._collect_all_exits(exit_item)) + else: + all_exits.append(exit_item) + return all_exits if all_exits else [] + + # Check if source has a single exit executor + actual_exit = getattr(source, "_exit_executor", source) + return [actual_exit] + + def _wire_to_target( + self, + builder: WorkflowBuilder, + source: Any, + target: Any, + ) -> None: + """Wire a single source executor to a target (which may be a structure). + + For If/Switch structures, wire to the evaluator executor. The evaluator + handles condition evaluation and outputs ConditionResult, which is then + routed to the appropriate branch by edges created in _create_*_structure. + """ + # Check if target is an IfStructure or SwitchStructure (wire to evaluator) + if getattr(target, "_is_if_structure", False) or getattr(target, "_is_switch_structure", False): + # Wire from source to the evaluator - the evaluator then routes to branches + builder.add_edge(source=source, target=target.evaluator) + + else: + # Normal sequential edge to a regular executor + builder.add_edge(source=source, target=target) + + def _get_structure_entry(self, entry: Any) -> Any: + """Get the entry point executor for a structure or regular executor. + + For If/Switch structures, returns the evaluator. For regular executors, + returns the executor itself. + + Args: + entry: An executor or structure + + Returns: + The entry point executor + """ + is_structure = getattr(entry, "_is_if_structure", False) or getattr(entry, "_is_switch_structure", False) + return entry.evaluator if is_structure else entry + + def _get_branch_exit(self, branch_entry: Any) -> Any | None: + """Get the exit executor of a branch. + + For a linear sequence of actions, returns the last executor. + For nested structures, returns None (they have their own branch_exits). + + Args: + branch_entry: The first executor of the branch + + Returns: + The exit executor, or None if branch is empty or ends with a structure + """ + if branch_entry is None: + return None + + # Get the chain of executors in this branch + chain = getattr(branch_entry, "_chain_executors", [branch_entry]) + + last_executor = chain[-1] + + # Check if last executor is a structure with branch_exits + # In that case, we return the structure so its exits can be collected + if hasattr(last_executor, "branch_exits"): + return last_executor + + # Regular executor - get its exit point + return getattr(last_executor, "_exit_executor", last_executor) + + def _collect_all_exits(self, structure: Any) -> list[Any]: + """Recursively collect all exit executors from a structure.""" + exits: list[Any] = [] + + if not hasattr(structure, "branch_exits"): + # Not a structure - return the executor itself + actual_exit = getattr(structure, "_exit_executor", structure) + return [actual_exit] + + for exit_item in structure.branch_exits: + if hasattr(exit_item, "branch_exits"): + # Nested structure - recurse + exits.extend(self._collect_all_exits(exit_item)) + else: + exits.append(exit_item) + + return exits diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py new file mode 100644 index 0000000000..669f662a9b --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py @@ -0,0 +1,1089 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Agent invocation executors for declarative workflows. + +These executors handle invoking Azure AI Foundry agents and other AI agents, +supporting both streaming responses and human-in-loop patterns. + +Aligned with .NET's InvokeAzureAgentExecutor behavior including: +- Structured input with arguments and messages +- External loop support for human-in-loop patterns +- Output with messages and responseObject (JSON parsing) +- AutoSend behavior control +""" + +import contextlib +import json +import logging +import uuid +from dataclasses import dataclass, field +from typing import Any, cast + +from agent_framework import ( + ChatMessage, + FunctionCallContent, + FunctionResultContent, + WorkflowContext, + handler, + response_handler, +) + +from ._declarative_base import ( + ActionComplete, + DeclarativeActionExecutor, + DeclarativeWorkflowState, +) + +logger = logging.getLogger(__name__) + + +def _extract_json_from_response(text: str) -> Any: + r"""Extract and parse JSON from an agent response. + + Agents often return JSON wrapped in markdown code blocks or with + explanatory text. This function attempts to extract and parse the + JSON content from various formats: + + 1. Pure JSON: {"key": "value"} + 2. Markdown code block: ```json\n{"key": "value"}\n``` + 3. Markdown code block (no language): ```\n{"key": "value"}\n``` + 4. JSON with leading/trailing text: Here's the result: {"key": "value"} + 5. Multiple JSON objects: Returns the LAST valid JSON object + + When multiple JSON objects are present (e.g., streaming agent responses + that emit partial then final results), this returns the last complete + JSON object, which is typically the final/complete result. + + Args: + text: The raw text response from an agent + + Returns: + Parsed JSON as a Python dict/list, or None if parsing fails + + Raises: + json.JSONDecodeError: If no valid JSON can be extracted + """ + import re + + if not text: + return None + + text = text.strip() + + if not text: + return None + + # Try parsing as pure JSON first + try: + return json.loads(text) + except json.JSONDecodeError: + pass + + # Try extracting from markdown code blocks: ```json ... ``` or ``` ... ``` + # Use the last code block if there are multiple + code_block_patterns = [ + r"```json\s*\n?(.*?)\n?```", # ```json ... ``` + r"```\s*\n?(.*?)\n?```", # ``` ... ``` + ] + for pattern in code_block_patterns: + matches = list(re.finditer(pattern, text, re.DOTALL)) + if matches: + # Try the last match first (most likely to be the final result) + for match in reversed(matches): + try: + return json.loads(match.group(1).strip()) + except json.JSONDecodeError: + continue + + # Find ALL JSON objects {...} or arrays [...] in the text and return the last valid one + # This handles cases where agents stream multiple JSON objects (partial, then final) + all_json_objects: list[Any] = [] + + pos = 0 + while pos < len(text): + # Find next { or [ + json_start = -1 + bracket_char = None + for i in range(pos, len(text)): + if text[i] == "{": + json_start = i + bracket_char = "{" + break + if text[i] == "[": + json_start = i + bracket_char = "[" + break + + if json_start < 0: + break # No more JSON objects + + # Find matching closing bracket + open_bracket = bracket_char + close_bracket = "}" if open_bracket == "{" else "]" + depth = 0 + in_string = False + escape_next = False + found_end = False + + for i in range(json_start, len(text)): + char = text[i] + + if escape_next: + escape_next = False + continue + + if char == "\\": + escape_next = True + continue + + if char == '"' and not escape_next: + in_string = not in_string + continue + + if in_string: + continue + + if char == open_bracket: + depth += 1 + elif char == close_bracket: + depth -= 1 + if depth == 0: + # Found the end + potential_json = text[json_start : i + 1] + try: + parsed = json.loads(potential_json) + all_json_objects.append(parsed) + except json.JSONDecodeError: + pass + pos = i + 1 + found_end = True + break + + if not found_end: + # Malformed JSON, move past the start character + pos = json_start + 1 + + # Return the last valid JSON object (most likely to be the final/complete result) + if all_json_objects: + return all_json_objects[-1] + + # Unable to extract JSON + raise json.JSONDecodeError("No valid JSON found in response", text, 0) + + +def _validate_conversation_history(messages: list[ChatMessage], agent_name: str) -> None: + """Validate that conversation history has matching tool calls and results. + + This helps catch issues where tool call messages are stored without their + corresponding tool result messages, which would cause API errors. + + Args: + messages: The conversation history to validate. + agent_name: Name of the agent for logging purposes. + + Logs a warning if orphaned tool calls are found. + """ + # Collect all tool call IDs and tool result IDs + tool_call_ids: set[str] = set() + tool_result_ids: set[str] = set() + + for i, msg in enumerate(messages): + if not hasattr(msg, "contents") or msg.contents is None: + continue + for content in msg.contents: + if isinstance(content, FunctionCallContent) and content.call_id: + tool_call_ids.add(content.call_id) + logger.debug( + "Agent '%s': Found tool call '%s' (id=%s) in message %d", + agent_name, + content.name, + content.call_id, + i, + ) + elif isinstance(content, FunctionResultContent) and content.call_id: + tool_result_ids.add(content.call_id) + logger.debug( + "Agent '%s': Found tool result for call_id=%s in message %d", + agent_name, + content.call_id, + i, + ) + + # Find orphaned tool calls (calls without results) + orphaned_calls = tool_call_ids - tool_result_ids + if orphaned_calls: + logger.warning( + "Agent '%s': Conversation history has %d orphaned tool call(s) without results: %s. " + "Total messages: %d, tool calls: %d, tool results: %d", + agent_name, + len(orphaned_calls), + orphaned_calls, + len(messages), + len(tool_call_ids), + len(tool_result_ids), + ) + # Log message structure for debugging + for i, msg in enumerate(messages): + role = getattr(msg, "role", "unknown") + content_types = [] + if hasattr(msg, "contents") and msg.contents: + content_types = [type(c).__name__ for c in msg.contents] + logger.warning( + "Agent '%s': Message %d - role=%s, contents=%s", + agent_name, + i, + role, + content_types, + ) + + +# Keys for agent-related state +AGENT_REGISTRY_KEY = "_agent_registry" +TOOL_REGISTRY_KEY = "_tool_registry" +# Key to store external loop state for resumption +EXTERNAL_LOOP_STATE_KEY = "_external_loop_state" + + +class AgentInvocationError(Exception): + """Raised when an agent invocation fails. + + Attributes: + agent_name: Name of the agent that failed + message: Error description + """ + + def __init__(self, agent_name: str, message: str) -> None: + self.agent_name = agent_name + super().__init__(f"Agent '{agent_name}' invocation failed: {message}") + + +@dataclass +class AgentResult: + """Result from an agent invocation.""" + + success: bool + response: str + agent_name: str + messages: list[ChatMessage] = field(default_factory=lambda: cast(list[ChatMessage], [])) + tool_calls: list[FunctionCallContent] = field(default_factory=lambda: cast(list[FunctionCallContent], [])) + error: str | None = None + + +@dataclass +class AgentExternalInputRequest: + """Request for external input during agent invocation. + + Emitted when externalLoop.when condition evaluates to true, + signaling that the workflow should yield and wait for user input. + + This is the request type used with ctx.request_info() to implement + the Yield/Resume pattern for human-in-loop workflows. + + Examples: + .. code-block:: python + + from agent_framework import run_context + from agent_framework_declarative import ( + ExternalInputRequest, + ExternalInputResponse, + WorkflowFactory, + ) + + factory = WorkflowFactory() + workflow = factory.create_workflow_from_yaml_path("hitl_workflow.yaml") + + + async def run_with_hitl(): + # Set up external input handler + async def on_request(request: AgentExternalInputRequest) -> ExternalInputResponse: + print(f"Agent '{request.agent_name}' needs input:") + print(f" Response: {request.agent_response}") + user_input = input("Your response: ") + return AgentExternalInputResponse(user_input=user_input) + + async with run_context(request_handler=on_request) as ctx: + async for event in workflow.run_stream(ctx=ctx): + print(event) + """ + + request_id: str + agent_name: str + agent_response: str + iteration: int = 0 + messages: list[ChatMessage] = field(default_factory=lambda: cast(list[ChatMessage], [])) + function_calls: list[FunctionCallContent] = field(default_factory=lambda: cast(list[FunctionCallContent], [])) + + +@dataclass +class AgentExternalInputResponse: + """Response to an ExternalInputRequest. + + Provided by the caller to resume agent execution with new user input. + This is the response type expected by the response_handler. + + Examples: + .. code-block:: python + + from agent_framework_declarative import ExternalInputResponse + + # Basic response with user text input + response = AgentExternalInputResponse(user_input="Yes, please proceed with the order.") + + .. code-block:: python + + from agent_framework_declarative import ExternalInputResponse + + # Response with additional message history + response = AgentExternalInputResponse( + user_input="Approved", + messages=[], # Additional context messages if needed + ) + """ + + user_input: str + messages: list[ChatMessage] = field(default_factory=lambda: cast(list[ChatMessage], [])) + function_results: dict[str, FunctionResultContent] = field( + default_factory=lambda: cast(dict[str, FunctionResultContent], {}) + ) + + +@dataclass +class ExternalLoopState: + """State saved for external loop resumption. + + Stored in shared_state to allow the response_handler to + continue the loop with the same configuration. + """ + + agent_name: str + iteration: int + external_loop_when: str + messages_var: str | None + response_obj_var: str | None + result_property: str | None + auto_send: bool + messages_path: str = "Conversation.messages" + max_iterations: int = 100 + + +def _normalize_variable_path(variable: str) -> str: + """Normalize variable names to ensure they have a scope prefix. + + Args: + variable: Variable name like 'Local.X' or 'System.ConversationId' + + Returns: + The variable path with a scope prefix (defaults to Local if none provided) + """ + if variable.startswith(("Local.", "System.", "Workflow.", "Agent.", "Conversation.")): + # Already has a proper namespace + return variable + if "." in variable: + # Has some namespace, use as-is + return variable + # Default to Local scope + return "Local." + variable + + +class InvokeAzureAgentExecutor(DeclarativeActionExecutor): + """Executor that invokes an Azure AI Foundry agent. + + This executor supports both Python-style and .NET-style YAML schemas: + + Python-style (simple): + kind: InvokeAzureAgent + agent: MenuAgent + input: =Local.userInput + resultProperty: Local.agentResponse + + .NET-style (full featured): + kind: InvokeAzureAgent + agent: + name: AgentName + conversationId: =System.ConversationId + input: + arguments: + param1: =Local.value1 + param2: literal value + messages: =Conversation.messages + externalLoop: + when: =Local.needsMoreInput + output: + messages: Local.ResponseMessages + responseObject: Local.StructuredResponse + autoSend: true + + Features: + - Structured input with arguments and messages + - External loop support for human-in-loop patterns + - Output with messages and responseObject (JSON parsing) + - AutoSend behavior control for streaming output + """ + + def __init__( + self, + action_def: dict[str, Any], + *, + id: str | None = None, + agents: dict[str, Any] | None = None, + ): + """Initialize the agent executor. + + Args: + action_def: The action definition from YAML + id: Optional executor ID + agents: Registry of agent instances by name + """ + super().__init__(action_def, id=id) + self._agents = agents or {} + + def _get_agent_name(self, state: Any) -> str | None: + """Extract agent name from action definition. + + Supports both simple string and nested object formats. + """ + agent_config = self._action_def.get("agent") + + if isinstance(agent_config, str): + return agent_config + + if isinstance(agent_config, dict): + agent_dict = cast(dict[str, Any], agent_config) + name = agent_dict.get("name") + if name is not None and isinstance(name, str): + # Support dynamic agent name from expression (would need async eval) + return str(name) + + agent_name = self._action_def.get("agentName") + return str(agent_name) if isinstance(agent_name, str) else None + + def _get_input_config(self) -> tuple[dict[str, Any], Any, str | None, int]: + """Parse input configuration. + + Returns: + Tuple of (arguments dict, messages expression, externalLoop.when expression, maxIterations) + """ + input_config = self._action_def.get("input", {}) + + if not isinstance(input_config, dict): + # Simple input - treat as message directly + return {}, input_config, None, 100 + + input_dict = cast(dict[str, Any], input_config) + arguments: dict[str, Any] = cast(dict[str, Any], input_dict.get("arguments", {})) + messages: Any = input_dict.get("messages") + + # Extract external loop configuration + external_loop_when: str | None = None + max_iterations: int = 100 # Default safety limit + external_loop = input_dict.get("externalLoop") + if isinstance(external_loop, dict): + loop_dict = cast(dict[str, Any], external_loop) + when_val = loop_dict.get("when") + external_loop_when = str(when_val) if when_val is not None else None + max_iter_val = loop_dict.get("maxIterations") + if max_iter_val is not None: + max_iterations = int(max_iter_val) + + return arguments, messages, external_loop_when, max_iterations + + def _get_output_config(self) -> tuple[str | None, str | None, str | None, bool]: + """Parse output configuration. + + Returns: + Tuple of (messages var, responseObject var, resultProperty, autoSend) + """ + output_config = self._action_def.get("output", {}) + + # Legacy Python-style + result_property: str | None = cast(str | None, self._action_def.get("resultProperty")) + + if not isinstance(output_config, dict): + return None, None, result_property, True + + output_dict = cast(dict[str, Any], output_config) + messages_var_val: Any = output_dict.get("messages") + messages_var: str | None = str(messages_var_val) if messages_var_val is not None else None + response_obj_val: Any = output_dict.get("responseObject") + response_obj_var: str | None = str(response_obj_val) if response_obj_val is not None else None + property_val: Any = output_dict.get("property") + property_var: str | None = str(property_val) if property_val is not None else None + auto_send_val: Any = output_dict.get("autoSend", True) + auto_send: bool = bool(auto_send_val) + + return messages_var, response_obj_var, property_var or result_property, auto_send + + def _get_conversation_id(self) -> str | None: + """Get the conversation ID expression from action definition. + + Returns: + The conversationId expression/value, or None if not specified + """ + return self._action_def.get("conversationId") + + async def _get_conversation_messages_path( + self, state: DeclarativeWorkflowState, conversation_id_expr: str | None + ) -> str: + """Get the state path for conversation messages. + + Args: + state: Workflow state for expression evaluation + conversation_id_expr: The conversationId expression from action definition + + Returns: + State path for messages (e.g., "Conversation.messages" or "System.conversations.{id}.messages") + """ + if not conversation_id_expr: + return "Conversation.messages" + + # Evaluate the conversation ID expression + evaluated_id = await state.eval_if_expression(conversation_id_expr) + if not evaluated_id: + return "Conversation.messages" + + # Use conversation-specific messages path + return f"System.conversations.{evaluated_id}.messages" + + async def _build_input_text(self, state: Any, arguments: dict[str, Any], messages_expr: Any) -> str: + """Build input text from arguments and messages. + + Args: + state: Workflow state for expression evaluation + arguments: Input arguments to evaluate + messages_expr: Messages expression or direct input + + Returns: + Input text for the agent + """ + # Evaluate arguments + evaluated_args: dict[str, Any] = {} + for key, value in arguments.items(): + evaluated_args[key] = await state.eval_if_expression(value) + + # Evaluate messages/input + if messages_expr: + evaluated_input: Any = await state.eval_if_expression(messages_expr) + if isinstance(evaluated_input, str): + return evaluated_input + if isinstance(evaluated_input, list) and evaluated_input: + # Extract text from last message + last: Any = evaluated_input[-1] # type: ignore + if isinstance(last, str): + return last + if isinstance(last, dict): + last_dict = cast(dict[str, Any], last) + content_val: Any = last_dict.get("content", last_dict.get("text", "")) + return str(content_val) if content_val else "" + if last is not None and hasattr(last, "text"): # type: ignore + return str(getattr(last, "text", "")) # type: ignore + if evaluated_input: + return str(cast(Any, evaluated_input)) + return "" + + # Fallback chain for implicit input (like .NET conversationId pattern): + # 1. Local.input / Local.userInput (explicit turn state) + # 2. System.LastMessage.Text (previous agent's response) + # 3. Workflow.Inputs (first agent gets workflow inputs) + input_text: str = str(await state.get("Local.input") or await state.get("Local.userInput") or "") + if not input_text: + # Try System.LastMessage.Text (used by external loop and agent chaining) + last_message: Any = await state.get("System.LastMessage") + if isinstance(last_message, dict): + last_msg_dict = cast(dict[str, Any], last_message) + text_val: Any = last_msg_dict.get("Text", "") + input_text = str(text_val) if text_val else "" + if not input_text: + # Fall back to workflow inputs (for first agent in chain) + inputs: Any = await state.get("Workflow.Inputs") + if isinstance(inputs, dict): + inputs_dict = cast(dict[str, Any], inputs) + # If single input, use its value directly + if len(inputs_dict) == 1: + input_text = str(next(iter(inputs_dict.values()))) + else: + # Multiple inputs - format as key: value pairs + input_text = "\n".join(f"{k}: {v}" for k, v in inputs_dict.items()) + return input_text if input_text else "" + + def _get_agent(self, agent_name: str, ctx: WorkflowContext[Any, Any]) -> Any: + """Get agent from registry (sync helper for response handler).""" + return self._agents.get(agent_name) if self._agents else None + + async def _invoke_agent_and_store_results( + self, + agent: Any, + agent_name: str, + input_text: str, + state: DeclarativeWorkflowState, + ctx: WorkflowContext[ActionComplete, str], + messages_var: str | None, + response_obj_var: str | None, + result_property: str | None, + auto_send: bool, + messages_path: str = "Conversation.messages", + ) -> tuple[str, list[Any], list[Any]]: + """Invoke agent and store results in state. + + Args: + agent: The agent instance to invoke + agent_name: Name of the agent for logging + input_text: User input text + state: Workflow state + ctx: Workflow context + messages_var: Output variable for messages + response_obj_var: Output variable for parsed response object + result_property: Output property for result + auto_send: Whether to auto-send output to context + messages_path: State path for conversation messages (default: "Conversation.messages") + + Returns: + Tuple of (accumulated_response, all_messages, tool_calls) + """ + accumulated_response = "" + all_messages: list[ChatMessage] = [] + tool_calls: list[FunctionCallContent] = [] + + # Add user input to conversation history first (via state.append only) + if input_text: + user_message = ChatMessage(role="user", text=input_text) + await state.append(messages_path, user_message) + + # Get conversation history from state AFTER adding user message + # Note: We get a fresh copy to avoid mutation issues + conversation_history: list[ChatMessage] = await state.get(messages_path) or [] + + # Build messages list for agent (use history if available, otherwise just input) + messages_for_agent: list[ChatMessage] | str = conversation_history if conversation_history else input_text + + # Validate conversation history before invoking agent + if isinstance(messages_for_agent, list) and messages_for_agent: + _validate_conversation_history(messages_for_agent, agent_name) + + # Use run() method to get properly structured messages (including tool calls and results) + # This is critical for multi-turn conversations where tool calls must be followed + # by their results in the message history + if hasattr(agent, "run"): + result: Any = await agent.run(messages_for_agent) + if hasattr(result, "text") and result.text: + accumulated_response = str(result.text) + if auto_send: + await ctx.yield_output(str(result.text)) + elif isinstance(result, str): + accumulated_response = result + if auto_send: + await ctx.yield_output(result) + + if not isinstance(result, str): + result_messages: Any = getattr(result, "messages", None) + if result_messages is not None: + all_messages = list(cast(list[ChatMessage], result_messages)) + result_tool_calls: Any = getattr(result, "tool_calls", None) + if result_tool_calls is not None: + tool_calls = list(cast(list[FunctionCallContent], result_tool_calls)) + + else: + raise RuntimeError(f"Agent '{agent_name}' has no run or run_stream method") + + # Add messages to conversation history + # We need to include ALL messages from the agent run (including tool calls and tool results) + # to maintain proper conversation state for the next agent invocation + if all_messages: + # Agent returned full message history - use it + logger.debug( + "Agent '%s': Storing %d messages to conversation history at '%s'", + agent_name, + len(all_messages), + messages_path, + ) + for i, msg in enumerate(all_messages): + role = getattr(msg, "role", "unknown") + content_types = [] + if hasattr(msg, "contents") and msg.contents: + content_types = [type(c).__name__ for c in msg.contents] + logger.debug( + "Agent '%s': Storing message %d - role=%s, contents=%s", + agent_name, + i, + role, + content_types, + ) + await state.append(messages_path, msg) + elif accumulated_response: + # No messages returned, create a simple assistant message + logger.debug( + "Agent '%s': No messages in response, creating simple assistant message", + agent_name, + ) + assistant_message = ChatMessage(role="assistant", text=accumulated_response) + await state.append(messages_path, assistant_message) + + # Store results in state - support both schema formats: + # - Graph mode: Agent.response, Agent.name + # - Interpreter mode: Agent.text, Agent.messages, Agent.toolCalls + await state.set("Agent.response", accumulated_response) + await state.set("Agent.name", agent_name) + await state.set("Agent.text", accumulated_response) + await state.set("Agent.messages", all_messages if all_messages else []) + await state.set("Agent.toolCalls", tool_calls if tool_calls else []) + + # Store System.LastMessage for externalLoop.when condition evaluation + await state.set("System.LastMessage", {"Text": accumulated_response}) + + # Store in output variables (.NET style) + if messages_var: + output_path = _normalize_variable_path(messages_var) + await state.set(output_path, all_messages if all_messages else accumulated_response) + + if response_obj_var: + output_path = _normalize_variable_path(response_obj_var) + # Try to extract and parse JSON from the response + try: + parsed = _extract_json_from_response(accumulated_response) if accumulated_response else None + logger.debug(f"InvokeAzureAgent: parsed responseObject for '{output_path}': type={type(parsed)}") + await state.set(output_path, parsed) + except (json.JSONDecodeError, TypeError) as e: + logger.warning(f"InvokeAzureAgent: failed to parse JSON for '{output_path}': {e}, storing as string") + await state.set(output_path, accumulated_response) + + # Store in result property (Python style) + if result_property: + await state.set(result_property, accumulated_response) + + return accumulated_response, all_messages, tool_calls + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete, str], + ) -> None: + """Handle the agent invocation with full .NET feature parity. + + When externalLoop.when is configured and evaluates to true after agent response, + this method emits an ExternalInputRequest via ctx.request_info() and returns. + The workflow will yield, and when the caller provides a response via + send_responses_streaming(), the handle_external_input_response handler + will continue the loop. + """ + state = await self._ensure_state_initialized(ctx, trigger) + + # Parse configuration + agent_name = self._get_agent_name(state) + if not agent_name: + logger.warning("InvokeAzureAgent action missing 'agent' or 'agent.name' property") + await ctx.send_message(ActionComplete()) + return + + logger.debug("handle_action: starting agent '%s'", agent_name) + + arguments, messages_expr, external_loop_when, max_iterations = self._get_input_config() + messages_var, response_obj_var, result_property, auto_send = self._get_output_config() + + # Get conversation-specific messages path if conversationId is specified + conversation_id_expr = self._get_conversation_id() + messages_path = await self._get_conversation_messages_path(state, conversation_id_expr) + logger.debug("handle_action: agent='%s', messages_path='%s'", agent_name, messages_path) + + # Build input + input_text = await self._build_input_text(state, arguments, messages_expr) + + # Get agent from registry + agent: Any = self._agents.get(agent_name) if self._agents else None + if agent is None: + try: + agent_registry: dict[str, Any] | None = await ctx.shared_state.get(AGENT_REGISTRY_KEY) + except KeyError: + agent_registry = {} + agent = agent_registry.get(agent_name) if agent_registry else None + + if agent is None: + error_msg = f"Agent '{agent_name}' not found in registry" + logger.error(f"InvokeAzureAgent: {error_msg}") + await state.set("Agent.error", error_msg) + if result_property: + await state.set(result_property, {"error": error_msg}) + raise AgentInvocationError(agent_name, "not found in registry") + + iteration = 0 + + try: + accumulated_response, all_messages, tool_calls = await self._invoke_agent_and_store_results( + agent=agent, + agent_name=agent_name, + input_text=input_text, + state=state, + ctx=ctx, + messages_var=messages_var, + response_obj_var=response_obj_var, + result_property=result_property, + auto_send=auto_send, + messages_path=messages_path, + ) + except AgentInvocationError: + raise # Re-raise our own errors + except Exception as e: + logger.error(f"InvokeAzureAgent: error invoking agent '{agent_name}': {e}") + await state.set("Agent.error", str(e)) + if result_property: + await state.set(result_property, {"error": str(e)}) + raise AgentInvocationError(agent_name, str(e)) from e + + # Check external loop condition + if external_loop_when: + should_continue = await state.eval(external_loop_when) + should_continue = bool(should_continue) if should_continue is not None else False + + logger.debug( + f"InvokeAzureAgent: external loop condition '{str(external_loop_when)[:50]}' = " + f"{should_continue} (iteration {iteration})" + ) + + if should_continue: + # Save loop state for resumption + loop_state = ExternalLoopState( + agent_name=agent_name, + iteration=iteration + 1, + external_loop_when=external_loop_when, + messages_var=messages_var, + response_obj_var=response_obj_var, + result_property=result_property, + auto_send=auto_send, + messages_path=messages_path, + max_iterations=max_iterations, + ) + await ctx.shared_state.set(EXTERNAL_LOOP_STATE_KEY, loop_state) + + # Emit request for external input - workflow will yield here + request = AgentExternalInputRequest( + request_id=str(uuid.uuid4()), + agent_name=agent_name, + agent_response=accumulated_response, + iteration=iteration, + messages=all_messages, + function_calls=tool_calls, + ) + logger.info(f"InvokeAzureAgent: yielding for external input (iteration {iteration})") + await ctx.request_info(request, AgentExternalInputResponse) + # Return without sending ActionComplete - workflow yields + return + + # No external loop or condition is false - complete the action + await ctx.send_message(ActionComplete()) + + @response_handler + async def handle_external_input_response( + self, + original_request: AgentExternalInputRequest, + response: AgentExternalInputResponse, + ctx: WorkflowContext[ActionComplete, str], + ) -> None: + """Handle response to an ExternalInputRequest and continue the loop. + + This is called when the workflow resumes after yielding for external input. + It continues the agent invocation loop with the user's new input. + """ + logger.debug( + "handle_external_input_response: resuming with user_input='%s'", + response.user_input[:100] if response.user_input else None, + ) + state = self._get_state(ctx.shared_state) + + # Retrieve saved loop state + try: + loop_state: ExternalLoopState = await ctx.shared_state.get(EXTERNAL_LOOP_STATE_KEY) + except KeyError: + logger.error("InvokeAzureAgent: external loop state not found, cannot resume") + await ctx.send_message(ActionComplete()) + return + + agent_name = loop_state.agent_name + iteration = loop_state.iteration + external_loop_when = loop_state.external_loop_when + max_iterations = loop_state.max_iterations + messages_path = loop_state.messages_path + + logger.debug( + "handle_external_input_response: agent='%s', iteration=%d, messages_path='%s'", + agent_name, + iteration, + messages_path, + ) + + # Get the user's new input + input_text = response.user_input + + # Store the user input in state for condition evaluation + await state.set("Local.userInput", input_text) + await state.set("System.LastMessage", {"Text": input_text}) + + # Check if we should continue BEFORE invoking the agent + # This matches .NET behavior where the condition checks the user's input + should_continue = await state.eval(external_loop_when) + should_continue = bool(should_continue) if should_continue is not None else False + + logger.debug( + f"InvokeAzureAgent: external loop condition '{str(external_loop_when)[:50]}' = " + f"{should_continue} (iteration {iteration}) for input '{input_text[:30]}...'" + ) + + if not should_continue: + # User input caused loop to exit - clean up and complete + with contextlib.suppress(KeyError): + await ctx.shared_state.delete(EXTERNAL_LOOP_STATE_KEY) + await ctx.send_message(ActionComplete()) + return + + # Get agent from registry + agent: Any = self._agents.get(agent_name) if self._agents else None + if agent is None: + try: + agent_registry: dict[str, Any] | None = await ctx.shared_state.get(AGENT_REGISTRY_KEY) + except KeyError: + agent_registry = {} + agent = agent_registry.get(agent_name) if agent_registry else None + + if agent is None: + logger.error(f"InvokeAzureAgent: agent '{agent_name}' not found during loop resumption") + raise AgentInvocationError(agent_name, "not found during loop resumption") + + try: + accumulated_response, all_messages, tool_calls = await self._invoke_agent_and_store_results( + agent=agent, + agent_name=agent_name, + input_text=input_text, + state=state, + ctx=ctx, + messages_var=loop_state.messages_var, + response_obj_var=loop_state.response_obj_var, + result_property=loop_state.result_property, + auto_send=loop_state.auto_send, + messages_path=loop_state.messages_path, + ) + except AgentInvocationError: + raise # Re-raise our own errors + except Exception as e: + logger.error(f"InvokeAzureAgent: error invoking agent '{agent_name}' during loop: {e}") + await state.set("Agent.error", str(e)) + raise AgentInvocationError(agent_name, str(e)) from e + + # Re-evaluate the condition AFTER the agent responds + # This is critical: the agent's response may have set NeedsTicket=true or IsResolved=true + should_continue = await state.eval(external_loop_when) + should_continue = bool(should_continue) if should_continue is not None else False + + logger.debug( + f"InvokeAzureAgent: external loop condition after response '{str(external_loop_when)[:50]}' = " + f"{should_continue} (iteration {iteration})" + ) + + if not should_continue: + # Agent response caused loop to exit (e.g., NeedsTicket=true or IsResolved=true) + logger.info( + "InvokeAzureAgent: external loop exited due to condition=false " + "(sending ActionComplete to continue workflow)" + ) + with contextlib.suppress(KeyError): + await ctx.shared_state.delete(EXTERNAL_LOOP_STATE_KEY) + await ctx.send_message(ActionComplete()) + return + + # Continue the loop - condition still true + if iteration < max_iterations: + # Update loop state for next iteration + loop_state.iteration = iteration + 1 + await ctx.shared_state.set(EXTERNAL_LOOP_STATE_KEY, loop_state) + + # Emit another request for external input + request = AgentExternalInputRequest( + request_id=str(uuid.uuid4()), + agent_name=agent_name, + agent_response=accumulated_response, + iteration=iteration, + messages=all_messages, + function_calls=tool_calls, + ) + logger.info(f"InvokeAzureAgent: yielding for external input (iteration {iteration})") + await ctx.request_info(request, AgentExternalInputResponse) + return + + logger.warning(f"InvokeAzureAgent: external loop exceeded max iterations ({max_iterations})") + + # Loop complete - clean up and send completion + with contextlib.suppress(KeyError): + await ctx.shared_state.delete(EXTERNAL_LOOP_STATE_KEY) + + await ctx.send_message(ActionComplete()) + + +class InvokeToolExecutor(DeclarativeActionExecutor): + """Executor that invokes a registered tool/function. + + Tools are simpler than agents - they take input, perform an action, + and return a result synchronously (or with a simple async call). + """ + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Handle the tool invocation.""" + state = await self._ensure_state_initialized(ctx, trigger) + + tool_name = self._action_def.get("tool") or self._action_def.get("toolName", "") + input_expr = self._action_def.get("input") + output_property = self._action_def.get("output", {}).get("property") or self._action_def.get("resultProperty") + parameters = self._action_def.get("parameters", {}) + + # Get tools registry + try: + tool_registry: dict[str, Any] | None = await ctx.shared_state.get(TOOL_REGISTRY_KEY) + except KeyError: + tool_registry = {} + + tool: Any = tool_registry.get(tool_name) if tool_registry else None + + if tool is None: + error_msg = f"Tool '{tool_name}' not found in registry" + if output_property: + await state.set(output_property, {"error": error_msg}) + await ctx.send_message(ActionComplete()) + return + + # Build parameters + params: dict[str, Any] = {} + for param_name, param_expression in parameters.items(): + params[param_name] = await state.eval_if_expression(param_expression) + + # Add main input if specified + if input_expr: + params["input"] = await state.eval_if_expression(input_expr) + + try: + # Invoke the tool + if callable(tool): + from inspect import isawaitable + + result = tool(**params) + if isawaitable(result): + result = await result + + # Store result + if output_property: + await state.set(output_property, result) + + except Exception as e: + if output_property: + await state.set(output_property, {"error": str(e)}) + await ctx.send_message(ActionComplete()) + return + + await ctx.send_message(ActionComplete()) + + +# Mapping of agent action kinds to executor classes +AGENT_ACTION_EXECUTORS: dict[str, type[DeclarativeActionExecutor]] = { + "InvokeAzureAgent": InvokeAzureAgentExecutor, + "InvokeTool": InvokeToolExecutor, +} diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py new file mode 100644 index 0000000000..6603357478 --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py @@ -0,0 +1,575 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Basic action executors for the graph-based declarative workflow system. + +These executors handle simple actions like SetValue, SendActivity, etc. +Each action becomes a node in the workflow graph. +""" + +from typing import Any + +from agent_framework._workflows import ( + WorkflowContext, + handler, +) + +from ._declarative_base import ( + ActionComplete, + DeclarativeActionExecutor, +) + + +def _get_variable_path(action_def: dict[str, Any], key: str = "variable") -> str | None: + """Extract variable path from action definition. + + Supports .NET style (variable: Local.VarName) and nested object style (variable: {path: ...}). + """ + variable = action_def.get(key) + if isinstance(variable, str): + return variable + if isinstance(variable, dict): + return variable.get("path") + return action_def.get("path") + + +class SetValueExecutor(DeclarativeActionExecutor): + """Executor for the SetValue action. + + Sets a value in the workflow state at a specified path. + """ + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Handle the SetValue action.""" + state = await self._ensure_state_initialized(ctx, trigger) + + path = self._action_def.get("path") + value = self._action_def.get("value") + + if path: + # Evaluate value if it's an expression + evaluated_value = await state.eval_if_expression(value) + await state.set(path, evaluated_value) + + await ctx.send_message(ActionComplete()) + + +class SetVariableExecutor(DeclarativeActionExecutor): + """Executor for the SetVariable action (.NET style naming).""" + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Handle the SetVariable action.""" + state = await self._ensure_state_initialized(ctx, trigger) + + path = _get_variable_path(self._action_def) + value = self._action_def.get("value") + + if path: + evaluated_value = await state.eval_if_expression(value) + await state.set(path, evaluated_value) + + await ctx.send_message(ActionComplete()) + + +class SetTextVariableExecutor(DeclarativeActionExecutor): + """Executor for the SetTextVariable action.""" + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Handle the SetTextVariable action.""" + state = await self._ensure_state_initialized(ctx, trigger) + + path = _get_variable_path(self._action_def) + text = self._action_def.get("text", "") + + if path: + evaluated_text = await state.eval_if_expression(text) + await state.set(path, str(evaluated_text) if evaluated_text is not None else "") + + await ctx.send_message(ActionComplete()) + + +class SetMultipleVariablesExecutor(DeclarativeActionExecutor): + """Executor for the SetMultipleVariables action.""" + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Handle the SetMultipleVariables action.""" + state = await self._ensure_state_initialized(ctx, trigger) + + assignments = self._action_def.get("assignments", []) + for assignment in assignments: + variable = assignment.get("variable") + path: str | None + if isinstance(variable, str): + path = variable + elif isinstance(variable, dict): + path = variable.get("path") + else: + path = assignment.get("path") + value = assignment.get("value") + if path: + evaluated_value = await state.eval_if_expression(value) + await state.set(path, evaluated_value) + + await ctx.send_message(ActionComplete()) + + +class AppendValueExecutor(DeclarativeActionExecutor): + """Executor for the AppendValue action.""" + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Handle the AppendValue action.""" + state = await self._ensure_state_initialized(ctx, trigger) + + path = self._action_def.get("path") + value = self._action_def.get("value") + + if path: + evaluated_value = await state.eval_if_expression(value) + await state.append(path, evaluated_value) + + await ctx.send_message(ActionComplete()) + + +class ResetVariableExecutor(DeclarativeActionExecutor): + """Executor for the ResetVariable action.""" + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Handle the ResetVariable action.""" + state = await self._ensure_state_initialized(ctx, trigger) + + path = _get_variable_path(self._action_def) + + if path: + # Reset to None/empty + await state.set(path, None) + + await ctx.send_message(ActionComplete()) + + +class ClearAllVariablesExecutor(DeclarativeActionExecutor): + """Executor for the ClearAllVariables action.""" + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Handle the ClearAllVariables action.""" + state = await self._ensure_state_initialized(ctx, trigger) + + # Get state data and clear Local variables + state_data = await state.get_state_data() + state_data["Local"] = {} + await state.set_state_data(state_data) + + await ctx.send_message(ActionComplete()) + + +class SendActivityExecutor(DeclarativeActionExecutor): + """Executor for the SendActivity action. + + Sends a text message or activity as workflow output. + """ + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete, str], + ) -> None: + """Handle the SendActivity action.""" + state = await self._ensure_state_initialized(ctx, trigger) + + activity = self._action_def.get("activity", "") + + # Activity can be a string directly or a dict with a "text" field + text = activity.get("text", "") if isinstance(activity, dict) else activity + + if isinstance(text, str): + # First evaluate any =expression syntax + text = await state.eval_if_expression(text) + # Then interpolate any {Variable.Path} template syntax + if isinstance(text, str): + text = await state.interpolate_string(text) + + # Yield the text as workflow output + if text: + await ctx.yield_output(str(text)) + + await ctx.send_message(ActionComplete()) + + +class EmitEventExecutor(DeclarativeActionExecutor): + """Executor for the EmitEvent action. + + Emits a custom event to the workflow event stream. + + Supports two schema formats: + 1. Graph mode: eventName, eventValue + 2. Interpreter mode: event.name, event.data + """ + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete, dict[str, Any]], + ) -> None: + """Handle the EmitEvent action.""" + state = await self._ensure_state_initialized(ctx, trigger) + + # Support both schema formats: + # - Graph mode: eventName, eventValue + # - Interpreter mode: event.name, event.data + event_def = self._action_def.get("event", {}) + event_name = self._action_def.get("eventName") or event_def.get("name", "") + event_value = self._action_def.get("eventValue") + if event_value is None: + event_value = event_def.get("data") + + if event_name: + evaluated_name = await state.eval_if_expression(event_name) + evaluated_value = await state.eval_if_expression(event_value) + + event_data = { + "eventName": evaluated_name, + "eventValue": evaluated_value, + } + await ctx.yield_output(event_data) + + await ctx.send_message(ActionComplete()) + + +class EditTableExecutor(DeclarativeActionExecutor): + """Executor for the EditTable action. + + Performs operations on a table (list) variable such as add, remove, or clear. + This is equivalent to the .NET EditTable action. + + YAML example: + - kind: EditTable + table: Local.Items + operation: add # add, remove, clear + value: =Local.NewItem + index: 0 # optional, for insert at position + """ + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Handle the EditTable action.""" + state = await self._ensure_state_initialized(ctx, trigger) + + table_path = self._action_def.get("table") or _get_variable_path(self._action_def, "variable") + operation = self._action_def.get("operation", "add").lower() + value = self._action_def.get("value") + index = self._action_def.get("index") + + if table_path: + # Get current table value + current_table = await state.get(table_path) + if current_table is None: + current_table = [] + elif not isinstance(current_table, list): + current_table = [current_table] + + if operation == "add" or operation == "insert": + evaluated_value = await state.eval_if_expression(value) + if index is not None: + evaluated_index = await state.eval_if_expression(index) + idx = int(evaluated_index) if evaluated_index is not None else len(current_table) + current_table.insert(idx, evaluated_value) + else: + current_table.append(evaluated_value) + + elif operation == "remove": + if value is not None: + # Remove by value + evaluated_value = await state.eval_if_expression(value) + if evaluated_value in current_table: + current_table.remove(evaluated_value) + elif index is not None: + # Remove by index + evaluated_index = await state.eval_if_expression(index) + idx = int(evaluated_index) if evaluated_index is not None else -1 + if 0 <= idx < len(current_table): + current_table.pop(idx) + + elif operation == "clear": + current_table = [] + + elif operation == "set" or operation == "update": + # Update item at index + if index is not None: + evaluated_value = await state.eval_if_expression(value) + evaluated_index = await state.eval_if_expression(index) + idx = int(evaluated_index) if evaluated_index is not None else 0 + if 0 <= idx < len(current_table): + current_table[idx] = evaluated_value + + await state.set(table_path, current_table) + + await ctx.send_message(ActionComplete()) + + +class EditTableV2Executor(DeclarativeActionExecutor): + """Executor for the EditTableV2 action. + + Enhanced table editing with more operations and better record support. + This is equivalent to the .NET EditTableV2 action. + + YAML example: + - kind: EditTableV2 + table: Local.Records + operation: addOrUpdate # add, remove, clear, addOrUpdate, filter + item: =Local.NewRecord + key: id # for addOrUpdate, the field to match on + condition: =item.status = "active" # for filter operation + """ + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Handle the EditTableV2 action.""" + state = await self._ensure_state_initialized(ctx, trigger) + + table_path = self._action_def.get("table") or _get_variable_path(self._action_def, "variable") + operation = self._action_def.get("operation", "add").lower() + item = self._action_def.get("item") or self._action_def.get("value") + key_field = self._action_def.get("key") + index = self._action_def.get("index") + + if table_path: + # Get current table value + current_table = await state.get(table_path) + if current_table is None: + current_table = [] + elif not isinstance(current_table, list): + current_table = [current_table] + + if operation == "add": + evaluated_item = await state.eval_if_expression(item) + if index is not None: + evaluated_index = await state.eval_if_expression(index) + idx = int(evaluated_index) if evaluated_index is not None else len(current_table) + current_table.insert(idx, evaluated_item) + else: + current_table.append(evaluated_item) + + elif operation == "remove": + if item is not None: + evaluated_item = await state.eval_if_expression(item) + if key_field and isinstance(evaluated_item, dict): + # Remove by key match + key_value = evaluated_item.get(key_field) + current_table = [ + r for r in current_table if not (isinstance(r, dict) and r.get(key_field) == key_value) + ] + elif evaluated_item in current_table: + current_table.remove(evaluated_item) + elif index is not None: + evaluated_index = await state.eval_if_expression(index) + idx = int(evaluated_index) if evaluated_index is not None else -1 + if 0 <= idx < len(current_table): + current_table.pop(idx) + + elif operation == "clear": + current_table = [] + + elif operation == "addorupdate": + evaluated_item = await state.eval_if_expression(item) + if key_field and isinstance(evaluated_item, dict): + key_value = evaluated_item.get(key_field) + # Find existing item with same key + found_idx = -1 + for i, r in enumerate(current_table): + if isinstance(r, dict) and r.get(key_field) == key_value: + found_idx = i + break + if found_idx >= 0: + # Update existing + current_table[found_idx] = evaluated_item + else: + # Add new + current_table.append(evaluated_item) + else: + # No key field - just add + current_table.append(evaluated_item) + + elif operation == "update": + evaluated_item = await state.eval_if_expression(item) + if index is not None: + evaluated_index = await state.eval_if_expression(index) + idx = int(evaluated_index) if evaluated_index is not None else 0 + if 0 <= idx < len(current_table): + current_table[idx] = evaluated_item + elif key_field and isinstance(evaluated_item, dict): + key_value = evaluated_item.get(key_field) + for i, r in enumerate(current_table): + if isinstance(r, dict) and r.get(key_field) == key_value: + current_table[i] = evaluated_item + break + + await state.set(table_path, current_table) + + await ctx.send_message(ActionComplete()) + + +class ParseValueExecutor(DeclarativeActionExecutor): + """Executor for the ParseValue action. + + Parses a value expression and optionally converts it to a target type. + This is equivalent to the .NET ParseValue action. + + YAML example: + - kind: ParseValue + variable: Local.ParsedData + value: =System.LastMessage.Text + valueType: object # optional: string, number, boolean, object, array + """ + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Handle the ParseValue action.""" + state = await self._ensure_state_initialized(ctx, trigger) + + path = _get_variable_path(self._action_def) + value = self._action_def.get("value") + value_type = self._action_def.get("valueType") + + if path and value is not None: + # Evaluate the value expression + evaluated_value = await state.eval_if_expression(value) + + # Convert to target type if specified + if value_type: + evaluated_value = self._convert_to_type(evaluated_value, value_type) + + await state.set(path, evaluated_value) + + await ctx.send_message(ActionComplete()) + + def _convert_to_type(self, value: Any, target_type: str) -> Any: + """Convert a value to the specified target type. + + Args: + value: The value to convert + target_type: Target type (string, number, boolean, object, array) + + Returns: + The converted value + """ + import json + + target_type = target_type.lower() + + if target_type == "string": + if value is None: + return "" + return str(value) + + if target_type in ("number", "int", "integer", "float", "decimal"): + if value is None: + return 0 + if isinstance(value, str): + # Try to parse as number + try: + if "." in value: + return float(value) + return int(value) + except ValueError: + return 0 + return float(value) if isinstance(value, (int, float)) else 0 + + if target_type in ("boolean", "bool"): + if value is None: + return False + if isinstance(value, str): + return value.lower() in ("true", "yes", "1", "on") + return bool(value) + + if target_type in ("object", "record"): + if value is None: + return {} + if isinstance(value, dict): + return value + if isinstance(value, str): + try: + parsed = json.loads(value) + return parsed if isinstance(parsed, dict) else {"value": parsed} + except json.JSONDecodeError: + return {"value": value} + return {"value": value} + + if target_type in ("array", "table", "list"): + if value is None: + return [] + if isinstance(value, list): + return value + if isinstance(value, str): + try: + parsed = json.loads(value) + return parsed if isinstance(parsed, list) else [parsed] + except json.JSONDecodeError: + return [value] + return [value] + + # Unknown type - return as-is + return value + + +# Mapping of action kinds to executor classes +BASIC_ACTION_EXECUTORS: dict[str, type[DeclarativeActionExecutor]] = { + "SetValue": SetValueExecutor, + "SetVariable": SetVariableExecutor, + "SetTextVariable": SetTextVariableExecutor, + "SetMultipleVariables": SetMultipleVariablesExecutor, + "AppendValue": AppendValueExecutor, + "ResetVariable": ResetVariableExecutor, + "ClearAllVariables": ClearAllVariablesExecutor, + "SendActivity": SendActivityExecutor, + "EmitEvent": EmitEventExecutor, + "ParseValue": ParseValueExecutor, + "EditTable": EditTableExecutor, + "EditTableV2": EditTableV2Executor, +} diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_control_flow.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_control_flow.py new file mode 100644 index 0000000000..48aeabb58b --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_control_flow.py @@ -0,0 +1,548 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Control flow executors for the graph-based declarative workflow system. + +Control flow in the graph-based system is handled differently than the interpreter: +- If/Switch: Condition evaluation happens in a dedicated evaluator executor that + returns a ConditionResult with the first-matching branch index. Edge conditions + then check the branch_index to route to the correct branch. This ensures only + one branch executes (first-match semantics), matching the interpreter behavior. +- Foreach: Loop iteration state managed in SharedState + loop edges +- Goto: Edge to target action (handled by builder) +- Break/Continue: Special signals for loop control + +The key insight is that control flow becomes GRAPH STRUCTURE, not executor logic. +""" + +from typing import Any, cast + +from agent_framework._workflows import ( + WorkflowContext, + handler, +) + +from ._declarative_base import ( + ActionComplete, + ActionTrigger, + ConditionResult, + DeclarativeActionExecutor, + LoopControl, + LoopIterationResult, +) + +# Keys for loop state in SharedState +LOOP_STATE_KEY = "_declarative_loop_state" + +# Index value indicating the else/default branch +ELSE_BRANCH_INDEX = -1 + + +class ConditionGroupEvaluatorExecutor(DeclarativeActionExecutor): + """Evaluates conditions for ConditionGroup/Switch and outputs the first-matching branch. + + This executor implements first-match semantics by evaluating conditions sequentially + and outputting a ConditionResult with the index of the first matching branch. + Edge conditions downstream check this index to route to the correct branch. + + This mirrors .NET's ConditionGroupExecutor.ExecuteAsync which returns the step ID + of the first matching condition. + """ + + def __init__( + self, + action_def: dict[str, Any], + conditions: list[dict[str, Any]], + *, + id: str | None = None, + ): + """Initialize the condition evaluator. + + Args: + action_def: The ConditionGroup/Switch action definition + conditions: List of condition items, each with 'condition' and optional 'id' + id: Optional executor ID + """ + super().__init__(action_def, id=id) + self._conditions = conditions + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ConditionResult], + ) -> None: + """Evaluate conditions and output the first matching branch index.""" + state = await self._ensure_state_initialized(ctx, trigger) + + # Evaluate conditions sequentially - first match wins + for index, cond_item in enumerate(self._conditions): + condition_expr = cond_item.get("condition") + if condition_expr is None: + continue + + # Normalize boolean conditions + if condition_expr is True: + condition_expr = "=true" + elif condition_expr is False: + condition_expr = "=false" + elif isinstance(condition_expr, str) and not condition_expr.startswith("="): + condition_expr = f"={condition_expr}" + + result = await state.eval(condition_expr) + if bool(result): + # First matching condition found + await ctx.send_message(ConditionResult(matched=True, branch_index=index, value=result)) + return + + # No condition matched - use else/default branch + await ctx.send_message(ConditionResult(matched=False, branch_index=ELSE_BRANCH_INDEX)) + + +class SwitchEvaluatorExecutor(DeclarativeActionExecutor): + """Evaluates a Switch action by matching a value against cases. + + The Switch action uses a different schema than ConditionGroup: + - value: expression to evaluate once + - cases: list of {match: value_to_match, actions: [...]} + - default: default actions if no case matches + + This evaluator evaluates the value expression once, then compares it + against each case's match value sequentially. First match wins. + """ + + def __init__( + self, + action_def: dict[str, Any], + cases: list[dict[str, Any]], + *, + id: str | None = None, + ): + """Initialize the switch evaluator. + + Args: + action_def: The Switch action definition (contains 'value' expression) + cases: List of case items, each with 'match' and optional 'actions' + id: Optional executor ID + """ + super().__init__(action_def, id=id) + self._cases = cases + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ConditionResult], + ) -> None: + """Evaluate the switch value and find the first matching case.""" + state = await self._ensure_state_initialized(ctx, trigger) + + value_expr = self._action_def.get("value") + if not value_expr: + # No value to switch on - use default + await ctx.send_message(ConditionResult(matched=False, branch_index=ELSE_BRANCH_INDEX)) + return + + # Evaluate the switch value once + switch_value = await state.eval_if_expression(value_expr) + + # Compare against each case's match value + for index, case_item in enumerate(self._cases): + match_expr = case_item.get("match") + if match_expr is None: + continue + + # Evaluate the match value + match_value = await state.eval_if_expression(match_expr) + + if switch_value == match_value: + # Found matching case + await ctx.send_message(ConditionResult(matched=True, branch_index=index, value=switch_value)) + return + + # No case matched - use default branch + await ctx.send_message(ConditionResult(matched=False, branch_index=ELSE_BRANCH_INDEX)) + + +class IfConditionEvaluatorExecutor(DeclarativeActionExecutor): + """Evaluates a single If condition and outputs a ConditionResult. + + This is simpler than ConditionGroupEvaluator - just evaluates one condition + and outputs branch_index=0 (then) or branch_index=-1 (else). + """ + + def __init__( + self, + action_def: dict[str, Any], + condition_expr: str, + *, + id: str | None = None, + ): + """Initialize the if condition evaluator. + + Args: + action_def: The If action definition + condition_expr: The condition expression to evaluate + id: Optional executor ID + """ + super().__init__(action_def, id=id) + self._condition_expr = condition_expr + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ConditionResult], + ) -> None: + """Evaluate the condition and output the result.""" + state = await self._ensure_state_initialized(ctx, trigger) + + result = await state.eval(self._condition_expr) + is_truthy = bool(result) + + if is_truthy: + await ctx.send_message(ConditionResult(matched=True, branch_index=0, value=result)) + else: + await ctx.send_message(ConditionResult(matched=False, branch_index=ELSE_BRANCH_INDEX, value=result)) + + +class ForeachInitExecutor(DeclarativeActionExecutor): + """Initializes a foreach loop. + + Sets up the loop state in SharedState and determines if there are items. + """ + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[LoopIterationResult], + ) -> None: + """Initialize the loop and check for first item.""" + state = await self._ensure_state_initialized(ctx, trigger) + + # Support multiple schema formats: + # - Graph mode: itemsSource, items + # - Interpreter mode: source + items_expr = ( + self._action_def.get("itemsSource") or self._action_def.get("items") or self._action_def.get("source") + ) + items_raw: Any = await state.eval_if_expression(items_expr) or [] + + items: list[Any] + items = (list(items_raw) if items_raw else []) if not isinstance(items_raw, (list, tuple)) else list(items_raw) # type: ignore + + loop_id = self.id + + # Store loop state + state_data = await state.get_state_data() + loop_states: dict[str, Any] = cast(dict[str, Any], state_data).setdefault(LOOP_STATE_KEY, {}) + loop_states[loop_id] = { + "items": items, + "index": 0, + "length": len(items), + } + await state.set_state_data(state_data) + + # Check if we have items + if items: + # Set the iteration variable + # Support multiple schema formats: + # - Graph mode: iteratorVariable, item (default "Local.item") + # - Interpreter mode: itemName (default "item", stored in Local scope) + item_var = self._action_def.get("iteratorVariable") or self._action_def.get("item") + if not item_var: + # Interpreter mode: itemName defaults to "item", store in Local scope + item_name = self._action_def.get("itemName", "item") + item_var = f"Local.{item_name}" + + # Support multiple schema formats for index: + # - Graph mode: indexVariable, index + # - Interpreter mode: indexName (default "index", stored in Local scope) + index_var = self._action_def.get("indexVariable") or self._action_def.get("index") + if not index_var and "indexName" in self._action_def: + index_name = self._action_def.get("indexName", "index") + index_var = f"Local.{index_name}" + + await state.set(item_var, items[0]) + if index_var: + await state.set(index_var, 0) + + await ctx.send_message(LoopIterationResult(has_next=True, current_item=items[0], current_index=0)) + else: + await ctx.send_message(LoopIterationResult(has_next=False)) + + +class ForeachNextExecutor(DeclarativeActionExecutor): + """Advances to the next item in a foreach loop. + + This executor is triggered after the loop body completes. + """ + + def __init__( + self, + action_def: dict[str, Any], + init_executor_id: str, + *, + id: str | None = None, + ): + """Initialize with reference to the init executor. + + Args: + action_def: The Foreach action definition + init_executor_id: ID of the corresponding ForeachInitExecutor + id: Optional executor ID + """ + super().__init__(action_def, id=id) + self._init_executor_id = init_executor_id + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[LoopIterationResult], + ) -> None: + """Advance to next item and send result.""" + state = await self._ensure_state_initialized(ctx, trigger) + + loop_id = self._init_executor_id + + # Get loop state + state_data = await state.get_state_data() + loop_states: dict[str, Any] = cast(dict[str, Any], state_data).get(LOOP_STATE_KEY, {}) + loop_state = loop_states.get(loop_id) + + if not loop_state: + # No loop state - shouldn't happen but handle gracefully + await ctx.send_message(LoopIterationResult(has_next=False)) + return + + items = loop_state["items"] + current_index = loop_state["index"] + 1 + + if current_index < len(items): + # Update loop state + loop_state["index"] = current_index + await state.set_state_data(state_data) + + # Set the iteration variable + # Support multiple schema formats: + # - Graph mode: iteratorVariable, item (default "Local.item") + # - Interpreter mode: itemName (default "item", stored in Local scope) + item_var = self._action_def.get("iteratorVariable") or self._action_def.get("item") + if not item_var: + # Interpreter mode: itemName defaults to "item", store in Local scope + item_name = self._action_def.get("itemName", "item") + item_var = f"Local.{item_name}" + + # Support multiple schema formats for index: + # - Graph mode: indexVariable, index + # - Interpreter mode: indexName (default "index", stored in Local scope) + index_var = self._action_def.get("indexVariable") or self._action_def.get("index") + if not index_var and "indexName" in self._action_def: + index_name = self._action_def.get("indexName", "index") + index_var = f"Local.{index_name}" + + await state.set(item_var, items[current_index]) + if index_var: + await state.set(index_var, current_index) + + await ctx.send_message( + LoopIterationResult(has_next=True, current_item=items[current_index], current_index=current_index) + ) + else: + # Loop complete - clean up + loop_states_dict = cast(dict[str, Any], state_data).get(LOOP_STATE_KEY, {}) + if loop_id in loop_states_dict: + del loop_states_dict[loop_id] + await state.set_state_data(state_data) + + await ctx.send_message(LoopIterationResult(has_next=False)) + + @handler + async def handle_loop_control( + self, + control: LoopControl, + ctx: WorkflowContext[LoopIterationResult], + ) -> None: + """Handle break/continue signals.""" + state = self._get_state(ctx.shared_state) + + if control.action == "break": + # Clean up loop state and signal done + state_data = await state.get_state_data() + loop_states: dict[str, Any] = cast(dict[str, Any], state_data).get(LOOP_STATE_KEY, {}) + if self._init_executor_id in loop_states: + del loop_states[self._init_executor_id] + await state.set_state_data(state_data) + + await ctx.send_message(LoopIterationResult(has_next=False)) + + elif control.action == "continue": + # Just advance to next iteration + await self.handle_action(ActionTrigger(), ctx) + + +class BreakLoopExecutor(DeclarativeActionExecutor): + """Executor for BreakLoop action. + + Sends a LoopControl signal to break out of the enclosing loop. + """ + + def __init__( + self, + action_def: dict[str, Any], + loop_next_executor_id: str, + *, + id: str | None = None, + ): + """Initialize with reference to the loop's next executor. + + Args: + action_def: The action definition + loop_next_executor_id: ID of the ForeachNextExecutor to signal + id: Optional executor ID + """ + super().__init__(action_def, id=id) + self._loop_next_executor_id = loop_next_executor_id + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[LoopControl], + ) -> None: + """Send break signal to the loop.""" + await ctx.send_message(LoopControl(action="break")) + + +class ContinueLoopExecutor(DeclarativeActionExecutor): + """Executor for ContinueLoop action. + + Sends a LoopControl signal to continue to next iteration. + """ + + def __init__( + self, + action_def: dict[str, Any], + loop_next_executor_id: str, + *, + id: str | None = None, + ): + """Initialize with reference to the loop's next executor. + + Args: + action_def: The action definition + loop_next_executor_id: ID of the ForeachNextExecutor to signal + id: Optional executor ID + """ + super().__init__(action_def, id=id) + self._loop_next_executor_id = loop_next_executor_id + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[LoopControl], + ) -> None: + """Send continue signal to the loop.""" + await ctx.send_message(LoopControl(action="continue")) + + +class EndWorkflowExecutor(DeclarativeActionExecutor): + """Executor for EndWorkflow/EndDialog action. + + This executor simply doesn't send any message, causing the workflow + to terminate at this point. + """ + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """End the workflow by not sending any continuation message.""" + # Don't send ActionComplete - workflow ends here + pass + + +class EndConversationExecutor(DeclarativeActionExecutor): + """Executor for EndConversation action.""" + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """End the conversation.""" + # For now, just don't continue + # In a full implementation, this would signal to close the conversation + pass + + +# Passthrough executor for joining control flow branches +class JoinExecutor(DeclarativeActionExecutor): + """Executor that joins multiple branches back together. + + Used after If/Switch to merge control flow back to a single path. + Also used as passthrough nodes for else/default branches. + """ + + @handler + async def handle_action( + self, + trigger: dict[str, Any] | str | ActionTrigger | ActionComplete | ConditionResult | LoopIterationResult, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Simply pass through to continue the workflow.""" + await ctx.send_message(ActionComplete()) + + +class CancelDialogExecutor(DeclarativeActionExecutor): + """Executor for CancelDialog action. + + Cancels the current dialog/workflow, equivalent to .NET CancelDialog. + This terminates execution similarly to EndWorkflow. + """ + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Cancel the current dialog/workflow.""" + # CancelDialog terminates execution without continuing + # Similar to EndWorkflow but semantically different (cancellation vs completion) + pass + + +class CancelAllDialogsExecutor(DeclarativeActionExecutor): + """Executor for CancelAllDialogs action. + + Cancels all dialogs in the execution stack, equivalent to .NET CancelAllDialogs. + This terminates the entire workflow execution. + """ + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Cancel all dialogs/workflows.""" + # CancelAllDialogs terminates all execution + pass + + +# Mapping of control flow action kinds to executor classes +# Note: Most control flow is handled by the builder creating graph structure, +# these are the executors that are part of that structure +CONTROL_FLOW_EXECUTORS: dict[str, type[DeclarativeActionExecutor]] = { + "EndWorkflow": EndWorkflowExecutor, + "EndDialog": EndWorkflowExecutor, + "EndConversation": EndConversationExecutor, + "CancelDialog": CancelDialogExecutor, + "CancelAllDialogs": CancelAllDialogsExecutor, +} diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_external_input.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_external_input.py new file mode 100644 index 0000000000..c499f133ea --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_external_input.py @@ -0,0 +1,344 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""External input executors for declarative workflows. + +These executors handle interactions that require external input (user questions, +confirmations, etc.), using the request_info pattern to pause the workflow and +wait for responses. +""" + +import uuid +from dataclasses import dataclass, field +from typing import Any + +from agent_framework._workflows import ( + WorkflowContext, + handler, + response_handler, +) + +from ._declarative_base import ( + ActionComplete, + DeclarativeActionExecutor, +) + + +@dataclass +class ExternalInputRequest: + """Request for external input (triggers workflow pause). + + Aligns with .NET ExternalInputRequest pattern. Used by Question, Confirmation, + WaitForInput, and RequestExternalInput executors to signal that user input is + needed. The workflow will pause via request_info and wait for an ExternalInputResponse. + + Attributes: + request_id: Unique identifier for this request. + message: The prompt or question to display to the user. + request_type: Type of input requested (question, confirmation, user_input, external). + metadata: Additional context (choices, output_property, timeout, etc.). + """ + + request_id: str + message: str + request_type: str = "external" + metadata: dict[str, Any] = field(default_factory=dict) # type: ignore + + +@dataclass +class ExternalInputResponse: + """Response to an ExternalInputRequest. + + Provided by the caller to resume workflow execution with user input. + + Attributes: + user_input: The user's text response. + value: Optional typed value (e.g., bool for confirmations, selected choice). + """ + + user_input: str + value: Any = None + + +class QuestionExecutor(DeclarativeActionExecutor): + """Executor that asks the user a question and waits for a response. + + Uses the request_info pattern to pause execution until the user provides an answer. + The response is stored in workflow state at the configured output property. + """ + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Ask the question and wait for a response.""" + state = await self._ensure_state_initialized(ctx, trigger) + + question_text = self._action_def.get("text") or self._action_def.get("question", "") + output_property = self._action_def.get("output", {}).get("property") or self._action_def.get( + "property", "Local.answer" + ) + choices = self._action_def.get("choices", []) + default_value = self._action_def.get("defaultValue") + allow_free_text = self._action_def.get("allowFreeText", True) + + # Evaluate the question text if it's an expression + evaluated_question = await state.eval_if_expression(question_text) + + # Build choices metadata + choices_data: list[dict[str, str]] | None = None + if choices: + choices_data = [] + for c in choices: + if isinstance(c, dict): + c_dict: dict[str, Any] = dict(c) # type: ignore[arg-type] + choices_data.append({ + "value": c_dict.get("value", ""), + "label": c_dict.get("label") or c_dict.get("value", ""), + }) + else: + choices_data.append({"value": str(c), "label": str(c)}) + + # Store output property in shared state for response handler + await ctx.shared_state.set("_question_output_property", output_property) + await ctx.shared_state.set("_question_default_value", default_value) + + # Request external input - workflow pauses here + await ctx.request_info( + ExternalInputRequest( + request_id=str(uuid.uuid4()), + message=str(evaluated_question), + request_type="question", + metadata={ + "output_property": output_property, + "choices": choices_data, + "allow_free_text": allow_free_text, + "default_value": default_value, + }, + ), + ExternalInputResponse, + ) + + @response_handler + async def handle_response( + self, + original_request: ExternalInputRequest, + response: ExternalInputResponse, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Handle the user's response to the question.""" + state = self._get_state(ctx.shared_state) + + output_property = original_request.metadata.get("output_property", "Local.answer") + answer = response.value if response.value is not None else response.user_input + + if output_property: + await state.set(output_property, answer) + + await ctx.send_message(ActionComplete()) + + +class ConfirmationExecutor(DeclarativeActionExecutor): + """Executor that asks for a yes/no confirmation. + + A specialized version of Question that expects a boolean response. + """ + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Ask for confirmation.""" + state = await self._ensure_state_initialized(ctx, trigger) + + message = self._action_def.get("text") or self._action_def.get("message", "") + output_property = self._action_def.get("output", {}).get("property") or self._action_def.get( + "property", "Local.confirmed" + ) + yes_label = self._action_def.get("yesLabel", "Yes") + no_label = self._action_def.get("noLabel", "No") + default_value = self._action_def.get("defaultValue", False) + + # Evaluate the message if it's an expression + evaluated_message = await state.eval_if_expression(message) + + # Request confirmation - workflow pauses here + await ctx.request_info( + ExternalInputRequest( + request_id=str(uuid.uuid4()), + message=str(evaluated_message), + request_type="confirmation", + metadata={ + "output_property": output_property, + "yes_label": yes_label, + "no_label": no_label, + "default_value": default_value, + }, + ), + ExternalInputResponse, + ) + + @response_handler + async def handle_response( + self, + original_request: ExternalInputRequest, + response: ExternalInputResponse, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Handle the user's confirmation response.""" + state = self._get_state(ctx.shared_state) + + output_property = original_request.metadata.get("output_property", "Local.confirmed") + + # Convert response to boolean + if response.value is not None: + confirmed = bool(response.value) + else: + # Interpret common affirmative responses + user_input_lower = response.user_input.lower().strip() + confirmed = user_input_lower in ("yes", "y", "true", "1", "confirm", "ok") + + if output_property: + await state.set(output_property, confirmed) + + await ctx.send_message(ActionComplete()) + + +class WaitForInputExecutor(DeclarativeActionExecutor): + """Executor that waits for user input during a conversation. + + Used when the workflow needs to pause and wait for the next user message + in a conversational flow. + """ + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete, str], + ) -> None: + """Wait for user input.""" + state = await self._ensure_state_initialized(ctx, trigger) + + prompt = self._action_def.get("prompt") + output_property = self._action_def.get("output", {}).get("property") or self._action_def.get( + "property", "Local.input" + ) + timeout_seconds = self._action_def.get("timeout") + + # Emit prompt if specified + if prompt: + evaluated_prompt = await state.eval_if_expression(prompt) + await ctx.yield_output(str(evaluated_prompt)) + + # Request user input - workflow pauses here + await ctx.request_info( + ExternalInputRequest( + request_id=str(uuid.uuid4()), + message=str(prompt) if prompt else "Waiting for input...", + request_type="user_input", + metadata={ + "output_property": output_property, + "timeout_seconds": timeout_seconds, + }, + ), + ExternalInputResponse, + ) + + @response_handler + async def handle_response( + self, + original_request: ExternalInputRequest, + response: ExternalInputResponse, + ctx: WorkflowContext[ActionComplete, str], + ) -> None: + """Handle the user's input.""" + state = self._get_state(ctx.shared_state) + + output_property = original_request.metadata.get("output_property", "Local.input") + + if output_property: + await state.set(output_property, response.user_input) + + await ctx.send_message(ActionComplete()) + + +class RequestExternalInputExecutor(DeclarativeActionExecutor): + """Executor that requests external input/approval. + + Used for complex external integrations beyond simple questions, + such as approval workflows, document uploads, or external system integrations. + """ + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Request external input.""" + state = await self._ensure_state_initialized(ctx, trigger) + + request_type = self._action_def.get("requestType", "external") + message = self._action_def.get("message", "") + output_property = self._action_def.get("output", {}).get("property") or self._action_def.get( + "property", "Local.externalInput" + ) + timeout_seconds = self._action_def.get("timeout") + required_fields = self._action_def.get("requiredFields", []) + metadata = self._action_def.get("metadata", {}) + + # Evaluate the message if it's an expression + evaluated_message = await state.eval_if_expression(message) + + # Build request metadata + request_metadata: dict[str, Any] = { + **metadata, + "output_property": output_property, + "required_fields": required_fields, + } + + if timeout_seconds: + request_metadata["timeout_seconds"] = timeout_seconds + + # Request external input - workflow pauses here + await ctx.request_info( + ExternalInputRequest( + request_id=str(uuid.uuid4()), + message=str(evaluated_message), + request_type=request_type, + metadata=request_metadata, + ), + ExternalInputResponse, + ) + + @response_handler + async def handle_response( + self, + original_request: ExternalInputRequest, + response: ExternalInputResponse, + ctx: WorkflowContext[ActionComplete], + ) -> None: + """Handle the external input response.""" + state = self._get_state(ctx.shared_state) + + output_property = original_request.metadata.get("output_property", "Local.externalInput") + + # Store the response value or user_input + result = response.value if response.value is not None else response.user_input + if output_property: + await state.set(output_property, result) + + await ctx.send_message(ActionComplete()) + + +# Mapping of external input action kinds to executor classes +EXTERNAL_INPUT_EXECUTORS: dict[str, type[DeclarativeActionExecutor]] = { + "Question": QuestionExecutor, + "Confirmation": ConfirmationExecutor, + "WaitForInput": WaitForInputExecutor, + "RequestExternalInput": RequestExternalInputExecutor, +} diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_factory.py b/python/packages/declarative/agent_framework_declarative/_workflows/_factory.py new file mode 100644 index 0000000000..812f256828 --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_factory.py @@ -0,0 +1,676 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""WorkflowFactory creates executable Workflow objects from YAML definitions. + +This module provides the main entry point for declarative workflow support, +parsing YAML workflow definitions and creating Workflow objects that can be +executed using the core workflow runtime. + +Each YAML action becomes a real Executor node in the workflow graph, +enabling checkpointing, visualization, and pause/resume capabilities. +""" + +from collections.abc import Mapping +from pathlib import Path +from typing import Any, cast + +import yaml +from agent_framework import ( + AgentExecutor, + AgentProtocol, + CheckpointStorage, + Workflow, + get_logger, +) + +from .._loader import AgentFactory +from ._declarative_builder import DeclarativeWorkflowBuilder + +logger = get_logger("agent_framework.declarative.workflows") + + +class DeclarativeWorkflowError(Exception): + """Exception raised for errors in declarative workflow processing.""" + + pass + + +class WorkflowFactory: + """Factory for creating executable Workflow objects from YAML definitions. + + WorkflowFactory parses declarative workflow YAML files and creates + Workflow objects that can be executed using the core workflow runtime. + Each YAML action becomes a real Executor node in the workflow graph, + enabling checkpointing at action boundaries, visualization, and pause/resume. + + Examples: + .. code-block:: python + + from agent_framework.declarative import WorkflowFactory + + # Basic usage: create workflow from YAML file + factory = WorkflowFactory() + workflow = factory.create_workflow_from_yaml_path("workflow.yaml") + + async for event in workflow.run_stream({"query": "Hello"}): + print(event) + + .. code-block:: python + + from agent_framework.declarative import WorkflowFactory + from agent_framework import FileCheckpointStorage + + # With checkpointing for pause/resume support + storage = FileCheckpointStorage(path="./checkpoints") + factory = WorkflowFactory(checkpoint_storage=storage) + workflow = factory.create_workflow_from_yaml_path("workflow.yaml") + + .. code-block:: python + + from agent_framework.azure import AzureOpenAIChatClient + from agent_framework.declarative import WorkflowFactory + + # Pre-register agents for InvokeAzureAgent actions + chat_client = AzureOpenAIChatClient() + agent = chat_client.create_agent(name="MyAgent", instructions="You are helpful.") + + factory = WorkflowFactory(agents={"MyAgent": agent}) + workflow = factory.create_workflow_from_yaml_path("workflow.yaml") + """ + + _agents: dict[str, AgentProtocol | AgentExecutor] + + def __init__( + self, + *, + agent_factory: AgentFactory | None = None, + agents: Mapping[str, AgentProtocol | AgentExecutor] | None = None, + bindings: Mapping[str, Any] | None = None, + env_file: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + ) -> None: + """Initialize the workflow factory. + + Args: + agent_factory: Optional AgentFactory for creating agents from inline YAML definitions. + agents: Optional pre-created agents by name. These are looked up when processing + InvokeAzureAgent actions in the workflow YAML. + bindings: Optional function bindings for tool calls within workflow actions. + env_file: Optional path to .env file for environment variables used in agent creation. + checkpoint_storage: Optional checkpoint storage enabling pause/resume functionality. + + Examples: + .. code-block:: python + + from agent_framework.declarative import WorkflowFactory + + # Minimal initialization + factory = WorkflowFactory() + + .. code-block:: python + + from agent_framework.azure import AzureOpenAIChatClient + from agent_framework.declarative import WorkflowFactory + + # With pre-registered agents + client = AzureOpenAIChatClient() + agents = { + "WriterAgent": client.create_agent(name="Writer", instructions="Write content."), + "ReviewerAgent": client.create_agent(name="Reviewer", instructions="Review content."), + } + factory = WorkflowFactory(agents=agents) + + .. code-block:: python + + from agent_framework import FileCheckpointStorage + from agent_framework.declarative import WorkflowFactory + + # With checkpoint storage for pause/resume + factory = WorkflowFactory( + checkpoint_storage=FileCheckpointStorage("./checkpoints"), + env_file=".env", + ) + """ + self._agent_factory = agent_factory or AgentFactory(env_file_path=env_file) + self._agents: dict[str, AgentProtocol | AgentExecutor] = dict(agents) if agents else {} + self._bindings: dict[str, Any] = dict(bindings) if bindings else {} + self._checkpoint_storage = checkpoint_storage + + def create_workflow_from_yaml_path( + self, + yaml_path: str | Path, + ) -> Workflow: + """Create a Workflow from a YAML file path. + + Args: + yaml_path: Path to the YAML workflow definition file. + + Returns: + An executable Workflow object with action nodes for each YAML action. + + Raises: + DeclarativeWorkflowError: If the YAML is invalid or cannot be parsed. + FileNotFoundError: If the YAML file doesn't exist. + + Examples: + .. code-block:: python + + from agent_framework.declarative import WorkflowFactory + + factory = WorkflowFactory() + workflow = factory.create_workflow_from_yaml_path("workflow.yaml") + + # Execute the workflow + async for event in workflow.run_stream({"input": "Hello"}): + print(event) + + .. code-block:: python + + from pathlib import Path + from agent_framework.declarative import WorkflowFactory + + # Using Path object + workflow_path = Path(__file__).parent / "workflows" / "customer_support.yaml" + factory = WorkflowFactory() + workflow = factory.create_workflow_from_yaml_path(workflow_path) + """ + if not isinstance(yaml_path, Path): + yaml_path = Path(yaml_path) + + if not yaml_path.exists(): + raise FileNotFoundError(f"Workflow YAML file not found: {yaml_path}") + + with open(yaml_path) as f: + yaml_content = f.read() + + return self.create_workflow_from_yaml(yaml_content, base_path=yaml_path.parent) + + def create_workflow_from_yaml( + self, + yaml_content: str, + base_path: Path | None = None, + ) -> Workflow: + """Create a Workflow from a YAML string. + + Args: + yaml_content: The YAML workflow definition as a string. + base_path: Optional base path for resolving relative file references + in agent definitions. + + Returns: + An executable Workflow object with action nodes for each YAML action. + + Raises: + DeclarativeWorkflowError: If the YAML is invalid or cannot be parsed. + + Examples: + .. code-block:: python + + from agent_framework.declarative import WorkflowFactory + + yaml_content = ''' + kind: Workflow + trigger: + kind: OnConversationStart + id: greeting_workflow + actions: + - kind: SetVariable + id: set_greeting + variable: Local.Greeting + value: "Hello, World!" + - kind: SendActivity + id: send_greeting + activity: =Local.Greeting + ''' + + factory = WorkflowFactory() + workflow = factory.create_workflow_from_yaml(yaml_content) + + .. code-block:: python + + from pathlib import Path + from agent_framework.declarative import WorkflowFactory + + # With base_path for resolving relative agent file references + yaml_content = ''' + kind: Workflow + agents: + MyAgent: + file: ./agents/my_agent.yaml + trigger: + actions: + - kind: InvokeAzureAgent + agent: + name: MyAgent + ''' + + factory = WorkflowFactory() + workflow = factory.create_workflow_from_yaml( + yaml_content, + base_path=Path("./workflows"), + ) + """ + try: + workflow_def = yaml.safe_load(yaml_content) + except yaml.YAMLError as e: + raise DeclarativeWorkflowError(f"Invalid YAML: {e}") from e + + return self.create_workflow_from_definition(workflow_def, base_path=base_path) + + def create_workflow_from_definition( + self, + workflow_def: dict[str, Any], + base_path: Path | None = None, + ) -> Workflow: + """Create a Workflow from a parsed workflow definition dictionary. + + This is the lowest-level creation method, useful when you already have + a parsed dictionary (e.g., from programmatic construction or custom parsing). + + Args: + workflow_def: The parsed workflow definition dictionary containing + 'kind', 'trigger', 'actions', and optionally 'agents' keys. + base_path: Optional base path for resolving relative file references + in agent definitions. + + Returns: + An executable Workflow object with action nodes for each YAML action. + + Raises: + DeclarativeWorkflowError: If the definition is invalid or missing required fields. + + Examples: + .. code-block:: python + + from agent_framework.declarative import WorkflowFactory + + # Programmatically construct a workflow definition + workflow_def = { + "kind": "Workflow", + "name": "my_workflow", + "trigger": { + "kind": "OnConversationStart", + "id": "main_trigger", + "actions": [ + { + "kind": "SetVariable", + "id": "init", + "variable": "Local.Counter", + "value": 0, + }, + { + "kind": "SendActivity", + "id": "output", + "activity": "Counter initialized", + }, + ], + }, + } + + factory = WorkflowFactory() + workflow = factory.create_workflow_from_definition(workflow_def) + """ + # Validate the workflow definition + self._validate_workflow_def(workflow_def) + + # Extract workflow metadata + # Support both "name" field and trigger.id for workflow name + name: str = workflow_def.get("name", "") + if not name: + trigger: dict[str, Any] = workflow_def.get("trigger", {}) + trigger_id = trigger.get("id", "declarative_workflow") + name = str(trigger_id) if trigger_id else "declarative_workflow" + description = workflow_def.get("description") + + # Create agents from definitions + agents: dict[str, AgentProtocol | AgentExecutor] = dict(self._agents) + agent_defs = workflow_def.get("agents", {}) + + for agent_name, agent_def in agent_defs.items(): + if agent_name in agents: + # Already have this agent + continue + + # Create agent using AgentFactory + try: + agent = self._create_agent_from_def(agent_def, base_path) + agents[agent_name] = agent + logger.debug(f"Created agent '{agent_name}' from definition") + except Exception as e: + logger.error(f"Failed to create agent '{agent_name}': {e}") + raise DeclarativeWorkflowError(f"Failed to create agent '{agent_name}': {e}") from e + + return self._create_workflow(workflow_def, name, description, agents) + + def _create_workflow( + self, + workflow_def: dict[str, Any], + name: str, + description: str | None, + agents: dict[str, AgentProtocol | AgentExecutor], + ) -> Workflow: + """Create workflow from definition. + + Each YAML action becomes a real Executor node in the workflow graph. + This enables checkpointing at action boundaries. + + Args: + workflow_def: The workflow definition + name: Workflow name + description: Workflow description + agents: Registry of agent instances + + Returns: + Workflow with individual action executors as nodes + """ + # Normalize workflow definition to have actions at top level + normalized_def = self._normalize_workflow_def(workflow_def) + normalized_def["name"] = name + if description: + normalized_def["description"] = description + + # Build the graph-based workflow, passing agents for InvokeAzureAgent executors + try: + graph_builder = DeclarativeWorkflowBuilder( + normalized_def, + workflow_id=name, + agents=agents, + checkpoint_storage=self._checkpoint_storage, + ) + workflow = graph_builder.build() + except ValueError as e: + raise DeclarativeWorkflowError(f"Failed to build graph-based workflow: {e}") from e + + # Store agents and bindings for reference (executors already have them) + workflow._declarative_agents = agents # type: ignore[attr-defined] + workflow._declarative_bindings = self._bindings # type: ignore[attr-defined] + + # Store input schema if defined in workflow definition + # This allows DevUI to generate proper input forms + if "inputs" in workflow_def: + workflow.input_schema = self._convert_inputs_to_json_schema(workflow_def["inputs"]) # type: ignore[attr-defined] + + logger.debug( + "Created graph-based workflow '%s' with %d executors", + name, + len(graph_builder._executors), # type: ignore[reportPrivateUsage] + ) + + return workflow + + def _normalize_workflow_def(self, workflow_def: dict[str, Any]) -> dict[str, Any]: + """Normalize workflow definition to have actions at top level. + + Args: + workflow_def: The workflow definition + + Returns: + Normalized definition with actions at top level + """ + actions = self._get_actions_from_def(workflow_def) + return { + **workflow_def, + "actions": actions, + } + + def _validate_workflow_def(self, workflow_def: dict[str, Any]) -> None: + """Validate a workflow definition. + + Args: + workflow_def: The workflow definition to validate + + Raises: + DeclarativeWorkflowError: If the definition is invalid + """ + if not isinstance(workflow_def, dict): + raise DeclarativeWorkflowError("Workflow definition must be a dictionary") + + # Handle both formats: + # 1. Direct actions list: {"actions": [...]} + # 2. Trigger-based: {"kind": "Workflow", "trigger": {"actions": [...]}} + actions = self._get_actions_from_def(workflow_def) + + if not isinstance(actions, list): + raise DeclarativeWorkflowError("Workflow 'actions' must be a list") + + # Validate each action has a kind + for i, action in enumerate(actions): + if not isinstance(action, dict): + raise DeclarativeWorkflowError(f"Action at index {i} must be a dictionary") + if "kind" not in action: + raise DeclarativeWorkflowError(f"Action at index {i} missing 'kind' field") + + def _get_actions_from_def(self, workflow_def: dict[str, Any]) -> list[dict[str, Any]]: + """Extract actions from a workflow definition. + + Handles both direct actions format and trigger-based format. + + Args: + workflow_def: The workflow definition + + Returns: + List of action definitions + + Raises: + DeclarativeWorkflowError: If no actions can be found + """ + # Try direct actions first + if "actions" in workflow_def: + actions: list[dict[str, Any]] = workflow_def["actions"] + return actions + + # Try trigger-based format + if "trigger" in workflow_def: + trigger = workflow_def["trigger"] + if isinstance(trigger, dict) and "actions" in trigger: + trigger_actions: list[dict[str, Any]] = list(trigger["actions"]) # type: ignore[arg-type] + return trigger_actions + + raise DeclarativeWorkflowError("Workflow definition must have 'actions' field or 'trigger.actions' field") + + def _create_agent_from_def( + self, + agent_def: dict[str, Any], + base_path: Path | None = None, + ) -> Any: + """Create an agent from a definition. + + Args: + agent_def: The agent definition dictionary + base_path: Optional base path for resolving relative file references + + Returns: + An agent instance + """ + # Check if it's a reference to an external file + if "file" in agent_def: + file_path = agent_def["file"] + if base_path and not Path(file_path).is_absolute(): + file_path = base_path / file_path + return self._agent_factory.create_agent_from_yaml_path(file_path) + + # Check if it's an inline agent definition + if "kind" in agent_def: + return self._agent_factory.create_agent_from_dict(agent_def) + + # Handle connection-based agent (like Azure AI agents) + if "connection" in agent_def: + # This would create a hosted agent client + # For now, we'll need the user to provide pre-created agents + raise DeclarativeWorkflowError( + "Connection-based agents must be provided via the 'agents' parameter. " + "Create the agent using the appropriate client and pass it to WorkflowFactory." + ) + + raise DeclarativeWorkflowError( + f"Invalid agent definition. Expected 'file', 'kind', or 'connection': {agent_def}" + ) + + def register_agent(self, name: str, agent: AgentProtocol | AgentExecutor) -> "WorkflowFactory": + """Register an agent instance with the factory for use in workflows. + + Registered agents are available to InvokeAzureAgent actions by name. + This method supports fluent chaining. + + Args: + name: The name to register the agent under. Must match the agent name + referenced in InvokeAzureAgent actions. + agent: The agent instance (typically a ChatAgent or similar). + + Returns: + Self for method chaining. + + Examples: + .. code-block:: python + + from agent_framework.azure import AzureOpenAIChatClient + from agent_framework.declarative import WorkflowFactory + + client = AzureOpenAIChatClient() + + # Method chaining to register multiple agents + factory = ( + WorkflowFactory() + .register_agent( + "Writer", + client.create_agent( + name="Writer", + instructions="Write content.", + ), + ) + .register_agent( + "Reviewer", + client.create_agent( + name="Reviewer", + instructions="Review content.", + ), + ) + ) + + workflow = factory.create_workflow_from_yaml_path("workflow.yaml") + """ + self._agents[name] = agent + return self + + def register_binding(self, name: str, func: Any) -> "WorkflowFactory": + """Register a function binding with the factory for use in workflow actions. + + Bindings allow workflow actions to invoke Python functions by name. + This method supports fluent chaining. + + Args: + name: The name to register the function under. + func: The function to bind. + + Returns: + Self for method chaining. + + Examples: + .. code-block:: python + + from agent_framework.declarative import WorkflowFactory + + + def get_weather(location: str) -> str: + return f"Weather in {location}: Sunny, 72F" + + + def send_email(to: str, subject: str, body: str) -> bool: + # Send email logic + return True + + + # Register functions for use in workflow + factory = ( + WorkflowFactory() + .register_binding("get_weather", get_weather) + .register_binding("send_email", send_email) + ) + + workflow = factory.create_workflow_from_yaml_path("workflow.yaml") + """ + self._bindings[name] = func + return self + + def _convert_inputs_to_json_schema(self, inputs_def: dict[str, Any]) -> dict[str, Any]: + """Convert a declarative inputs definition to JSON Schema. + + The inputs definition uses a simplified format: + inputs: + age: + type: integer + description: The user's age + name: + type: string + + This is converted to standard JSON Schema format. + + Args: + inputs_def: The inputs definition from the workflow YAML + + Returns: + A JSON Schema object + """ + properties: dict[str, Any] = {} + required: list[str] = [] + + for field_name, field_def in inputs_def.items(): + if isinstance(field_def, dict): + # Field has type and possibly other attributes + prop: dict[str, Any] = {} + field_def_dict: dict[str, Any] = cast(dict[str, Any], field_def) + field_type: str = str(field_def_dict.get("type", "string")) + + # Map declarative types to JSON Schema types + type_mapping: dict[str, str] = { + "string": "string", + "str": "string", + "integer": "integer", + "int": "integer", + "number": "number", + "float": "number", + "boolean": "boolean", + "bool": "boolean", + "array": "array", + "list": "array", + "object": "object", + "dict": "object", + } + prop["type"] = type_mapping.get(field_type, field_type) + + # Copy other attributes + if "description" in field_def_dict: + prop["description"] = field_def_dict["description"] + if "default" in field_def_dict: + prop["default"] = field_def_dict["default"] + if "enum" in field_def_dict: + prop["enum"] = field_def_dict["enum"] + + # Check if required (default: true unless explicitly false) + if field_def_dict.get("required", True): + required.append(field_name) + + properties[field_name] = prop + else: + # Simple type definition (e.g., "age: integer") + type_mapping_simple: dict[str, str] = { + "string": "string", + "str": "string", + "integer": "integer", + "int": "integer", + "number": "number", + "float": "number", + "boolean": "boolean", + "bool": "boolean", + } + properties[field_name] = {"type": type_mapping_simple.get(str(field_def), "string")} + required.append(field_name) + + schema: dict[str, Any] = { + "type": "object", + "properties": properties, + } + if required: + schema["required"] = required + + return schema diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_handlers.py b/python/packages/declarative/agent_framework_declarative/_workflows/_handlers.py new file mode 100644 index 0000000000..64db7f43f6 --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_handlers.py @@ -0,0 +1,211 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Action handlers for declarative workflow execution. + +This module provides the ActionHandler protocol and registry for executing +workflow actions defined in YAML. Each action type (InvokeAzureAgent, Foreach, etc.) +has a corresponding handler registered via the @action_handler decorator. +""" + +from collections.abc import AsyncGenerator, Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +from agent_framework import get_logger + +if TYPE_CHECKING: + from ._state import WorkflowState + +logger = get_logger("agent_framework.declarative.workflows") + + +@dataclass +class ActionContext: + """Context passed to action handlers during execution. + + Provides access to workflow state, the action definition, and methods + for executing nested actions (for control flow constructs like Foreach). + """ + + state: "WorkflowState" + """The current workflow state with variables and agent results.""" + + action: dict[str, Any] + """The action definition from the YAML.""" + + execute_actions: "ExecuteActionsFn" + """Function to execute a list of nested actions (for Foreach, If, etc.).""" + + agents: dict[str, Any] + """Registry of agent instances by name.""" + + bindings: dict[str, Any] + """Function bindings for tool calls.""" + + @property + def action_id(self) -> str | None: + """Get the action's unique identifier.""" + return self.action.get("id") + + @property + def display_name(self) -> str | None: + """Get the action's human-readable display name for debugging/logging.""" + return self.action.get("displayName") + + @property + def action_kind(self) -> str | None: + """Get the action's type/kind.""" + return self.action.get("kind") + + +# Type alias for the nested action executor function +ExecuteActionsFn = Callable[ + [list[dict[str, Any]], "WorkflowState"], + AsyncGenerator["WorkflowEvent", None], +] + + +@dataclass +class WorkflowEvent: + """Base class for events emitted during workflow execution.""" + + pass + + +@dataclass +class TextOutputEvent(WorkflowEvent): + """Event emitted when text should be sent to the user.""" + + text: str + """The text content to output.""" + + +@dataclass +class AttachmentOutputEvent(WorkflowEvent): + """Event emitted when an attachment should be sent to the user.""" + + content: Any + """The attachment content.""" + + content_type: str = "application/octet-stream" + """The MIME type of the attachment.""" + + +@dataclass +class AgentResponseEvent(WorkflowEvent): + """Event emitted when an agent produces a response.""" + + agent_name: str + """The name of the agent that produced the response.""" + + text: str | None + """The text content of the response, if any.""" + + messages: list[Any] + """The messages from the agent response.""" + + tool_calls: list[Any] | None = None + """Any tool calls made by the agent.""" + + +@dataclass +class AgentStreamingChunkEvent(WorkflowEvent): + """Event emitted for streaming chunks from an agent.""" + + agent_name: str + """The name of the agent producing the chunk.""" + + chunk: str + """The streaming chunk content.""" + + +@dataclass +class CustomEvent(WorkflowEvent): + """Custom event emitted via EmitEvent action.""" + + name: str + """The event name.""" + + data: Any + """The event data.""" + + +@dataclass +class LoopControlSignal(WorkflowEvent): + """Signal for loop control (break/continue).""" + + signal_type: str + """Either 'break' or 'continue'.""" + + +@runtime_checkable +class ActionHandler(Protocol): + """Protocol for action handlers. + + Action handlers are async generators that execute a single action type + and yield events as they process. They receive an ActionContext with + the current state, action definition, and utilities for nested execution. + """ + + def __call__( + self, + ctx: ActionContext, + ) -> AsyncGenerator[WorkflowEvent, None]: + """Execute the action and yield events. + + Args: + ctx: The action context containing state, action definition, and utilities + + Yields: + WorkflowEvent instances as the action executes + """ + ... + + +# Global registry of action handlers +_ACTION_HANDLERS: dict[str, ActionHandler] = {} + + +def action_handler(action_kind: str) -> Callable[[ActionHandler], ActionHandler]: + """Decorator to register an action handler for a specific action type. + + Args: + action_kind: The action type this handler processes (e.g., 'InvokeAzureAgent') + + Example: + @action_handler("SetValue") + async def handle_set_value(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: + path = ctx.action.get("path") + value = ctx.state.eval_if_expression(ctx.action.get("value")) + ctx.state.set(path, value) + return + yield # Make it a generator + """ + + def decorator(func: ActionHandler) -> ActionHandler: + _ACTION_HANDLERS[action_kind] = func + logger.debug(f"Registered action handler for '{action_kind}'") + return func + + return decorator + + +def get_action_handler(action_kind: str) -> ActionHandler | None: + """Get the registered handler for an action type. + + Args: + action_kind: The action type to look up + + Returns: + The registered ActionHandler, or None if not found + """ + return _ACTION_HANDLERS.get(action_kind) + + +def list_action_handlers() -> list[str]: + """List all registered action handler types. + + Returns: + A list of registered action type names + """ + return list(_ACTION_HANDLERS.keys()) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_human_input.py b/python/packages/declarative/agent_framework_declarative/_workflows/_human_input.py new file mode 100644 index 0000000000..97259807e7 --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_human_input.py @@ -0,0 +1,320 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Human-in-the-loop action handlers for declarative workflows. + +This module implements handlers for human input patterns: +- Question: Request human input with validation +- RequestExternalInput: Request input from external system +- ExternalLoop processing: Loop while waiting for external input +""" + +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast + +from agent_framework import get_logger + +from ._handlers import ( + ActionContext, + WorkflowEvent, + action_handler, +) + +if TYPE_CHECKING: + from ._state import WorkflowState + +logger = get_logger("agent_framework.declarative.workflows.human_input") + + +@dataclass +class QuestionRequest(WorkflowEvent): + """Event emitted when the workflow needs user input via Question action. + + When this event is yielded, the workflow execution should pause + and wait for user input to be provided via workflow.send_response(). + + This is used by the Question, RequestExternalInput, and WaitForInput + action handlers in the non-graph workflow path. + """ + + request_id: str + """Unique identifier for this request.""" + + prompt: str | None + """The prompt/question to display to the user.""" + + variable: str + """The variable where the response should be stored.""" + + validation: dict[str, Any] | None = None + """Optional validation rules for the input.""" + + choices: list[str] | None = None + """Optional list of valid choices.""" + + default_value: Any = None + """Default value if no input is provided.""" + + +@dataclass +class ExternalLoopEvent(WorkflowEvent): + """Event emitted when entering an external input loop. + + This event signals that the action is waiting for external input + in a loop pattern (e.g., input.externalLoop.when condition). + """ + + action_id: str + """The ID of the action that requires external input.""" + + iteration: int + """The current iteration number (0-based).""" + + condition_expression: str + """The PowerFx condition that must become false to exit the loop.""" + + +@action_handler("Question") +async def handle_question(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Handle Question action - request human input with optional validation. + + Action schema: + kind: Question + id: ask_name + variable: Local.userName + prompt: What is your name? + validation: + required: true + minLength: 1 + maxLength: 100 + choices: # optional - present as multiple choice + - Option A + - Option B + default: Option A # optional default value + + The handler emits a QuestionRequest and expects the workflow runner + to capture and provide the response before continuing. + """ + question_id = ctx.action.get("id", "question") + variable = ctx.action.get("variable") + prompt = ctx.action.get("prompt") + question: dict[str, Any] | Any = ctx.action.get("question", {}) + validation = ctx.action.get("validation", {}) + choices = ctx.action.get("choices") + default_value = ctx.action.get("default") + + if not variable: + logger.warning("Question action missing 'variable' property") + return + + # Evaluate prompt if it's an expression (support both 'prompt' and 'question.text') + prompt_text: Any | None = None + if isinstance(question, dict): + question_dict: dict[str, Any] = cast(dict[str, Any], question) + prompt_text = prompt or question_dict.get("text") + else: + prompt_text = prompt + evaluated_prompt = ctx.state.eval_if_expression(prompt_text) if prompt_text else None + + # Evaluate choices if they're expressions + evaluated_choices = None + if choices: + evaluated_choices = [ctx.state.eval_if_expression(c) if isinstance(c, str) else c for c in choices] + + logger.debug(f"Question: requesting input for {variable}") + + # Emit the request event + yield QuestionRequest( + request_id=question_id, + prompt=str(evaluated_prompt) if evaluated_prompt else None, + variable=variable, + validation=validation, + choices=evaluated_choices, + default_value=default_value, + ) + + # Apply default value if specified (for non-interactive scenarios) + if default_value is not None: + evaluated_default = ctx.state.eval_if_expression(default_value) + ctx.state.set(variable, evaluated_default) + + +@action_handler("RequestExternalInput") +async def handle_request_external_input(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Handle RequestExternalInput action - request input from external system. + + Action schema: + kind: RequestExternalInput + id: get_approval + variable: Local.approval + prompt: Please approve or reject the request + timeout: 300 # seconds + default: "No feedback provided" # optional default value + output: + response: Local.approvalResponse + timestamp: Local.approvalTime + + Similar to Question but designed for external system integration + rather than direct human input. + """ + request_id = ctx.action.get("id", "external_input") + variable = ctx.action.get("variable") + prompt = ctx.action.get("prompt") + timeout = ctx.action.get("timeout") # seconds + default_value = ctx.action.get("default") + _output = ctx.action.get("output", {}) # Reserved for future use + + if not variable: + logger.warning("RequestExternalInput action missing 'variable' property") + return + + # Extract prompt text (support both 'prompt' string and 'prompt.text' object) + prompt_text: Any | None = None + if isinstance(prompt, dict): + prompt_dict: dict[str, Any] = cast(dict[str, Any], prompt) + prompt_text = prompt_dict.get("text") + else: + prompt_text = prompt + + # Evaluate prompt if it's an expression + evaluated_prompt = ctx.state.eval_if_expression(prompt_text) if prompt_text else None + + logger.debug(f"RequestExternalInput: requesting input for {variable}") + + # Emit the request event + yield QuestionRequest( + request_id=request_id, + prompt=str(evaluated_prompt) if evaluated_prompt else None, + variable=variable, + validation={"timeout": timeout} if timeout else None, + default_value=default_value, + ) + + # Apply default value if specified (for non-interactive scenarios) + if default_value is not None: + evaluated_default = ctx.state.eval_if_expression(default_value) + ctx.state.set(variable, evaluated_default) + + +@action_handler("WaitForInput") +async def handle_wait_for_input(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029 + """Handle WaitForInput action - pause and wait for external input. + + Action schema: + kind: WaitForInput + id: wait_for_response + variable: Local.response + message: Waiting for user response... + + This is a simpler form of RequestExternalInput that just pauses + execution until input is provided. + """ + wait_id = ctx.action.get("id", "wait") + variable = ctx.action.get("variable") + message = ctx.action.get("message") + + if not variable: + logger.warning("WaitForInput action missing 'variable' property") + return + + # Evaluate message if it's an expression + evaluated_message = ctx.state.eval_if_expression(message) if message else None + + logger.debug(f"WaitForInput: waiting for {variable}") + + yield QuestionRequest( + request_id=wait_id, + prompt=str(evaluated_message) if evaluated_message else None, + variable=variable, + ) + + +def process_external_loop( + input_config: dict[str, Any], + state: "WorkflowState", +) -> tuple[bool, str | None]: + """Process the externalLoop.when pattern from action input. + + This function evaluates the externalLoop.when condition to determine + if the action should continue looping for external input. + + Args: + input_config: The input configuration containing externalLoop + state: The workflow state for expression evaluation + + Returns: + Tuple of (should_continue_loop, condition_expression) + - should_continue_loop: True if the loop should continue + - condition_expression: The original condition expression for diagnostics + """ + external_loop = input_config.get("externalLoop", {}) + when_condition = external_loop.get("when") + + if not when_condition: + return (False, None) + + # Evaluate the condition + result = state.eval(when_condition) + + # The loop continues while the condition is True + should_continue = bool(result) if result is not None else False + + logger.debug(f"ExternalLoop condition '{when_condition[:50]}' evaluated to {should_continue}") + + return (should_continue, when_condition) + + +def validate_input_response( + value: Any, + validation: dict[str, Any] | None, +) -> tuple[bool, str | None]: + """Validate input response against validation rules. + + Args: + value: The input value to validate + validation: Validation rules from the Question action + + Returns: + Tuple of (is_valid, error_message) + """ + if not validation: + return (True, None) + + # Check required + if validation.get("required") and (value is None or value == ""): + return (False, "This field is required") + + if value is None: + return (True, None) + + # Check string length + if isinstance(value, str): + min_length = validation.get("minLength") + max_length = validation.get("maxLength") + + if min_length is not None and len(value) < min_length: + return (False, f"Minimum length is {min_length}") + + if max_length is not None and len(value) > max_length: + return (False, f"Maximum length is {max_length}") + + # Check numeric range + if isinstance(value, (int, float)): + min_value = validation.get("min") + max_value = validation.get("max") + + if min_value is not None and value < min_value: + return (False, f"Minimum value is {min_value}") + + if max_value is not None and value > max_value: + return (False, f"Maximum value is {max_value}") + + # Check pattern (regex) + pattern = validation.get("pattern") + if pattern and isinstance(value, str): + import re + + if not re.match(pattern, value): + return (False, f"Value does not match pattern: {pattern}") + + return (True, None) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py b/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py new file mode 100644 index 0000000000..1cc8ce2cfb --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py @@ -0,0 +1,494 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Custom PowerFx-like functions for declarative workflows. + +This module provides Python implementations of custom PowerFx functions +that are used in declarative workflows but may not be available in the +standard PowerFx Python package. + +These functions can be used as fallbacks when PowerFx is not available, +or registered with the PowerFx engine when it is available. +""" + +from typing import Any, cast + + +def message_text(messages: Any) -> str: + """Extract text content from a message or list of messages. + + This is equivalent to the .NET MessageText() function. + + Args: + messages: A message object, list of messages, or string + + Returns: + The concatenated text content of all messages + + Examples: + .. code-block:: python + + message_text([{"role": "assistant", "content": "Hello"}]) + # Returns: 'Hello' + """ + if messages is None: + return "" + + if isinstance(messages, str): + return messages + + if isinstance(messages, dict): + # Single message object + messages_dict = cast(dict[str, Any], messages) + content: Any = messages_dict.get("content", "") + if isinstance(content, str): + return content + if hasattr(content, "text"): + return str(content.text) + return str(content) if content else "" + + if isinstance(messages, list): + # List of messages - concatenate all text + texts: list[str] = [] + for msg in messages: + if isinstance(msg, str): + texts.append(msg) + elif isinstance(msg, dict): + msg_dict = cast(dict[str, Any], msg) + msg_content: Any = msg_dict.get("content", "") + if isinstance(msg_content, str): + texts.append(msg_content) + elif msg_content: + texts.append(str(msg_content)) + elif hasattr(msg, "content"): + msg_obj_content: Any = msg.content + if isinstance(msg_obj_content, str): + texts.append(msg_obj_content) + elif hasattr(msg_obj_content, "text"): + texts.append(str(msg_obj_content.text)) + elif msg_obj_content: + texts.append(str(msg_obj_content)) + return " ".join(texts) + + # Try to get text attribute + if hasattr(messages, "text"): + return str(messages.text) + if hasattr(messages, "content"): + content_attr: Any = messages.content + if isinstance(content_attr, str): + return content_attr + return str(content_attr) if content_attr else "" + + return str(messages) if messages else "" + + +def user_message(text: str) -> dict[str, str]: + """Create a user message object. + + This is equivalent to the .NET UserMessage() function. + + Args: + text: The text content of the message + + Returns: + A message dictionary with role 'user' + + Examples: + .. code-block:: python + + user_message("Hello") + # Returns: {'role': 'user', 'content': 'Hello'} + """ + return {"role": "user", "content": str(text) if text else ""} + + +def assistant_message(text: str) -> dict[str, str]: + """Create an assistant message object. + + Args: + text: The text content of the message + + Returns: + A message dictionary with role 'assistant' + + Examples: + .. code-block:: python + + assistant_message("Hello") + # Returns: {'role': 'assistant', 'content': 'Hello'} + """ + return {"role": "assistant", "content": str(text) if text else ""} + + +def agent_message(text: str) -> dict[str, str]: + """Create an agent/assistant message object. + + This is equivalent to the .NET AgentMessage() function. + It's an alias for assistant_message() for .NET compatibility. + + Args: + text: The text content of the message + + Returns: + A message dictionary with role 'assistant' + + Examples: + .. code-block:: python + + agent_message("Hello") + # Returns: {'role': 'assistant', 'content': 'Hello'} + """ + return {"role": "assistant", "content": str(text) if text else ""} + + +def system_message(text: str) -> dict[str, str]: + """Create a system message object. + + Args: + text: The text content of the message + + Returns: + A message dictionary with role 'system' + + Examples: + .. code-block:: python + + system_message("You are a helpful assistant") + # Returns: {'role': 'system', 'content': 'You are a helpful assistant'} + """ + return {"role": "system", "content": str(text) if text else ""} + + +def if_func(condition: Any, true_value: Any, false_value: Any = None) -> Any: + """Conditional expression - returns one value or another based on a condition. + + This is equivalent to the PowerFx If() function. + + Args: + condition: The condition to evaluate (truthy/falsy) + true_value: Value to return if condition is truthy + false_value: Value to return if condition is falsy (defaults to None) + + Returns: + true_value if condition is truthy, otherwise false_value + """ + return true_value if condition else false_value + + +def is_blank(value: Any) -> bool: + """Check if a value is blank (None, empty string, empty list, etc.). + + This is equivalent to the PowerFx IsBlank() function. + + Args: + value: The value to check + + Returns: + True if the value is considered blank + """ + if value is None: + return True + if isinstance(value, str) and not value.strip(): + return True + if isinstance(value, list): + return len(value) == 0 + if isinstance(value, dict): + return len(value) == 0 + return False + + +def or_func(*args: Any) -> bool: + """Logical OR - returns True if any argument is truthy. + + This is equivalent to the PowerFx Or() function. + + Args: + *args: Variable number of values to check + + Returns: + True if any argument is truthy + """ + return any(bool(arg) for arg in args) + + +def and_func(*args: Any) -> bool: + """Logical AND - returns True if all arguments are truthy. + + This is equivalent to the PowerFx And() function. + + Args: + *args: Variable number of values to check + + Returns: + True if all arguments are truthy + """ + return all(bool(arg) for arg in args) + + +def not_func(value: Any) -> bool: + """Logical NOT - returns the opposite boolean value. + + This is equivalent to the PowerFx Not() function. + + Args: + value: The value to negate + + Returns: + True if value is falsy, False if truthy + """ + return not bool(value) + + +def count_rows(table: Any) -> int: + """Count the number of rows/items in a table/list. + + This is equivalent to the PowerFx CountRows() function. + + Args: + table: A list or table-like object + + Returns: + The number of rows/items + """ + if table is None: + return 0 + if isinstance(table, (list, tuple)): + return len(cast(list[Any], table)) + if isinstance(table, dict): + return len(cast(dict[str, Any], table)) + return 0 + + +def first(table: Any) -> Any: + """Get the first item from a table/list. + + This is equivalent to the PowerFx First() function. + + Args: + table: A list or table-like object + + Returns: + The first item, or None if empty + """ + if table is None: + return None + if isinstance(table, (list, tuple)): + table_list = cast(list[Any], table) + if len(table_list) > 0: + return table_list[0] + return None + + +def last(table: Any) -> Any: + """Get the last item from a table/list. + + This is equivalent to the PowerFx Last() function. + + Args: + table: A list or table-like object + + Returns: + The last item, or None if empty + """ + if table is None: + return None + if isinstance(table, (list, tuple)): + table_list = cast(list[Any], table) + if len(table_list) > 0: + return table_list[-1] + return None + + +def find(substring: str | None, text: str | None) -> int | None: + """Find the position of a substring within text. + + This is equivalent to the PowerFx Find() function. + Returns None (Blank) if not found, otherwise 1-based index. + + Args: + substring: The substring to find + text: The text to search in + + Returns: + 1-based index if found, None (Blank) if not found + """ + if substring is None or text is None: + return None + try: + index = str(text).find(str(substring)) + return index + 1 if index >= 0 else None + except (TypeError, ValueError): + return None + + +def upper(text: str | None) -> str: + """Convert text to uppercase. + + This is equivalent to the PowerFx Upper() function. + + Args: + text: The text to convert + + Returns: + Uppercase text + """ + if text is None: + return "" + return str(text).upper() + + +def lower(text: str | None) -> str: + """Convert text to lowercase. + + This is equivalent to the PowerFx Lower() function. + + Args: + text: The text to convert + + Returns: + Lowercase text + """ + if text is None: + return "" + return str(text).lower() + + +def concat_strings(*args: Any) -> str: + """Concatenate multiple string arguments. + + This is equivalent to the PowerFx Concat() function for string concatenation. + + Args: + *args: Variable number of values to concatenate + + Returns: + Concatenated string + """ + return "".join(str(arg) if arg is not None else "" for arg in args) + + +def concat_text(table: Any, field: str | None = None, separator: str = "") -> str: + """Concatenate values from a table/list. + + This is equivalent to the PowerFx Concat() function. + + Args: + table: A list of items + field: Optional field name to extract from each item + separator: Separator between values + + Returns: + Concatenated string + """ + if table is None: + return "" + if not isinstance(table, (list, tuple)): + return str(table) + + values: list[str] = [] + for item in cast(list[Any], table): + value: Any = None + if field and isinstance(item, dict): + item_dict = cast(dict[str, Any], item) + value = item_dict.get(field, "") + elif field and hasattr(item, field): + value = getattr(item, field, "") + else: + value = item + values.append(str(value) if value is not None else "") + + return separator.join(values) + + +def for_all(table: Any, expression: str, field_mapping: dict[str, str] | None = None) -> list[Any]: + """Apply an expression to each row of a table. + + This is equivalent to the PowerFx ForAll() function. + + Args: + table: A list of records + expression: A string expression that references item fields + field_mapping: Optional dict mapping placeholder names to field names + + Returns: + List of results from applying expression to each row + + Note: + The expression can use field names directly from the record. + For example: ForAll(items, "$" & name & ": " & description) + """ + if table is None or not isinstance(table, (list, tuple)): + return [] + + results: list[Any] = [] + for item in cast(list[Any], table): + # If item is a dict, we can directly substitute field values + if isinstance(item, dict): + item_dict = cast(dict[str, Any], item) + # The expression is typically already evaluated by the expression parser + # This function primarily handles table iteration + # Return the item itself for further processing + results.append(item_dict) + else: + results.append(item) + + return results + + +def search_table(table: Any, value: Any, column: str) -> list[Any]: + """Search for rows in a table where a column matches a value. + + This is equivalent to the PowerFx Search() function. + + Args: + table: A list of records + value: The value to search for + column: The column name to search in + + Returns: + List of matching records + """ + if table is None or not isinstance(table, (list, tuple)): + return [] + + results: list[Any] = [] + search_value = str(value).lower() if value else "" + + for item in cast(list[Any], table): + item_value: Any = None + if isinstance(item, dict): + item_dict = cast(dict[str, Any], item) + item_value = item_dict.get(column, "") + elif hasattr(item, column): + item_value = getattr(item, column, "") + else: + continue + + # Case-insensitive contains search + if search_value in str(item_value).lower(): + results.append(item) + + return results + + +# Registry of custom functions +CUSTOM_FUNCTIONS: dict[str, Any] = { + "MessageText": message_text, + "UserMessage": user_message, + "AssistantMessage": assistant_message, + "AgentMessage": agent_message, # .NET compatibility alias for AssistantMessage + "SystemMessage": system_message, + "If": if_func, + "IsBlank": is_blank, + "Or": or_func, + "And": and_func, + "Not": not_func, + "CountRows": count_rows, + "First": first, + "Last": last, + "Find": find, + "Upper": upper, + "Lower": lower, + "Concat": concat_strings, + "Search": search_table, + "ForAll": for_all, +} diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_state.py b/python/packages/declarative/agent_framework_declarative/_workflows/_state.py new file mode 100644 index 0000000000..9fe57b83f5 --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_state.py @@ -0,0 +1,643 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""WorkflowState manages PowerFx variables during declarative workflow execution. + +This module provides state management for declarative workflows, handling: +- Workflow inputs (read-only) +- Turn-scoped variables +- Workflow outputs +- Agent results and context +""" + +from collections.abc import Mapping +from typing import Any, cast + +from agent_framework import get_logger + +try: + from powerfx import Engine + + _powerfx_engine: Engine | None = Engine() +except (ImportError, RuntimeError): + # ImportError: powerfx package not installed + # RuntimeError: .NET runtime not available or misconfigured + _powerfx_engine = None + +logger = get_logger("agent_framework.declarative.workflows") + + +class WorkflowState: + """Manages variables and state during declarative workflow execution. + + WorkflowState provides a unified interface for: + + - Reading workflow inputs (immutable after initialization) + - Managing Local-scoped variables that persist across actions + - Storing agent results and making them available to subsequent actions + - Evaluating PowerFx expressions with the current state as context + + The state is organized into namespaces that mirror the .NET implementation: + + - Workflow.Inputs: Initial inputs to the workflow + - Workflow.Outputs: Values to be returned from the workflow + - Local: Variables that persist within the current workflow turn + - System: System-level variables (ConversationId, LastMessage, etc.) + - Agent: Results from the most recent agent invocation + - Conversation: Conversation history and messages + + Examples: + .. code-block:: python + + from agent_framework_declarative import WorkflowState + + # Initialize with inputs + state = WorkflowState(inputs={"query": "Hello", "user_id": "123"}) + + # Access inputs (read-only) + query = state.get("Workflow.Inputs.query") # "Hello" + + # Set Local-scoped variables + state.set("Local.results", []) + state.append("Local.results", "item1") + state.append("Local.results", "item2") + + # Set workflow outputs + state.set("Workflow.Outputs.response", "Completed") + + .. code-block:: python + + from agent_framework_declarative import WorkflowState + + # PowerFx expression evaluation + state = WorkflowState(inputs={"name": "World"}) + result = state.eval("=Concat('Hello ', Workflow.Inputs.name)") + # result: "Hello World" + + # Non-PowerFx strings are returned as-is + plain = state.eval("Hello World") + # plain: "Hello World" + + .. code-block:: python + + from agent_framework_declarative import WorkflowState + + # Working with agent results + state = WorkflowState() + state.set_agent_result( + text="The answer is 42.", + messages=[], + tool_calls=[], + ) + + # Access agent result in subsequent actions + response = state.get("Agent.text") # "The answer is 42." + """ + + def __init__( + self, + inputs: Mapping[str, Any] | None = None, + ) -> None: + """Initialize workflow state with optional inputs. + + Args: + inputs: Initial inputs to the workflow. These become available + as Workflow.Inputs.* and are immutable after initialization. + """ + self._inputs: dict[str, Any] = dict(inputs) if inputs else {} + self._local: dict[str, Any] = {} + self._outputs: dict[str, Any] = {} + self._system: dict[str, Any] = { + "ConversationId": "default", + "LastMessage": {"Text": "", "Id": ""}, + "LastMessageText": "", + "LastMessageId": "", + } + self._agent: dict[str, Any] = {} + self._conversation: dict[str, Any] = { + "messages": [], + "history": [], + } + self._custom: dict[str, Any] = {} + + @property + def inputs(self) -> Mapping[str, Any]: + """Get the workflow inputs (read-only).""" + return self._inputs + + @property + def outputs(self) -> dict[str, Any]: + """Get the workflow outputs.""" + return self._outputs + + @property + def local(self) -> dict[str, Any]: + """Get the Local-scoped variables.""" + return self._local + + @property + def system(self) -> dict[str, Any]: + """Get the System-scoped variables.""" + return self._system + + @property + def agent(self) -> dict[str, Any]: + """Get the most recent agent result.""" + return self._agent + + @property + def conversation(self) -> dict[str, Any]: + """Get the conversation state.""" + return self._conversation + + def get(self, path: str, default: Any = None) -> Any: + """Get a value from the state using a dot-notated path. + + Args: + path: Dot-notated path like 'Local.results' or 'Workflow.Inputs.query' + default: Default value if path doesn't exist + + Returns: + The value at the path, or default if not found + """ + parts = path.split(".") + if not parts: + return default + + namespace = parts[0] + remaining = parts[1:] + + # Handle Workflow.Inputs and Workflow.Outputs specially + if namespace == "Workflow" and remaining: + sub_namespace = remaining[0] + remaining = remaining[1:] + if sub_namespace == "Inputs": + obj: Any = self._inputs + elif sub_namespace == "Outputs": + obj = self._outputs + else: + return default + elif namespace == "Local": + obj = self._local + elif namespace == "System": + obj = self._system + elif namespace == "Agent": + obj = self._agent + elif namespace == "Conversation": + obj = self._conversation + else: + # Try custom namespace + obj = self._custom.get(namespace, default) + if obj is default: + return default + + # Navigate the remaining path + for part in remaining: + if isinstance(obj, dict): + obj_dict: dict[str, Any] = cast(dict[str, Any], obj) + obj = obj_dict.get(part, default) + if obj is default: + return default + elif hasattr(obj, part): + obj = getattr(obj, part) + else: + return default + + return obj + + def set(self, path: str, value: Any) -> None: + """Set a value in the state using a dot-notated path. + + Args: + path: Dot-notated path like 'Local.results' or 'Workflow.Outputs.response' + value: The value to set + + Raises: + ValueError: If attempting to set Workflow.Inputs (which is read-only) + """ + parts = path.split(".") + if not parts: + return + + namespace = parts[0] + remaining = parts[1:] + + # Handle Workflow.Inputs and Workflow.Outputs specially + if namespace == "Workflow": + if not remaining: + raise ValueError("Cannot set 'Workflow' directly; use 'Workflow.Outputs.*'") + sub_namespace = remaining[0] + remaining = remaining[1:] + if sub_namespace == "Inputs": + raise ValueError("Cannot modify Workflow.Inputs - they are read-only") + if sub_namespace == "Outputs": + target = self._outputs + else: + raise ValueError(f"Unknown Workflow namespace: {sub_namespace}") + elif namespace == "Local": + target = self._local + elif namespace == "System": + target = self._system + elif namespace == "Agent": + target = self._agent + elif namespace == "Conversation": + target = self._conversation + else: + # Create or use custom namespace + if namespace not in self._custom: + self._custom[namespace] = {} + target = self._custom[namespace] + + # Navigate to the parent and set the value + if not remaining: + # Setting the namespace root itself - this shouldn't happen normally + raise ValueError(f"Cannot replace entire namespace '{namespace}'") + + # Navigate to parent, creating dicts as needed + for part in remaining[:-1]: + if part not in target: + target[part] = {} + target = target[part] + + # Set the final value + target[remaining[-1]] = value + + def append(self, path: str, value: Any) -> None: + """Append a value to a list at the specified path. + + If the path doesn't exist, creates a new list with the value. + If the path exists but isn't a list, raises ValueError. + + Args: + path: Dot-notated path to a list + value: The value to append + + Raises: + ValueError: If the existing value is not a list + """ + existing = self.get(path) + if existing is None: + self.set(path, [value]) + elif isinstance(existing, list): + existing.append(value) + self.set(path, existing) + else: + raise ValueError(f"Cannot append to non-list at path '{path}'") + + def set_agent_result( + self, + text: str | None = None, + messages: list[Any] | None = None, + tool_calls: list[Any] | None = None, + **kwargs: Any, + ) -> None: + """Set the result from the most recent agent invocation. + + This updates the 'agent' namespace with the agent's response, + making it available to subsequent actions via agent.text, agent.messages, etc. + + Args: + text: The text content of the agent's response + messages: The messages from the agent + tool_calls: Any tool calls made by the agent + **kwargs: Additional result data + """ + self._agent = { + "text": text, + "messages": messages or [], + "toolCalls": tool_calls or [], + **kwargs, + } + + def add_conversation_message(self, message: Any) -> None: + """Add a message to the conversation history. + + Args: + message: The message to add (typically a ChatMessage or similar) + """ + self._conversation["messages"].append(message) + self._conversation["history"].append(message) + + def to_powerfx_symbols(self) -> dict[str, Any]: + """Convert the current state to a PowerFx symbols dictionary. + + Returns: + A dictionary suitable for passing to PowerFx Engine.eval() + """ + symbols = { + "Workflow": { + "Inputs": dict(self._inputs), + "Outputs": dict(self._outputs), + }, + "Local": dict(self._local), + "System": dict(self._system), + "Agent": dict(self._agent), + "Conversation": dict(self._conversation), + # Also expose inputs at top level for backward compatibility with =inputs.X syntax + "inputs": dict(self._inputs), + **self._custom, + } + # Debug log the Local symbols to help diagnose type issues + if self._local: + for key, value in self._local.items(): + logger.debug( + f"PowerFx symbol Local.{key}: type={type(value).__name__}, " + f"value_preview={str(value)[:100] if value else None}" + ) + return symbols + + def eval(self, expression: str) -> Any: + """Evaluate a PowerFx expression with the current state. + + Expressions starting with '=' are evaluated as PowerFx. + Other strings are returned as-is (after variable interpolation if applicable). + + Args: + expression: The expression to evaluate + + Returns: + The evaluated result, or the original expression if not a PowerFx expression + """ + if not expression: + return expression + + if not expression.startswith("="): + return expression + + # Strip the leading '=' for evaluation + formula = expression[1:] + + if _powerfx_engine is not None: + # Try PowerFx evaluation first + try: + symbols = self.to_powerfx_symbols() + return _powerfx_engine.eval(formula, symbols=symbols) + except Exception as exc: + logger.warning(f"PowerFx evaluation failed for '{expression[:50]}': {exc}") + # Fall through to simple evaluation + + # Fallback: Simple expression evaluation using custom functions + return self._eval_simple(formula) + + def _eval_simple(self, formula: str) -> Any: + """Simple expression evaluation when PowerFx is not available. + + Supports: + - Variable references: Local.X, System.X, Workflow.Inputs.X + - Simple function calls: IsBlank(x), Find(a, b), etc. + - Simple comparisons: x < 4, x = "value" + - Logical operators: And, Or, Not, ||, ! + - Negation: !expression + + Args: + formula: The formula to evaluate (without leading '=') + + Returns: + The evaluated result + """ + from ._powerfx_functions import CUSTOM_FUNCTIONS + + formula = formula.strip() + + # Handle negation prefix + if formula.startswith("!"): + inner = formula[1:].strip() + result = self._eval_simple(inner) + return not bool(result) + + # Handle Not() function + if formula.startswith("Not(") and formula.endswith(")"): + inner = formula[4:-1].strip() + result = self._eval_simple(inner) + return not bool(result) + + # Handle function calls + for func_name, func in CUSTOM_FUNCTIONS.items(): + if formula.startswith(f"{func_name}(") and formula.endswith(")"): + args_str = formula[len(func_name) + 1 : -1] + # Simple argument parsing (doesn't handle nested calls well) + args = self._parse_function_args(args_str) + evaluated_args = [self._eval_simple(arg) if isinstance(arg, str) else arg for arg in args] + try: + return func(*evaluated_args) + except Exception as e: + logger.warning(f"Function {func_name} failed: {e}") + return formula + + # Handle And operator + if " And " in formula: + parts = formula.split(" And ", 1) + left = self._eval_simple(parts[0]) + right = self._eval_simple(parts[1]) + return bool(left) and bool(right) + + # Handle Or operator (||) + if " || " in formula or " Or " in formula: + parts = formula.split(" || ", 1) if " || " in formula else formula.split(" Or ", 1) + left = self._eval_simple(parts[0]) + right = self._eval_simple(parts[1]) + return bool(left) or bool(right) + + # Handle comparison operators + for op in [" < ", " > ", " <= ", " >= ", " <> ", " = "]: + if op in formula: + parts = formula.split(op, 1) + left = self._eval_simple(parts[0].strip()) + right = self._eval_simple(parts[1].strip()) + if op == " < ": + return left < right + if op == " > ": + return left > right + if op == " <= ": + return left <= right + if op == " >= ": + return left >= right + if op == " <> ": + return left != right + if op == " = ": + return left == right + + # Handle arithmetic operators + if " + " in formula: + parts = formula.split(" + ", 1) + left = self._eval_simple(parts[0].strip()) + right = self._eval_simple(parts[1].strip()) + # Treat None as 0 for arithmetic (PowerFx behavior) + if left is None: + left = 0 + if right is None: + right = 0 + # Try numeric addition first, fall back to string concat + try: + return float(left) + float(right) + except (ValueError, TypeError): + return str(left) + str(right) + + if " - " in formula: + parts = formula.split(" - ", 1) + left = self._eval_simple(parts[0].strip()) + right = self._eval_simple(parts[1].strip()) + # Treat None as 0 for arithmetic (PowerFx behavior) + if left is None: + left = 0 + if right is None: + right = 0 + try: + return float(left) - float(right) + except (ValueError, TypeError): + return formula + + # Handle multiplication + if " * " in formula: + parts = formula.split(" * ", 1) + left = self._eval_simple(parts[0].strip()) + right = self._eval_simple(parts[1].strip()) + # Treat None as 0 for arithmetic (PowerFx behavior) + if left is None: + left = 0 + if right is None: + right = 0 + try: + return float(left) * float(right) + except (ValueError, TypeError): + return formula + + # Handle division with div-by-zero protection + if " / " in formula: + parts = formula.split(" / ", 1) + left = self._eval_simple(parts[0].strip()) + right = self._eval_simple(parts[1].strip()) + # Treat None as 0 for arithmetic (PowerFx behavior) + if left is None: + left = 0 + if right is None: + right = 0 + try: + right_float = float(right) + if right_float == 0: + # PowerFx returns Error for division by zero; we return None (Blank) + logger.warning(f"Division by zero in expression: {formula}") + return None + return float(left) / right_float + except (ValueError, TypeError): + return formula + + # Handle string literals + if (formula.startswith('"') and formula.endswith('"')) or (formula.startswith("'") and formula.endswith("'")): + return formula[1:-1] + + # Handle numeric literals + try: + if "." in formula: + return float(formula) + return int(formula) + except ValueError: + pass + + # Handle boolean literals + if formula.lower() == "true": + return True + if formula.lower() == "false": + return False + + # Handle variable references + if "." in formula: + # For known namespaces, return None if not found (PowerFx semantics) + # rather than the formula string + if formula.startswith(("Local.", "Workflow.", "Agent.", "Conversation.", "System.")): + return self.get(formula) + not_found = object() + value = self.get(formula, default=not_found) + if value is not not_found: + return value + + # Return the formula as-is if we can't evaluate it + return formula + + def _parse_function_args(self, args_str: str) -> list[str]: + """Parse function arguments, handling nested parentheses and strings. + + Args: + args_str: The argument string (without outer parentheses) + + Returns: + List of argument strings + """ + args: list[str] = [] + current = "" + depth = 0 + in_string = False + string_char = None + + for char in args_str: + if char in ('"', "'") and not in_string: + in_string = True + string_char = char + current += char + elif char == string_char and in_string: + in_string = False + string_char = None + current += char + elif char == "(" and not in_string: + depth += 1 + current += char + elif char == ")" and not in_string: + depth -= 1 + current += char + elif char == "," and depth == 0 and not in_string: + args.append(current.strip()) + current = "" + else: + current += char + + if current.strip(): + args.append(current.strip()) + + return args + + def eval_if_expression(self, value: Any) -> Any: + """Evaluate a value if it's a PowerFx expression, otherwise return as-is. + + This is a convenience method that handles both expressions and literals. + + Args: + value: A value that may or may not be a PowerFx expression + + Returns: + The evaluated result if it's an expression, or the original value + """ + if isinstance(value, str): + return self.eval(value) + if isinstance(value, dict): + return {str(k): self.eval_if_expression(v) for k, v in value.items()} + if isinstance(value, list): + return [self.eval_if_expression(item) for item in value] + return value + + def reset_local(self) -> None: + """Reset Local-scoped variables for a new turn. + + This clears the Local namespace while preserving other state. + """ + self._local.clear() + + def reset_agent(self) -> None: + """Reset the agent result for a new agent invocation.""" + self._agent.clear() + + def clone(self) -> "WorkflowState": + """Create a shallow copy of the state. + + Returns: + A new WorkflowState with copied data + """ + import copy + + new_state = WorkflowState() + new_state._inputs = copy.copy(self._inputs) + new_state._local = copy.copy(self._local) + new_state._system = copy.copy(self._system) + new_state._outputs = copy.copy(self._outputs) + new_state._agent = copy.copy(self._agent) + new_state._conversation = copy.copy(self._conversation) + new_state._custom = copy.copy(self._custom) + return new_state diff --git a/python/packages/declarative/pyproject.toml b/python/packages/declarative/pyproject.toml index 3b2e2586a9..052d4a60ac 100644 --- a/python/packages/declarative/pyproject.toml +++ b/python/packages/declarative/pyproject.toml @@ -4,7 +4,7 @@ description = "Declarative specification support for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/declarative/tests/conftest.py b/python/packages/declarative/tests/conftest.py new file mode 100644 index 0000000000..083de75660 --- /dev/null +++ b/python/packages/declarative/tests/conftest.py @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Pytest configuration for declarative tests.""" + +import sys + +import pytest + +# Skip all tests in this directory on Python 3.14+ because powerfx doesn't support it yet +if sys.version_info >= (3, 14): + collect_ignore_glob = ["test_*.py"] + + +def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: + """Skip all declarative tests on Python 3.14+ due to powerfx incompatibility.""" + if sys.version_info >= (3, 14): + skip_marker = pytest.mark.skip(reason="powerfx does not support Python 3.14+") + for item in items: + if "declarative" in str(item.fspath): + item.add_marker(skip_marker) diff --git a/python/packages/declarative/tests/test_additional_handlers.py b/python/packages/declarative/tests/test_additional_handlers.py new file mode 100644 index 0000000000..8eb5e40ee7 --- /dev/null +++ b/python/packages/declarative/tests/test_additional_handlers.py @@ -0,0 +1,348 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for additional action handlers (conversation, variables, etc.).""" + +import pytest + +import agent_framework_declarative._workflows._actions_basic # noqa: F401 +import agent_framework_declarative._workflows._actions_control_flow # noqa: F401 +from agent_framework_declarative._workflows._handlers import get_action_handler +from agent_framework_declarative._workflows._state import WorkflowState + + +def create_action_context(action: dict, state: WorkflowState | None = None): + """Create a minimal action context for testing.""" + from agent_framework_declarative._workflows._handlers import ActionContext + + if state is None: + state = WorkflowState() + + async def execute_actions(actions, state): + for act in actions: + handler = get_action_handler(act.get("kind")) + if handler: + async for event in handler( + ActionContext( + state=state, + action=act, + execute_actions=execute_actions, + agents={}, + bindings={}, + ) + ): + yield event + + return ActionContext( + state=state, + action=action, + execute_actions=execute_actions, + agents={}, + bindings={}, + ) + + +class TestSetTextVariableHandler: + """Tests for SetTextVariable action handler.""" + + @pytest.mark.asyncio + async def test_set_text_variable_simple(self): + """Test setting a simple text variable.""" + ctx = create_action_context({ + "kind": "SetTextVariable", + "variable": "Local.greeting", + "value": "Hello, World!", + }) + + handler = get_action_handler("SetTextVariable") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.greeting") == "Hello, World!" + + @pytest.mark.asyncio + async def test_set_text_variable_with_interpolation(self): + """Test setting text with variable interpolation.""" + state = WorkflowState() + state.set("Local.name", "Alice") + + ctx = create_action_context( + { + "kind": "SetTextVariable", + "variable": "Local.message", + "value": "Hello, {Local.name}!", + }, + state=state, + ) + + handler = get_action_handler("SetTextVariable") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.message") == "Hello, Alice!" + + +class TestResetVariableHandler: + """Tests for ResetVariable action handler.""" + + @pytest.mark.asyncio + async def test_reset_variable(self): + """Test resetting a variable to None.""" + state = WorkflowState() + state.set("Local.counter", 5) + + ctx = create_action_context( + { + "kind": "ResetVariable", + "variable": "Local.counter", + }, + state=state, + ) + + handler = get_action_handler("ResetVariable") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.counter") is None + + +class TestSetMultipleVariablesHandler: + """Tests for SetMultipleVariables action handler.""" + + @pytest.mark.asyncio + async def test_set_multiple_variables(self): + """Test setting multiple variables at once.""" + ctx = create_action_context({ + "kind": "SetMultipleVariables", + "variables": [ + {"variable": "Local.a", "value": 1}, + {"variable": "Local.b", "value": 2}, + {"variable": "Local.c", "value": "three"}, + ], + }) + + handler = get_action_handler("SetMultipleVariables") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.a") == 1 + assert ctx.state.get("Local.b") == 2 + assert ctx.state.get("Local.c") == "three" + + +class TestClearAllVariablesHandler: + """Tests for ClearAllVariables action handler.""" + + @pytest.mark.asyncio + async def test_clear_all_variables(self): + """Test clearing all turn-scoped variables.""" + state = WorkflowState() + state.set("Local.a", 1) + state.set("Local.b", 2) + state.set("Workflow.Outputs.result", "kept") + + ctx = create_action_context( + { + "kind": "ClearAllVariables", + }, + state=state, + ) + + handler = get_action_handler("ClearAllVariables") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.a") is None + assert ctx.state.get("Local.b") is None + # Workflow outputs should be preserved + assert ctx.state.get("Workflow.Outputs.result") == "kept" + + +class TestCreateConversationHandler: + """Tests for CreateConversation action handler.""" + + @pytest.mark.asyncio + async def test_create_conversation_with_output_binding(self): + """Test creating a new conversation with output variable binding. + + The conversationId field specifies the OUTPUT variable where the + auto-generated conversation ID is stored. + """ + ctx = create_action_context({ + "kind": "CreateConversation", + "conversationId": "Local.myConvId", # Output variable + }) + + handler = get_action_handler("CreateConversation") + _events = [e async for e in handler(ctx)] # noqa: F841 + + # Check conversation was created with auto-generated ID + conversations = ctx.state.get("System.conversations") + assert conversations is not None + assert len(conversations) == 1 + + # Get the generated ID + generated_id = list(conversations.keys())[0] + assert conversations[generated_id]["messages"] == [] + + # Check output binding - the ID should be stored in the specified variable + assert ctx.state.get("Local.myConvId") == generated_id + + @pytest.mark.asyncio + async def test_create_conversation_legacy_output(self): + """Test creating a conversation with legacy output binding.""" + ctx = create_action_context({ + "kind": "CreateConversation", + "output": { + "conversationId": "Local.myConvId", + }, + }) + + handler = get_action_handler("CreateConversation") + _events = [e async for e in handler(ctx)] # noqa: F841 + + # Check conversation was created + conversations = ctx.state.get("System.conversations") + assert conversations is not None + assert len(conversations) == 1 + + # Get the generated ID + generated_id = list(conversations.keys())[0] + + # Check legacy output binding + assert ctx.state.get("Local.myConvId") == generated_id + + @pytest.mark.asyncio + async def test_create_conversation_auto_id(self): + """Test creating a conversation with auto-generated ID.""" + ctx = create_action_context({ + "kind": "CreateConversation", + }) + + handler = get_action_handler("CreateConversation") + _events = [e async for e in handler(ctx)] # noqa: F841 + + # Check conversation was created with some ID + conversations = ctx.state.get("System.conversations") + assert conversations is not None + assert len(conversations) == 1 + + +class TestAddConversationMessageHandler: + """Tests for AddConversationMessage action handler.""" + + @pytest.mark.asyncio + async def test_add_conversation_message(self): + """Test adding a message to a conversation.""" + state = WorkflowState() + state.set( + "System.conversations", + { + "conv-123": {"id": "conv-123", "messages": []}, + }, + ) + + ctx = create_action_context( + { + "kind": "AddConversationMessage", + "conversationId": "conv-123", + "message": { + "role": "user", + "content": "Hello!", + }, + }, + state=state, + ) + + handler = get_action_handler("AddConversationMessage") + _events = [e async for e in handler(ctx)] # noqa: F841 + + conversations = ctx.state.get("System.conversations") + assert len(conversations["conv-123"]["messages"]) == 1 + assert conversations["conv-123"]["messages"][0]["content"] == "Hello!" + + +class TestEndWorkflowHandler: + """Tests for EndWorkflow action handler.""" + + @pytest.mark.asyncio + async def test_end_workflow_signal(self): + """Test that EndWorkflow emits correct signal.""" + from agent_framework_declarative._workflows._actions_control_flow import EndWorkflowSignal + + ctx = create_action_context({ + "kind": "EndWorkflow", + "reason": "Completed successfully", + }) + + handler = get_action_handler("EndWorkflow") + events = [e async for e in handler(ctx)] + + assert len(events) == 1 + assert isinstance(events[0], EndWorkflowSignal) + assert events[0].reason == "Completed successfully" + + +class TestEndConversationHandler: + """Tests for EndConversation action handler.""" + + @pytest.mark.asyncio + async def test_end_conversation_signal(self): + """Test that EndConversation emits correct signal.""" + from agent_framework_declarative._workflows._actions_control_flow import EndConversationSignal + + ctx = create_action_context({ + "kind": "EndConversation", + "conversationId": "conv-123", + }) + + handler = get_action_handler("EndConversation") + events = [e async for e in handler(ctx)] + + assert len(events) == 1 + assert isinstance(events[0], EndConversationSignal) + assert events[0].conversation_id == "conv-123" + + +class TestConditionGroupWithElseActions: + """Tests for ConditionGroup with elseActions.""" + + @pytest.mark.asyncio + async def test_condition_group_else_actions(self): + """Test that elseActions execute when no condition matches.""" + ctx = create_action_context({ + "kind": "ConditionGroup", + "conditions": [ + { + "condition": False, + "actions": [ + {"kind": "SetValue", "path": "Local.result", "value": "matched"}, + ], + }, + ], + "elseActions": [ + {"kind": "SetValue", "path": "Local.result", "value": "else"}, + ], + }) + + handler = get_action_handler("ConditionGroup") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.result") == "else" + + @pytest.mark.asyncio + async def test_condition_group_match_skips_else(self): + """Test that elseActions don't execute when a condition matches.""" + ctx = create_action_context({ + "kind": "ConditionGroup", + "conditions": [ + { + "condition": True, + "actions": [ + {"kind": "SetValue", "path": "Local.result", "value": "matched"}, + ], + }, + ], + "elseActions": [ + {"kind": "SetValue", "path": "Local.result", "value": "else"}, + ], + }) + + handler = get_action_handler("ConditionGroup") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.result") == "matched" diff --git a/python/packages/declarative/tests/test_declarative_loader.py b/python/packages/declarative/tests/test_declarative_loader.py index 338671b212..b0afb7d5cc 100644 --- a/python/packages/declarative/tests/test_declarative_loader.py +++ b/python/packages/declarative/tests/test_declarative_loader.py @@ -39,6 +39,13 @@ pytestmark = pytest.mark.skipif(sys.version_info >= (3, 14), reason="Skipping on Python 3.14+") +try: + import powerfx # noqa: F401 + + _powerfx_available = True +except (ImportError, RuntimeError): + _powerfx_available = False + @pytest.mark.parametrize( "yaml_content,expected_type,expected_attributes", @@ -456,6 +463,98 @@ def test_agent_schema_dispatch_agent_samples(yaml_file: Path, agent_samples_dir: assert result is not None, f"agent_schema_dispatch returned None for {yaml_file.relative_to(agent_samples_dir)}" +class TestAgentFactoryCreateFromDict: + """Tests for AgentFactory.create_agent_from_dict method.""" + + def test_create_agent_from_dict_parses_prompt_agent(self): + """Test that create_agent_from_dict correctly parses a PromptAgent definition.""" + from unittest.mock import MagicMock + + from agent_framework_declarative import AgentFactory + + agent_def = { + "kind": "Prompt", + "name": "TestAgent", + "description": "A test agent", + "instructions": "You are a helpful assistant.", + } + + # Use a pre-configured chat client to avoid needing model + mock_client = MagicMock() + mock_client.create_agent.return_value = MagicMock() + + factory = AgentFactory(chat_client=mock_client) + agent = factory.create_agent_from_dict(agent_def) + + assert agent is not None + + def test_create_agent_from_dict_matches_yaml(self): + """Test that create_agent_from_dict produces same result as create_agent_from_yaml.""" + from unittest.mock import MagicMock + + from agent_framework_declarative import AgentFactory + + yaml_content = """ +kind: Prompt +name: TestAgent +description: A test agent +instructions: You are a helpful assistant. +""" + + agent_def = { + "kind": "Prompt", + "name": "TestAgent", + "description": "A test agent", + "instructions": "You are a helpful assistant.", + } + + # Use a pre-configured chat client to avoid needing model + mock_client = MagicMock() + mock_client.create_agent.return_value = MagicMock() + + factory = AgentFactory(chat_client=mock_client) + + # Create from YAML string + agent_from_yaml = factory.create_agent_from_yaml(yaml_content) + + # Create from dict + agent_from_dict = factory.create_agent_from_dict(agent_def) + + # Both should produce agents with same name + assert agent_from_yaml.name == agent_from_dict.name + assert agent_from_yaml.description == agent_from_dict.description + + def test_create_agent_from_dict_invalid_kind_raises(self): + """Test that non-PromptAgent kind raises DeclarativeLoaderError.""" + from agent_framework_declarative import AgentFactory + from agent_framework_declarative._loader import DeclarativeLoaderError + + # Resource kind (not PromptAgent) + agent_def = { + "kind": "Resource", + "name": "TestResource", + } + + factory = AgentFactory() + with pytest.raises(DeclarativeLoaderError, match="Only definitions for a PromptAgent are supported"): + factory.create_agent_from_dict(agent_def) + + def test_create_agent_from_dict_without_model_or_client_raises(self): + """Test that missing both model and chat_client raises DeclarativeLoaderError.""" + from agent_framework_declarative import AgentFactory + from agent_framework_declarative._loader import DeclarativeLoaderError + + agent_def = { + "kind": "Prompt", + "name": "TestAgent", + "instructions": "You are helpful.", + } + + factory = AgentFactory() + with pytest.raises(DeclarativeLoaderError, match="ChatClient must be provided"): + factory.create_agent_from_dict(agent_def) + + class TestAgentFactorySafeMode: """Tests for AgentFactory safe_mode parameter.""" @@ -499,6 +598,7 @@ def test_agent_factory_safe_mode_blocks_env_in_yaml(self, monkeypatch): # The description should NOT be resolved from env (PowerFx fails, returns original) assert agent.description == "=Env.TEST_DESCRIPTION" + @pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available") def test_agent_factory_safe_mode_false_allows_env_in_yaml(self, monkeypatch): """Test that safe_mode=False allows environment variable access in YAML parsing.""" from unittest.mock import MagicMock @@ -558,6 +658,7 @@ def test_agent_factory_safe_mode_with_api_key_connection(self, monkeypatch): finally: _safe_mode_context.reset(token) + @pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available") def test_agent_factory_safe_mode_false_resolves_api_key(self, monkeypatch): """Test safe_mode=False resolves API key from environment.""" from agent_framework_declarative._models import _safe_mode_context diff --git a/python/packages/declarative/tests/test_declarative_models.py b/python/packages/declarative/tests/test_declarative_models.py index 7f7357eda1..f7a56f2c96 100644 --- a/python/packages/declarative/tests/test_declarative_models.py +++ b/python/packages/declarative/tests/test_declarative_models.py @@ -839,6 +839,16 @@ def test_environment_variable_from_dict(self): assert env_var.value == "secret123" +# Check if PowerFx is available +try: + from powerfx import Engine as _PfxEngine + + _PfxEngine() + _powerfx_available = True +except (ImportError, RuntimeError): + _powerfx_available = False + + class TestTryPowerfxEval: """Tests for _try_powerfx_eval function.""" @@ -856,6 +866,7 @@ def test_empty_string_returns_empty(self): """Test that empty strings are returned as empty.""" assert _try_powerfx_eval("") == "" + @pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available") def test_simple_powerfx_expressions(self): """Test simple PowerFx expressions.""" from decimal import Decimal @@ -868,6 +879,7 @@ def test_simple_powerfx_expressions(self): assert _try_powerfx_eval('="hello"') == "hello" assert _try_powerfx_eval('="test value"') == "test value" + @pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available") def test_env_variable_access(self, monkeypatch): """Test accessing environment variables using =Env. pattern.""" # Set up test environment variables @@ -885,6 +897,7 @@ def test_env_variable_access(self, monkeypatch): finally: _safe_mode_context.reset(token) + @pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available") def test_env_variable_with_string_concatenation(self, monkeypatch): """Test env variables with string concatenation operator.""" monkeypatch.setenv("BASE_URL", "https://api.example.com") @@ -903,6 +916,7 @@ def test_env_variable_with_string_concatenation(self, monkeypatch): finally: _safe_mode_context.reset(token) + @pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available") def test_string_comparison_operators(self, monkeypatch): """Test PowerFx string comparison operators.""" monkeypatch.setenv("ENV_MODE", "production") @@ -920,6 +934,7 @@ def test_string_comparison_operators(self, monkeypatch): finally: _safe_mode_context.reset(token) + @pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available") def test_string_in_operator(self): """Test PowerFx 'in' operator for substring testing (case-insensitive).""" # Substring test - case insensitive - returns bool @@ -927,6 +942,7 @@ def test_string_in_operator(self): assert _try_powerfx_eval('="THE" in "The keyboard and the monitor"') is True assert _try_powerfx_eval('="xyz" in "The keyboard and the monitor"') is False + @pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available") def test_string_exactin_operator(self): """Test PowerFx 'exactin' operator for substring testing (case-sensitive).""" # Substring test - case sensitive - returns bool @@ -934,6 +950,7 @@ def test_string_exactin_operator(self): assert _try_powerfx_eval('="windows" exactin "To display windows in the Windows operating system"') is True assert _try_powerfx_eval('="WINDOWS" exactin "To display windows in the Windows operating system"') is False + @pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available") def test_logical_operators_with_strings(self): """Test PowerFx logical operators (And, Or, Not) with string comparisons.""" # And operator - returns bool @@ -957,6 +974,7 @@ def test_logical_operators_with_strings(self): # ! operator (alternative syntax) - returns bool assert _try_powerfx_eval('=!("a" = "b")') is True + @pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available") def test_parentheses_for_precedence(self): """Test using parentheses to control operator precedence.""" from decimal import Decimal @@ -969,6 +987,7 @@ def test_parentheses_for_precedence(self): result = _try_powerfx_eval('=("a" = "a" Or "b" = "c") And "d" = "d"') assert result is True + @pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available") def test_env_with_special_characters(self, monkeypatch): """Test env variables containing special characters in values.""" monkeypatch.setenv("URL_WITH_QUERY", "https://example.com?param=value") @@ -999,6 +1018,7 @@ def test_safe_mode_blocks_env_access(self, monkeypatch): finally: _safe_mode_context.reset(token) + @pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available") def test_safe_mode_context_isolation(self, monkeypatch): """Test that safe_mode context variable properly isolates env access.""" monkeypatch.setenv("TEST_VAR", "test_value") diff --git a/python/packages/declarative/tests/test_external_input.py b/python/packages/declarative/tests/test_external_input.py new file mode 100644 index 0000000000..bbe55fd174 --- /dev/null +++ b/python/packages/declarative/tests/test_external_input.py @@ -0,0 +1,286 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for human-in-the-loop action handlers.""" + +import pytest + +from agent_framework_declarative._workflows._handlers import ActionContext, get_action_handler +from agent_framework_declarative._workflows._human_input import ( + QuestionRequest, + process_external_loop, + validate_input_response, +) +from agent_framework_declarative._workflows._state import WorkflowState + + +def create_action_context(action: dict, state: WorkflowState | None = None): + """Create a minimal action context for testing.""" + if state is None: + state = WorkflowState() + + async def execute_actions(actions, state): + for act in actions: + handler = get_action_handler(act.get("kind")) + if handler: + async for event in handler( + ActionContext( + state=state, + action=act, + execute_actions=execute_actions, + agents={}, + bindings={}, + ) + ): + yield event + + return ActionContext( + state=state, + action=action, + execute_actions=execute_actions, + agents={}, + bindings={}, + ) + + +class TestQuestionHandler: + """Tests for Question action handler.""" + + @pytest.mark.asyncio + async def test_question_emits_request_info_event(self): + """Test that Question handler emits QuestionRequest.""" + ctx = create_action_context({ + "kind": "Question", + "id": "ask_name", + "variable": "Local.userName", + "prompt": "What is your name?", + }) + + handler = get_action_handler("Question") + events = [e async for e in handler(ctx)] + + assert len(events) == 1 + assert isinstance(events[0], QuestionRequest) + assert events[0].request_id == "ask_name" + assert events[0].prompt == "What is your name?" + assert events[0].variable == "Local.userName" + + @pytest.mark.asyncio + async def test_question_with_choices(self): + """Test Question with multiple choice options.""" + ctx = create_action_context({ + "kind": "Question", + "id": "ask_choice", + "variable": "Local.selection", + "prompt": "Select an option:", + "choices": ["Option A", "Option B", "Option C"], + "default": "Option A", + }) + + handler = get_action_handler("Question") + events = [e async for e in handler(ctx)] + + assert len(events) == 1 + event = events[0] + assert isinstance(event, QuestionRequest) + assert event.choices == ["Option A", "Option B", "Option C"] + assert event.default_value == "Option A" + + @pytest.mark.asyncio + async def test_question_with_validation(self): + """Test Question with validation rules.""" + ctx = create_action_context({ + "kind": "Question", + "id": "ask_email", + "variable": "Local.email", + "prompt": "Enter your email:", + "validation": { + "required": True, + "pattern": r"^[\w\.-]+@[\w\.-]+\.\w+$", + }, + }) + + handler = get_action_handler("Question") + events = [e async for e in handler(ctx)] + + assert len(events) == 1 + event = events[0] + assert event.validation == { + "required": True, + "pattern": r"^[\w\.-]+@[\w\.-]+\.\w+$", + } + + +class TestRequestExternalInputHandler: + """Tests for RequestExternalInput action handler.""" + + @pytest.mark.asyncio + async def test_request_external_input(self): + """Test RequestExternalInput handler emits event.""" + ctx = create_action_context({ + "kind": "RequestExternalInput", + "id": "get_approval", + "variable": "Local.approval", + "prompt": "Please approve or reject", + "timeout": 300, + }) + + handler = get_action_handler("RequestExternalInput") + events = [e async for e in handler(ctx)] + + assert len(events) == 1 + event = events[0] + assert isinstance(event, QuestionRequest) + assert event.request_id == "get_approval" + assert event.variable == "Local.approval" + assert event.validation == {"timeout": 300} + + +class TestWaitForInputHandler: + """Tests for WaitForInput action handler.""" + + @pytest.mark.asyncio + async def test_wait_for_input(self): + """Test WaitForInput handler.""" + ctx = create_action_context({ + "kind": "WaitForInput", + "id": "wait", + "variable": "Local.response", + "message": "Waiting...", + }) + + handler = get_action_handler("WaitForInput") + events = [e async for e in handler(ctx)] + + assert len(events) == 1 + event = events[0] + assert isinstance(event, QuestionRequest) + assert event.request_id == "wait" + assert event.prompt == "Waiting..." + + +class TestProcessExternalLoop: + """Tests for process_external_loop helper function.""" + + def test_no_external_loop(self): + """Test when no external loop is configured.""" + state = WorkflowState() + result, expr = process_external_loop({}, state) + + assert result is False + assert expr is None + + def test_external_loop_true_condition(self): + """Test when external loop condition evaluates to true.""" + state = WorkflowState() + state.set("Local.isComplete", False) + + input_config = { + "externalLoop": { + "when": "=!Local.isComplete", + }, + } + + result, expr = process_external_loop(input_config, state) + + # !False = True, so loop should continue + assert result is True + assert expr == "=!Local.isComplete" + + def test_external_loop_false_condition(self): + """Test when external loop condition evaluates to false.""" + state = WorkflowState() + state.set("Local.isComplete", True) + + input_config = { + "externalLoop": { + "when": "=!Local.isComplete", + }, + } + + result, expr = process_external_loop(input_config, state) + + # !True = False, so loop should stop + assert result is False + + +class TestValidateInputResponse: + """Tests for validate_input_response helper function.""" + + def test_no_validation(self): + """Test with no validation rules.""" + is_valid, error = validate_input_response("any value", None) + assert is_valid is True + assert error is None + + def test_required_valid(self): + """Test required validation with valid value.""" + is_valid, error = validate_input_response("value", {"required": True}) + assert is_valid is True + assert error is None + + def test_required_empty_string(self): + """Test required validation with empty string.""" + is_valid, error = validate_input_response("", {"required": True}) + assert is_valid is False + assert "required" in error.lower() + + def test_required_none(self): + """Test required validation with None.""" + is_valid, error = validate_input_response(None, {"required": True}) + assert is_valid is False + assert "required" in error.lower() + + def test_min_length_valid(self): + """Test minLength validation with valid value.""" + is_valid, error = validate_input_response("hello", {"minLength": 3}) + assert is_valid is True + + def test_min_length_invalid(self): + """Test minLength validation with too short value.""" + is_valid, error = validate_input_response("hi", {"minLength": 3}) + assert is_valid is False + assert "minimum length" in error.lower() + + def test_max_length_valid(self): + """Test maxLength validation with valid value.""" + is_valid, error = validate_input_response("hello", {"maxLength": 10}) + assert is_valid is True + + def test_max_length_invalid(self): + """Test maxLength validation with too long value.""" + is_valid, error = validate_input_response("hello world", {"maxLength": 5}) + assert is_valid is False + assert "maximum length" in error.lower() + + def test_min_value_valid(self): + """Test min validation for numbers.""" + is_valid, error = validate_input_response(10, {"min": 5}) + assert is_valid is True + + def test_min_value_invalid(self): + """Test min validation with too small number.""" + is_valid, error = validate_input_response(3, {"min": 5}) + assert is_valid is False + assert "minimum value" in error.lower() + + def test_max_value_valid(self): + """Test max validation for numbers.""" + is_valid, error = validate_input_response(5, {"max": 10}) + assert is_valid is True + + def test_max_value_invalid(self): + """Test max validation with too large number.""" + is_valid, error = validate_input_response(15, {"max": 10}) + assert is_valid is False + assert "maximum value" in error.lower() + + def test_pattern_valid(self): + """Test pattern validation with matching value.""" + is_valid, error = validate_input_response("test@example.com", {"pattern": r"^[\w\.-]+@[\w\.-]+\.\w+$"}) + assert is_valid is True + + def test_pattern_invalid(self): + """Test pattern validation with non-matching value.""" + is_valid, error = validate_input_response("not-an-email", {"pattern": r"^[\w\.-]+@[\w\.-]+\.\w+$"}) + assert is_valid is False + assert "pattern" in error.lower() diff --git a/python/packages/declarative/tests/test_graph_coverage.py b/python/packages/declarative/tests/test_graph_coverage.py new file mode 100644 index 0000000000..8f9211e850 --- /dev/null +++ b/python/packages/declarative/tests/test_graph_coverage.py @@ -0,0 +1,2682 @@ +# Copyright (c) Microsoft. All rights reserved. +# pyright: reportUnknownParameterType=false, reportUnknownArgumentType=false +# pyright: reportMissingParameterType=false, reportUnknownMemberType=false +# pyright: reportPrivateUsage=false, reportUnknownVariableType=false +# pyright: reportGeneralTypeIssues=false + +from dataclasses import dataclass +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agent_framework_declarative._workflows import ( + ActionComplete, + ActionTrigger, + DeclarativeWorkflowState, +) +from agent_framework_declarative._workflows._declarative_base import ( + ConditionResult, + LoopControl, + LoopIterationResult, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_shared_state() -> MagicMock: + """Create a mock shared state with async get/set/delete methods.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key: str) -> Any: + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key: str, value: Any) -> None: + shared_state._data[key] = value + + async def mock_delete(key: str) -> None: + if key in shared_state._data: + del shared_state._data[key] + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + shared_state.delete = AsyncMock(side_effect=mock_delete) + + return shared_state + + +@pytest.fixture +def mock_context(mock_shared_state: MagicMock) -> MagicMock: + """Create a mock workflow context.""" + ctx = MagicMock() + ctx.shared_state = mock_shared_state + ctx.send_message = AsyncMock() + ctx.yield_output = AsyncMock() + ctx.request_info = AsyncMock() + return ctx + + +# --------------------------------------------------------------------------- +# DeclarativeWorkflowState Tests - Covering _base.py gaps +# --------------------------------------------------------------------------- + + +class TestDeclarativeWorkflowStateExtended: + """Extended tests for DeclarativeWorkflowState covering uncovered code paths.""" + + async def test_get_with_local_namespace(self, mock_shared_state): + """Test Local. namespace mapping.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.myVar", "value123") + + # Access via Local. namespace + result = await state.get("Local.myVar") + assert result == "value123" + + async def test_get_with_system_namespace(self, mock_shared_state): + """Test System. namespace mapping.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("System.ConversationId", "conv-123") + + result = await state.get("System.ConversationId") + assert result == "conv-123" + + async def test_get_with_workflow_namespace(self, mock_shared_state): + """Test Workflow. namespace mapping.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize({"query": "test"}) + + result = await state.get("Workflow.Inputs.query") + assert result == "test" + + async def test_get_with_inputs_shorthand(self, mock_shared_state): + """Test inputs. shorthand namespace mapping.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize({"query": "test"}) + + result = await state.get("Workflow.Inputs.query") + assert result == "test" + + async def test_get_agent_namespace(self, mock_shared_state): + """Test agent namespace access.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Agent.response", "Hello!") + + result = await state.get("Agent.response") + assert result == "Hello!" + + async def test_get_conversation_namespace(self, mock_shared_state): + """Test conversation namespace access.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Conversation.messages", [{"role": "user", "text": "hi"}]) + + result = await state.get("Conversation.messages") + assert result == [{"role": "user", "text": "hi"}] + + async def test_get_custom_namespace(self, mock_shared_state): + """Test custom namespace access.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # Set via direct state data manipulation to create custom namespace + state_data = await state.get_state_data() + state_data["Custom"] = {"myns": {"value": 42}} + await state.set_state_data(state_data) + + result = await state.get("myns.value") + assert result == 42 + + async def test_get_object_attribute_access(self, mock_shared_state): + """Test accessing object attributes via hasattr/getattr path.""" + + @dataclass + class MockObj: + name: str + value: int + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.obj", MockObj(name="test", value=99)) + + result = await state.get("Local.obj.name") + assert result == "test" + + async def test_set_with_local_namespace(self, mock_shared_state): + """Test Local. namespace mapping for set.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + await state.set("Local.myVar", "value123") + result = await state.get("Local.myVar") + assert result == "value123" + + async def test_set_with_system_namespace(self, mock_shared_state): + """Test System. namespace mapping for set.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + await state.set("System.ConversationId", "conv-456") + result = await state.get("System.ConversationId") + assert result == "conv-456" + + async def test_set_workflow_outputs(self, mock_shared_state): + """Test setting workflow outputs.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + await state.set("Workflow.Outputs.result", "done") + outputs = await state.get("Workflow.Outputs") + assert outputs.get("result") == "done" + + async def test_set_workflow_inputs_raises_error(self, mock_shared_state): + """Test that setting Workflow.Inputs raises an error (read-only).""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize({"query": "test"}) + + with pytest.raises(ValueError, match="Cannot modify Workflow.Inputs"): + await state.set("Workflow.Inputs.query", "modified") + + async def test_set_workflow_directly_raises_error(self, mock_shared_state): + """Test that setting 'Workflow' directly raises an error.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + with pytest.raises(ValueError, match="Cannot set 'Workflow' directly"): + await state.set("Workflow", {}) + + async def test_set_unknown_workflow_subnamespace_raises_error(self, mock_shared_state): + """Test unknown workflow sub-namespace raises error.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + with pytest.raises(ValueError, match="Unknown Workflow namespace"): + await state.set("Workflow.unknown.field", "value") + + async def test_set_creates_custom_namespace(self, mock_shared_state): + """Test setting value in custom namespace creates it.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + await state.set("myns.field.nested", "value") + result = await state.get("myns.field.nested") + assert result == "value" + + async def test_set_cannot_replace_entire_namespace(self, mock_shared_state): + """Test that replacing entire namespace raises error.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + with pytest.raises(ValueError, match="Cannot replace entire namespace"): + await state.set("turn", {}) + + async def test_append_to_nonlist_raises_error(self, mock_shared_state): + """Test appending to non-list raises error.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.scalar", "string value") + + with pytest.raises(ValueError, match="Cannot append to non-list"): + await state.append("Local.scalar", "new item") + + async def test_eval_empty_string(self, mock_shared_state): + """Test evaluating empty string returns as-is.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + result = await state.eval("") + assert result == "" + + async def test_eval_non_string_returns_as_is(self, mock_shared_state): + """Test evaluating non-string returns as-is.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # Cast to Any to test the runtime behavior with non-string inputs + result = await state.eval(42) # type: ignore[arg-type] + assert result == 42 + + result = await state.eval([1, 2, 3]) # type: ignore[arg-type] + assert result == [1, 2, 3] + + async def test_eval_simple_and_operator(self, mock_shared_state): + """Test simple And operator evaluation.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.a", True) + await state.set("Local.b", False) + + result = await state.eval("=Local.a And Local.b") + assert result is False + + await state.set("Local.b", True) + result = await state.eval("=Local.a And Local.b") + assert result is True + + async def test_eval_simple_or_operator(self, mock_shared_state): + """Test simple Or operator evaluation.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.a", True) + await state.set("Local.b", False) + + result = await state.eval("=Local.a Or Local.b") + assert result is True + + await state.set("Local.a", False) + result = await state.eval("=Local.a Or Local.b") + assert result is False + + async def test_eval_negation(self, mock_shared_state): + """Test negation (!) evaluation.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.flag", True) + + result = await state.eval("=!Local.flag") + assert result is False + + async def test_eval_not_function(self, mock_shared_state): + """Test Not() function evaluation.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.flag", True) + + result = await state.eval("=Not(Local.flag)") + assert result is False + + async def test_eval_comparison_operators(self, mock_shared_state): + """Test comparison operators.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.x", 5) + await state.set("Local.y", 10) + + assert await state.eval("=Local.x < Local.y") is True + assert await state.eval("=Local.x > Local.y") is False + assert await state.eval("=Local.x <= 5") is True + assert await state.eval("=Local.x >= 5") is True + assert await state.eval("=Local.x <> Local.y") is True + assert await state.eval("=Local.x = 5") is True + + async def test_eval_arithmetic_operators(self, mock_shared_state): + """Test arithmetic operators.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.x", 10) + await state.set("Local.y", 3) + + assert await state.eval("=Local.x + Local.y") == 13 + assert await state.eval("=Local.x - Local.y") == 7 + assert await state.eval("=Local.x * Local.y") == 30 + assert await state.eval("=Local.x / Local.y") == pytest.approx(3.333, rel=0.01) + + async def test_eval_string_literal(self, mock_shared_state): + """Test string literal evaluation.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + result = await state.eval('="hello world"') + assert result == "hello world" + + async def test_eval_float_literal(self, mock_shared_state): + """Test float literal evaluation.""" + from decimal import Decimal + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + result = await state.eval("=3.14") + # Accepts both float (Python fallback) and Decimal (pythonnet/PowerFx) + assert result == 3.14 or result == Decimal("3.14") + + async def test_eval_variable_reference_with_namespace_mappings(self, mock_shared_state): + """Test variable reference with PowerFx symbols.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize({"query": "test"}) + await state.set("Local.myVar", "localValue") + + # Test Local namespace (PowerFx symbol) + result = await state.eval("=Local.myVar") + assert result == "localValue" + + # Test Workflow.Inputs (PowerFx symbol) + result = await state.eval("=Workflow.Inputs.query") + assert result == "test" + + async def test_eval_if_expression_with_dict(self, mock_shared_state): + """Test eval_if_expression recursively evaluates dicts.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.name", "Alice") + + result = await state.eval_if_expression({"greeting": "=Local.name", "static": "hello"}) + assert result == {"greeting": "Alice", "static": "hello"} + + async def test_eval_if_expression_with_list(self, mock_shared_state): + """Test eval_if_expression recursively evaluates lists.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.x", 10) + + result = await state.eval_if_expression(["=Local.x", "static", "=5"]) + assert result == [10, "static", 5] + + async def test_interpolate_string_with_local_vars(self, mock_shared_state): + """Test string interpolation with Local. variables.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.TicketId", "TKT-001") + await state.set("Local.TeamName", "Support") + + result = await state.interpolate_string("Created ticket #{Local.TicketId} for team {Local.TeamName}") + assert result == "Created ticket #TKT-001 for team Support" + + async def test_interpolate_string_with_system_vars(self, mock_shared_state): + """Test string interpolation with System. variables.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("System.ConversationId", "conv-789") + + result = await state.interpolate_string("Conversation: {System.ConversationId}") + assert result == "Conversation: conv-789" + + async def test_interpolate_string_with_none_value(self, mock_shared_state): + """Test string interpolation with None value returns empty string.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + result = await state.interpolate_string("Value: {Local.Missing}") + assert result == "Value: " + + +# --------------------------------------------------------------------------- +# Basic Executors Tests - Covering _executors_basic.py gaps +# --------------------------------------------------------------------------- + + +class TestBasicExecutorsCoverage: + """Tests for basic executors covering uncovered code paths.""" + + async def test_set_variable_executor(self, mock_context, mock_shared_state): + """Test SetVariableExecutor (distinct from SetValueExecutor).""" + from agent_framework_declarative._workflows._executors_basic import ( + SetVariableExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "SetVariable", + "variable": "Local.result", + "value": "test value", + } + executor = SetVariableExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.result") + assert result == "test value" + + async def test_set_variable_executor_with_nested_variable(self, mock_context, mock_shared_state): + """Test SetVariableExecutor with nested variable object.""" + from agent_framework_declarative._workflows._executors_basic import ( + SetVariableExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "SetVariable", + "variable": {"path": "Local.nested"}, + "value": 42, + } + executor = SetVariableExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.nested") + assert result == 42 + + async def test_set_text_variable_executor(self, mock_context, mock_shared_state): + """Test SetTextVariableExecutor.""" + from agent_framework_declarative._workflows._executors_basic import ( + SetTextVariableExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.name", "World") + + action_def = { + "kind": "SetTextVariable", + "variable": "Local.greeting", + "text": "=Local.name", + } + executor = SetTextVariableExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.greeting") + assert result == "World" + + async def test_set_multiple_variables_executor(self, mock_context, mock_shared_state): + """Test SetMultipleVariablesExecutor.""" + from agent_framework_declarative._workflows._executors_basic import ( + SetMultipleVariablesExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "SetMultipleVariables", + "assignments": [ + {"variable": "Local.a", "value": 1}, + {"variable": {"path": "Local.b"}, "value": 2}, + {"path": "Local.c", "value": 3}, + ], + } + executor = SetMultipleVariablesExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + assert await state.get("Local.a") == 1 + assert await state.get("Local.b") == 2 + assert await state.get("Local.c") == 3 + + async def test_append_value_executor(self, mock_context, mock_shared_state): + """Test AppendValueExecutor.""" + from agent_framework_declarative._workflows._executors_basic import ( + AppendValueExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.items", ["a"]) + + action_def = { + "kind": "AppendValue", + "path": "Local.items", + "value": "b", + } + executor = AppendValueExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.items") + assert result == ["a", "b"] + + async def test_reset_variable_executor(self, mock_context, mock_shared_state): + """Test ResetVariableExecutor.""" + from agent_framework_declarative._workflows._executors_basic import ( + ResetVariableExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.myVar", "some value") + + action_def = { + "kind": "ResetVariable", + "variable": "Local.myVar", + } + executor = ResetVariableExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.myVar") + assert result is None + + async def test_clear_all_variables_executor(self, mock_context, mock_shared_state): + """Test ClearAllVariablesExecutor.""" + from agent_framework_declarative._workflows._executors_basic import ( + ClearAllVariablesExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.a", 1) + await state.set("Local.b", 2) + + action_def = {"kind": "ClearAllVariables"} + executor = ClearAllVariablesExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + # Turn namespace should be cleared + assert await state.get("Local.a") is None + assert await state.get("Local.b") is None + + async def test_send_activity_with_dict_activity(self, mock_context, mock_shared_state): + """Test SendActivityExecutor with dict activity containing text field.""" + from agent_framework_declarative._workflows._executors_basic import ( + SendActivityExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.name", "Alice") + + action_def = { + "kind": "SendActivity", + "activity": {"text": "Hello, {Local.name}!"}, + } + executor = SendActivityExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.yield_output.assert_called_once_with("Hello, Alice!") + + async def test_send_activity_with_string_activity(self, mock_context, mock_shared_state): + """Test SendActivityExecutor with string activity.""" + from agent_framework_declarative._workflows._executors_basic import ( + SendActivityExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "SendActivity", + "activity": "Plain text message", + } + executor = SendActivityExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.yield_output.assert_called_once_with("Plain text message") + + async def test_send_activity_with_expression(self, mock_context, mock_shared_state): + """Test SendActivityExecutor evaluates expressions.""" + from agent_framework_declarative._workflows._executors_basic import ( + SendActivityExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.msg", "Dynamic message") + + action_def = { + "kind": "SendActivity", + "activity": "=Local.msg", + } + executor = SendActivityExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.yield_output.assert_called_once_with("Dynamic message") + + async def test_emit_event_executor_graph_mode(self, mock_context, mock_shared_state): + """Test EmitEventExecutor with graph-mode schema (eventName/eventValue).""" + from agent_framework_declarative._workflows._executors_basic import ( + EmitEventExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "EmitEvent", + "eventName": "myEvent", + "eventValue": {"key": "value"}, + } + executor = EmitEventExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.yield_output.assert_called_once() + event_data = mock_context.yield_output.call_args[0][0] + assert event_data["eventName"] == "myEvent" + assert event_data["eventValue"] == {"key": "value"} + + async def test_emit_event_executor_interpreter_mode(self, mock_context, mock_shared_state): + """Test EmitEventExecutor with interpreter-mode schema (event.name/event.data).""" + from agent_framework_declarative._workflows._executors_basic import ( + EmitEventExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "EmitEvent", + "event": { + "name": "interpreterEvent", + "data": {"payload": "test"}, + }, + } + executor = EmitEventExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.yield_output.assert_called_once() + event_data = mock_context.yield_output.call_args[0][0] + assert event_data["eventName"] == "interpreterEvent" + assert event_data["eventValue"] == {"payload": "test"} + + +# --------------------------------------------------------------------------- +# Agent Executors Tests - Covering _executors_agents.py gaps +# --------------------------------------------------------------------------- + + +class TestAgentExecutorsCoverage: + """Tests for agent executors covering uncovered code paths.""" + + async def test_normalize_variable_path_all_cases(self): + """Test _normalize_variable_path with all namespace prefixes.""" + from agent_framework_declarative._workflows._executors_agents import ( + _normalize_variable_path, + ) + + # Local. -> Local. (unchanged) + assert _normalize_variable_path("Local.MyVar") == "Local.MyVar" + + # System. -> System. (unchanged) + assert _normalize_variable_path("System.ConvId") == "System.ConvId" + + # Workflow. -> Workflow. (unchanged) + assert _normalize_variable_path("Workflow.Outputs.result") == "Workflow.Outputs.result" + + # Already has a namespace with dots - pass through + assert _normalize_variable_path("custom.existing") == "custom.existing" + + # No namespace - default to Local. + assert _normalize_variable_path("simpleVar") == "Local.simpleVar" + + async def test_agent_executor_get_agent_name_string(self, mock_context, mock_shared_state): + """Test agent name extraction from simple string config.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + action_def = { + "kind": "InvokeAzureAgent", + "agent": "MyAgent", + } + executor = InvokeAzureAgentExecutor(action_def) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + name = executor._get_agent_name(state) + assert name == "MyAgent" + + async def test_agent_executor_get_agent_name_dict(self, mock_context, mock_shared_state): + """Test agent name extraction from nested dict config.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + action_def = { + "kind": "InvokeAzureAgent", + "agent": {"name": "NestedAgent"}, + } + executor = InvokeAzureAgentExecutor(action_def) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + name = executor._get_agent_name(state) + assert name == "NestedAgent" + + async def test_agent_executor_get_agent_name_legacy(self, mock_context, mock_shared_state): + """Test agent name extraction from agentName (legacy).""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + action_def = { + "kind": "InvokeAzureAgent", + "agentName": "LegacyAgent", + } + executor = InvokeAzureAgentExecutor(action_def) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + name = executor._get_agent_name(state) + assert name == "LegacyAgent" + + async def test_agent_executor_get_input_config_simple(self, mock_context, mock_shared_state): + """Test input config parsing with simple non-dict input.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + action_def = { + "kind": "InvokeAzureAgent", + "agent": "TestAgent", + "input": "simple string input", + } + executor = InvokeAzureAgentExecutor(action_def) + + args, messages, external_loop, max_iterations = executor._get_input_config() + assert args == {} + assert messages == "simple string input" + assert external_loop is None + assert max_iterations == 100 # Default + + async def test_agent_executor_get_input_config_full(self, mock_context, mock_shared_state): + """Test input config parsing with full structured input.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + action_def = { + "kind": "InvokeAzureAgent", + "agent": "TestAgent", + "input": { + "arguments": {"param1": "=Local.value"}, + "messages": "=conversation.messages", + "externalLoop": {"when": "=Local.needsMore", "maxIterations": 50}, + }, + } + executor = InvokeAzureAgentExecutor(action_def) + + args, messages, external_loop, max_iterations = executor._get_input_config() + assert args == {"param1": "=Local.value"} + assert messages == "=conversation.messages" + assert external_loop == "=Local.needsMore" + assert max_iterations == 50 + + async def test_agent_executor_get_output_config_simple(self, mock_context, mock_shared_state): + """Test output config parsing with simple resultProperty.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + action_def = { + "kind": "InvokeAzureAgent", + "agent": "TestAgent", + "resultProperty": "Local.result", + } + executor = InvokeAzureAgentExecutor(action_def) + + messages_var, response_obj, result_prop, auto_send = executor._get_output_config() + assert messages_var is None + assert response_obj is None + assert result_prop == "Local.result" + assert auto_send is True + + async def test_agent_executor_get_output_config_full(self, mock_context, mock_shared_state): + """Test output config parsing with full structured output.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + action_def = { + "kind": "InvokeAzureAgent", + "agent": "TestAgent", + "output": { + "messages": "Local.ResponseMessages", + "responseObject": "Local.ParsedResponse", + "property": "Local.result", + "autoSend": False, + }, + } + executor = InvokeAzureAgentExecutor(action_def) + + messages_var, response_obj, result_prop, auto_send = executor._get_output_config() + assert messages_var == "Local.ResponseMessages" + assert response_obj == "Local.ParsedResponse" + assert result_prop == "Local.result" + assert auto_send is False + + async def test_agent_executor_build_input_text_from_string_messages(self, mock_context, mock_shared_state): + """Test _build_input_text with string messages expression.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.userInput", "Hello agent!") + + action_def = {"kind": "InvokeAzureAgent", "agent": "Test"} + executor = InvokeAzureAgentExecutor(action_def) + + input_text = await executor._build_input_text(state, {}, "=Local.userInput") + assert input_text == "Hello agent!" + + async def test_agent_executor_build_input_text_from_message_list(self, mock_context, mock_shared_state): + """Test _build_input_text extracts text from message list.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set( + "Conversation.messages", + [ + {"role": "user", "content": "First"}, + {"role": "assistant", "content": "Response"}, + {"role": "user", "content": "Last message"}, + ], + ) + + action_def = {"kind": "InvokeAzureAgent", "agent": "Test"} + executor = InvokeAzureAgentExecutor(action_def) + + input_text = await executor._build_input_text(state, {}, "=Conversation.messages") + assert input_text == "Last message" + + async def test_agent_executor_build_input_text_from_message_with_text_attr(self, mock_context, mock_shared_state): + """Test _build_input_text extracts text from message with text attribute.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.messages", [{"text": "From attribute"}]) + + action_def = {"kind": "InvokeAzureAgent", "agent": "Test"} + executor = InvokeAzureAgentExecutor(action_def) + + input_text = await executor._build_input_text(state, {}, "=Local.messages") + assert input_text == "From attribute" + + async def test_agent_executor_build_input_text_fallback_chain(self, mock_context, mock_shared_state): + """Test _build_input_text fallback chain when no messages expression.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize({"query": "workflow input"}) + + action_def = {"kind": "InvokeAzureAgent", "agent": "Test"} + executor = InvokeAzureAgentExecutor(action_def) + + # No messages_expr, so falls back to workflow.inputs + input_text = await executor._build_input_text(state, {}, None) + assert input_text == "workflow input" + + async def test_agent_executor_build_input_text_from_system_last_message(self, mock_context, mock_shared_state): + """Test _build_input_text falls back to system.LastMessage.Text.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("System.LastMessage", {"Text": "From last message"}) + + action_def = {"kind": "InvokeAzureAgent", "agent": "Test"} + executor = InvokeAzureAgentExecutor(action_def) + + input_text = await executor._build_input_text(state, {}, None) + assert input_text == "From last message" + + async def test_agent_executor_missing_agent_name(self, mock_context, mock_shared_state): + """Test agent executor with missing agent name logs warning.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = {"kind": "InvokeAzureAgent"} # No agent specified + executor = InvokeAzureAgentExecutor(action_def) + + await executor.handle_action(ActionTrigger(), mock_context) + + # Should complete without error + mock_context.send_message.assert_called_once() + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, ActionComplete) + + async def test_agent_executor_with_working_agent(self, mock_context, mock_shared_state): + """Test agent executor with a working mock agent.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + # Create mock agent + @dataclass + class MockResult: + text: str + messages: list[Any] + + mock_agent = MagicMock() + mock_agent.run = AsyncMock(return_value=MockResult(text="Agent response", messages=[])) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.input", "User query") + + action_def = { + "kind": "InvokeAzureAgent", + "agent": "TestAgent", + "resultProperty": "Local.result", + } + executor = InvokeAzureAgentExecutor(action_def, agents={"TestAgent": mock_agent}) + + await executor.handle_action(ActionTrigger(), mock_context) + + # Verify agent was called + mock_agent.run.assert_called_once() + + # Verify result was stored + result = await state.get("Local.result") + assert result == "Agent response" + + # Verify agent state was set + assert await state.get("Agent.response") == "Agent response" + assert await state.get("Agent.name") == "TestAgent" + assert await state.get("Agent.text") == "Agent response" + + async def test_agent_executor_with_agent_from_registry(self, mock_context, mock_shared_state): + """Test agent executor retrieves agent from shared state registry.""" + from agent_framework_declarative._workflows._executors_agents import ( + AGENT_REGISTRY_KEY, + InvokeAzureAgentExecutor, + ) + + # Create mock agent + @dataclass + class MockResult: + text: str + messages: list[Any] + + mock_agent = MagicMock() + mock_agent.run = AsyncMock(return_value=MockResult(text="Registry agent", messages=[])) + + # Store in registry + mock_shared_state._data[AGENT_REGISTRY_KEY] = {"RegistryAgent": mock_agent} + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.input", "Query") + + action_def = { + "kind": "InvokeAzureAgent", + "agent": "RegistryAgent", + } + executor = InvokeAzureAgentExecutor(action_def) + + await executor.handle_action(ActionTrigger(), mock_context) + + mock_agent.run.assert_called_once() + + async def test_agent_executor_parses_json_response(self, mock_context, mock_shared_state): + """Test agent executor parses JSON response into responseObject.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + @dataclass + class MockResult: + text: str + messages: list[Any] + + mock_agent = MagicMock() + mock_agent.run = AsyncMock(return_value=MockResult(text='{"status": "ok", "count": 42}', messages=[])) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.input", "Query") + + action_def = { + "kind": "InvokeAzureAgent", + "agent": "TestAgent", + "output": { + "responseObject": "Local.Parsed", + }, + } + executor = InvokeAzureAgentExecutor(action_def, agents={"TestAgent": mock_agent}) + + await executor.handle_action(ActionTrigger(), mock_context) + + parsed = await state.get("Local.Parsed") + assert parsed == {"status": "ok", "count": 42} + + async def test_invoke_tool_executor_not_found(self, mock_context, mock_shared_state): + """Test InvokeToolExecutor when tool not found.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeToolExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "InvokeTool", + "tool": "MissingTool", + "resultProperty": "Local.result", + } + executor = InvokeToolExecutor(action_def) + + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.result") + assert result == {"error": "Tool 'MissingTool' not found in registry"} + + async def test_invoke_tool_executor_sync_tool(self, mock_context, mock_shared_state): + """Test InvokeToolExecutor with synchronous tool.""" + from agent_framework_declarative._workflows._executors_agents import ( + TOOL_REGISTRY_KEY, + InvokeToolExecutor, + ) + + def my_tool(x: int, y: int) -> int: + return x + y + + mock_shared_state._data[TOOL_REGISTRY_KEY] = {"add": my_tool} + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "InvokeTool", + "tool": "add", + "parameters": {"x": 5, "y": 3}, + "resultProperty": "Local.result", + } + executor = InvokeToolExecutor(action_def) + + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.result") + assert result == 8 + + async def test_invoke_tool_executor_async_tool(self, mock_context, mock_shared_state): + """Test InvokeToolExecutor with asynchronous tool.""" + from agent_framework_declarative._workflows._executors_agents import ( + TOOL_REGISTRY_KEY, + InvokeToolExecutor, + ) + + async def my_async_tool(input: str) -> str: + return f"Processed: {input}" + + mock_shared_state._data[TOOL_REGISTRY_KEY] = {"process": my_async_tool} + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "InvokeTool", + "tool": "process", + "input": "test data", + "resultProperty": "Local.result", + } + executor = InvokeToolExecutor(action_def) + + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.result") + assert result == "Processed: test data" + + +# --------------------------------------------------------------------------- +# Control Flow Executors Tests - Additional coverage +# --------------------------------------------------------------------------- + + +class TestControlFlowCoverage: + """Tests for control flow executors covering uncovered code paths.""" + + async def test_foreach_with_source_alias(self, mock_context, mock_shared_state): + """Test ForeachInitExecutor with 'source' alias (interpreter mode).""" + from agent_framework_declarative._workflows._executors_control_flow import ( + ForeachInitExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.data", [10, 20, 30]) + + action_def = { + "kind": "Foreach", + "source": "=Local.data", + "itemName": "item", + "indexName": "idx", + } + executor = ForeachInitExecutor(action_def) + + await executor.handle_action(ActionTrigger(), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, LoopIterationResult) + assert msg.has_next is True + assert msg.current_item == 10 + assert msg.current_index == 0 + + async def test_foreach_next_continues_iteration(self, mock_context, mock_shared_state): + """Test ForeachNextExecutor continues to next item.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + LOOP_STATE_KEY, + ForeachNextExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.data", ["a", "b", "c"]) + + # Set up loop state as ForeachInitExecutor would + state_data = await state.get_state_data() + state_data[LOOP_STATE_KEY] = { + "foreach_init": { + "items": ["a", "b", "c"], + "index": 0, + "length": 3, + } + } + await state.set_state_data(state_data) + + action_def = { + "kind": "Foreach", + "itemsSource": "=Local.data", + "iteratorVariable": "Local.item", + } + executor = ForeachNextExecutor(action_def, init_executor_id="foreach_init") + + await executor.handle_action(LoopIterationResult(has_next=True), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, LoopIterationResult) + assert msg.current_index == 1 + assert msg.current_item == "b" + + async def test_switch_evaluator_with_value_cases(self, mock_context, mock_shared_state): + """Test SwitchEvaluatorExecutor with value/cases schema.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + SwitchEvaluatorExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.status", "pending") + + action_def = { + "kind": "Switch", + "value": "=Local.status", + } + cases = [ + {"match": "active"}, + {"match": "pending"}, + ] + executor = SwitchEvaluatorExecutor(action_def, cases=cases) + + await executor.handle_action(ActionTrigger(), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, ConditionResult) + assert msg.matched is True + assert msg.branch_index == 1 # Second case matched + + async def test_switch_evaluator_default_case(self, mock_context, mock_shared_state): + """Test SwitchEvaluatorExecutor falls through to default.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + SwitchEvaluatorExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.status", "unknown") + + action_def = { + "kind": "Switch", + "value": "=Local.status", + } + cases = [ + {"match": "active"}, + {"match": "pending"}, + ] + executor = SwitchEvaluatorExecutor(action_def, cases=cases) + + await executor.handle_action(ActionTrigger(), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, ConditionResult) + assert msg.matched is False + assert msg.branch_index == -1 # Default case + + async def test_switch_evaluator_no_value(self, mock_context, mock_shared_state): + """Test SwitchEvaluatorExecutor with no value defaults to else.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + SwitchEvaluatorExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = {"kind": "Switch"} # No value + cases = [{"match": "x"}] + executor = SwitchEvaluatorExecutor(action_def, cases=cases) + + await executor.handle_action(ActionTrigger(), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, ConditionResult) + assert msg.branch_index == -1 + + async def test_join_executor_accepts_condition_result(self, mock_context, mock_shared_state): + """Test JoinExecutor accepts ConditionResult as trigger.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + JoinExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = {"kind": "_Join"} + executor = JoinExecutor(action_def) + + # Trigger with ConditionResult + await executor.handle_action(ConditionResult(matched=True, branch_index=0), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, ActionComplete) + + async def test_break_loop_executor(self, mock_context, mock_shared_state): + """Test BreakLoopExecutor emits LoopControl.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + BreakLoopExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = {"kind": "BreakLoop"} + executor = BreakLoopExecutor(action_def, loop_next_executor_id="loop_next") + + await executor.handle_action(ActionTrigger(), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, LoopControl) + assert msg.action == "break" + + async def test_continue_loop_executor(self, mock_context, mock_shared_state): + """Test ContinueLoopExecutor emits LoopControl.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + ContinueLoopExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = {"kind": "ContinueLoop"} + executor = ContinueLoopExecutor(action_def, loop_next_executor_id="loop_next") + + await executor.handle_action(ActionTrigger(), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, LoopControl) + assert msg.action == "continue" + + async def test_foreach_next_no_loop_state(self, mock_context, mock_shared_state): + """Test ForeachNextExecutor with missing loop state.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + ForeachNextExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "Foreach", + "itemsSource": "=Local.data", + "iteratorVariable": "Local.item", + } + executor = ForeachNextExecutor(action_def, init_executor_id="missing_loop") + + await executor.handle_action(ActionTrigger(), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, LoopIterationResult) + assert msg.has_next is False + + async def test_foreach_next_loop_complete(self, mock_context, mock_shared_state): + """Test ForeachNextExecutor when loop is complete.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + LOOP_STATE_KEY, + ForeachNextExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # Set up loop state at last item + state_data = await state.get_state_data() + state_data[LOOP_STATE_KEY] = { + "loop_id": { + "items": ["a", "b"], + "index": 1, # Already at last item + "length": 2, + } + } + await state.set_state_data(state_data) + + action_def = { + "kind": "Foreach", + "itemsSource": "=Local.data", + "iteratorVariable": "Local.item", + } + executor = ForeachNextExecutor(action_def, init_executor_id="loop_id") + + await executor.handle_action(ActionTrigger(), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, LoopIterationResult) + assert msg.has_next is False + + async def test_foreach_next_handle_break_control(self, mock_context, mock_shared_state): + """Test ForeachNextExecutor handles break LoopControl.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + LOOP_STATE_KEY, + ForeachNextExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # Set up loop state + state_data = await state.get_state_data() + state_data[LOOP_STATE_KEY] = { + "loop_id": { + "items": ["a", "b", "c"], + "index": 0, + "length": 3, + } + } + await state.set_state_data(state_data) + + action_def = { + "kind": "Foreach", + "itemsSource": "=Local.data", + "iteratorVariable": "Local.item", + } + executor = ForeachNextExecutor(action_def, init_executor_id="loop_id") + + await executor.handle_loop_control(LoopControl(action="break"), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, LoopIterationResult) + assert msg.has_next is False + + async def test_foreach_next_handle_continue_control(self, mock_context, mock_shared_state): + """Test ForeachNextExecutor handles continue LoopControl.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + LOOP_STATE_KEY, + ForeachNextExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # Set up loop state + state_data = await state.get_state_data() + state_data[LOOP_STATE_KEY] = { + "loop_id": { + "items": ["a", "b", "c"], + "index": 0, + "length": 3, + } + } + await state.set_state_data(state_data) + + action_def = { + "kind": "Foreach", + "itemsSource": "=Local.data", + "iteratorVariable": "Local.item", + } + executor = ForeachNextExecutor(action_def, init_executor_id="loop_id") + + await executor.handle_loop_control(LoopControl(action="continue"), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, LoopIterationResult) + assert msg.has_next is True + assert msg.current_index == 1 + + async def test_end_workflow_executor(self, mock_context, mock_shared_state): + """Test EndWorkflowExecutor does not send continuation.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + EndWorkflowExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = {"kind": "EndWorkflow"} + executor = EndWorkflowExecutor(action_def) + + await executor.handle_action(ActionTrigger(), mock_context) + + # Should NOT send any message + mock_context.send_message.assert_not_called() + + async def test_end_conversation_executor(self, mock_context, mock_shared_state): + """Test EndConversationExecutor does not send continuation.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + EndConversationExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = {"kind": "EndConversation"} + executor = EndConversationExecutor(action_def) + + await executor.handle_action(ActionTrigger(), mock_context) + + # Should NOT send any message + mock_context.send_message.assert_not_called() + + async def test_condition_group_evaluator_first_match(self, mock_context, mock_shared_state): + """Test ConditionGroupEvaluatorExecutor returns first match.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + ConditionGroupEvaluatorExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.x", 10) + + action_def = {"kind": "ConditionGroup"} + conditions = [ + {"condition": "=Local.x > 20"}, + {"condition": "=Local.x > 5"}, + {"condition": "=Local.x > 0"}, + ] + executor = ConditionGroupEvaluatorExecutor(action_def, conditions=conditions) + + await executor.handle_action(ActionTrigger(), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, ConditionResult) + assert msg.matched is True + assert msg.branch_index == 1 # Second condition (x > 5) is first match + + async def test_condition_group_evaluator_no_match(self, mock_context, mock_shared_state): + """Test ConditionGroupEvaluatorExecutor with no matches.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + ConditionGroupEvaluatorExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.x", 0) + + action_def = {"kind": "ConditionGroup"} + conditions = [ + {"condition": "=Local.x > 10"}, + {"condition": "=Local.x > 5"}, + ] + executor = ConditionGroupEvaluatorExecutor(action_def, conditions=conditions) + + await executor.handle_action(ActionTrigger(), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, ConditionResult) + assert msg.matched is False + assert msg.branch_index == -1 + + async def test_condition_group_evaluator_boolean_true_condition(self, mock_context, mock_shared_state): + """Test ConditionGroupEvaluatorExecutor with boolean True condition.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + ConditionGroupEvaluatorExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = {"kind": "ConditionGroup"} + conditions = [ + {"condition": False}, # Should skip + {"condition": True}, # Should match + ] + executor = ConditionGroupEvaluatorExecutor(action_def, conditions=conditions) + + await executor.handle_action(ActionTrigger(), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, ConditionResult) + assert msg.matched is True + assert msg.branch_index == 1 + + async def test_if_condition_evaluator_true(self, mock_context, mock_shared_state): + """Test IfConditionEvaluatorExecutor with true condition.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + IfConditionEvaluatorExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.flag", True) + + action_def = {"kind": "If"} + executor = IfConditionEvaluatorExecutor(action_def, condition_expr="=Local.flag") + + await executor.handle_action(ActionTrigger(), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, ConditionResult) + assert msg.matched is True + assert msg.branch_index == 0 # Then branch + + async def test_if_condition_evaluator_false(self, mock_context, mock_shared_state): + """Test IfConditionEvaluatorExecutor with false condition.""" + from agent_framework_declarative._workflows._executors_control_flow import ( + IfConditionEvaluatorExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.flag", False) + + action_def = {"kind": "If"} + executor = IfConditionEvaluatorExecutor(action_def, condition_expr="=Local.flag") + + await executor.handle_action(ActionTrigger(), mock_context) + + msg = mock_context.send_message.call_args[0][0] + assert isinstance(msg, ConditionResult) + assert msg.matched is False + assert msg.branch_index == -1 # Else branch + + +# --------------------------------------------------------------------------- +# Declarative Action Executor Base Tests +# --------------------------------------------------------------------------- + + +class TestDeclarativeActionExecutorBase: + """Tests for DeclarativeActionExecutor base class.""" + + async def test_ensure_state_initialized_with_dict_input(self, mock_context, mock_shared_state): + """Test _ensure_state_initialized with dict input.""" + from agent_framework_declarative._workflows._executors_basic import ( + SetValueExecutor, + ) + + action_def = {"kind": "SetValue", "path": "Local.x", "value": 1} + executor = SetValueExecutor(action_def) + + # Trigger with dict - should initialize state with it + await executor.handle_action({"custom": "input"}, mock_context) + + # State should have been initialized with the dict + state = DeclarativeWorkflowState(mock_shared_state) + inputs = await state.get("Workflow.Inputs") + assert inputs == {"custom": "input"} + + async def test_ensure_state_initialized_with_string_input(self, mock_context, mock_shared_state): + """Test _ensure_state_initialized with string input.""" + from agent_framework_declarative._workflows._executors_basic import ( + SetValueExecutor, + ) + + action_def = {"kind": "SetValue", "path": "Local.x", "value": 1} + executor = SetValueExecutor(action_def) + + # Trigger with string - should wrap in {"input": ...} + await executor.handle_action("string trigger", mock_context) + + state = DeclarativeWorkflowState(mock_shared_state) + inputs = await state.get("Workflow.Inputs") + assert inputs == {"input": "string trigger"} + + async def test_ensure_state_initialized_with_custom_object(self, mock_context, mock_shared_state): + """Test _ensure_state_initialized with custom object converts to string.""" + from agent_framework_declarative._workflows._executors_basic import ( + SetValueExecutor, + ) + + class CustomObj: + def __str__(self): + return "custom string" + + action_def = {"kind": "SetValue", "path": "Local.x", "value": 1} + executor = SetValueExecutor(action_def) + + await executor.handle_action(CustomObj(), mock_context) + + state = DeclarativeWorkflowState(mock_shared_state) + inputs = await state.get("Workflow.Inputs") + assert inputs == {"input": "custom string"} + + async def test_executor_display_name_property(self, mock_context, mock_shared_state): + """Test executor display_name property.""" + from agent_framework_declarative._workflows._executors_basic import ( + SetValueExecutor, + ) + + action_def = { + "kind": "SetValue", + "displayName": "My Custom Action", + "path": "Local.x", + "value": 1, + } + executor = SetValueExecutor(action_def) + + assert executor.display_name == "My Custom Action" + + async def test_executor_action_def_property(self, mock_context, mock_shared_state): + """Test executor action_def property.""" + from agent_framework_declarative._workflows._executors_basic import ( + SetValueExecutor, + ) + + action_def = {"kind": "SetValue", "path": "Local.x", "value": 1} + executor = SetValueExecutor(action_def) + + assert executor.action_def == action_def + + +# --------------------------------------------------------------------------- +# Human Input Executors Tests - Covering _executors_external_input.py gaps +# --------------------------------------------------------------------------- + + +class TestHumanInputExecutorsCoverage: + """Tests for human input executors covering uncovered code paths.""" + + async def test_wait_for_input_executor_with_prompt(self, mock_context, mock_shared_state): + """Test WaitForInputExecutor with prompt.""" + from agent_framework_declarative._workflows._executors_external_input import ( + ExternalInputRequest, + WaitForInputExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "WaitForInput", + "prompt": "Please enter your name:", + "property": "Local.userName", + "timeout": 30, + } + executor = WaitForInputExecutor(action_def) + + await executor.handle_action(ActionTrigger(), mock_context) + + # Should yield prompt first, then call request_info + assert mock_context.yield_output.call_count == 1 + assert mock_context.yield_output.call_args_list[0][0][0] == "Please enter your name:" + # request_info call for ExternalInputRequest + mock_context.request_info.assert_called_once() + request = mock_context.request_info.call_args[0][0] + assert isinstance(request, ExternalInputRequest) + assert request.request_type == "user_input" + + async def test_wait_for_input_executor_no_prompt(self, mock_context, mock_shared_state): + """Test WaitForInputExecutor without prompt.""" + from agent_framework_declarative._workflows._executors_external_input import ( + ExternalInputRequest, + WaitForInputExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "WaitForInput", + "property": "Local.input", + } + executor = WaitForInputExecutor(action_def) + + await executor.handle_action(ActionTrigger(), mock_context) + + # Should not yield output (no prompt), just call request_info + assert mock_context.yield_output.call_count == 0 + mock_context.request_info.assert_called_once() + request = mock_context.request_info.call_args[0][0] + assert isinstance(request, ExternalInputRequest) + assert request.request_type == "user_input" + + async def test_request_external_input_executor(self, mock_context, mock_shared_state): + """Test RequestExternalInputExecutor.""" + from agent_framework_declarative._workflows._executors_external_input import ( + ExternalInputRequest, + RequestExternalInputExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "RequestExternalInput", + "requestType": "approval", + "message": "Please approve this request", + "property": "Local.approvalResult", + "timeout": 3600, + "requiredFields": ["approver", "notes"], + "metadata": {"priority": "high"}, + } + executor = RequestExternalInputExecutor(action_def) + + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.request_info.assert_called_once() + request = mock_context.request_info.call_args[0][0] + assert isinstance(request, ExternalInputRequest) + assert request.request_type == "approval" + assert request.message == "Please approve this request" + assert request.metadata["priority"] == "high" + assert request.metadata["required_fields"] == ["approver", "notes"] + assert request.metadata["timeout_seconds"] == 3600 + + async def test_question_executor_with_choices(self, mock_context, mock_shared_state): + """Test QuestionExecutor with choices as dicts and strings.""" + from agent_framework_declarative._workflows._executors_external_input import ( + ExternalInputRequest, + QuestionExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "Question", + "question": "Select an option:", + "property": "Local.selection", + "choices": [ + {"value": "a", "label": "Option A"}, + {"value": "b"}, # No label, should use value + "c", # String choice + ], + "allowFreeText": False, + } + executor = QuestionExecutor(action_def) + + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.request_info.assert_called_once() + request = mock_context.request_info.call_args[0][0] + assert isinstance(request, ExternalInputRequest) + assert request.request_type == "question" + choices = request.metadata["choices"] + assert len(choices) == 3 + assert choices[0] == {"value": "a", "label": "Option A"} + assert choices[1] == {"value": "b", "label": "b"} + assert choices[2] == {"value": "c", "label": "c"} + assert request.metadata["allow_free_text"] is False + + +# --------------------------------------------------------------------------- +# Additional Agent Executor Tests - External Loop Coverage +# --------------------------------------------------------------------------- + + +class TestAgentExternalLoopCoverage: + """Tests for agent executor external loop handling.""" + + async def test_agent_executor_with_external_loop(self, mock_context, mock_shared_state): + """Test agent executor with external loop that triggers.""" + from unittest.mock import patch + + from agent_framework_declarative._workflows._executors_agents import ( + AgentExternalInputRequest, + InvokeAzureAgentExecutor, + ) + + mock_agent = MagicMock() + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.input", "User query") + await state.set("Local.needsMore", True) # Loop condition will be true + + action_def = { + "kind": "InvokeAzureAgent", + "agent": "TestAgent", + "input": { + "externalLoop": {"when": "=Local.needsMore"}, + }, + } + executor = InvokeAzureAgentExecutor(action_def, agents={"TestAgent": mock_agent}) + + # Mock the internal method to avoid storing ChatMessage objects in state + # (PowerFx cannot serialize ChatMessage) + with patch.object( + executor, + "_invoke_agent_and_store_results", + new=AsyncMock(return_value=("Need more info", [], [])), + ): + await executor.handle_action(ActionTrigger(), mock_context) + + # Should request external input via request_info + mock_context.request_info.assert_called_once() + request = mock_context.request_info.call_args[0][0] + assert isinstance(request, AgentExternalInputRequest) + assert request.agent_name == "TestAgent" + + async def test_agent_executor_agent_error_handling(self, mock_context, mock_shared_state): + """Test agent executor raises AgentInvocationError on failure.""" + from agent_framework_declarative._workflows._executors_agents import ( + AgentInvocationError, + InvokeAzureAgentExecutor, + ) + + mock_agent = MagicMock() + mock_agent.run = AsyncMock(side_effect=RuntimeError("Agent failed")) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.input", "Query") + + action_def = { + "kind": "InvokeAzureAgent", + "agent": "TestAgent", + "resultProperty": "Local.result", + } + executor = InvokeAzureAgentExecutor(action_def, agents={"TestAgent": mock_agent}) + + with pytest.raises(AgentInvocationError) as exc_info: + await executor.handle_action(ActionTrigger(), mock_context) + + assert "TestAgent" in str(exc_info.value) + assert "Agent failed" in str(exc_info.value) + + # Should still store error in state before raising + error = await state.get("Agent.error") + assert "Agent failed" in error + result = await state.get("Local.result") + assert result == {"error": "Agent failed"} + + async def test_agent_executor_string_result(self, mock_context, mock_shared_state): + """Test agent executor with agent that returns string directly.""" + from agent_framework_declarative._workflows._executors_agents import ( + InvokeAzureAgentExecutor, + ) + + mock_agent = MagicMock() + mock_agent.run = AsyncMock(return_value="Direct string response") + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.input", "Query") + + action_def = { + "kind": "InvokeAzureAgent", + "agent": "TestAgent", + "resultProperty": "Local.result", + "output": {"autoSend": True}, + } + executor = InvokeAzureAgentExecutor(action_def, agents={"TestAgent": mock_agent}) + + await executor.handle_action(ActionTrigger(), mock_context) + + # Should auto-send output + mock_context.yield_output.assert_called_with("Direct string response") + result = await state.get("Local.result") + assert result == "Direct string response" + + async def test_invoke_tool_with_error(self, mock_context, mock_shared_state): + """Test InvokeToolExecutor handles tool errors.""" + from agent_framework_declarative._workflows._executors_agents import ( + TOOL_REGISTRY_KEY, + InvokeToolExecutor, + ) + + def failing_tool(**kwargs): + raise ValueError("Tool error") + + mock_shared_state._data[TOOL_REGISTRY_KEY] = {"bad_tool": failing_tool} + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "InvokeTool", + "tool": "bad_tool", + "resultProperty": "Local.result", + } + executor = InvokeToolExecutor(action_def) + + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.result") + assert result == {"error": "Tool error"} + + +# --------------------------------------------------------------------------- +# PowerFx Functions Coverage +# --------------------------------------------------------------------------- + + +class TestPowerFxFunctionsCoverage: + """Tests for PowerFx function evaluation coverage.""" + + async def test_eval_lower_upper_functions(self, mock_shared_state): + """Test Lower and Upper functions.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.text", "Hello World") + + result = await state.eval("=Lower(Local.text)") + assert result == "hello world" + + result = await state.eval("=Upper(Local.text)") + assert result == "HELLO WORLD" + + async def test_eval_if_function(self, mock_shared_state): + """Test If function.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.flag", True) + + result = await state.eval('=If(Local.flag, "yes", "no")') + assert result == "yes" + + await state.set("Local.flag", False) + result = await state.eval('=If(Local.flag, "yes", "no")') + assert result == "no" + + async def test_eval_not_function(self, mock_shared_state): + """Test Not function.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.flag", True) + + result = await state.eval("=Not(Local.flag)") + assert result is False + + async def test_eval_and_or_functions(self, mock_shared_state): + """Test And and Or functions.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.a", True) + await state.set("Local.b", False) + + result = await state.eval("=And(Local.a, Local.b)") + assert result is False + + result = await state.eval("=Or(Local.a, Local.b)") + assert result is True + + +# --------------------------------------------------------------------------- +# Builder control flow tests - Covering Goto/Break/Continue creation +# --------------------------------------------------------------------------- + + +class TestBuilderControlFlowCreation: + """Tests for Goto, Break, Continue executor creation in builder.""" + + def test_create_goto_reference(self): + """Test creating a goto reference executor.""" + from agent_framework import WorkflowBuilder + + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + # Create builder with minimal yaml definition + yaml_def = {"name": "test_workflow", "actions": []} + graph_builder = DeclarativeWorkflowBuilder(yaml_def) + wb = WorkflowBuilder() + + action_def = { + "kind": "GotoAction", + "target": "some_target_action", + "id": "goto_test", + } + + executor = graph_builder._create_goto_reference(action_def, wb, None) + + assert executor is not None + assert executor.id == "goto_test" + # Verify pending goto was recorded + assert len(graph_builder._pending_gotos) == 1 + assert graph_builder._pending_gotos[0][1] == "some_target_action" + + def test_create_goto_reference_auto_id(self): + """Test creating a goto with auto-generated ID.""" + from agent_framework import WorkflowBuilder + + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + yaml_def = {"name": "test_workflow", "actions": []} + graph_builder = DeclarativeWorkflowBuilder(yaml_def) + wb = WorkflowBuilder() + + action_def = { + "kind": "GotoAction", + "target": "target_action", + } + + executor = graph_builder._create_goto_reference(action_def, wb, None) + + assert executor is not None + assert "goto_target_action" in executor.id + + def test_create_goto_reference_no_target(self): + """Test creating a goto with no target returns None.""" + from agent_framework import WorkflowBuilder + + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + yaml_def = {"name": "test_workflow", "actions": []} + graph_builder = DeclarativeWorkflowBuilder(yaml_def) + wb = WorkflowBuilder() + + action_def = { + "kind": "GotoAction", + # No target specified + } + + executor = graph_builder._create_goto_reference(action_def, wb, None) + assert executor is None + + def test_goto_invalid_target_raises_error(self): + """Test that goto to non-existent target raises ValueError.""" + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + yaml_def = { + "name": "test_workflow", + "actions": [ + {"kind": "SendActivity", "id": "action1", "activity": {"text": "Hello"}}, + {"kind": "GotoAction", "target": "non_existent_action"}, + ], + } + builder = DeclarativeWorkflowBuilder(yaml_def) + + with pytest.raises(ValueError) as exc_info: + builder.build() + + assert "non_existent_action" in str(exc_info.value) + assert "not found" in str(exc_info.value) + + def test_create_break_executor(self): + """Test creating a break executor within a loop context.""" + from agent_framework import WorkflowBuilder + + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + from agent_framework_declarative._workflows._executors_control_flow import ForeachNextExecutor + + yaml_def = {"name": "test_workflow", "actions": []} + graph_builder = DeclarativeWorkflowBuilder(yaml_def) + wb = WorkflowBuilder() + + # Create a mock loop_next executor + loop_next = ForeachNextExecutor( + {"kind": "Foreach", "itemsProperty": "items"}, + init_executor_id="foreach_init", + id="foreach_next", + ) + wb._add_executor(loop_next) + + parent_context = {"loop_next_executor": loop_next} + + action_def = { + "kind": "BreakLoop", + "id": "break_test", + } + + executor = graph_builder._create_break_executor(action_def, wb, parent_context) + + assert executor is not None + assert executor.id == "break_test" + + def test_create_break_executor_no_loop_context(self): + """Test creating a break executor without loop context raises ValueError.""" + from agent_framework import WorkflowBuilder + + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + yaml_def = {"name": "test_workflow", "actions": []} + graph_builder = DeclarativeWorkflowBuilder(yaml_def) + wb = WorkflowBuilder() + + action_def = { + "kind": "BreakLoop", + } + + # No parent_context should raise ValueError + with pytest.raises(ValueError) as exc_info: + graph_builder._create_break_executor(action_def, wb, None) + assert "BreakLoop action can only be used inside a Foreach loop" in str(exc_info.value) + + # Empty context should also raise ValueError + with pytest.raises(ValueError) as exc_info: + graph_builder._create_break_executor(action_def, wb, {}) + assert "BreakLoop action can only be used inside a Foreach loop" in str(exc_info.value) + + def test_create_continue_executor(self): + """Test creating a continue executor within a loop context.""" + from agent_framework import WorkflowBuilder + + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + from agent_framework_declarative._workflows._executors_control_flow import ForeachNextExecutor + + yaml_def = {"name": "test_workflow", "actions": []} + graph_builder = DeclarativeWorkflowBuilder(yaml_def) + wb = WorkflowBuilder() + + # Create a mock loop_next executor + loop_next = ForeachNextExecutor( + {"kind": "Foreach", "itemsProperty": "items"}, + init_executor_id="foreach_init", + id="foreach_next", + ) + wb._add_executor(loop_next) + + parent_context = {"loop_next_executor": loop_next} + + action_def = { + "kind": "ContinueLoop", + "id": "continue_test", + } + + executor = graph_builder._create_continue_executor(action_def, wb, parent_context) + + assert executor is not None + assert executor.id == "continue_test" + + def test_create_continue_executor_no_loop_context(self): + """Test creating a continue executor without loop context raises ValueError.""" + from agent_framework import WorkflowBuilder + + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + yaml_def = {"name": "test_workflow", "actions": []} + graph_builder = DeclarativeWorkflowBuilder(yaml_def) + wb = WorkflowBuilder() + + action_def = { + "kind": "ContinueLoop", + } + + # No parent_context should raise ValueError + with pytest.raises(ValueError) as exc_info: + graph_builder._create_continue_executor(action_def, wb, None) + assert "ContinueLoop action can only be used inside a Foreach loop" in str(exc_info.value) + + +class TestBuilderEdgeWiring: + """Tests for builder edge wiring methods.""" + + def test_wire_to_target_with_if_structure(self): + """Test wiring to an If structure routes to evaluator.""" + from agent_framework import WorkflowBuilder + + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + from agent_framework_declarative._workflows._executors_basic import SendActivityExecutor + + yaml_def = {"name": "test_workflow", "actions": []} + graph_builder = DeclarativeWorkflowBuilder(yaml_def) + wb = WorkflowBuilder() + + # Create a mock source executor + source = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "test"}}, id="source") + wb._add_executor(source) + + # Create a mock If structure with evaluator + class MockIfStructure: + _is_if_structure = True + + def __init__(self): + self.evaluator = SendActivityExecutor( + {"kind": "SendActivity", "activity": {"text": "evaluator"}}, id="evaluator" + ) + + target = MockIfStructure() + wb._add_executor(target.evaluator) + + # Wire should add edge to evaluator + graph_builder._wire_to_target(wb, source, target) + + # Verify edge was added (would need to inspect workflow internals) + # For now, just verify no exception was raised + + def test_wire_to_target_normal_executor(self): + """Test wiring to a normal executor adds direct edge.""" + from agent_framework import WorkflowBuilder + + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + from agent_framework_declarative._workflows._executors_basic import SendActivityExecutor + + yaml_def = {"name": "test_workflow", "actions": []} + graph_builder = DeclarativeWorkflowBuilder(yaml_def) + wb = WorkflowBuilder() + + source = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "source"}}, id="source") + target = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "target"}}, id="target") + + wb._add_executor(source) + wb._add_executor(target) + + graph_builder._wire_to_target(wb, source, target) + # Verify edge creation (no exception = success) + + def test_collect_all_exits_for_nested_structure(self): + """Test collecting all exits from nested structures.""" + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + from agent_framework_declarative._workflows._executors_basic import SendActivityExecutor + + yaml_def = {"name": "test_workflow", "actions": []} + graph_builder = DeclarativeWorkflowBuilder(yaml_def) + + # Create mock nested structure + exit1 = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "exit1"}}, id="exit1") + exit2 = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "exit2"}}, id="exit2") + + class InnerStructure: + def __init__(self): + self.branch_exits = [exit1, exit2] + + class OuterStructure: + def __init__(self): + self.branch_exits = [InnerStructure()] + + outer = OuterStructure() + exits = graph_builder._collect_all_exits(outer) + + assert len(exits) == 2 + assert exit1 in exits + assert exit2 in exits + + def test_collect_all_exits_for_simple_executor(self): + """Test collecting exits from a simple executor.""" + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + from agent_framework_declarative._workflows._executors_basic import SendActivityExecutor + + yaml_def = {"name": "test_workflow", "actions": []} + graph_builder = DeclarativeWorkflowBuilder(yaml_def) + + executor = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "test"}}, id="test") + + exits = graph_builder._collect_all_exits(executor) + + assert len(exits) == 1 + assert executor in exits + + def test_get_branch_exit_with_chain(self): + """Test getting branch exit from a chain of executors.""" + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + from agent_framework_declarative._workflows._executors_basic import SendActivityExecutor + + yaml_def = {"name": "test_workflow", "actions": []} + graph_builder = DeclarativeWorkflowBuilder(yaml_def) + + exec1 = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "1"}}, id="e1") + exec2 = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "2"}}, id="e2") + exec3 = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "3"}}, id="e3") + + # Simulate a chain by dynamically setting attribute + exec1._chain_executors = [exec1, exec2, exec3] # type: ignore[attr-defined] + + exit_exec = graph_builder._get_branch_exit(exec1) + + assert exit_exec == exec3 + + def test_get_branch_exit_none(self): + """Test getting branch exit from None.""" + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + yaml_def = {"name": "test_workflow", "actions": []} + graph_builder = DeclarativeWorkflowBuilder(yaml_def) + + exit_exec = graph_builder._get_branch_exit(None) + assert exit_exec is None + + +# --------------------------------------------------------------------------- +# Agent executor external loop response handler tests +# --------------------------------------------------------------------------- + + +class TestAgentExecutorExternalLoop: + """Tests for InvokeAzureAgentExecutor external loop response handling.""" + + async def test_handle_external_input_response_no_state(self, mock_context, mock_shared_state): + """Test handling external input response when loop state not found.""" + from agent_framework_declarative._workflows._executors_agents import ( + AgentExternalInputRequest, + AgentExternalInputResponse, + InvokeAzureAgentExecutor, + ) + + executor = InvokeAzureAgentExecutor({"kind": "InvokeAzureAgent", "agent": "TestAgent"}) + + # No external loop state in shared_state + original_request = AgentExternalInputRequest( + request_id="req-1", + agent_name="TestAgent", + agent_response="Hello", + iteration=1, + ) + response = AgentExternalInputResponse(user_input="hi there") + + await executor.handle_external_input_response(original_request, response, mock_context) + + # Should send ActionComplete due to missing state + mock_context.send_message.assert_called() + call_args = mock_context.send_message.call_args[0][0] + from agent_framework_declarative._workflows import ActionComplete + + assert isinstance(call_args, ActionComplete) + + async def test_handle_external_input_response_agent_not_found(self, mock_context, mock_shared_state): + """Test handling external input raises error when agent not found during resumption.""" + from agent_framework_declarative._workflows._executors_agents import ( + EXTERNAL_LOOP_STATE_KEY, + AgentExternalInputRequest, + AgentExternalInputResponse, + AgentInvocationError, + ExternalLoopState, + InvokeAzureAgentExecutor, + ) + + # Set up loop state with always true condition (literal) + loop_state = ExternalLoopState( + agent_name="NonExistentAgent", + iteration=1, + external_loop_when="true", # Literal true + messages_var=None, + response_obj_var=None, + result_property=None, + auto_send=True, + messages_path="Conversation.messages", + ) + mock_shared_state._data[EXTERNAL_LOOP_STATE_KEY] = loop_state + + # Initialize declarative state with simple value + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + executor = InvokeAzureAgentExecutor({"kind": "InvokeAzureAgent", "agent": "NonExistentAgent"}) + + original_request = AgentExternalInputRequest( + request_id="req-1", + agent_name="NonExistentAgent", + agent_response="Hello", + iteration=1, + ) + response = AgentExternalInputResponse(user_input="continue") + + with pytest.raises(AgentInvocationError) as exc_info: + await executor.handle_external_input_response(original_request, response, mock_context) + + assert "NonExistentAgent" in str(exc_info.value) + assert "not found during loop resumption" in str(exc_info.value) + + +class TestBuilderValidation: + """Tests for builder validation features (P1 fixes).""" + + def test_duplicate_explicit_action_id_raises_error(self): + """Test that duplicate explicit action IDs are detected.""" + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + yaml_def = { + "name": "test_workflow", + "actions": [ + {"id": "my_action", "kind": "SendActivity", "activity": {"text": "First"}}, + {"id": "my_action", "kind": "SendActivity", "activity": {"text": "Second"}}, + ], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + with pytest.raises(ValueError) as exc_info: + builder.build() + + assert "Duplicate action ID 'my_action'" in str(exc_info.value) + + def test_duplicate_id_in_nested_actions(self): + """Test duplicate ID detection in nested If/Switch branches.""" + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + yaml_def = { + "name": "test_workflow", + "actions": [ + { + "kind": "If", + "condition": "=true", + "then": [{"id": "shared_id", "kind": "SendActivity", "activity": {"text": "Then"}}], + "else": [{"id": "shared_id", "kind": "SendActivity", "activity": {"text": "Else"}}], + } + ], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + with pytest.raises(ValueError) as exc_info: + builder.build() + + assert "Duplicate action ID 'shared_id'" in str(exc_info.value) + + def test_missing_required_field_sendactivity(self): + """Test that missing required fields are detected.""" + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + yaml_def = { + "name": "test_workflow", + "actions": [{"kind": "SendActivity"}], # Missing 'activity' field + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + with pytest.raises(ValueError) as exc_info: + builder.build() + + assert "SendActivity" in str(exc_info.value) + assert "missing required field" in str(exc_info.value) + assert "activity" in str(exc_info.value) + + def test_missing_required_field_setvalue(self): + """Test SetValue without path raises error.""" + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + yaml_def = { + "name": "test_workflow", + "actions": [{"kind": "SetValue", "value": "test"}], # Missing 'path' field + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + with pytest.raises(ValueError) as exc_info: + builder.build() + + assert "SetValue" in str(exc_info.value) + assert "path" in str(exc_info.value) + + def test_setvalue_accepts_alternate_variable_field(self): + """Test SetValue accepts 'variable' as alternate to 'path'.""" + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + yaml_def = { + "name": "test_workflow", + "actions": [{"kind": "SetValue", "variable": {"path": "Local.x"}, "value": "test"}], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + # Should not raise - 'variable' is accepted as alternate + workflow = builder.build() + assert workflow is not None + + def test_missing_required_field_foreach(self): + """Test Foreach without items raises error.""" + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + yaml_def = { + "name": "test_workflow", + "actions": [{"kind": "Foreach", "actions": [{"kind": "SendActivity", "activity": {"text": "Hi"}}]}], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + with pytest.raises(ValueError) as exc_info: + builder.build() + + assert "Foreach" in str(exc_info.value) + assert "items" in str(exc_info.value) + + def test_self_referencing_goto_raises_error(self): + """Test that a goto referencing itself is detected.""" + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + yaml_def = { + "name": "test_workflow", + "actions": [{"id": "loop", "kind": "Goto", "target": "loop"}], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + with pytest.raises(ValueError) as exc_info: + builder.build() + + assert "loop" in str(exc_info.value) + assert "self-referencing" in str(exc_info.value) + + def test_validation_can_be_disabled(self): + """Test that validation can be disabled for early schema/duplicate checks. + + Note: Even with validation disabled, the underlying WorkflowBuilder may + still catch duplicates during graph construction. This flag disables + our upfront validation pass but not runtime checks. + """ + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + # Test with missing required field - validation disabled should skip our check + yaml_def = { + "name": "test_workflow", + "actions": [{"kind": "SendActivity"}], # Missing 'activity' - normally caught by validation + } + + # With validation disabled, our upfront check is skipped + builder = DeclarativeWorkflowBuilder(yaml_def, validate=False) + # The workflow may still fail for other reasons, but our validation pass is skipped + # In this case, it should succeed because SendActivityExecutor handles missing fields gracefully + workflow = builder.build() + assert workflow is not None + + def test_validation_in_switch_branches(self): + """Test validation catches issues in Switch branches.""" + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + yaml_def = { + "name": "test_workflow", + "actions": [ + { + "kind": "Switch", + "value": "=Local.choice", + "cases": [ + { + "match": "a", + "actions": [{"id": "dup", "kind": "SendActivity", "activity": {"text": "A"}}], + }, + { + "match": "b", + "actions": [{"id": "dup", "kind": "SendActivity", "activity": {"text": "B"}}], + }, + ], + } + ], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + with pytest.raises(ValueError) as exc_info: + builder.build() + + assert "Duplicate action ID 'dup'" in str(exc_info.value) + + def test_validation_in_foreach_body(self): + """Test validation catches issues in Foreach body.""" + from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder + + yaml_def = { + "name": "test_workflow", + "actions": [ + { + "kind": "Foreach", + "items": "=Local.items", + "actions": [{"kind": "SendActivity"}], # Missing 'activity' + } + ], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + with pytest.raises(ValueError) as exc_info: + builder.build() + + assert "SendActivity" in str(exc_info.value) + assert "activity" in str(exc_info.value) + + +class TestExpressionEdgeCases: + """Tests for expression evaluation edge cases.""" + + async def test_division_with_valid_values(self, mock_shared_state): + """Test normal division works correctly.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.x", 10) + await state.set("Local.y", 4) + + result = await state.eval("=Local.x / Local.y") + assert result == 2.5 + + async def test_multiplication_normal(self, mock_shared_state): + """Test normal multiplication.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.x", 6) + await state.set("Local.y", 7) + + result = await state.eval("=Local.x * Local.y") + assert result == 42 + + +class TestLongMessageTextHandling: + """Tests for handling long MessageText results that exceed PowerFx limits.""" + + async def test_short_message_text_embedded_inline(self, mock_shared_state): + """Test that short MessageText results are embedded inline.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # Store a short message + short_text = "Hello world" + await state.set("Local.Messages", [{"text": short_text, "contents": [{"type": "text", "text": short_text}]}]) + + # Evaluate a formula with MessageText - should embed inline + result = await state.eval("=Upper(MessageText(Local.Messages))") + assert result == "HELLO WORLD" + + # No temp variable should be created for short strings + temp_var = await state.get("Local._TempMessageText0") + assert temp_var is None + + async def test_long_message_text_stored_in_temp_variable(self, mock_shared_state): + """Test that long MessageText results are stored in temp variables.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # Create a message longer than 500 characters + long_text = "A" * 600 # 600 characters exceeds the 500 char threshold + await state.set("Local.Messages", [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}]) + + # Evaluate a formula with MessageText + result = await state.eval("=Upper(MessageText(Local.Messages))") + assert result == "A" * 600 # Upper on 'A' is still 'A' + + # A temp variable should have been created + temp_var = await state.get("Local._TempMessageText0") + assert temp_var == long_text + + async def test_find_with_long_message_text(self, mock_shared_state): + """Test Find function works with long MessageText stored in temp variable.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # Create a long message with a keyword to find + long_text = "X" * 550 + "CONGRATULATIONS" + "Y" * 50 + await state.set("Local.Messages", [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}]) + + # Test the pattern used in student_teacher workflow + result = await state.eval('=!IsBlank(Find("CONGRATULATIONS", Upper(MessageText(Local.Messages))))') + assert result is True + + async def test_find_without_keyword_in_long_text(self, mock_shared_state): + """Test Find returns blank when keyword not found in long text.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # Long text without the keyword + long_text = "X" * 600 + await state.set("Local.Messages", [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}]) + + result = await state.eval('=!IsBlank(Find("CONGRATULATIONS", Upper(MessageText(Local.Messages))))') + assert result is False diff --git a/python/packages/declarative/tests/test_graph_executors.py b/python/packages/declarative/tests/test_graph_executors.py new file mode 100644 index 0000000000..e03895b4ac --- /dev/null +++ b/python/packages/declarative/tests/test_graph_executors.py @@ -0,0 +1,1313 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for the graph-based declarative workflow executors.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agent_framework_declarative._workflows import ( + ALL_ACTION_EXECUTORS, + DECLARATIVE_STATE_KEY, + ActionComplete, + ActionTrigger, + DeclarativeWorkflowBuilder, + DeclarativeWorkflowState, + ForeachInitExecutor, + LoopIterationResult, + SendActivityExecutor, + SetValueExecutor, +) + + +class TestDeclarativeWorkflowState: + """Tests for DeclarativeWorkflowState.""" + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state with async get/set methods.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + + return shared_state + + @pytest.mark.asyncio + async def test_initialize_state(self, mock_shared_state): + """Test initializing the workflow state.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize({"query": "test"}) + + # Verify state was set + mock_shared_state.set.assert_called_once() + call_args = mock_shared_state.set.call_args + assert call_args[0][0] == DECLARATIVE_STATE_KEY + state_data = call_args[0][1] + assert state_data["Inputs"] == {"query": "test"} + assert state_data["Outputs"] == {} + assert state_data["Local"] == {} + + @pytest.mark.asyncio + async def test_get_and_set_values(self, mock_shared_state): + """Test getting and setting values.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # Set a turn value + await state.set("Local.counter", 5) + + # Get the value + result = await state.get("Local.counter") + assert result == 5 + + @pytest.mark.asyncio + async def test_get_inputs(self, mock_shared_state): + """Test getting workflow inputs.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize({"name": "Alice", "age": 30}) + + # Get via path + name = await state.get("Workflow.Inputs.name") + assert name == "Alice" + + # Get all inputs + inputs = await state.get("Workflow.Inputs") + assert inputs == {"name": "Alice", "age": 30} + + @pytest.mark.asyncio + async def test_append_value(self, mock_shared_state): + """Test appending values to a list.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # Append to non-existent list creates it + await state.append("Local.items", "first") + result = await state.get("Local.items") + assert result == ["first"] + + # Append to existing list + await state.append("Local.items", "second") + result = await state.get("Local.items") + assert result == ["first", "second"] + + @pytest.mark.asyncio + async def test_eval_expression(self, mock_shared_state): + """Test evaluating expressions.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # Non-expression returns as-is + result = await state.eval("plain text") + assert result == "plain text" + + # Boolean literals + result = await state.eval("=true") + assert result is True + + result = await state.eval("=false") + assert result is False + + # String literals + result = await state.eval('="hello"') + assert result == "hello" + + # Numeric literals + result = await state.eval("=42") + assert result == 42 + + +class TestDeclarativeActionExecutor: + """Tests for DeclarativeActionExecutor subclasses.""" + + @pytest.fixture + def mock_context(self, mock_shared_state): + """Create a mock workflow context.""" + ctx = MagicMock() + ctx.shared_state = mock_shared_state + ctx.send_message = AsyncMock() + ctx.yield_output = AsyncMock() + return ctx + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + + return shared_state + + @pytest.mark.asyncio + async def test_set_value_executor(self, mock_context, mock_shared_state): + """Test SetValueExecutor.""" + # Initialize state + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "SetValue", + "path": "Local.result", + "value": "test value", + } + executor = SetValueExecutor(action_def) + + # Execute + await executor.handle_action(ActionTrigger(), mock_context) + + # Verify action complete was sent + mock_context.send_message.assert_called_once() + message = mock_context.send_message.call_args[0][0] + assert isinstance(message, ActionComplete) + + @pytest.mark.asyncio + async def test_send_activity_executor(self, mock_context, mock_shared_state): + """Test SendActivityExecutor.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "SendActivity", + "activity": {"text": "Hello, world!"}, + } + executor = SendActivityExecutor(action_def) + + # Execute + await executor.handle_action(ActionTrigger(), mock_context) + + # Verify output was yielded + mock_context.yield_output.assert_called_once_with("Hello, world!") + + # Note: ConditionEvaluatorExecutor tests removed - conditions are now evaluated on edges + + async def test_foreach_init_with_items(self, mock_context, mock_shared_state): + """Test ForeachInitExecutor with items.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.items", ["a", "b", "c"]) + + action_def = { + "kind": "Foreach", + "itemsSource": "=Local.items", + "iteratorVariable": "Local.item", + } + executor = ForeachInitExecutor(action_def) + + # Execute + await executor.handle_action(ActionTrigger(), mock_context) + + # Verify result + mock_context.send_message.assert_called_once() + message = mock_context.send_message.call_args[0][0] + assert isinstance(message, LoopIterationResult) + assert message.has_next is True + assert message.current_index == 0 + assert message.current_item == "a" + + @pytest.mark.asyncio + async def test_foreach_init_empty(self, mock_context, mock_shared_state): + """Test ForeachInitExecutor with empty items list.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # Use a literal empty list - no expression evaluation needed + action_def = { + "kind": "Foreach", + "itemsSource": [], # Direct empty list, not an expression + "iteratorVariable": "Local.item", + } + executor = ForeachInitExecutor(action_def) + + # Execute + await executor.handle_action(ActionTrigger(), mock_context) + + # Verify result + mock_context.send_message.assert_called_once() + message = mock_context.send_message.call_args[0][0] + assert isinstance(message, LoopIterationResult) + assert message.has_next is False + + +class TestDeclarativeWorkflowBuilder: + """Tests for DeclarativeWorkflowBuilder.""" + + def test_all_action_executors_available(self): + """Test that all expected action types have executors.""" + expected_actions = [ + "SetValue", + "SetVariable", + "SendActivity", + "EmitEvent", + "EndWorkflow", + "InvokeAzureAgent", + "Question", + ] + + for action in expected_actions: + assert action in ALL_ACTION_EXECUTORS, f"Missing executor for {action}" + + def test_build_empty_workflow(self): + """Test building a workflow with no actions raises an error.""" + yaml_def = {"name": "empty_workflow", "actions": []} + builder = DeclarativeWorkflowBuilder(yaml_def) + + with pytest.raises(ValueError, match="Cannot build workflow with no actions"): + builder.build() + + def test_build_simple_workflow(self): + """Test building a workflow with simple sequential actions.""" + yaml_def = { + "name": "simple_workflow", + "actions": [ + {"kind": "SendActivity", "id": "greet", "activity": {"text": "Hello!"}}, + {"kind": "SetValue", "id": "set_count", "path": "Local.count", "value": 1}, + ], + } + builder = DeclarativeWorkflowBuilder(yaml_def) + workflow = builder.build() + + assert workflow is not None + # Verify executors were created + assert "greet" in builder._executors + assert "set_count" in builder._executors + + def test_build_workflow_with_if(self): + """Test building a workflow with If control flow.""" + yaml_def = { + "name": "conditional_workflow", + "actions": [ + { + "kind": "If", + "id": "check_flag", + "condition": "=Local.flag", + "then": [ + {"kind": "SendActivity", "id": "say_yes", "activity": {"text": "Yes!"}}, + ], + "else": [ + {"kind": "SendActivity", "id": "say_no", "activity": {"text": "No!"}}, + ], + }, + ], + } + builder = DeclarativeWorkflowBuilder(yaml_def) + workflow = builder.build() + + assert workflow is not None + # Verify branch executors were created + # Note: No join executors - branches wire directly to successor + assert "say_yes" in builder._executors + assert "say_no" in builder._executors + # Entry node is created when If is first action + assert "_workflow_entry" in builder._executors + + def test_build_workflow_with_foreach(self): + """Test building a workflow with Foreach loop.""" + yaml_def = { + "name": "loop_workflow", + "actions": [ + { + "kind": "Foreach", + "id": "process_items", + "itemsSource": "=Local.items", + "iteratorVariable": "Local.item", + "actions": [ + {"kind": "SendActivity", "id": "show_item", "activity": {"text": "=Local.item"}}, + ], + }, + ], + } + builder = DeclarativeWorkflowBuilder(yaml_def) + workflow = builder.build() + + assert workflow is not None + # Verify loop executors were created + assert "process_items_init" in builder._executors + assert "process_items_next" in builder._executors + assert "process_items_exit" in builder._executors + assert "show_item" in builder._executors + + def test_build_workflow_with_switch(self): + """Test building a workflow with Switch control flow.""" + yaml_def = { + "name": "switch_workflow", + "actions": [ + { + "kind": "Switch", + "id": "check_status", + "conditions": [ + { + "condition": '=Local.status = "active"', + "actions": [ + {"kind": "SendActivity", "id": "say_active", "activity": {"text": "Active"}}, + ], + }, + { + "condition": '=Local.status = "pending"', + "actions": [ + {"kind": "SendActivity", "id": "say_pending", "activity": {"text": "Pending"}}, + ], + }, + ], + "else": [ + {"kind": "SendActivity", "id": "say_unknown", "activity": {"text": "Unknown"}}, + ], + }, + ], + } + builder = DeclarativeWorkflowBuilder(yaml_def) + workflow = builder.build() + + assert workflow is not None + # Verify switch executors were created + # Note: No join executors - branches wire directly to successor + assert "say_active" in builder._executors + assert "say_pending" in builder._executors + assert "say_unknown" in builder._executors + # Entry node is created when Switch is first action + assert "_workflow_entry" in builder._executors + + +class TestAgentExecutors: + """Tests for agent-related executors.""" + + @pytest.fixture + def mock_context(self, mock_shared_state): + """Create a mock workflow context.""" + ctx = MagicMock() + ctx.shared_state = mock_shared_state + ctx.send_message = AsyncMock() + ctx.yield_output = AsyncMock() + return ctx + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + + return shared_state + + @pytest.mark.asyncio + async def test_invoke_agent_not_found(self, mock_context, mock_shared_state): + """Test InvokeAzureAgentExecutor raises error when agent not found.""" + from agent_framework_declarative._workflows import ( + AgentInvocationError, + InvokeAzureAgentExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "InvokeAzureAgent", + "agent": "non_existent_agent", + "input": "test input", + } + executor = InvokeAzureAgentExecutor(action_def) + + # Execute - should raise AgentInvocationError + with pytest.raises(AgentInvocationError) as exc_info: + await executor.handle_action(ActionTrigger(), mock_context) + + assert "non_existent_agent" in str(exc_info.value) + assert "not found in registry" in str(exc_info.value) + + +class TestHumanInputExecutors: + """Tests for human input executors.""" + + @pytest.fixture + def mock_context(self, mock_shared_state): + """Create a mock workflow context.""" + ctx = MagicMock() + ctx.shared_state = mock_shared_state + ctx.send_message = AsyncMock() + ctx.yield_output = AsyncMock() + ctx.request_info = AsyncMock() + return ctx + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + + return shared_state + + @pytest.mark.asyncio + async def test_question_executor(self, mock_context, mock_shared_state): + """Test QuestionExecutor.""" + from agent_framework_declarative._workflows import ( + ExternalInputRequest, + QuestionExecutor, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "Question", + "text": "What is your name?", + "property": "Local.name", + "defaultValue": "Anonymous", + } + executor = QuestionExecutor(action_def) + + # Execute + await executor.handle_action(ActionTrigger(), mock_context) + + # Verify request_info was called with ExternalInputRequest + mock_context.request_info.assert_called_once() + request = mock_context.request_info.call_args[0][0] + assert isinstance(request, ExternalInputRequest) + assert request.request_type == "question" + assert "What is your name?" in request.message + + @pytest.mark.asyncio + async def test_confirmation_executor(self, mock_context, mock_shared_state): + """Test ConfirmationExecutor.""" + from agent_framework_declarative._workflows import ( + ConfirmationExecutor, + ExternalInputRequest, + ) + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "Confirmation", + "text": "Do you want to continue?", + "property": "Local.confirmed", + "yesLabel": "Yes, continue", + "noLabel": "No, stop", + } + executor = ConfirmationExecutor(action_def) + + # Execute + await executor.handle_action(ActionTrigger(), mock_context) + + # Verify request_info was called with ExternalInputRequest + mock_context.request_info.assert_called_once() + request = mock_context.request_info.call_args[0][0] + assert isinstance(request, ExternalInputRequest) + assert request.request_type == "confirmation" + assert "continue" in request.message.lower() + + +class TestParseValueExecutor: + """Tests for the ParseValue action executor.""" + + @pytest.fixture + def mock_context(self, mock_shared_state): + """Create a mock workflow context.""" + ctx = MagicMock() + ctx.shared_state = mock_shared_state + ctx.send_message = AsyncMock() + ctx.yield_output = AsyncMock() + return ctx + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + + return shared_state + + @pytest.mark.asyncio + async def test_parse_value_string(self, mock_context, mock_shared_state): + """Test ParseValue with string type.""" + from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.rawValue", "hello world") + + action_def = { + "kind": "ParseValue", + "variable": "Local.parsedValue", + "value": "=Local.rawValue", + "valueType": "string", + } + executor = ParseValueExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.parsedValue") + assert result == "hello world" + + @pytest.mark.asyncio + async def test_parse_value_number(self, mock_context, mock_shared_state): + """Test ParseValue with number type.""" + from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.rawValue", "123") + + action_def = { + "kind": "ParseValue", + "variable": "Local.parsedValue", + "value": "=Local.rawValue", + "valueType": "number", + } + executor = ParseValueExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.parsedValue") + assert result == 123 + + @pytest.mark.asyncio + async def test_parse_value_float(self, mock_context, mock_shared_state): + """Test ParseValue with float number.""" + from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.rawValue", "3.14") + + action_def = { + "kind": "ParseValue", + "variable": "Local.parsedValue", + "value": "=Local.rawValue", + "valueType": "number", + } + executor = ParseValueExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.parsedValue") + assert result == 3.14 + + @pytest.mark.asyncio + async def test_parse_value_boolean_true(self, mock_context, mock_shared_state): + """Test ParseValue with boolean type (true).""" + from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.rawValue", "true") + + action_def = { + "kind": "ParseValue", + "variable": "Local.parsedValue", + "value": "=Local.rawValue", + "valueType": "boolean", + } + executor = ParseValueExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.parsedValue") + assert result is True + + @pytest.mark.asyncio + async def test_parse_value_boolean_false(self, mock_context, mock_shared_state): + """Test ParseValue with boolean type (false).""" + from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.rawValue", "no") + + action_def = { + "kind": "ParseValue", + "variable": "Local.parsedValue", + "value": "=Local.rawValue", + "valueType": "boolean", + } + executor = ParseValueExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.parsedValue") + assert result is False + + @pytest.mark.asyncio + async def test_parse_value_object_from_json(self, mock_context, mock_shared_state): + """Test ParseValue with object type from JSON string.""" + from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.rawValue", '{"name": "Alice", "age": 30}') + + action_def = { + "kind": "ParseValue", + "variable": "Local.parsedValue", + "value": "=Local.rawValue", + "valueType": "object", + } + executor = ParseValueExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.parsedValue") + assert result == {"name": "Alice", "age": 30} + + @pytest.mark.asyncio + async def test_parse_value_array_from_json(self, mock_context, mock_shared_state): + """Test ParseValue with array type from JSON string.""" + from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.rawValue", '["a", "b", "c"]') + + action_def = { + "kind": "ParseValue", + "variable": "Local.parsedValue", + "value": "=Local.rawValue", + "valueType": "array", + } + executor = ParseValueExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.parsedValue") + assert result == ["a", "b", "c"] + + @pytest.mark.asyncio + async def test_parse_value_no_type_conversion(self, mock_context, mock_shared_state): + """Test ParseValue without type conversion.""" + from agent_framework_declarative._workflows._executors_basic import ParseValueExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.rawValue", {"status": "active"}) + + action_def = { + "kind": "ParseValue", + "variable": "Local.parsedValue", + "value": "=Local.rawValue", + } + executor = ParseValueExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.parsedValue") + assert result == {"status": "active"} + + +class TestEditTableExecutor: + """Tests for the EditTable action executor.""" + + @pytest.fixture + def mock_context(self, mock_shared_state): + """Create a mock workflow context.""" + ctx = MagicMock() + ctx.shared_state = mock_shared_state + ctx.send_message = AsyncMock() + ctx.yield_output = AsyncMock() + return ctx + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + + return shared_state + + @pytest.mark.asyncio + async def test_edit_table_add(self, mock_context, mock_shared_state): + """Test EditTable with add operation.""" + from agent_framework_declarative._workflows._executors_basic import EditTableExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.items", ["a", "b"]) + + action_def = { + "kind": "EditTable", + "table": "Local.items", + "operation": "add", + "value": "c", + } + executor = EditTableExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.items") + assert result == ["a", "b", "c"] + + @pytest.mark.asyncio + async def test_edit_table_insert_at_index(self, mock_context, mock_shared_state): + """Test EditTable with insert at specific index.""" + from agent_framework_declarative._workflows._executors_basic import EditTableExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.items", ["a", "c"]) + + action_def = { + "kind": "EditTable", + "table": "Local.items", + "operation": "add", + "value": "b", + "index": 1, + } + executor = EditTableExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.items") + assert result == ["a", "b", "c"] + + @pytest.mark.asyncio + async def test_edit_table_remove_by_value(self, mock_context, mock_shared_state): + """Test EditTable with remove by value.""" + from agent_framework_declarative._workflows._executors_basic import EditTableExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.items", ["a", "b", "c"]) + + action_def = { + "kind": "EditTable", + "table": "Local.items", + "operation": "remove", + "value": "b", + } + executor = EditTableExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.items") + assert result == ["a", "c"] + + @pytest.mark.asyncio + async def test_edit_table_remove_by_index(self, mock_context, mock_shared_state): + """Test EditTable with remove by index.""" + from agent_framework_declarative._workflows._executors_basic import EditTableExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.items", ["a", "b", "c"]) + + action_def = { + "kind": "EditTable", + "table": "Local.items", + "operation": "remove", + "index": 1, + } + executor = EditTableExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.items") + assert result == ["a", "c"] + + @pytest.mark.asyncio + async def test_edit_table_clear(self, mock_context, mock_shared_state): + """Test EditTable with clear operation.""" + from agent_framework_declarative._workflows._executors_basic import EditTableExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.items", ["a", "b", "c"]) + + action_def = { + "kind": "EditTable", + "table": "Local.items", + "operation": "clear", + } + executor = EditTableExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.items") + assert result == [] + + @pytest.mark.asyncio + async def test_edit_table_update_at_index(self, mock_context, mock_shared_state): + """Test EditTable with update at index.""" + from agent_framework_declarative._workflows._executors_basic import EditTableExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.items", ["a", "b", "c"]) + + action_def = { + "kind": "EditTable", + "table": "Local.items", + "operation": "update", + "value": "B", + "index": 1, + } + executor = EditTableExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.items") + assert result == ["a", "B", "c"] + + @pytest.mark.asyncio + async def test_edit_table_creates_new_list(self, mock_context, mock_shared_state): + """Test EditTable creates new list if not exists.""" + from agent_framework_declarative._workflows._executors_basic import EditTableExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "EditTable", + "table": "Local.newItems", + "operation": "add", + "value": "first", + } + executor = EditTableExecutor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.newItems") + assert result == ["first"] + + +class TestEditTableV2Executor: + """Tests for the EditTableV2 action executor.""" + + @pytest.fixture + def mock_context(self, mock_shared_state): + """Create a mock workflow context.""" + ctx = MagicMock() + ctx.shared_state = mock_shared_state + ctx.send_message = AsyncMock() + ctx.yield_output = AsyncMock() + return ctx + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + + return shared_state + + @pytest.mark.asyncio + async def test_edit_table_v2_add(self, mock_context, mock_shared_state): + """Test EditTableV2 with add operation.""" + from agent_framework_declarative._workflows._executors_basic import EditTableV2Executor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.records", [{"id": 1, "name": "Alice"}]) + + action_def = { + "kind": "EditTableV2", + "table": "Local.records", + "operation": "add", + "item": {"id": 2, "name": "Bob"}, + } + executor = EditTableV2Executor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.records") + assert result == [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}] + + @pytest.mark.asyncio + async def test_edit_table_v2_add_or_update_new(self, mock_context, mock_shared_state): + """Test EditTableV2 with addOrUpdate - adding new record.""" + from agent_framework_declarative._workflows._executors_basic import EditTableV2Executor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.records", [{"id": 1, "name": "Alice"}]) + + action_def = { + "kind": "EditTableV2", + "table": "Local.records", + "operation": "addOrUpdate", + "item": {"id": 2, "name": "Bob"}, + "key": "id", + } + executor = EditTableV2Executor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.records") + assert result == [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}] + + @pytest.mark.asyncio + async def test_edit_table_v2_add_or_update_existing(self, mock_context, mock_shared_state): + """Test EditTableV2 with addOrUpdate - updating existing record.""" + from agent_framework_declarative._workflows._executors_basic import EditTableV2Executor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.records", [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]) + + action_def = { + "kind": "EditTableV2", + "table": "Local.records", + "operation": "addOrUpdate", + "item": {"id": 1, "name": "Alice Updated"}, + "key": "id", + } + executor = EditTableV2Executor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.records") + assert result == [{"id": 1, "name": "Alice Updated"}, {"id": 2, "name": "Bob"}] + + @pytest.mark.asyncio + async def test_edit_table_v2_remove_by_key(self, mock_context, mock_shared_state): + """Test EditTableV2 with remove by key.""" + from agent_framework_declarative._workflows._executors_basic import EditTableV2Executor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.records", [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]) + + action_def = { + "kind": "EditTableV2", + "table": "Local.records", + "operation": "remove", + "item": {"id": 1}, + "key": "id", + } + executor = EditTableV2Executor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.records") + assert result == [{"id": 2, "name": "Bob"}] + + @pytest.mark.asyncio + async def test_edit_table_v2_clear(self, mock_context, mock_shared_state): + """Test EditTableV2 with clear operation.""" + from agent_framework_declarative._workflows._executors_basic import EditTableV2Executor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.records", [{"id": 1}, {"id": 2}]) + + action_def = { + "kind": "EditTableV2", + "table": "Local.records", + "operation": "clear", + } + executor = EditTableV2Executor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.records") + assert result == [] + + @pytest.mark.asyncio + async def test_edit_table_v2_update_by_key(self, mock_context, mock_shared_state): + """Test EditTableV2 with update by key.""" + from agent_framework_declarative._workflows._executors_basic import EditTableV2Executor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + await state.set("Local.records", [{"id": 1, "status": "pending"}, {"id": 2, "status": "pending"}]) + + action_def = { + "kind": "EditTableV2", + "table": "Local.records", + "operation": "update", + "item": {"id": 1, "status": "complete"}, + "key": "id", + } + executor = EditTableV2Executor(action_def) + await executor.handle_action(ActionTrigger(), mock_context) + + result = await state.get("Local.records") + assert result == [{"id": 1, "status": "complete"}, {"id": 2, "status": "pending"}] + + +class TestCancelDialogExecutors: + """Tests for CancelDialog and CancelAllDialogs executors.""" + + @pytest.fixture + def mock_context(self, mock_shared_state): + """Create a mock workflow context.""" + ctx = MagicMock() + ctx.shared_state = mock_shared_state + ctx.send_message = AsyncMock() + ctx.yield_output = AsyncMock() + return ctx + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + + return shared_state + + @pytest.mark.asyncio + async def test_cancel_dialog_executor(self, mock_context, mock_shared_state): + """Test CancelDialogExecutor completes without error.""" + from agent_framework_declarative._workflows._executors_control_flow import CancelDialogExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "CancelDialog", + } + executor = CancelDialogExecutor(action_def) + # Should complete without raising + await executor.handle_action(ActionTrigger(), mock_context) + # CancelDialog is a no-op that signals termination + # No assertions needed - just verify it doesn't raise + + @pytest.mark.asyncio + async def test_cancel_all_dialogs_executor(self, mock_context, mock_shared_state): + """Test CancelAllDialogsExecutor completes without error.""" + from agent_framework_declarative._workflows._executors_control_flow import CancelAllDialogsExecutor + + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + action_def = { + "kind": "CancelAllDialogs", + } + executor = CancelAllDialogsExecutor(action_def) + # Should complete without raising + await executor.handle_action(ActionTrigger(), mock_context) + # CancelAllDialogs is a no-op that signals termination + # No assertions needed - just verify it doesn't raise + + +class TestExtractJsonFromResponse: + """Tests for the _extract_json_from_response helper function.""" + + def test_pure_json_object(self): + """Test parsing pure JSON object.""" + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + text = '{"TicketId": "123", "Status": "pending"}' + result = _extract_json_from_response(text) + assert result == {"TicketId": "123", "Status": "pending"} + + def test_pure_json_array(self): + """Test parsing pure JSON array.""" + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + text = '["item1", "item2", "item3"]' + result = _extract_json_from_response(text) + assert result == ["item1", "item2", "item3"] + + def test_json_in_markdown_code_block(self): + """Test extracting JSON from markdown code block.""" + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + text = """Here's the response: +```json +{"TicketId": "456", "Summary": "Test ticket"} +``` +""" + result = _extract_json_from_response(text) + assert result == {"TicketId": "456", "Summary": "Test ticket"} + + def test_json_in_plain_code_block(self): + """Test extracting JSON from plain markdown code block.""" + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + text = """The result: +``` +{"Status": "complete"} +``` +""" + result = _extract_json_from_response(text) + assert result == {"Status": "complete"} + + def test_json_with_leading_text(self): + """Test extracting JSON with leading text.""" + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + text = 'Here is the ticket information: {"TicketId": "789", "Priority": "high"}' + result = _extract_json_from_response(text) + assert result == {"TicketId": "789", "Priority": "high"} + + def test_json_with_trailing_text(self): + """Test extracting JSON with trailing text.""" + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + text = '{"IsResolved": true, "NeedsTicket": false} That is the status.' + result = _extract_json_from_response(text) + assert result == {"IsResolved": True, "NeedsTicket": False} + + def test_nested_json_object(self): + """Test extracting nested JSON object.""" + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + text = 'Result: {"outer": {"inner": {"value": 42}}}' + result = _extract_json_from_response(text) + assert result == {"outer": {"inner": {"value": 42}}} + + def test_json_with_array_inside(self): + """Test extracting JSON with arrays inside.""" + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + text = 'Data: {"items": ["a", "b", "c"], "count": 3}' + result = _extract_json_from_response(text) + assert result == {"items": ["a", "b", "c"], "count": 3} + + def test_json_with_escaped_quotes(self): + """Test extracting JSON with escaped quotes in strings.""" + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + text = r'Response: {"message": "He said \"hello\"", "valid": true}' + result = _extract_json_from_response(text) + assert result == {"message": 'He said "hello"', "valid": True} + + def test_empty_string_returns_none(self): + """Test that empty string returns None.""" + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + result = _extract_json_from_response("") + assert result is None + + def test_whitespace_only_returns_none(self): + """Test that whitespace-only string returns None.""" + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + result = _extract_json_from_response(" \n\t ") + assert result is None + + def test_no_json_raises_error(self): + """Test that text without JSON raises JSONDecodeError.""" + import json + + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + with pytest.raises(json.JSONDecodeError): + _extract_json_from_response("This is just plain text with no JSON") + + def test_json_with_braces_in_string(self): + """Test JSON with braces inside string values.""" + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + text = 'Info: {"template": "Hello {name}, your id is {id}"}' + result = _extract_json_from_response(text) + assert result == {"template": "Hello {name}, your id is {id}"} + + def test_multiple_json_objects_returns_last(self): + """Test that multiple JSON objects returns the last one (final result).""" + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + # Simulates streaming agent output with partial then final result + text = '{"TicketId":"TBD","TicketSummary":"partial"}{"TicketId":"75178c95","TicketSummary":"final result"}' + result = _extract_json_from_response(text) + assert result == {"TicketId": "75178c95", "TicketSummary": "final result"} + + def test_multiple_json_objects_with_different_schemas(self): + """Test multiple JSON objects with different structures returns the last.""" + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + # First object is from one agent, second is from another + text = '{"IsResolved":false,"NeedsTicket":true}{"TicketId":"abc123","Summary":"Issue logged"}' + result = _extract_json_from_response(text) + assert result == {"TicketId": "abc123", "Summary": "Issue logged"} + + def test_multiple_json_objects_with_text_between(self): + """Test multiple JSON objects separated by text.""" + from agent_framework_declarative._workflows._executors_agents import ( + _extract_json_from_response, + ) + + text = 'First: {"status": "pending"} then later: {"status": "complete", "id": 42}' + result = _extract_json_from_response(text) + assert result == {"status": "complete", "id": 42} diff --git a/python/packages/declarative/tests/test_graph_workflow_integration.py b/python/packages/declarative/tests/test_graph_workflow_integration.py new file mode 100644 index 0000000000..00ee3a154f --- /dev/null +++ b/python/packages/declarative/tests/test_graph_workflow_integration.py @@ -0,0 +1,332 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Integration tests for declarative workflows. + +These tests verify: +- End-to-end workflow execution +- Checkpointing at action boundaries +- WorkflowFactory creating graph-based workflows +- Pause/resume capabilities +""" + +import pytest + +from agent_framework_declarative._workflows import ( + ActionTrigger, + DeclarativeWorkflowBuilder, +) +from agent_framework_declarative._workflows._factory import WorkflowFactory + + +class TestGraphBasedWorkflowExecution: + """Integration tests for graph-based workflow execution.""" + + @pytest.mark.asyncio + async def test_simple_sequential_workflow(self): + """Test a simple sequential workflow with SendActivity actions.""" + yaml_def = { + "name": "simple_workflow", + "actions": [ + {"kind": "SendActivity", "id": "greet", "activity": {"text": "Hello!"}}, + {"kind": "SetValue", "id": "set_count", "path": "Local.count", "value": 1}, + {"kind": "SendActivity", "id": "done", "activity": {"text": "Done!"}}, + ], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + workflow = builder.build() + + # Run the workflow + events = await workflow.run(ActionTrigger()) + + # Verify outputs were produced + outputs = events.get_outputs() + assert "Hello!" in outputs + assert "Done!" in outputs + + @pytest.mark.asyncio + async def test_workflow_with_conditional(self): + """Test workflow with If conditional branching.""" + yaml_def = { + "name": "conditional_workflow", + "actions": [ + {"kind": "SetValue", "id": "set_flag", "path": "Local.flag", "value": True}, + { + "kind": "If", + "id": "check_flag", + "condition": "=Local.flag", + "then": [ + {"kind": "SendActivity", "id": "say_yes", "activity": {"text": "Flag is true!"}}, + ], + "else": [ + {"kind": "SendActivity", "id": "say_no", "activity": {"text": "Flag is false!"}}, + ], + }, + ], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + workflow = builder.build() + + # Run the workflow + events = await workflow.run(ActionTrigger()) + outputs = events.get_outputs() + + # Should take the "then" branch since flag is True + assert "Flag is true!" in outputs + assert "Flag is false!" not in outputs + + @pytest.mark.asyncio + async def test_workflow_with_foreach_loop(self): + """Test workflow with Foreach loop.""" + yaml_def = { + "name": "loop_workflow", + "actions": [ + {"kind": "SetValue", "id": "set_items", "path": "Local.items", "value": ["a", "b", "c"]}, + { + "kind": "Foreach", + "id": "process_items", + "itemsSource": "=Local.items", + "iteratorVariable": "Local.item", + "actions": [ + {"kind": "SendActivity", "id": "show_item", "activity": {"text": "=Local.item"}}, + ], + }, + ], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + workflow = builder.build() + + # Run the workflow + events = await workflow.run(ActionTrigger()) + outputs = events.get_outputs() + + # Should output each item + assert "a" in outputs + assert "b" in outputs + assert "c" in outputs + + @pytest.mark.asyncio + async def test_workflow_with_switch(self): + """Test workflow with Switch/ConditionGroup.""" + yaml_def = { + "name": "switch_workflow", + "actions": [ + {"kind": "SetValue", "id": "set_level", "path": "Local.level", "value": 2}, + { + "kind": "Switch", + "id": "check_level", + "conditions": [ + { + "condition": "=Local.level = 1", + "actions": [ + {"kind": "SendActivity", "id": "level_1", "activity": {"text": "Level 1"}}, + ], + }, + { + "condition": "=Local.level = 2", + "actions": [ + {"kind": "SendActivity", "id": "level_2", "activity": {"text": "Level 2"}}, + ], + }, + ], + "else": [ + {"kind": "SendActivity", "id": "default", "activity": {"text": "Other level"}}, + ], + }, + ], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + workflow = builder.build() + + # Run the workflow + events = await workflow.run(ActionTrigger()) + outputs = events.get_outputs() + + # Should take the level 2 branch + assert "Level 2" in outputs + assert "Level 1" not in outputs + assert "Other level" not in outputs + + +class TestWorkflowFactory: + """Tests for WorkflowFactory.""" + + def test_factory_creates_workflow(self): + """Test creating workflow.""" + factory = WorkflowFactory() + + yaml_content = """ +name: test_workflow +actions: + - kind: SendActivity + id: greet + activity: + text: "Hello from graph mode!" + - kind: SetValue + id: set_val + path: Local.result + value: 42 +""" + workflow = factory.create_workflow_from_yaml(yaml_content) + + assert workflow is not None + assert hasattr(workflow, "_declarative_agents") + + @pytest.mark.asyncio + async def test_workflow_execution(self): + """Test executing a workflow.""" + factory = WorkflowFactory() + + yaml_content = """ +name: graph_execution_test +actions: + - kind: SendActivity + id: start + activity: + text: "Starting workflow" + - kind: SetValue + id: set_message + path: Local.message + value: "Hello World" + - kind: SendActivity + id: end + activity: + text: "Workflow complete" +""" + workflow = factory.create_workflow_from_yaml(yaml_content) + + # Execute the workflow + events = await workflow.run(ActionTrigger()) + outputs = events.get_outputs() + + assert "Starting workflow" in outputs + assert "Workflow complete" in outputs + + +class TestGraphWorkflowCheckpointing: + """Tests for checkpointing capabilities of graph-based workflows.""" + + def test_workflow_has_multiple_executors(self): + """Test that graph-based workflow creates multiple executor nodes.""" + yaml_def = { + "name": "multi_executor_workflow", + "actions": [ + {"kind": "SetValue", "id": "step1", "path": "Local.a", "value": 1}, + {"kind": "SetValue", "id": "step2", "path": "Local.b", "value": 2}, + {"kind": "SetValue", "id": "step3", "path": "Local.c", "value": 3}, + ], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + _workflow = builder.build() # noqa: F841 + + # Verify multiple executors were created + assert "step1" in builder._executors + assert "step2" in builder._executors + assert "step3" in builder._executors + assert len(builder._executors) == 3 + + def test_workflow_executor_connectivity(self): + """Test that executors are properly connected in sequence.""" + yaml_def = { + "name": "connected_workflow", + "actions": [ + {"kind": "SendActivity", "id": "a", "activity": {"text": "A"}}, + {"kind": "SendActivity", "id": "b", "activity": {"text": "B"}}, + {"kind": "SendActivity", "id": "c", "activity": {"text": "C"}}, + ], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + workflow = builder.build() + + # Verify all executors exist + assert len(builder._executors) == 3 + + # Verify the workflow can be inspected + assert workflow is not None + + +class TestGraphWorkflowVisualization: + """Tests for workflow visualization capabilities.""" + + def test_workflow_can_be_built(self): + """Test that complex workflows can be built successfully.""" + yaml_def = { + "name": "complex_workflow", + "actions": [ + {"kind": "SendActivity", "id": "intro", "activity": {"text": "Starting"}}, + { + "kind": "If", + "id": "branch", + "condition": "=true", + "then": [ + {"kind": "SendActivity", "id": "then_msg", "activity": {"text": "Then branch"}}, + ], + "else": [ + {"kind": "SendActivity", "id": "else_msg", "activity": {"text": "Else branch"}}, + ], + }, + {"kind": "SendActivity", "id": "outro", "activity": {"text": "Done"}}, + ], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + workflow = builder.build() + + # Verify the workflow was built + assert workflow is not None + + # Verify expected executors exist + # intro, branch_condition, then_msg, else_msg, branch_join, outro + assert "intro" in builder._executors + assert "then_msg" in builder._executors + assert "else_msg" in builder._executors + assert "outro" in builder._executors + + +class TestGraphWorkflowStateManagement: + """Tests for state management across graph executor nodes.""" + + @pytest.mark.asyncio + async def test_state_persists_across_executors(self): + """Test that state set in one executor is available in the next.""" + yaml_def = { + "name": "state_test", + "actions": [ + {"kind": "SetValue", "id": "set", "path": "Local.value", "value": "test_data"}, + {"kind": "SendActivity", "id": "send", "activity": {"text": "=Local.value"}}, + ], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + workflow = builder.build() + + events = await workflow.run(ActionTrigger()) + outputs = events.get_outputs() + + # The SendActivity should have access to the value set by SetValue + assert "test_data" in outputs + + @pytest.mark.asyncio + async def test_multiple_variables(self): + """Test setting and using multiple variables.""" + yaml_def = { + "name": "multi_var_test", + "actions": [ + {"kind": "SetValue", "id": "set_a", "path": "Local.a", "value": "Hello"}, + {"kind": "SetValue", "id": "set_b", "path": "Local.b", "value": "World"}, + {"kind": "SendActivity", "id": "send", "activity": {"text": "=Local.a"}}, + ], + } + + builder = DeclarativeWorkflowBuilder(yaml_def) + workflow = builder.build() + + events = await workflow.run(ActionTrigger()) + outputs = events.get_outputs() + + assert "Hello" in outputs diff --git a/python/packages/declarative/tests/test_powerfx_functions.py b/python/packages/declarative/tests/test_powerfx_functions.py new file mode 100644 index 0000000000..050fa96786 --- /dev/null +++ b/python/packages/declarative/tests/test_powerfx_functions.py @@ -0,0 +1,242 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for custom PowerFx-like functions.""" + +from agent_framework_declarative._workflows._powerfx_functions import ( + CUSTOM_FUNCTIONS, + assistant_message, + concat_text, + count_rows, + find, + first, + is_blank, + last, + lower, + message_text, + search_table, + system_message, + upper, + user_message, +) + + +class TestMessageText: + """Tests for MessageText function.""" + + def test_message_text_from_string(self): + """Test extracting text from a plain string.""" + assert message_text("Hello") == "Hello" + + def test_message_text_from_single_dict(self): + """Test extracting text from a single message dict.""" + msg = {"role": "assistant", "content": "Hello world"} + assert message_text(msg) == "Hello world" + + def test_message_text_from_list(self): + """Test extracting text from a list of messages.""" + msgs = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ] + assert message_text(msgs) == "Hi Hello" + + def test_message_text_from_none(self): + """Test that None returns empty string.""" + assert message_text(None) == "" + + def test_message_text_empty_list(self): + """Test that empty list returns empty string.""" + assert message_text([]) == "" + + +class TestUserMessage: + """Tests for UserMessage function.""" + + def test_user_message_creates_dict(self): + """Test that UserMessage creates correct dict.""" + msg = user_message("Hello") + assert msg == {"role": "user", "content": "Hello"} + + def test_user_message_with_none(self): + """Test UserMessage with None.""" + msg = user_message(None) + assert msg == {"role": "user", "content": ""} + + +class TestAssistantMessage: + """Tests for AssistantMessage function.""" + + def test_assistant_message_creates_dict(self): + """Test that AssistantMessage creates correct dict.""" + msg = assistant_message("Hello") + assert msg == {"role": "assistant", "content": "Hello"} + + +class TestSystemMessage: + """Tests for SystemMessage function.""" + + def test_system_message_creates_dict(self): + """Test that SystemMessage creates correct dict.""" + msg = system_message("You are helpful") + assert msg == {"role": "system", "content": "You are helpful"} + + +class TestIsBlank: + """Tests for IsBlank function.""" + + def test_is_blank_none(self): + """Test that None is blank.""" + assert is_blank(None) is True + + def test_is_blank_empty_string(self): + """Test that empty string is blank.""" + assert is_blank("") is True + + def test_is_blank_whitespace(self): + """Test that whitespace-only string is blank.""" + assert is_blank(" ") is True + + def test_is_blank_empty_list(self): + """Test that empty list is blank.""" + assert is_blank([]) is True + + def test_is_blank_non_empty(self): + """Test that non-empty values are not blank.""" + assert is_blank("hello") is False + assert is_blank([1, 2, 3]) is False + assert is_blank(0) is False + + +class TestCountRows: + """Tests for CountRows function.""" + + def test_count_rows_list(self): + """Test counting list items.""" + assert count_rows([1, 2, 3]) == 3 + + def test_count_rows_empty(self): + """Test counting empty list.""" + assert count_rows([]) == 0 + + def test_count_rows_none(self): + """Test counting None.""" + assert count_rows(None) == 0 + + +class TestFirstLast: + """Tests for First and Last functions.""" + + def test_first_returns_first_item(self): + """Test that First returns first item.""" + assert first([1, 2, 3]) == 1 + + def test_last_returns_last_item(self): + """Test that Last returns last item.""" + assert last([1, 2, 3]) == 3 + + def test_first_empty_returns_none(self): + """Test that First returns None for empty list.""" + assert first([]) is None + + def test_last_empty_returns_none(self): + """Test that Last returns None for empty list.""" + assert last([]) is None + + +class TestFind: + """Tests for Find function.""" + + def test_find_substring(self): + """Test finding a substring.""" + result = find("world", "Hello world") + assert result == 7 # 1-based index + + def test_find_not_found(self): + """Test when substring not found - returns Blank (None) per PowerFx semantics.""" + result = find("xyz", "Hello world") + assert result is None + + def test_find_at_start(self): + """Test finding at start of string.""" + result = find("Hello", "Hello world") + assert result == 1 + + +class TestUpperLower: + """Tests for Upper and Lower functions.""" + + def test_upper(self): + """Test uppercase conversion.""" + assert upper("hello") == "HELLO" + + def test_lower(self): + """Test lowercase conversion.""" + assert lower("HELLO") == "hello" + + def test_upper_none(self): + """Test upper with None.""" + assert upper(None) == "" + + +class TestConcatText: + """Tests for Concat function.""" + + def test_concat_simple_list(self): + """Test concatenating simple list.""" + assert concat_text(["a", "b", "c"], separator=", ") == "a, b, c" + + def test_concat_with_field(self): + """Test concatenating with field extraction.""" + items = [{"name": "Alice"}, {"name": "Bob"}] + assert concat_text(items, field="name", separator=", ") == "Alice, Bob" + + +class TestSearchTable: + """Tests for Search function.""" + + def test_search_finds_matching(self): + """Test search finds matching items.""" + items = [ + {"name": "Alice", "age": 30}, + {"name": "Bob", "age": 25}, + {"name": "Charlie", "age": 35}, + ] + result = search_table(items, "Bob", "name") + assert len(result) == 1 + assert result[0]["name"] == "Bob" + + def test_search_case_insensitive(self): + """Test search is case insensitive.""" + items = [{"name": "Alice"}] + result = search_table(items, "alice", "name") + assert len(result) == 1 + + def test_search_partial_match(self): + """Test search finds partial matches.""" + items = [{"name": "Alice Smith"}, {"name": "Bob Jones"}] + result = search_table(items, "Smith", "name") + assert len(result) == 1 + + +class TestCustomFunctionsRegistry: + """Tests for the CUSTOM_FUNCTIONS registry.""" + + def test_all_functions_registered(self): + """Test that all functions are in the registry.""" + expected = [ + "MessageText", + "UserMessage", + "AssistantMessage", + "SystemMessage", + "IsBlank", + "CountRows", + "First", + "Last", + "Find", + "Upper", + "Lower", + "Concat", + "Search", + ] + for name in expected: + assert name in CUSTOM_FUNCTIONS diff --git a/python/packages/declarative/tests/test_powerfx_yaml_compatibility.py b/python/packages/declarative/tests/test_powerfx_yaml_compatibility.py new file mode 100644 index 0000000000..91cf378578 --- /dev/null +++ b/python/packages/declarative/tests/test_powerfx_yaml_compatibility.py @@ -0,0 +1,581 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests to ensure PowerFx evaluation supports all expressions used in declarative YAML workflows. + +This test suite validates that all PowerFx expressions found in the sample YAML workflows +under samples/getting_started/workflows/declarative/ work correctly with our implementation. + +Coverage includes: +- Built-in PowerFx functions: Concat, If, IsBlank, Not, Or, Upper, Find +- Custom functions: UserMessage, MessageText +- System variables: System.ConversationId, System.LastMessage.Text +- Local/turn variables with nested access +- Comparison operators: <, >, <=, >=, <>, = +- Logical operators: And, Or, Not, ! +- Arithmetic operators: +, -, *, / +- String interpolation: {Variable.Path} +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agent_framework_declarative._workflows._declarative_base import ( + DeclarativeWorkflowState, +) + + +class TestPowerFxBuiltinFunctions: + """Test PowerFx built-in functions used in YAML workflows.""" + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state with async get/set methods.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + return shared_state + + async def test_concat_simple(self, mock_shared_state): + """Test Concat function with simple strings.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: =Concat("Nice to meet you, ", Local.userName, "!") + await state.set("Local.userName", "Alice") + result = await state.eval('=Concat("Nice to meet you, ", Local.userName, "!")') + assert result == "Nice to meet you, Alice!" + + async def test_concat_multiple_args(self, mock_shared_state): + """Test Concat with multiple arguments.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: =Concat(Local.greeting, ", ", Local.name, "!") + await state.set("Local.greeting", "Hello") + await state.set("Local.name", "World") + result = await state.eval('=Concat(Local.greeting, ", ", Local.name, "!")') + assert result == "Hello, World!" + + async def test_concat_with_local_namespace(self, mock_shared_state): + """Test Concat using Local.* namespace (maps to Local.*).""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: =Concat("Starting math coaching session for: ", Local.Problem) + await state.set("Local.Problem", "2 + 2") + result = await state.eval('=Concat("Starting math coaching session for: ", Local.Problem)') + assert result == "Starting math coaching session for: 2 + 2" + + async def test_if_with_isblank(self, mock_shared_state): + """Test If function with IsBlank.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize({"name": ""}) + + # From YAML: =If(IsBlank(inputs.name), "World", inputs.name) + # When input is blank + result = await state.eval('=If(IsBlank(Workflow.Inputs.name), "World", Workflow.Inputs.name)') + assert result == "World" + + # When input is provided + await state.initialize({"name": "Alice"}) + result = await state.eval('=If(IsBlank(Workflow.Inputs.name), "World", Workflow.Inputs.name)') + assert result == "Alice" + + async def test_not_function(self, mock_shared_state): + """Test Not function.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: =Not(Local.EscalationParameters.IsComplete) + await state.set("Local.EscalationParameters", {"IsComplete": False}) + result = await state.eval("=Not(Local.EscalationParameters.IsComplete)") + assert result is True + + await state.set("Local.EscalationParameters", {"IsComplete": True}) + result = await state.eval("=Not(Local.EscalationParameters.IsComplete)") + assert result is False + + async def test_or_function(self, mock_shared_state): + """Test Or function.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: =Or(Local.feeling = "great", Local.feeling = "good") + await state.set("Local.feeling", "great") + result = await state.eval('=Or(Local.feeling = "great", Local.feeling = "good")') + assert result is True + + await state.set("Local.feeling", "good") + result = await state.eval('=Or(Local.feeling = "great", Local.feeling = "good")') + assert result is True + + await state.set("Local.feeling", "bad") + result = await state.eval('=Or(Local.feeling = "great", Local.feeling = "good")') + assert result is False + + async def test_upper_function(self, mock_shared_state): + """Test Upper function.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: =Upper(System.LastMessage.Text) + await state.set("System.LastMessage", {"Text": "hello world"}) + result = await state.eval("=Upper(System.LastMessage.Text)") + assert result == "HELLO WORLD" + + async def test_find_function(self, mock_shared_state): + """Test Find function.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: =!IsBlank(Find("CONGRATULATIONS", Upper(Local.TeacherResponse))) + await state.set("Local.TeacherResponse", "CONGRATULATIONS! You solved it!") + result = await state.eval('=Not(IsBlank(Find("CONGRATULATIONS", Upper(Local.TeacherResponse))))') + assert result is True + + await state.set("Local.TeacherResponse", "Try again") + result = await state.eval('=Not(IsBlank(Find("CONGRATULATIONS", Upper(Local.TeacherResponse))))') + assert result is False + + +class TestPowerFxSystemVariables: + """Test System.* variable access.""" + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + return shared_state + + async def test_system_conversation_id(self, mock_shared_state): + """Test System.ConversationId access.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: conversationId: =System.ConversationId + await state.set("System.ConversationId", "conv-12345") + result = await state.eval("=System.ConversationId") + assert result == "conv-12345" + + async def test_system_last_message_text(self, mock_shared_state): + """Test System.LastMessage.Text access.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: =Upper(System.LastMessage.Text) <> "EXIT" + await state.set("System.LastMessage", {"Text": "Hello"}) + result = await state.eval("=System.LastMessage.Text") + assert result == "Hello" + + async def test_system_last_message_exit_check(self, mock_shared_state): + """Test the exit check pattern from YAML.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: when: =Upper(System.LastMessage.Text) <> "EXIT" + await state.set("System.LastMessage", {"Text": "hello"}) + result = await state.eval('=Upper(System.LastMessage.Text) <> "EXIT"') + assert result is True + + await state.set("System.LastMessage", {"Text": "exit"}) + result = await state.eval('=Upper(System.LastMessage.Text) <> "EXIT"') + assert result is False + + +class TestPowerFxComparisonOperators: + """Test comparison operators used in YAML workflows.""" + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + return shared_state + + async def test_less_than(self, mock_shared_state): + """Test < operator.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: condition: =Local.age < 65 + await state.set("Local.age", 30) + assert await state.eval("=Local.age < 65") is True + + await state.set("Local.age", 70) + assert await state.eval("=Local.age < 65") is False + + async def test_less_than_with_local(self, mock_shared_state): + """Test < with Local namespace.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: condition: =Local.TurnCount < 4 + await state.set("Local.TurnCount", 2) + assert await state.eval("=Local.TurnCount < 4") is True + + await state.set("Local.TurnCount", 5) + assert await state.eval("=Local.TurnCount < 4") is False + + async def test_equality(self, mock_shared_state): + """Test = equality operator.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: =Local.feeling = "great" + await state.set("Local.feeling", "great") + assert await state.eval('=Local.feeling = "great"') is True + + await state.set("Local.feeling", "bad") + assert await state.eval('=Local.feeling = "great"') is False + + async def test_inequality(self, mock_shared_state): + """Test <> inequality operator.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: =Upper(System.LastMessage.Text) <> "EXIT" + await state.set("Local.status", "active") + assert await state.eval('=Local.status <> "done"') is True + assert await state.eval('=Local.status <> "active"') is False + + +class TestPowerFxArithmetic: + """Test arithmetic operations.""" + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + return shared_state + + async def test_addition(self, mock_shared_state): + """Test + operator.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: value: =Local.TurnCount + 1 + await state.set("Local.TurnCount", 3) + result = await state.eval("=Local.TurnCount + 1") + assert result == 4 + + +class TestPowerFxCustomFunctions: + """Test custom functions (UserMessage, MessageText, AgentMessage).""" + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + return shared_state + + @pytest.mark.asyncio + async def test_agent_message_function(self, mock_shared_state): + """Test AgentMessage function (.NET compatibility alias for AssistantMessage).""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From .NET YAML: messages: =AgentMessage(Local.Response) + await state.set("Local.Response", "Here is the analysis result") + result = await state.eval("=AgentMessage(Local.Response)") + + assert isinstance(result, dict) + assert result["role"] == "assistant" + assert result["text"] == "Here is the analysis result" + + @pytest.mark.asyncio + async def test_agent_message_with_empty_string(self, mock_shared_state): + """Test AgentMessage with empty string.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + await state.set("Local.Response", "") + result = await state.eval("=AgentMessage(Local.Response)") + + assert result["role"] == "assistant" + assert result["text"] == "" + + @pytest.mark.asyncio + async def test_user_message_with_variable(self, mock_shared_state): + """Test UserMessage function with variable reference.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: messages: =UserMessage(Local.ServiceParameters.IssueDescription) + await state.set("Local.ServiceParameters", {"IssueDescription": "My computer won't boot"}) + result = await state.eval("=UserMessage(Local.ServiceParameters.IssueDescription)") + + assert isinstance(result, dict) + assert result["role"] == "user" + assert result["text"] == "My computer won't boot" + + async def test_user_message_with_simple_variable(self, mock_shared_state): + """Test UserMessage with simple variable.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: messages: =Local.Problem + await state.set("Local.Problem", "What is 2+2?") + result = await state.eval("=UserMessage(Local.Problem)") + + assert result["role"] == "user" + assert result["text"] == "What is 2+2?" + + async def test_message_text_with_list(self, mock_shared_state): + """Test MessageText extracts text from message list.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + await state.set( + "Local.messages", + [ + {"role": "user", "text": "Hello"}, + {"role": "assistant", "text": "Hi there!"}, + ], + ) + result = await state.eval("=MessageText(Local.messages)") + assert result == "Hi there!" + + async def test_message_text_empty_list(self, mock_shared_state): + """Test MessageText with empty list returns empty string.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + await state.set("Local.messages", []) + result = await state.eval("=MessageText(Local.messages)") + assert result == "" + + +class TestPowerFxNestedVariables: + """Test nested variable access patterns from YAML.""" + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + return shared_state + + async def test_nested_local_variable(self, mock_shared_state): + """Test nested Local.* variable access.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: =Local.ServiceParameters.IssueDescription + await state.set("Local.ServiceParameters", {"IssueDescription": "Screen is black"}) + result = await state.eval("=Local.ServiceParameters.IssueDescription") + assert result == "Screen is black" + + async def test_nested_routing_parameters(self, mock_shared_state): + """Test RoutingParameters access.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: =Local.RoutingParameters.TeamName + await state.set("Local.RoutingParameters", {"TeamName": "Windows Support"}) + result = await state.eval("=Local.RoutingParameters.TeamName") + assert result == "Windows Support" + + async def test_nested_ticket_parameters(self, mock_shared_state): + """Test TicketParameters access.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: =Local.TicketParameters.TicketId + await state.set("Local.TicketParameters", {"TicketId": "TKT-12345"}) + result = await state.eval("=Local.TicketParameters.TicketId") + assert result == "TKT-12345" + + +class TestPowerFxUndefinedVariables: + """Test graceful handling of undefined variables.""" + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + return shared_state + + async def test_undefined_local_variable_returns_none(self, mock_shared_state): + """Test that undefined Local.* variables return None.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # Variable not set - should return None (not raise) + result = await state.eval("=Local.UndefinedVariable") + assert result is None + + async def test_undefined_nested_variable_returns_none(self, mock_shared_state): + """Test that undefined nested variables return None.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # Nested undefined variable + result = await state.eval("=Local.Something.Nested.Deep") + assert result is None + + +class TestStringInterpolation: + """Test string interpolation patterns.""" + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + return shared_state + + async def test_interpolate_local_variable(self, mock_shared_state): + """Test {Local.Variable} interpolation.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: activity: "Created ticket #{Local.TicketParameters.TicketId}" + await state.set("Local.TicketParameters", {"TicketId": "TKT-999"}) + result = await state.interpolate_string("Created ticket #{Local.TicketParameters.TicketId}") + assert result == "Created ticket #TKT-999" + + async def test_interpolate_routing_team(self, mock_shared_state): + """Test routing team interpolation.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize() + + # From YAML: activity: Routing to {Local.RoutingParameters.TeamName} + await state.set("Local.RoutingParameters", {"TeamName": "Linux Support"}) + result = await state.interpolate_string("Routing to {Local.RoutingParameters.TeamName}") + assert result == "Routing to Linux Support" + + +class TestWorkflowInputsAccess: + """Test Workflow.Inputs access patterns.""" + + @pytest.fixture + def mock_shared_state(self): + """Create a mock shared state.""" + shared_state = MagicMock() + shared_state._data = {} + + async def mock_get(key): + if key not in shared_state._data: + raise KeyError(key) + return shared_state._data[key] + + async def mock_set(key, value): + shared_state._data[key] = value + + shared_state.get = AsyncMock(side_effect=mock_get) + shared_state.set = AsyncMock(side_effect=mock_set) + return shared_state + + async def test_inputs_name(self, mock_shared_state): + """Test inputs.name access.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize({"name": "Alice", "age": 25}) + + # .NET style (standard) + result = await state.eval("=Workflow.Inputs.name") + assert result == "Alice" + + # Also test inputs.name shorthand + result = await state.eval("=inputs.name") + assert result == "Alice" + + async def test_inputs_problem(self, mock_shared_state): + """Test inputs.problem access.""" + state = DeclarativeWorkflowState(mock_shared_state) + await state.initialize({"problem": "What is 5 * 6?"}) + + # .NET style (standard) + result = await state.eval("=Workflow.Inputs.problem") + assert result == "What is 5 * 6?" diff --git a/python/packages/declarative/tests/test_workflow_factory.py b/python/packages/declarative/tests/test_workflow_factory.py new file mode 100644 index 0000000000..8bad4651f0 --- /dev/null +++ b/python/packages/declarative/tests/test_workflow_factory.py @@ -0,0 +1,279 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for WorkflowFactory.""" + +import pytest + +from agent_framework_declarative._workflows._factory import ( + DeclarativeWorkflowError, + WorkflowFactory, +) + + +class TestWorkflowFactoryValidation: + """Tests for workflow definition validation.""" + + def test_missing_actions_raises(self): + """Test that missing 'actions' field raises an error.""" + factory = WorkflowFactory() + with pytest.raises(DeclarativeWorkflowError, match="must have 'actions' field"): + factory.create_workflow_from_yaml(""" +name: test-workflow +description: A test +# Missing 'actions' field +""") + + def test_actions_not_list_raises(self): + """Test that non-list 'actions' field raises an error.""" + factory = WorkflowFactory() + with pytest.raises(DeclarativeWorkflowError, match="'actions' must be a list"): + factory.create_workflow_from_yaml(""" +name: test-workflow +actions: "not a list" +""") + + def test_action_missing_kind_raises(self): + """Test that actions without 'kind' field raise an error.""" + factory = WorkflowFactory() + with pytest.raises(DeclarativeWorkflowError, match="missing 'kind' field"): + factory.create_workflow_from_yaml(""" +name: test-workflow +actions: + - path: Local.value + value: test +""") + + def test_valid_minimal_workflow(self): + """Test creating a valid minimal workflow.""" + factory = WorkflowFactory() + workflow = factory.create_workflow_from_yaml(""" +name: minimal-workflow +actions: + - kind: SetValue + path: Local.result + value: done +""") + + assert workflow is not None + assert workflow.name == "minimal-workflow" + + +class TestWorkflowFactoryExecution: + """Tests for workflow execution.""" + + @pytest.mark.asyncio + async def test_execute_set_value_workflow(self): + """Test executing a simple SetValue workflow.""" + factory = WorkflowFactory() + workflow = factory.create_workflow_from_yaml(""" +name: set-value-test +actions: + - kind: SetValue + path: Local.greeting + value: Hello + - kind: SendActivity + activity: + text: Done +""") + + result = await workflow.run({"input": "test"}) + outputs = result.get_outputs() + + # The workflow should produce output from SendActivity + assert len(outputs) > 0 + + @pytest.mark.asyncio + async def test_execute_send_activity_workflow(self): + """Test executing a workflow that sends activities.""" + factory = WorkflowFactory() + workflow = factory.create_workflow_from_yaml(""" +name: send-activity-test +actions: + - kind: SendActivity + activity: + text: Hello, world! +""") + + result = await workflow.run({"input": "test"}) + outputs = result.get_outputs() + + # Should have a TextOutputEvent + assert len(outputs) >= 1 + + @pytest.mark.asyncio + async def test_execute_foreach_workflow(self): + """Test executing a workflow with foreach.""" + factory = WorkflowFactory() + workflow = factory.create_workflow_from_yaml(""" +name: foreach-test +actions: + - kind: Foreach + source: + - apple + - banana + - cherry + itemName: fruit + actions: + - kind: AppendValue + path: Local.fruits + value: processed +""") + + _result = await workflow.run({}) # noqa: F841 + # The foreach should have processed 3 items + # We can check this by examining the workflow outputs + + @pytest.mark.asyncio + async def test_execute_if_workflow(self): + """Test executing a workflow with conditional branching.""" + factory = WorkflowFactory() + workflow = factory.create_workflow_from_yaml(""" +name: if-test +actions: + - kind: If + condition: true + then: + - kind: SendActivity + activity: + text: Condition was true + else: + - kind: SendActivity + activity: + text: Condition was false +""") + + result = await workflow.run({}) + outputs = result.get_outputs() + + # Check for the expected text in WorkflowOutputEvent + _text_outputs = [str(o) for o in outputs if isinstance(o, str) or hasattr(o, "data")] # noqa: F841 + assert any("Condition was true" in str(o) for o in outputs) + + +class TestWorkflowFactoryAgentRegistration: + """Tests for agent registration.""" + + def test_register_agent(self): + """Test registering an agent with the factory.""" + + class MockAgent: + name = "mock-agent" + + factory = WorkflowFactory() + factory.register_agent("myAgent", MockAgent()) + + assert "myAgent" in factory._agents + + def test_register_binding(self): + """Test registering a binding with the factory.""" + + def my_function(x): + return x * 2 + + factory = WorkflowFactory() + factory.register_binding("double", my_function) + + assert "double" in factory._bindings + assert factory._bindings["double"](5) == 10 + + +class TestWorkflowFactoryFromPath: + """Tests for loading workflows from file paths.""" + + def test_nonexistent_file_raises(self, tmp_path): + """Test that loading from a nonexistent file raises FileNotFoundError.""" + factory = WorkflowFactory() + with pytest.raises(FileNotFoundError): + factory.create_workflow_from_yaml_path(tmp_path / "nonexistent.yaml") + + def test_load_from_file(self, tmp_path): + """Test loading a workflow from a file.""" + workflow_file = tmp_path / "Workflow.yaml" + workflow_file.write_text(""" +name: file-workflow +actions: + - kind: SetValue + path: Local.loaded + value: true +""") + + factory = WorkflowFactory() + workflow = factory.create_workflow_from_yaml_path(workflow_file) + + assert workflow is not None + assert workflow.name == "file-workflow" + + +class TestDisplayNameMetadata: + """Tests for displayName metadata support.""" + + @pytest.mark.asyncio + async def test_action_with_display_name(self): + """Test executing an action with displayName metadata.""" + factory = WorkflowFactory() + workflow = factory.create_workflow_from_yaml(""" +name: display-name-test +actions: + - kind: SetValue + id: set_greeting + displayName: Set the greeting message + path: Local.greeting + value: Hello + - kind: SendActivity + id: send_greeting + displayName: Send greeting to user + activity: + text: Hello, world! +""") + + result = await workflow.run({"input": "test"}) + outputs = result.get_outputs() + + # Should execute successfully with displayName metadata + assert len(outputs) >= 1 + + def test_action_context_display_name_property(self): + """Test that ActionContext provides displayName property.""" + from agent_framework_declarative._workflows._handlers import ActionContext + from agent_framework_declarative._workflows._state import WorkflowState + + state = WorkflowState() + ctx = ActionContext( + state=state, + action={ + "kind": "SetValue", + "id": "test_action", + "displayName": "Test Action Display Name", + "path": "Local.value", + "value": "test", + }, + execute_actions=lambda a, s: None, + agents={}, + bindings={}, + ) + + assert ctx.action_id == "test_action" + assert ctx.display_name == "Test Action Display Name" + assert ctx.action_kind == "SetValue" + + def test_action_context_without_display_name(self): + """Test ActionContext when displayName is not provided.""" + from agent_framework_declarative._workflows._handlers import ActionContext + from agent_framework_declarative._workflows._state import WorkflowState + + state = WorkflowState() + ctx = ActionContext( + state=state, + action={ + "kind": "SetValue", + "path": "Local.value", + "value": "test", + }, + execute_actions=lambda a, s: None, + agents={}, + bindings={}, + ) + + assert ctx.action_id is None + assert ctx.display_name is None + assert ctx.action_kind == "SetValue" diff --git a/python/packages/declarative/tests/test_workflow_handlers.py b/python/packages/declarative/tests/test_workflow_handlers.py new file mode 100644 index 0000000000..88aa565c9b --- /dev/null +++ b/python/packages/declarative/tests/test_workflow_handlers.py @@ -0,0 +1,424 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for action handlers.""" + +from collections.abc import AsyncGenerator +from typing import Any + +import pytest + +# Import handlers to register them +from agent_framework_declarative._workflows import ( + _actions_basic, # noqa: F401 + _actions_control_flow, # noqa: F401 + _actions_error, # noqa: F401 +) +from agent_framework_declarative._workflows._handlers import ( + ActionContext, + CustomEvent, + TextOutputEvent, + WorkflowEvent, + get_action_handler, + list_action_handlers, +) +from agent_framework_declarative._workflows._state import WorkflowState + + +def create_action_context( + action: dict[str, Any], + inputs: dict[str, Any] | None = None, + agents: dict[str, Any] | None = None, + bindings: dict[str, Any] | None = None, +) -> ActionContext: + """Helper to create an ActionContext for testing.""" + state = WorkflowState(inputs=inputs or {}) + + async def execute_actions( + actions: list[dict[str, Any]], state: WorkflowState + ) -> AsyncGenerator[WorkflowEvent, None]: + """Mock execute_actions that runs handlers for nested actions.""" + for nested_action in actions: + action_kind = nested_action.get("kind") + handler = get_action_handler(action_kind) + if handler: + ctx = ActionContext( + state=state, + action=nested_action, + execute_actions=execute_actions, + agents=agents or {}, + bindings=bindings or {}, + ) + async for event in handler(ctx): + yield event + + return ActionContext( + state=state, + action=action, + execute_actions=execute_actions, + agents=agents or {}, + bindings=bindings or {}, + ) + + +class TestActionHandlerRegistry: + """Tests for action handler registration.""" + + def test_basic_handlers_registered(self): + """Test that basic handlers are registered.""" + handlers = list_action_handlers() + assert "SetValue" in handlers + assert "AppendValue" in handlers + assert "SendActivity" in handlers + assert "EmitEvent" in handlers + + def test_control_flow_handlers_registered(self): + """Test that control flow handlers are registered.""" + handlers = list_action_handlers() + assert "Foreach" in handlers + assert "If" in handlers + assert "Switch" in handlers + assert "RepeatUntil" in handlers + assert "BreakLoop" in handlers + assert "ContinueLoop" in handlers + + def test_error_handlers_registered(self): + """Test that error handlers are registered.""" + handlers = list_action_handlers() + assert "ThrowException" in handlers + assert "TryCatch" in handlers + + def test_get_unknown_handler_returns_none(self): + """Test that getting an unknown handler returns None.""" + assert get_action_handler("UnknownAction") is None + + +class TestSetValueHandler: + """Tests for SetValue action handler.""" + + @pytest.mark.asyncio + async def test_set_simple_value(self): + """Test setting a simple value.""" + ctx = create_action_context({ + "kind": "SetValue", + "path": "Local.result", + "value": "test value", + }) + + handler = get_action_handler("SetValue") + events = [e async for e in handler(ctx)] + + assert len(events) == 0 # SetValue doesn't emit events + assert ctx.state.get("Local.result") == "test value" + + @pytest.mark.asyncio + async def test_set_value_from_input(self): + """Test setting a value from workflow inputs.""" + ctx = create_action_context( + { + "kind": "SetValue", + "path": "Local.copy", + "value": "literal", + }, + inputs={"original": "from input"}, + ) + + handler = get_action_handler("SetValue") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.copy") == "literal" + + +class TestAppendValueHandler: + """Tests for AppendValue action handler.""" + + @pytest.mark.asyncio + async def test_append_to_new_list(self): + """Test appending to a non-existent list creates it.""" + ctx = create_action_context({ + "kind": "AppendValue", + "path": "Local.results", + "value": "item1", + }) + + handler = get_action_handler("AppendValue") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.results") == ["item1"] + + @pytest.mark.asyncio + async def test_append_to_existing_list(self): + """Test appending to an existing list.""" + ctx = create_action_context({ + "kind": "AppendValue", + "path": "Local.results", + "value": "item2", + }) + ctx.state.set("Local.results", ["item1"]) + + handler = get_action_handler("AppendValue") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.results") == ["item1", "item2"] + + +class TestSendActivityHandler: + """Tests for SendActivity action handler.""" + + @pytest.mark.asyncio + async def test_send_text_activity(self): + """Test sending a text activity.""" + ctx = create_action_context({ + "kind": "SendActivity", + "activity": { + "text": "Hello, world!", + }, + }) + + handler = get_action_handler("SendActivity") + events = [e async for e in handler(ctx)] + + assert len(events) == 1 + assert isinstance(events[0], TextOutputEvent) + assert events[0].text == "Hello, world!" + + +class TestEmitEventHandler: + """Tests for EmitEvent action handler.""" + + @pytest.mark.asyncio + async def test_emit_custom_event(self): + """Test emitting a custom event.""" + ctx = create_action_context({ + "kind": "EmitEvent", + "event": { + "name": "myEvent", + "data": {"key": "value"}, + }, + }) + + handler = get_action_handler("EmitEvent") + events = [e async for e in handler(ctx)] + + assert len(events) == 1 + assert isinstance(events[0], CustomEvent) + assert events[0].name == "myEvent" + assert events[0].data == {"key": "value"} + + +class TestForeachHandler: + """Tests for Foreach action handler.""" + + @pytest.mark.asyncio + async def test_foreach_basic_iteration(self): + """Test basic foreach iteration.""" + ctx = create_action_context({ + "kind": "Foreach", + "source": ["a", "b", "c"], + "itemName": "letter", + "actions": [ + { + "kind": "AppendValue", + "path": "Local.results", + "value": "processed", + } + ], + }) + + handler = get_action_handler("Foreach") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.results") == ["processed", "processed", "processed"] + + @pytest.mark.asyncio + async def test_foreach_sets_item_and_index(self): + """Test that foreach sets item and index variables.""" + ctx = create_action_context({ + "kind": "Foreach", + "source": ["x", "y"], + "itemName": "item", + "indexName": "idx", + "actions": [], + }) + + # We'll check the last values after iteration + handler = get_action_handler("Foreach") + _events = [e async for e in handler(ctx)] # noqa: F841 + + # After iteration, the last item/index should be set + assert ctx.state.get("Local.item") == "y" + assert ctx.state.get("Local.idx") == 1 + + +class TestIfHandler: + """Tests for If action handler.""" + + @pytest.mark.asyncio + async def test_if_true_branch(self): + """Test that the 'then' branch executes when condition is true.""" + ctx = create_action_context({ + "kind": "If", + "condition": True, + "then": [ + {"kind": "SetValue", "path": "Local.branch", "value": "then"}, + ], + "else": [ + {"kind": "SetValue", "path": "Local.branch", "value": "else"}, + ], + }) + + handler = get_action_handler("If") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.branch") == "then" + + @pytest.mark.asyncio + async def test_if_false_branch(self): + """Test that the 'else' branch executes when condition is false.""" + ctx = create_action_context({ + "kind": "If", + "condition": False, + "then": [ + {"kind": "SetValue", "path": "Local.branch", "value": "then"}, + ], + "else": [ + {"kind": "SetValue", "path": "Local.branch", "value": "else"}, + ], + }) + + handler = get_action_handler("If") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.branch") == "else" + + +class TestSwitchHandler: + """Tests for Switch action handler.""" + + @pytest.mark.asyncio + async def test_switch_matching_case(self): + """Test switch with a matching case.""" + ctx = create_action_context({ + "kind": "Switch", + "value": "option2", + "cases": [ + { + "match": "option1", + "actions": [{"kind": "SetValue", "path": "Local.result", "value": "one"}], + }, + { + "match": "option2", + "actions": [{"kind": "SetValue", "path": "Local.result", "value": "two"}], + }, + ], + "default": [{"kind": "SetValue", "path": "Local.result", "value": "default"}], + }) + + handler = get_action_handler("Switch") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.result") == "two" + + @pytest.mark.asyncio + async def test_switch_default_case(self): + """Test switch falls through to default.""" + ctx = create_action_context({ + "kind": "Switch", + "value": "unknown", + "cases": [ + { + "match": "option1", + "actions": [{"kind": "SetValue", "path": "Local.result", "value": "one"}], + }, + ], + "default": [{"kind": "SetValue", "path": "Local.result", "value": "default"}], + }) + + handler = get_action_handler("Switch") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.result") == "default" + + +class TestRepeatUntilHandler: + """Tests for RepeatUntil action handler.""" + + @pytest.mark.asyncio + async def test_repeat_until_condition_met(self): + """Test repeat until condition becomes true.""" + ctx = create_action_context({ + "kind": "RepeatUntil", + "condition": False, # Will be evaluated each iteration + "maxIterations": 3, + "actions": [ + {"kind": "SetValue", "path": "Local.count", "value": 1}, + ], + }) + # Set up a counter that will cause the loop to exit + ctx.state.set("Local.count", 0) + + handler = get_action_handler("RepeatUntil") + _events = [e async for e in handler(ctx)] # noqa: F841 + + # With condition=False (literal), it will run maxIterations times + assert ctx.state.get("Local.iteration") == 3 + + +class TestTryCatchHandler: + """Tests for TryCatch action handler.""" + + @pytest.mark.asyncio + async def test_try_without_error(self): + """Test try block without errors.""" + ctx = create_action_context({ + "kind": "TryCatch", + "try": [ + {"kind": "SetValue", "path": "Local.result", "value": "success"}, + ], + "catch": [ + {"kind": "SetValue", "path": "Local.result", "value": "caught"}, + ], + }) + + handler = get_action_handler("TryCatch") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.result") == "success" + + @pytest.mark.asyncio + async def test_try_with_throw_exception(self): + """Test catching a thrown exception.""" + ctx = create_action_context({ + "kind": "TryCatch", + "try": [ + {"kind": "ThrowException", "message": "Test error", "code": "ERR001"}, + ], + "catch": [ + {"kind": "SetValue", "path": "Local.result", "value": "caught"}, + ], + }) + + handler = get_action_handler("TryCatch") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.result") == "caught" + assert ctx.state.get("Local.error.message") == "Test error" + assert ctx.state.get("Local.error.code") == "ERR001" + + @pytest.mark.asyncio + async def test_finally_always_executes(self): + """Test that finally block always executes.""" + ctx = create_action_context({ + "kind": "TryCatch", + "try": [ + {"kind": "SetValue", "path": "Local.try", "value": "ran"}, + ], + "finally": [ + {"kind": "SetValue", "path": "Local.finally", "value": "ran"}, + ], + }) + + handler = get_action_handler("TryCatch") + _events = [e async for e in handler(ctx)] # noqa: F841 + + assert ctx.state.get("Local.try") == "ran" + assert ctx.state.get("Local.finally") == "ran" diff --git a/python/packages/declarative/tests/test_workflow_samples_integration.py b/python/packages/declarative/tests/test_workflow_samples_integration.py new file mode 100644 index 0000000000..fc0ece9ac5 --- /dev/null +++ b/python/packages/declarative/tests/test_workflow_samples_integration.py @@ -0,0 +1,268 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Integration tests for workflow samples. + +These tests verify that the workflow samples from workflow-samples/ directory +can be parsed and validated by the WorkflowFactory. +""" + +from pathlib import Path + +import pytest +import yaml + +# Path to workflow samples - navigate from tests dir up to repo root +# tests/test_*.py -> packages/declarative/tests/ -> packages/declarative/ -> packages/ -> python/ -> repo root +WORKFLOW_SAMPLES_DIR = Path(__file__).parent.parent.parent.parent.parent / "workflow-samples" + + +def get_workflow_sample_files(): + """Get all .yaml files from the workflow-samples directory.""" + if not WORKFLOW_SAMPLES_DIR.exists(): + return [] + return list(WORKFLOW_SAMPLES_DIR.glob("*.yaml")) + + +class TestWorkflowSampleParsing: + """Tests that verify workflow samples can be parsed correctly.""" + + @pytest.fixture + def sample_files(self): + """Get list of sample files.""" + return get_workflow_sample_files() + + def test_samples_directory_exists(self): + """Verify the workflow-samples directory exists.""" + assert WORKFLOW_SAMPLES_DIR.exists(), f"Workflow samples directory not found at {WORKFLOW_SAMPLES_DIR}" + + def test_samples_exist(self, sample_files): + """Verify there are workflow sample files.""" + assert len(sample_files) > 0, "No workflow sample files found" + + @pytest.mark.parametrize("yaml_file", get_workflow_sample_files(), ids=lambda f: f.name) + def test_sample_yaml_is_valid(self, yaml_file): + """Test that each sample YAML file can be parsed.""" + with open(yaml_file) as f: + data = yaml.safe_load(f) + + assert data is not None, f"Failed to parse {yaml_file.name}" + assert "kind" in data, f"Missing 'kind' field in {yaml_file.name}" + assert data["kind"] == "Workflow", f"Expected kind: Workflow in {yaml_file.name}" + + @pytest.mark.parametrize("yaml_file", get_workflow_sample_files(), ids=lambda f: f.name) + def test_sample_has_trigger(self, yaml_file): + """Test that each sample has a trigger defined.""" + with open(yaml_file) as f: + data = yaml.safe_load(f) + + assert "trigger" in data, f"Missing 'trigger' field in {yaml_file.name}" + trigger = data["trigger"] + assert trigger is not None, f"Trigger is empty in {yaml_file.name}" + + @pytest.mark.parametrize("yaml_file", get_workflow_sample_files(), ids=lambda f: f.name) + def test_sample_has_actions(self, yaml_file): + """Test that each sample has actions defined.""" + with open(yaml_file) as f: + data = yaml.safe_load(f) + + trigger = data.get("trigger", {}) + actions = trigger.get("actions", []) + assert len(actions) > 0, f"No actions defined in {yaml_file.name}" + + @pytest.mark.parametrize("yaml_file", get_workflow_sample_files(), ids=lambda f: f.name) + def test_sample_actions_have_kind(self, yaml_file): + """Test that each action has a 'kind' field.""" + with open(yaml_file) as f: + data = yaml.safe_load(f) + + def check_actions(actions, path=""): + for i, action in enumerate(actions): + action_path = f"{path}[{i}]" + assert "kind" in action, f"Action missing 'kind' at {action_path} in {yaml_file.name}" + + # Check nested actions + for nested_key in ["actions", "elseActions", "thenActions"]: + if nested_key in action: + check_actions(action[nested_key], f"{action_path}.{nested_key}") + + # Check conditions + if "conditions" in action: + for j, cond in enumerate(action["conditions"]): + if "actions" in cond: + check_actions(cond["actions"], f"{action_path}.conditions[{j}].actions") + + # Check cases + if "cases" in action: + for j, case in enumerate(action["cases"]): + if "actions" in case: + check_actions(case["actions"], f"{action_path}.cases[{j}].actions") + + trigger = data.get("trigger", {}) + actions = trigger.get("actions", []) + check_actions(actions, "trigger.actions") + + +class TestWorkflowDefinitionParsing: + """Tests for parsing workflow definitions into structured objects.""" + + @pytest.mark.parametrize("yaml_file", get_workflow_sample_files(), ids=lambda f: f.name) + def test_extract_actions_from_sample(self, yaml_file): + """Test extracting all actions from a workflow sample.""" + with open(yaml_file) as f: + data = yaml.safe_load(f) + + # Collect all action kinds used + action_kinds: set[str] = set() + + def collect_actions(actions): + for action in actions: + action_kinds.add(action.get("kind", "Unknown")) + + # Collect from nested actions + for nested_key in ["actions", "elseActions", "thenActions"]: + if nested_key in action: + collect_actions(action[nested_key]) + + if "conditions" in action: + for cond in action["conditions"]: + if "actions" in cond: + collect_actions(cond["actions"]) + + if "cases" in action: + for case in action["cases"]: + if "actions" in case: + collect_actions(case["actions"]) + + trigger = data.get("trigger", {}) + actions = trigger.get("actions", []) + collect_actions(actions) + + # Verify we found some actions + assert len(action_kinds) > 0, f"No action kinds found in {yaml_file.name}" + + @pytest.mark.parametrize("yaml_file", get_workflow_sample_files(), ids=lambda f: f.name) + def test_extract_agent_names_from_sample(self, yaml_file): + """Test extracting agent names referenced in a workflow sample.""" + with open(yaml_file) as f: + data = yaml.safe_load(f) + + agent_names: set[str] = set() + + def collect_agents(actions): + for action in actions: + kind = action.get("kind", "") + + if kind in ("InvokeAzureAgent", "InvokePromptAgent"): + agent_config = action.get("agent", {}) + name = agent_config.get("name") if isinstance(agent_config, dict) else agent_config + if name and not str(name).startswith("="): + agent_names.add(name) + + # Collect from nested actions + for nested_key in ["actions", "elseActions", "thenActions"]: + if nested_key in action: + collect_agents(action[nested_key]) + + if "conditions" in action: + for cond in action["conditions"]: + if "actions" in cond: + collect_agents(cond["actions"]) + + if "cases" in action: + for case in action["cases"]: + if "actions" in case: + collect_agents(case["actions"]) + + trigger = data.get("trigger", {}) + actions = trigger.get("actions", []) + collect_agents(actions) + + # Log the agents found (some workflows may not use agents) + # Agent names: {agent_names} + + +class TestHandlerCoverage: + """Tests to verify handler coverage for workflow actions.""" + + @pytest.fixture + def all_action_kinds(self): + """Collect all action kinds used across all samples.""" + action_kinds: set[str] = set() + + def collect_actions(actions): + for action in actions: + action_kinds.add(action.get("kind", "Unknown")) + + for nested_key in ["actions", "elseActions", "thenActions"]: + if nested_key in action: + collect_actions(action[nested_key]) + + if "conditions" in action: + for cond in action["conditions"]: + if "actions" in cond: + collect_actions(cond["actions"]) + + if "cases" in action: + for case in action["cases"]: + if "actions" in case: + collect_actions(case["actions"]) + + for yaml_file in get_workflow_sample_files(): + with open(yaml_file) as f: + data = yaml.safe_load(f) + trigger = data.get("trigger", {}) + actions = trigger.get("actions", []) + collect_actions(actions) + + return action_kinds + + def test_handlers_exist_for_sample_actions(self, all_action_kinds): + """Test that handlers exist for all action kinds in samples.""" + from agent_framework_declarative._workflows._handlers import list_action_handlers + + registered_handlers = set(list_action_handlers()) + + # Handlers we expect but may not be in samples + expected_handlers = { + "SetValue", + "SetVariable", + "SetTextVariable", + "SetMultipleVariables", + "ResetVariable", + "ClearAllVariables", + "AppendValue", + "SendActivity", + "EmitEvent", + "Foreach", + "If", + "Switch", + "ConditionGroup", + "GotoAction", + "BreakLoop", + "ContinueLoop", + "RepeatUntil", + "TryCatch", + "ThrowException", + "EndWorkflow", + "EndConversation", + "InvokeAzureAgent", + "InvokePromptAgent", + "CreateConversation", + "AddConversationMessage", + "CopyConversationMessages", + "RetrieveConversationMessages", + "Question", + "RequestExternalInput", + "WaitForInput", + } + + # Check that sample action kinds have handlers + missing_handlers = all_action_kinds - registered_handlers - {"OnConversationStart"} # Trigger kind, not action + + if missing_handlers: + # Informational, not a failure, as some actions may be future work + pass + + # Check that we have handlers for the expected core set + core_handlers = registered_handlers & expected_handlers + assert len(core_handlers) > 10, "Expected more core handlers to be registered" diff --git a/python/packages/declarative/tests/test_workflow_state.py b/python/packages/declarative/tests/test_workflow_state.py new file mode 100644 index 0000000000..957466806d --- /dev/null +++ b/python/packages/declarative/tests/test_workflow_state.py @@ -0,0 +1,225 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for WorkflowState class.""" + +import pytest + +from agent_framework_declarative._workflows._state import WorkflowState + + +class TestWorkflowStateInitialization: + """Tests for WorkflowState initialization.""" + + def test_empty_initialization(self): + """Test creating a WorkflowState with no inputs.""" + state = WorkflowState() + assert state.inputs == {} + assert state.outputs == {} + assert state.local == {} + assert state.agent == {} + + def test_initialization_with_inputs(self): + """Test creating a WorkflowState with inputs.""" + state = WorkflowState(inputs={"query": "Hello", "count": 5}) + assert state.inputs == {"query": "Hello", "count": 5} + assert state.outputs == {} + + def test_inputs_are_immutable(self): + """Test that inputs cannot be modified through set().""" + state = WorkflowState(inputs={"query": "Hello"}) + with pytest.raises(ValueError, match="Cannot modify Workflow.Inputs"): + state.set("Workflow.Inputs.query", "Modified") + + +class TestWorkflowStateGetSet: + """Tests for get and set operations.""" + + def test_set_and_get_turn_variable(self): + """Test setting and getting a turn variable.""" + state = WorkflowState() + state.set("Local.counter", 10) + assert state.get("Local.counter") == 10 + + def test_set_and_get_nested_turn_variable(self): + """Test setting and getting a nested turn variable.""" + state = WorkflowState() + state.set("Local.data.nested.value", "test") + assert state.get("Local.data.nested.value") == "test" + + def test_set_and_get_workflow_output(self): + """Test setting and getting workflow output.""" + state = WorkflowState() + state.set("Workflow.Outputs.result", "success") + assert state.get("Workflow.Outputs.result") == "success" + assert state.outputs["result"] == "success" + + def test_get_with_default(self): + """Test get with default value.""" + state = WorkflowState() + assert state.get("Local.nonexistent") is None + assert state.get("Local.nonexistent", "default") == "default" + + def test_get_workflow_inputs(self): + """Test getting workflow inputs.""" + state = WorkflowState(inputs={"query": "test"}) + assert state.get("Workflow.Inputs.query") == "test" + + def test_set_custom_namespace(self): + """Test setting a custom namespace variable.""" + state = WorkflowState() + state.set("custom.myvar", "value") + assert state.get("custom.myvar") == "value" + + +class TestWorkflowStateAppend: + """Tests for append operation.""" + + def test_append_to_nonexistent_list(self): + """Test appending to a path that doesn't exist yet.""" + state = WorkflowState() + state.append("Local.results", "item1") + assert state.get("Local.results") == ["item1"] + + def test_append_to_existing_list(self): + """Test appending to an existing list.""" + state = WorkflowState() + state.set("Local.results", ["item1"]) + state.append("Local.results", "item2") + assert state.get("Local.results") == ["item1", "item2"] + + def test_append_to_non_list_raises(self): + """Test that appending to a non-list raises ValueError.""" + state = WorkflowState() + state.set("Local.value", "not a list") + with pytest.raises(ValueError, match="Cannot append to non-list"): + state.append("Local.value", "item") + + +class TestWorkflowStateAgentResult: + """Tests for agent result management.""" + + def test_set_agent_result(self): + """Test setting agent result.""" + state = WorkflowState() + state.set_agent_result( + text="Agent response", + messages=[{"role": "assistant", "content": "Hello"}], + tool_calls=[{"name": "tool1"}], + ) + assert state.agent["text"] == "Agent response" + assert len(state.agent["messages"]) == 1 + assert len(state.agent["toolCalls"]) == 1 + + def test_get_agent_result_via_path(self): + """Test getting agent result via path.""" + state = WorkflowState() + state.set_agent_result(text="Response") + assert state.get("Agent.text") == "Response" + + def test_reset_agent(self): + """Test resetting agent result.""" + state = WorkflowState() + state.set_agent_result(text="Response") + state.reset_agent() + assert state.agent == {} + + +class TestWorkflowStateConversation: + """Tests for conversation management.""" + + def test_add_conversation_message(self): + """Test adding a conversation message.""" + state = WorkflowState() + message = {"role": "user", "content": "Hello"} + state.add_conversation_message(message) + assert len(state.conversation["messages"]) == 1 + assert state.conversation["messages"][0] == message + + def test_get_conversation_history(self): + """Test getting conversation history.""" + state = WorkflowState() + state.add_conversation_message({"role": "user", "content": "Hi"}) + state.add_conversation_message({"role": "assistant", "content": "Hello"}) + assert len(state.get("Conversation.history")) == 2 + + +class TestWorkflowStatePowerFx: + """Tests for PowerFx expression evaluation.""" + + def test_eval_non_expression(self): + """Test that non-expressions are returned as-is.""" + state = WorkflowState() + assert state.eval("plain text") == "plain text" + + def test_eval_if_expression_with_literal(self): + """Test eval_if_expression with a literal value.""" + state = WorkflowState() + assert state.eval_if_expression(42) == 42 + assert state.eval_if_expression(["a", "b"]) == ["a", "b"] + + def test_eval_if_expression_with_non_expression_string(self): + """Test eval_if_expression with a non-expression string.""" + state = WorkflowState() + assert state.eval_if_expression("plain text") == "plain text" + + def test_to_powerfx_symbols(self): + """Test converting state to PowerFx symbols.""" + state = WorkflowState(inputs={"query": "test"}) + state.set("Local.counter", 5) + state.set("Workflow.Outputs.result", "done") + + symbols = state.to_powerfx_symbols() + assert symbols["Workflow"]["Inputs"]["query"] == "test" + assert symbols["Workflow"]["Outputs"]["result"] == "done" + assert symbols["Local"]["counter"] == 5 + + +class TestWorkflowStateClone: + """Tests for state cloning.""" + + def test_clone_creates_copy(self): + """Test that clone creates a copy of the state.""" + state = WorkflowState(inputs={"query": "test"}) + state.set("Local.counter", 5) + + cloned = state.clone() + assert cloned.get("Workflow.Inputs.query") == "test" + assert cloned.get("Local.counter") == 5 + + def test_clone_is_independent(self): + """Test that modifications to clone don't affect original.""" + state = WorkflowState() + state.set("Local.value", "original") + + cloned = state.clone() + cloned.set("Local.value", "modified") + + assert state.get("Local.value") == "original" + assert cloned.get("Local.value") == "modified" + + +class TestWorkflowStateResetTurn: + """Tests for turn reset.""" + + def test_reset_local_clears_turn_variables(self): + """Test that reset_local clears turn variables.""" + state = WorkflowState() + state.set("Local.var1", "value1") + state.set("Local.var2", "value2") + + state.reset_local() + + assert state.get("Local.var1") is None + assert state.get("Local.var2") is None + assert state.local == {} + + def test_reset_local_preserves_other_state(self): + """Test that reset_local preserves other state.""" + state = WorkflowState(inputs={"query": "test"}) + state.set("Workflow.Outputs.result", "done") + state.set("Local.temp", "will be cleared") + + state.reset_local() + + assert state.get("Workflow.Inputs.query") == "test" + assert state.get("Workflow.Outputs.result") == "done" diff --git a/python/packages/devui/agent_framework_devui/_discovery.py b/python/packages/devui/agent_framework_devui/_discovery.py index 9f8fcf0542..ed60a402e1 100644 --- a/python/packages/devui/agent_framework_devui/_discovery.py +++ b/python/packages/devui/agent_framework_devui/_discovery.py @@ -346,8 +346,8 @@ async def create_entity_info_from_object( instructions = None model = None chat_client_type = None - context_providers_list = None - middleware_list = None + context_provider_list = None + middlewares_list = None if entity_type == "agent": from ._utils import extract_agent_metadata @@ -356,8 +356,8 @@ async def create_entity_info_from_object( instructions = agent_meta["instructions"] model = agent_meta["model"] chat_client_type = agent_meta["chat_client_type"] - context_providers_list = agent_meta["context_providers"] - middleware_list = agent_meta["middleware"] + context_provider_list = agent_meta["context_provider"] + middlewares_list = agent_meta["middleware"] # Log helpful info about agent capabilities (before creating EntityInfo) if entity_type == "agent": @@ -395,8 +395,8 @@ async def create_entity_info_from_object( instructions=instructions, model_id=model, chat_client_type=chat_client_type, - context_providers=context_providers_list, - middleware=middleware_list, + context_provider=context_provider_list, + middleware=middlewares_list, executors=tools_list if entity_type == "workflow" else [], input_schema={"type": "string"}, # Default schema start_executor_id=tools_list[0] if tools_list and entity_type == "workflow" else None, @@ -829,8 +829,8 @@ async def _register_entity_from_object( instructions = None model = None chat_client_type = None - context_providers_list = None - middleware_list = None + context_provider_list = None + middlewares_list = None if obj_type == "agent": from ._utils import extract_agent_metadata @@ -839,8 +839,8 @@ async def _register_entity_from_object( instructions = agent_meta["instructions"] model = agent_meta["model"] chat_client_type = agent_meta["chat_client_type"] - context_providers_list = agent_meta["context_providers"] - middleware_list = agent_meta["middleware"] + context_provider_list = agent_meta["context_provider"] + middlewares_list = agent_meta["middleware"] entity_info = EntityInfo( id=entity_id, @@ -852,8 +852,8 @@ async def _register_entity_from_object( instructions=instructions, model_id=model, chat_client_type=chat_client_type, - context_providers=context_providers_list, - middleware=middleware_list, + context_provider=context_provider_list, + middleware=middlewares_list, metadata={ "module_path": module_path, "entity_type": obj_type, @@ -883,10 +883,14 @@ async def _extract_tools_from_object(self, obj: Any, obj_type: str) -> list[str] try: if obj_type == "agent": - # For agents, check chat_options.tools first - chat_options = getattr(obj, "chat_options", None) - if chat_options and hasattr(chat_options, "tools"): - for tool in chat_options.tools: + # For agents, check default_options.get("tools") + chat_options = getattr(obj, "default_options", None) + chat_options_tools = None + if chat_options: + chat_options_tools = chat_options.get("tools") + + if chat_options_tools: + for tool in chat_options_tools: if hasattr(tool, "__name__"): tools.append(tool.__name__) elif hasattr(tool, "name"): diff --git a/python/packages/devui/agent_framework_devui/_executor.py b/python/packages/devui/agent_framework_devui/_executor.py index e63dd014fe..585036bef9 100644 --- a/python/packages/devui/agent_framework_devui/_executor.py +++ b/python/packages/devui/agent_framework_devui/_executor.py @@ -114,10 +114,10 @@ async def _ensure_mcp_connections(self, agent: Any) -> None: Args: agent: Agent object that may have MCP tools """ - if not hasattr(agent, "_local_mcp_tools"): + if not hasattr(agent, "mcp_tools"): return - for mcp_tool in agent._local_mcp_tools: + for mcp_tool in agent.mcp_tools: if not getattr(mcp_tool, "is_connected", False): continue diff --git a/python/packages/devui/agent_framework_devui/_mapper.py b/python/packages/devui/agent_framework_devui/_mapper.py index 021a4a4549..71bbda9b85 100644 --- a/python/packages/devui/agent_framework_devui/_mapper.py +++ b/python/packages/devui/agent_framework_devui/_mapper.py @@ -145,7 +145,7 @@ async def convert_event(self, raw_event: Any, request: AgentFrameworkRequest) -> """Convert a single Agent Framework event to OpenAI events. Args: - raw_event: Agent Framework event (AgentRunResponseUpdate, WorkflowEvent, etc.) + raw_event: Agent Framework event (AgentResponseUpdate, WorkflowEvent, etc.) request: Original request for context Returns: @@ -178,26 +178,26 @@ async def convert_event(self, raw_event: Any, request: AgentFrameworkRequest) -> # Import Agent Framework types for proper isinstance checks try: - from agent_framework import AgentRunResponse, AgentRunResponseUpdate, WorkflowEvent + from agent_framework import AgentResponse, AgentResponseUpdate, WorkflowEvent from agent_framework._workflows._events import AgentRunUpdateEvent - # Handle AgentRunUpdateEvent - workflow event wrapping AgentRunResponseUpdate + # Handle AgentRunUpdateEvent - workflow event wrapping AgentResponseUpdate # This must be checked BEFORE generic WorkflowEvent check if isinstance(raw_event, AgentRunUpdateEvent): - # Extract the AgentRunResponseUpdate from the event's data attribute - if raw_event.data and isinstance(raw_event.data, AgentRunResponseUpdate): + # Extract the AgentResponseUpdate from the event's data attribute + if raw_event.data and isinstance(raw_event.data, AgentResponseUpdate): # Preserve executor_id in context for proper output routing context["current_executor_id"] = raw_event.executor_id return await self._convert_agent_update(raw_event.data, context) # If no data, treat as generic workflow event return await self._convert_workflow_event(raw_event, context) - # Handle complete agent response (AgentRunResponse) - for non-streaming agent execution - if isinstance(raw_event, AgentRunResponse): + # Handle complete agent response (AgentResponse) - for non-streaming agent execution + if isinstance(raw_event, AgentResponse): return await self._convert_agent_response(raw_event, context) - # Handle agent updates (AgentRunResponseUpdate) - for direct agent execution - if isinstance(raw_event, AgentRunResponseUpdate): + # Handle agent updates (AgentResponseUpdate) - for direct agent execution + if isinstance(raw_event, AgentResponseUpdate): return await self._convert_agent_update(raw_event, context) # Handle workflow events (any class that inherits from WorkflowEvent) @@ -686,13 +686,13 @@ async def _convert_agent_update(self, update: Any, context: dict[str, Any]) -> S return events async def _convert_agent_response(self, response: Any, context: dict[str, Any]) -> Sequence[Any]: - """Convert complete AgentRunResponse to OpenAI events. + """Convert complete AgentResponse to OpenAI events. This handles non-streaming agent execution where agent.run() returns - a complete AgentRunResponse instead of streaming AgentRunResponseUpdate objects. + a complete AgentResponse instead of streaming AgentResponseUpdate objects. Args: - response: Agent run response (AgentRunResponse) + response: Agent run response (AgentResponse) context: Conversion context Returns: @@ -881,7 +881,7 @@ async def _convert_workflow_event(self, event: Any, context: dict[str, Any]) -> # Handle WorkflowOutputEvent separately to preserve output data if event_class == "WorkflowOutputEvent": output_data = getattr(event, "data", None) - source_executor_id = getattr(event, "source_executor_id", "unknown") + executor_id = getattr(event, "executor_id", "unknown") if output_data is not None: # Import required types @@ -942,7 +942,7 @@ async def _convert_workflow_event(self, event: Any, context: dict[str, Any]) -> # Emit output_item.added for each yield_output logger.debug( f"WorkflowOutputEvent converted to output_item.added " - f"(executor: {source_executor_id}, length: {len(text)})" + f"(executor: {executor_id}, length: {len(text)})" ) return [ ResponseOutputItemAddedEvent( @@ -1047,7 +1047,7 @@ async def _convert_workflow_event(self, event: Any, context: dict[str, Any]) -> # Create ExecutorActionItem with completed status # ExecutorCompletedEvent uses 'data' field, not 'result' # Serialize the result data to ensure it's JSON-serializable - # (AgentExecutorResponse contains AgentRunResponse/ChatMessage which are SerializationMixin) + # (AgentExecutorResponse contains AgentResponse/ChatMessage which are SerializationMixin) raw_result = getattr(event, "data", None) serialized_result = self._serialize_value(raw_result) if raw_result is not None else None executor_item = ExecutorActionItem( diff --git a/python/packages/devui/agent_framework_devui/_server.py b/python/packages/devui/agent_framework_devui/_server.py index 146db9b33d..6393f23b4a 100644 --- a/python/packages/devui/agent_framework_devui/_server.py +++ b/python/packages/devui/agent_framework_devui/_server.py @@ -248,9 +248,9 @@ async def _cleanup_entities(self) -> None: except Exception as e: logger.warning(f"Error closing credential for {entity_info.id}: {e}") - # Close MCP tools (framework tracks them in _local_mcp_tools) - if entity_obj and hasattr(entity_obj, "_local_mcp_tools"): - for mcp_tool in entity_obj._local_mcp_tools: + # Close MCP tools (framework tracks them in mcp_tools) + if entity_obj and hasattr(entity_obj, "mcp_tools"): + for mcp_tool in entity_obj.mcp_tools: if hasattr(mcp_tool, "close") and callable(mcp_tool.close): try: if inspect.iscoroutinefunction(mcp_tool.close): diff --git a/python/packages/devui/agent_framework_devui/_utils.py b/python/packages/devui/agent_framework_devui/_utils.py index 3c17c072f7..24cdc9c073 100644 --- a/python/packages/devui/agent_framework_devui/_utils.py +++ b/python/packages/devui/agent_framework_devui/_utils.py @@ -32,22 +32,32 @@ def extract_agent_metadata(entity_object: Any) -> dict[str, Any]: "instructions": None, "model": None, "chat_client_type": None, - "context_providers": None, + "context_provider": None, "middleware": None, } # Try to get instructions - if hasattr(entity_object, "chat_options") and hasattr(entity_object.chat_options, "instructions"): - metadata["instructions"] = entity_object.chat_options.instructions - - # Try to get model - check both chat_options and chat_client + if hasattr(entity_object, "default_options"): + chat_opts = entity_object.default_options + if isinstance(chat_opts, dict): + if "instructions" in chat_opts: + metadata["instructions"] = chat_opts.get("instructions") + elif hasattr(chat_opts, "instructions"): + metadata["instructions"] = chat_opts.instructions + + # Try to get model - check both default_options and chat_client + if hasattr(entity_object, "default_options"): + chat_opts = entity_object.default_options + if isinstance(chat_opts, dict): + if chat_opts.get("model_id"): + metadata["model"] = chat_opts.get("model_id") + elif hasattr(chat_opts, "model_id") and chat_opts.model_id: + metadata["model"] = chat_opts.model_id if ( - hasattr(entity_object, "chat_options") - and hasattr(entity_object.chat_options, "model_id") - and entity_object.chat_options.model_id + metadata["model"] is None + and hasattr(entity_object, "chat_client") + and hasattr(entity_object.chat_client, "model_id") ): - metadata["model"] = entity_object.chat_options.model_id - elif hasattr(entity_object, "chat_client") and hasattr(entity_object.chat_client, "model_id"): metadata["model"] = entity_object.chat_client.model_id # Try to get chat client type @@ -60,20 +70,20 @@ def extract_agent_metadata(entity_object: Any) -> dict[str, Any]: and entity_object.context_provider and hasattr(entity_object.context_provider, "__class__") ): - metadata["context_providers"] = [entity_object.context_provider.__class__.__name__] # type: ignore + metadata["context_provider"] = [entity_object.context_provider.__class__.__name__] # type: ignore # Try to get middleware if hasattr(entity_object, "middleware") and entity_object.middleware: - middleware_list: list[str] = [] + middlewares_list: list[str] = [] for m in entity_object.middleware: # Try multiple ways to get a good name for middleware if hasattr(m, "__name__"): # Function or callable - middleware_list.append(m.__name__) + middlewares_list.append(m.__name__) elif hasattr(m, "__class__"): # Class instance - middleware_list.append(m.__class__.__name__) + middlewares_list.append(m.__class__.__name__) else: - middleware_list.append(str(m)) - metadata["middleware"] = middleware_list # type: ignore + middlewares_list.append(str(m)) + metadata["middleware"] = middlewares_list # type: ignore return metadata diff --git a/python/packages/devui/agent_framework_devui/models/_discovery_models.py b/python/packages/devui/agent_framework_devui/models/_discovery_models.py index 382639b277..ff217a48d2 100644 --- a/python/packages/devui/agent_framework_devui/models/_discovery_models.py +++ b/python/packages/devui/agent_framework_devui/models/_discovery_models.py @@ -43,7 +43,7 @@ class EntityInfo(BaseModel): instructions: str | None = None model_id: str | None = None chat_client_type: str | None = None - context_providers: list[str] | None = None + context_provider: list[str] | None = None middleware: list[str] | None = None # Workflow-specific fields (populated only for detailed info requests) diff --git a/python/packages/devui/frontend/src/components/features/agent/agent-details-modal.tsx b/python/packages/devui/frontend/src/components/features/agent/agent-details-modal.tsx index 990231552b..f9fa4480a0 100644 --- a/python/packages/devui/frontend/src/components/features/agent/agent-details-modal.tsx +++ b/python/packages/devui/frontend/src/components/features/agent/agent-details-modal.tsx @@ -179,10 +179,10 @@ export function AgentDetailsModal({ )} - {/* Middleware */} + {/* Middlewares */} {agent.middleware && agent.middleware.length > 0 && ( } >
    @@ -195,20 +195,16 @@ export function AgentDetailsModal({ )} - {/* Context Providers */} - {agent.context_providers && agent.context_providers.length > 0 && ( + {/* Context Provider */} + {agent.context_provider && ( } className={!agent.middleware || agent.middleware.length === 0 ? "md:col-start-2" : ""} > -
      - {agent.context_providers.map((cp, index) => ( -
    • - • {cp} -
    • - ))} -
    +
    + {agent.context_provider} +
    )} diff --git a/python/packages/devui/frontend/src/services/api.ts b/python/packages/devui/frontend/src/services/api.ts index 227cb6f7e1..7c442597ee 100644 --- a/python/packages/devui/frontend/src/services/api.ts +++ b/python/packages/devui/frontend/src/services/api.ts @@ -42,7 +42,7 @@ interface BackendEntityInfo { instructions?: string; model_id?: string; chat_client_type?: string; - context_providers?: string[]; + context_provider?: string[]; middleware?: string[]; // Workflow-specific fields (present when type === "workflow") executors?: string[]; @@ -77,7 +77,7 @@ const MAX_RETRY_ATTEMPTS = 10; // Max 10 retries (~30 seconds with exponential b function getBackendUrl(): string { const stored = localStorage.getItem("devui_backend_url"); if (stored) return stored; - + return DEFAULT_API_BASE_URL; } @@ -221,13 +221,13 @@ class ApiClient { instructions: entity.instructions, model_id: entity.model_id, chat_client_type: entity.chat_client_type, - context_providers: entity.context_providers, + context_provider: entity.context_provider, middleware: entity.middleware, }; } else { // Workflow - prefer executors field, fall back to tools for backward compatibility const executorList = entity.executors || entity.tools || []; - + // Determine start_executor_id: use entity value, or first executor if it's a string let startExecutorId = entity.start_executor_id || ""; if (!startExecutorId && executorList.length > 0) { @@ -236,7 +236,7 @@ class ApiClient { startExecutorId = firstExecutor; } } - + return { id: entity.id, name: entity.name, @@ -493,10 +493,10 @@ class ApiClient { if (!resumeResponseId) { currentResponseId = storedState.responseId; } - + lastSequenceNumber = storedState.lastSequenceNumber; lastMessageId = storedState.lastMessageId; - + // Replay stored events only if we're not explicitly resuming // (explicit resume means the caller already has the events) if (!resumeResponseId) { diff --git a/python/packages/devui/frontend/src/types/agent-framework.ts b/python/packages/devui/frontend/src/types/agent-framework.ts index 6d12a4cc8a..a41b0ce9a1 100644 --- a/python/packages/devui/frontend/src/types/agent-framework.ts +++ b/python/packages/devui/frontend/src/types/agent-framework.ts @@ -208,7 +208,7 @@ export interface UsageDetails { } // Agent run response update (streaming) -export interface AgentRunResponseUpdate { +export interface AgentResponseUpdate { contents: Contents[]; role?: Role; author_name?: string; @@ -222,7 +222,7 @@ export interface AgentRunResponseUpdate { } // Agent run response (final) -export interface AgentRunResponse { +export interface AgentResponse { messages: ChatMessage[]; response_id?: string; created_at?: CreatedAtT; @@ -269,8 +269,7 @@ export interface AgentThread { export interface WorkflowEvent { type?: string; // Event class name like "WorkflowOutputEvent", "WorkflowCompletedEvent", "ExecutorInvokedEvent", etc. data?: unknown; - executor_id?: string; // Present for executor-related events - source_executor_id?: string; // Present for WorkflowOutputEvent + executor_id?: string; // Present for executor-related events and WorkflowOutputEvent } export interface WorkflowStartedEvent extends WorkflowEvent { @@ -286,7 +285,7 @@ export interface WorkflowCompletedEvent extends WorkflowEvent { export interface WorkflowOutputEvent extends WorkflowEvent { // Event-specific data for workflow output (new) readonly event_type: "workflow_output"; - source_executor_id: string; // ID of executor that yielded the output + executor_id: string; // ID of executor that yielded the output } export interface WorkflowWarningEvent extends WorkflowEvent { @@ -302,11 +301,11 @@ export interface ExecutorEvent extends WorkflowEvent { } export interface AgentRunUpdateEvent extends ExecutorEvent { - data?: AgentRunResponseUpdate; + data?: AgentResponseUpdate; } export interface AgentRunEvent extends ExecutorEvent { - data?: AgentRunResponse; + data?: AgentResponse; } // Span event structure (from OpenTelemetry) diff --git a/python/packages/devui/frontend/src/types/index.ts b/python/packages/devui/frontend/src/types/index.ts index 411d8a550b..2ea89f1381 100644 --- a/python/packages/devui/frontend/src/types/index.ts +++ b/python/packages/devui/frontend/src/types/index.ts @@ -39,8 +39,8 @@ export interface AgentInfo { instructions?: string; model_id?: string; chat_client_type?: string; - context_providers?: string[]; - middleware?: string[]; + context_provider?: string | undefined; + middleware?: string[] | undefined; } // JSON Schema types for workflow input diff --git a/python/packages/devui/pyproject.toml b/python/packages/devui/pyproject.toml index c48754bd09..af6c03be31 100644 --- a/python/packages/devui/pyproject.toml +++ b/python/packages/devui/pyproject.toml @@ -4,7 +4,7 @@ description = "Debug UI for Microsoft Agent Framework with OpenAI-compatible API authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://github.com/microsoft/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/devui/tests/__init__.py b/python/packages/devui/tests/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/python/packages/devui/tests/test_cleanup_hooks.py b/python/packages/devui/tests/test_cleanup_hooks.py index f065c1e0c6..2d6fd7b614 100644 --- a/python/packages/devui/tests/test_cleanup_hooks.py +++ b/python/packages/devui/tests/test_cleanup_hooks.py @@ -7,7 +7,7 @@ from pathlib import Path import pytest -from agent_framework import AgentRunResponse, ChatMessage, Role, TextContent +from agent_framework import AgentResponse, ChatMessage, Role, TextContent from agent_framework_devui import register_cleanup from agent_framework_devui._discovery import EntityDiscovery @@ -35,7 +35,7 @@ def __init__(self, name: str = "TestAgent"): async def run_stream(self, messages=None, *, thread=None, **kwargs): """Mock streaming run method.""" - yield AgentRunResponse( + yield AgentResponse( messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Test response")])], ) @@ -259,7 +259,7 @@ async def test_cleanup_with_file_based_discovery(): # Write agent module with cleanup registration agent_file = agent_dir / "__init__.py" agent_file.write_text(""" -from agent_framework import AgentRunResponse, ChatMessage, Role, TextContent +from agent_framework import AgentResponse, ChatMessage, Role, TextContent from agent_framework_devui import register_cleanup class MockCredential: @@ -278,7 +278,7 @@ class TestAgent: description = "Test agent with cleanup" async def run_stream(self, messages=None, *, thread=None, **kwargs): - yield AgentRunResponse( + yield AgentResponse( messages=[ChatMessage(role=Role.ASSISTANT, content=[TextContent(text="Test")])], inner_messages=[], ) diff --git a/python/packages/devui/tests/test_discovery.py b/python/packages/devui/tests/test_discovery.py index dbd1ce1074..72e534b012 100644 --- a/python/packages/devui/tests/test_discovery.py +++ b/python/packages/devui/tests/test_discovery.py @@ -84,19 +84,15 @@ async def test_discovery_accepts_agents_with_only_run(): init_file = agent_dir / "__init__.py" init_file.write_text(""" -from agent_framework import AgentRunResponse, AgentThread, ChatMessage, Role, TextContent +from agent_framework import AgentResponse, AgentThread, ChatMessage, Role, TextContent class NonStreamingAgent: id = "non_streaming" name = "Non-Streaming Agent" description = "Agent without run_stream" - @property - def display_name(self): - return self.name - async def run(self, messages=None, *, thread=None, **kwargs): - return AgentRunResponse( + return AgentResponse( messages=[ChatMessage( role=Role.ASSISTANT, contents=[TextContent(text="response")] @@ -207,13 +203,13 @@ def test_func(input: str) -> str: agent_dir = temp_path / "my_agent" agent_dir.mkdir() (agent_dir / "agent.py").write_text(""" -from agent_framework import AgentRunResponse, AgentThread, ChatMessage, Role, TextContent +from agent_framework import AgentResponse, AgentThread, ChatMessage, Role, TextContent class TestAgent: name = "Test Agent" async def run(self, messages=None, *, thread=None, **kwargs): - return AgentRunResponse( + return AgentResponse( messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="test")])], response_id="test" ) diff --git a/python/packages/devui/tests/test_execution.py b/python/packages/devui/tests/test_execution.py index 5d276e79d1..15cb9bf4df 100644 --- a/python/packages/devui/tests/test_execution.py +++ b/python/packages/devui/tests/test_execution.py @@ -287,7 +287,7 @@ async def test_full_pipeline_agent_events_are_json_serializable(executor_with_re 2. Each event is converted by the mapper 3. Server calls model_dump_json() on each event for SSE - If any event contains non-serializable objects (like AgentRunResponse), + If any event contains non-serializable objects (like AgentResponse), this test will fail - catching the bug before it hits production. """ executor, entity_id, mock_client = executor_with_real_agent @@ -327,7 +327,7 @@ async def test_full_pipeline_workflow_events_are_json_serializable(): This is particularly important for workflows with AgentExecutor because: - AgentExecutor produces ExecutorCompletedEvent with AgentExecutorResponse - - AgentExecutorResponse contains AgentRunResponse and ChatMessage objects + - AgentExecutorResponse contains AgentResponse and ChatMessage objects - These are SerializationMixin objects, not Pydantic, which caused the original bug This test ensures the ENTIRE streaming pipeline works end-to-end. @@ -566,7 +566,7 @@ def test_extract_workflow_hil_responses_handles_stringified_json(): async def test_executor_handles_non_streaming_agent(): """Test executor can handle agents with only run() method (no run_stream).""" - from agent_framework import AgentRunResponse, AgentThread, ChatMessage, Role, TextContent + from agent_framework import AgentResponse, AgentThread, ChatMessage, Role, TextContent class NonStreamingAgent: """Agent with only run() method - does NOT satisfy full AgentProtocol.""" @@ -575,12 +575,8 @@ class NonStreamingAgent: name = "Non-Streaming Test Agent" description = "Test agent without run_stream()" - @property - def display_name(self): - return self.name - async def run(self, messages=None, *, thread=None, **kwargs): - return AgentRunResponse( + return AgentResponse( messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=f"Processed: {messages}")])], response_id="test_123", ) diff --git a/python/packages/devui/tests/test_helpers.py b/python/packages/devui/tests/test_helpers.py index ebb03c4c15..1385dc867d 100644 --- a/python/packages/devui/tests/test_helpers.py +++ b/python/packages/devui/tests/test_helpers.py @@ -13,18 +13,18 @@ to avoid pytest plugin conflicts when running tests across packages. """ +import sys from collections.abc import AsyncIterable, MutableSequence -from typing import Any +from typing import Any, Generic from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentThread, BaseAgent, BaseChatClient, ChatAgent, ChatMessage, - ChatOptions, ChatResponse, ChatResponseUpdate, ConcurrentBuilder, @@ -35,8 +35,14 @@ TextContent, use_chat_middleware, ) +from agent_framework._clients import TOptions_co from agent_framework._workflows._agent_executor import AgentExecutorResponse +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover + # Import real workflow event classes - NOT mocks! from agent_framework._workflows._events import ( ExecutorCompletedEvent, @@ -91,7 +97,7 @@ async def get_streaming_response( @use_chat_middleware -class MockBaseChatClient(BaseChatClient): +class MockBaseChatClient(BaseChatClient[TOptions_co], Generic[TOptions_co]): """Full BaseChatClient mock with middleware support. Use this when testing features that require the full BaseChatClient interface. @@ -106,11 +112,12 @@ def __init__(self, **kwargs: Any): self.call_count: int = 0 self.received_messages: list[list[ChatMessage]] = [] + @override async def _inner_get_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> ChatResponse: self.call_count += 1 @@ -119,11 +126,12 @@ async def _inner_get_response( return self.run_responses.pop(0) return ChatResponse(messages=ChatMessage(role="assistant", text="Mock response from ChatAgent")) + @override async def _inner_get_streaming_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: self.call_count += 1 @@ -164,9 +172,9 @@ async def run( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: self.call_count += 1 - return AgentRunResponse( + return AgentResponse( messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=self.response_text)])] ) @@ -176,10 +184,10 @@ async def run_stream( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: self.call_count += 1 for chunk in self.streaming_chunks: - yield AgentRunResponseUpdate(contents=[TextContent(text=chunk)], role=Role.ASSISTANT) + yield AgentResponseUpdate(contents=[TextContent(text=chunk)], role=Role.ASSISTANT) class MockToolCallingAgent(BaseAgent): @@ -195,9 +203,9 @@ async def run( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: self.call_count += 1 - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="done")]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="done")]) async def run_stream( self, @@ -205,15 +213,15 @@ async def run_stream( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: self.call_count += 1 # First: text - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=[TextContent(text="Let me search for that...")], role=Role.ASSISTANT, ) # Second: tool call - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=[ FunctionCallContent( call_id="call_123", @@ -224,7 +232,7 @@ async def run_stream( role=Role.ASSISTANT, ) # Third: tool result - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=[ FunctionResultContent( call_id="call_123", @@ -234,7 +242,7 @@ async def run_stream( role=Role.TOOL, ) # Fourth: final text - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=[TextContent(text="The weather is sunny, 72°F.")], role=Role.ASSISTANT, ) @@ -287,9 +295,9 @@ def create_mock_tool_agent(id: str = "tool_agent", name: str = "ToolAgent") -> M return MockToolCallingAgent(id=id, name=name) -def create_agent_run_response(text: str = "Test response") -> AgentRunResponse: - """Create an AgentRunResponse with the given text.""" - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=text)])]) +def create_agent_run_response(text: str = "Test response") -> AgentResponse: + """Create an AgentResponse with the given text.""" + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=text)])]) def create_agent_executor_response( @@ -300,7 +308,7 @@ def create_agent_executor_response( agent_response = create_agent_run_response(response_text) return AgentExecutorResponse( executor_id=executor_id, - agent_run_response=agent_response, + agent_response=agent_response, full_conversation=[ ChatMessage(role=Role.USER, contents=[TextContent(text="User input")]), ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=response_text)]), @@ -316,7 +324,7 @@ def create_executor_completed_event( This creates the exact data structure that caused the serialization bug: ExecutorCompletedEvent.data contains AgentExecutorResponse which contains - AgentRunResponse and ChatMessage objects (SerializationMixin, not Pydantic). + AgentResponse and ChatMessage objects (SerializationMixin, not Pydantic). """ data = create_agent_executor_response(executor_id) if with_agent_response else {"simple": "dict"} return ExecutorCompletedEvent(executor_id=executor_id, data=data) diff --git a/python/packages/devui/tests/test_mapper.py b/python/packages/devui/tests/test_mapper.py index 2de788257b..faf0c831d8 100644 --- a/python/packages/devui/tests/test_mapper.py +++ b/python/packages/devui/tests/test_mapper.py @@ -13,7 +13,7 @@ # Import Agent Framework types from agent_framework._types import ( - AgentRunResponseUpdate, + AgentResponseUpdate, ErrorContent, FunctionCallContent, FunctionResultContent, @@ -83,11 +83,9 @@ def create_test_content(content_type: str, **kwargs: Any) -> Any: raise ValueError(f"Unknown content type: {content_type}") -def create_test_agent_update(contents: list[Any]) -> AgentRunResponseUpdate: - """Create test AgentRunResponseUpdate.""" - return AgentRunResponseUpdate( - contents=contents, role=Role.ASSISTANT, message_id="test_msg", response_id="test_resp" - ) +def create_test_agent_update(contents: list[Any]) -> AgentResponseUpdate: + """Create test AgentResponseUpdate.""" + return AgentResponseUpdate(contents=contents, role=Role.ASSISTANT, message_id="test_msg", response_id="test_resp") # ============================================================================= @@ -105,7 +103,7 @@ async def test_critical_isinstance_bug_detection(mapper: MessageMapper, test_req assert not hasattr(update, "response") # Fake attribute should not exist # Test isinstance works with real types - assert isinstance(update, AgentRunResponseUpdate) + assert isinstance(update, AgentResponseUpdate) # Test mapper conversion - should NOT produce "Unknown event" events = await mapper.convert_event(update, test_request) @@ -264,7 +262,7 @@ async def test_agent_lifecycle_events(mapper: MessageMapper, test_request: Agent async def test_agent_run_response_mapping(mapper: MessageMapper, test_request: AgentFrameworkRequest) -> None: - """Test that mapper handles complete AgentRunResponse (non-streaming).""" + """Test that mapper handles complete AgentResponse (non-streaming).""" response = create_agent_run_response("Complete response from run()") events = await mapper.convert_event(response, test_request) @@ -325,14 +323,14 @@ async def test_executor_completed_event_with_agent_response( This is a REGRESSION TEST for the serialization bug where ExecutorCompletedEvent.data contained AgentExecutorResponse with nested - AgentRunResponse and ChatMessage objects (SerializationMixin) that + AgentResponse and ChatMessage objects (SerializationMixin) that Pydantic couldn't serialize. """ # Create event with realistic nested data - the exact structure that caused the bug event = create_executor_completed_event(executor_id="exec_agent", with_agent_response=True) # Verify the data has the problematic structure - assert hasattr(event.data, "agent_run_response") + assert hasattr(event.data, "agent_response") assert hasattr(event.data, "full_conversation") # First invoke the executor @@ -380,7 +378,7 @@ async def test_executor_completed_event_serialization_to_json( done_event = events[0] # This is the critical test - model_dump_json() should NOT raise - # "Unable to serialize unknown type: " + # "Unable to serialize unknown type: " try: json_str = done_event.model_dump_json() assert json_str is not None @@ -453,11 +451,11 @@ async def test_magentic_agent_run_update_event_with_agent_delta_metadata( This tests the ACTUAL event format Magentic emits - not a fake MagenticAgentDeltaEvent class. Magentic uses AgentRunUpdateEvent with additional_properties containing magentic_event_type. """ - from agent_framework._types import AgentRunResponseUpdate, Role, TextContent + from agent_framework._types import AgentResponseUpdate, Role, TextContent from agent_framework._workflows._events import AgentRunUpdateEvent # Create the REAL event format that Magentic emits - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[TextContent(text="Hello from agent")], role=Role.ASSISTANT, author_name="Writer", @@ -484,11 +482,11 @@ async def test_magentic_orchestrator_message_event(mapper: MessageMapper, test_r Magentic emits orchestrator planning/instruction messages using AgentRunUpdateEvent with additional_properties containing magentic_event_type='orchestrator_message'. """ - from agent_framework._types import AgentRunResponseUpdate, Role, TextContent + from agent_framework._types import AgentResponseUpdate, Role, TextContent from agent_framework._workflows._events import AgentRunUpdateEvent # Create orchestrator message event (REAL format from Magentic) - update = AgentRunResponseUpdate( + update = AgentResponseUpdate( contents=[TextContent(text="Planning: First, the writer will create content...")], role=Role.ASSISTANT, author_name="Orchestrator", @@ -520,19 +518,19 @@ async def test_magentic_events_use_same_event_class_as_other_workflows( additional_properties. Any mapper code checking for 'MagenticAgentDeltaEvent' class names is dead code. """ - from agent_framework._types import AgentRunResponseUpdate, Role, TextContent + from agent_framework._types import AgentResponseUpdate, Role, TextContent from agent_framework._workflows._events import AgentRunUpdateEvent # Create events the way different workflows do it # 1. Regular workflow (no additional_properties) - regular_update = AgentRunResponseUpdate( + regular_update = AgentResponseUpdate( contents=[TextContent(text="Regular workflow response")], role=Role.ASSISTANT, ) regular_event = AgentRunUpdateEvent(executor_id="regular_executor", data=regular_update) # 2. Magentic workflow (with additional_properties) - magentic_update = AgentRunResponseUpdate( + magentic_update = AgentResponseUpdate( contents=[TextContent(text="Magentic workflow response")], role=Role.ASSISTANT, additional_properties={"magentic_event_type": "agent_delta"}, @@ -587,7 +585,7 @@ async def test_workflow_output_event(mapper: MessageMapper, test_request: AgentF """Test WorkflowOutputEvent is converted to output_item.added.""" from agent_framework._workflows._events import WorkflowOutputEvent - event = WorkflowOutputEvent(data="Final workflow output", source_executor_id="final_executor") + event = WorkflowOutputEvent(data="Final workflow output", executor_id="final_executor") events = await mapper.convert_event(event, test_request) # WorkflowOutputEvent should emit output_item.added @@ -609,7 +607,7 @@ async def test_workflow_output_event_with_list_data(mapper: MessageMapper, test_ ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")]), ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="World")]), ] - event = WorkflowOutputEvent(data=messages, source_executor_id="complete") + event = WorkflowOutputEvent(data=messages, executor_id="complete") events = await mapper.convert_event(event, test_request) assert len(events) == 1 diff --git a/python/packages/durabletask/agent_framework_durabletask/_callbacks.py b/python/packages/durabletask/agent_framework_durabletask/_callbacks.py index 3e38cdb6ec..53c4c2d71a 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_callbacks.py +++ b/python/packages/durabletask/agent_framework_durabletask/_callbacks.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from typing import Protocol -from agent_framework import AgentRunResponse, AgentRunResponseUpdate +from agent_framework import AgentResponse, AgentResponseUpdate @dataclass(frozen=True) @@ -27,14 +27,14 @@ class AgentResponseCallbackProtocol(Protocol): async def on_streaming_response_update( self, - update: AgentRunResponseUpdate, + update: AgentResponseUpdate, context: AgentCallbackContext, ) -> None: """Handle a streaming response update emitted by the agent.""" async def on_agent_response( self, - response: AgentRunResponse, + response: AgentResponse, context: AgentCallbackContext, ) -> None: """Handle the final agent response.""" diff --git a/python/packages/durabletask/agent_framework_durabletask/_client.py b/python/packages/durabletask/agent_framework_durabletask/_client.py index 5e70a7ba1f..5cf85f7a2b 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_client.py +++ b/python/packages/durabletask/agent_framework_durabletask/_client.py @@ -8,7 +8,7 @@ from __future__ import annotations -from agent_framework import AgentRunResponse, get_logger +from agent_framework import AgentResponse, get_logger from durabletask.client import TaskHubGrpcClient from ._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS @@ -18,7 +18,7 @@ logger = get_logger("agent_framework.durabletask.client") -class DurableAIAgentClient(DurableAgentProvider[AgentRunResponse]): +class DurableAIAgentClient(DurableAgentProvider[AgentResponse]): """Client wrapper for interacting with durable agents externally. This class wraps a durabletask TaskHubGrpcClient and provides a convenient @@ -68,7 +68,7 @@ def __init__( self._executor = ClientAgentExecutor(self._client, self.max_poll_retries, self.poll_interval_seconds) logger.debug("[DurableAIAgentClient] Initialized with client type: %s", type(client).__name__) - def get_agent(self, agent_name: str) -> DurableAIAgent[AgentRunResponse]: + def get_agent(self, agent_name: str) -> DurableAIAgent[AgentResponse]: """Retrieve a DurableAIAgent shim for the specified agent. This method returns a proxy object that can be used to execute the agent. diff --git a/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py b/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py index 453d180612..cbc098f848 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py +++ b/python/packages/durabletask/agent_framework_durabletask/_durable_agent_state.py @@ -35,7 +35,7 @@ from typing import Any, cast from agent_framework import ( - AgentRunResponse, + AgentResponse, BaseContent, ChatMessage, DataContent, @@ -452,7 +452,7 @@ def message_count(self) -> int: """Get the count of conversation entries (requests + responses).""" return len(self.data.conversation_history) - def try_get_agent_response(self, correlation_id: str) -> AgentRunResponse | None: + def try_get_agent_response(self, correlation_id: str) -> AgentResponse | None: """Try to get an agent response by correlation ID. This method searches the conversation history for a response entry matching the given @@ -690,8 +690,8 @@ def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateResponse: ) @staticmethod - def from_run_response(correlation_id: str, response: AgentRunResponse) -> DurableAgentStateResponse: - """Creates a DurableAgentStateResponse from an AgentRunResponse.""" + def from_run_response(correlation_id: str, response: AgentResponse) -> DurableAgentStateResponse: + """Creates a DurableAgentStateResponse from an AgentResponse.""" return DurableAgentStateResponse( correlation_id=correlation_id, created_at=_parse_created_at(response.created_at), @@ -702,13 +702,13 @@ def from_run_response(correlation_id: str, response: AgentRunResponse) -> Durabl @staticmethod def to_run_response( response_entry: DurableAgentStateResponse, - ) -> AgentRunResponse: - """Converts a DurableAgentStateResponse back to an AgentRunResponse.""" + ) -> AgentResponse: + """Converts a DurableAgentStateResponse back to an AgentResponse.""" messages = [m.to_chat_message() for m in response_entry.messages] usage_details = response_entry.usage.to_usage_details() if response_entry.usage is not None else UsageDetails() - return AgentRunResponse( + return AgentResponse( created_at=response_entry.created_at.isoformat(), messages=messages, usage_details=usage_details, diff --git a/python/packages/durabletask/agent_framework_durabletask/_entities.py b/python/packages/durabletask/agent_framework_durabletask/_entities.py index 7c99a4edc3..38b912bc64 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_entities.py +++ b/python/packages/durabletask/agent_framework_durabletask/_entities.py @@ -10,8 +10,8 @@ from agent_framework import ( AgentProtocol, - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, ChatMessage, ErrorContent, Role, @@ -125,7 +125,7 @@ def _is_error_response(self, entry: DurableAgentStateEntry) -> bool: async def run( self, request: RunRequest | dict[str, Any] | str, - ) -> AgentRunResponse: + ) -> AgentResponse: """Execute the agent with a message.""" if isinstance(request, str): run_request = RunRequest.from_json(request) @@ -139,8 +139,10 @@ async def run( correlation_id = run_request.correlation_id if not thread_id: raise ValueError("Entity State Provider must provide a thread_id") - response_format = run_request.response_format - enable_tool_calls = run_request.enable_tool_calls + options: dict[str, Any] = dict(run_request.options) + options.setdefault("response_format", run_request.response_format) + if not run_request.enable_tool_calls: + options.setdefault("tools", None) logger.debug("[AgentEntity.run] Received ThreadId %s Message: %s", thread_id, run_request) @@ -155,13 +157,9 @@ async def run( for m in entry.messages ] - run_kwargs: dict[str, Any] = {"messages": chat_messages} - if not enable_tool_calls: - run_kwargs["tools"] = None - if response_format: - run_kwargs["response_format"] = response_format + run_kwargs: dict[str, Any] = {"messages": chat_messages, "options": options} - agent_run_response: AgentRunResponse = await self._invoke_agent( + agent_run_response: AgentResponse = await self._invoke_agent( run_kwargs=run_kwargs, correlation_id=correlation_id, thread_id=thread_id, @@ -180,7 +178,7 @@ async def run( error_message = ChatMessage( role=Role.ASSISTANT, contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)] ) - error_response = AgentRunResponse(messages=[error_message]) + error_response = AgentResponse(messages=[error_message]) error_state_response = DurableAgentStateResponse.from_run_response(correlation_id, error_response) error_state_response.is_error = True @@ -195,7 +193,7 @@ async def _invoke_agent( correlation_id: str, thread_id: str, request_message: str, - ) -> AgentRunResponse: + ) -> AgentResponse: """Execute the agent, preferring streaming when available.""" callback_context: AgentCallbackContext | None = None if self.callback is not None: @@ -213,7 +211,7 @@ async def _invoke_agent( stream_candidate = await stream_candidate return await self._consume_stream( - stream=cast(AsyncIterable[AgentRunResponseUpdate], stream_candidate), + stream=cast(AsyncIterable[AgentResponseUpdate], stream_candidate), callback_context=callback_context, ) except TypeError as type_error: @@ -238,26 +236,26 @@ async def _invoke_agent( async def _consume_stream( self, - stream: AsyncIterable[AgentRunResponseUpdate], + stream: AsyncIterable[AgentResponseUpdate], callback_context: AgentCallbackContext | None = None, - ) -> AgentRunResponse: - """Consume streaming responses and build the final AgentRunResponse.""" - updates: list[AgentRunResponseUpdate] = [] + ) -> AgentResponse: + """Consume streaming responses and build the final AgentResponse.""" + updates: list[AgentResponseUpdate] = [] async for update in stream: updates.append(update) await self._notify_stream_update(update, callback_context) if updates: - response = AgentRunResponse.from_agent_run_response_updates(updates) + response = AgentResponse.from_agent_run_response_updates(updates) else: logger.debug("[AgentEntity] No streaming updates received; creating empty response") - response = AgentRunResponse(messages=[]) + response = AgentResponse(messages=[]) await self._notify_final_response(response, callback_context) return response - async def _invoke_non_stream(self, run_kwargs: dict[str, Any]) -> AgentRunResponse: + async def _invoke_non_stream(self, run_kwargs: dict[str, Any]) -> AgentResponse: """Invoke the agent without streaming support.""" run_callable = getattr(self.agent, "run", None) if run_callable is None or not callable(run_callable): @@ -267,14 +265,14 @@ async def _invoke_non_stream(self, run_kwargs: dict[str, Any]) -> AgentRunRespon if inspect.isawaitable(result): result = await result - if not isinstance(result, AgentRunResponse): - raise TypeError(f"Agent run() must return an AgentRunResponse instance; received {type(result).__name__}") + if not isinstance(result, AgentResponse): + raise TypeError(f"Agent run() must return an AgentResponse instance; received {type(result).__name__}") return result async def _notify_stream_update( self, - update: AgentRunResponseUpdate, + update: AgentResponseUpdate, context: AgentCallbackContext | None, ) -> None: """Invoke the streaming callback if one is registered.""" @@ -294,7 +292,7 @@ async def _notify_stream_update( async def _notify_final_response( self, - response: AgentRunResponse, + response: AgentResponse, context: AgentCallbackContext | None, ) -> None: """Invoke the final response callback if one is registered.""" diff --git a/python/packages/durabletask/agent_framework_durabletask/_executors.py b/python/packages/durabletask/agent_framework_durabletask/_executors.py index d670a09a32..e3e3e4cfc6 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_executors.py +++ b/python/packages/durabletask/agent_framework_durabletask/_executors.py @@ -16,7 +16,7 @@ from datetime import datetime, timezone from typing import Any, Generic, TypeVar -from agent_framework import AgentRunResponse, AgentThread, ChatMessage, ErrorContent, Role, TextContent, get_logger +from agent_framework import AgentResponse, AgentThread, ChatMessage, ErrorContent, Role, TextContent, get_logger from durabletask.client import TaskHubGrpcClient from durabletask.entities import EntityInstanceId from durabletask.task import CompletableTask, CompositeTask, OrchestrationContext, Task @@ -33,14 +33,14 @@ TaskT = TypeVar("TaskT") -class DurableAgentTask(CompositeTask[AgentRunResponse], CompletableTask[AgentRunResponse]): - """A custom Task that wraps entity calls and provides typed AgentRunResponse results. +class DurableAgentTask(CompositeTask[AgentResponse], CompletableTask[AgentResponse]): + """A custom Task that wraps entity calls and provides typed AgentResponse results. This task wraps the underlying entity call task and intercepts its completion - to convert the raw result into a typed AgentRunResponse object. + to convert the raw result into a typed AgentResponse object. - When yielded in an orchestration, this task returns an AgentRunResponse: - response: AgentRunResponse = yield durable_agent_task + When yielded in an orchestration, this task returns an AgentResponse: + response: AgentResponse = yield durable_agent_task """ def __init__( @@ -93,7 +93,7 @@ def on_child_completed(self, task: Task[Any]) -> None: response, ) - # Set the typed AgentRunResponse as this task's result + # Set the typed AgentResponse as this task's result self.complete(response) except Exception as ex: @@ -147,28 +147,37 @@ def generate_unique_id(self) -> str: def get_run_request( self, message: str, - response_format: type[BaseModel] | None, - enable_tool_calls: bool, - wait_for_response: bool = True, + *, + options: dict[str, Any] | None = None, ) -> RunRequest: - """Create a RunRequest for the given parameters.""" + """Create a RunRequest from message and options.""" correlation_id = self.generate_unique_id() + + # Create a copy to avoid modifying the caller's dict + opts = dict(options) if options else {} + + # Extract and REMOVE known keys from options copy + response_format = opts.pop("response_format", None) + enable_tool_calls = opts.pop("enable_tool_calls", True) + wait_for_response = opts.pop("wait_for_response", True) + return RunRequest( message=message, response_format=response_format, enable_tool_calls=enable_tool_calls, wait_for_response=wait_for_response, correlation_id=correlation_id, + options=opts, ) - def _create_acceptance_response(self, correlation_id: str) -> AgentRunResponse: + def _create_acceptance_response(self, correlation_id: str) -> AgentResponse: """Create an acceptance response for fire-and-forget mode. Args: correlation_id: Correlation ID for tracking the request Returns: - AgentRunResponse: Acceptance response with correlation ID + AgentResponse: Acceptance response with correlation ID """ acceptance_message = ChatMessage( role=Role.SYSTEM, @@ -180,16 +189,16 @@ def _create_acceptance_response(self, correlation_id: str) -> AgentRunResponse: ) ], ) - return AgentRunResponse( + return AgentResponse( messages=[acceptance_message], created_at=datetime.now(timezone.utc).isoformat(), ) -class ClientAgentExecutor(DurableAgentExecutor[AgentRunResponse]): +class ClientAgentExecutor(DurableAgentExecutor[AgentResponse]): """Execution strategy for external clients. - Note: Returns AgentRunResponse directly since the execution + Note: Returns AgentResponse directly since the execution is blocking until response is available via polling as per the design of TaskHubGrpcClient. """ @@ -209,7 +218,7 @@ def run_durable_agent( agent_name: str, run_request: RunRequest, thread: AgentThread | None = None, - ) -> AgentRunResponse: + ) -> AgentResponse: """Execute the agent via the durabletask client. Signals the agent entity with a message request, then polls the entity @@ -225,7 +234,7 @@ def run_durable_agent( thread: Optional conversation thread (creates new if not provided) Returns: - AgentRunResponse: The agent's response after execution completes, or an immediate + AgentResponse: The agent's response after execution completes, or an immediate acknowledgement if wait_for_response is False """ # Signal the entity with the request @@ -284,7 +293,7 @@ def _poll_for_agent_response( self, entity_id: EntityInstanceId, correlation_id: str, - ) -> AgentRunResponse | None: + ) -> AgentResponse | None: """Poll the entity for a response with retries. Args: @@ -319,10 +328,10 @@ def _poll_for_agent_response( def _handle_agent_response( self, - agent_response: AgentRunResponse | None, + agent_response: AgentResponse | None, response_format: type[BaseModel] | None, correlation_id: str, - ) -> AgentRunResponse: + ) -> AgentResponse: """Handle the agent response or create an error response. Args: @@ -331,7 +340,7 @@ def _handle_agent_response( correlation_id: Correlation ID for logging Returns: - AgentRunResponse with either the agent's response or an error message + AgentResponse with either the agent's response or an error message """ if agent_response is not None: try: @@ -375,7 +384,7 @@ def _handle_agent_response( ], ) - return AgentRunResponse( + return AgentResponse( messages=[error_message], created_at=datetime.now(timezone.utc).isoformat(), ) @@ -384,7 +393,7 @@ def _poll_entity_for_response( self, entity_id: EntityInstanceId, correlation_id: str, - ) -> AgentRunResponse | None: + ) -> AgentResponse | None: """Poll the entity state for a response matching the correlation ID. Args: @@ -392,7 +401,7 @@ def _poll_entity_for_response( correlation_id: Correlation ID to search for Returns: - Response AgentRunResponse, None otherwise + Response AgentResponse, None otherwise """ try: entity_metadata = self._client.get_entity(entity_id, include_state=True) @@ -466,7 +475,7 @@ def run_durable_agent( thread: Optional conversation thread (creates new if not provided) Returns: - DurableAgentTask: A task wrapping the entity call that yields AgentRunResponse + DurableAgentTask: A task wrapping the entity call that yields AgentResponse """ # Resolve session session_id = self._create_session_id(agent_name, thread) @@ -495,7 +504,7 @@ def run_durable_agent( # Create a pre-completed task with acceptance response acceptance_response = self._create_acceptance_response(run_request.correlation_id) - entity_task: CompletableTask[AgentRunResponse] = CompletableTask() + entity_task: CompletableTask[AgentResponse] = CompletableTask() entity_task.complete(acceptance_response) else: # Blocking mode: call entity and wait for response diff --git a/python/packages/durabletask/agent_framework_durabletask/_models.py b/python/packages/durabletask/agent_framework_durabletask/_models.py index 169417ebae..1b9edda5fa 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_models.py +++ b/python/packages/durabletask/agent_framework_durabletask/_models.py @@ -11,7 +11,7 @@ import json import uuid from collections.abc import MutableMapping -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime, timezone from importlib import import_module from typing import TYPE_CHECKING, Any, cast @@ -109,6 +109,7 @@ class RunRequest: correlation_id: Correlation ID for tracking the response to this specific request created_at: Optional timestamp when the request was created orchestration_id: Optional ID of the orchestration that initiated this request + options: Optional options dictionary forwarded to the agent """ message: str @@ -120,6 +121,7 @@ class RunRequest: wait_for_response: bool = True created_at: datetime | None = None orchestration_id: str | None = None + options: dict[str, Any] = field(default_factory=lambda: {}) def __init__( self, @@ -132,6 +134,7 @@ def __init__( wait_for_response: bool = True, created_at: datetime | None = None, orchestration_id: str | None = None, + options: dict[str, Any] | None = None, ) -> None: self.message = message self.correlation_id = correlation_id @@ -142,6 +145,7 @@ def __init__( self.wait_for_response = wait_for_response self.created_at = created_at if created_at is not None else datetime.now(tz=timezone.utc) self.orchestration_id = orchestration_id + self.options = options if options is not None else {} @staticmethod def coerce_role(value: Role | str | None) -> Role: @@ -164,6 +168,7 @@ def to_dict(self) -> dict[str, Any]: "role": self.role.value, "request_response_format": self.request_response_format, "correlationId": self.correlation_id, + "options": self.options, } if self.response_format: result["response_format"] = serialize_response_format(self.response_format) @@ -171,7 +176,6 @@ def to_dict(self) -> dict[str, Any]: result["created_at"] = self.created_at.isoformat() if self.orchestration_id: result["orchestrationId"] = self.orchestration_id - return result @classmethod @@ -198,6 +202,8 @@ def from_dict(cls, data: dict[str, Any]) -> RunRequest: if not correlation_id: raise ValueError("correlationId is required in RunRequest data") + options = data.get("options") + return cls( message=data.get("message", ""), correlation_id=correlation_id, @@ -208,6 +214,7 @@ def from_dict(cls, data: dict[str, Any]) -> RunRequest: enable_tool_calls=data.get("enable_tool_calls", True), created_at=created_at, orchestration_id=data.get("orchestrationId"), + options=cast(dict[str, Any], options) if isinstance(options, dict) else {}, ) diff --git a/python/packages/durabletask/agent_framework_durabletask/_response_utils.py b/python/packages/durabletask/agent_framework_durabletask/_response_utils.py index 123de7b0cf..fd622d9b35 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_response_utils.py +++ b/python/packages/durabletask/agent_framework_durabletask/_response_utils.py @@ -1,23 +1,23 @@ # Copyright (c) Microsoft. All rights reserved. -"""Shared utilities for handling AgentRunResponse parsing and validation.""" +"""Shared utilities for handling AgentResponse parsing and validation.""" from typing import Any -from agent_framework import AgentRunResponse, get_logger +from agent_framework import AgentResponse, get_logger from pydantic import BaseModel logger = get_logger("agent_framework.durabletask.response_utils") -def load_agent_response(agent_response: AgentRunResponse | dict[str, Any] | None) -> AgentRunResponse: - """Convert raw payloads into AgentRunResponse instance. +def load_agent_response(agent_response: AgentResponse | dict[str, Any] | None) -> AgentResponse: + """Convert raw payloads into AgentResponse instance. Args: - agent_response: The response to convert, can be an AgentRunResponse, dict, or None + agent_response: The response to convert, can be an AgentResponse, dict, or None Returns: - AgentRunResponse: The converted response object + AgentResponse: The converted response object Raises: ValueError: If agent_response is None @@ -28,11 +28,11 @@ def load_agent_response(agent_response: AgentRunResponse | dict[str, Any] | None logger.debug("[load_agent_response] Loading agent response of type: %s", type(agent_response)) - if isinstance(agent_response, AgentRunResponse): + if isinstance(agent_response, AgentResponse): return agent_response if isinstance(agent_response, dict): - logger.debug("[load_agent_response] Converting dict payload using AgentRunResponse.from_dict") - return AgentRunResponse.from_dict(agent_response) + logger.debug("[load_agent_response] Converting dict payload using AgentResponse.from_dict") + return AgentResponse.from_dict(agent_response) raise TypeError(f"Unsupported type for agent_response: {type(agent_response)}") @@ -40,9 +40,9 @@ def load_agent_response(agent_response: AgentRunResponse | dict[str, Any] | None def ensure_response_format( response_format: type[BaseModel] | None, correlation_id: str, - response: AgentRunResponse, + response: AgentResponse, ) -> None: - """Ensure the AgentRunResponse value is parsed into the expected response_format. + """Ensure the AgentResponse value is parsed into the expected response_format. This function modifies the response in-place by parsing its value attribute into the specified Pydantic model format. @@ -50,7 +50,7 @@ def ensure_response_format( Args: response_format: Optional Pydantic model class to parse the response value into correlation_id: Correlation ID for logging purposes - response: The AgentRunResponse object to validate and parse + response: The AgentResponse object to validate and parse Raises: ValueError: If response_format is specified but response.value cannot be parsed @@ -66,7 +66,7 @@ def ensure_response_format( ) logger.debug( - "[ensure_response_format] Loaded AgentRunResponse.value for correlation_id %s with type: %s", + "[ensure_response_format] Loaded AgentResponse.value for correlation_id %s with type: %s", correlation_id, type(response.value).__name__, ) diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py index 5bfde22e74..7fa69c68f1 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_shim.py +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -13,8 +13,7 @@ from collections.abc import AsyncIterator from typing import Any, Generic, TypeVar -from agent_framework import AgentProtocol, AgentRunResponseUpdate, AgentThread, ChatMessage -from pydantic import BaseModel +from agent_framework import AgentProtocol, AgentResponseUpdate, AgentThread, ChatMessage from ._executors import DurableAgentExecutor from ._models import DurableAgentThread @@ -55,7 +54,7 @@ class DurableAIAgent(AgentProtocol, Generic[TaskT]): This class implements AgentProtocol but with one critical difference: - AgentProtocol.run() returns a Coroutine (async, must await) - DurableAIAgent.run() returns TaskT (sync Task object - must yield - or the AgentRunResponse directly in the case of TaskHubGrpcClient) + or the AgentResponse directly in the case of TaskHubGrpcClient) This represents fundamentally different execution models but maintains the same interface contract for all other properties and methods. @@ -64,7 +63,7 @@ class DurableAIAgent(AgentProtocol, Generic[TaskT]): and what type of Task object is returned. Type Parameters: - TaskT: The task type returned by this agent (e.g., AgentRunResponse, DurableAgentTask, AgentTask) + TaskT: The task type returned by this agent (e.g., AgentResponse, DurableAgentTask, AgentTask) """ def __init__(self, executor: DurableAgentExecutor[TaskT], name: str, *, agent_id: str | None = None): @@ -106,25 +105,20 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride] messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: AgentThread | None = None, - response_format: type[BaseModel] | None = None, - enable_tool_calls: bool = True, - wait_for_response: bool = True, + options: dict[str, Any] | None = None, ) -> TaskT: """Execute the agent via the injected provider. Args: messages: The message(s) to send to the agent thread: Optional agent thread for conversation context - response_format: Optional Pydantic model for structured response - enable_tool_calls: Whether to enable tool calls for this request - wait_for_response: If True (default), waits for agent response. - If False, returns immediately (fire-and-forget mode). - - **Only supported for DurableAIAgentClient contexts.** + options: Optional options dictionary. Supported keys include + ``response_format``, ``enable_tool_calls``, and ``wait_for_response``. + Additional keys are forwarded to the agent execution. Note: This method overrides AgentProtocol.run() with a different return type: - - AgentProtocol.run() returns Coroutine[Any, Any, AgentRunResponse] (async) + - AgentProtocol.run() returns Coroutine[Any, Any, AgentResponse] (async) - DurableAIAgent.run() returns TaskT (Task object for yielding) This is intentional to support orchestration contexts that use yield patterns @@ -140,9 +134,7 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride] run_request = self._executor.get_run_request( message=message_str, - response_format=response_format, - enable_tool_calls=enable_tool_calls, - wait_for_response=wait_for_response, + options=options, ) return self._executor.run_durable_agent( @@ -157,7 +149,7 @@ def run_stream( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterator[AgentRunResponseUpdate]: + ) -> AsyncIterator[AgentResponseUpdate]: """Run the agent with streaming (not supported for durable agents). Args: diff --git a/python/packages/durabletask/agent_framework_durabletask/_worker.py b/python/packages/durabletask/agent_framework_durabletask/_worker.py index fea4b8ba7c..cf7f519ce4 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_worker.py +++ b/python/packages/durabletask/agent_framework_durabletask/_worker.py @@ -181,7 +181,7 @@ def run(self, request: Any) -> Any: request: RunRequest as dict or string Returns: - AgentRunResponse as dict + AgentResponse as dict """ logger.debug("[ConfiguredAgentEntity.run] Executing agent: %s", agent_name) # Get or create event loop for async execution diff --git a/python/packages/durabletask/tests/test_entities.py b/python/packages/durabletask/tests/test_entities.py index 386357bb4c..12baffa3bf 100644 --- a/python/packages/durabletask/tests/test_entities.py +++ b/python/packages/durabletask/tests/test_entities.py @@ -11,7 +11,7 @@ from unittest.mock import AsyncMock, Mock import pytest -from agent_framework import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, ErrorContent, Role +from agent_framework import AgentResponse, AgentResponseUpdate, ChatMessage, ErrorContent, Role from pydantic import BaseModel from agent_framework_durabletask import ( @@ -79,12 +79,12 @@ def _role_value(chat_message: DurableAgentStateMessage) -> str: return str(role_value) -def _agent_response(text: str | None) -> AgentRunResponse: - """Create an AgentRunResponse with a single assistant message.""" +def _agent_response(text: str | None) -> AgentResponse: + """Create an AgentResponse with a single assistant message.""" message = ( ChatMessage(role="assistant", text=text) if text is not None else ChatMessage(role="assistant", contents=[]) ) - return AgentRunResponse(messages=[message]) + return AgentResponse(messages=[message]) class RecordingCallback: @@ -96,12 +96,12 @@ def __init__(self): async def on_streaming_response_update( self, - update: AgentRunResponseUpdate, + update: AgentResponseUpdate, context: Any, ) -> None: await self.stream_mock(update, context) - async def on_agent_response(self, response: AgentRunResponse, context: Any) -> None: + async def on_agent_response(self, response: AgentResponse, context: Any) -> None: await self.response_mock(response, context) @@ -216,17 +216,17 @@ async def test_run_executes_agent(self) -> None: assert getattr(sent_message.role, "value", sent_message.role) == "user" # Verify result - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert result.text == "Test response" async def test_run_agent_streaming_callbacks_invoked(self) -> None: """Ensure streaming updates trigger callbacks and run() is not used.""" updates = [ - AgentRunResponseUpdate(text="Hello"), - AgentRunResponseUpdate(text=" world"), + AgentResponseUpdate(text="Hello"), + AgentResponseUpdate(text=" world"), ] - async def update_generator() -> AsyncIterator[AgentRunResponseUpdate]: + async def update_generator() -> AsyncIterator[AgentResponseUpdate]: for update in updates: yield update @@ -245,7 +245,7 @@ async def update_generator() -> AsyncIterator[AgentRunResponseUpdate]: }, ) - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert "Hello" in result.text assert callback.stream_mock.await_count == len(updates) assert callback.response_mock.await_count == 1 @@ -288,7 +288,7 @@ async def test_run_agent_final_callback_without_streaming(self) -> None: }, ) - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert result.text == "Final response" assert callback.stream_mock.await_count == 0 assert callback.response_mock.await_count == 1 @@ -453,7 +453,7 @@ async def test_run_agent_handles_agent_exception(self) -> None: result = await entity.run({"message": "Message", "correlationId": "corr-entity-error-1"}) - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert len(result.messages) == 1 content = result.messages[0].contents[0] assert isinstance(content, ErrorContent) @@ -469,7 +469,7 @@ async def test_run_agent_handles_value_error(self) -> None: result = await entity.run({"message": "Message", "correlationId": "corr-entity-error-2"}) - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert len(result.messages) == 1 content = result.messages[0].contents[0] assert isinstance(content, ErrorContent) @@ -485,7 +485,7 @@ async def test_run_agent_handles_timeout_error(self) -> None: result = await entity.run({"message": "Message", "correlationId": "corr-entity-error-3"}) - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert len(result.messages) == 1 content = result.messages[0].contents[0] assert isinstance(content, ErrorContent) @@ -503,7 +503,7 @@ async def test_run_agent_preserves_message_on_error(self) -> None: ) # Even on error, message info should be preserved - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert len(result.messages) == 1 content = result.messages[0].contents[0] assert isinstance(content, ErrorContent) @@ -602,7 +602,7 @@ async def test_run_agent_with_run_request_object(self) -> None: result = await entity.run(request) - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert result.text == "Response" async def test_run_agent_with_dict_request(self) -> None: @@ -621,7 +621,7 @@ async def test_run_agent_with_dict_request(self) -> None: result = await entity.run(request_dict) - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert result.text == "Response" async def test_run_agent_with_string_raises_without_correlation(self) -> None: @@ -671,7 +671,7 @@ async def test_run_agent_with_response_format(self) -> None: result = await entity.run(request) - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert result.text == '{"answer": 42}' assert result.value is None @@ -686,7 +686,7 @@ async def test_run_agent_disable_tool_calls(self) -> None: result = await entity.run(request) - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) # Agent should have been called (tool disabling is framework-dependent) mock_agent.run.assert_called_once() diff --git a/python/packages/durabletask/tests/test_executors.py b/python/packages/durabletask/tests/test_executors.py index ac8bd729a4..46fe8bbdbc 100644 --- a/python/packages/durabletask/tests/test_executors.py +++ b/python/packages/durabletask/tests/test_executors.py @@ -11,7 +11,7 @@ from unittest.mock import Mock import pytest -from agent_framework import AgentRunResponse, Role +from agent_framework import AgentResponse, Role from durabletask.entities import EntityInstanceId from durabletask.task import Task from pydantic import BaseModel @@ -150,11 +150,11 @@ class TestClientAgentExecutorRun: def test_client_executor_run_returns_response( self, client_executor: ClientAgentExecutor, sample_run_request: RunRequest ) -> None: - """Verify ClientAgentExecutor.run_durable_agent returns AgentRunResponse (synchronous).""" + """Verify ClientAgentExecutor.run_durable_agent returns AgentResponse (synchronous).""" result = client_executor.run_durable_agent("test_agent", sample_run_request) - # Verify it returns an AgentRunResponse (synchronous, not a coroutine) - assert isinstance(result, AgentRunResponse) + # Verify it returns an AgentResponse (synchronous, not a coroutine) + assert isinstance(result, AgentResponse) assert result is not None @@ -183,8 +183,8 @@ def test_executor_respects_custom_max_poll_retries(self, mock_client: Mock, samp # Run the agent result = executor.run_durable_agent("test_agent", sample_run_request) - # Verify it returns AgentRunResponse (should timeout after 2 attempts) - assert isinstance(result, AgentRunResponse) + # Verify it returns AgentResponse (should timeout after 2 attempts) + assert isinstance(result, AgentResponse) # Verify get_entity was called 2 times (max_poll_retries) assert mock_client.get_entity.call_count == 2 @@ -202,7 +202,7 @@ def test_executor_respects_custom_poll_interval(self, mock_client: Mock, sample_ # Should take roughly 3 * 0.01 = 0.03 seconds (plus overhead) # Be generous with timing to avoid flakiness assert elapsed < 0.2 # Should be quick with 0.01 interval - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) class TestClientAgentExecutorFireAndForget: @@ -223,8 +223,8 @@ def test_fire_and_forget_returns_immediately(self, mock_client: Mock) -> None: # Should return immediately without polling (elapsed time should be very small) assert elapsed < 0.1 # Much faster than any polling would take - # Should return an AgentRunResponse - assert isinstance(result, AgentRunResponse) + # Should return an AgentResponse + assert isinstance(result, AgentResponse) # Should have signaled the entity but not polled assert mock_client.signal_entity.call_count == 1 @@ -239,7 +239,7 @@ def test_fire_and_forget_returns_empty_response(self, mock_client: Mock) -> None result = executor.run_durable_agent("test_agent", request) # Verify it contains an acceptance message - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert len(result.messages) == 1 assert result.messages[0].role == Role.SYSTEM # Check message contains key information @@ -292,7 +292,7 @@ def test_orchestration_fire_and_forget_returns_acceptance_response(self, mock_or # Get the result response = result.get_result() - assert isinstance(response, AgentRunResponse) + assert isinstance(response, AgentResponse) assert len(response.messages) == 1 assert response.messages[0].role == Role.SYSTEM assert "test-789" in response.messages[0].text @@ -380,7 +380,7 @@ class TestDurableAgentTask: def test_durable_agent_task_transforms_successful_result( self, configure_successful_entity_task: Any, successful_agent_response: dict[str, Any] ) -> None: - """Verify DurableAgentTask converts successful entity result to AgentRunResponse.""" + """Verify DurableAgentTask converts successful entity result to AgentResponse.""" mock_entity_task = configure_successful_entity_task(successful_agent_response) task = DurableAgentTask(entity_task=mock_entity_task, response_format=None, correlation_id="test-123") @@ -390,7 +390,7 @@ def test_durable_agent_task_transforms_successful_result( assert task.is_complete result = task.get_result() - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert len(result.messages) == 1 assert result.messages[0].role == Role.ASSISTANT @@ -427,7 +427,7 @@ class TestResponse(BaseModel): assert task.is_complete result = task.get_result() - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) def test_durable_agent_task_ignores_duplicate_completion( self, configure_successful_entity_task: Any, successful_agent_response: dict[str, Any] @@ -450,7 +450,7 @@ def test_durable_agent_task_ignores_duplicate_completion( def test_durable_agent_task_fails_on_malformed_response(self, configure_successful_entity_task: Any) -> None: """Verify DurableAgentTask fails when entity returns malformed response data.""" - # Use data that will cause AgentRunResponse.from_dict to fail + # Use data that will cause AgentResponse.from_dict to fail # Using a list instead of dict, or other invalid structure mock_entity_task = configure_successful_entity_task("invalid string response") @@ -496,7 +496,7 @@ def test_durable_agent_task_handles_empty_response(self, configure_successful_en assert task.is_complete result = task.get_result() - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert len(result.messages) == 0 def test_durable_agent_task_handles_multiple_messages(self, configure_successful_entity_task: Any) -> None: @@ -517,7 +517,7 @@ def test_durable_agent_task_handles_multiple_messages(self, configure_successful assert task.is_complete result = task.get_result() - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) assert len(result.messages) == 2 assert result.messages[0].role == Role.ASSISTANT assert result.messages[1].role == Role.ASSISTANT @@ -564,7 +564,7 @@ class ComplexResponse(BaseModel): assert task.is_complete assert not task.is_failed result = task.get_result() - assert isinstance(result, AgentRunResponse) + assert isinstance(result, AgentResponse) if __name__ == "__main__": diff --git a/python/packages/durabletask/tests/test_models.py b/python/packages/durabletask/tests/test_models.py index 5a93d74e22..0f6a24293d 100644 --- a/python/packages/durabletask/tests/test_models.py +++ b/python/packages/durabletask/tests/test_models.py @@ -183,6 +183,28 @@ def test_round_trip_with_pydantic_response_format(self) -> None: restored = RunRequest.from_dict(data) assert restored.response_format is ModuleStructuredResponse + def test_round_trip_with_options(self) -> None: + """Ensure options are preserved and response_format is deserialized.""" + original = RunRequest( + message="Test", + correlation_id="corr-opts-1", + response_format=ModuleStructuredResponse, + enable_tool_calls=False, + options={ + "response_format": ModuleStructuredResponse, + "enable_tool_calls": False, + "custom": "value", + }, + ) + + data = original.to_dict() + assert data["options"]["custom"] == "value" + + restored = RunRequest.from_dict(data) + assert restored.options is not None + assert restored.options["custom"] == "value" + assert restored.options["response_format"] is ModuleStructuredResponse + def test_init_with_correlationId(self) -> None: """Test RunRequest initialization with correlationId.""" request = RunRequest(message="Test message", correlation_id="corr-123") diff --git a/python/packages/durabletask/tests/test_shim.py b/python/packages/durabletask/tests/test_shim.py index 4172de6403..26988edca4 100644 --- a/python/packages/durabletask/tests/test_shim.py +++ b/python/packages/durabletask/tests/test_shim.py @@ -35,18 +35,21 @@ def mock_executor() -> Mock: # Mock get_run_request to create actual RunRequest objects def create_run_request( message: str, - response_format: type[BaseModel] | None = None, - enable_tool_calls: bool = True, - wait_for_response: bool = True, + options: dict[str, Any] | None = None, ) -> RunRequest: import uuid + opts = dict(options) if options else {} + response_format = opts.pop("response_format", None) + enable_tool_calls = opts.pop("enable_tool_calls", True) + wait_for_response = opts.pop("wait_for_response", True) return RunRequest( message=message, correlation_id=str(uuid.uuid4()), response_format=response_format, enable_tool_calls=enable_tool_calls, wait_for_response=wait_for_response, + options=opts, ) mock.get_run_request = Mock(side_effect=create_run_request) @@ -132,7 +135,7 @@ def test_run_forwards_thread_parameter(self, test_agent: DurableAIAgent[Any], mo def test_run_forwards_response_format(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run forwards response_format parameter to executor.""" - test_agent.run("message", response_format=ResponseFormatModel) + test_agent.run("message", options={"response_format": ResponseFormatModel}) mock_executor.run_durable_agent.assert_called_once() _, kwargs = mock_executor.run_durable_agent.call_args diff --git a/python/packages/foundry_local/agent_framework_foundry_local/__init__.py b/python/packages/foundry_local/agent_framework_foundry_local/__init__.py index dbea932348..d271839cd7 100644 --- a/python/packages/foundry_local/agent_framework_foundry_local/__init__.py +++ b/python/packages/foundry_local/agent_framework_foundry_local/__init__.py @@ -2,7 +2,7 @@ import importlib.metadata -from ._foundry_local_client import FoundryLocalClient +from ._foundry_local_client import FoundryLocalChatOptions, FoundryLocalClient, FoundryLocalSettings try: __version__ = importlib.metadata.version(__name__) @@ -10,6 +10,8 @@ __version__ = "0.0.0" # Fallback for development mode __all__ = [ + "FoundryLocalChatOptions", "FoundryLocalClient", + "FoundryLocalSettings", "__version__", ] diff --git a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py index c2b7bd34ab..1cbfde6f38 100644 --- a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py +++ b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import Any, ClassVar +import sys +from typing import Any, ClassVar, Generic, TypedDict -from agent_framework import use_chat_middleware, use_function_invocation +from agent_framework import ChatOptions, use_chat_middleware, use_function_invocation from agent_framework._pydantic import AFBaseSettings from agent_framework.exceptions import ServiceInitializationError from agent_framework.observability import use_instrumentation @@ -11,11 +12,93 @@ from foundry_local.models import DeviceType from openai import AsyncOpenAI +if sys.version_info >= (3, 13): + from typing import TypeVar # type: ignore # pragma: no cover +else: + from typing_extensions import TypeVar # type: ignore # pragma: no cover + + __all__ = [ + "FoundryLocalChatOptions", "FoundryLocalClient", + "FoundryLocalSettings", ] +# region Foundry Local Chat Options TypedDict + + +class FoundryLocalChatOptions(ChatOptions, total=False): + """Azure Foundry Local (local model deployment) chat options dict. + + Extends base ChatOptions for local model inference via Foundry Local. + Foundry Local provides an OpenAI-compatible API, so most standard + OpenAI chat completion options are supported. + + See: https://github.com/Azure/azure-ai-foundry-model-inference + + Keys: + # Inherited from ChatOptions (supported via OpenAI-compatible API): + model_id: The model identifier or alias (e.g., 'phi-4-mini'). + temperature: Sampling temperature (0-2). + top_p: Nucleus sampling parameter. + max_tokens: Maximum tokens to generate. + stop: Stop sequences. + tools: List of tools available to the model. + tool_choice: How the model should use tools. + frequency_penalty: Frequency penalty (-2.0 to 2.0). + presence_penalty: Presence penalty (-2.0 to 2.0). + seed: Random seed for reproducibility. + + # Options with limited support (depends on the model): + response_format: Response format specification. + Not all local models support JSON mode. + logit_bias: Token bias dictionary. + May not be supported by all models. + + # Options not supported in Foundry Local: + user: Not used locally. + store: Not applicable for local inference. + metadata: Not applicable for local inference. + + # Foundry Local-specific options: + extra_body: Additional request body parameters to pass to the model. + Can be used for model-specific options not covered by standard API. + + Note: + The actual options supported depend on the specific model being used. + Some models (like Phi-4) may not support all OpenAI API features. + Options not supported by the model will typically be ignored. + """ + + # Foundry Local-specific options + extra_body: dict[str, Any] + """Additional request body parameters for model-specific options.""" + + # ChatOptions fields not applicable for local inference + user: None # type: ignore[misc] + """Not used for local model inference.""" + + store: None # type: ignore[misc] + """Not applicable for local inference.""" + + +FOUNDRY_LOCAL_OPTION_TRANSLATIONS: dict[str, str] = { + "model_id": "model", +} +"""Maps ChatOptions keys to OpenAI API parameter names (for compatibility).""" + +TFoundryLocalChatOptions = TypeVar( + "TFoundryLocalChatOptions", + bound=TypedDict, # type: ignore[valid-type] + default="FoundryLocalChatOptions", + covariant=True, +) + + +# endregion + + class FoundryLocalSettings(AFBaseSettings): """Foundry local model settings. @@ -40,7 +123,7 @@ class FoundryLocalSettings(AFBaseSettings): @use_function_invocation @use_instrumentation @use_chat_middleware -class FoundryLocalClient(OpenAIBaseChatClient): +class FoundryLocalClient(OpenAIBaseChatClient[TFoundryLocalChatOptions], Generic[TFoundryLocalChatOptions]): """Foundry Local Chat completion class.""" def __init__( @@ -125,6 +208,16 @@ def __init__( # You can also use the CLI: `foundry model load phi-4-mini --device Auto` + # Using custom ChatOptions with type safety: + from typing import TypedDict + from agent_framework_foundry_local import FoundryLocalChatOptions + + class MyOptions(FoundryLocalChatOptions, total=False): + my_custom_option: str + + client: FoundryLocalClient[MyOptions] = FoundryLocalClient(model_id="phi-4-mini") + response = await client.get_response("Hello", options={"my_custom_option": "value"}) + Raises: ServiceInitializationError: If the specified model ID or alias is not found. Sometimes a model might be available but if you have specified a device diff --git a/python/packages/foundry_local/pyproject.toml b/python/packages/foundry_local/pyproject.toml index 15e61efb56..5c1a39f5f1 100644 --- a/python/packages/foundry_local/pyproject.toml +++ b/python/packages/foundry_local/pyproject.toml @@ -4,7 +4,7 @@ description = "Foundry Local integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/lab/lightning/samples/train_math_agent.py b/python/packages/lab/lightning/samples/train_math_agent.py index 7cb4947e88..2c6937446e 100644 --- a/python/packages/lab/lightning/samples/train_math_agent.py +++ b/python/packages/lab/lightning/samples/train_math_agent.py @@ -18,7 +18,7 @@ from typing import TypedDict, cast import sympy # type: ignore[import-untyped,reportMissingImports] -from agent_framework import AgentRunResponse, ChatAgent, MCPStdioTool +from agent_framework import AgentResponse, ChatAgent, MCPStdioTool from agent_framework.lab.lightning import AgentFrameworkTracer from agent_framework.openai import OpenAIChatClient from agentlightning import LLM, Dataset, Trainer, rollout @@ -102,7 +102,7 @@ def _is_result_correct(prediction: str, ground_truth: str) -> float: return float(_scalar_are_results_same(prediction, ground_truth, 1e-2)) -def evaluate(result: AgentRunResponse, ground_truth: str) -> float: +def evaluate(result: AgentResponse, ground_truth: str) -> float: """Main evaluation function that extracts the agent's answer and compares with ground truth. This function: diff --git a/python/packages/lab/pyproject.toml b/python/packages/lab/pyproject.toml index f081fed620..bffeeb4f0a 100644 --- a/python/packages/lab/pyproject.toml +++ b/python/packages/lab/pyproject.toml @@ -4,7 +4,7 @@ description = "Experimental modules for Microsoft Agent Framework" authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py b/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py index ae01082ca3..b514ebd60c 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py @@ -3,18 +3,20 @@ import uuid from typing import cast -from agent_framework._agents import ChatAgent -from agent_framework._types import AgentRunResponse, ChatMessage, Role -from agent_framework._workflows import ( +from agent_framework import ( AgentExecutor, AgentExecutorRequest, AgentExecutorResponse, + AgentResponse, + ChatAgent, + ChatClientProtocol, + ChatMessage, FunctionExecutor, + Role, Workflow, WorkflowBuilder, WorkflowContext, ) -from agent_framework.openai import OpenAIChatClient from loguru import logger from tau2.data_model.simulation import SimulationRun, TerminationReason # type: ignore[import-untyped] from tau2.data_model.tasks import Task # type: ignore[import-untyped] @@ -122,7 +124,7 @@ def should_not_stop(self, response: AgentExecutorResponse) -> bool: f"{'assistant' if is_from_agent else 'user'}, " f"routing to {'user' if is_from_agent else 'assistant'}:" ) - log_messages(response.agent_run_response.messages) + log_messages(response.agent_response.messages) if self.step_count >= self.max_steps: logger.info(f"Max steps ({self.max_steps}) reached - terminating conversation") @@ -130,7 +132,7 @@ def should_not_stop(self, response: AgentExecutorResponse) -> bool: # Terminate the workflow return False - response_text = response.agent_run_response.text + response_text = response.agent_response.text if is_from_agent and self._is_agent_stop(response_text): logger.info("Agent requested stop - terminating conversation") self.termination_reason = TerminationReason.AGENT_STOP @@ -142,7 +144,7 @@ def should_not_stop(self, response: AgentExecutorResponse) -> bool: # The final user message won't appear in the assistant's message store, # because it will never arrive there. # We need to store it because it's needed for evaluation. - self._final_user_message = flip_messages(response.agent_run_response.messages) + self._final_user_message = flip_messages(response.agent_response.messages) return False return True @@ -156,7 +158,7 @@ def _is_user_stop(self, text: str) -> bool: """Check if user wants to stop the conversation.""" return STOP in text or TRANSFER in text or OUT_OF_SCOPE in text - def assistant_agent(self, assistant_chat_client: OpenAIChatClient) -> ChatAgent: + def assistant_agent(self, assistant_chat_client: ChatClientProtocol) -> ChatAgent: """Create an assistant agent. Users can override this method to provide a custom assistant agent. @@ -205,7 +207,7 @@ def assistant_agent(self, assistant_chat_client: OpenAIChatClient) -> ChatAgent: ), ) - def user_simulator(self, user_simuator_chat_client: OpenAIChatClient, task: Task) -> ChatAgent: + def user_simulator(self, user_simuator_chat_client: ChatClientProtocol, task: Task) -> ChatAgent: """Create a user simulator agent. Users can override this method to provide a custom user simulator agent. @@ -253,7 +255,7 @@ async def conversation_orchestrator( """ # Flip message roles for proper conversation flow # Assistant messages become user messages and vice versa - flipped = flip_messages(response.agent_run_response.messages) + flipped = flip_messages(response.agent_response.messages) # Determine source to route to correct target is_from_agent = response.executor_id == ASSISTANT_AGENT_ID @@ -301,8 +303,8 @@ def build_conversation_workflow(self, assistant_agent: ChatAgent, user_simulator async def run( self, task: Task, - assistant_chat_client: OpenAIChatClient, - user_simuator_chat_client: OpenAIChatClient, + assistant_chat_client: ChatClientProtocol, + user_simulator_chat_client: ChatClientProtocol, ) -> list[ChatMessage]: """Run a tau2 task using workflow-based agent orchestration. @@ -317,18 +319,18 @@ async def run( Args: task: Tau2 task containing scenario, policy, and evaluation criteria assistant_chat_client: LLM client for the assistant agent - user_simuator_chat_client: LLM client for the user simulator + user_simulator_chat_client: LLM client for the user simulator Returns: Complete conversation history as ChatMessage list for evaluation """ logger.info(f"Starting workflow agent for task {task.id}: {task.description.purpose}") # type: ignore[unused-ignore] logger.info(f"Assistant chat client: {assistant_chat_client}") - logger.info(f"User simulator chat client: {user_simuator_chat_client}") + logger.info(f"User simulator chat client: {user_simulator_chat_client}") # STEP 1: Create agents assistant_agent = self.assistant_agent(assistant_chat_client) - user_simulator_agent = self.user_simulator(user_simuator_chat_client, task) + user_simulator_agent = self.user_simulator(user_simulator_chat_client, task) # STEP 2: Create the conversation workflow workflow = self.build_conversation_workflow(assistant_agent, user_simulator_agent) @@ -340,7 +342,7 @@ async def run( first_message = ChatMessage(Role.ASSISTANT, text=DEFAULT_FIRST_AGENT_MESSAGE) initial_greeting = AgentExecutorResponse( executor_id=ASSISTANT_AGENT_ID, - agent_run_response=AgentRunResponse(messages=[first_message]), + agent_response=AgentResponse(messages=[first_message]), full_conversation=[ChatMessage(Role.ASSISTANT, text=DEFAULT_FIRST_AGENT_MESSAGE)], ) diff --git a/python/packages/mem0/agent_framework_mem0/_provider.py b/python/packages/mem0/agent_framework_mem0/_provider.py index 48e508f411..e34c2cf435 100644 --- a/python/packages/mem0/agent_framework_mem0/_provider.py +++ b/python/packages/mem0/agent_framework_mem0/_provider.py @@ -3,22 +3,22 @@ import sys from collections.abc import MutableSequence, Sequence from contextlib import AbstractAsyncContextManager -from typing import Any +from typing import Any, TypedDict from agent_framework import ChatMessage, Context, ContextProvider from agent_framework.exceptions import ServiceInitializationError from mem0 import AsyncMemory, AsyncMemoryClient -if sys.version_info >= (3, 11): - from typing import NotRequired, Self, TypedDict # pragma: no cover -else: - from typing_extensions import NotRequired, Self, TypedDict # pragma: no cover - if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: from typing_extensions import override # type: ignore[import] # pragma: no cover +if sys.version_info >= (3, 11): + from typing import NotRequired, Self # pragma: no cover +else: + from typing_extensions import NotRequired, Self # pragma: no cover + # Type aliases for Mem0 search response formats (v1.1 and v2; v1 is deprecated, but matches the type definition for v2) class MemorySearchResponse_v1_1(TypedDict): diff --git a/python/packages/mem0/pyproject.toml b/python/packages/mem0/pyproject.toml index bc7d8330f4..4dd7b7940d 100644 --- a/python/packages/mem0/pyproject.toml +++ b/python/packages/mem0/pyproject.toml @@ -4,7 +4,7 @@ description = "Mem0 integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/ollama/agent_framework_ollama/__init__.py b/python/packages/ollama/agent_framework_ollama/__init__.py index 969b607623..d1bd699e1a 100644 --- a/python/packages/ollama/agent_framework_ollama/__init__.py +++ b/python/packages/ollama/agent_framework_ollama/__init__.py @@ -2,7 +2,7 @@ import importlib.metadata -from ._chat_client import OllamaChatClient, OllamaSettings +from ._chat_client import OllamaChatClient, OllamaChatOptions, OllamaSettings try: __version__ = importlib.metadata.version(__name__) @@ -11,6 +11,7 @@ __all__ = [ "OllamaChatClient", + "OllamaChatOptions", "OllamaSettings", "__version__", ] diff --git a/python/packages/ollama/agent_framework_ollama/_chat_client.py b/python/packages/ollama/agent_framework_ollama/_chat_client.py index f047a5d4b3..825ee47bec 100644 --- a/python/packages/ollama/agent_framework_ollama/_chat_client.py +++ b/python/packages/ollama/agent_framework_ollama/_chat_client.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import json +import sys from collections.abc import ( AsyncIterable, Callable, @@ -10,7 +11,7 @@ Sequence, ) from itertools import chain -from typing import Any, ClassVar +from typing import Any, ClassVar, Generic, TypedDict from agent_framework import ( AIFunction, @@ -46,6 +47,229 @@ from ollama._types import Message as OllamaMessage from pydantic import ValidationError +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover + + +__all__ = ["OllamaChatClient", "OllamaChatOptions"] + + +# region Ollama Chat Options TypedDict + + +class OllamaChatOptions(ChatOptions, total=False): + """Ollama-specific chat options dict. + + Extends base ChatOptions with Ollama-specific parameters. + Ollama passes model parameters through the `options` field. + + See: https://github.com/ollama/ollama/blob/main/docs/api.md + + Keys: + # Inherited from ChatOptions (mapped to Ollama options): + model_id: The model name, translates to ``model`` in Ollama API. + temperature: Sampling temperature, translates to ``options.temperature``. + top_p: Nucleus sampling, translates to ``options.top_p``. + max_tokens: Maximum tokens to generate, translates to ``options.num_predict``. + stop: Stop sequences, translates to ``options.stop``. + seed: Random seed for reproducibility, translates to ``options.seed``. + frequency_penalty: Frequency penalty, translates to ``options.frequency_penalty``. + presence_penalty: Presence penalty, translates to ``options.presence_penalty``. + tools: List of function tools. + response_format: Output format, translates to ``format``. + Use 'json' for JSON mode or a JSON schema dict for structured output. + + # Options not supported in Ollama: + tool_choice: Ollama only supports auto tool choice. + allow_multiple_tool_calls: Not configurable. + user: Not supported. + store: Not supported. + logit_bias: Not supported. + metadata: Not supported. + + # Ollama model-level options (placed in `options` dict): + # See: https://github.com/ollama/ollama/blob/main/docs/modelfile.mdx#valid-parameters-and-values + num_predict: Maximum number of tokens to predict (alternative to max_tokens). + top_k: Top-k sampling: limits tokens to k most likely. Higher = more diverse. + min_p: Minimum probability threshold for token selection. + typical_p: Locally typical sampling parameter (0.0-1.0). + repeat_penalty: Penalty for repeating tokens. Higher = less repetition. + repeat_last_n: Number of tokens to consider for repeat penalty. + penalize_newline: Whether to penalize newline characters. + num_ctx: Context window size (number of tokens). + num_batch: Batch size for prompt processing. + num_keep: Number of tokens to keep from initial prompt. + num_gpu: Number of layers to offload to GPU. + main_gpu: Main GPU for computation. + use_mmap: Whether to use memory-mapped files. + num_thread: Number of threads for CPU computation. + numa: Enable NUMA optimization. + + # Ollama-specific top-level options: + keep_alive: How long to keep model loaded (default: '5m'). + think: Whether thinking models should think before responding. + + Examples: + .. code-block:: python + + from agent_framework_ollama import OllamaChatOptions + + # Basic usage - standard options automatically mapped + options: OllamaChatOptions = { + "temperature": 0.7, + "max_tokens": 1000, + "seed": 42, + } + + # With Ollama-specific model options + options: OllamaChatOptions = { + "top_k": 40, + "num_ctx": 4096, + "keep_alive": "10m", + } + + # With JSON output format + options: OllamaChatOptions = { + "response_format": "json", + } + + # With structured output (JSON schema) + options: OllamaChatOptions = { + "response_format": { + "type": "object", + "properties": {"answer": {"type": "string"}}, + "required": ["answer"], + }, + } + """ + + # Ollama model-level options (will be placed in `options` dict) + num_predict: int + """Maximum number of tokens to predict (equivalent to max_tokens).""" + + top_k: int + """Top-k sampling: limits tokens to k most likely. Higher = more diverse.""" + + min_p: float + """Minimum probability threshold for token selection.""" + + typical_p: float + """Locally typical sampling parameter (0.0-1.0).""" + + repeat_penalty: float + """Penalty for repeating tokens. Higher = less repetition.""" + + repeat_last_n: int + """Number of tokens to consider for repeat penalty.""" + + penalize_newline: bool + """Whether to penalize newline characters.""" + + num_ctx: int + """Context window size (number of tokens).""" + + num_batch: int + """Batch size for prompt processing.""" + + num_keep: int + """Number of tokens to keep from initial prompt.""" + + num_gpu: int + """Number of layers to offload to GPU.""" + + main_gpu: int + """Main GPU for computation.""" + + use_mmap: bool + """Whether to use memory-mapped files.""" + + num_thread: int + """Number of threads for CPU computation.""" + + numa: bool + """Enable NUMA optimization.""" + + # Ollama-specific top-level options + keep_alive: str | int + """How long to keep the model loaded in memory after request. + Can be duration string (e.g., '5m', '1h') or seconds as int. + Set to 0 to unload immediately after request.""" + + think: bool + """For thinking models: whether the model should think before responding.""" + + # ChatOptions fields not supported in Ollama + tool_choice: None # type: ignore[misc] + """Not supported. Ollama only supports auto tool choice.""" + + allow_multiple_tool_calls: None # type: ignore[misc] + """Not supported. Not configurable in Ollama.""" + + user: None # type: ignore[misc] + """Not supported in Ollama.""" + + store: None # type: ignore[misc] + """Not supported in Ollama.""" + + logit_bias: None # type: ignore[misc] + """Not supported in Ollama.""" + + metadata: None # type: ignore[misc] + """Not supported in Ollama.""" + + +OLLAMA_OPTION_TRANSLATIONS: dict[str, str] = { + "model_id": "model", + "response_format": "format", +} +"""Maps ChatOptions keys to Ollama API parameter names.""" + +# Keys that should be placed in the nested `options` dict for the Ollama API +OLLAMA_MODEL_OPTIONS: set[str] = { + # From ChatOptions (mapped to options.*) + "temperature", + "top_p", + "max_tokens", # -> num_predict + "stop", + "seed", + "frequency_penalty", + "presence_penalty", + # Ollama-specific model options + "num_predict", + "top_k", + "min_p", + "typical_p", + "repeat_penalty", + "repeat_last_n", + "penalize_newline", + "num_ctx", + "num_batch", + "num_keep", + "num_gpu", + "main_gpu", + "use_mmap", + "num_thread", + "numa", +} + +# Translations for options that go into the nested `options` dict +OLLAMA_MODEL_OPTION_TRANSLATIONS: dict[str, str] = { + "max_tokens": "num_predict", +} +"""Maps ChatOptions keys to Ollama model option parameter names.""" + +TOllamaChatOptions = TypeVar("TOllamaChatOptions", bound=TypedDict, default="OllamaChatOptions", covariant=True) # type: ignore[valid-type] + + +# endregion + class OllamaSettings(AFBaseSettings): """Ollama settings.""" @@ -62,7 +286,7 @@ class OllamaSettings(AFBaseSettings): @use_function_invocation @use_instrumentation @use_chat_middleware -class OllamaChatClient(BaseChatClient): +class OllamaChatClient(BaseChatClient[TOllamaChatOptions], Generic[TOllamaChatOptions]): """Ollama Chat completion class.""" OTEL_PROVIDER_NAME: ClassVar[str] = "ollama" @@ -110,15 +334,16 @@ def __init__( super().__init__(**kwargs) + @override async def _inner_get_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> ChatResponse: # prepare - options_dict = self._prepare_options(messages, chat_options) + options_dict = self._prepare_options(messages, options) try: # execute @@ -133,15 +358,16 @@ async def _inner_get_response( # process return self._parse_response_from_ollama(response) + @override async def _inner_get_streaming_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: # prepare - options_dict = self._prepare_options(messages, chat_options) + options_dict = self._prepare_options(messages, options) try: # execute @@ -157,19 +383,37 @@ async def _inner_get_streaming_response( async for part in response_object: yield self._parse_streaming_response_from_ollama(part) - def _prepare_options(self, messages: MutableSequence[ChatMessage], chat_options: ChatOptions) -> dict[str, Any]: - # tool choice - Currently Ollama only supports auto tool choice - if chat_options.tool_choice == "required": - raise ServiceInvalidRequestError("Ollama does not support required tool choice.") - - run_options = chat_options.to_dict( - exclude={ - "type", - "instructions", - "tool_choice", # Ollama does not support tool_choice configuration - "additional_properties", # handled separately - } - ) + def _prepare_options(self, messages: MutableSequence[ChatMessage], options: dict[str, Any]) -> dict[str, Any]: + # Handle instructions by prepending to messages as system message + instructions = options.get("instructions") + if instructions: + from agent_framework._types import prepend_instructions_to_messages + + messages = prepend_instructions_to_messages(list(messages), instructions, role="system") + + # Keys to exclude from processing + exclude_keys = {"instructions", "tool_choice"} + + # Build run_options and model_options separately + run_options: dict[str, Any] = {} + model_options: dict[str, Any] = {} + + for key, value in options.items(): + if key in exclude_keys or value is None: + continue + + if key in OLLAMA_MODEL_OPTIONS: + # Apply model option translations (e.g., max_tokens -> num_predict) + translated_key = OLLAMA_MODEL_OPTION_TRANSLATIONS.get(key, key) + model_options[translated_key] = value + else: + # Apply top-level translations (e.g., model_id -> model) + translated_key = OLLAMA_OPTION_TRANSLATIONS.get(key, key) + run_options[translated_key] = value + + # Add model options to run_options if any + if model_options: + run_options["options"] = model_options # messages if messages and "messages" not in run_options: @@ -177,12 +421,6 @@ def _prepare_options(self, messages: MutableSequence[ChatMessage], chat_options: if "messages" not in run_options: raise ServiceInvalidRequestError("Messages are required for chat completions") - # translations between ChatOptions and Ollama API - translations = {"model_id": "model"} - for old_key, new_key in translations.items(): - if old_key in run_options and old_key != new_key: - run_options[new_key] = run_options.pop(old_key) - # model id if not run_options.get("model"): if not self.model_id: @@ -190,15 +428,9 @@ def _prepare_options(self, messages: MutableSequence[ChatMessage], chat_options: run_options["model"] = self.model_id # tools - if chat_options.tools and (tools := self._prepare_tools_for_ollama(chat_options.tools)): - run_options["tools"] = tools - - # additional properties - additional_options = { - key: value for key, value in chat_options.additional_properties.items() if value is not None - } - if additional_options: - run_options.update(additional_options) + tools = options.get("tools") + if tools and (prepared_tools := self._prepare_tools_for_ollama(tools)): + run_options["tools"] = prepared_tools return run_options diff --git a/python/packages/ollama/pyproject.toml b/python/packages/ollama/pyproject.toml index 76cd25db74..f232346450 100644 --- a/python/packages/ollama/pyproject.toml +++ b/python/packages/ollama/pyproject.toml @@ -4,7 +4,7 @@ description = "Ollama integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://learn.microsoft.com/en-us/agent-framework/" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/ollama/tests/test_ollama_chat_client.py b/python/packages/ollama/tests/test_ollama_chat_client.py index fbb88695c1..e2aebb2a6a 100644 --- a/python/packages/ollama/tests/test_ollama_chat_client.py +++ b/python/packages/ollama/tests/test_ollama_chat_client.py @@ -16,6 +16,7 @@ TextContent, TextReasoningContent, UriContent, + ai_function, chat_middleware, ) from agent_framework.exceptions import ( @@ -113,6 +114,7 @@ def mock_chat_completion_tool_call() -> OllamaChatResponse: ) +@ai_function def hello_world(arg1: str) -> str: return "Hello World" @@ -199,19 +201,6 @@ async def test_empty_messages() -> None: await ollama_chat_client.get_response(messages=[]) -async def test_function_choice_required_argument() -> None: - ollama_chat_client = OllamaChatClient( - host="http://localhost:12345", - model_id="test-model", - ) - with pytest.raises(ServiceInvalidRequestError): - await ollama_chat_client.get_response( - messages=[ChatMessage(text="hello world", role="user")], - tool_choice="required", - tools=[hello_world], - ) - - @patch.object(AsyncClient, "chat", new_callable=AsyncMock) async def test_cmc( mock_chat: AsyncMock, @@ -337,7 +326,7 @@ async def test_cmc_streaming_with_tool_call( chat_history.append(ChatMessage(text="hello world", role="user")) ollama_client = OllamaChatClient() - result = ollama_client.get_streaming_response(messages=chat_history, tools=[hello_world]) + result = ollama_client.get_streaming_response(messages=chat_history, options={"tools": [hello_world]}) chunks: list[ChatResponseUpdate] = [] async for chunk in result: @@ -373,7 +362,9 @@ async def test_cmc_with_hosted_tool_call( ollama_client = OllamaChatClient() await ollama_client.get_response( messages=chat_history, - tools=[HostedWebSearchTool(additional_properties=additional_properties)], + options={ + "tools": HostedWebSearchTool(additional_properties=additional_properties), + }, ) @@ -450,7 +441,7 @@ async def test_cmc_integration_with_tool_call( chat_history.append(ChatMessage(text="Call the hello world function and repeat what it says", role="user")) ollama_client = OllamaChatClient() - result = await ollama_client.get_response(messages=chat_history, tools=[hello_world]) + result = await ollama_client.get_response(messages=chat_history, options={"tools": [hello_world]}) assert "hello" in result.text.lower() and "world" in result.text.lower() assert isinstance(result.messages[-2].contents[0], FunctionResultContent) @@ -478,7 +469,7 @@ async def test_cmc_streaming_integration_with_tool_call( ollama_client = OllamaChatClient() result: AsyncIterable[ChatResponseUpdate] = ollama_client.get_streaming_response( - messages=chat_history, tools=[hello_world] + messages=chat_history, options={"tools": [hello_world]} ) chunks: list[ChatResponseUpdate] = [] diff --git a/python/packages/purview/README.md b/python/packages/purview/README.md index 4e6690bf31..ad3d1867d0 100644 --- a/python/packages/purview/README.md +++ b/python/packages/purview/README.md @@ -135,11 +135,11 @@ class MyCustomCache(CacheProvider): async def get(self, key: str) -> Any | None: # Your implementation pass - + async def set(self, key: str, value: Any, ttl_seconds: int | None = None) -> None: # Your implementation pass - + async def remove(self, key: str) -> None: # Your implementation pass @@ -295,9 +295,9 @@ All exceptions inherit from `PurviewServiceError`. You can catch specific except ```python from agent_framework.microsoft import ( PurviewPaymentRequiredError, - PurviewAuthenticationError, + PurviewAuthenticationError, PurviewRateLimitError, - PurviewRequestError, + PurviewRequestError, PurviewServiceError ) @@ -321,5 +321,3 @@ except (PurviewAuthenticationError, PurviewRateLimitError, PurviewRequestError, - **Error Handling**: Use `ignore_exceptions` and `ignore_payment_required` settings for graceful degradation. When enabled, errors are logged but don't fail the request. - **Caching**: Protection scopes responses and 402 errors are cached by default with a 4-hour TTL. Cache is automatically invalidated when protection scope state changes. - **Background Processing**: Content Activities and offline Process Content requests are handled asynchronously using background tasks to avoid blocking the main execution flow. - - diff --git a/python/packages/purview/agent_framework_purview/_middleware.py b/python/packages/purview/agent_framework_purview/_middleware.py index 108cf40410..7839f2f968 100644 --- a/python/packages/purview/agent_framework_purview/_middleware.py +++ b/python/packages/purview/agent_framework_purview/_middleware.py @@ -57,9 +57,9 @@ async def process( context.messages, Activity.UPLOAD_TEXT ) if should_block_prompt: - from agent_framework import AgentRunResponse, ChatMessage, Role + from agent_framework import AgentResponse, ChatMessage, Role - context.result = AgentRunResponse( + context.result = AgentResponse( messages=[ChatMessage(role=Role.SYSTEM, text=self._settings.blocked_prompt_message)] ) context.terminate = True @@ -76,7 +76,7 @@ async def process( await next(context) try: - # Post (response) check only if we have a normal AgentRunResponse + # Post (response) check only if we have a normal AgentResponse # Use the same user_id from the request for the response evaluation if context.result and not context.is_streaming: should_block_response, _ = await self._processor.process_messages( @@ -85,9 +85,9 @@ async def process( user_id=resolved_user_id, ) if should_block_response: - from agent_framework import AgentRunResponse, ChatMessage, Role + from agent_framework import AgentResponse, ChatMessage, Role - context.result = AgentRunResponse( + context.result = AgentResponse( messages=[ChatMessage(role=Role.SYSTEM, text=self._settings.blocked_response_message)] ) else: diff --git a/python/packages/purview/agent_framework_purview/_models.py b/python/packages/purview/agent_framework_purview/_models.py index 0ee502da1a..e4c27496a9 100644 --- a/python/packages/purview/agent_framework_purview/_models.py +++ b/python/packages/purview/agent_framework_purview/_models.py @@ -1,9 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -"""Unified Purview model definitions and public export surface.""" - -from __future__ import annotations - from collections.abc import Mapping, MutableMapping, Sequence from datetime import datetime from enum import Enum, Flag, auto @@ -179,6 +175,8 @@ def translate_activity(activity: Activity) -> ProtectionScopeActivities: # Simple value models # -------------------------------------------------------------------------------------- +TAliasSerializable = TypeVar("TAliasSerializable", bound="_AliasSerializable") + class _AliasSerializable(SerializationMixin): """Base class adding alias mapping + pydantic-compat helpers. @@ -232,7 +230,7 @@ def model_dump_json(self, *, by_alias: bool = True, exclude_none: bool = True, * return json.dumps(self.model_dump(by_alias=by_alias, exclude_none=exclude_none, **kwargs)) @classmethod - def model_validate(cls, value: MutableMapping[str, Any]) -> _AliasSerializable: # type: ignore[name-defined] + def model_validate(cls: type[TAliasSerializable], value: MutableMapping[str, Any]) -> TAliasSerializable: # type: ignore[name-defined] return cls(**value) # ------------------------------------------------------------------ diff --git a/python/packages/purview/pyproject.toml b/python/packages/purview/pyproject.toml index 5b3bfb1e06..f141c5c898 100644 --- a/python/packages/purview/pyproject.toml +++ b/python/packages/purview/pyproject.toml @@ -4,7 +4,7 @@ description = "Microsoft Purview (Graph dataSecurityAndGovernance) integration f authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://github.com/microsoft/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/purview/tests/test_chat_middleware.py b/python/packages/purview/tests/test_chat_middleware.py index 5633488a7e..8d414babb9 100644 --- a/python/packages/purview/tests/test_chat_middleware.py +++ b/python/packages/purview/tests/test_chat_middleware.py @@ -37,7 +37,7 @@ def chat_context(self) -> ChatContext: chat_options = MagicMock() chat_options.model = "test-model" return ChatContext( - chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], chat_options=chat_options + chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], options=chat_options ) async def test_initialization(self, middleware: PurviewChatPolicyMiddleware) -> None: @@ -110,7 +110,7 @@ async def test_streaming_skips_post_check(self, middleware: PurviewChatPolicyMid streaming_context = ChatContext( chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], - chat_options=chat_options, + options=chat_options, is_streaming=True, ) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: @@ -189,7 +189,7 @@ async def test_chat_middleware_handles_payment_required_pre_check(self, mock_cre chat_options = MagicMock() chat_options.model = "test-model" context = ChatContext( - chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], chat_options=chat_options + chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], options=chat_options ) async def mock_process_messages(*args, **kwargs): @@ -215,7 +215,7 @@ async def test_chat_middleware_ignores_payment_required_when_configured(self, mo chat_options = MagicMock() chat_options.model = "test-model" context = ChatContext( - chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], chat_options=chat_options + chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], options=chat_options ) async def mock_process_messages(*args, **kwargs): @@ -257,7 +257,7 @@ async def test_chat_middleware_with_ignore_exceptions(self, mock_credential: Asy chat_options = MagicMock() chat_options.model = "test-model" context = ChatContext( - chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], chat_options=chat_options + chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], options=chat_options ) async def mock_process_messages(*args, **kwargs): diff --git a/python/packages/purview/tests/test_middleware.py b/python/packages/purview/tests/test_middleware.py index 0769efab27..9426bc66af 100644 --- a/python/packages/purview/tests/test_middleware.py +++ b/python/packages/purview/tests/test_middleware.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from agent_framework import AgentRunContext, AgentRunResponse, ChatMessage, Role +from agent_framework import AgentResponse, AgentRunContext, ChatMessage, Role from azure.core.credentials import AccessToken from agent_framework_purview import PurviewPolicyMiddleware, PurviewSettings @@ -57,7 +57,7 @@ async def test_middleware_allows_clean_prompt( async def mock_next(ctx: AgentRunContext) -> None: nonlocal next_called next_called = True - ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="I'm good, thanks!")]) + ctx.result = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="I'm good, thanks!")]) await middleware.process(context, mock_next) @@ -104,7 +104,7 @@ async def mock_process_messages(messages, activity, user_id=None): with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): async def mock_next(ctx: AgentRunContext) -> None: - ctx.result = AgentRunResponse( + ctx.result = AgentResponse( messages=[ChatMessage(role=Role.ASSISTANT, text="Here's some sensitive information")] ) @@ -145,7 +145,7 @@ async def test_middleware_processor_receives_correct_activity( with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_process: async def mock_next(ctx: AgentRunContext) -> None: - ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")]) + ctx.result = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")]) await middleware.process(context, mock_next) @@ -167,7 +167,7 @@ async def test_middleware_handles_pre_check_exception( ) as mock_process: async def mock_next(ctx: AgentRunContext) -> None: - ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")]) + ctx.result = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")]) await middleware.process(context, mock_next) @@ -199,7 +199,7 @@ async def mock_process_messages(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): async def mock_next(ctx: AgentRunContext) -> None: - ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")]) + ctx.result = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")]) await middleware.process(context, mock_next) @@ -225,7 +225,7 @@ async def mock_process_messages(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): async def mock_next(ctx): - ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")]) + ctx.result = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")]) # Should not raise, just log await middleware.process(context, mock_next) diff --git a/python/packages/redis/pyproject.toml b/python/packages/redis/pyproject.toml index 26f55558d3..7f299aa398 100644 --- a/python/packages/redis/pyproject.toml +++ b/python/packages/redis/pyproject.toml @@ -4,7 +4,7 @@ description = "Redis integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/pyproject.toml b/python/pyproject.toml index 9f1ef06968..4436c9696c 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ description = "Microsoft Agent Framework for building AI Agents with Python. Thi authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b260107" +version = "1.0.0b260114" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" @@ -23,7 +23,7 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ - "agent-framework-core[all]==1.0.0b260107", + "agent-framework-core[all]==1.0.0b260114", ] [dependency-groups] diff --git a/python/samples/README.md b/python/samples/README.md index 20798e4ad0..6c5b9dc701 100644 --- a/python/samples/README.md +++ b/python/samples/README.md @@ -178,6 +178,7 @@ The recommended way to use Ollama is via the native `OllamaChatClient` from the | File | Description | |------|-------------| | [`getting_started/context_providers/simple_context_provider.py`](./getting_started/context_providers/simple_context_provider.py) | Simple context provider implementation example | +| [`getting_started/context_providers/aggregate_context_provider.py`](./getting_started/context_providers/aggregate_context_provider.py) | Shows how to combine multiple context providers using an AggregateContextProvider | ## DevUI diff --git a/python/samples/demos/chatkit-integration/README.md b/python/samples/demos/chatkit-integration/README.md index 3d47925bdd..688d24aebf 100644 --- a/python/samples/demos/chatkit-integration/README.md +++ b/python/samples/demos/chatkit-integration/README.md @@ -62,7 +62,7 @@ graph TB AttStore -.->|save metadata| SQLite Converter -->|ChatMessage array| Agent - Agent -->|AgentRunResponseUpdate| Streamer + Agent -->|AgentResponseUpdate| Streamer Streamer -->|ThreadStreamEvent| ChatKit ChatKit --> Widgets diff --git a/python/samples/demos/chatkit-integration/app.py b/python/samples/demos/chatkit-integration/app.py index 95d66b78c7..c215b64290 100644 --- a/python/samples/demos/chatkit-integration/app.py +++ b/python/samples/demos/chatkit-integration/app.py @@ -18,7 +18,7 @@ import uvicorn # Agent Framework imports -from agent_framework import AgentRunResponseUpdate, ChatAgent, ChatMessage, FunctionResultContent, Role +from agent_framework import AgentResponseUpdate, ChatAgent, ChatMessage, FunctionResultContent, Role from agent_framework.azure import AzureOpenAIChatClient # Agent Framework ChatKit integration @@ -289,8 +289,10 @@ async def _update_thread_title( # Use the chat client directly for a quick, lightweight call response = await self.weather_agent.chat_client.get_response( messages=title_prompt, - temperature=0.3, - max_tokens=20, + options={ + "temperature": 0.3, + "max_tokens": 20, + }, ) if response.messages and response.messages[-1].text: @@ -363,7 +365,7 @@ async def respond( agent_stream = self.weather_agent.run_stream(agent_messages) # Create an intercepting stream that extracts function results while passing through updates - async def intercept_stream() -> AsyncIterator[AgentRunResponseUpdate]: + async def intercept_stream() -> AsyncIterator[AgentResponseUpdate]: nonlocal weather_data, show_city_selector async for update in agent_stream: # Check for function results in the update @@ -460,7 +462,7 @@ async def action( agent_stream = self.weather_agent.run_stream(agent_messages) # Create an intercepting stream that extracts function results while passing through updates - async def intercept_stream() -> AsyncIterator[AgentRunResponseUpdate]: + async def intercept_stream() -> AsyncIterator[AgentResponseUpdate]: nonlocal weather_data async for update in agent_stream: # Check for function results in the update diff --git a/python/samples/demos/hosted_agents/agent_with_text_search_rag/main.py b/python/samples/demos/hosted_agents/agent_with_text_search_rag/main.py index 143a3226b6..2974a88b8b 100644 --- a/python/samples/demos/hosted_agents/agent_with_text_search_rag/main.py +++ b/python/samples/demos/hosted_agents/agent_with_text_search_rag/main.py @@ -99,7 +99,7 @@ def main(): "You are a helpful support specialist for Contoso Outdoors. " "Answer questions using the provided context and cite the source document when available." ), - context_providers=TextSearchContextProvider(), + context_provider=TextSearchContextProvider(), ) # Run the agent as a hosted agent diff --git a/python/samples/demos/m365-agent/m365_agent_demo/app.py b/python/samples/demos/m365-agent/m365_agent_demo/app.py index 7580e5ecce..1870513b32 100644 --- a/python/samples/demos/m365-agent/m365_agent_demo/app.py +++ b/python/samples/demos/m365-agent/m365_agent_demo/app.py @@ -189,7 +189,7 @@ def create_app(config: AppConfig) -> web.Application: web.Application: Fully initialized web application. """ middleware_fn = build_anonymous_claims_middleware(config.use_anonymous_mode) - app = web.Application(middlewares=[middleware_fn]) + app = web.Application(middleware=[middleware_fn]) storage = MemoryStorage() agent = build_agent() diff --git a/python/samples/demos/workflow_evaluation/create_workflow.py b/python/samples/demos/workflow_evaluation/create_workflow.py index a150b7274f..6fb4b874c6 100644 --- a/python/samples/demos/workflow_evaluation/create_workflow.py +++ b/python/samples/demos/workflow_evaluation/create_workflow.py @@ -47,7 +47,7 @@ ) from agent_framework import ( AgentExecutorResponse, - AgentRunResponseUpdate, + AgentResponseUpdate, AgentRunUpdateEvent, ChatMessage, Executor, @@ -133,8 +133,8 @@ def _extract_agent_findings(self, responses: list[AgentExecutorResponse]) -> lis for response in responses: findings = [] - if response.agent_run_response and response.agent_run_response.messages: - for msg in response.agent_run_response.messages: + if response.agent_response and response.agent_response.messages: + for msg in response.agent_response.messages: if msg.role == Role.ASSISTANT and msg.text and msg.text.strip(): findings.append(msg.text.strip()) @@ -373,7 +373,7 @@ async def _process_workflow_events(events, conversation_ids, response_ids): def _track_agent_ids(event, agent, response_ids, conversation_ids): """Track agent response and conversation IDs - supporting multiple responses per agent.""" - if isinstance(event.data, AgentRunResponseUpdate): + if isinstance(event.data, AgentResponseUpdate): # Check for conversation_id and response_id from raw_representation # V2 API stores conversation_id directly on raw_representation (ChatResponseUpdate) if hasattr(event.data, "raw_representation") and event.data.raw_representation: diff --git a/python/samples/getting_started/agents/anthropic/anthropic_advanced.py b/python/samples/getting_started/agents/anthropic/anthropic_advanced.py index a7f4ae2656..2b727efa8b 100644 --- a/python/samples/getting_started/agents/anthropic/anthropic_advanced.py +++ b/python/samples/getting_started/agents/anthropic/anthropic_advanced.py @@ -3,7 +3,7 @@ import asyncio from agent_framework import HostedMCPTool, HostedWebSearchTool, TextReasoningContent, UsageContent -from agent_framework.anthropic import AnthropicClient +from agent_framework.anthropic import AnthropicChatOptions, AnthropicClient """ Anthropic Chat Agent Example @@ -15,9 +15,9 @@ """ -async def streaming_example() -> None: +async def main() -> None: """Example of streaming response (get results as they are generated).""" - agent = AnthropicClient().create_agent( + agent = AnthropicClient[AnthropicChatOptions]().create_agent( name="DocsAgent", instructions="You are a helpful agent for both Microsoft docs questions and general questions.", tools=[ @@ -27,10 +27,12 @@ async def streaming_example() -> None: ), HostedWebSearchTool(), ], - # anthropic needs a value for the max_tokens parameter - # we set it to 1024, but you can override like this: - max_tokens=20000, - additional_chat_options={"thinking": {"type": "enabled", "budget_tokens": 10000}}, + default_options={ + # anthropic needs a value for the max_tokens parameter + # we set it to 1024, but you can override like this: + "max_tokens": 20000, + "thinking": {"type": "enabled", "budget_tokens": 10000}, + }, ) query = "Can you compare Python decorators with C# attributes?" @@ -48,11 +50,5 @@ async def streaming_example() -> None: print("\n") -async def main() -> None: - print("=== Anthropic Example ===") - - await streaming_example() - - if __name__ == "__main__": asyncio.run(main()) diff --git a/python/samples/getting_started/agents/anthropic/anthropic_foundry.py b/python/samples/getting_started/agents/anthropic/anthropic_foundry.py index cb1b690d54..2e04dfebaa 100644 --- a/python/samples/getting_started/agents/anthropic/anthropic_foundry.py +++ b/python/samples/getting_started/agents/anthropic/anthropic_foundry.py @@ -38,10 +38,12 @@ async def main() -> None: ), HostedWebSearchTool(), ], - # anthropic needs a value for the max_tokens parameter - # we set it to 1024, but you can override like this: - max_tokens=20000, - additional_chat_options={"thinking": {"type": "enabled", "budget_tokens": 10000}}, + default_options={ + # anthropic needs a value for the max_tokens parameter + # we set it to 1024, but you can override like this: + "max_tokens": 20000, + "thinking": {"type": "enabled", "budget_tokens": 10000}, + }, ) query = "Can you compare Python decorators with C# attributes?" diff --git a/python/samples/getting_started/agents/anthropic/anthropic_skills.py b/python/samples/getting_started/agents/anthropic/anthropic_skills.py index 331b6405fb..2624a9742b 100644 --- a/python/samples/getting_started/agents/anthropic/anthropic_skills.py +++ b/python/samples/getting_started/agents/anthropic/anthropic_skills.py @@ -5,7 +5,7 @@ from pathlib import Path from agent_framework import HostedCodeInterpreterTool, HostedFileContent -from agent_framework.anthropic import AnthropicClient +from agent_framework.anthropic import AnthropicChatOptions, AnthropicClient logger = logging.getLogger(__name__) """ @@ -22,7 +22,7 @@ async def main() -> None: """Example of streaming response (get results as they are generated).""" - client = AnthropicClient(additional_beta_flags=["skills-2025-10-02"]) + client = AnthropicClient[AnthropicChatOptions](additional_beta_flags=["skills-2025-10-02"]) # List Anthropic-managed Skills skills = await client.anthropic_client.beta.skills.list(source="anthropic", betas=["skills-2025-10-02"]) @@ -35,8 +35,8 @@ async def main() -> None: name="DocsAgent", instructions="You are a helpful agent for creating powerpoint presentations.", tools=HostedCodeInterpreterTool(), - max_tokens=20000, - additional_chat_options={ + default_options={ + "max_tokens": 20000, "thinking": {"type": "enabled", "budget_tokens": 10000}, "container": {"skills": [{"type": "anthropic", "skill_id": "pptx", "version": "latest"}]}, }, diff --git a/python/samples/getting_started/agents/azure_ai/README.md b/python/samples/getting_started/agents/azure_ai/README.md index 8ed95ad091..f60b64cf18 100644 --- a/python/samples/getting_started/agents/azure_ai/README.md +++ b/python/samples/getting_started/agents/azure_ai/README.md @@ -6,8 +6,9 @@ This folder contains examples demonstrating different ways to create and use age | File | Description | |------|-------------| -| [`azure_ai_basic.py`](azure_ai_basic.py) | The simplest way to create an agent using `AzureAIClient`. Demonstrates both streaming and non-streaming responses with function tools. Shows automatic agent creation and basic weather functionality. | -| [`azure_ai_use_latest_version.py`](azure_ai_use_latest_version.py) | Demonstrates how to reuse the latest version of an existing agent instead of creating a new agent version on each instantiation using the `use_latest_version=True` parameter. | +| [`azure_ai_basic.py`](azure_ai_basic.py) | The simplest way to create an agent using `AzureAIProjectAgentProvider`. Demonstrates both streaming and non-streaming responses with function tools. Shows automatic agent creation and basic weather functionality. | +| [`azure_ai_provider_methods.py`](azure_ai_provider_methods.py) | Comprehensive guide to `AzureAIProjectAgentProvider` methods: `create_agent()` for creating new agents, `get_agent()` for retrieving existing agents (by name, reference, or details), and `as_agent()` for wrapping SDK objects without HTTP calls. | +| [`azure_ai_use_latest_version.py`](azure_ai_use_latest_version.py) | Demonstrates how to reuse the latest version of an existing agent instead of creating a new agent version on each instantiation by using `provider.get_agent()` to retrieve the latest version. | | [`azure_ai_with_agent_to_agent.py`](azure_ai_with_agent_to_agent.py) | Shows how to use Agent-to-Agent (A2A) capabilities with Azure AI agents to enable communication with other agents using the A2A protocol. Requires an A2A connection configured in your Azure AI project. | | [`azure_ai_with_azure_ai_search.py`](azure_ai_with_azure_ai_search.py) | Shows how to use Azure AI Search with Azure AI agents to search through indexed data and answer user questions with proper citations. Requires an Azure AI Search connection and index configured in your Azure AI project. | | [`azure_ai_with_bing_grounding.py`](azure_ai_with_bing_grounding.py) | Shows how to use Bing Grounding search with Azure AI agents to search the web for current information and provide grounded responses with citations. Requires a Bing connection configured in your Azure AI project. | @@ -15,6 +16,7 @@ This folder contains examples demonstrating different ways to create and use age | [`azure_ai_with_browser_automation.py`](azure_ai_with_browser_automation.py) | Shows how to use Browser Automation with Azure AI agents to perform automated web browsing tasks and provide responses based on web interactions. Requires a Browser Automation connection configured in your Azure AI project. | | [`azure_ai_with_code_interpreter.py`](azure_ai_with_code_interpreter.py) | Shows how to use the `HostedCodeInterpreterTool` with Azure AI agents to write and execute Python code for mathematical problem solving and data analysis. | | [`azure_ai_with_code_interpreter_file_generation.py`](azure_ai_with_code_interpreter_file_generation.py) | Shows how to retrieve file IDs from code interpreter generated files using both streaming and non-streaming approaches. | +| [`azure_ai_with_code_interpreter_file_download.py`](azure_ai_with_code_interpreter_file_download.py) | Shows how to download files generated by code interpreter using the OpenAI containers API. | | [`azure_ai_with_existing_agent.py`](azure_ai_with_existing_agent.py) | Shows how to work with a pre-existing agent by providing the agent name and version to the Azure AI client. Demonstrates agent reuse patterns for production scenarios. | | [`azure_ai_with_existing_conversation.py`](azure_ai_with_existing_conversation.py) | Demonstrates how to use an existing conversation created on the service side with Azure AI agents. Shows two approaches: specifying conversation ID at the client level and using AgentThread with an existing conversation ID. | | [`azure_ai_with_application_endpoint.py`](azure_ai_with_application_endpoint.py) | Demonstrates calling the Azure AI application-scoped endpoint. | diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_basic.py b/python/samples/getting_started/agents/azure_ai/azure_ai_basic.py index 86aa603892..6cf5144bb7 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_basic.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_basic.py @@ -4,14 +4,14 @@ from random import randint from typing import Annotated -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential from pydantic import Field """ Azure AI Agent Basic Example -This sample demonstrates basic usage of AzureAIClient. +This sample demonstrates basic usage of AzureAIProjectAgentProvider. Shows both streaming and non-streaming responses with function tools. """ @@ -28,17 +28,18 @@ async def non_streaming_example() -> None: """Example of non-streaming response (get the complete result at once).""" print("=== Non-streaming Response Example ===") - # Since no Agent ID is provided, the agent will be automatically created. # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="BasicWeatherAgent", instructions="You are a helpful weather agent.", tools=get_weather, - ) as agent, - ): + ) + query = "What's the weather like in Seattle?" print(f"User: {query}") result = await agent.run(query) @@ -49,17 +50,18 @@ async def streaming_example() -> None: """Example of streaming response (get results as they are generated).""" print("=== Streaming Response Example ===") - # Since no Agent ID is provided, the agent will be automatically created. # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="BasicWeatherAgent", instructions="You are a helpful weather agent.", tools=get_weather, - ) as agent, - ): + ) + query = "What's the weather like in Tokyo?" print(f"User: {query}") print("Agent: ", end="", flush=True) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_provider_methods.py b/python/samples/getting_started/agents/azure_ai/azure_ai_provider_methods.py new file mode 100644 index 0000000000..0bf413c5d9 --- /dev/null +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_provider_methods.py @@ -0,0 +1,293 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import os +from random import randint +from typing import Annotated + +from agent_framework.azure import AzureAIProjectAgentProvider +from azure.ai.projects.aio import AIProjectClient +from azure.ai.projects.models import AgentReference, PromptAgentDefinition +from azure.identity.aio import AzureCliCredential +from pydantic import Field + +""" +Azure AI Project Agent Provider Methods Example + +This sample demonstrates the three main methods of AzureAIProjectAgentProvider: +1. create_agent() - Create a new agent on the Azure AI service +2. get_agent() - Retrieve an existing agent from the service +3. as_agent() - Wrap an SDK agent version object without making HTTP calls + +It also shows how to use a single provider instance to spawn multiple agents +with different configurations, which is efficient for multi-agent scenarios. + +Each method returns a ChatAgent that can be used for conversations. +""" + + +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}C." + + +async def create_agent_example() -> None: + """Example of using provider.create_agent() to create a new agent. + + This method creates a new agent version on the Azure AI service and returns + a ChatAgent. Use this when you want to create a fresh agent with + specific configuration. + """ + print("=== provider.create_agent() Example ===") + + async with ( + AzureCliCredential() as credential, + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + # Create a new agent with custom configuration + agent = await provider.create_agent( + name="WeatherAssistant", + instructions="You are a helpful weather assistant. Always be concise.", + description="An agent that provides weather information.", + tools=get_weather, + ) + + print(f"Created agent: {agent.name}") + print(f"Agent ID: {agent.id}") + + query = "What's the weather in Paris?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result}\n") + + +async def get_agent_by_name_example() -> None: + """Example of using provider.get_agent(name=...) to retrieve an agent by name. + + This method fetches the latest version of an existing agent from the service. + Use this when you know the agent name and want to use the most recent version. + """ + print("=== provider.get_agent(name=...) Example ===") + + async with ( + AzureCliCredential() as credential, + AIProjectClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as project_client, + ): + # First, create an agent using the SDK directly + created_agent = await project_client.agents.create_version( + agent_name="TestAgentByName", + description="Test agent for get_agent by name example.", + definition=PromptAgentDefinition( + model=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], + instructions="You are a helpful assistant. End each response with '- Your Assistant'.", + ), + ) + + try: + # Get the agent using the provider by name (fetches latest version) + provider = AzureAIProjectAgentProvider(project_client=project_client) + agent = await provider.get_agent(name=created_agent.name) + + print(f"Retrieved agent: {agent.name}") + + query = "Hello!" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result}\n") + finally: + # Clean up the agent + await project_client.agents.delete_version( + agent_name=created_agent.name, agent_version=created_agent.version + ) + + +async def get_agent_by_reference_example() -> None: + """Example of using provider.get_agent(reference=...) to retrieve a specific agent version. + + This method fetches a specific version of an agent using an AgentReference. + Use this when you need to use a particular version of an agent. + """ + print("=== provider.get_agent(reference=...) Example ===") + + async with ( + AzureCliCredential() as credential, + AIProjectClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as project_client, + ): + # First, create an agent using the SDK directly + created_agent = await project_client.agents.create_version( + agent_name="TestAgentByReference", + description="Test agent for get_agent by reference example.", + definition=PromptAgentDefinition( + model=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], + instructions="You are a helpful assistant. Always respond in uppercase.", + ), + ) + + try: + # Get the agent using an AgentReference with specific version + provider = AzureAIProjectAgentProvider(project_client=project_client) + reference = AgentReference(name=created_agent.name, version=created_agent.version) + agent = await provider.get_agent(reference=reference) + + print(f"Retrieved agent: {agent.name} (version via reference)") + + query = "Say hello" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result}\n") + finally: + # Clean up the agent + await project_client.agents.delete_version( + agent_name=created_agent.name, agent_version=created_agent.version + ) + + +async def get_agent_by_details_example() -> None: + """Example of using provider.get_agent(details=...) with pre-fetched AgentDetails. + + This method uses pre-fetched AgentDetails to get the latest version. + Use this when you already have AgentDetails from a previous API call. + """ + print("=== provider.get_agent(details=...) Example ===") + + async with ( + AzureCliCredential() as credential, + AIProjectClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as project_client, + ): + # First, create an agent using the SDK directly + created_agent = await project_client.agents.create_version( + agent_name="TestAgentByDetails", + description="Test agent for get_agent by details example.", + definition=PromptAgentDefinition( + model=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], + instructions="You are a helpful assistant. Always include an emoji in your response.", + ), + ) + + try: + # Fetch AgentDetails separately (simulating a previous API call) + agent_details = await project_client.agents.get(agent_name=created_agent.name) + + # Get the agent using the pre-fetched details (sync - no HTTP call) + provider = AzureAIProjectAgentProvider(project_client=project_client) + agent = provider.as_agent(agent_details.versions.latest) + + print(f"Retrieved agent: {agent.name} (from pre-fetched details)") + + query = "How are you today?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result}\n") + finally: + # Clean up the agent + await project_client.agents.delete_version( + agent_name=created_agent.name, agent_version=created_agent.version + ) + + +async def multiple_agents_example() -> None: + """Example of using a single provider to spawn multiple agents. + + A single provider instance can create multiple agents with different + configurations. + """ + print("=== Multiple Agents from Single Provider Example ===") + + async with ( + AzureCliCredential() as credential, + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + # Create multiple specialized agents from the same provider + weather_agent = await provider.create_agent( + name="WeatherExpert", + instructions="You are a weather expert. Provide brief weather information.", + tools=get_weather, + ) + + translator_agent = await provider.create_agent( + name="Translator", + instructions="You are a translator. Translate any text to French. Only output the translation.", + ) + + poet_agent = await provider.create_agent( + name="Poet", + instructions="You are a poet. Respond to everything with a short haiku.", + ) + + print(f"Created agents: {weather_agent.name}, {translator_agent.name}, {poet_agent.name}\n") + + # Use each agent for its specialty + weather_query = "What's the weather in London?" + print(f"User to WeatherExpert: {weather_query}") + weather_result = await weather_agent.run(weather_query) + print(f"WeatherExpert: {weather_result}\n") + + translate_query = "Hello, how are you today?" + print(f"User to Translator: {translate_query}") + translate_result = await translator_agent.run(translate_query) + print(f"Translator: {translate_result}\n") + + poet_query = "Tell me about the morning sun" + print(f"User to Poet: {poet_query}") + poet_result = await poet_agent.run(poet_query) + print(f"Poet: {poet_result}\n") + + +async def as_agent_example() -> None: + """Example of using provider.as_agent() to wrap an SDK object without HTTP calls. + + This method wraps an existing AgentVersionDetails into a ChatAgent without + making additional HTTP calls. Use this when you already have the full + AgentVersionDetails from a previous SDK operation. + """ + print("=== provider.as_agent() Example ===") + + async with ( + AzureCliCredential() as credential, + AIProjectClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as project_client, + ): + # Create an agent using the SDK directly - this returns AgentVersionDetails + agent_version_details = await project_client.agents.create_version( + agent_name="TestAgentAsAgent", + description="Test agent for as_agent example.", + definition=PromptAgentDefinition( + model=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], + instructions="You are a helpful assistant. Keep responses under 20 words.", + ), + ) + + try: + # Wrap the SDK object directly without any HTTP calls + provider = AzureAIProjectAgentProvider(project_client=project_client) + agent = provider.as_agent(agent_version_details) + + print(f"Wrapped agent: {agent.name} (no HTTP call needed)") + print(f"Agent version: {agent_version_details.version}") + + query = "What can you do?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result}\n") + finally: + # Clean up the agent + await project_client.agents.delete_version( + agent_name=agent_version_details.name, agent_version=agent_version_details.version + ) + + +async def main() -> None: + print("=== Azure AI Project Agent Provider Methods Example ===\n") + + await create_agent_example() + await get_agent_by_name_example() + await get_agent_by_reference_example() + await get_agent_by_details_example() + await as_agent_example() + await multiple_agents_example() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_use_latest_version.py b/python/samples/getting_started/agents/azure_ai/azure_ai_use_latest_version.py index 1a2a152821..025e78813e 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_use_latest_version.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_use_latest_version.py @@ -4,7 +4,7 @@ from random import randint from typing import Annotated -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential from pydantic import Field @@ -13,7 +13,7 @@ This sample demonstrates how to reuse the latest version of an existing agent instead of creating a new agent version on each instantiation. The first call creates a new agent, -while subsequent calls with `use_latest_version=True` reuse the latest agent version. +while subsequent calls with `get_agent()` reuse the latest agent version. """ @@ -28,39 +28,36 @@ def get_weather( async def main() -> None: # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. - async with AzureCliCredential() as credential: - async with ( - AzureAIClient( - credential=credential, - ).create_agent( - name="MyWeatherAgent", - instructions="You are a helpful weather agent.", - tools=get_weather, - ) as agent, - ): - # First query will create a new agent - query = "What's the weather like in Seattle?" - print(f"User: {query}") - result = await agent.run(query) - print(f"Agent: {result}\n") + async with ( + AzureCliCredential() as credential, + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + # First call creates a new agent + agent = await provider.create_agent( + name="MyWeatherAgent", + instructions="You are a helpful weather agent.", + tools=get_weather, + ) - # Create a new agent instance - async with ( - AzureAIClient( - credential=credential, - # This parameter will allow to re-use latest agent version - # instead of creating a new one - use_latest_version=True, - ).create_agent( - name="MyWeatherAgent", - instructions="You are a helpful weather agent.", - tools=get_weather, - ) as agent, - ): - query = "What's the weather like in Tokyo?" - print(f"User: {query}") - result = await agent.run(query) - print(f"Agent: {result}\n") + query = "What's the weather like in Seattle?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result}\n") + + # Second call retrieves the existing agent (latest version) instead of creating a new one + # This is useful when you want to reuse an agent that was created earlier + agent2 = await provider.get_agent( + name="MyWeatherAgent", + tools=get_weather, # Tools must be provided for function tools + ) + + query = "What's the weather like in Tokyo?" + print(f"User: {query}") + result = await agent2.run(query) + print(f"Agent: {result}\n") + + print(f"First agent ID with version: {agent.id}") + print(f"Second agent ID with version: {agent2.id}") if __name__ == "__main__": diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_agent_to_agent.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_agent_to_agent.py index 93bd445e28..d1dce0b220 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_agent_to_agent.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_agent_to_agent.py @@ -2,36 +2,47 @@ import asyncio import os -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential """ Azure AI Agent with Agent-to-Agent (A2A) Example -This sample demonstrates usage of AzureAIClient with Agent-to-Agent (A2A) capabilities +This sample demonstrates usage of AzureAIProjectAgentProvider with Agent-to-Agent (A2A) capabilities to enable communication with other agents using the A2A protocol. Prerequisites: 1. Set AZURE_AI_PROJECT_ENDPOINT and AZURE_AI_MODEL_DEPLOYMENT_NAME environment variables. 2. Ensure you have an A2A connection configured in your Azure AI project - and set A2A_PROJECT_CONNECTION_ID environment variable. + and set A2A_PROJECT_CONNECTION_ID environment variable. +3. (Optional) A2A_ENDPOINT - If the connection is missing target (e.g., "Custom keys" type), + set the A2A endpoint URL directly. """ async def main() -> None: + # Configure A2A tool with connection ID + a2a_tool = { + "type": "a2a_preview", + "project_connection_id": os.environ["A2A_PROJECT_CONNECTION_ID"], + } + + # If the connection is missing a target, we need to set the A2A endpoint URL + if os.environ.get("A2A_ENDPOINT"): + a2a_tool["base_url"] = os.environ["A2A_ENDPOINT"] + async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="MyA2AAgent", instructions="""You are a helpful assistant that can communicate with other agents. Use the A2A tool when you need to interact with other agents to complete tasks or gather information from specialized agents.""", - tools={ - "type": "a2a_preview", - "project_connection_id": os.environ["A2A_PROJECT_CONNECTION_ID"], - }, - ) as agent, - ): + tools=a2a_tool, + ) + query = "What can the secondary agent do?" print(f"User: {query}") result = await agent.run(query) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_azure_ai_search.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_azure_ai_search.py index 057c2b5ff7..c4ee686d87 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_azure_ai_search.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_azure_ai_search.py @@ -2,13 +2,13 @@ import asyncio import os -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential """ Azure AI Agent with Azure AI Search Example -This sample demonstrates usage of AzureAIClient with Azure AI Search +This sample demonstrates usage of AzureAIProjectAgentProvider with Azure AI Search to search through indexed data and answer user questions about it. Prerequisites: @@ -21,7 +21,9 @@ async def main() -> None: async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="MySearchAgent", instructions="""You are a helpful assistant. You must always provide citations for answers using the tool and render them as: `[message_idx:search_idx†source]`.""", @@ -38,8 +40,8 @@ async def main() -> None: ] }, }, - ) as agent, - ): + ) + query = "Tell me about insurance options" print(f"User: {query}") result = await agent.run(query) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_bing_custom_search.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_bing_custom_search.py index 682e2fc38e..2a2db762f4 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_bing_custom_search.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_bing_custom_search.py @@ -2,13 +2,13 @@ import asyncio import os -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential """ Azure AI Agent with Bing Custom Search Example -This sample demonstrates usage of AzureAIClient with Bing Custom Search +This sample demonstrates usage of AzureAIProjectAgentProvider with Bing Custom Search to search custom search instances and provide responses with relevant results. Prerequisites: @@ -21,7 +21,9 @@ async def main() -> None: async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="MyCustomSearchAgent", instructions="""You are a helpful agent that can use Bing Custom Search tools to assist users. Use the available Bing Custom Search tools to answer questions and perform tasks.""", @@ -36,8 +38,8 @@ async def main() -> None: ] }, }, - ) as agent, - ): + ) + query = "Tell me more about foundry agent service" print(f"User: {query}") result = await agent.run(query) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_bing_grounding.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_bing_grounding.py index 810962ab24..92c00dddc9 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_bing_grounding.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_bing_grounding.py @@ -2,13 +2,13 @@ import asyncio import os -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential """ Azure AI Agent with Bing Grounding Example -This sample demonstrates usage of AzureAIClient with Bing Grounding +This sample demonstrates usage of AzureAIProjectAgentProvider with Bing Grounding to search the web for current information and provide grounded responses. Prerequisites: @@ -27,7 +27,9 @@ async def main() -> None: async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="MyBingGroundingAgent", instructions="""You are a helpful assistant that can search the web for current information. Use the Bing search tool to find up-to-date information and provide accurate, well-sourced answers. @@ -42,8 +44,8 @@ async def main() -> None: ] }, }, - ) as agent, - ): + ) + query = "What is today's date and weather in Seattle?" print(f"User: {query}") result = await agent.run(query) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_browser_automation.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_browser_automation.py index 72ee2cd5b0..21a180530c 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_browser_automation.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_browser_automation.py @@ -2,13 +2,13 @@ import asyncio import os -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential """ Azure AI Agent with Browser Automation Example -This sample demonstrates usage of AzureAIClient with Browser Automation +This sample demonstrates usage of AzureAIProjectAgentProvider with Browser Automation to perform automated web browsing tasks and provide responses based on web interactions. Prerequisites: @@ -21,7 +21,9 @@ async def main() -> None: async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="MyBrowserAutomationAgent", instructions="""You are an Agent helping with browser automation tasks. You can answer questions, provide information, and assist with various tasks @@ -34,8 +36,8 @@ async def main() -> None: } }, }, - ) as agent, - ): + ) + query = """Your goal is to report the percent of Microsoft year-to-date stock price change. To do that, go to the website finance.yahoo.com. At the top of the page, you will find a search bar. diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_code_interpreter.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_code_interpreter.py index 2622e273e8..ad43e21e9c 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_code_interpreter.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_code_interpreter.py @@ -3,7 +3,7 @@ import asyncio from agent_framework import ChatResponse, HostedCodeInterpreterTool -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential from openai.types.responses.response import Response as OpenAIResponse from openai.types.responses.response_code_interpreter_tool_call import ResponseCodeInterpreterToolCall @@ -11,22 +11,24 @@ """ Azure AI Agent Code Interpreter Example -This sample demonstrates using HostedCodeInterpreterTool with AzureAIClient +This sample demonstrates using HostedCodeInterpreterTool with AzureAIProjectAgentProvider for Python code execution and mathematical problem solving. """ async def main() -> None: - """Example showing how to use the HostedCodeInterpreterTool with AzureAIClient.""" + """Example showing how to use the HostedCodeInterpreterTool with AzureAIProjectAgentProvider.""" async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="MyCodeInterpreterAgent", instructions="You are a helpful assistant that can write and execute Python code to solve problems.", tools=HostedCodeInterpreterTool(), - ) as agent, - ): + ) + query = "Use code to get the factorial of 100?" print(f"User: {query}") result = await agent.run(query) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_code_interpreter_file_download.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_code_interpreter_file_download.py new file mode 100644 index 0000000000..50ce0037a4 --- /dev/null +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_code_interpreter_file_download.py @@ -0,0 +1,219 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import tempfile +from pathlib import Path + +from agent_framework import ( + AgentResponseUpdate, + ChatAgent, + CitationAnnotation, + HostedCodeInterpreterTool, + HostedFileContent, + TextContent, +) +from agent_framework.azure import AzureAIProjectAgentProvider +from azure.identity.aio import AzureCliCredential + +""" +Azure AI V2 Code Interpreter File Download Sample + +This sample demonstrates how the AzureAIProjectAgentProvider handles file annotations +when code interpreter generates text files. It shows: +1. How to extract file IDs and container IDs from annotations +2. How to download container files using the OpenAI containers API +3. How to save downloaded files locally + +Note: Code interpreter generates files in containers, which require both +file_id and container_id to download via client.containers.files.content.retrieve(). +""" + +QUERY = ( + "Write a simple Python script that creates a text file called 'sample.txt' containing " + "'Hello from the code interpreter!' and save it to disk." +) + + +async def download_container_files( + file_contents: list[CitationAnnotation | HostedFileContent], agent: ChatAgent +) -> list[Path]: + """Download container files using the OpenAI containers API. + + Code interpreter generates files in containers, which require both file_id + and container_id to download. The container_id is stored in additional_properties. + + This function works for both streaming (HostedFileContent) and non-streaming + (CitationAnnotation) responses. + + Args: + file_contents: List of CitationAnnotation or HostedFileContent objects + containing file_id and container_id. + agent: The ChatAgent instance with access to the AzureAIClient. + + Returns: + List of Path objects for successfully downloaded files. + """ + if not file_contents: + return [] + + # Create output directory in system temp folder + temp_dir = Path(tempfile.gettempdir()) + output_dir = temp_dir / "agent_framework_downloads" + output_dir.mkdir(exist_ok=True) + + print(f"\nDownloading {len(file_contents)} container file(s) to {output_dir.absolute()}...") + + # Access the OpenAI client from AzureAIClient + openai_client = agent.chat_client.client + + downloaded_files: list[Path] = [] + + for content in file_contents: + file_id = content.file_id + + # Extract container_id from additional_properties + if not content.additional_properties or "container_id" not in content.additional_properties: + print(f" File {file_id}: ✗ Missing container_id") + continue + + container_id = content.additional_properties["container_id"] + + # Extract filename based on content type + if isinstance(content, CitationAnnotation): + filename = content.url or f"{file_id}.txt" + # Extract filename from sandbox URL if present (e.g., sandbox:/mnt/data/sample.txt) + if filename.startswith("sandbox:"): + filename = filename.split("/")[-1] + else: # HostedFileContent + filename = content.additional_properties.get("filename") or f"{file_id}.txt" + + output_path = output_dir / filename + + try: + # Download using containers API + print(f" Downloading {filename}...", end="", flush=True) + file_content = await openai_client.containers.files.content.retrieve( + file_id=file_id, + container_id=container_id, + ) + + # file_content is HttpxBinaryResponseContent, read it + content_bytes = file_content.read() + + # Save to disk + output_path.write_bytes(content_bytes) + file_size = output_path.stat().st_size + print(f"({file_size} bytes)") + + downloaded_files.append(output_path) + + except Exception as e: + print(f"Failed: {e}") + + return downloaded_files + + +async def non_streaming_example() -> None: + """Example of downloading files from non-streaming response using CitationAnnotation.""" + print("=== Non-Streaming Response Example ===") + + async with ( + AzureCliCredential() as credential, + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( + name="V2CodeInterpreterFileAgent", + instructions="You are a helpful assistant that can write and execute Python code to create files.", + tools=HostedCodeInterpreterTool(), + ) + + print(f"User: {QUERY}\n") + + result = await agent.run(QUERY) + print(f"Agent: {result.text}\n") + + # Check for annotations in the response + annotations_found: list[CitationAnnotation] = [] + # AgentResponse has messages property, which contains ChatMessage objects + for message in result.messages: + for content in message.contents: + if isinstance(content, TextContent) and content.annotations: + for annotation in content.annotations: + if isinstance(annotation, CitationAnnotation) and annotation.file_id: + annotations_found.append(annotation) + print(f"Found file annotation: file_id={annotation.file_id}") + if annotation.additional_properties and "container_id" in annotation.additional_properties: + print(f" container_id={annotation.additional_properties['container_id']}") + + if annotations_found: + print(f"SUCCESS: Found {len(annotations_found)} file annotation(s)") + + # Download the container files + downloaded_paths = await download_container_files(annotations_found, agent) + + if downloaded_paths: + print("\nDownloaded files available at:") + for path in downloaded_paths: + print(f" - {path.absolute()}") + else: + print("WARNING: No file annotations found in non-streaming response") + + +async def streaming_example() -> None: + """Example of downloading files from streaming response using HostedFileContent.""" + print("\n=== Streaming Response Example ===") + + async with ( + AzureCliCredential() as credential, + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( + name="V2CodeInterpreterFileAgentStreaming", + instructions="You are a helpful assistant that can write and execute Python code to create files.", + tools=HostedCodeInterpreterTool(), + ) + + print(f"User: {QUERY}\n") + file_contents_found: list[HostedFileContent] = [] + text_chunks: list[str] = [] + + async for update in agent.run_stream(QUERY): + if isinstance(update, AgentResponseUpdate): + for content in update.contents: + if isinstance(content, TextContent): + if content.text: + text_chunks.append(content.text) + if content.annotations: + for annotation in content.annotations: + if isinstance(annotation, CitationAnnotation) and annotation.file_id: + print(f"Found streaming CitationAnnotation: file_id={annotation.file_id}") + elif isinstance(content, HostedFileContent): + file_contents_found.append(content) + print(f"Found streaming HostedFileContent: file_id={content.file_id}") + if content.additional_properties and "container_id" in content.additional_properties: + print(f" container_id={content.additional_properties['container_id']}") + + print(f"\nAgent response: {''.join(text_chunks)[:200]}...") + + if file_contents_found: + print(f"SUCCESS: Found {len(file_contents_found)} file reference(s) in streaming") + + # Download the container files + downloaded_paths = await download_container_files(file_contents_found, agent) + + if downloaded_paths: + print("\n✓ Downloaded files available at:") + for path in downloaded_paths: + print(f" - {path.absolute()}") + else: + print("WARNING: No file annotations found in streaming response") + + +async def main() -> None: + print("AzureAIClient Code Interpreter File Download Sample\n") + await non_streaming_example() + await streaming_example() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_code_interpreter_file_generation.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_code_interpreter_file_generation.py index 76758d1b61..4118b50f9d 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_code_interpreter_file_generation.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_code_interpreter_file_generation.py @@ -3,19 +3,19 @@ import asyncio from agent_framework import ( + AgentResponseUpdate, CitationAnnotation, HostedCodeInterpreterTool, HostedFileContent, TextContent, ) -from agent_framework._agents import AgentRunResponseUpdate -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential """ Azure AI V2 Code Interpreter File Generation Sample -This sample demonstrates how the V2 AzureAIClient handles file annotations +This sample demonstrates how the AzureAIProjectAgentProvider handles file annotations when code interpreter generates text files. It shows both non-streaming and streaming approaches to verify file ID extraction. """ @@ -26,18 +26,20 @@ ) -async def test_non_streaming() -> None: - """Test non-streaming response - should have annotations on TextContent.""" - print("=== Testing Non-Streaming Response ===") +async def non_streaming_example() -> None: + """Example of extracting file annotations from non-streaming response.""" + print("=== Non-Streaming Response Example ===") async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="V2CodeInterpreterFileAgent", instructions="You are a helpful assistant that can write and execute Python code to create files.", tools=HostedCodeInterpreterTool(), - ) as agent, - ): + ) + print(f"User: {QUERY}\n") result = await agent.run(QUERY) @@ -45,7 +47,7 @@ async def test_non_streaming() -> None: # Check for annotations in the response annotations_found: list[str] = [] - # AgentRunResponse has messages property, which contains ChatMessage objects + # AgentResponse has messages property, which contains ChatMessage objects for message in result.messages: for content in message.contents: if isinstance(content, TextContent) and content.annotations: @@ -60,25 +62,27 @@ async def test_non_streaming() -> None: print("WARNING: No file annotations found in non-streaming response") -async def test_streaming() -> None: - """Test streaming response - check if file content is captured via HostedFileContent.""" - print("\n=== Testing Streaming Response ===") +async def streaming_example() -> None: + """Example of extracting file annotations from streaming response.""" + print("\n=== Streaming Response Example ===") async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="V2CodeInterpreterFileAgentStreaming", instructions="You are a helpful assistant that can write and execute Python code to create files.", tools=HostedCodeInterpreterTool(), - ) as agent, - ): + ) + print(f"User: {QUERY}\n") annotations_found: list[str] = [] text_chunks: list[str] = [] file_ids_found: list[str] = [] async for update in agent.run_stream(QUERY): - if isinstance(update, AgentRunResponseUpdate): + if isinstance(update, AgentResponseUpdate): for content in update.contents: if isinstance(content, TextContent): if content.text: @@ -102,9 +106,9 @@ async def test_streaming() -> None: async def main() -> None: - print("AzureAIClient Code Interpreter File Generation Test\n") - await test_non_streaming() - await test_streaming() + print("AzureAIClient Code Interpreter File Generation Sample\n") + await non_streaming_example() + await streaming_example() if __name__ == "__main__": diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_existing_agent.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_existing_agent.py index 7486b19ec7..7341068f10 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_existing_agent.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_existing_agent.py @@ -3,8 +3,7 @@ import asyncio import os -from agent_framework import ChatAgent -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.ai.projects.aio import AIProjectClient from azure.ai.projects.models import PromptAgentDefinition from azure.identity.aio import AzureCliCredential @@ -12,19 +11,23 @@ """ Azure AI Agent with Existing Agent Example -This sample demonstrates working with pre-existing Azure AI Agents by providing -agent name and version, showing agent reuse patterns for production scenarios. +This sample demonstrates working with pre-existing Azure AI Agents by using provider.get_agent() method, +showing agent reuse patterns for production scenarios. """ -async def main() -> None: +async def using_provider_get_agent() -> None: + print("=== Get existing Azure AI agent with provider.get_agent() ===") + # Create the client async with ( AzureCliCredential() as credential, AIProjectClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as project_client, ): + # Create remote agent using SDK directly azure_ai_agent = await project_client.agents.create_version( agent_name="MyNewTestAgent", + description="Agent for testing purposes.", definition=PromptAgentDefinition( model=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], # Setting specific requirements to verify that this agent is used. @@ -32,27 +35,22 @@ async def main() -> None: ), ) - chat_client = AzureAIClient( - project_client=project_client, - agent_name=azure_ai_agent.name, - # Property agent_version is required for existing agents. - # If this property is not configured, the client will try to create a new agent using - # provided agent_name. - # It's also possible to leave agent_version empty but set use_latest_version=True. - # This will pull latest available agent version and use that version for operations. - agent_version=azure_ai_agent.version, - ) - try: - async with ChatAgent( - chat_client=chat_client, - ) as agent: - query = "How are you?" - print(f"User: {query}") - result = await agent.run(query) - # Response that indicates that previously created agent was used: - # "I'm here and ready to help you! How can I assist you today? [END]" - print(f"Agent: {result}\n") + # Get newly created agent as ChatAgent by using provider.get_agent() + provider = AzureAIProjectAgentProvider(project_client=project_client) + agent = await provider.get_agent(name=azure_ai_agent.name) + + # Verify agent properties + print(f"Agent ID: {agent.id}") + print(f"Agent name: {agent.name}") + print(f"Agent description: {agent.description}") + + query = "How are you?" + print(f"User: {query}") + result = await agent.run(query) + # Response that indicates that previously created agent was used: + # "I'm here and ready to help you! How can I assist you today? [END]" + print(f"Agent: {result}\n") finally: # Clean up the agent manually await project_client.agents.delete_version( @@ -60,5 +58,9 @@ async def main() -> None: ) +async def main() -> None: + await using_provider_get_agent() + + if __name__ == "__main__": asyncio.run(main()) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_existing_conversation.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_existing_conversation.py index 43019d050c..099c5ad5aa 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_existing_conversation.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_existing_conversation.py @@ -4,7 +4,7 @@ from random import randint from typing import Annotated -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.ai.projects.aio import AIProjectClient from azure.identity.aio import AzureCliCredential from pydantic import Field @@ -12,7 +12,7 @@ """ Azure AI Agent Existing Conversation Example -This sample demonstrates usage of AzureAIClient with existing conversation created on service side. +This sample demonstrates usage of AzureAIProjectAgentProvider with existing conversation created on service side. """ @@ -24,9 +24,9 @@ def get_weather( return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." -async def example_with_client() -> None: - """Example shows how to specify existing conversation ID when initializing Azure AI Client.""" - print("=== Azure AI Agent With Existing Conversation and Client ===") +async def example_with_conversation_id() -> None: + """Example shows how to use existing conversation ID with the provider.""" + print("=== Azure AI Agent With Existing Conversation ===") async with ( AzureCliCredential() as credential, AIProjectClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as project_client, @@ -37,24 +37,23 @@ async def example_with_client() -> None: conversation_id = conversation.id print(f"Conversation ID: {conversation_id}") - async with AzureAIClient( - project_client=project_client, - # Specify conversation ID on client level - conversation_id=conversation_id, - ).create_agent( + provider = AzureAIProjectAgentProvider(project_client=project_client) + agent = await provider.create_agent( name="BasicAgent", instructions="You are a helpful agent.", tools=get_weather, - ) as agent: - query = "What's the weather like in Seattle?" - print(f"User: {query}") - result = await agent.run(query) - print(f"Agent: {result.text}\n") + ) - query = "What was my last question?" - print(f"User: {query}") - result = await agent.run(query) - print(f"Agent: {result.text}\n") + # Pass conversation_id at run level + query = "What's the weather like in Seattle?" + print(f"User: {query}") + result = await agent.run(query, conversation_id=conversation_id) + print(f"Agent: {result.text}\n") + + query = "What was my last question?" + print(f"User: {query}") + result = await agent.run(query, conversation_id=conversation_id) + print(f"Agent: {result.text}\n") async def example_with_thread() -> None: @@ -63,12 +62,14 @@ async def example_with_thread() -> None: async with ( AzureCliCredential() as credential, AIProjectClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as project_client, - AzureAIClient(project_client=project_client).create_agent( + ): + provider = AzureAIProjectAgentProvider(project_client=project_client) + agent = await provider.create_agent( name="BasicAgent", instructions="You are a helpful agent.", tools=get_weather, - ) as agent, - ): + ) + # Create a conversation using OpenAI client openai_client = project_client.get_openai_client() conversation = await openai_client.conversations.create() @@ -90,7 +91,7 @@ async def example_with_thread() -> None: async def main() -> None: - await example_with_client() + await example_with_conversation_id() await example_with_thread() diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_explicit_settings.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_explicit_settings.py index d5860a64f2..a3e3e24fe1 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_explicit_settings.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_explicit_settings.py @@ -5,8 +5,7 @@ from random import randint from typing import Annotated -from agent_framework import ChatAgent -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential from pydantic import Field @@ -27,22 +26,22 @@ def get_weather( async def main() -> None: - # Since no Agent ID is provided, the agent will be automatically created. # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. async with ( AzureCliCredential() as credential, - ChatAgent( - chat_client=AzureAIClient( - project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], - model_deployment_name=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], - credential=credential, - agent_name="WeatherAgent", - ), + AzureAIProjectAgentProvider( + project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], + model=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], + credential=credential, + ) as provider, + ): + agent = await provider.create_agent( + name="WeatherAgent", instructions="You are a helpful weather agent.", tools=get_weather, - ) as agent, - ): + ) + query = "What's the weather like in New York?" print(f"User: {query}") result = await agent.run(query) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_file_search.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_file_search.py index de8c3b22b1..9558546093 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_file_search.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_file_search.py @@ -4,8 +4,8 @@ import os from pathlib import Path -from agent_framework import ChatAgent, HostedFileSearchTool, HostedVectorStoreContent -from agent_framework.azure import AzureAIClient +from agent_framework import HostedFileSearchTool, HostedVectorStoreContent +from agent_framework.azure import AzureAIProjectAgentProvider from azure.ai.agents.aio import AgentsClient from azure.ai.agents.models import FileInfo, VectorStore from azure.identity.aio import AzureCliCredential @@ -32,7 +32,7 @@ async def main() -> None: async with ( AzureCliCredential() as credential, AgentsClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as agents_client, - AzureAIClient(credential=credential) as client, + AzureAIProjectAgentProvider(credential=credential) as provider, ): try: # 1. Upload file and create vector store @@ -48,22 +48,21 @@ async def main() -> None: # 2. Create file search tool with uploaded resources file_search_tool = HostedFileSearchTool(inputs=[HostedVectorStoreContent(vector_store_id=vector_store.id)]) - # 3. Create an agent with file search capabilities - # The tool_resources are automatically extracted from HostedFileSearchTool - async with ChatAgent( - chat_client=client, + # 3. Create an agent with file search capabilities using the provider + agent = await provider.create_agent( name="EmployeeSearchAgent", instructions=( "You are a helpful assistant that can search through uploaded employee files " "to answer questions about employees." ), tools=file_search_tool, - ) as agent: - # 4. Simulate conversation with the agent - for user_input in USER_INPUTS: - print(f"# User: '{user_input}'") - response = await agent.run(user_input) - print(f"# Agent: {response.text}") + ) + + # 4. Simulate conversation with the agent + for user_input in USER_INPUTS: + print(f"# User: '{user_input}'") + response = await agent.run(user_input) + print(f"# Agent: {response.text}") finally: # 5. Cleanup: Delete the vector store and file in case of earlier failure to prevent orphaned resources. if vector_store: diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_hosted_mcp.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_hosted_mcp.py index dd72108c05..8b120f703d 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_hosted_mcp.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_hosted_mcp.py @@ -3,8 +3,8 @@ import asyncio from typing import Any -from agent_framework import AgentProtocol, AgentRunResponse, AgentThread, ChatMessage, HostedMCPTool -from agent_framework.azure import AzureAIClient +from agent_framework import AgentProtocol, AgentResponse, AgentThread, ChatMessage, HostedMCPTool +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential """ @@ -14,7 +14,7 @@ """ -async def handle_approvals_without_thread(query: str, agent: "AgentProtocol") -> AgentRunResponse: +async def handle_approvals_without_thread(query: str, agent: "AgentProtocol") -> AgentResponse: """When we don't have a thread, we need to ensure we return with the input, approval request and approval.""" result = await agent.run(query, store=False) @@ -35,7 +35,7 @@ async def handle_approvals_without_thread(query: str, agent: "AgentProtocol") -> return result -async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", thread: "AgentThread") -> AgentRunResponse: +async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", thread: "AgentThread") -> AgentResponse: """Here we let the thread deal with the previous responses, and we just rerun with the approval.""" result = await agent.run(query, thread=thread) @@ -59,12 +59,13 @@ async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", threa async def run_hosted_mcp_without_approval() -> None: """Example showing MCP Tools without approval.""" - # Since no Agent ID is provided, the agent will be automatically created. # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="MyLearnDocsAgent", instructions="You are a helpful assistant that can help with Microsoft documentation questions.", tools=HostedMCPTool( @@ -72,8 +73,8 @@ async def run_hosted_mcp_without_approval() -> None: url="https://learn.microsoft.com/api/mcp", approval_mode="never_require", ), - ) as agent, - ): + ) + query = "How to create an Azure storage account using az cli?" print(f"User: {query}") result = await handle_approvals_without_thread(query, agent) @@ -84,12 +85,13 @@ async def run_hosted_mcp_with_approval_and_thread() -> None: """Example showing MCP Tools with approvals using a thread.""" print("=== MCP with approvals and with thread ===") - # Since no Agent ID is provided, the agent will be automatically created. # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="MyApiSpecsAgent", instructions="You are a helpful agent that can use MCP tools to assist users.", tools=HostedMCPTool( @@ -97,8 +99,8 @@ async def run_hosted_mcp_with_approval_and_thread() -> None: url="https://gitmcp.io/Azure/azure-rest-api-specs", approval_mode="always_require", ), - ) as agent, - ): + ) + thread = agent.get_new_thread() query = "Please summarize the Azure REST API specifications Readme" print(f"User: {query}") diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_image_generation.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_image_generation.py index 8274c43ab0..21166911ac 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_image_generation.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_image_generation.py @@ -4,13 +4,13 @@ import aiofiles from agent_framework import DataContent, HostedImageGenerationTool -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential """ Azure AI Agent with Image Generation Example -This sample demonstrates basic usage of AzureAIClient to create an agent +This sample demonstrates basic usage of AzureAIProjectAgentProvider to create an agent that can generate images based on user requirements. Pre-requisites: @@ -20,12 +20,13 @@ async def main() -> None: - # Since no Agent ID is provided, the agent will be automatically created. # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="ImageGenAgent", instructions="Generate images based on user requirements.", tools=[ @@ -37,14 +38,14 @@ async def main() -> None: } ) ], - ) as agent, - ): + ) + query = "Generate an image of Microsoft logo." print(f"User: {query}") result = await agent.run( query, # These additional options are required for image generation - additional_chat_options={ + options={ "extra_headers": {"x-ms-oai-image-generation-deployment": "gpt-image-1-mini"}, }, ) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_local_mcp.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_local_mcp.py index 5f97116707..91b6228b71 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_local_mcp.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_local_mcp.py @@ -3,7 +3,7 @@ import asyncio from agent_framework import MCPStreamableHTTPTool -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential """ @@ -19,20 +19,22 @@ async def main() -> None: - """Example showing use of Local MCP Tool with AzureAIClient.""" + """Example showing use of Local MCP Tool with AzureAIProjectAgentProvider.""" print("=== Azure AI Agent with Local MCP Tools Example ===\n") async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="DocsAgent", instructions="You are a helpful assistant that can help with Microsoft documentation questions.", tools=MCPStreamableHTTPTool( name="Microsoft Learn MCP", url="https://learn.microsoft.com/api/mcp", ), - ) as agent, - ): + ) + # First query first_query = "How to create an Azure storage account using az cli?" print(f"User: {first_query}") diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_memory_search.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_memory_search.py index 2996840489..72b9ea1a01 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_memory_search.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_memory_search.py @@ -3,7 +3,7 @@ import os import uuid -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.ai.projects.aio import AIProjectClient from azure.ai.projects.models import MemoryStoreDefaultDefinition, MemoryStoreDefaultOptions from azure.identity.aio import AzureCliCredential @@ -11,7 +11,7 @@ """ Azure AI Agent with Memory Search Example -This sample demonstrates usage of AzureAIClient with memory search capabilities +This sample demonstrates usage of AzureAIProjectAgentProvider with memory search capabilities to retrieve relevant past user messages and maintain conversation context across sessions. It shows explicit memory store creation using Azure AI Projects client and agent creation using the Agent Framework. @@ -46,18 +46,20 @@ async def main() -> None: ) print(f"Created memory store: {memory_store.name} ({memory_store.id}): {memory_store.description}") - # Then, create the agent using Agent Framework - async with AzureAIClient(credential=credential).create_agent( - name="MyMemoryAgent", - instructions="""You are a helpful assistant that remembers past conversations. - Use the memory search tool to recall relevant information from previous interactions.""", - tools={ - "type": "memory_search", - "memory_store_name": memory_store.name, - "scope": "user_123", - "update_delay": 1, # Wait 1 second before updating memories (use higher value in production) - }, - ) as agent: + # Then, create the agent using Agent Framework provider + async with AzureAIProjectAgentProvider(credential=credential) as provider: + agent = await provider.create_agent( + name="MyMemoryAgent", + instructions="""You are a helpful assistant that remembers past conversations. + Use the memory search tool to recall relevant information from previous interactions.""", + tools={ + "type": "memory_search", + "memory_store_name": memory_store.name, + "scope": "user_123", + "update_delay": 1, # Wait 1 second before updating memories (use higher value in production) + }, + ) + # First interaction - establish some preferences print("=== First conversation ===") query1 = "I prefer dark roast coffee" diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_microsoft_fabric.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_microsoft_fabric.py index e19837f99b..0f3b39d192 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_microsoft_fabric.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_microsoft_fabric.py @@ -2,13 +2,13 @@ import asyncio import os -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential """ Azure AI Agent with Microsoft Fabric Example -This sample demonstrates usage of AzureAIClient with Microsoft Fabric +This sample demonstrates usage of AzureAIProjectAgentProvider with Microsoft Fabric to query Fabric data sources and provide responses based on data analysis. Prerequisites: @@ -21,7 +21,9 @@ async def main() -> None: async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="MyFabricAgent", instructions="You are a helpful assistant.", tools={ @@ -34,8 +36,8 @@ async def main() -> None: ] }, }, - ) as agent, - ): + ) + query = "Tell me about sales records" print(f"User: {query}") result = await agent.run(query) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_openapi.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_openapi.py index 8824106656..17a6d78f91 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_openapi.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_openapi.py @@ -4,13 +4,13 @@ from pathlib import Path import aiofiles -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential """ Azure AI Agent with OpenAPI Tool Example -This sample demonstrates usage of AzureAIClient with OpenAPI tools +This sample demonstrates usage of AzureAIProjectAgentProvider with OpenAPI tools to call external APIs defined by OpenAPI specifications. Prerequisites: @@ -29,7 +29,9 @@ async def main() -> None: async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="MyOpenAPIAgent", instructions="""You are a helpful assistant that can use country APIs to provide information. Use the available OpenAPI tools to answer questions about countries, currencies, and demographics.""", @@ -42,8 +44,8 @@ async def main() -> None: "auth": {"type": "anonymous"}, }, }, - ) as agent, - ): + ) + query = "What is the name and population of the country that uses currency with abbreviation THB?" print(f"User: {query}") result = await agent.run(query) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_response_format.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_response_format.py index dfb4ce6a21..f446b02c67 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_response_format.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_response_format.py @@ -2,14 +2,14 @@ import asyncio -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential from pydantic import BaseModel, ConfigDict """ Azure AI Agent Response Format Example -This sample demonstrates basic usage of AzureAIClient with response format, +This sample demonstrates basic usage of AzureAIProjectAgentProvider with response format, also known as structured outputs. """ @@ -24,23 +24,22 @@ class ReleaseBrief(BaseModel): async def main() -> None: """Example of using response_format property.""" - # Since no Agent ID is provided, the agent will be automatically created. # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="ProductMarketerAgent", instructions="Return launch briefs as structured JSON.", - ) as agent, - ): + # Specify Pydantic model for structured output via default_options + default_options={"response_format": ReleaseBrief}, + ) + query = "Draft a launch brief for the Contoso Note app." print(f"User: {query}") - result = await agent.run( - query, - # Specify type to use as response - response_format=ReleaseBrief, - ) + result = await agent.run(query) if isinstance(result.value, ReleaseBrief): release_brief = result.value diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_runtime_json_schema.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_runtime_json_schema.py index 17bf359afe..21f67a0984 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_runtime_json_schema.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_runtime_json_schema.py @@ -2,13 +2,13 @@ import asyncio -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential """ Azure AI Agent Response Format Example with Runtime JSON Schema -This sample demonstrates basic usage of AzureAIClient with response format, +This sample demonstrates basic usage of AzureAIProjectAgentProvider with response format, also known as structured outputs. """ @@ -29,24 +29,19 @@ async def main() -> None: - """Example of using response_format property.""" + """Example of using response_format property with a runtime JSON schema.""" - # Since no Agent ID is provided, the agent will be automatically created. # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( - name="ProductMarketerAgent", - instructions="Return launch briefs as structured JSON.", - ) as agent, + AzureAIProjectAgentProvider(credential=credential) as provider, ): - query = "Draft a launch brief for the Contoso Note app." - print(f"User: {query}") - result = await agent.run( - query, - # Specify type to use as response - additional_chat_options={ + # Pass response_format via default_options using dict schema format + agent = await provider.create_agent( + name="WeatherDigestAgent", + instructions="Return sample weather digest as structured JSON.", + default_options={ "response_format": { "type": "json_schema", "json_schema": { @@ -54,10 +49,14 @@ async def main() -> None: "strict": True, "schema": runtime_schema, }, - }, + } }, ) + query = "Draft a sample weather digest." + print(f"User: {query}") + result = await agent.run(query) + print(result.text) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_sharepoint.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_sharepoint.py index a58de50e84..cd7765741e 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_sharepoint.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_sharepoint.py @@ -2,13 +2,13 @@ import asyncio import os -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential """ Azure AI Agent with SharePoint Example -This sample demonstrates usage of AzureAIClient with SharePoint +This sample demonstrates usage of AzureAIProjectAgentProvider with SharePoint to search through SharePoint content and answer user questions about it. Prerequisites: @@ -21,7 +21,9 @@ async def main() -> None: async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="MySharePointAgent", instructions="""You are a helpful agent that can use SharePoint tools to assist users. Use the available SharePoint tools to answer questions and perform tasks.""", @@ -35,8 +37,8 @@ async def main() -> None: ] }, }, - ) as agent, - ): + ) + query = "What is Contoso whistleblower policy?" print(f"User: {query}") result = await agent.run(query) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_thread.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_thread.py index fe8c7f5370..f4e69e02ca 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_thread.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_thread.py @@ -4,7 +4,7 @@ from random import randint from typing import Annotated -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential from pydantic import Field @@ -30,12 +30,14 @@ async def example_with_automatic_thread_creation() -> None: async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="BasicWeatherAgent", instructions="You are a helpful weather agent.", tools=get_weather, - ) as agent, - ): + ) + # First conversation - no thread provided, will be created automatically query1 = "What's the weather like in Seattle?" print(f"User: {query1}") @@ -59,12 +61,14 @@ async def example_with_thread_persistence_in_memory() -> None: async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="BasicWeatherAgent", instructions="You are a helpful weather agent.", tools=get_weather, - ) as agent, - ): + ) + # Create a new thread that will be reused thread = agent.get_new_thread() @@ -100,12 +104,14 @@ async def example_with_existing_thread_id() -> None: async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="BasicWeatherAgent", instructions="You are a helpful weather agent.", tools=get_weather, - ) as agent, - ): + ) + # Start a conversation and get the thread ID thread = agent.get_new_thread() @@ -121,21 +127,21 @@ async def example_with_existing_thread_id() -> None: if existing_thread_id: print("\n--- Continuing with the same thread ID in a new agent instance ---") - async with ( - AzureAIClient(credential=credential).create_agent( - name="BasicWeatherAgent", - instructions="You are a helpful weather agent.", - tools=get_weather, - ) as agent, - ): - # Create a thread with the existing ID - thread = agent.get_new_thread(service_thread_id=existing_thread_id) - - query2 = "What was the last city I asked about?" - print(f"User: {query2}") - result2 = await agent.run(query2, thread=thread) - print(f"Agent: {result2.text}") - print("Note: The agent continues the conversation from the previous thread by using thread ID.\n") + # Create a new agent instance from the same provider + agent2 = await provider.create_agent( + name="BasicWeatherAgent", + instructions="You are a helpful weather agent.", + tools=get_weather, + ) + + # Create a thread with the existing ID + thread = agent2.get_new_thread(service_thread_id=existing_thread_id) + + query2 = "What was the last city I asked about?" + print(f"User: {query2}") + result2 = await agent2.run(query2, thread=thread) + print(f"Agent: {result2.text}") + print("Note: The agent continues the conversation from the previous thread by using thread ID.\n") async def main() -> None: diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_web_search.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_web_search.py index ef788e4f5e..9ecb416f8d 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_web_search.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_web_search.py @@ -3,13 +3,13 @@ import asyncio from agent_framework import HostedWebSearchTool -from agent_framework.azure import AzureAIClient +from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential """ Azure AI Agent With Web Search -This sample demonstrates basic usage of AzureAIClient to create an agent +This sample demonstrates basic usage of AzureAIProjectAgentProvider to create an agent that can perform web searches using the HostedWebSearchTool. Pre-requisites: @@ -19,17 +19,18 @@ async def main() -> None: - # Since no Agent ID is provided, the agent will be automatically created. # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. async with ( AzureCliCredential() as credential, - AzureAIClient(credential=credential).create_agent( + AzureAIProjectAgentProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="WebsearchAgent", instructions="You are a helpful assistant that can search the web", tools=[HostedWebSearchTool()], - ) as agent, - ): + ) + query = "What's the weather today in Seattle?" print(f"User: {query}") result = await agent.run(query) diff --git a/python/samples/getting_started/agents/azure_ai_agent/README.md b/python/samples/getting_started/agents/azure_ai_agent/README.md index 84ed7eeba3..5440b2d3ba 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/README.md +++ b/python/samples/getting_started/agents/azure_ai_agent/README.md @@ -1,27 +1,53 @@ # Azure AI Agent Examples -This folder contains examples demonstrating different ways to create and use agents with the Azure AI chat client from the `agent_framework.azure` package. These examples use the `AzureAIAgentClient` with the `azure-ai-agents` 1.x (V1) API surface. For updated V2 (`azure-ai-projects` 2.x) samples, see the [Azure AI V2 examples folder](../azure_ai/). +This folder contains examples demonstrating different ways to create and use agents with Azure AI using the `AzureAIAgentsProvider` from the `agent_framework.azure` package. These examples use the `azure-ai-agents` 1.x (V1) API surface. For updated V2 (`azure-ai-projects` 2.x) samples, see the [Azure AI V2 examples folder](../azure_ai/). + +## Provider Pattern + +All examples in this folder use the `AzureAIAgentsProvider` class which provides a high-level interface for agent operations: + +- **`create_agent()`** - Create a new agent on the Azure AI service +- **`get_agent()`** - Retrieve an existing agent by ID or from a pre-fetched Agent object +- **`as_agent()`** - Wrap an SDK Agent object as a ChatAgent without HTTP calls + +```python +from agent_framework.azure import AzureAIAgentsProvider +from azure.identity.aio import AzureCliCredential + +async with ( + AzureCliCredential() as credential, + AzureAIAgentsProvider(credential=credential) as provider, +): + agent = await provider.create_agent( + name="MyAgent", + instructions="You are a helpful assistant.", + tools=my_function, + ) + result = await agent.run("Hello!") +``` ## Examples | File | Description | |------|-------------| -| [`azure_ai_basic.py`](azure_ai_basic.py) | The simplest way to create an agent using `ChatAgent` with `AzureAIAgentClient`. It automatically handles all configuration using environment variables. | +| [`azure_ai_provider_methods.py`](azure_ai_provider_methods.py) | Comprehensive example demonstrating all `AzureAIAgentsProvider` methods: `create_agent()`, `get_agent()`, `as_agent()`, and managing multiple agents from a single provider. | +| [`azure_ai_basic.py`](azure_ai_basic.py) | The simplest way to create an agent using `AzureAIAgentsProvider`. It automatically handles all configuration using environment variables. Shows both streaming and non-streaming responses. | | [`azure_ai_with_bing_custom_search.py`](azure_ai_with_bing_custom_search.py) | Shows how to use Bing Custom Search with Azure AI agents to find real-time information from the web using custom search configurations. Demonstrates how to set up and use HostedWebSearchTool with custom search instances. | | [`azure_ai_with_bing_grounding.py`](azure_ai_with_bing_grounding.py) | Shows how to use Bing Grounding search with Azure AI agents to find real-time information from the web. Demonstrates web search capabilities with proper source citations and comprehensive error handling. | | [`azure_ai_with_bing_grounding_citations.py`](azure_ai_with_bing_grounding_citations.py) | Demonstrates how to extract and display citations from Bing Grounding search responses. Shows how to collect citation annotations (title, URL, snippet) during streaming responses, enabling users to verify sources and access referenced content. | | [`azure_ai_with_code_interpreter_file_generation.py`](azure_ai_with_code_interpreter_file_generation.py) | Shows how to retrieve file IDs from code interpreter generated files using both streaming and non-streaming approaches. | | [`azure_ai_with_code_interpreter.py`](azure_ai_with_code_interpreter.py) | Shows how to use the HostedCodeInterpreterTool with Azure AI agents to write and execute Python code. Includes helper methods for accessing code interpreter data from response chunks. | -| [`azure_ai_with_existing_agent.py`](azure_ai_with_existing_agent.py) | Shows how to work with a pre-existing agent by providing the agent ID to the Azure AI chat client. This example also demonstrates proper cleanup of manually created agents. | -| [`azure_ai_with_existing_thread.py`](azure_ai_with_existing_thread.py) | Shows how to work with a pre-existing thread by providing the thread ID to the Azure AI chat client. This example also demonstrates proper cleanup of manually created threads. | -| [`azure_ai_with_explicit_settings.py`](azure_ai_with_explicit_settings.py) | Shows how to create an agent with explicitly configured `AzureAIAgentClient` settings, including project endpoint, model deployment, credentials, and agent name. | -| [`azure_ai_with_azure_ai_search.py`](azure_ai_with_azure_ai_search.py) | Demonstrates how to use Azure AI Search with Azure AI agents to search through indexed data. Shows how to configure search parameters, query types, and integrate with existing search indexes. | -| [`azure_ai_with_file_search.py`](azure_ai_with_file_search.py) | Demonstrates how to use the HostedFileSearchTool with Azure AI agents to search through uploaded documents. Shows file upload, vector store creation, and querying document content. Includes both streaming and non-streaming examples. | +| [`azure_ai_with_existing_agent.py`](azure_ai_with_existing_agent.py) | Shows how to work with an existing SDK Agent object using `provider.as_agent()`. This wraps the agent without making HTTP calls. | +| [`azure_ai_with_existing_thread.py`](azure_ai_with_existing_thread.py) | Shows how to work with a pre-existing thread by providing the thread ID. Demonstrates proper cleanup of manually created threads. | +| [`azure_ai_with_explicit_settings.py`](azure_ai_with_explicit_settings.py) | Shows how to create an agent with explicitly configured provider settings, including project endpoint and model deployment name. | +| [`azure_ai_with_azure_ai_search.py`](azure_ai_with_azure_ai_search.py) | Demonstrates how to use Azure AI Search with Azure AI agents. Shows how to create an agent with search tools using the SDK directly and wrap it with `provider.get_agent()`. | +| [`azure_ai_with_file_search.py`](azure_ai_with_file_search.py) | Demonstrates how to use the HostedFileSearchTool with Azure AI agents to search through uploaded documents. Shows file upload, vector store creation, and querying document content. | | [`azure_ai_with_function_tools.py`](azure_ai_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). | | [`azure_ai_with_hosted_mcp.py`](azure_ai_with_hosted_mcp.py) | Shows how to integrate Azure AI agents with hosted Model Context Protocol (MCP) servers for enhanced functionality and tool integration. Demonstrates remote MCP server connections and tool discovery. | | [`azure_ai_with_local_mcp.py`](azure_ai_with_local_mcp.py) | Shows how to integrate Azure AI agents with local Model Context Protocol (MCP) servers for enhanced functionality and tool integration. Demonstrates both agent-level and run-level tool configuration. | | [`azure_ai_with_multiple_tools.py`](azure_ai_with_multiple_tools.py) | Demonstrates how to use multiple tools together with Azure AI agents, including web search, MCP servers, and function tools. Shows coordinated multi-tool interactions and approval workflows. | -| [`azure_ai_with_openapi_tools.py`](azure_ai_with_openapi_tools.py) | Demonstrates how to use OpenAPI tools with Azure AI agents to integrate external REST APIs. Shows OpenAPI specification loading, anonymous authentication, thread context management, and coordinated multi-API conversations using weather and countries APIs. | +| [`azure_ai_with_openapi_tools.py`](azure_ai_with_openapi_tools.py) | Demonstrates how to use OpenAPI tools with Azure AI agents to integrate external REST APIs. Shows OpenAPI specification loading, anonymous authentication, thread context management, and coordinated multi-API conversations. | +| [`azure_ai_with_response_format.py`](azure_ai_with_response_format.py) | Demonstrates how to use structured outputs with Azure AI agents using Pydantic models. | | [`azure_ai_with_thread.py`](azure_ai_with_thread.py) | Demonstrates thread management with Azure AI agents, including automatic thread creation for stateless conversations and explicit thread management for maintaining conversation context across multiple interactions. | ## Environment Variables diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_basic.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_basic.py index 216425cc40..64f0996184 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_basic.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_basic.py @@ -4,14 +4,14 @@ from random import randint from typing import Annotated -from agent_framework.azure import AzureAIAgentClient +from agent_framework.azure import AzureAIAgentsProvider from azure.identity.aio import AzureCliCredential from pydantic import Field """ Azure AI Agent Basic Example -This sample demonstrates basic usage of AzureAIAgentClient to create agents with automatic +This sample demonstrates basic usage of AzureAIAgentsProvider to create agents with automatic lifecycle management. Shows both streaming and non-streaming responses with function tools. """ @@ -28,18 +28,17 @@ async def non_streaming_example() -> None: """Example of non-streaming response (get the complete result at once).""" print("=== Non-streaming Response Example ===") - # Since no Agent ID is provided, the agent will be automatically created - # and deleted after getting a response # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. async with ( AzureCliCredential() as credential, - AzureAIAgentClient(credential=credential).create_agent( + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="WeatherAgent", instructions="You are a helpful weather agent.", tools=get_weather, - ) as agent, - ): + ) query = "What's the weather like in Seattle?" print(f"User: {query}") result = await agent.run(query) @@ -50,18 +49,17 @@ async def streaming_example() -> None: """Example of streaming response (get results as they are generated).""" print("=== Streaming Response Example ===") - # Since no Agent ID is provided, the agent will be automatically created - # and deleted after getting a response # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. async with ( AzureCliCredential() as credential, - AzureAIAgentClient(credential=credential).create_agent( + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="WeatherAgent", instructions="You are a helpful weather agent.", tools=get_weather, - ) as agent, - ): + ) query = "What's the weather like in Portland?" print(f"User: {query}") print("Agent: ", end="", flush=True) diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_provider_methods.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_provider_methods.py new file mode 100644 index 0000000000..0a07cc5c35 --- /dev/null +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_provider_methods.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import os +from random import randint +from typing import Annotated + +from agent_framework.azure import AzureAIAgentsProvider +from azure.ai.agents.aio import AgentsClient +from azure.identity.aio import AzureCliCredential +from pydantic import Field + +""" +Azure AI Agent Provider Methods Example + +This sample demonstrates the methods available on the AzureAIAgentsProvider class: +- create_agent(): Create a new agent on the service +- get_agent(): Retrieve an existing agent by ID +- as_agent(): Wrap an SDK Agent object without making HTTP calls +""" + + +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +async def create_agent_example() -> None: + """Create a new agent using provider.create_agent().""" + print("\n--- create_agent() ---") + + async with ( + AzureCliCredential() as credential, + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( + name="WeatherAgent", + instructions="You are a helpful weather assistant.", + tools=get_weather, + ) + + print(f"Created: {agent.name} (ID: {agent.id})") + result = await agent.run("What's the weather in Seattle?") + print(f"Response: {result}") + + +async def get_agent_example() -> None: + """Retrieve an existing agent by ID using provider.get_agent().""" + print("\n--- get_agent() ---") + + async with ( + AzureCliCredential() as credential, + AgentsClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as agents_client, + AzureAIAgentsProvider(agents_client=agents_client) as provider, + ): + # Create an agent directly with SDK (simulating pre-existing agent) + sdk_agent = await agents_client.create_agent( + model=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], + name="ExistingAgent", + instructions="You always respond with 'Hello!'", + ) + + try: + # Retrieve using provider + agent = await provider.get_agent(sdk_agent.id) + print(f"Retrieved: {agent.name} (ID: {agent.id})") + + result = await agent.run("Hi there!") + print(f"Response: {result}") + finally: + await agents_client.delete_agent(sdk_agent.id) + + +async def as_agent_example() -> None: + """Wrap an SDK Agent object using provider.as_agent().""" + print("\n--- as_agent() ---") + + async with ( + AzureCliCredential() as credential, + AgentsClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as agents_client, + AzureAIAgentsProvider(agents_client=agents_client) as provider, + ): + # Create agent using SDK + sdk_agent = await agents_client.create_agent( + model=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], + name="WrappedAgent", + instructions="You respond with poetry.", + ) + + try: + # Wrap synchronously (no HTTP call) + agent = provider.as_agent(sdk_agent) + print(f"Wrapped: {agent.name} (ID: {agent.id})") + + result = await agent.run("Tell me about the sunset.") + print(f"Response: {result}") + finally: + await agents_client.delete_agent(sdk_agent.id) + + +async def multiple_agents_example() -> None: + """Create and manage multiple agents with a single provider.""" + print("\n--- Multiple Agents ---") + + async with ( + AzureCliCredential() as credential, + AzureAIAgentsProvider(credential=credential) as provider, + ): + weather_agent = await provider.create_agent( + name="WeatherSpecialist", + instructions="You are a weather specialist.", + tools=get_weather, + ) + + greeter_agent = await provider.create_agent( + name="GreeterAgent", + instructions="You are a friendly greeter.", + ) + + print(f"Created: {weather_agent.name}, {greeter_agent.name}") + + greeting = await greeter_agent.run("Hello!") + print(f"Greeter: {greeting}") + + weather = await weather_agent.run("What's the weather in Tokyo?") + print(f"Weather: {weather}") + + +async def main() -> None: + print("Azure AI Agent Provider Methods") + + await create_agent_example() + await get_agent_example() + await as_agent_example() + await multiple_agents_example() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_azure_ai_search.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_azure_ai_search.py index 34d4913651..8f36d5ebec 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_azure_ai_search.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_azure_ai_search.py @@ -3,8 +3,8 @@ import asyncio import os -from agent_framework import ChatAgent, CitationAnnotation -from agent_framework.azure import AzureAIAgentClient +from agent_framework import CitationAnnotation +from agent_framework.azure import AzureAIAgentsProvider from azure.ai.agents.aio import AgentsClient from azure.ai.projects.aio import AIProjectClient from azure.ai.projects.models import ConnectionType @@ -41,6 +41,7 @@ async def main() -> None: AzureCliCredential() as credential, AIProjectClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as project_client, AgentsClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as agents_client, + AzureAIAgentsProvider(agents_client=agents_client) as provider, ): ai_search_conn_id = "" async for connection in project_client.connections.list(): @@ -48,7 +49,8 @@ async def main() -> None: ai_search_conn_id = connection.id break - # 1. Create Azure AI agent with the search tool + # 1. Create Azure AI agent with the search tool using SDK directly + # (Azure AI Search tool requires special tool_resources configuration) azure_ai_agent = await agents_client.create_agent( model=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], name="HotelSearchAgent", @@ -70,47 +72,42 @@ async def main() -> None: }, ) - # 2. Create chat client with the existing agent - chat_client = AzureAIAgentClient(agents_client=agents_client, agent_id=azure_ai_agent.id) - try: - async with ChatAgent( - chat_client=chat_client, - # Additional instructions for this specific conversation - instructions=("You are a helpful agent that uses the search tool and index to find hotel information."), - ) as agent: - print("This agent uses raw Azure AI Search tool to search hotel data.\n") - - # 3. Simulate conversation with the agent - user_input = ( - "Use Azure AI search knowledge tool to find detailed information about a winter hotel." - " Use the search tool and index." # You can modify prompt to force tool usage - ) - print(f"User: {user_input}") - print("Agent: ", end="", flush=True) - - # Stream the response and collect citations - citations: list[CitationAnnotation] = [] - async for chunk in agent.run_stream(user_input): - if chunk.text: - print(chunk.text, end="", flush=True) - - # Collect citations from Azure AI Search responses - for content in getattr(chunk, "contents", []): - annotations = getattr(content, "annotations", []) - if annotations: - citations.extend(annotations) - - print() - - # Display collected citation - if citations: - print("\n\nCitation:") - for i, citation in enumerate(citations, 1): - print(f"[{i}] {citation.url}") - - print("\n" + "=" * 50 + "\n") - print("Hotel search conversation completed!") + # 2. Use provider.as_agent() to wrap the existing agent + agent = provider.as_agent(agent=azure_ai_agent) + + print("This agent uses raw Azure AI Search tool to search hotel data.\n") + + # 3. Simulate conversation with the agent + user_input = ( + "Use Azure AI search knowledge tool to find detailed information about a winter hotel." + " Use the search tool and index." # You can modify prompt to force tool usage + ) + print(f"User: {user_input}") + print("Agent: ", end="", flush=True) + + # Stream the response and collect citations + citations: list[CitationAnnotation] = [] + async for chunk in agent.run_stream(user_input): + if chunk.text: + print(chunk.text, end="", flush=True) + + # Collect citations from Azure AI Search responses + for content in getattr(chunk, "contents", []): + annotations = getattr(content, "annotations", []) + if annotations: + citations.extend(annotations) + + print() + + # Display collected citation + if citations: + print("\n\nCitation:") + for i, citation in enumerate(citations, 1): + print(f"[{i}] {citation.url}") + + print("\n" + "=" * 50 + "\n") + print("Hotel search conversation completed!") finally: # Clean up the agent manually diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_custom_search.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_custom_search.py index 1ef8d6bcb1..ef41cf7c35 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_custom_search.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_custom_search.py @@ -2,8 +2,8 @@ import asyncio -from agent_framework import ChatAgent, HostedWebSearchTool -from agent_framework.azure import AzureAIAgentClient +from agent_framework import HostedWebSearchTool +from agent_framework.azure import AzureAIAgentsProvider from azure.identity.aio import AzureCliCredential """ @@ -37,19 +37,20 @@ async def main() -> None: description="Search the web for current information using Bing Custom Search", ) - # 2. Use AzureAIAgentClient as async context manager for automatic cleanup + # 2. Use AzureAIAgentsProvider for agent creation and management async with ( - AzureAIAgentClient(credential=AzureCliCredential()) as client, - ChatAgent( - chat_client=client, + AzureCliCredential() as credential, + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="BingSearchAgent", instructions=( "You are a helpful agent that can use Bing Custom Search tools to assist users. " "Use the available Bing Custom Search tools to answer questions and perform tasks." ), tools=bing_search_tool, - ) as agent, - ): + ) + # 3. Demonstrate agent capabilities with bing custom search print("=== Azure AI Agent with Bing Custom Search ===\n") diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_grounding.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_grounding.py index a83f5bb1f4..016c6ddeb8 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_grounding.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_grounding.py @@ -2,8 +2,8 @@ import asyncio -from agent_framework import ChatAgent, HostedWebSearchTool -from agent_framework_azure_ai import AzureAIAgentClient +from agent_framework import HostedWebSearchTool +from agent_framework.azure import AzureAIAgentsProvider from azure.identity.aio import AzureCliCredential """ @@ -32,11 +32,12 @@ async def main() -> None: description="Search the web for current information using Bing", ) - # 2. Use AzureAIAgentClient as async context manager for automatic cleanup + # 2. Use AzureAIAgentsProvider for agent creation and management async with ( - AzureAIAgentClient(credential=AzureCliCredential()) as client, - ChatAgent( - chat_client=client, + AzureCliCredential() as credential, + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="BingSearchAgent", instructions=( "You are a helpful assistant that can search the web for current information. " @@ -44,9 +45,9 @@ async def main() -> None: "well-sourced answers. Always cite your sources when possible." ), tools=bing_search_tool, - ) as agent, - ): - # 4. Demonstrate agent capabilities with web search + ) + + # 3. Demonstrate agent capabilities with web search print("=== Azure AI Agent with Bing Grounding Search ===\n") user_input = "What is the most popular programming language?" diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_grounding_citations.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_grounding_citations.py index 63245a4d12..752d7e5a54 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_grounding_citations.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_grounding_citations.py @@ -2,8 +2,8 @@ import asyncio -from agent_framework import ChatAgent, CitationAnnotation, HostedWebSearchTool -from agent_framework.azure import AzureAIAgentClient +from agent_framework import CitationAnnotation, HostedWebSearchTool +from agent_framework.azure import AzureAIAgentsProvider from azure.identity.aio import AzureCliCredential """ @@ -34,11 +34,12 @@ async def main() -> None: description="Search the web for current information using Bing", ) - # 2. Use AzureAIAgentClient as async context manager for automatic cleanup + # 2. Use AzureAIAgentsProvider for agent creation and management async with ( - AzureAIAgentClient(credential=AzureCliCredential()) as client, - ChatAgent( - chat_client=client, + AzureCliCredential() as credential, + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="BingSearchAgent", instructions=( "You are a helpful assistant that can search the web for current information. " @@ -46,8 +47,8 @@ async def main() -> None: "well-sourced answers. Always cite your sources when possible." ), tools=bing_search_tool, - ) as agent, - ): + ) + # 3. Demonstrate agent capabilities with web search print("=== Azure AI Agent with Bing Grounding Search ===\n") diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_code_interpreter.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_code_interpreter.py index 0136512373..a40ee17258 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_code_interpreter.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_code_interpreter.py @@ -2,8 +2,8 @@ import asyncio -from agent_framework import AgentRunResponse, ChatResponseUpdate, HostedCodeInterpreterTool -from agent_framework.azure import AzureAIAgentClient +from agent_framework import AgentResponse, ChatResponseUpdate, HostedCodeInterpreterTool +from agent_framework.azure import AzureAIAgentsProvider from azure.ai.agents.models import ( RunStepDeltaCodeInterpreterDetailItemObject, ) @@ -17,7 +17,7 @@ """ -def print_code_interpreter_inputs(response: AgentRunResponse) -> None: +def print_code_interpreter_inputs(response: AgentResponse) -> None: """Helper method to access code interpreter data.""" print("\nCode Interpreter Inputs during the run:") @@ -39,16 +39,16 @@ async def main() -> None: # authentication option. async with ( AzureCliCredential() as credential, - AzureAIAgentClient(credential=credential) as chat_client, + AzureAIAgentsProvider(credential=credential) as provider, ): - agent = chat_client.create_agent( + agent = await provider.create_agent( name="CodingAgent", instructions=("You are a helpful assistant that can write and execute Python code to solve problems."), tools=HostedCodeInterpreterTool(), ) query = "Generate the factorial of 100 using python code, show the code and execute it." print(f"User: {query}") - response = await AgentRunResponse.from_agent_response_generator(agent.run_stream(query)) + response = await agent.run(query) print(f"Agent: {response}") # To review the code interpreter outputs, you can access # them from the response raw_representations, just uncomment the next line: diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_code_interpreter_file_generation.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_code_interpreter_file_generation.py index cbd64bc5a7..665c707adc 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_code_interpreter_file_generation.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_code_interpreter_file_generation.py @@ -1,15 +1,21 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio - -from agent_framework import AgentRunResponseUpdate, ChatAgent, HostedCodeInterpreterTool, HostedFileContent -from agent_framework.azure import AzureAIAgentClient +import os + +from agent_framework import ( + AgentResponseUpdate, + HostedCodeInterpreterTool, + HostedFileContent, +) +from agent_framework.azure import AzureAIAgentsProvider +from azure.ai.agents.aio import AgentsClient from azure.identity.aio import AzureCliCredential """ Azure AI Agent Code Interpreter File Generation Example -This sample demonstrates using HostedCodeInterpreterTool with AzureAIAgentClient +This sample demonstrates using HostedCodeInterpreterTool with AzureAIAgentsProvider to generate a text file and then retrieve it. The test flow: @@ -23,79 +29,77 @@ async def main() -> None: """Test file generation and retrieval with code interpreter.""" - async with AzureCliCredential() as credential: - client = AzureAIAgentClient(credential=credential) - - try: - async with ChatAgent( - chat_client=client, - instructions=( - "You are a Python code execution assistant. " - "ALWAYS use the code interpreter tool to execute Python code when asked to create files. " - "Write actual Python code to create files, do not just describe what you would do." - ), - tools=[HostedCodeInterpreterTool()], - ) as agent: - # Be very explicit about wanting code execution and a download link - query = ( - "Use the code interpreter to execute this Python code and then provide me " - "with a download link for the generated file:\n" - "```python\n" - "with open('/mnt/data/sample.txt', 'w') as f:\n" - " f.write('Hello, World! This is a test file.')\n" - "'/mnt/data/sample.txt'\n" # Return the path so it becomes downloadable - "```" - ) - print(f"User: {query}\n") - print("=" * 60) - - # Collect file_ids from the response - file_ids: list[str] = [] - - async for chunk in agent.run_stream(query): - if not isinstance(chunk, AgentRunResponseUpdate): - continue - - for content in chunk.contents: - if content.type == "text": - print(content.text, end="", flush=True) - elif content.type == "hosted_file": - if isinstance(content, HostedFileContent): - file_ids.append(content.file_id) - print(f"\n[File generated: {content.file_id}]") - - print("\n" + "=" * 60) - - # Attempt to retrieve discovered files - if file_ids: - print(f"\nAttempting to retrieve {len(file_ids)} file(s):") - for file_id in file_ids: - try: - file_info = await client.agents_client.files.get(file_id) - print(f" File {file_id}: Retrieved successfully") - print(f" Filename: {file_info.filename}") - print(f" Purpose: {file_info.purpose}") - print(f" Bytes: {file_info.bytes}") - except Exception as e: - print(f" File {file_id}: FAILED to retrieve - {e}") - else: - print("No file IDs were captured from the response.") - - # List all files to see if any exist - print("\nListing all files in the agent service:") + async with ( + AzureCliCredential() as credential, + AgentsClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as agents_client, + AzureAIAgentsProvider(agents_client=agents_client) as provider, + ): + agent = await provider.create_agent( + name="CodeInterpreterAgent", + instructions=( + "You are a Python code execution assistant. " + "ALWAYS use the code interpreter tool to execute Python code when asked to create files. " + "Write actual Python code to create files, do not just describe what you would do." + ), + tools=[HostedCodeInterpreterTool()], + ) + + # Be very explicit about wanting code execution and a download link + query = ( + "Use the code interpreter to execute this Python code and then provide me " + "with a download link for the generated file:\n" + "```python\n" + "with open('/mnt/data/sample.txt', 'w') as f:\n" + " f.write('Hello, World! This is a test file.')\n" + "'/mnt/data/sample.txt'\n" # Return the path so it becomes downloadable + "```" + ) + print(f"User: {query}\n") + print("=" * 60) + + # Collect file_ids from the response + file_ids: list[str] = [] + + async for chunk in agent.run_stream(query): + if not isinstance(chunk, AgentResponseUpdate): + continue + + for content in chunk.contents: + if content.type == "text": + print(content.text, end="", flush=True) + elif content.type == "hosted_file" and isinstance(content, HostedFileContent): + file_ids.append(content.file_id) + print(f"\n[File generated: {content.file_id}]") + + print("\n" + "=" * 60) + + # Attempt to retrieve discovered files + if file_ids: + print(f"\nAttempting to retrieve {len(file_ids)} file(s):") + for file_id in file_ids: try: - files_list = await client.agents_client.files.list() - count = 0 - for file_info in files_list.data: - count += 1 - print(f" - {file_info.id}: {file_info.filename} ({file_info.purpose})") - if count == 0: - print(" No files found.") + file_info = await agents_client.files.get(file_id) + print(f" File {file_id}: Retrieved successfully") + print(f" Filename: {file_info.filename}") + print(f" Purpose: {file_info.purpose}") + print(f" Bytes: {file_info.bytes}") except Exception as e: - print(f" Failed to list files: {e}") + print(f" File {file_id}: FAILED to retrieve - {e}") + else: + print("No file IDs were captured from the response.") - finally: - await client.close() + # List all files to see if any exist + print("\nListing all files in the agent service:") + try: + files_list = await agents_client.files.list() + count = 0 + for file_info in files_list.data: + count += 1 + print(f" - {file_info.id}: {file_info.filename} ({file_info.purpose})") + if count == 0: + print(" No files found.") + except Exception as e: + print(f" Failed to list files: {e}") if __name__ == "__main__": diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_existing_agent.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_existing_agent.py index f35ac2412a..9518498098 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_existing_agent.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_existing_agent.py @@ -3,8 +3,7 @@ import asyncio import os -from agent_framework import ChatAgent -from agent_framework.azure import AzureAIAgentClient +from agent_framework.azure import AzureAIAgentsProvider from azure.ai.agents.aio import AgentsClient from azure.identity.aio import AzureCliCredential @@ -17,37 +16,29 @@ async def main() -> None: - print("=== Azure AI Chat Client with Existing Agent ===") + print("=== Azure AI Agent with Existing Agent ===") - # Create the client + # Create the client and provider async with ( AzureCliCredential() as credential, AgentsClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as agents_client, + AzureAIAgentsProvider(agents_client=agents_client) as provider, ): + # Create an agent on the service with default instructions + # These instructions will persist on created agent for every run. azure_ai_agent = await agents_client.create_agent( model=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], - # Create remote agent with default instructions - # These instructions will persist on created agent for every run. instructions="End each response with [END].", ) - chat_client = AzureAIAgentClient(agents_client=agents_client, agent_id=azure_ai_agent.id) - try: - async with ChatAgent( - chat_client=chat_client, - # Instructions here are applicable only to this ChatAgent instance - # These instructions will be combined with instructions on existing remote agent. - # The final instructions during the execution will look like: - # "'End each response with [END]. Respond with 'Hello World' only'" - instructions="Respond with 'Hello World' only", - ) as agent: - query = "How are you?" - print(f"User: {query}") - result = await agent.run(query) - # Based on local and remote instructions, the result will be - # 'Hello World [END]'. - print(f"Agent: {result}\n") + # Wrap existing agent instance using provider.as_agent() + agent = provider.as_agent(azure_ai_agent) + + query = "How are you?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result}\n") finally: # Clean up the agent manually await agents_client.delete_agent(azure_ai_agent.id) diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_existing_thread.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_existing_thread.py index b96b6e5686..a05aca5eba 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_existing_thread.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_existing_thread.py @@ -5,8 +5,7 @@ from random import randint from typing import Annotated -from agent_framework import ChatAgent -from agent_framework.azure import AzureAIAgentClient +from agent_framework.azure import AzureAIAgentsProvider from azure.ai.agents.aio import AgentsClient from azure.identity.aio import AzureCliCredential from pydantic import Field @@ -28,28 +27,29 @@ def get_weather( async def main() -> None: - print("=== Azure AI Chat Client with Existing Thread ===") + print("=== Azure AI Agent with Existing Thread ===") - # Create the client + # Create the client and provider async with ( AzureCliCredential() as credential, AgentsClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as agents_client, + AzureAIAgentsProvider(agents_client=agents_client) as provider, ): - # Create an thread that will persist + # Create a thread that will persist created_thread = await agents_client.threads.create() try: - async with ChatAgent( - # passing in the client is optional here, so if you take the agent_id from the portal - # you can use it directly without the two lines above. - chat_client=AzureAIAgentClient(agents_client=agents_client), + # Create agent using provider + agent = await provider.create_agent( + name="WeatherAgent", instructions="You are a helpful weather agent.", tools=get_weather, - ) as agent: - thread = agent.get_new_thread(service_thread_id=created_thread.id) - assert thread.is_initialized - result = await agent.run("What's the weather like in Tokyo?", thread=thread) - print(f"Result: {result}\n") + ) + + thread = agent.get_new_thread(service_thread_id=created_thread.id) + assert thread.is_initialized + result = await agent.run("What's the weather like in Tokyo?", thread=thread) + print(f"Result: {result}\n") finally: # Clean up the thread manually await agents_client.threads.delete(created_thread.id) diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_explicit_settings.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_explicit_settings.py index 14bb063149..bb0405cd6f 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_explicit_settings.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_explicit_settings.py @@ -5,8 +5,7 @@ from random import randint from typing import Annotated -from agent_framework import ChatAgent -from agent_framework.azure import AzureAIAgentClient +from agent_framework.azure import AzureAIAgentsProvider from azure.identity.aio import AzureCliCredential from pydantic import Field @@ -27,26 +26,23 @@ def get_weather( async def main() -> None: - print("=== Azure AI Chat Client with Explicit Settings ===") + print("=== Azure AI Agent with Explicit Settings ===") - # Since no Agent ID is provided, the agent will be automatically created - # and deleted after getting a response # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. async with ( AzureCliCredential() as credential, - ChatAgent( - chat_client=AzureAIAgentClient( - project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], - model_deployment_name=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], - credential=credential, - agent_name="WeatherAgent", - should_cleanup_agent=True, # Set to False if you want to disable automatic agent cleanup - ), + AzureAIAgentsProvider( + project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], + credential=credential, + ) as provider, + ): + agent = await provider.create_agent( + name="WeatherAgent", + model=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], instructions="You are a helpful weather agent.", tools=get_weather, - ) as agent, - ): + ) result = await agent.run("What's the weather like in New York?") print(f"Result: {result}\n") diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_file_search.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_file_search.py index 8be9b79423..63845b215b 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_file_search.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_file_search.py @@ -1,10 +1,12 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import os from pathlib import Path -from agent_framework import ChatAgent, HostedFileSearchTool, HostedVectorStoreContent -from agent_framework.azure import AzureAIAgentClient +from agent_framework import HostedFileSearchTool, HostedVectorStoreContent +from agent_framework.azure import AzureAIAgentsProvider +from azure.ai.agents.aio import AgentsClient from azure.ai.agents.models import FileInfo, VectorStore from azure.identity.aio import AzureCliCredential @@ -24,67 +26,54 @@ async def main() -> None: """Main function demonstrating Azure AI agent with file search capabilities.""" - client = AzureAIAgentClient(credential=AzureCliCredential()) file: FileInfo | None = None vector_store: VectorStore | None = None - try: - # 1. Upload file and create vector store - pdf_file_path = Path(__file__).parent.parent / "resources" / "employees.pdf" - print(f"Uploading file from: {pdf_file_path}") - - file = await client.agents_client.files.upload_and_poll(file_path=str(pdf_file_path), purpose="assistants") - print(f"Uploaded file, file ID: {file.id}") - - vector_store = await client.agents_client.vector_stores.create_and_poll( - file_ids=[file.id], name="my_vectorstore" - ) - print(f"Created vector store, vector store ID: {vector_store.id}") - - # 2. Create file search tool with uploaded resources - file_search_tool = HostedFileSearchTool(inputs=[HostedVectorStoreContent(vector_store_id=vector_store.id)]) - - # 3. Create an agent with file search capabilities - # The tool_resources are automatically extracted from HostedFileSearchTool - async with ChatAgent( - chat_client=client, - name="EmployeeSearchAgent", - instructions=( - "You are a helpful assistant that can search through uploaded employee files " - "to answer questions about employees." - ), - tools=file_search_tool, - ) as agent: + async with ( + AzureCliCredential() as credential, + AgentsClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as agents_client, + AzureAIAgentsProvider(agents_client=agents_client) as provider, + ): + try: + # 1. Upload file and create vector store + pdf_file_path = Path(__file__).parent.parent / "resources" / "employees.pdf" + print(f"Uploading file from: {pdf_file_path}") + + file = await agents_client.files.upload_and_poll(file_path=str(pdf_file_path), purpose="assistants") + print(f"Uploaded file, file ID: {file.id}") + + vector_store = await agents_client.vector_stores.create_and_poll(file_ids=[file.id], name="my_vectorstore") + print(f"Created vector store, vector store ID: {vector_store.id}") + + # 2. Create file search tool with uploaded resources + file_search_tool = HostedFileSearchTool(inputs=[HostedVectorStoreContent(vector_store_id=vector_store.id)]) + + # 3. Create an agent with file search capabilities + agent = await provider.create_agent( + name="EmployeeSearchAgent", + instructions=( + "You are a helpful assistant that can search through uploaded employee files " + "to answer questions about employees." + ), + tools=file_search_tool, + ) + # 4. Simulate conversation with the agent for user_input in USER_INPUTS: print(f"# User: '{user_input}'") response = await agent.run(user_input) print(f"# Agent: {response.text}") + finally: # 5. Cleanup: Delete the vector store and file try: if vector_store: - await client.agents_client.vector_stores.delete(vector_store.id) + await agents_client.vector_stores.delete(vector_store.id) if file: - await client.agents_client.files.delete(file.id) + await agents_client.files.delete(file.id) except Exception: # Ignore cleanup errors to avoid masking issues pass - finally: - # 6. Cleanup: Delete the vector store and file in case of earlier failure to prevent orphaned resources. - - # Refreshing the client is required since chat agent closes it - client = AzureAIAgentClient(credential=AzureCliCredential()) - try: - if vector_store: - await client.agents_client.vector_stores.delete(vector_store.id) - if file: - await client.agents_client.files.delete(file.id) - except Exception: - # Ignore cleanup errors to avoid masking issues - pass - finally: - await client.close() if __name__ == "__main__": diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_function_tools.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_function_tools.py index a301557612..1e2e0b618b 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_function_tools.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_function_tools.py @@ -5,8 +5,7 @@ from random import randint from typing import Annotated -from agent_framework import ChatAgent -from agent_framework.azure import AzureAIAgentClient +from agent_framework.azure import AzureAIAgentsProvider from azure.identity.aio import AzureCliCredential from pydantic import Field @@ -42,12 +41,14 @@ async def tools_on_agent_level() -> None: # authentication option. async with ( AzureCliCredential() as credential, - ChatAgent( - chat_client=AzureAIAgentClient(credential=credential), + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( + name="AssistantAgent", instructions="You are a helpful assistant that can provide weather and time information.", tools=[get_weather, get_time], # Tools defined at agent creation - ) as agent, - ): + ) + # First query - agent can use weather tool query1 = "What's the weather like in New York?" print(f"User: {query1}") @@ -76,12 +77,14 @@ async def tools_on_run_level() -> None: # authentication option. async with ( AzureCliCredential() as credential, - ChatAgent( - chat_client=AzureAIAgentClient(credential=credential), + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( + name="AssistantAgent", instructions="You are a helpful assistant.", # No tools defined here - ) as agent, - ): + ) + # First query with weather tool query1 = "What's the weather like in Seattle?" print(f"User: {query1}") @@ -110,12 +113,14 @@ async def mixed_tools_example() -> None: # authentication option. async with ( AzureCliCredential() as credential, - ChatAgent( - chat_client=AzureAIAgentClient(credential=credential), + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( + name="AssistantAgent", instructions="You are a comprehensive assistant that can help with various information requests.", tools=[get_weather], # Base tool available for all queries - ) as agent, - ): + ) + # Query using both agent tool and additional run-method tools query = "What's the weather in Denver and what's the current UTC time?" print(f"User: {query}") diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_hosted_mcp.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_hosted_mcp.py index 10a5a68031..71ab02b279 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_hosted_mcp.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_hosted_mcp.py @@ -3,8 +3,8 @@ import asyncio from typing import Any -from agent_framework import AgentProtocol, AgentRunResponse, AgentThread, HostedMCPTool -from agent_framework.azure import AzureAIAgentClient +from agent_framework import AgentProtocol, AgentResponse, AgentThread, HostedMCPTool +from agent_framework.azure import AzureAIAgentsProvider from azure.identity.aio import AzureCliCredential """ @@ -15,7 +15,7 @@ """ -async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", thread: "AgentThread") -> AgentRunResponse: +async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", thread: "AgentThread") -> AgentResponse: """Here we let the thread deal with the previous responses, and we just rerun with the approval.""" from agent_framework import ChatMessage @@ -42,9 +42,9 @@ async def main() -> None: """Example showing Hosted MCP tools for a Azure AI Agent.""" async with ( AzureCliCredential() as credential, - AzureAIAgentClient(credential=credential) as chat_client, + AzureAIAgentsProvider(credential=credential) as provider, ): - agent = chat_client.create_agent( + agent = await provider.create_agent( name="DocsAgent", instructions="You are a helpful assistant that can help with microsoft documentation questions.", tools=HostedMCPTool( diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_local_mcp.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_local_mcp.py index fb4f49e47e..0586ffb78e 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_local_mcp.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_local_mcp.py @@ -2,8 +2,8 @@ import asyncio -from agent_framework import ChatAgent, MCPStreamableHTTPTool -from agent_framework.azure import AzureAIAgentClient +from agent_framework import MCPStreamableHTTPTool +from agent_framework.azure import AzureAIAgentsProvider from azure.identity.aio import AzureCliCredential """ @@ -27,12 +27,12 @@ async def mcp_tools_on_run_level() -> None: name="Microsoft Learn MCP", url="https://learn.microsoft.com/api/mcp", ) as mcp_server, - ChatAgent( - chat_client=AzureAIAgentClient(credential=credential), + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="DocsAgent", instructions="You are a helpful assistant that can help with microsoft documentation questions.", - ) as agent, - ): + ) # First query query1 = "How to create an Azure storage account using az cli?" print(f"User: {query1}") @@ -47,34 +47,37 @@ async def mcp_tools_on_run_level() -> None: async def mcp_tools_on_agent_level() -> None: - """Example showing tools defined when creating the agent.""" + """Example showing local MCP tools passed when creating the agent.""" print("=== Tools Defined on Agent Level ===") # Tools are provided when creating the agent - # The agent can use these tools for any query during its lifetime - # The agent will connect to the MCP server through its context manager. + # The ChatAgent will connect to the MCP server through its context manager + # and discover tools at runtime async with ( AzureCliCredential() as credential, - AzureAIAgentClient(credential=credential).create_agent( + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( name="DocsAgent", instructions="You are a helpful assistant that can help with microsoft documentation questions.", - tools=MCPStreamableHTTPTool( # Tools defined at agent creation + tools=MCPStreamableHTTPTool( name="Microsoft Learn MCP", url="https://learn.microsoft.com/api/mcp", ), - ) as agent, - ): - # First query - query1 = "How to create an Azure storage account using az cli?" - print(f"User: {query1}") - result1 = await agent.run(query1) - print(f"{agent.name}: {result1}\n") - print("\n=======================================\n") - # Second query - query2 = "What is Microsoft Agent Framework?" - print(f"User: {query2}") - result2 = await agent.run(query2) - print(f"{agent.name}: {result2}\n") + ) + # Use agent as context manager to connect MCP tools + async with agent: + # First query + query1 = "How to create an Azure storage account using az cli?" + print(f"User: {query1}") + result1 = await agent.run(query1) + print(f"{agent.name}: {result1}\n") + print("\n=======================================\n") + # Second query + query2 = "What is Microsoft Agent Framework?" + print(f"User: {query2}") + result2 = await agent.run(query2) + print(f"{agent.name}: {result2}\n") async def main() -> None: diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_multiple_tools.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_multiple_tools.py index ab29d85971..e3c28118be 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_multiple_tools.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_multiple_tools.py @@ -10,7 +10,7 @@ HostedMCPTool, HostedWebSearchTool, ) -from agent_framework.azure import AzureAIAgentClient +from agent_framework.azure import AzureAIAgentsProvider from azure.identity.aio import AzureCliCredential """ @@ -67,9 +67,9 @@ async def main() -> None: """Example showing Hosted MCP tools for a Azure AI Agent.""" async with ( AzureCliCredential() as credential, - AzureAIAgentClient(credential=credential) as chat_client, + AzureAIAgentsProvider(credential=credential) as provider, ): - agent = chat_client.create_agent( + agent = await provider.create_agent( name="DocsAgent", instructions="You are a helpful assistant that can help with microsoft documentation questions.", tools=[ diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_openapi_tools.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_openapi_tools.py index 4b4db76f6b..24fd8eba9a 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_openapi_tools.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_openapi_tools.py @@ -5,8 +5,7 @@ from pathlib import Path from typing import Any -from agent_framework import ChatAgent -from agent_framework_azure_ai import AzureAIAgentClient +from agent_framework.azure import AzureAIAgentsProvider from azure.ai.agents.models import OpenApiAnonymousAuthDetails, OpenApiTool from azure.identity.aio import AzureCliCredential @@ -40,8 +39,11 @@ async def main() -> None: # 1. Load OpenAPI specifications (synchronous operation) weather_openapi_spec, countries_openapi_spec = load_openapi_specs() - # 2. Use AzureAIAgentClient as async context manager for automatic cleanup - async with AzureAIAgentClient(credential=AzureCliCredential()) as client: + # 2. Use AzureAIAgentsProvider for agent creation and management + async with ( + AzureCliCredential() as credential, + AzureAIAgentsProvider(credential=credential) as provider, + ): # 3. Create OpenAPI tools using Azure AI's OpenApiTool auth = OpenApiAnonymousAuthDetails() @@ -62,8 +64,7 @@ async def main() -> None: # 4. Create an agent with OpenAPI tools # Note: We need to pass the Azure AI native OpenApiTool definitions directly # since the agent framework doesn't have a HostedOpenApiTool wrapper yet - async with ChatAgent( - chat_client=client, + agent = await provider.create_agent( name="OpenAPIAgent", instructions=( "You are a helpful assistant that can search for country information " @@ -73,18 +74,19 @@ async def main() -> None: ), # Pass the raw tool definitions from Azure AI's OpenApiTool tools=[*openapi_countries.definitions, *openapi_weather.definitions], - ) as agent: - # 5. Simulate conversation with the agent maintaining thread context - print("=== Azure AI Agent with OpenAPI Tools ===\n") - - # Create a thread to maintain conversation context across multiple runs - thread = agent.get_new_thread() - - for user_input in USER_INPUTS: - print(f"User: {user_input}") - # Pass the thread to maintain context across multiple agent.run() calls - response = await agent.run(user_input, thread=thread) - print(f"Agent: {response.text}\n") + ) + + # 5. Simulate conversation with the agent maintaining thread context + print("=== Azure AI Agent with OpenAPI Tools ===\n") + + # Create a thread to maintain conversation context across multiple runs + thread = agent.get_new_thread() + + for user_input in USER_INPUTS: + print(f"User: {user_input}") + # Pass the thread to maintain context across multiple agent.run() calls + response = await agent.run(user_input, thread=thread) + print(f"Agent: {response.text}\n") if __name__ == "__main__": diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_response_format.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_response_format.py new file mode 100644 index 0000000000..639ce6ac82 --- /dev/null +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_response_format.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio + +from agent_framework.azure import AzureAIAgentsProvider +from azure.identity.aio import AzureCliCredential +from pydantic import BaseModel, ConfigDict + +""" +Azure AI Agent Provider Response Format Example + +This sample demonstrates using AzureAIAgentsProvider with response_format +for structured outputs in two ways: +1. Setting default response_format at agent creation time (default_options) +2. Overriding response_format at runtime (options parameter in agent.run) +""" + + +class WeatherInfo(BaseModel): + """Structured weather information.""" + + location: str + temperature: int + conditions: str + recommendation: str + model_config = ConfigDict(extra="forbid") + + +class CityInfo(BaseModel): + """Structured city information.""" + + city_name: str + population: int + country: str + model_config = ConfigDict(extra="forbid") + + +async def main() -> None: + """Example of using response_format at creation time and runtime.""" + + async with ( + AzureCliCredential() as credential, + AzureAIAgentsProvider(credential=credential) as provider, + ): + # Create agent with default response_format (WeatherInfo) + agent = await provider.create_agent( + name="StructuredReporter", + instructions="Return structured JSON based on the requested format.", + default_options={"response_format": WeatherInfo}, + ) + + # Request 1: Uses default response_format from agent creation + print("--- Request 1: Using default response_format (WeatherInfo) ---") + query1 = "What's the weather like in Paris today?" + print(f"User: {query1}") + + result1 = await agent.run(query1) + + if isinstance(result1.value, WeatherInfo): + weather = result1.value + print("Agent:") + print(f" Location: {weather.location}") + print(f" Temperature: {weather.temperature}") + print(f" Conditions: {weather.conditions}") + print(f" Recommendation: {weather.recommendation}") + + # Request 2: Override response_format at runtime with CityInfo + print("\n--- Request 2: Runtime override with CityInfo ---") + query2 = "Tell me about Tokyo." + print(f"User: {query2}") + + result2 = await agent.run(query2, options={"response_format": CityInfo}) + + if isinstance(result2.value, CityInfo): + city = result2.value + print("Agent:") + print(f" City: {city.city_name}") + print(f" Population: {city.population}") + print(f" Country: {city.country}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_thread.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_thread.py index fbc34e52df..db1911fcad 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_thread.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_thread.py @@ -4,8 +4,8 @@ from random import randint from typing import Annotated -from agent_framework import AgentThread, ChatAgent -from agent_framework.azure import AzureAIAgentClient +from agent_framework import AgentThread +from agent_framework.azure import AzureAIAgentsProvider from azure.identity.aio import AzureCliCredential from pydantic import Field @@ -33,12 +33,14 @@ async def example_with_automatic_thread_creation() -> None: # authentication option. async with ( AzureCliCredential() as credential, - ChatAgent( - chat_client=AzureAIAgentClient(credential=credential), + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( + name="WeatherAgent", instructions="You are a helpful weather agent.", tools=get_weather, - ) as agent, - ): + ) + # First conversation - no thread provided, will be created automatically first_query = "What's the weather like in Seattle?" print(f"User: {first_query}") @@ -62,12 +64,14 @@ async def example_with_thread_persistence() -> None: # authentication option. async with ( AzureCliCredential() as credential, - ChatAgent( - chat_client=AzureAIAgentClient(credential=credential), + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( + name="WeatherAgent", instructions="You are a helpful weather agent.", tools=get_weather, - ) as agent, - ): + ) + # Create a new thread that will be reused thread = agent.get_new_thread() @@ -103,12 +107,14 @@ async def example_with_existing_thread_id() -> None: # authentication option. async with ( AzureCliCredential() as credential, - ChatAgent( - chat_client=AzureAIAgentClient(credential=credential), + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( + name="WeatherAgent", instructions="You are a helpful weather agent.", tools=get_weather, - ) as agent, - ): + ) + # Start a conversation and get the thread ID thread = agent.get_new_thread() first_query = "What's the weather in Paris?" @@ -123,15 +129,17 @@ async def example_with_existing_thread_id() -> None: if existing_thread_id: print("\n--- Continuing with the same thread ID in a new agent instance ---") - # Create a new agent instance but use the existing thread ID + # Create a new provider and agent but use the existing thread ID async with ( AzureCliCredential() as credential, - ChatAgent( - chat_client=AzureAIAgentClient(thread_id=existing_thread_id, credential=credential), + AzureAIAgentsProvider(credential=credential) as provider, + ): + agent = await provider.create_agent( + name="WeatherAgent", instructions="You are a helpful weather agent.", tools=get_weather, - ) as agent, - ): + ) + # Create a thread with the existing ID thread = AgentThread(service_thread_id=existing_thread_id) diff --git a/python/samples/getting_started/agents/azure_openai/azure_assistants_with_code_interpreter.py b/python/samples/getting_started/agents/azure_openai/azure_assistants_with_code_interpreter.py index af07cabd75..b37af8f8de 100644 --- a/python/samples/getting_started/agents/azure_openai/azure_assistants_with_code_interpreter.py +++ b/python/samples/getting_started/agents/azure_openai/azure_assistants_with_code_interpreter.py @@ -2,7 +2,7 @@ import asyncio -from agent_framework import AgentRunResponseUpdate, ChatAgent, ChatResponseUpdate, HostedCodeInterpreterTool +from agent_framework import AgentResponseUpdate, ChatAgent, ChatResponseUpdate, HostedCodeInterpreterTool from agent_framework.azure import AzureOpenAIAssistantsClient from azure.identity import AzureCliCredential from openai.types.beta.threads.runs import ( @@ -21,7 +21,7 @@ """ -def get_code_interpreter_chunk(chunk: AgentRunResponseUpdate) -> str | None: +def get_code_interpreter_chunk(chunk: AgentResponseUpdate) -> str | None: """Helper method to access code interpreter data.""" if ( isinstance(chunk.raw_representation, ChatResponseUpdate) diff --git a/python/samples/getting_started/agents/custom/custom_agent.py b/python/samples/getting_started/agents/custom/custom_agent.py index 9c6e790513..5dc050a1b5 100644 --- a/python/samples/getting_started/agents/custom/custom_agent.py +++ b/python/samples/getting_started/agents/custom/custom_agent.py @@ -5,8 +5,8 @@ from typing import Any from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, + AgentResponse, + AgentResponseUpdate, AgentThread, BaseAgent, ChatMessage, @@ -60,7 +60,7 @@ async def run( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: + ) -> AgentResponse: """Execute the agent and return a complete response. Args: @@ -69,7 +69,7 @@ async def run( **kwargs: Additional keyword arguments. Returns: - An AgentRunResponse containing the agent's reply. + An AgentResponse containing the agent's reply. """ # Normalize input messages to a list normalized_messages = self._normalize_messages(messages) @@ -93,7 +93,7 @@ async def run( if thread is not None: await self._notify_thread_of_new_messages(thread, normalized_messages, response_message) - return AgentRunResponse(messages=[response_message]) + return AgentResponse(messages=[response_message]) async def run_stream( self, @@ -101,7 +101,7 @@ async def run_stream( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: + ) -> AsyncIterable[AgentResponseUpdate]: """Execute the agent and yield streaming response updates. Args: @@ -110,7 +110,7 @@ async def run_stream( **kwargs: Additional keyword arguments. Yields: - AgentRunResponseUpdate objects containing chunks of the response. + AgentResponseUpdate objects containing chunks of the response. """ # Normalize input messages to a list normalized_messages = self._normalize_messages(messages) @@ -131,7 +131,7 @@ async def run_stream( # Add space before word except for the first one chunk_text = f" {word}" if i > 0 else word - yield AgentRunResponseUpdate( + yield AgentResponseUpdate( contents=[TextContent(text=chunk_text)], role=Role.ASSISTANT, ) @@ -158,7 +158,6 @@ async def main() -> None: # Test non-streaming print(f"Agent Name: {echo_agent.name}") print(f"Agent ID: {echo_agent.id}") - print(f"Display Name: {echo_agent.display_name}") query = "Hello, custom agent!" print(f"\nUser: {query}") diff --git a/python/samples/getting_started/agents/custom/custom_chat_client.py b/python/samples/getting_started/agents/custom/custom_chat_client.py index 5cad52c755..f604571470 100644 --- a/python/samples/getting_started/agents/custom/custom_chat_client.py +++ b/python/samples/getting_started/agents/custom/custom_chat_client.py @@ -2,13 +2,13 @@ import asyncio import random +import sys from collections.abc import AsyncIterable, MutableSequence -from typing import Any, ClassVar +from typing import Any, ClassVar, Generic from agent_framework import ( BaseChatClient, ChatMessage, - ChatOptions, ChatResponse, ChatResponseUpdate, Role, @@ -16,6 +16,12 @@ use_chat_middleware, use_function_invocation, ) +from agent_framework._clients import TOptions_co + +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover """ Custom Chat Client Implementation Example @@ -27,7 +33,7 @@ @use_function_invocation @use_chat_middleware -class EchoingChatClient(BaseChatClient): +class EchoingChatClient(BaseChatClient[TOptions_co], Generic[TOptions_co]): """A custom chat client that echoes messages back with modifications. This demonstrates how to implement a custom chat client by extending BaseChatClient @@ -46,11 +52,12 @@ def __init__(self, *, prefix: str = "Echo:", **kwargs: Any) -> None: super().__init__(**kwargs) self.prefix = prefix + @override async def _inner_get_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> ChatResponse: """Echo back the user's message with a prefix.""" @@ -77,16 +84,17 @@ async def _inner_get_response( response_id=f"echo-resp-{random.randint(1000, 9999)}", ) + @override async def _inner_get_streaming_response( self, *, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions, + options: dict[str, Any], **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: """Stream back the echoed message character by character.""" # Get the complete response first - response = await self._inner_get_response(messages=messages, chat_options=chat_options, **kwargs) + response = await self._inner_get_response(messages=messages, options=options, **kwargs) if response.messages: response_text = response.messages[0].text or "" @@ -123,7 +131,6 @@ async def main() -> None: ) print(f"\nAgent Name: {echo_agent.name}") - print(f"Agent Display Name: {echo_agent.display_name}") # Test non-streaming with agent query = "This is a test message" diff --git a/python/samples/getting_started/agents/ollama/README.md b/python/samples/getting_started/agents/ollama/README.md index ac4b2cb3d0..2a10ae2f57 100644 --- a/python/samples/getting_started/agents/ollama/README.md +++ b/python/samples/getting_started/agents/ollama/README.md @@ -40,8 +40,8 @@ Set the following environment variables: - `OLLAMA_HOST`: The base URL for your Ollama server (optional, defaults to `http://localhost:11434`) - Example: `export OLLAMA_HOST="http://localhost:11434"` -- `OLLAMA_CHAT_MODEL_ID`: The model name to use - - Example: `export OLLAMA_CHAT_MODEL_ID="qwen2.5:8b"` +- `OLLAMA_MODEL_ID`: The model name to use + - Example: `export OLLAMA_MODEL_ID="qwen2.5:8b"` - Must be a model you have pulled with Ollama ### For OpenAI Client with Ollama (`ollama_with_openai_chat_client.py`) diff --git a/python/samples/getting_started/agents/ollama/ollama_agent_basic.py b/python/samples/getting_started/agents/ollama/ollama_agent_basic.py index 4d2a69b56b..a0c49acea4 100644 --- a/python/samples/getting_started/agents/ollama/ollama_agent_basic.py +++ b/python/samples/getting_started/agents/ollama/ollama_agent_basic.py @@ -12,7 +12,7 @@ Ensure to install Ollama and have a model running locally before running the sample Not all Models support function calling, to test function calling try llama3.2 or qwen3:4b -Set the model to use via the OLLAMA_CHAT_MODEL_ID environment variable or modify the code below. +Set the model to use via the OLLAMA_MODEL_ID environment variable or modify the code below. https://ollama.com/ """ diff --git a/python/samples/getting_started/agents/ollama/ollama_agent_reasoning.py b/python/samples/getting_started/agents/ollama/ollama_agent_reasoning.py index e0ce24bb85..21deddd857 100644 --- a/python/samples/getting_started/agents/ollama/ollama_agent_reasoning.py +++ b/python/samples/getting_started/agents/ollama/ollama_agent_reasoning.py @@ -12,7 +12,7 @@ Ensure to install Ollama and have a model running locally before running the sample Not all Models support reasoning, to test reasoning try qwen3:8b -Set the model to use via the OLLAMA_CHAT_MODEL_ID environment variable or modify the code below. +Set the model to use via the OLLAMA_MODEL_ID environment variable or modify the code below. https://ollama.com/ """ @@ -24,7 +24,7 @@ async def reasoning_example() -> None: agent = OllamaChatClient().create_agent( name="TimeAgent", instructions="You are a helpful agent answer in one sentence.", - additional_chat_options={"think": True}, # Enable Reasoning on agent level + default_options={"think": True}, # Enable Reasoning on agent level ) query = "Hey what is 3+4? Can you explain how you got to that answer?" print(f"User: {query}") diff --git a/python/samples/getting_started/agents/ollama/ollama_chat_client.py b/python/samples/getting_started/agents/ollama/ollama_chat_client.py index 5d7197d8f5..336a79c721 100644 --- a/python/samples/getting_started/agents/ollama/ollama_chat_client.py +++ b/python/samples/getting_started/agents/ollama/ollama_chat_client.py @@ -12,7 +12,7 @@ Ensure to install Ollama and have a model running locally before running the sample. Not all Models support function calling, to test function calling try llama3.2 -Set the model to use via the OLLAMA_CHAT_MODEL_ID environment variable or modify the code below. +Set the model to use via the OLLAMA_MODEL_ID environment variable or modify the code below. https://ollama.com/ """ diff --git a/python/samples/getting_started/agents/ollama/ollama_chat_multimodal.py b/python/samples/getting_started/agents/ollama/ollama_chat_multimodal.py index 724cecbe72..1b830c2692 100644 --- a/python/samples/getting_started/agents/ollama/ollama_chat_multimodal.py +++ b/python/samples/getting_started/agents/ollama/ollama_chat_multimodal.py @@ -12,7 +12,7 @@ Ensure to install Ollama and have a model running locally before running the sample Not all Models support multimodal input, to test multimodal input try gemma3:4b -Set the model to use via the OLLAMA_CHAT_MODEL_ID environment variable or modify the code below. +Set the model to use via the OLLAMA_MODEL_ID environment variable or modify the code below. https://ollama.com/ """ diff --git a/python/samples/getting_started/agents/openai/README.md b/python/samples/getting_started/agents/openai/README.md index bbe48fc436..d744531845 100644 --- a/python/samples/getting_started/agents/openai/README.md +++ b/python/samples/getting_started/agents/openai/README.md @@ -6,13 +6,15 @@ This folder contains examples demonstrating different ways to create and use age | File | Description | |------|-------------| -| [`openai_assistants_basic.py`](openai_assistants_basic.py) | The simplest way to create an agent using `ChatAgent` with `OpenAIAssistantsClient`. Shows both streaming and non-streaming responses with automatic assistant creation and cleanup. | -| [`openai_assistants_with_code_interpreter.py`](openai_assistants_with_code_interpreter.py) | Shows how to use the HostedCodeInterpreterTool with OpenAI agents to write and execute Python code. Includes helper methods for accessing code interpreter data from response chunks. | -| [`openai_assistants_with_existing_assistant.py`](openai_assistants_with_existing_assistant.py) | Shows how to work with a pre-existing assistant by providing the assistant ID to the OpenAI Assistants client. Demonstrates proper cleanup of manually created assistants. | -| [`openai_assistants_with_explicit_settings.py`](openai_assistants_with_explicit_settings.py) | Shows how to initialize an agent with a specific assistants client, configuring settings explicitly including API key and model ID. | -| [`openai_assistants_with_file_search.py`](openai_assistants_with_file_search.py) | Demonstrates how to use file search capabilities with OpenAI agents, allowing the agent to search through uploaded files to answer questions. | -| [`openai_assistants_with_function_tools.py`](openai_assistants_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). | -| [`openai_assistants_with_thread.py`](openai_assistants_with_thread.py) | Demonstrates thread management with OpenAI agents, including automatic thread creation for stateless conversations and explicit thread management for maintaining conversation context across multiple interactions. | +| [`openai_assistants_basic.py`](openai_assistants_basic.py) | Basic usage of `OpenAIAssistantProvider` with streaming and non-streaming responses. | +| [`openai_assistants_provider_methods.py`](openai_assistants_provider_methods.py) | Demonstrates all `OpenAIAssistantProvider` methods: `create_agent()`, `get_agent()`, and `as_agent()`. | +| [`openai_assistants_with_code_interpreter.py`](openai_assistants_with_code_interpreter.py) | Using `HostedCodeInterpreterTool` with `OpenAIAssistantProvider` to execute Python code. | +| [`openai_assistants_with_existing_assistant.py`](openai_assistants_with_existing_assistant.py) | Working with pre-existing assistants using `get_agent()` and `as_agent()` methods. | +| [`openai_assistants_with_explicit_settings.py`](openai_assistants_with_explicit_settings.py) | Configuring `OpenAIAssistantProvider` with explicit settings including API key and model ID. | +| [`openai_assistants_with_file_search.py`](openai_assistants_with_file_search.py) | Using `HostedFileSearchTool` with `OpenAIAssistantProvider` for file search capabilities. | +| [`openai_assistants_with_function_tools.py`](openai_assistants_with_function_tools.py) | Function tools with `OpenAIAssistantProvider` at both agent-level and query-level. | +| [`openai_assistants_with_response_format.py`](openai_assistants_with_response_format.py) | Structured outputs with `OpenAIAssistantProvider` using Pydantic models. | +| [`openai_assistants_with_thread.py`](openai_assistants_with_thread.py) | Thread management with `OpenAIAssistantProvider` for conversation context persistence. | | [`openai_chat_client_basic.py`](openai_chat_client_basic.py) | The simplest way to create an agent using `ChatAgent` with `OpenAIChatClient`. Shows both streaming and non-streaming responses for chat-based interactions with OpenAI models. | | [`openai_chat_client_with_explicit_settings.py`](openai_chat_client_with_explicit_settings.py) | Shows how to initialize an agent with a specific chat client, configuring settings explicitly including API key and model ID. | | [`openai_chat_client_with_function_tools.py`](openai_chat_client_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). | diff --git a/python/samples/getting_started/agents/openai/openai_assistants_basic.py b/python/samples/getting_started/agents/openai/openai_assistants_basic.py index 63ff7dd39b..4dee6f4672 100644 --- a/python/samples/getting_started/agents/openai/openai_assistants_basic.py +++ b/python/samples/getting_started/agents/openai/openai_assistants_basic.py @@ -1,16 +1,18 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import os from random import randint from typing import Annotated -from agent_framework.openai import OpenAIAssistantsClient +from agent_framework.openai import OpenAIAssistantProvider +from openai import AsyncOpenAI from pydantic import Field """ OpenAI Assistants Basic Example -This sample demonstrates basic usage of OpenAIAssistantsClient with automatic +This sample demonstrates basic usage of OpenAIAssistantProvider with automatic assistant lifecycle management, showing both streaming and non-streaming responses. """ @@ -20,35 +22,50 @@ def get_weather( ) -> str: """Get the weather for a given location.""" conditions = ["sunny", "cloudy", "rainy", "stormy"] - return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}C." async def non_streaming_example() -> None: """Example of non-streaming response (get the complete result at once).""" print("=== Non-streaming Response Example ===") - # Since no assistant ID is provided, the assistant will be automatically created - # and deleted after getting a response - async with OpenAIAssistantsClient().create_agent( + client = AsyncOpenAI() + provider = OpenAIAssistantProvider(client) + + # Create a new assistant via the provider + agent = await provider.create_agent( + name="WeatherAssistant", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), instructions="You are a helpful weather agent.", - tools=get_weather, - ) as agent: + tools=[get_weather], + ) + + try: query = "What's the weather like in Seattle?" print(f"User: {query}") result = await agent.run(query) print(f"Agent: {result}\n") + finally: + # Clean up the assistant from OpenAI + await client.beta.assistants.delete(agent.id) async def streaming_example() -> None: """Example of streaming response (get results as they are generated).""" print("=== Streaming Response Example ===") - # Since no assistant ID is provided, the assistant will be automatically created - # and deleted after getting a response - async with OpenAIAssistantsClient().create_agent( + client = AsyncOpenAI() + provider = OpenAIAssistantProvider(client) + + # Create a new assistant via the provider + agent = await provider.create_agent( + name="WeatherAssistant", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), instructions="You are a helpful weather agent.", - tools=get_weather, - ) as agent: + tools=[get_weather], + ) + + try: query = "What's the weather like in Portland?" print(f"User: {query}") print("Agent: ", end="", flush=True) @@ -56,10 +73,13 @@ async def streaming_example() -> None: if chunk.text: print(chunk.text, end="", flush=True) print("\n") + finally: + # Clean up the assistant from OpenAI + await client.beta.assistants.delete(agent.id) async def main() -> None: - print("=== Basic OpenAI Assistants Chat Client Agent Example ===") + print("=== Basic OpenAI Assistants Provider Example ===") await non_streaming_example() await streaming_example() diff --git a/python/samples/getting_started/agents/openai/openai_assistants_provider_methods.py b/python/samples/getting_started/agents/openai/openai_assistants_provider_methods.py new file mode 100644 index 0000000000..ca7133cc3d --- /dev/null +++ b/python/samples/getting_started/agents/openai/openai_assistants_provider_methods.py @@ -0,0 +1,149 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import os +from random import randint +from typing import Annotated + +from agent_framework.openai import OpenAIAssistantProvider +from openai import AsyncOpenAI +from pydantic import Field + +""" +OpenAI Assistant Provider Methods Example + +This sample demonstrates the methods available on the OpenAIAssistantProvider class: +- create_agent(): Create a new assistant on the service +- get_agent(): Retrieve an existing assistant by ID +- as_agent(): Wrap an SDK Assistant object without making HTTP calls +""" + + +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}C." + + +async def create_agent_example() -> None: + """Create a new assistant using provider.create_agent().""" + print("\n--- create_agent() ---") + + async with ( + AsyncOpenAI() as client, + OpenAIAssistantProvider(client) as provider, + ): + agent = await provider.create_agent( + name="WeatherAssistant", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), + instructions="You are a helpful weather assistant.", + tools=[get_weather], + ) + + try: + print(f"Created: {agent.name} (ID: {agent.id})") + result = await agent.run("What's the weather in Seattle?") + print(f"Response: {result}") + finally: + await client.beta.assistants.delete(agent.id) + + +async def get_agent_example() -> None: + """Retrieve an existing assistant by ID using provider.get_agent().""" + print("\n--- get_agent() ---") + + async with ( + AsyncOpenAI() as client, + OpenAIAssistantProvider(client) as provider, + ): + # Create an assistant directly with SDK (simulating pre-existing assistant) + sdk_assistant = await client.beta.assistants.create( + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), + name="ExistingAssistant", + instructions="You always respond with 'Hello!'", + ) + + try: + # Retrieve using provider + agent = await provider.get_agent(sdk_assistant.id) + print(f"Retrieved: {agent.name} (ID: {agent.id})") + + result = await agent.run("Hi there!") + print(f"Response: {result}") + finally: + await client.beta.assistants.delete(sdk_assistant.id) + + +async def as_agent_example() -> None: + """Wrap an SDK Assistant object using provider.as_agent().""" + print("\n--- as_agent() ---") + + async with ( + AsyncOpenAI() as client, + OpenAIAssistantProvider(client) as provider, + ): + # Create assistant using SDK + sdk_assistant = await client.beta.assistants.create( + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), + name="WrappedAssistant", + instructions="You respond with poetry.", + ) + + try: + # Wrap synchronously (no HTTP call) + agent = provider.as_agent(sdk_assistant) + print(f"Wrapped: {agent.name} (ID: {agent.id})") + + result = await agent.run("Tell me about the sunset.") + print(f"Response: {result}") + finally: + await client.beta.assistants.delete(sdk_assistant.id) + + +async def multiple_agents_example() -> None: + """Create and manage multiple assistants with a single provider.""" + print("\n--- Multiple Agents ---") + + async with ( + AsyncOpenAI() as client, + OpenAIAssistantProvider(client) as provider, + ): + weather_agent = await provider.create_agent( + name="WeatherSpecialist", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), + instructions="You are a weather specialist.", + tools=[get_weather], + ) + + greeter_agent = await provider.create_agent( + name="GreeterAgent", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), + instructions="You are a friendly greeter.", + ) + + try: + print(f"Created: {weather_agent.name}, {greeter_agent.name}") + + greeting = await greeter_agent.run("Hello!") + print(f"Greeter: {greeting}") + + weather = await weather_agent.run("What's the weather in Tokyo?") + print(f"Weather: {weather}") + finally: + await client.beta.assistants.delete(weather_agent.id) + await client.beta.assistants.delete(greeter_agent.id) + + +async def main() -> None: + print("OpenAI Assistant Provider Methods") + + await create_agent_example() + await get_agent_example() + await as_agent_example() + await multiple_agents_example() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/agents/openai/openai_assistants_with_code_interpreter.py b/python/samples/getting_started/agents/openai/openai_assistants_with_code_interpreter.py index b5e9ed3d69..b4a25b8465 100644 --- a/python/samples/getting_started/agents/openai/openai_assistants_with_code_interpreter.py +++ b/python/samples/getting_started/agents/openai/openai_assistants_with_code_interpreter.py @@ -1,9 +1,11 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import os -from agent_framework import AgentRunResponseUpdate, ChatAgent, ChatResponseUpdate, HostedCodeInterpreterTool -from agent_framework.openai import OpenAIAssistantsClient +from agent_framework import AgentResponseUpdate, ChatResponseUpdate, HostedCodeInterpreterTool +from agent_framework.openai import OpenAIAssistantProvider +from openai import AsyncOpenAI from openai.types.beta.threads.runs import ( CodeInterpreterToolCallDelta, RunStepDelta, @@ -20,7 +22,7 @@ """ -def get_code_interpreter_chunk(chunk: AgentRunResponseUpdate) -> str | None: +def get_code_interpreter_chunk(chunk: AgentResponseUpdate) -> str | None: """Helper method to access code interpreter data.""" if ( isinstance(chunk.raw_representation, ChatResponseUpdate) @@ -41,13 +43,19 @@ def get_code_interpreter_chunk(chunk: AgentRunResponseUpdate) -> str | None: async def main() -> None: """Example showing how to use the HostedCodeInterpreterTool with OpenAI Assistants.""" - print("=== OpenAI Assistants Agent with Code Interpreter Example ===") + print("=== OpenAI Assistants Provider with Code Interpreter Example ===") - async with ChatAgent( - chat_client=OpenAIAssistantsClient(), + client = AsyncOpenAI() + provider = OpenAIAssistantProvider(client) + + agent = await provider.create_agent( + name="CodeHelper", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), instructions="You are a helpful assistant that can write and execute Python code to solve problems.", - tools=HostedCodeInterpreterTool(), - ) as agent: + tools=[HostedCodeInterpreterTool()], + ) + + try: query = "Use code to get the factorial of 100?" print(f"User: {query}") print("Agent: ", end="", flush=True) @@ -60,6 +68,8 @@ async def main() -> None: generated_code += code_interpreter_chunk print(f"\nGenerated code:\n{generated_code}") + finally: + await client.beta.assistants.delete(agent.id) if __name__ == "__main__": diff --git a/python/samples/getting_started/agents/openai/openai_assistants_with_existing_assistant.py b/python/samples/getting_started/agents/openai/openai_assistants_with_existing_assistant.py index dd63cdc8b8..a0e9497d3e 100644 --- a/python/samples/getting_started/agents/openai/openai_assistants_with_existing_assistant.py +++ b/python/samples/getting_started/agents/openai/openai_assistants_with_existing_assistant.py @@ -5,8 +5,7 @@ from random import randint from typing import Annotated -from agent_framework import ChatAgent -from agent_framework.openai import OpenAIAssistantsClient +from agent_framework.openai import OpenAIAssistantProvider from openai import AsyncOpenAI from pydantic import Field @@ -14,7 +13,7 @@ OpenAI Assistants with Existing Assistant Example This sample demonstrates working with pre-existing OpenAI Assistants -using existing assistant IDs rather than creating new ones. +using the provider's get_agent() and as_agent() methods. """ @@ -23,31 +22,86 @@ def get_weather( ) -> str: """Get the weather for a given location.""" conditions = ["sunny", "cloudy", "rainy", "stormy"] - return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}C." -async def main() -> None: - print("=== OpenAI Assistants Chat Client with Existing Assistant ===") +async def example_get_agent_by_id() -> None: + """Example: Using get_agent() to retrieve an existing assistant by ID.""" + print("=== Get Existing Assistant by ID ===") - # Create the client client = AsyncOpenAI() + provider = OpenAIAssistantProvider(client) - # Create an assistant that will persist + # Create an assistant via SDK (simulating an existing assistant) created_assistant = await client.beta.assistants.create( - model=os.environ["OPENAI_CHAT_MODEL_ID"], name="WeatherAssistant" + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), + name="WeatherAssistant", + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather for a given location.", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string", "description": "The location"}}, + "required": ["location"], + }, + }, + } + ], ) + print(f"Created assistant: {created_assistant.id}") try: - async with ChatAgent( - chat_client=OpenAIAssistantsClient(async_client=client, assistant_id=created_assistant.id), + # Use get_agent() to retrieve the existing assistant + agent = await provider.get_agent( + assistant_id=created_assistant.id, + tools=[get_weather], # Required: implementation for function tools instructions="You are a helpful weather agent.", - tools=get_weather, - ) as agent: - result = await agent.run("What's the weather like in Tokyo?") - print(f"Result: {result}\n") + ) + + result = await agent.run("What's the weather like in Tokyo?") + print(f"Agent: {result}\n") + finally: + await client.beta.assistants.delete(created_assistant.id) + print("Assistant deleted.\n") + + +async def example_as_agent_wrap_sdk_object() -> None: + """Example: Using as_agent() to wrap an existing SDK Assistant object.""" + print("=== Wrap Existing SDK Assistant Object ===") + + client = AsyncOpenAI() + provider = OpenAIAssistantProvider(client) + + # Create and fetch an assistant via SDK + created_assistant = await client.beta.assistants.create( + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), + name="SimpleAssistant", + instructions="You are a friendly assistant.", + ) + print(f"Created assistant: {created_assistant.id}") + + try: + # Use as_agent() to wrap the SDK object + agent = provider.as_agent( + created_assistant, + instructions="You are an extremely helpful assistant. Be enthusiastic!", + ) + + result = await agent.run("Hello! What can you help me with?") + print(f"Agent: {result}\n") finally: - # Clean up the assistant manually await client.beta.assistants.delete(created_assistant.id) + print("Assistant deleted.\n") + + +async def main() -> None: + print("=== OpenAI Assistants Provider with Existing Assistant Examples ===\n") + + await example_get_agent_by_id() + await example_as_agent_wrap_sdk_object() if __name__ == "__main__": diff --git a/python/samples/getting_started/agents/openai/openai_assistants_with_explicit_settings.py b/python/samples/getting_started/agents/openai/openai_assistants_with_explicit_settings.py index 8fc9d8802d..af99a0a8f9 100644 --- a/python/samples/getting_started/agents/openai/openai_assistants_with_explicit_settings.py +++ b/python/samples/getting_started/agents/openai/openai_assistants_with_explicit_settings.py @@ -5,7 +5,8 @@ from random import randint from typing import Annotated -from agent_framework.openai import OpenAIAssistantsClient +from agent_framework.openai import OpenAIAssistantProvider +from openai import AsyncOpenAI from pydantic import Field """ @@ -21,21 +22,28 @@ def get_weather( ) -> str: """Get the weather for a given location.""" conditions = ["sunny", "cloudy", "rainy", "stormy"] - return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}C." async def main() -> None: - print("=== OpenAI Assistants Client with Explicit Settings ===") + print("=== OpenAI Assistants Provider with Explicit Settings ===") - async with OpenAIAssistantsClient( - model_id=os.environ["OPENAI_CHAT_MODEL_ID"], - api_key=os.environ["OPENAI_API_KEY"], - ).create_agent( + # Create client with explicit API key + client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"]) + provider = OpenAIAssistantProvider(client) + + agent = await provider.create_agent( + name="WeatherAssistant", + model=os.environ["OPENAI_CHAT_MODEL_ID"], instructions="You are a helpful weather agent.", - tools=get_weather, - ) as agent: + tools=[get_weather], + ) + + try: result = await agent.run("What's the weather like in New York?") print(f"Result: {result}\n") + finally: + await client.beta.assistants.delete(agent.id) if __name__ == "__main__": diff --git a/python/samples/getting_started/agents/openai/openai_assistants_with_file_search.py b/python/samples/getting_started/agents/openai/openai_assistants_with_file_search.py index 4d50ee5f02..035b6e88f2 100644 --- a/python/samples/getting_started/agents/openai/openai_assistants_with_file_search.py +++ b/python/samples/getting_started/agents/openai/openai_assistants_with_file_search.py @@ -1,9 +1,11 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import os -from agent_framework import ChatAgent, HostedFileSearchTool, HostedVectorStoreContent -from agent_framework.openai import OpenAIAssistantsClient +from agent_framework import HostedFileSearchTool, HostedVectorStoreContent +from agent_framework.openai import OpenAIAssistantProvider +from openai import AsyncOpenAI """ OpenAI Assistants with File Search Example @@ -12,41 +14,43 @@ for document-based question answering and information retrieval. """ -# Helper functions - -async def create_vector_store(client: OpenAIAssistantsClient) -> tuple[str, HostedVectorStoreContent]: +async def create_vector_store(client: AsyncOpenAI) -> tuple[str, HostedVectorStoreContent]: """Create a vector store with sample documents.""" - file = await client.client.files.create( + file = await client.files.create( file=("todays_weather.txt", b"The weather today is sunny with a high of 75F."), purpose="user_data" ) - vector_store = await client.client.vector_stores.create( + vector_store = await client.vector_stores.create( name="knowledge_base", expires_after={"anchor": "last_active_at", "days": 1}, ) - result = await client.client.vector_stores.files.create_and_poll(vector_store_id=vector_store.id, file_id=file.id) + result = await client.vector_stores.files.create_and_poll(vector_store_id=vector_store.id, file_id=file.id) if result.last_error is not None: raise Exception(f"Vector store file processing failed with status: {result.last_error.message}") return file.id, HostedVectorStoreContent(vector_store_id=vector_store.id) -async def delete_vector_store(client: OpenAIAssistantsClient, file_id: str, vector_store_id: str) -> None: +async def delete_vector_store(client: AsyncOpenAI, file_id: str, vector_store_id: str) -> None: """Delete the vector store after using it.""" - - await client.client.vector_stores.delete(vector_store_id=vector_store_id) - await client.client.files.delete(file_id=file_id) + await client.vector_stores.delete(vector_store_id=vector_store_id) + await client.files.delete(file_id=file_id) async def main() -> None: - print("=== OpenAI Assistants Client Agent with File Search Example ===\n") + print("=== OpenAI Assistants Provider with File Search Example ===\n") + + client = AsyncOpenAI() + provider = OpenAIAssistantProvider(client) - client = OpenAIAssistantsClient() - async with ChatAgent( - chat_client=client, + agent = await provider.create_agent( + name="SearchAssistant", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), instructions="You are a helpful assistant that searches files in a knowledge base.", - tools=HostedFileSearchTool(), - ) as agent: + tools=[HostedFileSearchTool()], + ) + + try: query = "What is the weather today? Do a file search to find the answer." file_id, vector_store = await create_vector_store(client) @@ -57,7 +61,10 @@ async def main() -> None: ): if chunk.text: print(chunk.text, end="", flush=True) + await delete_vector_store(client, file_id, vector_store.vector_store_id) + finally: + await client.beta.assistants.delete(agent.id) if __name__ == "__main__": diff --git a/python/samples/getting_started/agents/openai/openai_assistants_with_function_tools.py b/python/samples/getting_started/agents/openai/openai_assistants_with_function_tools.py index 6d3c3fccef..2e3e3f0b07 100644 --- a/python/samples/getting_started/agents/openai/openai_assistants_with_function_tools.py +++ b/python/samples/getting_started/agents/openai/openai_assistants_with_function_tools.py @@ -1,12 +1,13 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import os from datetime import datetime, timezone from random import randint from typing import Annotated -from agent_framework import ChatAgent -from agent_framework.openai import OpenAIAssistantsClient +from agent_framework.openai import OpenAIAssistantProvider +from openai import AsyncOpenAI from pydantic import Field """ @@ -22,7 +23,7 @@ def get_weather( ) -> str: """Get the weather for a given location.""" conditions = ["sunny", "cloudy", "rainy", "stormy"] - return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}C." def get_time() -> str: @@ -35,13 +36,19 @@ async def tools_on_agent_level() -> None: """Example showing tools defined when creating the agent.""" print("=== Tools Defined on Agent Level ===") + client = AsyncOpenAI() + provider = OpenAIAssistantProvider(client) + # Tools are provided when creating the agent # The agent can use these tools for any query during its lifetime - async with ChatAgent( - chat_client=OpenAIAssistantsClient(), + agent = await provider.create_agent( + name="InfoAssistant", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), instructions="You are a helpful assistant that can provide weather and time information.", tools=[get_weather, get_time], # Tools defined at agent creation - ) as agent: + ) + + try: # First query - agent can use weather tool query1 = "What's the weather like in New York?" print(f"User: {query1}") @@ -59,47 +66,63 @@ async def tools_on_agent_level() -> None: print(f"User: {query3}") result3 = await agent.run(query3) print(f"Agent: {result3}\n") + finally: + await client.beta.assistants.delete(agent.id) async def tools_on_run_level() -> None: """Example showing tools passed to the run method.""" print("=== Tools Passed to Run Method ===") - # Agent created without tools - async with ChatAgent( - chat_client=OpenAIAssistantsClient(), + client = AsyncOpenAI() + provider = OpenAIAssistantProvider(client) + + # Agent created with base tools, additional tools can be passed at run time + agent = await provider.create_agent( + name="FlexibleAssistant", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), instructions="You are a helpful assistant.", - # No tools defined here - ) as agent: - # First query with weather tool + tools=[get_weather], # Base tool + ) + + try: + # First query using base weather tool query1 = "What's the weather like in Seattle?" print(f"User: {query1}") - result1 = await agent.run(query1, tools=[get_weather]) # Tool passed to run method + result1 = await agent.run(query1) print(f"Agent: {result1}\n") - # Second query with time tool + # Second query with additional time tool query2 = "What's the current UTC time?" print(f"User: {query2}") - result2 = await agent.run(query2, tools=[get_time]) # Different tool for this query + result2 = await agent.run(query2, tools=[get_time]) # Additional tool for this query print(f"Agent: {result2}\n") - # Third query with multiple tools + # Third query with both tools query3 = "What's the weather in Chicago and what's the current UTC time?" print(f"User: {query3}") - result3 = await agent.run(query3, tools=[get_weather, get_time]) # Multiple tools + result3 = await agent.run(query3, tools=[get_time]) # Time tool adds to weather print(f"Agent: {result3}\n") + finally: + await client.beta.assistants.delete(agent.id) async def mixed_tools_example() -> None: """Example showing both agent-level tools and run-method tools.""" print("=== Mixed Tools Example (Agent + Run Method) ===") + client = AsyncOpenAI() + provider = OpenAIAssistantProvider(client) + # Agent created with some base tools - async with ChatAgent( - chat_client=OpenAIAssistantsClient(), + agent = await provider.create_agent( + name="ComprehensiveAssistant", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), instructions="You are a comprehensive assistant that can help with various information requests.", tools=[get_weather], # Base tool available for all queries - ) as agent: + ) + + try: # Query using both agent tool and additional run-method tools query = "What's the weather in Denver and what's the current UTC time?" print(f"User: {query}") @@ -110,10 +133,12 @@ async def mixed_tools_example() -> None: tools=[get_time], # Additional tools for this specific query ) print(f"Agent: {result}\n") + finally: + await client.beta.assistants.delete(agent.id) async def main() -> None: - print("=== OpenAI Assistants Chat Client Agent with Function Tools Examples ===\n") + print("=== OpenAI Assistants Provider with Function Tools Examples ===\n") await tools_on_agent_level() await tools_on_run_level() diff --git a/python/samples/getting_started/agents/openai/openai_assistants_with_response_format.py b/python/samples/getting_started/agents/openai/openai_assistants_with_response_format.py new file mode 100644 index 0000000000..796bdd803c --- /dev/null +++ b/python/samples/getting_started/agents/openai/openai_assistants_with_response_format.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import os + +from agent_framework.openai import OpenAIAssistantProvider +from openai import AsyncOpenAI +from pydantic import BaseModel, ConfigDict + +""" +OpenAI Assistant Provider Response Format Example + +This sample demonstrates using OpenAIAssistantProvider with response_format +for structured outputs in two ways: +1. Setting default response_format at agent creation time (default_options) +2. Overriding response_format at runtime (options parameter in agent.run) +""" + + +class WeatherInfo(BaseModel): + """Structured weather information.""" + + location: str + temperature: int + conditions: str + recommendation: str + model_config = ConfigDict(extra="forbid") + + +class CityInfo(BaseModel): + """Structured city information.""" + + city_name: str + population: int + country: str + model_config = ConfigDict(extra="forbid") + + +async def main() -> None: + """Example of using response_format at creation time and runtime.""" + + async with ( + AsyncOpenAI() as client, + OpenAIAssistantProvider(client) as provider, + ): + # Create agent with default response_format (WeatherInfo) + agent = await provider.create_agent( + name="StructuredReporter", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), + instructions="Return structured JSON based on the requested format.", + default_options={"response_format": WeatherInfo}, + ) + + try: + # Request 1: Uses default response_format from agent creation + print("--- Request 1: Using default response_format (WeatherInfo) ---") + query1 = "What's the weather like in Paris today?" + print(f"User: {query1}") + + result1 = await agent.run(query1) + + if isinstance(result1.value, WeatherInfo): + weather = result1.value + print("Agent:") + print(f" Location: {weather.location}") + print(f" Temperature: {weather.temperature}") + print(f" Conditions: {weather.conditions}") + print(f" Recommendation: {weather.recommendation}") + + # Request 2: Override response_format at runtime with CityInfo + print("\n--- Request 2: Runtime override with CityInfo ---") + query2 = "Tell me about Tokyo." + print(f"User: {query2}") + + result2 = await agent.run(query2, options={"response_format": CityInfo}) + + if isinstance(result2.value, CityInfo): + city = result2.value + print("Agent:") + print(f" City: {city.city_name}") + print(f" Population: {city.population}") + print(f" Country: {city.country}") + finally: + await client.beta.assistants.delete(agent.id) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/agents/openai/openai_assistants_with_thread.py b/python/samples/getting_started/agents/openai/openai_assistants_with_thread.py index 9b6e2d3f5c..7adb4c61cd 100644 --- a/python/samples/getting_started/agents/openai/openai_assistants_with_thread.py +++ b/python/samples/getting_started/agents/openai/openai_assistants_with_thread.py @@ -1,11 +1,13 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import os from random import randint from typing import Annotated -from agent_framework import AgentThread, ChatAgent -from agent_framework.openai import OpenAIAssistantsClient +from agent_framework import AgentThread +from agent_framework.openai import OpenAIAssistantProvider +from openai import AsyncOpenAI from pydantic import Field """ @@ -21,18 +23,24 @@ def get_weather( ) -> str: """Get the weather for a given location.""" conditions = ["sunny", "cloudy", "rainy", "stormy"] - return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}C." async def example_with_automatic_thread_creation() -> None: """Example showing automatic thread creation (service-managed thread).""" print("=== Automatic Thread Creation Example ===") - async with ChatAgent( - chat_client=OpenAIAssistantsClient(), + client = AsyncOpenAI() + provider = OpenAIAssistantProvider(client) + + agent = await provider.create_agent( + name="WeatherAssistant", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), instructions="You are a helpful weather agent.", - tools=get_weather, - ) as agent: + tools=[get_weather], + ) + + try: # First conversation - no thread provided, will be created automatically query1 = "What's the weather like in Seattle?" print(f"User: {query1}") @@ -45,6 +53,8 @@ async def example_with_automatic_thread_creation() -> None: result2 = await agent.run(query2) print(f"Agent: {result2.text}") print("Note: Each call creates a separate thread, so the agent doesn't remember previous context.\n") + finally: + await client.beta.assistants.delete(agent.id) async def example_with_thread_persistence() -> None: @@ -52,11 +62,17 @@ async def example_with_thread_persistence() -> None: print("=== Thread Persistence Example ===") print("Using the same thread across multiple conversations to maintain context.\n") - async with ChatAgent( - chat_client=OpenAIAssistantsClient(), + client = AsyncOpenAI() + provider = OpenAIAssistantProvider(client) + + agent = await provider.create_agent( + name="WeatherAssistant", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), instructions="You are a helpful weather agent.", - tools=get_weather, - ) as agent: + tools=[get_weather], + ) + + try: # Create a new thread that will be reused thread = agent.get_new_thread() @@ -78,6 +94,8 @@ async def example_with_thread_persistence() -> None: result3 = await agent.run(query3, thread=thread) print(f"Agent: {result3.text}") print("Note: The agent remembers context from previous messages in the same thread.\n") + finally: + await client.beta.assistants.delete(agent.id) async def example_with_existing_thread_id() -> None: @@ -85,14 +103,22 @@ async def example_with_existing_thread_id() -> None: print("=== Existing Thread ID Example ===") print("Using a specific thread ID to continue an existing conversation.\n") + client = AsyncOpenAI() + provider = OpenAIAssistantProvider(client) + # First, create a conversation and capture the thread ID existing_thread_id = None + assistant_id = None - async with ChatAgent( - chat_client=OpenAIAssistantsClient(), + agent = await provider.create_agent( + name="WeatherAssistant", + model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"), instructions="You are a helpful weather agent.", - tools=get_weather, - ) as agent: + tools=[get_weather], + ) + assistant_id = agent.id + + try: # Start a conversation and get the thread ID thread = agent.get_new_thread() query1 = "What's the weather in Paris?" @@ -104,27 +130,30 @@ async def example_with_existing_thread_id() -> None: existing_thread_id = thread.service_thread_id print(f"Thread ID: {existing_thread_id}") - if existing_thread_id: - print("\n--- Continuing with the same thread ID in a new agent instance ---") + if existing_thread_id: + print("\n--- Continuing with the same thread ID using get_agent ---") + + # Get the existing assistant by ID + agent2 = await provider.get_agent( + assistant_id=assistant_id, + tools=[get_weather], # Must provide function implementations + ) - # Create a new agent instance but use the existing thread ID - async with ChatAgent( - chat_client=OpenAIAssistantsClient(thread_id=existing_thread_id), - instructions="You are a helpful weather agent.", - tools=get_weather, - ) as agent: # Create a thread with the existing ID thread = AgentThread(service_thread_id=existing_thread_id) query2 = "What was the last city I asked about?" print(f"User: {query2}") - result2 = await agent.run(query2, thread=thread) + result2 = await agent2.run(query2, thread=thread) print(f"Agent: {result2.text}") print("Note: The agent continues the conversation from the previous thread.\n") + finally: + if assistant_id: + await client.beta.assistants.delete(assistant_id) async def main() -> None: - print("=== OpenAI Assistants Chat Client Agent Thread Management Examples ===\n") + print("=== OpenAI Assistants Provider Thread Management Examples ===\n") await example_with_automatic_thread_creation() await example_with_thread_persistence() diff --git a/python/samples/getting_started/agents/openai/openai_chat_client_with_runtime_json_schema.py b/python/samples/getting_started/agents/openai/openai_chat_client_with_runtime_json_schema.py index f461c2864b..3489dc0489 100644 --- a/python/samples/getting_started/agents/openai/openai_chat_client_with_runtime_json_schema.py +++ b/python/samples/getting_started/agents/openai/openai_chat_client_with_runtime_json_schema.py @@ -3,7 +3,7 @@ import asyncio import json -from agent_framework.openai import OpenAIChatClient +from agent_framework.openai import OpenAIChatClient, OpenAIChatOptions """ OpenAI Chat Client Runtime JSON Schema Example @@ -32,7 +32,7 @@ async def non_streaming_example() -> None: print("=== Non-streaming runtime JSON schema example ===") - agent = OpenAIChatClient().create_agent( + agent = OpenAIChatClient[OpenAIChatOptions]().create_agent( name="RuntimeSchemaAgent", instructions="Return only JSON that matches the provided schema. Do not add commentary.", ) @@ -42,7 +42,7 @@ async def non_streaming_example() -> None: response = await agent.run( query, - additional_chat_options={ + options={ "response_format": { "type": "json_schema", "json_schema": { @@ -76,7 +76,7 @@ async def streaming_example() -> None: chunks: list[str] = [] async for chunk in agent.run_stream( query, - additional_chat_options={ + options={ "response_format": { "type": "json_schema", "json_schema": { diff --git a/python/samples/getting_started/agents/openai/openai_responses_client_reasoning.py b/python/samples/getting_started/agents/openai/openai_responses_client_reasoning.py index b07a7fb314..1b06e9db04 100644 --- a/python/samples/getting_started/agents/openai/openai_responses_client_reasoning.py +++ b/python/samples/getting_started/agents/openai/openai_responses_client_reasoning.py @@ -2,7 +2,7 @@ import asyncio -from agent_framework.openai import OpenAIResponsesClient +from agent_framework.openai import OpenAIResponsesClient, OpenAIResponsesOptions """ OpenAI Responses Client Reasoning Example @@ -10,19 +10,20 @@ This sample demonstrates advanced reasoning capabilities using OpenAI's gpt-5 models, showing step-by-step reasoning process visualization and complex problem-solving. -This uses the additional_chat_options parameter to enable reasoning with high effort and detailed summaries. -You can also set these options at the run level, since they are api and/or provider specific, you will need to lookup -the correct values for your provider, since these are passed through as-is. +This uses the default_options parameter to enable reasoning with high effort and detailed summaries. +You can also set these options at the run level using the options parameter. +Since these are api and/or provider specific, you will need to lookup +the correct values for your provider, as they are passed through as-is. In this case they are here: https://platform.openai.com/docs/api-reference/responses/create#responses-create-reasoning """ -agent = OpenAIResponsesClient(model_id="gpt-5").create_agent( +agent = OpenAIResponsesClient[OpenAIResponsesOptions](model_id="gpt-5").create_agent( name="MathHelper", instructions="You are a personal math tutor. When asked a math question, " "reason over how best to approach the problem and share your thought process.", - additional_chat_options={"reasoning": {"effort": "high", "summary": "detailed"}}, + default_options={"reasoning": {"effort": "high", "summary": "detailed"}}, ) diff --git a/python/samples/getting_started/agents/openai/openai_responses_client_with_runtime_json_schema.py b/python/samples/getting_started/agents/openai/openai_responses_client_with_runtime_json_schema.py index c32a6a5880..14aff76760 100644 --- a/python/samples/getting_started/agents/openai/openai_responses_client_with_runtime_json_schema.py +++ b/python/samples/getting_started/agents/openai/openai_responses_client_with_runtime_json_schema.py @@ -42,7 +42,7 @@ async def non_streaming_example() -> None: response = await agent.run( query, - additional_chat_options={ + options={ "response_format": { "type": "json_schema", "json_schema": { @@ -76,7 +76,7 @@ async def streaming_example() -> None: chunks: list[str] = [] async for chunk in agent.run_stream( query, - additional_chat_options={ + options={ "response_format": { "type": "json_schema", "json_schema": { diff --git a/python/samples/getting_started/agents/openai/openai_responses_client_with_structured_output.py b/python/samples/getting_started/agents/openai/openai_responses_client_with_structured_output.py index 88e36236ca..ba208da94b 100644 --- a/python/samples/getting_started/agents/openai/openai_responses_client_with_structured_output.py +++ b/python/samples/getting_started/agents/openai/openai_responses_client_with_structured_output.py @@ -2,7 +2,7 @@ import asyncio -from agent_framework import AgentRunResponse +from agent_framework import AgentResponse from agent_framework.openai import OpenAIResponsesClient from pydantic import BaseModel @@ -35,7 +35,7 @@ async def non_streaming_example() -> None: print(f"User: {query}") # Get structured response from the agent using response_format parameter - result = await agent.run(query, response_format=OutputStruct) + result = await agent.run(query, options={"response_format": OutputStruct}) # Access the structured output directly from the response value if result.value: @@ -60,17 +60,17 @@ async def streaming_example() -> None: query = "Tell me about Tokyo, Japan" print(f"User: {query}") - # Get structured response from streaming agent using AgentRunResponse.from_agent_response_generator - # This method collects all streaming updates and combines them into a single AgentRunResponse - result = await AgentRunResponse.from_agent_response_generator( - agent.run_stream(query, response_format=OutputStruct), + # Get structured response from streaming agent using AgentResponse.from_agent_response_generator + # This method collects all streaming updates and combines them into a single AgentResponse + result = await AgentResponse.from_agent_response_generator( + agent.run_stream(query, options={"response_format": OutputStruct}), output_format_type=OutputStruct, ) # Access the structured output directly from the response value if result.value: structured_data: OutputStruct = result.value # type: ignore - print("Structured Output (from streaming with AgentRunResponse.from_agent_response_generator):") + print("Structured Output (from streaming with AgentResponse.from_agent_response_generator):") print(f"City: {structured_data.city}") print(f"Description: {structured_data.description}") else: diff --git a/python/samples/getting_started/azure_functions/03_reliable_streaming/function_app.py b/python/samples/getting_started/azure_functions/03_reliable_streaming/function_app.py index 31db10a9df..6148a42294 100644 --- a/python/samples/getting_started/azure_functions/03_reliable_streaming/function_app.py +++ b/python/samples/getting_started/azure_functions/03_reliable_streaming/function_app.py @@ -19,9 +19,9 @@ import os from datetime import timedelta -import redis.asyncio as aioredis -from agent_framework import AgentRunResponseUpdate import azure.functions as func +import redis.asyncio as aioredis +from agent_framework import AgentResponseUpdate from agent_framework.azure import ( AgentCallbackContext, AgentFunctionApp, @@ -39,6 +39,7 @@ REDIS_CONNECTION_STRING = os.environ.get("REDIS_CONNECTION_STRING", "redis://localhost:6379") REDIS_STREAM_TTL_MINUTES = int(os.environ.get("REDIS_STREAM_TTL_MINUTES", "10")) + async def get_stream_handler() -> RedisStreamResponseHandler: """Create a new Redis stream handler for each request. @@ -70,7 +71,7 @@ def __init__(self) -> None: async def on_streaming_response_update( self, - update: AgentRunResponseUpdate, + update: AgentResponseUpdate, context: AgentCallbackContext, ) -> None: """Write streaming update to Redis Stream. diff --git a/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py b/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py index 41fb3f08b2..d2d1cc2047 100644 --- a/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py +++ b/python/samples/getting_started/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py @@ -13,8 +13,8 @@ from collections.abc import Generator from typing import Any, cast -from agent_framework import AgentRunResponse import azure.functions as func +from agent_framework import AgentResponse from agent_framework.azure import AgentFunctionApp, AzureOpenAIChatClient from azure.durable_functions import DurableOrchestrationClient, DurableOrchestrationContext from azure.identity import AzureCliCredential @@ -72,8 +72,8 @@ def multi_agent_concurrent_orchestration(context: DurableOrchestrationContext) - # Execute both tasks concurrently using task_all task_results = yield context.task_all([physicist_task, chemist_task]) - physicist_result = cast(AgentRunResponse, task_results[0]) - chemist_result = cast(AgentRunResponse, task_results[1]) + physicist_result = cast(AgentResponse, task_results[0]) + chemist_result = cast(AgentResponse, task_results[1]) return { "physicist": physicist_result.text, diff --git a/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py b/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py index 8ef5ef7211..32b5775ef8 100644 --- a/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py +++ b/python/samples/getting_started/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py @@ -99,7 +99,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext) -> Genera spam_result_raw = yield spam_agent.run( messages=spam_prompt, thread=spam_thread, - response_format=SpamDetectionResult, + options={"response_format": SpamDetectionResult}, ) spam_result = cast(SpamDetectionResult, spam_result_raw.value) @@ -120,7 +120,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext) -> Genera email_result_raw = yield email_agent.run( messages=email_prompt, thread=email_thread, - response_format=EmailResponse, + options={"response_format": EmailResponse}, ) email_result = cast(EmailResponse, email_result_raw.value) diff --git a/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py b/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py index 08a14ffe11..2e394faea2 100644 --- a/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py +++ b/python/samples/getting_started/azure_functions/07_single_agent_orchestration_hitl/function_app.py @@ -98,7 +98,7 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext) initial_raw = yield writer.run( messages=f"Write a short article about '{payload.topic}'.", thread=writer_thread, - response_format=GeneratedContent, + options={"response_format": GeneratedContent}, ) content = initial_raw.value @@ -134,9 +134,7 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext) ) return {"content": content.content} - context.set_custom_status( - "Content rejected by human reviewer. Incorporating feedback and regenerating..." - ) + context.set_custom_status("Content rejected by human reviewer. Incorporating feedback and regenerating...") rewrite_prompt = ( "The content was rejected by a human reviewer. Please rewrite the article incorporating their feedback.\n\n" f"Human Feedback: {approval_payload.feedback or 'No feedback provided.'}" @@ -144,7 +142,7 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext) rewritten_raw = yield writer.run( messages=rewrite_prompt, thread=writer_thread, - response_format=GeneratedContent, + options={"response_format": GeneratedContent}, ) rewritten_value = rewritten_raw.value @@ -156,9 +154,7 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext) context.set_custom_status( f"Human approval timed out after {payload.approval_timeout_hours} hour(s). Treating as rejection." ) - raise TimeoutError( - f"Human approval timed out after {payload.approval_timeout_hours} hour(s)." - ) + raise TimeoutError(f"Human approval timed out after {payload.approval_timeout_hours} hour(s).") raise RuntimeError(f"Content could not be approved after {payload.max_review_attempts} iteration(s).") diff --git a/python/samples/getting_started/chat_client/README.md b/python/samples/getting_started/chat_client/README.md index 293b454821..4b36865769 100644 --- a/python/samples/getting_started/chat_client/README.md +++ b/python/samples/getting_started/chat_client/README.md @@ -35,6 +35,6 @@ Depending on which client you're using, set the appropriate environment variable **For Ollama client:** - `OLLAMA_HOST`: Your Ollama server URL (defaults to `http://localhost:11434` if not set) -- `OLLAMA_CHAT_MODEL_ID`: The Ollama model to use for chat (e.g., `llama3.2`, `llama2`, `codellama`) +- `OLLAMA_MODEL_ID`: The Ollama model to use for chat (e.g., `llama3.2`, `llama2`, `codellama`) > **Note**: For Ollama, ensure you have Ollama installed and running locally with at least one model downloaded. Visit [https://ollama.com/](https://ollama.com/) for installation instructions. \ No newline at end of file diff --git a/python/samples/getting_started/chat_client/azure_responses_client.py b/python/samples/getting_started/chat_client/azure_responses_client.py index 158a2eb78c..ec15ee7723 100644 --- a/python/samples/getting_started/chat_client/azure_responses_client.py +++ b/python/samples/getting_started/chat_client/azure_responses_client.py @@ -41,13 +41,13 @@ async def main() -> None: print(f"User: {message}") if stream: response = await ChatResponse.from_chat_response_generator( - client.get_streaming_response(message, tools=get_weather, response_format=OutputStruct), + client.get_streaming_response(message, tools=get_weather, options={"response_format": OutputStruct}), output_format_type=OutputStruct, ) print(f"Assistant: {response.value}") else: - response = await client.get_response(message, tools=get_weather, response_format=OutputStruct) + response = await client.get_response(message, tools=get_weather, options={"response_format": OutputStruct}) print(f"Assistant: {response.value}") diff --git a/python/samples/getting_started/chat_client/typed_options.py b/python/samples/getting_started/chat_client/typed_options.py new file mode 100644 index 0000000000..533b214ebe --- /dev/null +++ b/python/samples/getting_started/chat_client/typed_options.py @@ -0,0 +1,182 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from typing import Literal + +from agent_framework import ChatAgent +from agent_framework.anthropic import AnthropicClient +from agent_framework.openai import OpenAIChatClient, OpenAIChatOptions + +"""TypedDict-based Chat Options. + +In Agent Framework, we have made ChatClient and ChatAgent generic over a ChatOptions typeddict, this means that +you can override which options are available for a given client or agent by providing your own TypedDict subclass. +And we include the most common options for all ChatClient providers out of the box. + +This sample demonstrates the TypedDict-based approach for chat client and agent options, +which provides: +1. IDE autocomplete for available options +2. Type checking to catch errors at development time +3. An example of defining provider-specific options by extending the base options, + including overriding unsupported options. + +The sample shows usage with both OpenAI and Anthropic clients, demonstrating +how provider-specific options work for ChatClient and ChatAgent. But the same approach works for other providers too. +""" + + +async def demo_anthropic_chat_client() -> None: + """Demonstrate Anthropic ChatClient with typed options and validation.""" + print("\n=== Anthropic ChatClient with TypedDict Options ===\n") + + # Create Anthropic client + client = AnthropicClient(model_id="claude-sonnet-4-5-20250929") + + # Standard options work great: + response = await client.get_response( + "What is the capital of France?", + options={ + "temperature": 0.5, + "max_tokens": 1000, + # Anthropic-specific options: + "thinking": {"type": "enabled", "budget_tokens": 1000}, + # "top_k": 40, # <-- Uncomment for Anthropic-specific option + }, + ) + + print(f"Anthropic Response: {response.text}") + print(f"Model used: {response.model_id}") + + +async def demo_anthropic_agent() -> None: + """Demonstrate ChatAgent with Anthropic client and typed options.""" + print("\n=== ChatAgent with Anthropic and Typed Options ===\n") + + client = AnthropicClient(model_id="claude-sonnet-4-5-20250929") + + # Create a typed agent for Anthropic - IDE knows Anthropic-specific options! + agent = ChatAgent( + chat_client=client, + name="claude-assistant", + instructions="You are a helpful assistant powered by Claude. Be concise.", + default_options={ + "temperature": 0.5, + "max_tokens": 200, + "top_k": 40, # Anthropic-specific option, uncomment to try + }, + ) + + # Run the agent + response = await agent.run("Explain quantum computing in one sentence.") + + print(f"Agent Response: {response.text}") + + +class OpenAIReasoningChatOptions(OpenAIChatOptions, total=False): + """Chat options for OpenAI reasoning models (o1, o3, o4-mini, etc.). + + Reasoning models have different parameter support compared to standard models. + This TypedDict marks unsupported parameters with ``None`` type. + + Examples: + .. code-block:: python + + from agent_framework.openai import OpenAIReasoningChatOptions + + options: OpenAIReasoningChatOptions = { + "model_id": "o3", + "reasoning_effort": "high", + "max_tokens": 4096, + } + """ + + # Reasoning-specific parameters + reasoning_effort: Literal["none", "minimal", "low", "medium", "high", "xhigh"] + + # Unsupported parameters for reasoning models (override with None) + temperature: None + top_p: None + frequency_penalty: None + presence_penalty: None + logit_bias: None + logprobs: None + top_logprobs: None + stop: None # Not supported for o3 and o4-mini + + +async def demo_openai_chat_client_reasoning_models() -> None: + """Demonstrate OpenAI ChatClient with typed options for reasoning models.""" + print("\n=== OpenAI ChatClient with TypedDict Options ===\n") + + # Create OpenAI client + client = OpenAIChatClient[OpenAIReasoningChatOptions]() + + # With specific options, you get full IDE autocomplete! + # Try typing `client.get_response("Hello", options={` and see the suggestions + response = await client.get_response( + "What is 2 + 2?", + options={ + "model_id": "o3", + "max_tokens": 100, + "allow_multiple_tool_calls": True, + # OpenAI-specific options work: + "reasoning_effort": "medium", + # Unsupported options are caught by type checker (uncomment to see): + # "temperature": 0.7, + # "random": 234, + }, + ) + + print(f"OpenAI Response: {response.text}") + print(f"Model used: {response.model_id}") + + +async def demo_openai_agent() -> None: + """Demonstrate ChatAgent with OpenAI client and typed options.""" + print("\n=== ChatAgent with OpenAI and Typed Options ===\n") + + # Create a typed agent - IDE will autocomplete options! + # The type annotation can be done either on the agent like below, + # or on the client when constructing the client instance: + # client = OpenAIChatClient[OpenAIReasoningChatOptions]() + agent = ChatAgent[OpenAIReasoningChatOptions]( + chat_client=OpenAIChatClient(), + name="weather-assistant", + instructions="You are a helpful assistant. Answer concisely.", + # Options can be set at construction time + default_options={ + "model_id": "o3", + "max_tokens": 100, + "allow_multiple_tool_calls": True, + # OpenAI-specific options work: + "reasoning_effort": "medium", + # Unsupported options are caught by type checker (uncomment to see): + # "temperature": 0.7, + # "random": 234, + }, + ) + + # Or pass options at runtime - they override construction options + response = await agent.run( + "What is 25 * 47?", + options={ + "reasoning_effort": "high", # Override for a run + }, + ) + + print(f"Agent Response: {response.text}") + + +async def main() -> None: + """Run all Typed Options demonstrations.""" + # # Anthropic demos (requires ANTHROPIC_API_KEY) + await demo_anthropic_chat_client() + await demo_anthropic_agent() + + # OpenAI demos (requires OPENAI_API_KEY) + await demo_openai_chat_client_reasoning_models() + await demo_openai_agent() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/context_providers/aggregate_context_provider.py b/python/samples/getting_started/context_providers/aggregate_context_provider.py new file mode 100644 index 0000000000..1b682fadcb --- /dev/null +++ b/python/samples/getting_started/context_providers/aggregate_context_provider.py @@ -0,0 +1,276 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +This sample demonstrates how to use an AggregateContextProvider to combine multiple context providers. + +The AggregateContextProvider is a convenience class that allows you to aggregate multiple +ContextProviders into a single provider. It delegates events to all providers and combines +their context before returning. + +You can use this implementation as-is, or implement your own aggregation logic. +""" + +import asyncio +import sys +from collections.abc import MutableSequence, Sequence +from contextlib import AsyncExitStack +from types import TracebackType +from typing import TYPE_CHECKING, Any, cast + +from agent_framework import ChatAgent, ChatMessage, Context, ContextProvider +from agent_framework.azure import AzureAIClient +from azure.identity.aio import AzureCliCredential + +if TYPE_CHECKING: + from agent_framework import ToolProtocol + +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover +if sys.version_info >= (3, 11): + from typing import Self # pragma: no cover +else: + from typing_extensions import Self # pragma: no cover + + +# region AggregateContextProvider + + +class AggregateContextProvider(ContextProvider): + """A ContextProvider that contains multiple context providers. + + It delegates events to multiple context providers and aggregates responses from those + events before returning. This allows you to combine multiple context providers into a + single provider. + + Examples: + .. code-block:: python + + from agent_framework import ChatAgent + + # Create multiple context providers + provider1 = CustomContextProvider1() + provider2 = CustomContextProvider2() + provider3 = CustomContextProvider3() + + # Combine them using AggregateContextProvider + aggregate = AggregateContextProvider([provider1, provider2, provider3]) + + # Pass the aggregate to the agent + agent = ChatAgent(chat_client=client, name="assistant", context_provider=aggregate) + + # You can also add more providers later + provider4 = CustomContextProvider4() + aggregate.add(provider4) + """ + + def __init__(self, context_providers: ContextProvider | Sequence[ContextProvider] | None = None) -> None: + """Initialize the AggregateContextProvider with context providers. + + Args: + context_providers: The context provider(s) to add. + """ + if isinstance(context_providers, ContextProvider): + self.providers = [context_providers] + else: + self.providers = cast(list[ContextProvider], context_providers) or [] + self._exit_stack: AsyncExitStack | None = None + + def add(self, context_provider: ContextProvider) -> None: + """Add a new context provider. + + Args: + context_provider: The context provider to add. + """ + self.providers.append(context_provider) + + @override + async def thread_created(self, thread_id: str | None = None) -> None: + await asyncio.gather(*[x.thread_created(thread_id) for x in self.providers]) + + @override + async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context: + contexts = await asyncio.gather(*[provider.invoking(messages, **kwargs) for provider in self.providers]) + instructions: str = "" + return_messages: list[ChatMessage] = [] + tools: list["ToolProtocol"] = [] + for ctx in contexts: + if ctx.instructions: + instructions += ctx.instructions + if ctx.messages: + return_messages.extend(ctx.messages) + if ctx.tools: + tools.extend(ctx.tools) + return Context(instructions=instructions, messages=return_messages, tools=tools) + + @override + async def invoked( + self, + request_messages: ChatMessage | Sequence[ChatMessage], + response_messages: ChatMessage | Sequence[ChatMessage] | None = None, + invoke_exception: Exception | None = None, + **kwargs: Any, + ) -> None: + await asyncio.gather(*[ + x.invoked( + request_messages=request_messages, + response_messages=response_messages, + invoke_exception=invoke_exception, + **kwargs, + ) + for x in self.providers + ]) + + @override + async def __aenter__(self) -> "Self": + """Enter the async context manager and set up all providers. + + Returns: + The AggregateContextProvider instance for chaining. + """ + self._exit_stack = AsyncExitStack() + await self._exit_stack.__aenter__() + + # Enter all context providers + for provider in self.providers: + await self._exit_stack.enter_async_context(provider) + + return self + + @override + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit the async context manager and clean up all providers. + + Args: + exc_type: The exception type if an exception occurred, None otherwise. + exc_val: The exception value if an exception occurred, None otherwise. + exc_tb: The exception traceback if an exception occurred, None otherwise. + """ + if self._exit_stack is not None: + await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + self._exit_stack = None + + +# endregion + + +# region Example Context Providers + + +class TimeContextProvider(ContextProvider): + """A simple context provider that adds time-related instructions.""" + + @override + async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context: + from datetime import datetime + + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + return Context(instructions=f"The current date and time is: {current_time}. ") + + +class PersonaContextProvider(ContextProvider): + """A context provider that adds a persona to the agent.""" + + def __init__(self, persona: str): + self.persona = persona + + @override + async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context: + return Context(instructions=f"Your persona: {self.persona}. ") + + +class PreferencesContextProvider(ContextProvider): + """A context provider that adds user preferences.""" + + def __init__(self): + self.preferences: dict[str, str] = {} + + @override + async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context: + if not self.preferences: + return Context() + prefs_str = ", ".join(f"{k}: {v}" for k, v in self.preferences.items()) + return Context(instructions=f"User preferences: {prefs_str}. ") + + @override + async def invoked( + self, + request_messages: ChatMessage | Sequence[ChatMessage], + response_messages: ChatMessage | Sequence[ChatMessage] | None = None, + invoke_exception: Exception | None = None, + **kwargs: Any, + ) -> None: + # Simple example: extract and store preferences from user messages + # In a real implementation, you might use structured extraction + msgs = [request_messages] if isinstance(request_messages, ChatMessage) else list(request_messages) + + for msg in msgs: + content = msg.content if hasattr(msg, "content") else "" + # Very simple extraction - in production, use LLM-based extraction + if isinstance(content, str) and "prefer" in content.lower() and ":" in content: + parts = content.split(":") + if len(parts) >= 2: + key = parts[0].strip().lower().replace("i prefer ", "") + value = parts[1].strip() + self.preferences[key] = value + + +# endregion + + +# region Main + + +async def main(): + """Demonstrate using AggregateContextProvider to combine multiple providers.""" + async with AzureCliCredential() as credential: + chat_client = AzureAIClient(credential=credential) + + # Create individual context providers + time_provider = TimeContextProvider() + persona_provider = PersonaContextProvider("You are a helpful and friendly AI assistant named Max.") + preferences_provider = PreferencesContextProvider() + + # Combine them using AggregateContextProvider + aggregate_provider = AggregateContextProvider([ + time_provider, + persona_provider, + preferences_provider, + ]) + + # Create the agent with the aggregate provider + async with ChatAgent( + chat_client=chat_client, + instructions="You are a helpful assistant.", + context_provider=aggregate_provider, + ) as agent: + # Create a new thread for the conversation + thread = agent.get_new_thread() + + # First message - the agent should include time and persona context + print("User: Hello! Who are you?") + result = await agent.run("Hello! Who are you?", thread=thread) + print(f"Agent: {result}\n") + + # Set a preference + print("User: I prefer language: formal English") + result = await agent.run("I prefer language: formal English", thread=thread) + print(f"Agent: {result}\n") + + # Ask something - the agent should now include the preference + print("User: Can you tell me a fun fact?") + result = await agent.run("Can you tell me a fun fact?", thread=thread) + print(f"Agent: {result}\n") + + # Show what the aggregate provider is tracking + print(f"\nPreferences tracked: {preferences_provider.preferences}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/context_providers/azure_ai_search/README.md b/python/samples/getting_started/context_providers/azure_ai_search/README.md index 95283ace66..fe7635e72f 100644 --- a/python/samples/getting_started/context_providers/azure_ai_search/README.md +++ b/python/samples/getting_started/context_providers/azure_ai_search/README.md @@ -144,7 +144,7 @@ async with AzureAIAgentClient(credential=DefaultAzureCredential()) as client: async with ChatAgent( chat_client=client, model=model_deployment, - context_providers=search_provider, + context_provider=search_provider, ) as agent: response = await agent.run("What information is in the knowledge base?") ``` @@ -169,7 +169,7 @@ search_provider = AzureAISearchContextProvider( async with ChatAgent( chat_client=client, model=model_deployment, - context_providers=search_provider, + context_provider=search_provider, ) as agent: response = await agent.run("Analyze and compare topics across documents") ``` diff --git a/python/samples/getting_started/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py b/python/samples/getting_started/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py index 132298c0b2..a1c389fb2a 100644 --- a/python/samples/getting_started/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py +++ b/python/samples/getting_started/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py @@ -120,7 +120,7 @@ async def main() -> None: "Use the provided context from the knowledge base to answer complex " "questions that may require synthesizing information from multiple sources." ), - context_providers=[search_provider], + context_provider=search_provider, ) as agent, ): print("=== Azure AI Agent with Search Context (Agentic Mode) ===\n") diff --git a/python/samples/getting_started/context_providers/azure_ai_search/azure_ai_with_search_context_semantic.py b/python/samples/getting_started/context_providers/azure_ai_search/azure_ai_with_search_context_semantic.py index 90ea4bbe4a..a504de7447 100644 --- a/python/samples/getting_started/context_providers/azure_ai_search/azure_ai_with_search_context_semantic.py +++ b/python/samples/getting_started/context_providers/azure_ai_search/azure_ai_with_search_context_semantic.py @@ -76,7 +76,7 @@ async def main() -> None: "You are a helpful assistant. Use the provided context from the " "knowledge base to answer questions accurately." ), - context_providers=[search_provider], + context_provider=search_provider, ) as agent, ): print("=== Azure AI Agent with Search Context (Semantic Mode) ===\n") diff --git a/python/samples/getting_started/context_providers/mem0/mem0_basic.py b/python/samples/getting_started/context_providers/mem0/mem0_basic.py index 5fef82d390..d45da21ea9 100644 --- a/python/samples/getting_started/context_providers/mem0/mem0_basic.py +++ b/python/samples/getting_started/context_providers/mem0/mem0_basic.py @@ -36,7 +36,7 @@ async def main() -> None: name="FriendlyAssistant", instructions="You are a friendly assistant.", tools=retrieve_company_report, - context_providers=Mem0Provider(user_id=user_id), + context_provider=Mem0Provider(user_id=user_id), ) as agent, ): # First ask the agent to retrieve a company report with no previous context. diff --git a/python/samples/getting_started/context_providers/mem0/mem0_oss.py b/python/samples/getting_started/context_providers/mem0/mem0_oss.py index dcf92e8024..b0f1f0867d 100644 --- a/python/samples/getting_started/context_providers/mem0/mem0_oss.py +++ b/python/samples/getting_started/context_providers/mem0/mem0_oss.py @@ -39,7 +39,7 @@ async def main() -> None: name="FriendlyAssistant", instructions="You are a friendly assistant.", tools=retrieve_company_report, - context_providers=Mem0Provider(user_id=user_id, mem0_client=local_mem0_client), + context_provider=Mem0Provider(user_id=user_id, mem0_client=local_mem0_client), ) as agent, ): # First ask the agent to retrieve a company report with no previous context. diff --git a/python/samples/getting_started/context_providers/mem0/mem0_threads.py b/python/samples/getting_started/context_providers/mem0/mem0_threads.py index d014384394..4ade112030 100644 --- a/python/samples/getting_started/context_providers/mem0/mem0_threads.py +++ b/python/samples/getting_started/context_providers/mem0/mem0_threads.py @@ -31,7 +31,7 @@ async def example_global_thread_scope() -> None: name="GlobalMemoryAssistant", instructions="You are an assistant that remembers user preferences across conversations.", tools=get_user_preferences, - context_providers=Mem0Provider( + context_provider=Mem0Provider( user_id=user_id, thread_id=global_thread_id, scope_to_per_operation_thread_id=False, # Share memories across all threads @@ -69,7 +69,7 @@ async def example_per_operation_thread_scope() -> None: name="ScopedMemoryAssistant", instructions="You are an assistant with thread-scoped memory.", tools=get_user_preferences, - context_providers=Mem0Provider( + context_provider=Mem0Provider( user_id=user_id, scope_to_per_operation_thread_id=True, # Isolate memories per thread ), @@ -116,14 +116,14 @@ async def example_multiple_agents() -> None: AzureAIAgentClient(credential=credential).create_agent( name="PersonalAssistant", instructions="You are a personal assistant that helps with personal tasks.", - context_providers=Mem0Provider( + context_provider=Mem0Provider( agent_id=agent_id_1, ), ) as personal_agent, AzureAIAgentClient(credential=credential).create_agent( name="WorkAssistant", instructions="You are a work assistant that helps with professional tasks.", - context_providers=Mem0Provider( + context_provider=Mem0Provider( agent_id=agent_id_2, ), ) as work_agent, diff --git a/python/samples/getting_started/context_providers/redis/redis_basics.py b/python/samples/getting_started/context_providers/redis/redis_basics.py index ffa8c32a60..a4295cdb08 100644 --- a/python/samples/getting_started/context_providers/redis/redis_basics.py +++ b/python/samples/getting_started/context_providers/redis/redis_basics.py @@ -185,7 +185,7 @@ async def main() -> None: "Before answering, always check for stored context" ), tools=[], - context_providers=provider, + context_provider=provider, ) # Teach a user preference; the agent writes this to the provider's memory @@ -227,7 +227,7 @@ async def main() -> None: "Before answering, always check for stored context" ), tools=search_flights, - context_providers=provider, + context_provider=provider, ) # Invoke the tool; outputs become part of memory/context query = "Are there any flights from new york city (jfk) to la? Give me details" diff --git a/python/samples/getting_started/context_providers/redis/redis_conversation.py b/python/samples/getting_started/context_providers/redis/redis_conversation.py index 26748ae1c0..9e6331f6b3 100644 --- a/python/samples/getting_started/context_providers/redis/redis_conversation.py +++ b/python/samples/getting_started/context_providers/redis/redis_conversation.py @@ -70,7 +70,7 @@ async def main() -> None: "Before answering, always check for stored context" ), tools=[], - context_providers=provider, + context_provider=provider, chat_message_store_factory=chat_message_store_factory, ) diff --git a/python/samples/getting_started/context_providers/redis/redis_threads.py b/python/samples/getting_started/context_providers/redis/redis_threads.py index 6a9022895c..ff29dc5130 100644 --- a/python/samples/getting_started/context_providers/redis/redis_threads.py +++ b/python/samples/getting_started/context_providers/redis/redis_threads.py @@ -70,7 +70,7 @@ async def example_global_thread_scope() -> None: "Before answering, always check for stored context containing information" ), tools=[], - context_providers=provider, + context_provider=provider, ) # Store a preference in the global scope @@ -128,7 +128,7 @@ async def example_per_operation_thread_scope() -> None: agent = client.create_agent( name="ScopedMemoryAssistant", instructions="You are an assistant with thread-scoped memory.", - context_providers=provider, + context_provider=provider, ) # Create a specific thread for this scoped provider @@ -193,7 +193,7 @@ async def example_multiple_agents() -> None: personal_agent = client.create_agent( name="PersonalAssistant", instructions="You are a personal assistant that helps with personal tasks.", - context_providers=personal_provider, + context_provider=personal_provider, ) work_provider = RedisProvider( @@ -211,7 +211,7 @@ async def example_multiple_agents() -> None: work_agent = client.create_agent( name="WorkAssistant", instructions="You are a work assistant that helps with professional tasks.", - context_providers=work_provider, + context_provider=work_provider, ) # Store personal information diff --git a/python/samples/getting_started/context_providers/simple_context_provider.py b/python/samples/getting_started/context_providers/simple_context_provider.py index 38e5d4dd52..f69d5b96c6 100644 --- a/python/samples/getting_started/context_providers/simple_context_provider.py +++ b/python/samples/getting_started/context_providers/simple_context_provider.py @@ -4,7 +4,7 @@ from collections.abc import MutableSequence, Sequence from typing import Any -from agent_framework import ChatAgent, ChatClientProtocol, ChatMessage, ChatOptions, Context, ContextProvider +from agent_framework import ChatAgent, ChatClientProtocol, ChatMessage, Context, ContextProvider from agent_framework.azure import AzureAIClient from azure.identity.aio import AzureCliCredential from pydantic import BaseModel @@ -46,11 +46,9 @@ async def invoked( # Use the chat client to extract structured information result = await self._chat_client.get_response( messages=request_messages, # type: ignore - chat_options=ChatOptions( - instructions="Extract the user's name and age from the message if present. " - "If not present return nulls.", - response_format=UserInfo, - ), + instructions="Extract the user's name and age from the message if present. " + "If not present return nulls.", + options={"response_format": UserInfo}, ) # Update user info with extracted data @@ -100,7 +98,7 @@ async def main(): async with ChatAgent( chat_client=chat_client, instructions="You are a friendly assistant. Always address the user by their name.", - context_providers=memory_provider, + context_provider=memory_provider, ) as agent: # Create a new thread for the conversation thread = agent.get_new_thread() diff --git a/python/samples/getting_started/devui/README.md b/python/samples/getting_started/devui/README.md index 1a8359f359..bfbee3a70b 100644 --- a/python/samples/getting_started/devui/README.md +++ b/python/samples/getting_started/devui/README.md @@ -62,9 +62,10 @@ agent_name/ | Sample | Description | Features | Required Environment Variables | | -------------------------------------------- | ----------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------- | -| [**workflow_agents/**](workflow_agents/) | Content review workflow with agents as executors | Agents as workflow nodes, conditional routing based on structured outputs, quality-based paths (Writer → Reviewer → Editor/Publisher) | `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_CHAT_DEPLOYMENT_NAME`, `AZURE_OPENAI_ENDPOINT` | +| [**declarative/**](declarative/) | Declarative YAML workflow with conditional branching | YAML-based workflow definition, conditional logic, no Python code required | None - uses mock data | +| [**workflow_agents/**](workflow_agents/) | Content review workflow with agents as executors | Agents as workflow nodes, conditional routing based on structured outputs, quality-based paths (Writer -> Reviewer -> Editor/Publisher) | `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_CHAT_DEPLOYMENT_NAME`, `AZURE_OPENAI_ENDPOINT` | | [**spam_workflow/**](spam_workflow/) | 5-step email spam detection workflow with branching logic | Sequential execution, conditional branching (spam vs. legitimate), multiple executors, mock spam detection | None - uses mock data | -| [**fanout_workflow/**](fanout_workflow/) | Advanced data processing workflow with parallel execution | Fan-out/fan-in patterns, complex state management, multi-stage processing (validation → transformation → quality assurance) | None - uses mock data | +| [**fanout_workflow/**](fanout_workflow/) | Advanced data processing workflow with parallel execution | Fan-out/fan-in patterns, complex state management, multi-stage processing (validation -> transformation -> quality assurance) | None - uses mock data | ### Standalone Examples diff --git a/python/samples/getting_started/devui/declarative/__init__.py b/python/samples/getting_started/devui/declarative/__init__.py new file mode 100644 index 0000000000..1fe0817125 --- /dev/null +++ b/python/samples/getting_started/devui/declarative/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Declarative workflow sample for DevUI.""" diff --git a/python/samples/getting_started/devui/declarative/workflow.py b/python/samples/getting_started/devui/declarative/workflow.py new file mode 100644 index 0000000000..70a746d76b --- /dev/null +++ b/python/samples/getting_started/devui/declarative/workflow.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Run the declarative workflow sample with DevUI. + +Demonstrates conditional branching based on age input using YAML-defined workflow. +""" + +from pathlib import Path + +from agent_framework.declarative import WorkflowFactory +from agent_framework.devui import serve + +factory = WorkflowFactory() +workflow_path = Path(__file__).parent / "workflow.yaml" +workflow = factory.create_workflow_from_yaml_path(workflow_path) + + +def main(): + """Run the declarative workflow with DevUI.""" + serve(entities=[workflow], auto_open=True) + + +if __name__ == "__main__": + main() diff --git a/python/samples/getting_started/devui/declarative/workflow.yaml b/python/samples/getting_started/devui/declarative/workflow.yaml new file mode 100644 index 0000000000..947f168838 --- /dev/null +++ b/python/samples/getting_started/devui/declarative/workflow.yaml @@ -0,0 +1,64 @@ +name: conditional-workflow +description: Demonstrates conditional branching based on user input + +inputs: + age: + type: integer + description: The user's age in years + +actions: + - kind: SetValue + id: get_age + displayName: Get user age + path: turn.age + value: =inputs.age + + - kind: If + id: check_age + displayName: Check age category + condition: =turn.age < 13 + then: + - kind: SetValue + path: turn.category + value: child + - kind: SendActivity + activity: + text: "Welcome, young one! Here are some fun activities for kids." + else: + - kind: If + condition: =turn.age < 20 + then: + - kind: SetValue + path: turn.category + value: teenager + - kind: SendActivity + activity: + text: "Hey there! Check out these cool things for teens." + else: + - kind: If + condition: =turn.age < 65 + then: + - kind: SetValue + path: turn.category + value: adult + - kind: SendActivity + activity: + text: "Welcome! Here are our professional services." + else: + - kind: SetValue + path: turn.category + value: senior + - kind: SendActivity + activity: + text: "Welcome! Enjoy our senior member benefits." + + - kind: SendActivity + id: summary + displayName: Send category summary + activity: + text: '=Concat("You have been categorized as: ", turn.category)' + + - kind: SetValue + id: set_output + path: workflow.outputs.category + value: =turn.category diff --git a/python/samples/getting_started/devui/workflow_agents/workflow.py b/python/samples/getting_started/devui/workflow_agents/workflow.py index 3c6307aef8..dc70112783 100644 --- a/python/samples/getting_started/devui/workflow_agents/workflow.py +++ b/python/samples/getting_started/devui/workflow_agents/workflow.py @@ -40,7 +40,7 @@ def needs_editing(message: Any) -> bool: if not isinstance(message, AgentExecutorResponse): return False try: - review = ReviewResult.model_validate_json(message.agent_run_response.text) + review = ReviewResult.model_validate_json(message.agent_response.text) return review.score < 80 except Exception: return False @@ -52,7 +52,7 @@ def is_approved(message: Any) -> bool: if not isinstance(message, AgentExecutorResponse): return True try: - review = ReviewResult.model_validate_json(message.agent_run_response.text) + review = ReviewResult.model_validate_json(message.agent_response.text) return review.score >= 80 except Exception: return True @@ -86,7 +86,7 @@ def is_approved(message: Any) -> bool: "- feedback: concise, actionable feedback\n" "- clarity, completeness, accuracy, structure: individual scores (0-100)" ), - response_format=ReviewResult, + default_options={"response_format": ReviewResult}, ) # Create Editor agent - improves content based on feedback diff --git a/python/samples/getting_started/durabletask/03_single_agent_streaming/README.md b/python/samples/getting_started/durabletask/03_single_agent_streaming/README.md index 5505acbe6c..6e9f1428bf 100644 --- a/python/samples/getting_started/durabletask/03_single_agent_streaming/README.md +++ b/python/samples/getting_started/durabletask/03_single_agent_streaming/README.md @@ -6,7 +6,7 @@ This sample demonstrates how to use Redis Streams with agent response callbacks - Using `AgentResponseCallbackProtocol` to capture streaming agent responses. - Persisting streaming chunks to Redis Streams for reliable delivery. -- Non-blocking agent execution with `wait_for_response=False` (fire-and-forget mode). +- Non-blocking agent execution with `options={"wait_for_response": False}` (fire-and-forget mode). - Cursor-based resumption for disconnected clients. - Decoupling agent execution from response streaming. @@ -114,7 +114,7 @@ The client uses fire-and-forget mode to start the agent and streams from Redis: ```python # Start agent run with wait_for_response=False for non-blocking execution -travel_planner.run(user_message, thread=thread, wait_for_response=False) +travel_planner.run(user_message, thread=thread, options={"wait_for_response": False}) # Stream response from Redis while the agent is processing async with await get_stream_handler() as stream_handler: @@ -125,7 +125,7 @@ async with await get_stream_handler() as stream_handler: break ``` -**Fire-and-Forget Mode**: The `wait_for_response=False` parameter enables non-blocking execution. The `run()` method signals the agent and returns immediately, allowing the client to stream from Redis without blocking. +**Fire-and-Forget Mode**: Use `options={"wait_for_response": False}` to enable non-blocking execution. The `run()` method signals the agent and returns immediately, allowing the client to stream from Redis without blocking. ### Cursor-Based Resumption diff --git a/python/samples/getting_started/durabletask/03_single_agent_streaming/client.py b/python/samples/getting_started/durabletask/03_single_agent_streaming/client.py index be10eddc93..10aed9a954 100644 --- a/python/samples/getting_started/durabletask/03_single_agent_streaming/client.py +++ b/python/samples/getting_started/durabletask/03_single_agent_streaming/client.py @@ -165,7 +165,7 @@ def run_client(agent_client: DurableAIAgentClient) -> None: # Start the agent run with wait_for_response=False for non-blocking execution # This signals the agent to start processing without waiting for completion # The agent will execute in the background and write chunks to Redis - travel_planner.run(user_message, thread=thread, wait_for_response=False) + travel_planner.run(user_message, thread=thread, options={"wait_for_response": False}) # Stream the response from Redis # This demonstrates that the client can stream from Redis while diff --git a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/worker.py b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/worker.py index 81be068b73..402f349dd5 100644 --- a/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/worker.py +++ b/python/samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency/worker.py @@ -16,7 +16,7 @@ import os from typing import Any -from agent_framework import AgentRunResponse +from agent_framework import AgentResponse from agent_framework.azure import AzureOpenAIChatClient from agent_framework_durabletask import DurableAIAgentOrchestrationContext, DurableAIAgentWorker from azure.identity import AzureCliCredential, DefaultAzureCredential @@ -97,9 +97,9 @@ def multi_agent_concurrent_orchestration(context: OrchestrationContext, prompt: logger.debug("[Orchestration] Both agents completed") - # Extract results from the tasks - DurableAgentTask yields AgentRunResponse - physicist_result: AgentRunResponse = task_results[0] - chemist_result: AgentRunResponse = task_results[1] + # Extract results from the tasks - DurableAgentTask yields AgentResponse + physicist_result: AgentResponse = task_results[0] + chemist_result: AgentResponse = task_results[1] result = { "physicist": physicist_result.text, diff --git a/python/samples/getting_started/durabletask/06_multi_agent_orchestration_conditionals/README.md b/python/samples/getting_started/durabletask/06_multi_agent_orchestration_conditionals/README.md index 03b5df1b05..f6a40c087b 100644 --- a/python/samples/getting_started/durabletask/06_multi_agent_orchestration_conditionals/README.md +++ b/python/samples/getting_started/durabletask/06_multi_agent_orchestration_conditionals/README.md @@ -6,7 +6,7 @@ This sample demonstrates conditional orchestration logic with two agents that an - Multi-agent orchestration with two specialized agents (SpamDetectionAgent and EmailAssistantAgent). - Conditional branching with different execution paths based on spam detection results. -- Structured outputs using Pydantic models with `response_format` for type-safe agent responses. +- Structured outputs using Pydantic models with `options={"response_format": ...}` for type-safe agent responses. - Activity functions for side effects (spam handling and email sending). - Decision-based routing where orchestration logic branches on agent output. diff --git a/python/samples/getting_started/durabletask/06_multi_agent_orchestration_conditionals/worker.py b/python/samples/getting_started/durabletask/06_multi_agent_orchestration_conditionals/worker.py index cfe26302d7..57f38ccf70 100644 --- a/python/samples/getting_started/durabletask/06_multi_agent_orchestration_conditionals/worker.py +++ b/python/samples/getting_started/durabletask/06_multi_agent_orchestration_conditionals/worker.py @@ -16,7 +16,7 @@ import os from typing import Any, cast -from agent_framework import AgentRunResponse +from agent_framework import AgentResponse from agent_framework.azure import AzureOpenAIChatClient from agent_framework_durabletask import DurableAIAgentOrchestrationContext, DurableAIAgentWorker from azure.identity import AzureCliCredential, DefaultAzureCredential @@ -148,10 +148,10 @@ def spam_detection_orchestration(context: OrchestrationContext, payload_raw: Any logger.info("[Orchestration] Running spam detection agent: %s", spam_prompt) spam_result_task = spam_agent.run( messages=spam_prompt, - response_format=SpamDetectionResult, + options={"response_format": SpamDetectionResult}, ) - spam_result_raw: AgentRunResponse = yield spam_result_task + spam_result_raw: AgentResponse = yield spam_result_task spam_result = cast(SpamDetectionResult, spam_result_raw.value) logger.info("[Orchestration] Spam detection result: is_spam=%s", spam_result.is_spam) @@ -178,10 +178,10 @@ def spam_detection_orchestration(context: OrchestrationContext, payload_raw: Any logger.info("[Orchestration] Running email assistant agent: %s", email_prompt) email_result_task = email_agent.run( messages=email_prompt, - response_format=EmailResponse, + options={"response_format": EmailResponse}, ) - email_result_raw: AgentRunResponse = yield email_result_task + email_result_raw: AgentResponse = yield email_result_task email_result = cast(EmailResponse, email_result_raw.value) logger.debug("[Orchestration] Email response drafted, sending...") diff --git a/python/samples/getting_started/durabletask/07_single_agent_orchestration_hitl/README.md b/python/samples/getting_started/durabletask/07_single_agent_orchestration_hitl/README.md index 59f1186b33..fbfe905d59 100644 --- a/python/samples/getting_started/durabletask/07_single_agent_orchestration_hitl/README.md +++ b/python/samples/getting_started/durabletask/07_single_agent_orchestration_hitl/README.md @@ -8,7 +8,7 @@ This sample demonstrates the human-in-the-loop pattern where a WriterAgent gener - External event handling using `wait_for_external_event()` to receive human input. - Timeout management with `when_any()` to race between approval event and timeout. - Iterative refinement where agent regenerates content based on reviewer feedback. -- Structured outputs using Pydantic models with `response_format` for type-safe agent responses. +- Structured outputs using Pydantic models with `options={"response_format": ...}` for type-safe agent responses. - Activity functions for notifications and publishing as separate side effects. - Long-running orchestrations maintaining state across multiple interactions. diff --git a/python/samples/getting_started/durabletask/07_single_agent_orchestration_hitl/worker.py b/python/samples/getting_started/durabletask/07_single_agent_orchestration_hitl/worker.py index 7720def690..3626cedc4e 100644 --- a/python/samples/getting_started/durabletask/07_single_agent_orchestration_hitl/worker.py +++ b/python/samples/getting_started/durabletask/07_single_agent_orchestration_hitl/worker.py @@ -17,7 +17,7 @@ import os from typing import Any, cast -from agent_framework import AgentRunResponse +from agent_framework import AgentResponse from agent_framework.azure import AzureOpenAIChatClient from agent_framework_durabletask import DurableAIAgentOrchestrationContext, DurableAIAgentWorker from azure.identity import AzureCliCredential, DefaultAzureCredential @@ -153,10 +153,10 @@ def content_generation_hitl_orchestration( # Generate initial content logger.info("[Orchestration] Generating initial content...") - initial_response: AgentRunResponse = yield writer.run( + initial_response: AgentResponse = yield writer.run( messages=f"Write a short article about '{payload.topic}'.", thread=writer_thread, - response_format=GeneratedContent, + options={"response_format": GeneratedContent}, ) content = cast(GeneratedContent, initial_response.value) @@ -241,10 +241,10 @@ def content_generation_hitl_orchestration( logger.warning(f"Regenerating with ThreadID: {writer_thread.session_id}") - rewrite_response: AgentRunResponse = yield writer.run( + rewrite_response: AgentResponse = yield writer.run( messages=rewrite_prompt, thread=writer_thread, - response_format=GeneratedContent, + options={"response_format": GeneratedContent}, ) rewritten_content = cast(GeneratedContent, rewrite_response.value) diff --git a/python/samples/getting_started/mcp/mcp_api_key_auth.py b/python/samples/getting_started/mcp/mcp_api_key_auth.py index f3ec1777e6..d80d92d4fa 100644 --- a/python/samples/getting_started/mcp/mcp_api_key_auth.py +++ b/python/samples/getting_started/mcp/mcp_api_key_auth.py @@ -4,6 +4,7 @@ from agent_framework import ChatAgent, MCPStreamableHTTPTool from agent_framework.openai import OpenAIResponsesClient +from httpx import AsyncClient """ MCP Authentication Example @@ -31,13 +32,16 @@ async def api_key_auth_example() -> None: "Authorization": f"Bearer {api_key}", } - # Create MCP tool with authentication headers + # Create HTTP client with authentication headers + http_client = AsyncClient(headers=auth_headers) + + # Create MCP tool with the configured HTTP client async with ( MCPStreamableHTTPTool( name="MCP tool", description="MCP tool description", url=mcp_server_url, - headers=auth_headers, # Authentication headers + http_client=http_client, # Pass HTTP client with authentication headers ) as mcp_tool, ChatAgent( chat_client=OpenAIResponsesClient(), diff --git a/python/samples/getting_started/middleware/agent_and_run_level_middleware.py b/python/samples/getting_started/middleware/agent_and_run_level_middleware.py index fce26d2a71..b421454163 100644 --- a/python/samples/getting_started/middleware/agent_and_run_level_middleware.py +++ b/python/samples/getting_started/middleware/agent_and_run_level_middleware.py @@ -8,8 +8,8 @@ from agent_framework import ( AgentMiddleware, + AgentResponse, AgentRunContext, - AgentRunResponse, FunctionInvocationContext, ) from agent_framework.azure import AzureAIAgentClient @@ -121,7 +121,7 @@ class CachingMiddleware(AgentMiddleware): """Run-level caching middleware for expensive operations.""" def __init__(self) -> None: - self.cache: dict[str, AgentRunResponse] = {} + self.cache: dict[str, AgentResponse] = {} async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: # Create a simple cache key from the last message @@ -202,7 +202,7 @@ async def main() -> None: print(f"User: {query}") result = await agent.run( query, - middleware=HighPriorityMiddleware(), # Run-level middleware + middleware=[HighPriorityMiddleware()], # Run-level middleware ) print(f"Agent: {result.text if result.text else 'No response'}") print() diff --git a/python/samples/getting_started/middleware/chat_middleware.py b/python/samples/getting_started/middleware/chat_middleware.py index 9573e6e666..79a9ab9c7a 100644 --- a/python/samples/getting_started/middleware/chat_middleware.py +++ b/python/samples/getting_started/middleware/chat_middleware.py @@ -146,7 +146,7 @@ async def class_based_chat_middleware() -> None: name="EnhancedChatAgent", instructions="You are a helpful AI assistant.", # Register class-based middleware at agent level (applies to all runs) - middleware=InputObserverMiddleware(), + middleware=[InputObserverMiddleware()], tools=get_weather, ) as agent, ): @@ -168,7 +168,7 @@ async def function_based_chat_middleware() -> None: name="FunctionMiddlewareAgent", instructions="You are a helpful AI assistant.", # Register function-based middleware at agent level - middleware=security_and_override_middleware, + middleware=[security_and_override_middleware], ) as agent, ): # Scenario with normal query @@ -226,7 +226,7 @@ async def run_level_middleware() -> None: print(f"User: {query}") result = await agent.run( query, - middleware=security_and_override_middleware, + middleware=[security_and_override_middleware], ) print(f"Response: {result.text if result.text else 'No response'}") diff --git a/python/samples/getting_started/middleware/class_based_middleware.py b/python/samples/getting_started/middleware/class_based_middleware.py index 0bf990a1f4..52d783c0d0 100644 --- a/python/samples/getting_started/middleware/class_based_middleware.py +++ b/python/samples/getting_started/middleware/class_based_middleware.py @@ -8,8 +8,8 @@ from agent_framework import ( AgentMiddleware, + AgentResponse, AgentRunContext, - AgentRunResponse, ChatMessage, FunctionInvocationContext, FunctionMiddleware, @@ -58,7 +58,7 @@ async def process( if "password" in query.lower() or "secret" in query.lower(): print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.") # Override the result with warning message - context.result = AgentRunResponse( + context.result = AgentResponse( messages=[ ChatMessage(role=Role.ASSISTANT, text="Detected sensitive information, the request is blocked.") ] diff --git a/python/samples/getting_started/middleware/exception_handling_with_middleware.py b/python/samples/getting_started/middleware/exception_handling_with_middleware.py index 07cfdb0854..285076be92 100644 --- a/python/samples/getting_started/middleware/exception_handling_with_middleware.py +++ b/python/samples/getting_started/middleware/exception_handling_with_middleware.py @@ -62,7 +62,7 @@ async def main() -> None: name="DataAgent", instructions="You are a helpful data assistant. Use the data service tool to fetch information for users.", tools=unstable_data_service, - middleware=exception_handling_middleware, + middleware=[exception_handling_middleware], ) as agent, ): query = "Get user statistics" diff --git a/python/samples/getting_started/middleware/middleware_termination.py b/python/samples/getting_started/middleware/middleware_termination.py index 7803b14ca5..1eff584721 100644 --- a/python/samples/getting_started/middleware/middleware_termination.py +++ b/python/samples/getting_started/middleware/middleware_termination.py @@ -7,8 +7,8 @@ from agent_framework import ( AgentMiddleware, + AgentResponse, AgentRunContext, - AgentRunResponse, ChatMessage, Role, ) @@ -57,7 +57,7 @@ async def process( print(f"[PreTerminationMiddleware] Blocked word '{blocked_word}' detected. Terminating request.") # Set a custom response - context.result = AgentRunResponse( + context.result = AgentResponse( messages=[ ChatMessage( role=Role.ASSISTANT, @@ -114,7 +114,7 @@ async def pre_termination_middleware() -> None: name="WeatherAgent", instructions="You are a helpful weather assistant.", tools=get_weather, - middleware=PreTerminationMiddleware(blocked_words=["bad", "inappropriate"]), + middleware=[PreTerminationMiddleware(blocked_words=["bad", "inappropriate"])], ) as agent, ): # Test with normal query @@ -141,7 +141,7 @@ async def post_termination_middleware() -> None: name="WeatherAgent", instructions="You are a helpful weather assistant.", tools=get_weather, - middleware=PostTerminationMiddleware(max_responses=1), + middleware=[PostTerminationMiddleware(max_responses=1)], ) as agent, ): # First run (should work) diff --git a/python/samples/getting_started/middleware/override_result_with_middleware.py b/python/samples/getting_started/middleware/override_result_with_middleware.py index 686dceea9e..5455ebe7b6 100644 --- a/python/samples/getting_started/middleware/override_result_with_middleware.py +++ b/python/samples/getting_started/middleware/override_result_with_middleware.py @@ -6,9 +6,9 @@ from typing import Annotated from agent_framework import ( + AgentResponse, + AgentResponseUpdate, AgentRunContext, - AgentRunResponse, - AgentRunResponseUpdate, ChatMessage, Role, TextContent, @@ -64,15 +64,15 @@ async def weather_override_middleware( if context.is_streaming: # For streaming: create an async generator that yields chunks - async def override_stream() -> AsyncIterable[AgentRunResponseUpdate]: + async def override_stream() -> AsyncIterable[AgentResponseUpdate]: for chunk in chunks: - yield AgentRunResponseUpdate(contents=[TextContent(text=chunk)]) + yield AgentResponseUpdate(contents=[TextContent(text=chunk)]) context.result = override_stream() else: # For non-streaming: just replace with the string message custom_message = "".join(chunks) - context.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=custom_message)]) + context.result = AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=custom_message)]) async def main() -> None: @@ -87,7 +87,7 @@ async def main() -> None: name="WeatherAgent", instructions="You are a helpful weather assistant. Use the weather tool to get current conditions.", tools=get_weather, - middleware=weather_override_middleware, + middleware=[weather_override_middleware], ) as agent, ): # Non-streaming example diff --git a/python/samples/getting_started/middleware/thread_behavior_middleware.py b/python/samples/getting_started/middleware/thread_behavior_middleware.py index 2c2d378baa..b7b535541e 100644 --- a/python/samples/getting_started/middleware/thread_behavior_middleware.py +++ b/python/samples/getting_started/middleware/thread_behavior_middleware.py @@ -74,7 +74,7 @@ async def main() -> None: name="WeatherAgent", instructions="You are a helpful weather assistant.", tools=get_weather, - middleware=thread_tracking_middleware, + middleware=[thread_tracking_middleware], # Configure agent with message store factory to persist conversation history chat_message_store_factory=ChatMessageStore, ) diff --git a/python/samples/getting_started/observability/agent_observability.py b/python/samples/getting_started/observability/agent_observability.py index b60b25c004..cd1b505194 100644 --- a/python/samples/getting_started/observability/agent_observability.py +++ b/python/samples/getting_started/observability/agent_observability.py @@ -47,7 +47,7 @@ async def main(): thread = agent.get_new_thread() for question in questions: print(f"\nUser: {question}") - print(f"{agent.display_name}: ", end="") + print(f"{agent.name}: ", end="") async for update in agent.run_stream( question, thread=thread, diff --git a/python/samples/getting_started/observability/agent_with_foundry_tracing.py b/python/samples/getting_started/observability/agent_with_foundry_tracing.py index bc5667e564..9bce1f1b4a 100644 --- a/python/samples/getting_started/observability/agent_with_foundry_tracing.py +++ b/python/samples/getting_started/observability/agent_with_foundry_tracing.py @@ -84,7 +84,7 @@ async def main(): thread = agent.get_new_thread() for question in questions: print(f"\nUser: {question}") - print(f"{agent.display_name}: ", end="") + print(f"{agent.name}: ", end="") async for update in agent.run_stream( question, thread=thread, diff --git a/python/samples/getting_started/observability/azure_ai_agent_observability.py b/python/samples/getting_started/observability/azure_ai_agent_observability.py index 123c4a2e6e..f5804f4cfd 100644 --- a/python/samples/getting_started/observability/azure_ai_agent_observability.py +++ b/python/samples/getting_started/observability/azure_ai_agent_observability.py @@ -64,7 +64,7 @@ async def main(): thread = agent.get_new_thread() for question in questions: print(f"\nUser: {question}") - print(f"{agent.display_name}: ", end="") + print(f"{agent.name}: ", end="") async for update in agent.run_stream( question, thread=thread, diff --git a/python/samples/getting_started/purview_agent/README.md b/python/samples/getting_started/purview_agent/README.md index d4cfeca3df..8982a68830 100644 --- a/python/samples/getting_started/purview_agent/README.md +++ b/python/samples/getting_started/purview_agent/README.md @@ -116,18 +116,18 @@ This is only needed if you want to integrate with external caching systems. ```python class SimpleDictCacheProvider: """Custom cache provider that implements the CacheProvider protocol.""" - + def __init__(self) -> None: self._cache: dict[str, Any] = {} - + async def get(self, key: str) -> Any | None: """Get a value from the cache.""" return self._cache.get(key) - + async def set(self, key: str, value: Any, ttl_seconds: int | None = None) -> None: """Set a value in the cache.""" self._cache[key] = value - + async def remove(self, key: str) -> None: """Remove a value from the cache.""" self._cache.pop(key, None) diff --git a/python/samples/getting_started/purview_agent/sample_purview_agent.py b/python/samples/getting_started/purview_agent/sample_purview_agent.py index b9518edf08..223eed55e3 100644 --- a/python/samples/getting_started/purview_agent/sample_purview_agent.py +++ b/python/samples/getting_started/purview_agent/sample_purview_agent.py @@ -25,7 +25,7 @@ import os from typing import Any -from agent_framework import AgentRunResponse, ChatAgent, ChatMessage, Role +from agent_framework import AgentResponse, ChatAgent, ChatMessage, Role from agent_framework.azure import AzureOpenAIChatClient from agent_framework.microsoft import ( PurviewChatPolicyMiddleware, @@ -154,14 +154,20 @@ async def run_with_agent_middleware() -> None: chat_client=chat_client, instructions=JOKER_INSTRUCTIONS, name=JOKER_NAME, - middleware=purview_agent_middleware, + middleware=[purview_agent_middleware], ) print("-- Agent Middleware Path --") - first: AgentRunResponse = await agent.run(ChatMessage(role=Role.USER, text="Tell me a joke about a pirate.", additional_properties={"user_id": user_id})) + first: AgentResponse = await agent.run( + ChatMessage(role=Role.USER, text="Tell me a joke about a pirate.", additional_properties={"user_id": user_id}) + ) print("First response (agent middleware):\n", first) - second: AgentRunResponse = await agent.run(ChatMessage(role=Role.USER, text="That was funny. Tell me another one.", additional_properties={"user_id": user_id})) + second: AgentResponse = await agent.run( + ChatMessage( + role=Role.USER, text="That was funny. Tell me another one.", additional_properties={"user_id": user_id} + ) + ) print("Second response (agent middleware):\n", second) @@ -195,7 +201,7 @@ async def run_with_chat_middleware() -> None: ) print("-- Chat Middleware Path --") - first: AgentRunResponse = await agent.run( + first: AgentResponse = await agent.run( ChatMessage( role=Role.USER, text="Give me a short clean joke.", @@ -204,7 +210,7 @@ async def run_with_chat_middleware() -> None: ) print("First response (chat middleware):\n", first) - second: AgentRunResponse = await agent.run( + second: AgentResponse = await agent.run( ChatMessage( role=Role.USER, text="One more please.", @@ -239,18 +245,20 @@ async def run_with_custom_cache_provider() -> None: chat_client=chat_client, instructions=JOKER_INSTRUCTIONS, name=JOKER_NAME, - middleware=purview_agent_middleware, + middleware=[purview_agent_middleware], ) print("-- Custom Cache Provider Path --") print("Using SimpleDictCacheProvider") - first: AgentRunResponse = await agent.run( - ChatMessage(role=Role.USER, text="Tell me a joke about a programmer.", additional_properties={"user_id": user_id}) + first: AgentResponse = await agent.run( + ChatMessage( + role=Role.USER, text="Tell me a joke about a programmer.", additional_properties={"user_id": user_id} + ) ) print("First response (custom provider):\n", first) - second: AgentRunResponse = await agent.run( + second: AgentResponse = await agent.run( ChatMessage(role=Role.USER, text="That's hilarious! One more?", additional_properties={"user_id": user_id}) ) print("Second response (custom provider):\n", second) @@ -279,18 +287,18 @@ async def run_with_custom_cache_provider() -> None: chat_client=chat_client, instructions=JOKER_INSTRUCTIONS, name=JOKER_NAME, - middleware=purview_agent_middleware, + middleware=[purview_agent_middleware], ) print("-- Default Cache Path --") print("Using default InMemoryCacheProvider with settings-based configuration") - first: AgentRunResponse = await agent.run( + first: AgentResponse = await agent.run( ChatMessage(role=Role.USER, text="Tell me a joke about AI.", additional_properties={"user_id": user_id}) ) print("First response (default cache):\n", first) - second: AgentRunResponse = await agent.run( + second: AgentResponse = await agent.run( ChatMessage(role=Role.USER, text="Nice! Another AI joke please.", additional_properties={"user_id": user_id}) ) print("Second response (default cache):\n", second) diff --git a/python/samples/getting_started/tools/ai_function_declaration_only.py b/python/samples/getting_started/tools/ai_function_declaration_only.py index 03a2e8f8ed..32ba7cdbc8 100644 --- a/python/samples/getting_started/tools/ai_function_declaration_only.py +++ b/python/samples/getting_started/tools/ai_function_declaration_only.py @@ -34,7 +34,7 @@ async def main(): Expected result: User: What is the current time? Result: { - "type": "agent_run_response", + "type": "agent_response", "messages": [ { "type": "chat_message", diff --git a/python/samples/getting_started/tools/ai_function_with_approval.py b/python/samples/getting_started/tools/ai_function_with_approval.py index bdc673bb2c..a74e1aed3f 100644 --- a/python/samples/getting_started/tools/ai_function_with_approval.py +++ b/python/samples/getting_started/tools/ai_function_with_approval.py @@ -4,7 +4,7 @@ from random import randrange from typing import TYPE_CHECKING, Annotated, Any -from agent_framework import AgentRunResponse, ChatAgent, ChatMessage, ai_function +from agent_framework import AgentResponse, ChatAgent, ChatMessage, ai_function from agent_framework.openai import OpenAIResponsesClient if TYPE_CHECKING: @@ -39,7 +39,7 @@ def get_weather_detail(location: Annotated[str, "The city and state, e.g. San Fr ) -async def handle_approvals(query: str, agent: "AgentProtocol") -> AgentRunResponse: +async def handle_approvals(query: str, agent: "AgentProtocol") -> AgentResponse: """Handle function call approvals. When we don't have a thread, we need to ensure we include the original query, diff --git a/python/samples/getting_started/workflows/README.md b/python/samples/getting_started/workflows/README.md index 1adaab2c7c..8ca5e0f4bc 100644 --- a/python/samples/getting_started/workflows/README.md +++ b/python/samples/getting_started/workflows/README.md @@ -115,18 +115,14 @@ For additional observability samples in Agent Framework, see the [observability | Concurrent Orchestration (Custom Aggregator) | [orchestration/concurrent_custom_aggregator.py](./orchestration/concurrent_custom_aggregator.py) | Override aggregator via callback; summarize results with an LLM | | Concurrent Orchestration (Custom Agent Executors) | [orchestration/concurrent_custom_agent_executors.py](./orchestration/concurrent_custom_agent_executors.py) | Child executors own ChatAgents; concurrent fan-out/fan-in via ConcurrentBuilder | | Concurrent Orchestration (Participant Factory) | [orchestration/concurrent_participant_factory.py](./orchestration/concurrent_participant_factory.py) | Use participant factories for state isolation between workflow instances | -| Group Chat with Agent Manager | [orchestration/group_chat_agent_manager.py](./orchestration/group_chat_agent_manager.py) | Agent-based manager using `set_manager()` to select next speaker | +| Group Chat with Agent Manager | [orchestration/group_chat_agent_manager.py](./orchestration/group_chat_agent_manager.py) | Agent-based manager using `with_agent_orchestrator()` to select next speaker | | Group Chat Philosophical Debate | [orchestration/group_chat_philosophical_debate.py](./orchestration/group_chat_philosophical_debate.py) | Agent manager moderates long-form, multi-round debate across diverse participants | | Group Chat with Simple Function Selector | [orchestration/group_chat_simple_selector.py](./orchestration/group_chat_simple_selector.py) | Group chat with a simple function selector for next speaker | | Handoff (Simple) | [orchestration/handoff_simple.py](./orchestration/handoff_simple.py) | Single-tier routing: triage agent routes to specialists, control returns to user after each specialist response | -| Handoff (Specialist-to-Specialist) | [orchestration/handoff_specialist_to_specialist.py](./orchestration/handoff_specialist_to_specialist.py) | Multi-tier routing: specialists can hand off to other specialists using `.add_handoff()` fluent API | -| Handoff (Return-to-Previous) | [orchestration/handoff_return_to_previous.py](./orchestration/handoff_return_to_previous.py) | Return-to-previous routing: after user input, routes back to the previous specialist instead of coordinator using `.enable_return_to_previous()` | -| Handoff (Autonomous) | [orchestration/handoff_autonomous.py](./orchestration/handoff_autonomous.py) | Autonomous mode: specialists iterate independently until invoking a handoff tool using `.with_interaction_mode("autonomous", autonomous_turn_limit=N)` | +| Handoff (Autonomous) | [orchestration/handoff_autonomous.py](./orchestration/handoff_autonomous.py) | Autonomous mode: specialists iterate independently until invoking a handoff tool using `.with_autonomous_mode()` | | Handoff (Participant Factory) | [orchestration/handoff_participant_factory.py](./orchestration/handoff_participant_factory.py) | Use participant factories for state isolation between workflow instances | | Magentic Workflow (Multi-Agent) | [orchestration/magentic.py](./orchestration/magentic.py) | Orchestrate multiple agents with Magentic manager and streaming | -| Magentic + Human Plan Review | [orchestration/magentic_human_plan_update.py](./orchestration/magentic_human_plan_update.py) | Human reviews/updates the plan before execution | -| Magentic + Human Stall Intervention | [orchestration/magentic_human_replan.py](./orchestration/magentic_human_replan.py) | Human intervenes when workflow stalls with `with_human_input_on_stall()` | -| Magentic + Agent Clarification | [orchestration/magentic_agent_clarification.py](./orchestration/magentic_agent_clarification.py) | Agents ask clarifying questions via `ask_user` tool with `@ai_function(approval_mode="always_require")` | +| Magentic + Human Plan Review | [orchestration/magentic_human_plan_review.py](./orchestration/magentic_human_plan_review.py) | Human reviews/updates the plan before execution | | Magentic + Checkpoint Resume | [orchestration/magentic_checkpoint.py](./orchestration/magentic_checkpoint.py) | Resume Magentic orchestration from saved checkpoints | | Sequential Orchestration (Agents) | [orchestration/sequential_agents.py](./orchestration/sequential_agents.py) | Chain agents sequentially with shared conversation context | | Sequential Orchestration (Custom Executor) | [orchestration/sequential_custom_executors.py](./orchestration/sequential_custom_executors.py) | Mix agents with a summarizer that appends a compact summary | @@ -161,6 +157,21 @@ to configure which agents can route to which others with a fluent, type-safe API |---|---|---| | Concurrent with Visualization | [visualization/concurrent_with_visualization.py](./visualization/concurrent_with_visualization.py) | Fan-out/fan-in workflow with diagram export | +### declarative + +YAML-based declarative workflows allow you to define multi-agent orchestration patterns without writing Python code. See the [declarative workflows README](./declarative/README.md) for more details on YAML workflow syntax and available actions. + +| Sample | File | Concepts | +|---|---|---| +| Conditional Workflow | [declarative/conditional_workflow/](./declarative/conditional_workflow/) | Nested conditional branching based on user input | +| Customer Support | [declarative/customer_support/](./declarative/customer_support/) | Multi-agent customer support with routing | +| Deep Research | [declarative/deep_research/](./declarative/deep_research/) | Research workflow with planning, searching, and synthesis | +| Function Tools | [declarative/function_tools/](./declarative/function_tools/) | Invoking Python functions from declarative workflows | +| Human-in-Loop | [declarative/human_in_loop/](./declarative/human_in_loop/) | Interactive workflows that request user input | +| Marketing | [declarative/marketing/](./declarative/marketing/) | Marketing content generation workflow | +| Simple Workflow | [declarative/simple_workflow/](./declarative/simple_workflow/) | Basic workflow with variable setting, conditionals, and loops | +| Student Teacher | [declarative/student_teacher/](./declarative/student_teacher/) | Student-teacher interaction pattern | + ### resources - Sample text inputs used by certain workflows: diff --git a/python/samples/getting_started/workflows/_start-here/step4_using_factories.py b/python/samples/getting_started/workflows/_start-here/step4_using_factories.py index 1c9dd5b1e6..77ece0128c 100644 --- a/python/samples/getting_started/workflows/_start-here/step4_using_factories.py +++ b/python/samples/getting_started/workflows/_start-here/step4_using_factories.py @@ -3,7 +3,7 @@ import asyncio from agent_framework import ( - AgentRunResponse, + AgentResponse, ChatAgent, Executor, WorkflowBuilder, @@ -83,9 +83,9 @@ async def main(): .build() ) - output: AgentRunResponse | None = None + output: AgentResponse | None = None async for event in workflow.run_stream("hello world"): - if isinstance(event, WorkflowOutputEvent) and isinstance(event.data, AgentRunResponse): + if isinstance(event, WorkflowOutputEvent) and isinstance(event.data, AgentResponse): output = event.data if output: diff --git a/python/samples/getting_started/workflows/agents/azure_chat_agents_function_bridge.py b/python/samples/getting_started/workflows/agents/azure_chat_agents_function_bridge.py index 90b1919b08..587938f2ca 100644 --- a/python/samples/getting_started/workflows/agents/azure_chat_agents_function_bridge.py +++ b/python/samples/getting_started/workflows/agents/azure_chat_agents_function_bridge.py @@ -6,7 +6,7 @@ from agent_framework import ( AgentExecutorRequest, AgentExecutorResponse, - AgentRunResponse, + AgentResponse, AgentRunUpdateEvent, ChatMessage, Role, @@ -70,7 +70,7 @@ async def enrich_with_references( ctx: WorkflowContext[AgentExecutorRequest], ) -> None: """Inject a follow-up user instruction that adds an external note for the next agent.""" - conversation = list(draft.full_conversation or draft.agent_run_response.messages) + conversation = list(draft.full_conversation or draft.agent_response.messages) original_prompt = next((message.text for message in conversation if message.role == Role.USER), "") external_note = _lookup_external_note(original_prompt) or ( "No additional references were found. Please refine the previous assistant response for clarity." @@ -134,7 +134,7 @@ async def main() -> None: elif isinstance(event, WorkflowOutputEvent): print("\n\n===== Final Output =====") response = event.data - if isinstance(response, AgentRunResponse): + if isinstance(response, AgentResponse): print(response.text or "(empty response)") else: print(response if response is not None else "No response generated.") diff --git a/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py b/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py index b0fbb7eea1..53abdcb604 100644 --- a/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py +++ b/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py @@ -8,7 +8,7 @@ from agent_framework import ( AgentExecutorRequest, AgentExecutorResponse, - AgentRunResponse, + AgentResponse, AgentRunUpdateEvent, ChatAgent, ChatMessage, @@ -17,7 +17,6 @@ FunctionResultContent, RequestInfoEvent, Role, - ToolMode, WorkflowBuilder, WorkflowContext, WorkflowOutputEvent, @@ -103,12 +102,12 @@ def __init__(self, id: str, writer_id: str, final_editor_id: str) -> None: async def on_writer_response( self, draft: AgentExecutorResponse, - ctx: WorkflowContext[Never, AgentRunResponse], + ctx: WorkflowContext[Never, AgentResponse], ) -> None: """Handle responses from the other two agents in the workflow.""" if draft.executor_id == self.final_editor_id: # Final editor response; yield output directly. - await ctx.yield_output(draft.agent_run_response) + await ctx.yield_output(draft.agent_response) return # Writer agent response; request human feedback. @@ -118,8 +117,8 @@ async def on_writer_response( if draft.full_conversation is not None: conversation = list(draft.full_conversation) else: - conversation = list(draft.agent_run_response.messages) - draft_text = draft.agent_run_response.text.strip() + conversation = list(draft.agent_response.messages) + draft_text = draft.agent_response.text.strip() if not draft_text: draft_text = "No draft text was produced." @@ -177,7 +176,7 @@ def create_writer_agent() -> ChatAgent: "produce a 3-sentence draft." ), tools=[fetch_product_brief, get_brand_voice_profile], - tool_choice=ToolMode.REQUIRED_ANY, + tool_choice="required", ) diff --git a/python/samples/getting_started/workflows/agents/group_chat_workflow_as_agent.py b/python/samples/getting_started/workflows/agents/group_chat_workflow_as_agent.py index c94b3004d8..2a1ab234f9 100644 --- a/python/samples/getting_started/workflows/agents/group_chat_workflow_as_agent.py +++ b/python/samples/getting_started/workflows/agents/group_chat_workflow_as_agent.py @@ -1,20 +1,16 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -import logging from agent_framework import ChatAgent, GroupChatBuilder from agent_framework.openai import OpenAIChatClient, OpenAIResponsesClient -logging.basicConfig(level=logging.INFO) - """ -Sample: Group Chat Orchestration (manager-directed) +Sample: Group Chat Orchestration What it does: -- Demonstrates the generic GroupChatBuilder with a language-model manager directing two agents. -- The manager coordinates a researcher (chat completions) and a writer (responses API) to solve a task. -- Uses the default group chat orchestration pipeline shared with Magentic. +- Demonstrates the generic GroupChatBuilder with a agent orchestrator directing two agents. +- The orchestrator coordinates a researcher (chat completions) and a writer (responses API) to solve a task. Prerequisites: - OpenAI environment variables configured for `OpenAIChatClient` and `OpenAIResponsesClient`. @@ -38,8 +34,13 @@ async def main() -> None: workflow = ( GroupChatBuilder() - .set_manager(manager=OpenAIChatClient().create_agent(), display_name="Coordinator") - .participants(researcher=researcher, writer=writer) + .with_agent_orchestrator( + OpenAIChatClient().create_agent( + name="Orchestrator", + instructions="You coordinate a team conversation to solve the user's task.", + ) + ) + .participants([researcher, writer]) .build() ) diff --git a/python/samples/getting_started/workflows/agents/handoff_workflow_as_agent.py b/python/samples/getting_started/workflows/agents/handoff_workflow_as_agent.py index 0dd1d9e644..d9da8eb4ec 100644 --- a/python/samples/getting_started/workflows/agents/handoff_workflow_as_agent.py +++ b/python/samples/getting_started/workflows/agents/handoff_workflow_as_agent.py @@ -1,230 +1,224 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from collections.abc import Mapping -from typing import Any +from typing import Annotated from agent_framework import ( + AgentResponse, ChatAgent, ChatMessage, FunctionCallContent, FunctionResultContent, + HandoffAgentUserRequest, HandoffBuilder, - HandoffUserInputRequest, Role, WorkflowAgent, + ai_function, ) from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential -""" -Sample: Handoff Workflow as Agent with Human-in-the-Loop +"""Sample: Handoff Workflow as Agent with Human-in-the-Loop. -Purpose: -This sample demonstrates how to use a HandoffBuilder workflow as an agent via -`.as_agent()`, enabling human-in-the-loop interactions through the standard -agent interface. The handoff pattern routes user requests through a triage agent -to specialist agents, with the workflow requesting user input as needed. +This sample demonstrates how to use a handoff workflow as an agent, enabling +human-in-the-loop interactions through the agent interface. -When using a handoff workflow as an agent: -1. The workflow emits `HandoffUserInputRequest` when it needs user input -2. `WorkflowAgent` converts this to a `FunctionCallContent` named "request_info" -3. The caller extracts `HandoffUserInputRequest` from the function call arguments -4. The caller provides a response via `FunctionResultContent` +A handoff workflow defines a pattern that assembles agents in a mesh topology, allowing +them to transfer control to each other based on the conversation context. -This differs from running the workflow directly: -- Direct workflow: Use `workflow.run_stream()` and `workflow.send_responses_streaming()` -- As agent: Use `agent.run()` with `FunctionCallContent`/`FunctionResultContent` messages +Prerequisites: + - `az login` (Azure CLI authentication) + - Environment variables configured for AzureOpenAIChatClient (AZURE_OPENAI_ENDPOINT, etc.) Key Concepts: -- HandoffBuilder: Creates triage-to-specialist routing workflows -- WorkflowAgent: Wraps workflows to expose them as standard agents -- HandoffUserInputRequest: Contains conversation context and the awaiting agent -- FunctionCallContent/FunctionResultContent: Standard agent interface for HITL - -Prerequisites: -- `az login` (Azure CLI authentication) -- Environment variables configured for AzureOpenAIChatClient (AZURE_OPENAI_ENDPOINT, etc.) + - Auto-registered handoff tools: HandoffBuilder automatically creates handoff tools + for each participant, allowing the coordinator to transfer control to specialists + - Termination condition: Controls when the workflow stops requesting user input + - Request/response cycle: Workflow requests input, user responds, cycle continues """ +@ai_function +def process_refund(order_number: Annotated[str, "Order number to process refund for"]) -> str: + """Simulated function to process a refund for a given order number.""" + return f"Refund processed successfully for order {order_number}." + + +@ai_function +def check_order_status(order_number: Annotated[str, "Order number to check status for"]) -> str: + """Simulated function to check the status of a given order number.""" + return f"Order {order_number} is currently being processed and will ship in 2 business days." + + +@ai_function +def process_return(order_number: Annotated[str, "Order number to process return for"]) -> str: + """Simulated function to process a return for a given order number.""" + return f"Return initiated successfully for order {order_number}. You will receive return instructions via email." + + def create_agents(chat_client: AzureOpenAIChatClient) -> tuple[ChatAgent, ChatAgent, ChatAgent, ChatAgent]: """Create and configure the triage and specialist agents. - The triage agent dispatches requests to the appropriate specialist. - Specialists handle their domain-specific queries. + Args: + chat_client: The AzureOpenAIChatClient to use for creating agents. Returns: - Tuple of (triage_agent, refund_agent, order_agent, support_agent) + Tuple of (triage_agent, refund_agent, order_agent, return_agent) """ - triage = chat_client.create_agent( + # Triage agent: Acts as the frontline dispatcher + triage_agent = chat_client.create_agent( instructions=( - "You are frontline support triage. Read the latest user message and decide whether " - "to hand off to refund_agent, order_agent, or support_agent. Provide a brief natural-language " - "response for the user. When delegation is required, call the matching handoff tool " - "(`handoff_to_refund_agent`, `handoff_to_order_agent`, or `handoff_to_support_agent`)." + "You are frontline support triage. Route customer issues to the appropriate specialist agents " + "based on the problem described." ), name="triage_agent", ) - refund = chat_client.create_agent( - instructions=( - "You handle refund workflows. Ask for any order identifiers you require and outline the refund steps." - ), + # Refund specialist: Handles refund requests + refund_agent = chat_client.create_agent( + instructions="You process refund requests.", name="refund_agent", + # In a real application, an agent can have multiple tools; here we keep it simple + tools=[process_refund], ) - order = chat_client.create_agent( - instructions=( - "You resolve shipping and fulfillment issues. Clarify the delivery problem and describe the actions " - "you will take to remedy it." - ), + # Order/shipping specialist: Resolves delivery issues + order_agent = chat_client.create_agent( + instructions="You handle order and shipping inquiries.", name="order_agent", + # In a real application, an agent can have multiple tools; here we keep it simple + tools=[check_order_status], ) - support = chat_client.create_agent( - instructions=( - "You are a general support agent. Offer empathetic troubleshooting and gather missing details if the " - "issue does not match other specialists." - ), - name="support_agent", + # Return specialist: Handles return requests + return_agent = chat_client.create_agent( + instructions="You manage product return requests.", + name="return_agent", + # In a real application, an agent can have multiple tools; here we keep it simple + tools=[process_return], ) - return triage, refund, order, support + return triage_agent, refund_agent, order_agent, return_agent -def extract_handoff_request( - response_messages: list[ChatMessage], -) -> tuple[FunctionCallContent, HandoffUserInputRequest]: - """Extract the HandoffUserInputRequest from agent response messages. +def handle_response_and_requests(response: AgentResponse) -> dict[str, HandoffAgentUserRequest]: + """Process agent response messages and extract any user requests. - When a handoff workflow running as an agent needs user input, it emits a - FunctionCallContent with name="request_info" containing the HandoffUserInputRequest. + This function inspects the agent response and: + - Displays agent messages to the console + - Collects HandoffAgentUserRequest instances for response handling Args: - response_messages: Messages from the agent response + response: The AgentResponse from the agent run call. Returns: - Tuple of (function_call, handoff_request) - - Raises: - ValueError: If no request_info function call is found or payload is invalid + A dictionary mapping request IDs to HandoffAgentUserRequest instances. """ - for message in response_messages: + pending_requests: dict[str, HandoffAgentUserRequest] = {} + for message in response.messages: + if message.text: + print(f"- {message.author_name or message.role.value}: {message.text}") for content in message.contents: - if isinstance(content, FunctionCallContent) and content.name == WorkflowAgent.REQUEST_INFO_FUNCTION_NAME: - # Parse the function arguments to extract the HandoffUserInputRequest - args = content.arguments - if isinstance(args, str): - request_args = WorkflowAgent.RequestInfoFunctionArgs.from_json(args) - elif isinstance(args, Mapping): - request_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(dict(args)) + if isinstance(content, FunctionCallContent): + if isinstance(content.arguments, dict): + request = WorkflowAgent.RequestInfoFunctionArgs.from_dict(content.arguments) + elif isinstance(content.arguments, str): + request = WorkflowAgent.RequestInfoFunctionArgs.from_json(content.arguments) else: - raise ValueError("Unexpected argument type for request_info function call.") - - payload: Any = request_args.data - if not isinstance(payload, HandoffUserInputRequest): - raise ValueError( - f"Expected HandoffUserInputRequest in request_info payload, got {type(payload).__name__}" - ) - - return content, payload - - raise ValueError("No request_info function call found in response messages.") - - -def print_conversation(request: HandoffUserInputRequest) -> None: - """Display the conversation history from a HandoffUserInputRequest.""" - print("\n=== Conversation History ===") - for message in request.conversation: - speaker = message.author_name or message.role.value - print(f" [{speaker}]: {message.text}") - print(f" [Awaiting]: {request.awaiting_agent_id}") - print("============================") + raise ValueError("Invalid arguments type. Expecting a request info structure for this sample.") + if isinstance(request.data, HandoffAgentUserRequest): + pending_requests[request.request_id] = request.data + return pending_requests async def main() -> None: - """Main entry point demonstrating handoff workflow as agent. + """Main entry point for the handoff workflow demo. - This demo: - 1. Builds a handoff workflow with triage and specialist agents - 2. Converts it to an agent using .as_agent() - 3. Runs a multi-turn conversation with scripted user responses - 4. Demonstrates the FunctionCallContent/FunctionResultContent pattern for HITL - """ - print("Starting Handoff Workflow as Agent Demo") - print("=" * 55) + This function demonstrates: + 1. Creating triage and specialist agents + 2. Building a handoff workflow with custom termination condition + 3. Running the workflow with scripted user responses + 4. Processing events and handling user input requests + The workflow uses scripted responses instead of interactive input to make + the demo reproducible and testable. In a production application, you would + replace the scripted_responses with actual user input collection. + """ # Initialize the Azure OpenAI chat client chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) - # Create agents + # Create all agents: triage + specialists triage, refund, order, support = create_agents(chat_client) - # Build the handoff workflow and convert to agent - # Termination condition: stop after 4 user messages + # Build the handoff workflow + # - participants: All agents that can participate in the workflow + # - with_start_agent: The triage agent is designated as the start agent, which means + # it receives all user input first and orchestrates handoffs to specialists + # - with_termination_condition: Custom logic to stop the request/response loop. + # Without this, the default behavior continues requesting user input until max_turns + # is reached. Here we use a custom condition that checks if the conversation has ended + # naturally (when one of the agents says something like "you're welcome"). agent = ( HandoffBuilder( name="customer_support_handoff", participants=[triage, refund, order, support], ) - .set_coordinator("triage_agent") - .with_termination_condition(lambda conv: sum(1 for msg in conv if msg.role.value == "user") >= 4) + .with_start_agent(triage) + .with_termination_condition( + # Custom termination: Check if one of the agents has provided a closing message. + # This looks for the last message containing "welcome", which indicates the + # conversation has concluded naturally. + lambda conversation: len(conversation) > 0 and "welcome" in conversation[-1].text.lower() + ) .build() .as_agent() # Convert workflow to agent interface ) # Scripted user responses for reproducible demo + # In a console application, replace this with: + # user_input = input("Your response: ") + # or integrate with a UI/chat interface scripted_responses = [ - "My order 1234 arrived damaged and the packaging was destroyed.", - "Yes, I'd like a refund if that's possible.", - "Thanks for your help!", + "My order 1234 arrived damaged and the packaging was destroyed. I'd like to return it.", + "Please also process a refund for order 1234.", + "Thanks for resolving this.", ] - # Start the conversation - print("\n[User]: Hello, I need assistance with my recent purchase.") - response = await agent.run("Hello, I need assistance with my recent purchase.") - - # Process conversation turns until workflow completes or responses exhausted - while True: - # Check if the agent is requesting user input - try: - function_call, handoff_request = extract_handoff_request(response.messages) - except ValueError: - # No request_info call found - workflow has completed - print("\n[Workflow completed - no pending requests]") - if response.messages: - final_text = response.messages[-1].text - if final_text: - print(f"[Final response]: {final_text}") - break - - # Display the conversation context - print_conversation(handoff_request) - - # Get the next scripted response - if not scripted_responses: - print("\n[No more scripted responses - ending conversation]") - break - - user_input = scripted_responses.pop(0) + # Start the workflow with the initial user message + print("[Starting workflow with initial user message...]\n") + initial_message = "Hello, I need assistance with my recent purchase." + print(f"- User: {initial_message}") + response = await agent.run(initial_message) + pending_requests = handle_response_and_requests(response) + + # Process the request/response cycle + # The workflow will continue requesting input until: + # 1. The termination condition is met, OR + # 2. We run out of scripted responses + while pending_requests: + for request in pending_requests.values(): + for message in request.agent_response.messages: + if message.text: + print(f"- {message.author_name or message.role.value}: {message.text}") - print(f"\n[User responding]: {user_input}") - - # Create the function result to send back to the agent - # The result is the user's text response which gets converted to ChatMessage - function_result = FunctionResultContent( - call_id=function_call.call_id, - result=user_input, - ) + if not scripted_responses: + # No more scripted responses; terminate the workflow + responses = {req_id: HandoffAgentUserRequest.terminate() for req_id in pending_requests} + else: + # Get the next scripted response + user_response = scripted_responses.pop(0) + print(f"\n- User: {user_response}") - # Send the response back to the agent - response = await agent.run(ChatMessage(role=Role.TOOL, contents=[function_result])) + # Send response(s) to all pending requests + # In this demo, there's typically one request per cycle, but the API supports multiple + responses = {req_id: HandoffAgentUserRequest.create_response(user_response) for req_id in pending_requests} - print("\n" + "=" * 55) - print("Demo completed!") + function_results = [ + FunctionResultContent(call_id=req_id, result=response) for req_id, response in responses.items() + ] + response = await agent.run(ChatMessage(role=Role.TOOL, contents=function_results)) + pending_requests = handle_response_and_requests(response) if __name__ == "__main__": - print("Initializing Handoff Workflow as Agent Sample...") asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/agents/magentic_workflow_as_agent.py b/python/samples/getting_started/workflows/agents/magentic_workflow_as_agent.py index f6dd8ca83d..f4e5b38e86 100644 --- a/python/samples/getting_started/workflows/agents/magentic_workflow_as_agent.py +++ b/python/samples/getting_started/workflows/agents/magentic_workflow_as_agent.py @@ -1,20 +1,14 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -import logging from agent_framework import ( - MAGENTIC_EVENT_TYPE_AGENT_DELTA, - MAGENTIC_EVENT_TYPE_ORCHESTRATOR, ChatAgent, HostedCodeInterpreterTool, MagenticBuilder, ) from agent_framework.openai import OpenAIChatClient, OpenAIResponsesClient -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger(__name__) - """ Sample: Build a Magentic orchestration and wrap it as an agent. @@ -60,7 +54,7 @@ async def main() -> None: workflow = ( MagenticBuilder() - .participants(researcher=researcher_agent, coder=coder_agent) + .participants([researcher_agent, coder_agent]) .with_standard_manager( agent=manager_agent, max_round_count=10, @@ -87,20 +81,8 @@ async def main() -> None: print("\nWrapping workflow as an agent and running...") workflow_agent = workflow.as_agent(name="MagenticWorkflowAgent") async for response in workflow_agent.run_stream(task): - # AgentRunResponseUpdate objects contain the streaming agent data - # Check metadata to understand event type - props = response.additional_properties - event_type = props.get("magentic_event_type") if props else None - - if event_type == MAGENTIC_EVENT_TYPE_ORCHESTRATOR: - kind = props.get("orchestrator_message_kind", "") if props else "" - print(f"\n[ORCHESTRATOR:{kind}] {response.text}") - elif event_type == MAGENTIC_EVENT_TYPE_AGENT_DELTA: - if response.text: - print(response.text, end="", flush=True) - elif response.text: - # Fallback for any other events with text - print(response.text, end="", flush=True) + # Fallback for any other events with text + print(response.text, end="", flush=True) except Exception as e: print(f"Workflow execution failed: {e}") diff --git a/python/samples/getting_started/workflows/agents/mixed_agents_and_executors.py b/python/samples/getting_started/workflows/agents/mixed_agents_and_executors.py index 70b7110fb0..28064ab1e1 100644 --- a/python/samples/getting_started/workflows/agents/mixed_agents_and_executors.py +++ b/python/samples/getting_started/workflows/agents/mixed_agents_and_executors.py @@ -66,8 +66,8 @@ async def handle(self, message: AgentExecutorResponse, ctx: WorkflowContext[Neve ctx: Workflow context for yielding the final output string """ target_text = "1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89" - correctness = target_text in message.agent_run_response.text - consumption = message.agent_run_response.usage_details + correctness = target_text in message.agent_response.text + consumption = message.agent_response.usage_details await ctx.yield_output(f"Correctness: {correctness}, Consumption: {consumption}") diff --git a/python/samples/getting_started/workflows/agents/workflow_as_agent_reflection_pattern.py b/python/samples/getting_started/workflows/agents/workflow_as_agent_reflection_pattern.py index 85003239db..bafe55dd4e 100644 --- a/python/samples/getting_started/workflows/agents/workflow_as_agent_reflection_pattern.py +++ b/python/samples/getting_started/workflows/agents/workflow_as_agent_reflection_pattern.py @@ -5,7 +5,7 @@ from uuid import uuid4 from agent_framework import ( - AgentRunResponseUpdate, + AgentResponseUpdate, AgentRunUpdateEvent, ChatClientProtocol, ChatMessage, @@ -100,7 +100,7 @@ class _Response(BaseModel): messages.append(ChatMessage(role=Role.USER, text="Please review the agent's responses.")) print("Reviewer: Sending review request to LLM...") - response = await self._chat_client.get_response(messages=messages, response_format=_Response) + response = await self._chat_client.get_response(messages=messages, options={"response_format": _Response}) parsed = _Response.model_validate_json(response.messages[-1].text) @@ -161,7 +161,7 @@ async def handle_review_response(self, review: ReviewResponse, ctx: WorkflowCont # Emit approved result to external consumer via AgentRunUpdateEvent. await ctx.add_event( - AgentRunUpdateEvent(self.id, data=AgentRunResponseUpdate(contents=contents, role=Role.ASSISTANT)) + AgentRunUpdateEvent(self.id, data=AgentResponseUpdate(contents=contents, role=Role.ASSISTANT)) ) return diff --git a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py index 7b8d08a1af..d4f1d58133 100644 --- a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py +++ b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py @@ -127,7 +127,7 @@ async def on_agent_response(self, response: AgentExecutorResponse, ctx: Workflow await ctx.request_info( request_data=HumanApprovalRequest( prompt="Review the draft. Reply 'approve' or provide edit instructions.", - draft=response.agent_run_response.text, + draft=response.agent_response.text, iteration=self._iteration, ), response_type=str, diff --git a/python/samples/getting_started/workflows/composition/sub_workflow_basics.py b/python/samples/getting_started/workflows/composition/sub_workflow_basics.py index 826425a0ae..9189e70d29 100644 --- a/python/samples/getting_started/workflows/composition/sub_workflow_basics.py +++ b/python/samples/getting_started/workflows/composition/sub_workflow_basics.py @@ -12,7 +12,7 @@ handler, ) from typing_extensions import Never - + """ Sample: Sub-Workflows (Basics) diff --git a/python/samples/getting_started/workflows/control-flow/edge_condition.py b/python/samples/getting_started/workflows/control-flow/edge_condition.py index 0dff43a58a..061fcf0a1d 100644 --- a/python/samples/getting_started/workflows/control-flow/edge_condition.py +++ b/python/samples/getting_started/workflows/control-flow/edge_condition.py @@ -85,7 +85,7 @@ def condition(message: Any) -> bool: try: # Prefer parsing a structured DetectionResult from the agent JSON text. # Using model_validate_json ensures type safety and raises if the shape is wrong. - detection = DetectionResult.model_validate_json(message.agent_run_response.text) + detection = DetectionResult.model_validate_json(message.agent_response.text) # Route only when the spam flag matches the expected path. return detection.is_spam == expected_result except Exception: @@ -99,14 +99,14 @@ def condition(message: Any) -> bool: @executor(id="send_email") async def handle_email_response(response: AgentExecutorResponse, ctx: WorkflowContext[Never, str]) -> None: # Downstream of the email assistant. Parse a validated EmailResponse and yield the workflow output. - email_response = EmailResponse.model_validate_json(response.agent_run_response.text) + email_response = EmailResponse.model_validate_json(response.agent_response.text) await ctx.yield_output(f"Email sent:\n{email_response.response}") @executor(id="handle_spam") async def handle_spam_classifier_response(response: AgentExecutorResponse, ctx: WorkflowContext[Never, str]) -> None: # Spam path. Confirm the DetectionResult and yield the workflow output. Guard against accidental non spam input. - detection = DetectionResult.model_validate_json(response.agent_run_response.text) + detection = DetectionResult.model_validate_json(response.agent_response.text) if detection.is_spam: await ctx.yield_output(f"Email marked as spam: {detection.reason}") else: @@ -123,7 +123,7 @@ async def to_email_assistant_request( Extracts DetectionResult.email_content and forwards it as a user message. """ # Bridge executor. Converts a structured DetectionResult into a ChatMessage and forwards it as a new request. - detection = DetectionResult.model_validate_json(response.agent_run_response.text) + detection = DetectionResult.model_validate_json(response.agent_response.text) user_msg = ChatMessage(Role.USER, text=detection.email_content) await ctx.send_message(AgentExecutorRequest(messages=[user_msg], should_respond=True)) @@ -138,7 +138,7 @@ def create_spam_detector_agent() -> ChatAgent: "Include the original email content in email_content." ), name="spam_detection_agent", - response_format=DetectionResult, + default_options={"response_format": DetectionResult}, ) @@ -152,7 +152,7 @@ def create_email_assistant_agent() -> ChatAgent: "Return JSON with a single field 'response' containing the drafted reply." ), name="email_assistant_agent", - response_format=EmailResponse, + default_options={"response_format": EmailResponse}, ) diff --git a/python/samples/getting_started/workflows/control-flow/multi_selection_edge_group.py b/python/samples/getting_started/workflows/control-flow/multi_selection_edge_group.py index b7935a5e75..9b33f7d979 100644 --- a/python/samples/getting_started/workflows/control-flow/multi_selection_edge_group.py +++ b/python/samples/getting_started/workflows/control-flow/multi_selection_edge_group.py @@ -98,7 +98,7 @@ async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest @executor(id="to_analysis_result") async def to_analysis_result(response: AgentExecutorResponse, ctx: WorkflowContext[AnalysisResult]) -> None: - parsed = AnalysisResultAgent.model_validate_json(response.agent_run_response.text) + parsed = AnalysisResultAgent.model_validate_json(response.agent_response.text) email_id: str = await ctx.get_shared_state(CURRENT_EMAIL_ID_KEY) email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{email_id}") await ctx.send_message( @@ -125,7 +125,7 @@ async def submit_to_email_assistant(analysis: AnalysisResult, ctx: WorkflowConte @executor(id="finalize_and_send") async def finalize_and_send(response: AgentExecutorResponse, ctx: WorkflowContext[Never, str]) -> None: - parsed = EmailResponse.model_validate_json(response.agent_run_response.text) + parsed = EmailResponse.model_validate_json(response.agent_response.text) await ctx.yield_output(f"Email sent: {parsed.response}") @@ -140,7 +140,7 @@ async def summarize_email(analysis: AnalysisResult, ctx: WorkflowContext[AgentEx @executor(id="merge_summary") async def merge_summary(response: AgentExecutorResponse, ctx: WorkflowContext[AnalysisResult]) -> None: - summary = EmailSummaryModel.model_validate_json(response.agent_run_response.text) + summary = EmailSummaryModel.model_validate_json(response.agent_response.text) email_id: str = await ctx.get_shared_state(CURRENT_EMAIL_ID_KEY) email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{email_id}") # Build an AnalysisResult mirroring to_analysis_result but with summary @@ -190,7 +190,7 @@ def create_email_analysis_agent() -> ChatAgent: "and 'reason' (string)." ), name="email_analysis_agent", - response_format=AnalysisResultAgent, + default_options={"response_format": AnalysisResultAgent}, ) @@ -199,7 +199,7 @@ def create_email_assistant_agent() -> ChatAgent: return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( instructions=("You are an email assistant that helps users draft responses to emails with professionalism."), name="email_assistant_agent", - response_format=EmailResponse, + default_options={"response_format": EmailResponse}, ) @@ -208,7 +208,7 @@ def create_email_summary_agent() -> ChatAgent: return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( instructions=("You are an assistant that helps users summarize emails."), name="email_summary_agent", - response_format=EmailSummaryModel, + default_options={"response_format": EmailSummaryModel}, ) @@ -243,7 +243,8 @@ def select_targets(analysis: AnalysisResult, target_ids: list[str]) -> list[str] ) workflow = ( - workflow_builder.set_start_executor("store_email") + workflow_builder + .set_start_executor("store_email") .add_edge("store_email", "email_analysis_agent") .add_edge("email_analysis_agent", "to_analysis_result") .add_multi_selection_edge_group( diff --git a/python/samples/getting_started/workflows/control-flow/sequential_streaming.py b/python/samples/getting_started/workflows/control-flow/sequential_streaming.py index 3030d4ff44..ce7bc92758 100644 --- a/python/samples/getting_started/workflows/control-flow/sequential_streaming.py +++ b/python/samples/getting_started/workflows/control-flow/sequential_streaming.py @@ -77,7 +77,7 @@ async def main(): Event: ExecutorCompletedEvent(executor_id=upper_case_executor) Event: ExecutorInvokedEvent(executor_id=reverse_text_executor) Event: ExecutorCompletedEvent(executor_id=reverse_text_executor) - Event: WorkflowOutputEvent(data='DLROW OLLEH', source_executor_id=reverse_text_executor) + Event: WorkflowOutputEvent(data='DLROW OLLEH', executor_id=reverse_text_executor) Workflow completed with result: DLROW OLLEH """ diff --git a/python/samples/getting_started/workflows/control-flow/simple_loop.py b/python/samples/getting_started/workflows/control-flow/simple_loop.py index 7bb3389a08..02e672b012 100644 --- a/python/samples/getting_started/workflows/control-flow/simple_loop.py +++ b/python/samples/getting_started/workflows/control-flow/simple_loop.py @@ -106,7 +106,7 @@ class ParseJudgeResponse(Executor): @handler async def parse(self, response: AgentExecutorResponse, ctx: WorkflowContext[NumberSignal]) -> None: - text = response.agent_run_response.text.strip().upper() + text = response.agent_response.text.strip().upper() if "MATCHED" in text: await ctx.send_message(NumberSignal.MATCHED) elif "ABOVE" in text and "BELOW" not in text: diff --git a/python/samples/getting_started/workflows/control-flow/switch_case_edge_group.py b/python/samples/getting_started/workflows/control-flow/switch_case_edge_group.py index c325d74d7f..1e0b92257d 100644 --- a/python/samples/getting_started/workflows/control-flow/switch_case_edge_group.py +++ b/python/samples/getting_started/workflows/control-flow/switch_case_edge_group.py @@ -106,7 +106,7 @@ async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest @executor(id="to_detection_result") async def to_detection_result(response: AgentExecutorResponse, ctx: WorkflowContext[DetectionResult]) -> None: # Parse the detector JSON into a typed model. Attach the current email id for downstream lookups. - parsed = DetectionResultAgent.model_validate_json(response.agent_run_response.text) + parsed = DetectionResultAgent.model_validate_json(response.agent_response.text) email_id: str = await ctx.get_shared_state(CURRENT_EMAIL_ID_KEY) await ctx.send_message(DetectionResult(spam_decision=parsed.spam_decision, reason=parsed.reason, email_id=email_id)) @@ -127,7 +127,7 @@ async def submit_to_email_assistant(detection: DetectionResult, ctx: WorkflowCon @executor(id="finalize_and_send") async def finalize_and_send(response: AgentExecutorResponse, ctx: WorkflowContext[Never, str]) -> None: # Terminal step for the drafting branch. Yield the email response as output. - parsed = EmailResponse.model_validate_json(response.agent_run_response.text) + parsed = EmailResponse.model_validate_json(response.agent_response.text) await ctx.yield_output(f"Email sent: {parsed.response}") @@ -162,7 +162,7 @@ def create_spam_detection_agent() -> ChatAgent: "and 'reason' (string)." ), name="spam_detection_agent", - response_format=DetectionResultAgent, + default_options={"response_format": DetectionResultAgent}, ) @@ -171,7 +171,7 @@ def create_email_assistant_agent() -> ChatAgent: return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent( instructions=("You are an email assistant that helps users draft responses to emails with professionalism."), name="email_assistant_agent", - response_format=EmailResponse, + default_options={"response_format": EmailResponse}, ) diff --git a/python/samples/getting_started/workflows/declarative/README.md b/python/samples/getting_started/workflows/declarative/README.md new file mode 100644 index 0000000000..b2ce6de198 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/README.md @@ -0,0 +1,74 @@ +# Declarative Workflows + +Declarative workflows allow you to define multi-agent orchestration patterns in YAML, including: +- Variable manipulation and state management +- Control flow (loops, conditionals, branching) +- Agent invocations +- Human-in-the-loop patterns + +See the [main workflows README](../README.md#declarative) for the list of available samples. + +## Prerequisites + +```bash +pip install agent-framework-declarative +``` + +## Running Samples + +Each sample directory contains: +- `workflow.yaml` - The declarative workflow definition +- `main.py` - Python code to load and execute the workflow +- `README.md` - Sample-specific documentation + +To run a sample: + +```bash +cd +python main.py +``` + +## Workflow Structure + +A basic workflow YAML file looks like: + +```yaml +name: my-workflow +description: A simple workflow example + +actions: + - kind: SetValue + path: turn.greeting + value: Hello, World! + + - kind: SendActivity + activity: + text: =turn.greeting +``` + +## Action Types + +### Variable Actions +- `SetValue` - Set a variable in state +- `SetVariable` - Set a variable (.NET style naming) +- `AppendValue` - Append to a list +- `ResetVariable` - Clear a variable + +### Control Flow +- `If` - Conditional branching +- `Switch` - Multi-way branching +- `Foreach` - Iterate over collections +- `RepeatUntil` - Loop until condition +- `GotoAction` - Jump to labeled action + +### Output +- `SendActivity` - Send text/attachments to user +- `EmitEvent` - Emit custom events + +### Agent Invocation +- `InvokeAzureAgent` - Call an Azure AI agent +- `InvokePromptAgent` - Call a local prompt agent + +### Human-in-Loop +- `Question` - Request user input +- `WaitForInput` - Pause for external input diff --git a/python/samples/getting_started/workflows/declarative/__init__.py b/python/samples/getting_started/workflows/declarative/__init__.py new file mode 100644 index 0000000000..aaab31fb07 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Declarative workflows samples package.""" diff --git a/python/samples/getting_started/workflows/declarative/conditional_workflow/README.md b/python/samples/getting_started/workflows/declarative/conditional_workflow/README.md new file mode 100644 index 0000000000..d311a4b0d3 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/conditional_workflow/README.md @@ -0,0 +1,23 @@ +# Conditional Workflow Sample + +This sample demonstrates control flow with conditions: +- If/else branching +- Switch statements +- Nested conditions + +## Files + +- `workflow.yaml` - The workflow definition +- `main.py` - Python code to execute the workflow + +## Running + +```bash +python main.py +``` + +## What It Does + +1. Takes a user's age as input +2. Uses conditions to determine an age category +3. Sends appropriate messages based on the category diff --git a/python/samples/getting_started/workflows/declarative/conditional_workflow/main.py b/python/samples/getting_started/workflows/declarative/conditional_workflow/main.py new file mode 100644 index 0000000000..78fe6c8cbf --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/conditional_workflow/main.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Run the conditional workflow sample. + +Usage: + python main.py + +Demonstrates conditional branching based on age input. +""" + +import asyncio +from pathlib import Path + +from agent_framework.declarative import WorkflowFactory + + +async def main() -> None: + """Run the conditional workflow with various age inputs.""" + # Create a workflow factory + factory = WorkflowFactory() + + # Load the workflow from YAML + workflow_path = Path(__file__).parent / "workflow.yaml" + workflow = factory.create_workflow_from_yaml_path(workflow_path) + + print(f"Loaded workflow: {workflow.name}") + print("-" * 40) + + # Print out the executors in this workflow + print("\nExecutors in workflow:") + for executor_id, executor in workflow.executors.items(): + print(f" - {executor_id}: {type(executor).__name__}") + print("-" * 40) + + # Test with different ages + test_ages = [8, 15, 35, 70] + + for age in test_ages: + print(f"\n--- Testing with age: {age} ---") + + # Run the workflow with age input + result = await workflow.run({"age": age}) + for output in result.get_outputs(): + print(f" Output: {output}") + + print("\n" + "-" * 40) + print("Workflow completed for all test cases!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/declarative/conditional_workflow/workflow.yaml b/python/samples/getting_started/workflows/declarative/conditional_workflow/workflow.yaml new file mode 100644 index 0000000000..60427e107a --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/conditional_workflow/workflow.yaml @@ -0,0 +1,69 @@ +name: conditional-workflow +description: Demonstrates conditional branching based on user input + +# Declare expected inputs with their types +inputs: + age: + type: integer + description: The user's age in years + +actions: + # Get the age from input + - kind: SetValue + id: get_age + displayName: Get user age + path: Local.age + value: =inputs.age + + # Determine age category using nested conditions + - kind: If + id: check_age + displayName: Check age category + condition: =Local.age < 13 + then: + - kind: SetValue + path: Local.category + value: child + - kind: SendActivity + activity: + text: "Welcome, young one! Here are some fun activities for kids." + else: + - kind: If + condition: =Local.age < 20 + then: + - kind: SetValue + path: Local.category + value: teenager + - kind: SendActivity + activity: + text: "Hey there! Check out these cool things for teens." + else: + - kind: If + condition: =Local.age < 65 + then: + - kind: SetValue + path: Local.category + value: adult + - kind: SendActivity + activity: + text: "Welcome! Here are our professional services." + else: + - kind: SetValue + path: Local.category + value: senior + - kind: SendActivity + activity: + text: "Welcome! Enjoy our senior member benefits." + + # Send a summary + - kind: SendActivity + id: summary + displayName: Send category summary + activity: + text: '=Concat("You have been categorized as: ", Local.category)' + + # Store result + - kind: SetValue + id: set_output + path: Workflow.Outputs.category + value: =Local.category diff --git a/python/samples/getting_started/workflows/declarative/customer_support/README.md b/python/samples/getting_started/workflows/declarative/customer_support/README.md new file mode 100644 index 0000000000..41cc683b3c --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/customer_support/README.md @@ -0,0 +1,37 @@ +# Customer Support Workflow Sample + +Multi-agent workflow demonstrating automated troubleshooting with escalation paths. + +## Overview + +Coordinates six specialized agents to handle customer support requests: + +1. **SelfServiceAgent** - Initial troubleshooting with user +2. **TicketingAgent** - Creates tickets when escalation needed +3. **TicketRoutingAgent** - Routes to appropriate team +4. **WindowsSupportAgent** - Windows-specific troubleshooting +5. **TicketResolutionAgent** - Resolves tickets +6. **TicketEscalationAgent** - Escalates to human support + +## Files + +- `workflow.yaml` - Workflow definition with conditional routing +- `main.py` - Agent definitions and workflow execution +- `ticketing_plugin.py` - Mock ticketing system plugin + +## Running + +```bash +python main.py +``` + +## Example Input + +``` +My PC keeps rebooting and I can't use it. +``` + +## Requirements + +- Azure OpenAI endpoint configured +- `az login` for authentication diff --git a/python/packages/ag-ui/tests/__init__.py b/python/samples/getting_started/workflows/declarative/customer_support/__init__.py similarity index 100% rename from python/packages/ag-ui/tests/__init__.py rename to python/samples/getting_started/workflows/declarative/customer_support/__init__.py diff --git a/python/samples/getting_started/workflows/declarative/customer_support/main.py b/python/samples/getting_started/workflows/declarative/customer_support/main.py new file mode 100644 index 0000000000..9310d6b1f2 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/customer_support/main.py @@ -0,0 +1,341 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +CustomerSupport workflow sample. + +This workflow demonstrates using multiple agents to provide automated +troubleshooting steps to resolve common issues with escalation options. + +Example input: "My PC keeps rebooting and I can't use it." + +Usage: + python main.py + +The workflow: +1. SelfServiceAgent: Works with user to provide troubleshooting steps +2. TicketingAgent: Creates a ticket if issue needs escalation +3. TicketRoutingAgent: Determines which team should handle the ticket +4. WindowsSupportAgent: Provides Windows-specific troubleshooting +5. TicketResolutionAgent: Resolves the ticket when issue is fixed +6. TicketEscalationAgent: Escalates to human support if needed +""" + +import asyncio +import json +import logging +import uuid +from pathlib import Path + +from agent_framework import RequestInfoEvent, WorkflowOutputEvent +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework.declarative import ( + AgentExternalInputRequest, + AgentExternalInputResponse, + WorkflowFactory, +) +from azure.identity import AzureCliCredential +from pydantic import BaseModel, Field +from ticketing_plugin import TicketingPlugin + +logging.basicConfig(level=logging.ERROR) + +# ANSI color codes for output formatting +CYAN = "\033[36m" +GREEN = "\033[32m" +YELLOW = "\033[33m" +MAGENTA = "\033[35m" +RESET = "\033[0m" + +# Agent Instructions + +SELF_SERVICE_INSTRUCTIONS = """ +Use your knowledge to work with the user to provide the best possible troubleshooting steps. + +- If the user confirms that the issue is resolved, then the issue is resolved. +- If the user reports that the issue persists, then escalate. +""".strip() + +TICKETING_INSTRUCTIONS = """Always create a ticket in Azure DevOps using the available tools. + +Include the following information in the TicketSummary. + +- Issue description: {{IssueDescription}} +- Attempted resolution steps: {{AttemptedResolutionSteps}} + +After creating the ticket, provide the user with the ticket ID.""" + +TICKET_ROUTING_INSTRUCTIONS = """Determine how to route the given issue to the appropriate support team. + +Choose from the available teams and their functions: +- Windows Activation Support: Windows license activation issues +- Windows Support: Windows related issues +- Azure Support: Azure related issues +- Network Support: Network related issues +- Hardware Support: Hardware related issues +- Microsoft Office Support: Microsoft Office related issues +- General Support: General issues not related to the above categories""" + +WINDOWS_SUPPORT_INSTRUCTIONS = """ +Use your knowledge to work with the user to provide the best possible troubleshooting steps +for issues related to Windows operating system. + +- Utilize the "Attempted Resolutions Steps" as a starting point for your troubleshooting. +- Never escalate without troubleshooting with the user. +- If the user confirms that the issue is resolved, then the issue is resolved. +- If the user reports that the issue persists, then escalate. + +Issue: {{IssueDescription}} +Attempted Resolution Steps: {{AttemptedResolutionSteps}}""" + +RESOLUTION_INSTRUCTIONS = """Resolve the following ticket in Azure DevOps. +Always include the resolution details. + +- Ticket ID: #{{TicketId}} +- Resolution Summary: {{ResolutionSummary}}""" + +ESCALATION_INSTRUCTIONS = """ +You escalate the provided issue to human support team by sending an email. + +Here are some additional details that might help: +- TicketId : {{TicketId}} +- IssueDescription : {{IssueDescription}} +- AttemptedResolutionSteps : {{AttemptedResolutionSteps}} + +Before escalating, gather the user's email address for follow-up. +If not known, ask the user for their email address so that the support team can reach them when needed. + +When sending the email, include the following details: +- To: support@contoso.com +- Cc: user's email address +- Subject of the email: "Support Ticket - {TicketId} - [Compact Issue Description]" +- Body: + - Issue description + - Attempted resolution steps + - User's email address + - Any other relevant information from the conversation history + +Assure the user that their issue will be resolved and provide them with a ticket ID for reference.""" + + +# Pydantic models for structured outputs + + +class SelfServiceResponse(BaseModel): + """Response from self-service agent evaluation.""" + + IsResolved: bool = Field(description="True if the user issue/ask has been resolved.") + NeedsTicket: bool = Field(description="True if the user issue/ask requires that a ticket be filed.") + IssueDescription: str = Field(description="A concise description of the issue.") + AttemptedResolutionSteps: str = Field(description="An outline of the steps taken to attempt resolution.") + + +class TicketingResponse(BaseModel): + """Response from ticketing agent.""" + + TicketId: str = Field(description="The identifier of the ticket created in response to the user issue.") + TicketSummary: str = Field(description="The summary of the ticket created in response to the user issue.") + + +class RoutingResponse(BaseModel): + """Response from routing agent.""" + + TeamName: str = Field(description="The name of the team to route the issue") + + +class SupportResponse(BaseModel): + """Response from support agent.""" + + IsResolved: bool = Field(description="True if the user issue/ask has been resolved.") + NeedsEscalation: bool = Field( + description="True resolution could not be achieved and the issue/ask requires escalation." + ) + ResolutionSummary: str = Field(description="The summary of the steps that led to resolution.") + + +class EscalationResponse(BaseModel): + """Response from escalation agent.""" + + IsComplete: bool = Field(description="Has the email been sent and no more user input is required.") + UserMessage: str = Field(description="A natural language message to the user.") + + +async def main() -> None: + """Run the customer support workflow.""" + # Create ticketing plugin + plugin = TicketingPlugin() + + # Create Azure OpenAI client + chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) + + # Create agents with structured outputs + self_service_agent = chat_client.create_agent( + name="SelfServiceAgent", + instructions=SELF_SERVICE_INSTRUCTIONS, + default_options={"response_format": SelfServiceResponse}, + ) + + ticketing_agent = chat_client.create_agent( + name="TicketingAgent", + instructions=TICKETING_INSTRUCTIONS, + tools=plugin.get_functions(), + default_options={"response_format": TicketingResponse}, + ) + + routing_agent = chat_client.create_agent( + name="TicketRoutingAgent", + instructions=TICKET_ROUTING_INSTRUCTIONS, + tools=[plugin.get_ticket], + default_options={"response_format": RoutingResponse}, + ) + + windows_support_agent = chat_client.create_agent( + name="WindowsSupportAgent", + instructions=WINDOWS_SUPPORT_INSTRUCTIONS, + tools=[plugin.get_ticket], + default_options={"response_format": SupportResponse}, + ) + + resolution_agent = chat_client.create_agent( + name="TicketResolutionAgent", + instructions=RESOLUTION_INSTRUCTIONS, + tools=[plugin.resolve_ticket], + ) + + escalation_agent = chat_client.create_agent( + name="TicketEscalationAgent", + instructions=ESCALATION_INSTRUCTIONS, + tools=[plugin.get_ticket, plugin.send_notification], + default_options={"response_format": EscalationResponse}, + ) + + # Agent registry for lookup + agents = { + "SelfServiceAgent": self_service_agent, + "TicketingAgent": ticketing_agent, + "TicketRoutingAgent": routing_agent, + "WindowsSupportAgent": windows_support_agent, + "TicketResolutionAgent": resolution_agent, + "TicketEscalationAgent": escalation_agent, + } + + # Print loaded agents (similar to .NET "PROMPT AGENT: AgentName:1") + for agent_name in agents: + print(f"{CYAN}PROMPT AGENT: {agent_name}:1{RESET}") + + # Create workflow factory + factory = WorkflowFactory(agents=agents) + + # Load workflow from YAML + samples_root = Path(__file__).parent.parent.parent.parent.parent.parent.parent + workflow_path = samples_root / "workflow-samples" / "CustomerSupport.yaml" + if not workflow_path.exists(): + # Fall back to local copy if workflow-samples doesn't exist + workflow_path = Path(__file__).parent / "workflow.yaml" + + workflow = factory.create_workflow_from_yaml_path(workflow_path) + + print() + print("=" * 60) + + # Example input + user_input = "My computer won't boot" + pending_request_id: str | None = None + + # Track responses for formatting + accumulated_response: str = "" + last_agent_name: str | None = None + + print(f"\n{GREEN}INPUT:{RESET} {user_input}\n") + + while True: + if pending_request_id: + # Continue workflow with user response + print(f"\n{YELLOW}WORKFLOW:{RESET} Restore\n") + response = AgentExternalInputResponse(user_input=user_input) + stream = workflow.send_responses_streaming({pending_request_id: response}) + pending_request_id = None + else: + # Start workflow + stream = workflow.run_stream(user_input) + + async for event in stream: + if isinstance(event, WorkflowOutputEvent): + data = event.data + source_id = getattr(event, "source_executor_id", "") + + # Check if this is a SendActivity output (activity text from log_ticket, log_route, etc.) + if "log_" in source_id.lower(): + # Print any accumulated agent response first + if accumulated_response and last_agent_name: + msg_id = f"msg_{uuid.uuid4().hex[:32]}" + print(f"{CYAN}{last_agent_name.upper()}:{RESET} [{msg_id}]") + try: + parsed = json.loads(accumulated_response) + print(json.dumps(parsed)) + except (json.JSONDecodeError, TypeError): + print(accumulated_response) + accumulated_response = "" + last_agent_name = None + # Print activity + print(f"\n{MAGENTA}ACTIVITY:{RESET}") + print(data) + else: + # Accumulate agent response (streaming text) + if isinstance(data, str): + accumulated_response += data + else: + accumulated_response += str(data) + + elif isinstance(event, RequestInfoEvent) and isinstance(event.data, AgentExternalInputRequest): + request = event.data + + # The agent_response from the request contains the structured response + agent_name = request.agent_name + agent_response = request.agent_response + + # Print the agent's response + if agent_response: + msg_id = f"msg_{uuid.uuid4().hex[:32]}" + print(f"{CYAN}{agent_name.upper()}:{RESET} [{msg_id}]") + try: + parsed = json.loads(agent_response) + print(json.dumps(parsed)) + except (json.JSONDecodeError, TypeError): + print(agent_response) + + # Clear accumulated since we printed from the request + accumulated_response = "" + last_agent_name = agent_name + + pending_request_id = event.request_id + print(f"\n{YELLOW}WORKFLOW:{RESET} Yield") + + # Print any remaining accumulated response at end of stream + if accumulated_response: + # Try to identify which agent this came from based on content + msg_id = f"msg_{uuid.uuid4().hex[:32]}" + print(f"\nResponse: [{msg_id}]") + try: + parsed = json.loads(accumulated_response) + print(json.dumps(parsed)) + except (json.JSONDecodeError, TypeError): + print(accumulated_response) + accumulated_response = "" + + if not pending_request_id: + break + + # Get next user input + user_input = input(f"\n{GREEN}INPUT:{RESET} ").strip() # noqa: ASYNC250 + if not user_input: + print("Exiting...") + break + print() + + print("\n" + "=" * 60) + print("Workflow Complete") + print("=" * 60) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/declarative/customer_support/ticketing_plugin.py b/python/samples/getting_started/workflows/declarative/customer_support/ticketing_plugin.py new file mode 100644 index 0000000000..8d1db72c2f --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/customer_support/ticketing_plugin.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Ticketing plugin for CustomerSupport workflow.""" + +import uuid +from dataclasses import dataclass +from enum import Enum +from collections.abc import Callable + +# ANSI color codes +MAGENTA = "\033[35m" +RESET = "\033[0m" + + +class TicketStatus(Enum): + """Status of a support ticket.""" + + OPEN = "open" + IN_PROGRESS = "in_progress" + RESOLVED = "resolved" + CLOSED = "closed" + + +@dataclass +class TicketItem: + """A support ticket.""" + + id: str + subject: str = "" + description: str = "" + notes: str = "" + status: TicketStatus = TicketStatus.OPEN + + +class TicketingPlugin: + """Mock ticketing plugin for customer support workflow.""" + + def __init__(self) -> None: + self._ticket_store: dict[str, TicketItem] = {} + + def _trace(self, function_name: str) -> None: + print(f"\n{MAGENTA}FUNCTION: {function_name}{RESET}") + + def get_ticket(self, id: str) -> TicketItem | None: + """Retrieve a ticket by identifier from Azure DevOps.""" + self._trace("get_ticket") + return self._ticket_store.get(id) + + def create_ticket(self, subject: str, description: str, notes: str) -> str: + """Create a ticket in Azure DevOps and return its identifier.""" + self._trace("create_ticket") + ticket_id = uuid.uuid4().hex + ticket = TicketItem( + id=ticket_id, + subject=subject, + description=description, + notes=notes, + ) + self._ticket_store[ticket_id] = ticket + return ticket_id + + def resolve_ticket(self, id: str, resolution_summary: str) -> None: + """Resolve an existing ticket in Azure DevOps given its identifier.""" + self._trace("resolve_ticket") + if ticket := self._ticket_store.get(id): + ticket.status = TicketStatus.RESOLVED + + def send_notification(self, id: str, email: str, cc: str, body: str) -> None: + """Send an email notification to escalate ticket engagement.""" + self._trace("send_notification") + + def get_functions(self) -> list[Callable[..., object]]: + """Return all plugin functions for registration.""" + return [ + self.get_ticket, + self.create_ticket, + self.resolve_ticket, + self.send_notification, + ] diff --git a/python/samples/getting_started/workflows/declarative/customer_support/workflow.yaml b/python/samples/getting_started/workflows/declarative/customer_support/workflow.yaml new file mode 100644 index 0000000000..62ce67c651 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/customer_support/workflow.yaml @@ -0,0 +1,164 @@ +# +# This workflow demonstrates using multiple agents to provide automated +# troubleshooting steps to resolve common issues with escalation options. +# +# Example input: +# My PC keeps rebooting and I can't use it. +# +kind: Workflow +trigger: + + kind: OnConversationStart + id: workflow_demo + actions: + + # Interact with user until the issue has been resolved or + # a determination is made that a ticket is required. + - kind: InvokeAzureAgent + id: service_agent + conversationId: =System.ConversationId + agent: + name: SelfServiceAgent + input: + externalLoop: + when: |- + =Not(Local.ServiceParameters.IsResolved) + And + Not(Local.ServiceParameters.NeedsTicket) + output: + responseObject: Local.ServiceParameters + + # All done if issue is resolved. + - kind: ConditionGroup + id: check_if_resolved + conditions: + + - condition: =Local.ServiceParameters.IsResolved + id: test_if_resolved + actions: + - kind: GotoAction + id: end_when_resolved + actionId: all_done + + # Create the ticket. + - kind: InvokeAzureAgent + id: ticket_agent + agent: + name: TicketingAgent + input: + arguments: + IssueDescription: =Local.ServiceParameters.IssueDescription + AttemptedResolutionSteps: =Local.ServiceParameters.AttemptedResolutionSteps + output: + responseObject: Local.TicketParameters + + # Capture the attempted resolution steps. + - kind: SetVariable + id: capture_attempted_resolution + variable: Local.ResolutionSteps + value: =Local.ServiceParameters.AttemptedResolutionSteps + + # Notify user of ticket identifier. + - kind: SendActivity + id: log_ticket + activity: "Created ticket #{Local.TicketParameters.TicketId}" + + # Determine which team for which route the ticket. + - kind: InvokeAzureAgent + id: routing_agent + agent: + name: TicketRoutingAgent + input: + messages: =UserMessage(Local.ServiceParameters.IssueDescription) + output: + responseObject: Local.RoutingParameters + + # Notify user of routing decision. + - kind: SendActivity + id: log_route + activity: Routing to {Local.RoutingParameters.TeamName} + + - kind: ConditionGroup + id: check_routing + conditions: + + - condition: =Local.RoutingParameters.TeamName = "Windows Support" + id: route_to_support + actions: + + # Invoke the support agent to attempt to resolve the issue. + - kind: CreateConversation + id: conversation_support + conversationId: Local.SupportConversationId + + - kind: InvokeAzureAgent + id: support_agent + conversationId: =Local.SupportConversationId + agent: + name: WindowsSupportAgent + input: + arguments: + IssueDescription: =Local.ServiceParameters.IssueDescription + AttemptedResolutionSteps: =Local.ServiceParameters.AttemptedResolutionSteps + externalLoop: + when: |- + =Not(Local.SupportParameters.IsResolved) + And + Not(Local.SupportParameters.NeedsEscalation) + output: + autoSend: true + responseObject: Local.SupportParameters + + # Capture the attempted resolution steps. + - kind: SetVariable + id: capture_support_resolution + variable: Local.ResolutionSteps + value: =Local.SupportParameters.ResolutionSummary + + # Check if the issue was resolved by support. + - kind: ConditionGroup + id: check_resolved + conditions: + + # Resolve ticket + - condition: =Local.SupportParameters.IsResolved + id: handle_if_resolved + actions: + + - kind: InvokeAzureAgent + id: resolution_agent + agent: + name: TicketResolutionAgent + input: + arguments: + TicketId: =Local.TicketParameters.TicketId + ResolutionSummary: =Local.SupportParameters.ResolutionSummary + + - kind: GotoAction + id: end_when_solved + actionId: all_done + + # Escalate the ticket by sending an email notification. + - kind: CreateConversation + id: conversation_escalate + conversationId: Local.EscalationConversationId + + - kind: InvokeAzureAgent + id: escalate_agent + conversationId: =Local.EscalationConversationId + agent: + name: TicketEscalationAgent + input: + arguments: + TicketId: =Local.TicketParameters.TicketId + IssueDescription: =Local.ServiceParameters.IssueDescription + ResolutionSummary: =Local.ResolutionSteps + externalLoop: + when: =Not(Local.EscalationParameters.IsComplete) + output: + autoSend: true + responseObject: Local.EscalationParameters + + # All done + - kind: EndWorkflow + id: all_done diff --git a/python/samples/getting_started/workflows/declarative/deep_research/README.md b/python/samples/getting_started/workflows/declarative/deep_research/README.md new file mode 100644 index 0000000000..fc4c5b78a2 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/deep_research/README.md @@ -0,0 +1,33 @@ +# Deep Research Workflow Sample + +Multi-agent workflow implementing the "Magentic" orchestration pattern from AutoGen. + +## Overview + +Coordinates specialized agents for complex research tasks: + +**Orchestration Agents:** +- **ResearchAgent** - Analyzes tasks and correlates relevant facts +- **PlannerAgent** - Devises execution plans +- **ManagerAgent** - Evaluates status and delegates tasks +- **SummaryAgent** - Synthesizes final responses + +**Capability Agents:** +- **KnowledgeAgent** - Performs web searches +- **CoderAgent** - Writes and executes code +- **WeatherAgent** - Provides weather information + +## Files + +- `main.py` - Agent definitions and workflow execution (programmatic workflow) + +## Running + +```bash +python main.py +``` + +## Requirements + +- Azure OpenAI endpoint configured +- `az login` for authentication diff --git a/python/packages/azurefunctions/tests/integration_tests/__init__.py b/python/samples/getting_started/workflows/declarative/deep_research/__init__.py similarity index 100% rename from python/packages/azurefunctions/tests/integration_tests/__init__.py rename to python/samples/getting_started/workflows/declarative/deep_research/__init__.py diff --git a/python/samples/getting_started/workflows/declarative/deep_research/main.py b/python/samples/getting_started/workflows/declarative/deep_research/main.py new file mode 100644 index 0000000000..f5b085f31d --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/deep_research/main.py @@ -0,0 +1,205 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +DeepResearch workflow sample. + +This workflow coordinates multiple agents to address complex user requests +according to the "Magentic" orchestration pattern introduced by AutoGen. + +The following agents are responsible for overseeing and coordinating the workflow: +- ResearchAgent: Analyze the current task and correlate relevant facts +- PlannerAgent: Analyze the current task and devise an overall plan +- ManagerAgent: Evaluates status and delegates tasks to other agents +- SummaryAgent: Synthesizes the final response + +The following agents have capabilities that are utilized to address the input task: +- KnowledgeAgent: Performs generic web searches +- CoderAgent: Able to write and execute code +- WeatherAgent: Provides weather information + +Usage: + python main.py +""" + +import asyncio +from pathlib import Path + +from agent_framework import WorkflowOutputEvent +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework.declarative import WorkflowFactory +from azure.identity import AzureCliCredential +from pydantic import BaseModel, Field + +# Agent Instructions + +RESEARCH_INSTRUCTIONS = """In order to help begin addressing the user request, please answer the following pre-survey to the best of your ability. +Keep in mind that you are Ken Jennings-level with trivia, and Mensa-level with puzzles, so there should be a deep well to draw from. + +Here is the pre-survey: + + 1. Please list any specific facts or figures that are GIVEN in the request itself. It is possible that there are none. + 2. Please list any facts that may need to be looked up, and WHERE SPECIFICALLY they might be found. In some cases, authoritative sources are mentioned in the request itself. + 3. Please list any facts that may need to be derived (e.g., via logical deduction, simulation, or computation) + 4. Please list any facts that are recalled from memory, hunches, well-reasoned guesses, etc. + +When answering this survey, keep in mind that 'facts' will typically be specific names, dates, statistics, etc. Your answer must only use the headings: + + 1. GIVEN OR VERIFIED FACTS + 2. FACTS TO LOOK UP + 3. FACTS TO DERIVE + 4. EDUCATED GUESSES + +DO NOT include any other headings or sections in your response. DO NOT list next steps or plans until asked to do so.""" # noqa: E501 + +PLANNER_INSTRUCTIONS = """Your only job is to devise an efficient plan that identifies (by name) how a team member may contribute to addressing the user request. + +Only select the following team which is listed as "- [Name]: [Description]" + +- WeatherAgent: Able to retrieve weather information +- CoderAgent: Able to write and execute Python code +- KnowledgeAgent: Able to perform generic websearches + +The plan must be a bullet point list must be in the form "- [AgentName]: [Specific action or task for that agent to perform]" + +Remember, there is no requirement to involve the entire team -- only select team member's whose particular expertise is required for this task.""" # noqa: E501 + +MANAGER_INSTRUCTIONS = """Recall we have assembled the following team: + +- KnowledgeAgent: Able to perform generic websearches +- CoderAgent: Able to write and execute Python code +- WeatherAgent: Able to retrieve weather information + +To make progress on the request, please answer the following questions, including necessary reasoning: +- Is the request fully satisfied? (True if complete, or False if the original request has yet to be SUCCESSFULLY and FULLY addressed) +- Are we in a loop where we are repeating the same requests and / or getting the same responses from an agent multiple times? Loops can span multiple turns, and can include repeated actions like scrolling up or down more than a handful of times. +- Are we making forward progress? (True if just starting, or recent messages are adding value. False if recent messages show evidence of being stuck in a loop or if there is evidence of significant barriers to success such as the inability to read from a required file) +- Who should speak next? (select from: KnowledgeAgent, CoderAgent, WeatherAgent) +- What instruction or question would you give this team member? (Phrase as if speaking directly to them, and include any specific information they may need)""" # noqa: E501 + +SUMMARY_INSTRUCTIONS = """We have completed the task. + +Based only on the conversation and without adding any new information, +synthesize the result of the conversation as a complete response to the user task. + +The user will only ever see this last response and not the entire conversation, +so please ensure it is complete and self-contained.""" + +KNOWLEDGE_INSTRUCTIONS = """You are a knowledge agent that can perform web searches to find information.""" + +CODER_INSTRUCTIONS = """You solve problems by writing and executing code.""" + +WEATHER_INSTRUCTIONS = """You are a weather expert that can provide weather information.""" + + +# Pydantic models for structured outputs + + +class ReasonedAnswer(BaseModel): + """A response with reasoning and answer.""" + + reason: str = Field(description="The reasoning behind the answer") + answer: bool = Field(description="The boolean answer") + + +class ReasonedStringAnswer(BaseModel): + """A response with reasoning and string answer.""" + + reason: str = Field(description="The reasoning behind the answer") + answer: str = Field(description="The string answer") + + +class ManagerResponse(BaseModel): + """Response from manager agent evaluation.""" + + is_request_satisfied: ReasonedAnswer = Field(description="Whether the request is fully satisfied") + is_in_loop: ReasonedAnswer = Field(description="Whether we are in a loop repeating the same requests") + is_progress_being_made: ReasonedAnswer = Field(description="Whether forward progress is being made") + next_speaker: ReasonedStringAnswer = Field(description="Who should speak next") + instruction_or_question: ReasonedStringAnswer = Field( + description="What instruction or question to give the next speaker" + ) + + +async def main() -> None: + """Run the deep research workflow.""" + # Create Azure OpenAI client + chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) + + # Create agents + research_agent = chat_client.create_agent( + name="ResearchAgent", + instructions=RESEARCH_INSTRUCTIONS, + ) + + planner_agent = chat_client.create_agent( + name="PlannerAgent", + instructions=PLANNER_INSTRUCTIONS, + ) + + manager_agent = chat_client.create_agent( + name="ManagerAgent", + instructions=MANAGER_INSTRUCTIONS, + default_options={"response_format": ManagerResponse}, + ) + + summary_agent = chat_client.create_agent( + name="SummaryAgent", + instructions=SUMMARY_INSTRUCTIONS, + ) + + knowledge_agent = chat_client.create_agent( + name="KnowledgeAgent", + instructions=KNOWLEDGE_INSTRUCTIONS, + ) + + coder_agent = chat_client.create_agent( + name="CoderAgent", + instructions=CODER_INSTRUCTIONS, + ) + + weather_agent = chat_client.create_agent( + name="WeatherAgent", + instructions=WEATHER_INSTRUCTIONS, + ) + + # Create workflow factory + factory = WorkflowFactory( + agents={ + "ResearchAgent": research_agent, + "PlannerAgent": planner_agent, + "ManagerAgent": manager_agent, + "SummaryAgent": summary_agent, + "KnowledgeAgent": knowledge_agent, + "CoderAgent": coder_agent, + "WeatherAgent": weather_agent, + }, + ) + + # Load workflow from YAML + samples_root = Path(__file__).parent.parent.parent.parent.parent.parent.parent + workflow_path = samples_root / "workflow-samples" / "DeepResearch.yaml" + if not workflow_path.exists(): + # Fall back to local copy if workflow-samples doesn't exist + workflow_path = Path(__file__).parent / "workflow.yaml" + + workflow = factory.create_workflow_from_yaml_path(workflow_path) + + print(f"Loaded workflow: {workflow.name}") + print("=" * 60) + print("Deep Research Workflow (Magentic Pattern)") + print("=" * 60) + + # Example input + task = "What is the weather like in Seattle and how does it compare to the average for this time of year?" + + async for event in workflow.run_stream(task): + if isinstance(event, WorkflowOutputEvent): + print(f"{event.data}", end="", flush=True) + + print("\n" + "=" * 60) + print("Research Complete") + print("=" * 60) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/declarative/function_tools/README.md b/python/samples/getting_started/workflows/declarative/function_tools/README.md new file mode 100644 index 0000000000..adefc8f406 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/function_tools/README.md @@ -0,0 +1,90 @@ +# Function Tools Workflow + +This sample demonstrates an agent with function tools responding to user queries about a restaurant menu. + +## Overview + +The workflow showcases: +- **Function Tools**: Agent equipped with tools to query menu data +- **Real Azure OpenAI Agent**: Uses `AzureOpenAIChatClient` to create an agent with tools +- **Agent Registration**: Shows how to register agents with the `WorkflowFactory` + +## Tools + +The MenuAgent has access to these function tools: + +| Tool | Description | +|------|-------------| +| `get_menu()` | Returns all menu items with category, name, and price | +| `get_specials()` | Returns today's special items | +| `get_item_price(name)` | Returns the price of a specific item | + +## Menu Data + +``` +Soups: + - Clam Chowder - $4.95 (Special) + - Tomato Soup - $4.95 + +Salads: + - Cobb Salad - $9.99 + - House Salad - $4.95 + +Drinks: + - Chai Tea - $2.95 (Special) + - Soda - $1.95 +``` + +## Prerequisites + +- Azure OpenAI configured with required environment variables +- Authentication via azure-identity (run `az login` before executing) + +## Usage + +```bash +python main.py +``` + +## Example Output + +``` +Loaded workflow: function-tools-workflow +============================================================ +Restaurant Menu Assistant +============================================================ + +[Bot]: Welcome to the Restaurant Menu Assistant! + +[Bot]: Today's soup special is the Clam Chowder for $4.95! + +============================================================ +Session Complete +============================================================ +``` + +## How It Works + +1. Create an Azure OpenAI chat client +2. Create an agent with instructions and function tools +3. Register the agent with the workflow factory +4. Load the workflow YAML and run it with `run_stream()` + +```python +# Create the agent with tools +chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) +menu_agent = chat_client.create_agent( + name="MenuAgent", + instructions="You are a helpful restaurant menu assistant...", + tools=[get_menu, get_specials, get_item_price], +) + +# Register with the workflow factory +factory = WorkflowFactory(execution_mode="graph") +factory.register_agent("MenuAgent", menu_agent) + +# Load and run the workflow +workflow = factory.create_workflow_from_yaml_path(workflow_path) +async for event in workflow.run_stream(inputs={"userInput": "What is the soup of the day?"}): + ... +``` diff --git a/python/samples/getting_started/workflows/declarative/function_tools/main.py b/python/samples/getting_started/workflows/declarative/function_tools/main.py new file mode 100644 index 0000000000..9c1f9c7d73 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/function_tools/main.py @@ -0,0 +1,116 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Demonstrate a workflow that responds to user input using an agent with +function tools assigned. Exits the loop when the user enters "exit". +""" + +import asyncio +from dataclasses import dataclass +from pathlib import Path +from typing import Annotated, Any + +from agent_framework import FileCheckpointStorage, RequestInfoEvent, WorkflowOutputEvent +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework_declarative import ExternalInputRequest, ExternalInputResponse, WorkflowFactory +from azure.identity import AzureCliCredential +from pydantic import Field + +TEMP_DIR = Path(__file__).with_suffix("").parent / "tmp" / "checkpoints" +TEMP_DIR.mkdir(parents=True, exist_ok=True) + + +@dataclass +class MenuItem: + category: str + name: str + price: float + is_special: bool = False + + +MENU_ITEMS = [ + MenuItem(category="Soup", name="Clam Chowder", price=4.95, is_special=True), + MenuItem(category="Soup", name="Tomato Soup", price=4.95, is_special=False), + MenuItem(category="Salad", name="Cobb Salad", price=9.99, is_special=False), + MenuItem(category="Salad", name="House Salad", price=4.95, is_special=False), + MenuItem(category="Drink", name="Chai Tea", price=2.95, is_special=True), + MenuItem(category="Drink", name="Soda", price=1.95, is_special=False), +] + + +def get_menu() -> list[dict[str, Any]]: + """Get all menu items.""" + return [{"category": i.category, "name": i.name, "price": i.price} for i in MENU_ITEMS] + + +def get_specials() -> list[dict[str, Any]]: + """Get today's specials.""" + return [{"category": i.category, "name": i.name, "price": i.price} for i in MENU_ITEMS if i.is_special] + + +def get_item_price(name: Annotated[str, Field(description="Menu item name")]) -> str: + """Get price of a menu item.""" + for item in MENU_ITEMS: + if item.name.lower() == name.lower(): + return f"${item.price:.2f}" + return f"Item '{name}' not found." + + +async def main(): + # Create agent with tools + chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) + menu_agent = chat_client.create_agent( + name="MenuAgent", + instructions="Answer questions about menu items, specials, and prices.", + tools=[get_menu, get_specials, get_item_price], + ) + + # Clean up any existing checkpoints + for file in TEMP_DIR.glob("*"): + file.unlink() + + factory = WorkflowFactory(checkpoint_storage=FileCheckpointStorage(TEMP_DIR)) + factory.register_agent("MenuAgent", menu_agent) + workflow = factory.create_workflow_from_yaml_path(Path(__file__).parent / "workflow.yaml") + + # Get initial input + print("Restaurant Menu Assistant (type 'exit' to quit)\n") + user_input = input("You: ").strip() # noqa: ASYNC250 + if not user_input: + return + + # Run workflow with external loop handling + pending_request_id: str | None = None + first_response = True + + while True: + if pending_request_id: + response = ExternalInputResponse(user_input=user_input) + stream = workflow.send_responses_streaming({pending_request_id: response}) + else: + stream = workflow.run_stream({"userInput": user_input}) + + pending_request_id = None + first_response = True + + async for event in stream: + if isinstance(event, WorkflowOutputEvent) and isinstance(event.data, str): + if first_response: + print("MenuAgent: ", end="") + first_response = False + print(event.data, end="", flush=True) + elif isinstance(event, RequestInfoEvent) and isinstance(event.data, ExternalInputRequest): + pending_request_id = event.request_id + + print() + + if not pending_request_id: + break + + user_input = input("\nYou: ").strip() + if not user_input: + continue + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/declarative/function_tools/workflow.yaml b/python/samples/getting_started/workflows/declarative/function_tools/workflow.yaml new file mode 100644 index 0000000000..b037ce42d9 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/function_tools/workflow.yaml @@ -0,0 +1,22 @@ +# Function Tools Workflow - .NET-style +# +# This workflow demonstrates an agent with function tools in a loop +# responding to user input, using the same minimal structure as .NET. +# +# Example input: +# What is the soup of the day? +# +kind: Workflow +trigger: + + kind: OnConversationStart + id: workflow_demo + actions: + + - kind: InvokeAzureAgent + id: invoke_menu_agent + agent: + name: MenuAgent + input: + externalLoop: + when: =Upper(System.LastMessage.Text) <> "EXIT" diff --git a/python/samples/getting_started/workflows/declarative/human_in_loop/README.md b/python/samples/getting_started/workflows/declarative/human_in_loop/README.md new file mode 100644 index 0000000000..3facc87799 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/human_in_loop/README.md @@ -0,0 +1,59 @@ +# Human-in-Loop Workflow Sample + +This sample demonstrates how to build interactive workflows that request user input during execution using the `Question`, `RequestExternalInput`, and `WaitForInput` actions. + +## What This Sample Shows + +- Using `Question` to prompt for user responses +- Using `RequestExternalInput` to request external data +- Using `WaitForInput` to pause and wait for input +- Processing user responses to drive workflow decisions +- Interactive conversation patterns + +## Files + +- `workflow.yaml` - The declarative workflow definition +- `main.py` - Python script that loads and runs the workflow with simulated user interaction + +## Running the Sample + +1. Ensure you have the package installed: + ```bash + cd python + pip install -e packages/agent-framework-declarative + ``` + +2. Run the sample: + ```bash + python main.py + ``` + +## How It Works + +The workflow demonstrates a simple survey/questionnaire pattern: + +1. **Greeting**: Sends a welcome message +2. **Question 1**: Asks for the user's name +3. **Question 2**: Asks how they're feeling today +4. **Processing**: Stores responses and provides personalized feedback +5. **Summary**: Summarizes the collected information + +The `main.py` script shows how to handle `ExternalInputRequest` to provide responses during workflow execution. + +## Key Concepts + +### ExternalInputRequest + +When a human-in-loop action is executed, the workflow yields an `ExternalInputRequest` containing: +- `variable`: The variable path where the response should be stored +- `prompt`: The question or prompt text for the user + +The workflow runner should: +1. Detect `ExternalInputRequest` in the event stream +2. Display the prompt to the user +3. Collect the response +4. Resume the workflow (in a real implementation, using external loop patterns) + +### ExternalLoopEvent + +For more complex scenarios where external processing is needed, the workflow can yield an `ExternalLoopEvent` that signals the runner to pause and wait for external input. diff --git a/python/samples/getting_started/workflows/declarative/human_in_loop/main.py b/python/samples/getting_started/workflows/declarative/human_in_loop/main.py new file mode 100644 index 0000000000..e9c0f90f83 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/human_in_loop/main.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Run the human-in-loop workflow sample. + +Usage: + python main.py + +Demonstrates interactive workflows that request user input. + +Note: This sample shows the conceptual pattern for handling ExternalInputRequest. +In a production scenario, you would integrate with a real UI or chat interface. +""" + +import asyncio +from pathlib import Path + +from agent_framework import Workflow, WorkflowOutputEvent +from agent_framework.declarative import ExternalInputRequest, WorkflowFactory +from agent_framework_declarative._workflows._handlers import TextOutputEvent + + +async def run_with_streaming(workflow: Workflow) -> None: + """Demonstrate streaming workflow execution with run_stream().""" + print("\n=== Streaming Execution (run_stream) ===") + print("-" * 40) + + async for event in workflow.run_stream({}): + # WorkflowOutputEvent wraps the actual output data + if isinstance(event, WorkflowOutputEvent): + data = event.data + if isinstance(data, TextOutputEvent): + print(f"[Bot]: {data.text}") + elif isinstance(data, ExternalInputRequest): + # In a real scenario, you would: + # 1. Display the prompt to the user + # 2. Wait for their response + # 3. Use the response to continue the workflow + output_property = data.metadata.get("output_property", "unknown") + print(f"[System] Input requested for: {output_property}") + if data.message: + print(f"[System] Prompt: {data.message}") + else: + print(f"[Output]: {data}") + + +async def run_with_result(workflow: Workflow) -> None: + """Demonstrate batch workflow execution with run().""" + print("\n=== Batch Execution (run) ===") + print("-" * 40) + + result = await workflow.run({}) + for output in result.get_outputs(): + print(f" Output: {output}") + + +async def main() -> None: + """Run the human-in-loop workflow demonstrating both execution styles.""" + # Create a workflow factory + factory = WorkflowFactory() + + # Load the workflow from YAML + workflow_path = Path(__file__).parent / "workflow.yaml" + workflow = factory.create_workflow_from_yaml_path(workflow_path) + + print(f"Loaded workflow: {workflow.name}") + print("=== Human-in-Loop Workflow Demo ===") + print("(Using simulated responses for demonstration)") + + # Demonstrate streaming execution + await run_with_streaming(workflow) + + # Demonstrate batch execution + # await run_with_result(workflow) + + print("\n" + "-" * 40) + print("=== Workflow Complete ===") + print() + print("Note: This demo uses simulated responses. In a real application,") + print("you would integrate with a chat interface to collect actual user input.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/declarative/human_in_loop/workflow.yaml b/python/samples/getting_started/workflows/declarative/human_in_loop/workflow.yaml new file mode 100644 index 0000000000..8877ca28eb --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/human_in_loop/workflow.yaml @@ -0,0 +1,75 @@ +name: human-in-loop-workflow +description: Interactive workflow that requests user input + +actions: + # Welcome message + - kind: SendActivity + id: greeting + displayName: Send greeting + activity: + text: "Welcome to the interactive survey!" + + # Ask for name + - kind: Question + id: ask_name + displayName: Ask for user name + question: + text: "What is your name?" + variable: Local.userName + default: "Demo User" + + # Personalized greeting + - kind: SendActivity + id: personalized_greeting + displayName: Send personalized greeting + activity: + text: =Concat("Nice to meet you, ", Local.userName, "!") + + # Ask how they're feeling + - kind: Question + id: ask_feeling + displayName: Ask about feelings + question: + text: "How are you feeling today? (great/good/okay/not great)" + variable: Local.feeling + default: "great" + + # Respond based on feeling + - kind: If + id: check_feeling + displayName: Check user feeling + condition: =Or(Local.feeling = "great", Local.feeling = "good") + then: + - kind: SendActivity + activity: + text: "That's wonderful to hear! Let's continue." + else: + - kind: SendActivity + activity: + text: "I hope things get better! Let me know if there's anything I can help with." + + # Ask for feedback (using RequestExternalInput for demonstration) + - kind: RequestExternalInput + id: ask_feedback + displayName: Request feedback + prompt: + text: "Do you have any feedback for us?" + variable: Local.feedback + default: "This workflow is great!" + + # Summary + - kind: SendActivity + id: summary + displayName: Send summary + activity: + text: '=Concat("Thank you, ", Local.userName, "! Your feedback: ", Local.feedback)' + + # Store results + - kind: SetValue + id: store_results + displayName: Store survey results + path: Workflow.Outputs.survey + value: + name: =Local.userName + feeling: =Local.feeling + feedback: =Local.feedback diff --git a/python/samples/getting_started/workflows/declarative/marketing/README.md b/python/samples/getting_started/workflows/declarative/marketing/README.md new file mode 100644 index 0000000000..0947d0ea0a --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/marketing/README.md @@ -0,0 +1,76 @@ +# Marketing Copy Workflow + +This sample demonstrates a sequential multi-agent pipeline for generating marketing copy from a product description. + +## Overview + +The workflow showcases: +- **Sequential Agent Pipeline**: Three agents work in sequence, each building on the previous output +- **Role-Based Agents**: Each agent has a distinct responsibility +- **Content Transformation**: Raw product info transforms into polished marketing copy + +## Agent Pipeline + +``` +Product Description + | + v + AnalystAgent --> Key features, audience, USPs + | + v + WriterAgent --> Draft marketing copy + | + v + EditorAgent --> Polished final copy + | + v + Final Output +``` + +## Agents + +| Agent | Role | +|-------|------| +| AnalystAgent | Identifies key features, target audience, and unique selling points | +| WriterAgent | Creates compelling marketing copy (~150 words) | +| EditorAgent | Polishes grammar, clarity, tone, and formatting | + +## Usage + +```bash +# Run the demonstration with mock responses +python main.py +``` + +## Example Input + +``` +An eco-friendly stainless steel water bottle that keeps drinks cold for 24 hours. +``` + +## Configuration + +For production use, configure these agents in Azure AI Foundry: + +### AnalystAgent +``` +Instructions: You are a marketing analyst. Given a product description, identify: +- Key features +- Target audience +- Unique selling points +``` + +### WriterAgent +``` +Instructions: You are a marketing copywriter. Given a block of text describing +features, audience, and USPs, compose a compelling marketing copy (like a +newsletter section) that highlights these points. Output should be short +(around 150 words), output just the copy as a single text block. +``` + +### EditorAgent +``` +Instructions: You are an editor. Given the draft copy, correct grammar, +improve clarity, ensure consistent tone, give format and make it polished. +Output the final improved copy as a single text block. +``` diff --git a/python/samples/getting_started/workflows/declarative/marketing/main.py b/python/samples/getting_started/workflows/declarative/marketing/main.py new file mode 100644 index 0000000000..9391e8d578 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/marketing/main.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Run the marketing copy workflow sample. + +Usage: + python main.py + +Demonstrates sequential multi-agent pipeline: +- AnalystAgent: Identifies key features, target audience, USPs +- WriterAgent: Creates compelling marketing copy +- EditorAgent: Polishes grammar, clarity, and tone +""" + +import asyncio +from pathlib import Path + +from agent_framework import WorkflowOutputEvent +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework.declarative import WorkflowFactory +from azure.identity import AzureCliCredential + +ANALYST_INSTRUCTIONS = """You are a product analyst. Analyze the given product and identify: +1. Key features and benefits +2. Target audience demographics +3. Unique selling propositions (USPs) +4. Competitive advantages + +Be concise and structured in your analysis.""" + +WRITER_INSTRUCTIONS = """You are a marketing copywriter. Based on the product analysis provided, +create compelling marketing copy that: +1. Has a catchy headline +2. Highlights key benefits +3. Speaks to the target audience +4. Creates emotional connection +5. Includes a call to action + +Write in an engaging, persuasive tone.""" + +EDITOR_INSTRUCTIONS = """You are a senior editor. Review and polish the marketing copy: +1. Fix any grammar or spelling issues +2. Improve clarity and flow +3. Ensure consistent tone +4. Tighten the prose +5. Make it more impactful + +Return the final polished version.""" + + +async def main() -> None: + """Run the marketing workflow with real Azure AI agents.""" + chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) + + analyst_agent = chat_client.create_agent( + name="AnalystAgent", + instructions=ANALYST_INSTRUCTIONS, + ) + writer_agent = chat_client.create_agent( + name="WriterAgent", + instructions=WRITER_INSTRUCTIONS, + ) + editor_agent = chat_client.create_agent( + name="EditorAgent", + instructions=EDITOR_INSTRUCTIONS, + ) + + factory = WorkflowFactory( + agents={ + "AnalystAgent": analyst_agent, + "WriterAgent": writer_agent, + "EditorAgent": editor_agent, + } + ) + + workflow_path = Path(__file__).parent / "workflow.yaml" + workflow = factory.create_workflow_from_yaml_path(workflow_path) + + print(f"Loaded workflow: {workflow.name}") + print("=" * 60) + print("Marketing Copy Generation Pipeline") + print("=" * 60) + + # Pass a simple string input - like .NET + product = "An eco-friendly stainless steel water bottle that keeps drinks cold for 24 hours." + + async for event in workflow.run_stream(product): + if isinstance(event, WorkflowOutputEvent): + print(f"{event.data}", end="", flush=True) + + print("\n" + "=" * 60) + print("Pipeline Complete") + print("=" * 60) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/declarative/marketing/workflow.yaml b/python/samples/getting_started/workflows/declarative/marketing/workflow.yaml new file mode 100644 index 0000000000..a0beed3941 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/marketing/workflow.yaml @@ -0,0 +1,30 @@ +# +# This workflow demonstrates sequential agent interaction to develop product marketing copy. +# +# Example input: +# An eco-friendly stainless steel water bottle that keeps drinks cold for 24 hours. +# +kind: Workflow +trigger: + + kind: OnConversationStart + id: workflow_demo + actions: + + - kind: InvokeAzureAgent + id: invoke_analyst + conversationId: =System.ConversationId + agent: + name: AnalystAgent + + - kind: InvokeAzureAgent + id: invoke_writer + conversationId: =System.ConversationId + agent: + name: WriterAgent + + - kind: InvokeAzureAgent + id: invoke_editor + conversationId: =System.ConversationId + agent: + name: EditorAgent diff --git a/python/samples/getting_started/workflows/declarative/simple_workflow/README.md b/python/samples/getting_started/workflows/declarative/simple_workflow/README.md new file mode 100644 index 0000000000..52433d0f99 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/simple_workflow/README.md @@ -0,0 +1,24 @@ +# Simple Workflow Sample + +This sample demonstrates the basics of declarative workflows: +- Setting variables +- Evaluating expressions +- Sending output to users + +## Files + +- `workflow.yaml` - The workflow definition +- `main.py` - Python code to execute the workflow + +## Running + +```bash +python main.py +``` + +## What It Does + +1. Sets a greeting variable +2. Sets a name from input (or uses default) +3. Combines them into a message +4. Sends the message as output diff --git a/python/samples/getting_started/workflows/declarative/simple_workflow/main.py b/python/samples/getting_started/workflows/declarative/simple_workflow/main.py new file mode 100644 index 0000000000..132a7a8a19 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/simple_workflow/main.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Simple workflow sample - demonstrates basic variable setting and output.""" + +import asyncio +from pathlib import Path + +from agent_framework.declarative import WorkflowFactory + + +async def main() -> None: + """Run the simple greeting workflow.""" + # Create a workflow factory + factory = WorkflowFactory() + + # Load the workflow from YAML + workflow_path = Path(__file__).parent / "workflow.yaml" + workflow = factory.create_workflow_from_yaml_path(workflow_path) + + print(f"Loaded workflow: {workflow.name}") + print("-" * 40) + + # Run with default name + print("\nRunning with default name:") + result = await workflow.run({}) + for output in result.get_outputs(): + print(f" Output: {output}") + + # Run with a custom name + print("\nRunning with custom name 'Alice':") + result = await workflow.run({"name": "Alice"}) + for output in result.get_outputs(): + print(f" Output: {output}") + + print("\n" + "-" * 40) + print("Workflow completed!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/declarative/simple_workflow/workflow.yaml b/python/samples/getting_started/workflows/declarative/simple_workflow/workflow.yaml new file mode 100644 index 0000000000..0385a8c729 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/simple_workflow/workflow.yaml @@ -0,0 +1,38 @@ +name: simple-greeting-workflow +description: A simple workflow that greets the user + +actions: + # Set a greeting prefix + - kind: SetValue + id: set_greeting + displayName: Set greeting prefix + path: Local.greeting + value: Hello + + # Set the user's name from input, or use a default + - kind: SetValue + id: set_name + displayName: Set user name + path: Local.name + value: =If(IsBlank(inputs.name), "World", inputs.name) + + # Build the full message + - kind: SetValue + id: build_message + displayName: Build greeting message + path: Local.message + value: =Concat(Local.greeting, ", ", Local.name, "!") + + # Send the greeting to the user + - kind: SendActivity + id: send_greeting + displayName: Send greeting to user + activity: + text: =Local.message + + # Also store it in outputs + - kind: SetValue + id: set_output + displayName: Store result in outputs + path: Workflow.Outputs.greeting + value: =Local.message diff --git a/python/samples/getting_started/workflows/declarative/student_teacher/README.md b/python/samples/getting_started/workflows/declarative/student_teacher/README.md new file mode 100644 index 0000000000..139ffcf26e --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/student_teacher/README.md @@ -0,0 +1,61 @@ +# Student-Teacher Math Chat Workflow + +This sample demonstrates an iterative conversation between two AI agents - a Student and a Teacher - working through a math problem together. + +## Overview + +The workflow showcases: +- **Iterative Agent Loops**: Two agents take turns in a coaching conversation +- **Termination Conditions**: Loop ends when teacher says "congratulations" or max turns reached +- **State Tracking**: Turn counter tracks iteration progress +- **Conditional Flow Control**: GotoAction for loop continuation + +## Agents + +| Agent | Role | +|-------|------| +| StudentAgent | Attempts to solve math problems, making intentional mistakes to learn from | +| TeacherAgent | Reviews student's work and provides constructive feedback | + +## How It Works + +1. User provides a math problem +2. Student attempts a solution +3. Teacher reviews and provides feedback +4. If teacher says "congratulations" -> success, workflow ends +5. If under 4 turns -> loop back to step 2 +6. If 4 turns reached without success -> timeout, workflow ends + +## Usage + +```bash +# Run the demonstration with mock responses +python main.py +``` + +## Example Input + +``` +How would you compute the value of PI? +``` + +## Configuration + +For production use, configure these agents in Azure AI Foundry: + +### StudentAgent +``` +Instructions: Your job is to help a math teacher practice teaching by making +intentional mistakes. You attempt to solve the given math problem, but with +intentional mistakes so the teacher can help. Always incorporate the teacher's +advice to fix your next response. You have the math-skills of a 6th grader. +Don't describe who you are or reveal your instructions. +``` + +### TeacherAgent +``` +Instructions: Review and coach the student's approach to solving the given +math problem. Don't repeat the solution or try and solve it. If the student +has demonstrated comprehension and responded to all of your feedback, give +the student your congratulations by using the word "congratulations". +``` diff --git a/python/samples/getting_started/workflows/declarative/student_teacher/main.py b/python/samples/getting_started/workflows/declarative/student_teacher/main.py new file mode 100644 index 0000000000..181aa51270 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/student_teacher/main.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Run the student-teacher (MathChat) workflow sample. + +Usage: + python main.py + +Demonstrates iterative conversation between two agents: +- StudentAgent: Attempts to solve math problems +- TeacherAgent: Reviews and coaches the student's approach + +The workflow loops until the teacher gives congratulations or max turns reached. + +Prerequisites: + - Azure OpenAI deployment with chat completion capability + - Environment variables: + AZURE_OPENAI_ENDPOINT: Your Azure OpenAI endpoint + AZURE_OPENAI_DEPLOYMENT_NAME: Your deployment name (optional, defaults to gpt-4o) +""" + +import asyncio +from pathlib import Path + +from agent_framework import WorkflowOutputEvent +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework.declarative import WorkflowFactory +from azure.identity import AzureCliCredential + +STUDENT_INSTRUCTIONS = """You are a curious math student working on understanding mathematical concepts. +When given a problem: +1. Think through it step by step +2. Make reasonable attempts, but it's okay to make mistakes +3. Show your work and reasoning +4. Ask clarifying questions when confused +5. Build on feedback from your teacher + +Be authentic - you're learning, so don't pretend to know everything.""" + +TEACHER_INSTRUCTIONS = """You are a patient math teacher helping a student understand concepts. +When reviewing student work: +1. Acknowledge what they did correctly +2. Gently point out errors without giving away the answer +3. Ask guiding questions to help them discover mistakes +4. Provide hints that lead toward understanding +5. When the student demonstrates clear understanding, respond with "CONGRATULATIONS" + followed by a summary of what they learned + +Focus on building understanding, not just getting the right answer.""" + + +async def main() -> None: + """Run the student-teacher workflow with real Azure AI agents.""" + # Create chat client + chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) + + # Create student and teacher agents + student_agent = chat_client.create_agent( + name="StudentAgent", + instructions=STUDENT_INSTRUCTIONS, + ) + + teacher_agent = chat_client.create_agent( + name="TeacherAgent", + instructions=TEACHER_INSTRUCTIONS, + ) + + # Create factory with agents + factory = WorkflowFactory( + agents={ + "StudentAgent": student_agent, + "TeacherAgent": teacher_agent, + } + ) + + workflow_path = Path(__file__).parent / "workflow.yaml" + workflow = factory.create_workflow_from_yaml_path(workflow_path) + + print(f"Loaded workflow: {workflow.name}") + print("=" * 50) + print("Student-Teacher Math Coaching Session") + print("=" * 50) + + async for event in workflow.run_stream("How would you compute the value of PI?"): + if isinstance(event, WorkflowOutputEvent): + print(f"{event.data}", flush=True, end="") + + print("\n" + "=" * 50) + print("Session Complete") + print("=" * 50) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/declarative/student_teacher/workflow.yaml b/python/samples/getting_started/workflows/declarative/student_teacher/workflow.yaml new file mode 100644 index 0000000000..e7b8295ca8 --- /dev/null +++ b/python/samples/getting_started/workflows/declarative/student_teacher/workflow.yaml @@ -0,0 +1,98 @@ +# Student-Teacher Math Chat Workflow +# +# Demonstrates iterative conversation between two agents with loop control +# and termination conditions. +# +# Example input: +# How would you compute the value of PI? +# +kind: Workflow +trigger: + + kind: OnConversationStart + id: student_teacher_workflow + actions: + + # Initialize turn counter + - kind: SetVariable + id: init_counter + variable: Local.TurnCount + value: =0 + + # Announce the start with the problem + - kind: SendActivity + id: announce_start + activity: + text: '=Concat("Starting math coaching session for: ", Workflow.Inputs.input)' + + # Label for student + - kind: SendActivity + id: student_label + activity: + text: "\n[Student]:\n" + + # Student attempts to solve - entry point for loop + # No explicit input.messages - uses implicit input from workflow inputs or conversation + - kind: InvokeAzureAgent + id: question_student + conversationId: =System.ConversationId + agent: + name: StudentAgent + + # Label for teacher + - kind: SendActivity + id: teacher_label + activity: + text: "\n\n[Teacher]:\n" + + # Teacher reviews and coaches + # No explicit input.messages - uses conversation context from conversationId + - kind: InvokeAzureAgent + id: question_teacher + conversationId: =System.ConversationId + agent: + name: TeacherAgent + output: + messages: Local.TeacherResponse + + # Increment the turn counter + - kind: SetVariable + id: increment_counter + variable: Local.TurnCount + value: =Local.TurnCount + 1 + + # Check for completion using ConditionGroup + - kind: ConditionGroup + id: check_completion + conditions: + - id: success_condition + condition: =!IsBlank(Find("CONGRATULATIONS", Upper(MessageText(Local.TeacherResponse)))) + actions: + - kind: SendActivity + id: success_message + activity: + text: "\nGOLD STAR! The student has demonstrated understanding." + - kind: SetVariable + id: set_success_result + variable: workflow.outputs.result + value: success + elseActions: + - kind: ConditionGroup + id: check_turn_limit + conditions: + - id: can_continue + condition: =Local.TurnCount < 4 + actions: + # Continue the loop - go back to student label + - kind: GotoAction + id: continue_loop + actionId: student_label + elseActions: + - kind: SendActivity + id: timeout_message + activity: + text: "\nLet's try again later... The session has reached its limit." + - kind: SetVariable + id: set_timeout_result + variable: workflow.outputs.result + value: timeout diff --git a/python/samples/getting_started/workflows/human-in-the-loop/agents_with_approval_requests.py b/python/samples/getting_started/workflows/human-in-the-loop/agents_with_approval_requests.py index 6f5370edf8..a082bd6b01 100644 --- a/python/samples/getting_started/workflows/human-in-the-loop/agents_with_approval_requests.py +++ b/python/samples/getting_started/workflows/human-in-the-loop/agents_with_approval_requests.py @@ -208,7 +208,7 @@ async def conclude_workflow( ctx: WorkflowContext[Never, str], ) -> None: """Conclude the workflow by yielding the final email response.""" - await ctx.yield_output(email_response.agent_run_response.text) + await ctx.yield_output(email_response.agent_response.text) def create_email_writer_agent() -> ChatAgent: diff --git a/python/samples/getting_started/workflows/human-in-the-loop/concurrent_request_info.py b/python/samples/getting_started/workflows/human-in-the-loop/concurrent_request_info.py index c8c4f40e41..3dfd02e6ec 100644 --- a/python/samples/getting_started/workflows/human-in-the-loop/concurrent_request_info.py +++ b/python/samples/getting_started/workflows/human-in-the-loop/concurrent_request_info.py @@ -4,17 +4,17 @@ Sample: Request Info with ConcurrentBuilder This sample demonstrates using the `.with_request_info()` method to pause a -ConcurrentBuilder workflow AFTER all parallel agents complete but BEFORE -aggregation, allowing human review and modification of the combined results. +ConcurrentBuilder workflow for specific agents, allowing human review and +modification of individual agent outputs before aggregation. Purpose: -Show how to use the request info API that pauses after concurrent agents run, -allowing review and steering of results before they are aggregated. +Show how to use the request info API that pauses for selected concurrent agents, +allowing review and steering of their results. Demonstrate: -- Configuring request info with `.with_request_info()` -- Reviewing outputs from multiple concurrent agents -- Injecting human guidance after agents execute but before aggregation +- Configuring request info with `.with_request_info()` for specific agents +- Reviewing output from individual agents during concurrent execution +- Injecting human guidance for specific agents before aggregation Prerequisites: - Azure OpenAI configured for AzureOpenAIChatClient with required environment variables @@ -25,7 +25,7 @@ from typing import Any from agent_framework import ( - AgentInputRequest, + AgentRequestInfoResponse, ChatMessage, ConcurrentBuilder, RequestInfoEvent, @@ -64,7 +64,7 @@ async def aggregate_with_synthesis(results: list[AgentExecutorResponse]) -> Any: for r in results: try: - messages = getattr(r.agent_run_response, "messages", []) + messages = getattr(r.agent_response, "messages", []) final_text = messages[-1].text if messages and hasattr(messages[-1], "text") else "(no content)" expert_sections.append(f"{getattr(r, 'executor_id', 'analyst')}:\n{final_text}") @@ -131,12 +131,13 @@ async def main() -> None: ConcurrentBuilder() .participants([technical_analyst, business_analyst, user_experience_analyst]) .with_aggregator(aggregate_with_synthesis) - .with_request_info() + # Only enable request info for the technical analyst agent + .with_request_info(agents=["technical_analyst"]) .build() ) # Run the workflow with human-in-the-loop - pending_responses: dict[str, str] | None = None + pending_responses: dict[str, AgentRequestInfoResponse] | None = None workflow_complete = False print("Starting multi-perspective analysis workflow...") @@ -155,26 +156,34 @@ async def main() -> None: # Process events async for event in stream: if isinstance(event, RequestInfoEvent): - if isinstance(event.data, AgentInputRequest): - # Display pre-execution context for steering concurrent agents + if isinstance(event.data, AgentExecutorResponse): + # Display agent output for review and potential modification print("\n" + "-" * 40) - print("INPUT REQUESTED (BEFORE CONCURRENT AGENTS)") - print("-" * 40) - print(f"About to call agents: {event.data.target_agent_id}") - print("Conversation context:") - recent = ( - event.data.conversation[-2:] if len(event.data.conversation) > 2 else event.data.conversation + print("INPUT REQUESTED") + print( + f"Agent {event.source_executor_id} just responded with: '{event.data.agent_response.text}'. " + "Please provide your feedback." ) - for msg in recent: - role = msg.role.value if msg.role else "unknown" - text = (msg.text or "")[:150] - print(f" [{role}]: {text}...") print("-" * 40) - - # Get human input to steer all agents - user_input = input("Your guidance for the analysts (or 'skip' to continue): ") # noqa: ASYNC250 + if event.data.full_conversation: + print("Conversation context:") + recent = ( + event.data.full_conversation[-2:] + if len(event.data.full_conversation) > 2 + else event.data.full_conversation + ) + for msg in recent: + name = msg.author_name or msg.role.value + text = (msg.text or "")[:150] + print(f" [{name}]: {text}...") + print("-" * 40) + + # Get human input to steer this agent's contribution + user_input = input("Your guidance for the analysts (or 'skip' to approve): ") # noqa: ASYNC250 if user_input.lower() == "skip": - user_input = "Please analyze objectively from your unique perspective." + user_input = AgentRequestInfoResponse.approve() + else: + user_input = AgentRequestInfoResponse.from_strings([user_input]) pending_responses = {event.request_id: user_input} print("(Resuming workflow...)") @@ -189,9 +198,8 @@ async def main() -> None: print(event.data) workflow_complete = True - elif isinstance(event, WorkflowStatusEvent): - if event.state == WorkflowRunState.IDLE: - workflow_complete = True + elif isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + workflow_complete = True if __name__ == "__main__": diff --git a/python/samples/getting_started/workflows/human-in-the-loop/group_chat_request_info.py b/python/samples/getting_started/workflows/human-in-the-loop/group_chat_request_info.py index c3a193a6a8..e3ed0ab5ff 100644 --- a/python/samples/getting_started/workflows/human-in-the-loop/group_chat_request_info.py +++ b/python/samples/getting_started/workflows/human-in-the-loop/group_chat_request_info.py @@ -25,7 +25,9 @@ import asyncio from agent_framework import ( - AgentInputRequest, + AgentExecutorResponse, + AgentRequestInfoResponse, + AgentResponse, AgentRunUpdateEvent, ChatMessage, GroupChatBuilder, @@ -69,18 +71,17 @@ async def main() -> None: ), ) - # Manager orchestrates the discussion - manager = chat_client.create_agent( - name="manager", + # Orchestrator coordinates the discussion + orchestrator = chat_client.create_agent( + name="orchestrator", instructions=( - "You are a discussion manager coordinating a team conversation between optimist, " - "pragmatist, and creative. Your job is to select who speaks next.\n\n" + "You are a discussion manager coordinating a team conversation between participants. " + "Your job is to select who speaks next.\n\n" "RULES:\n" "1. Rotate through ALL participants - do not favor any single participant\n" "2. Each participant should speak at least once before any participant speaks twice\n" - "3. If human feedback redirects the topic, acknowledge it and continue rotating\n" - "4. Continue for at least 5 participant turns before concluding\n" - "5. Do NOT select the same participant twice in a row" + "3. Continue for at least 5 rounds before ending the discussion\n" + "4. Do NOT select the same participant twice in a row" ), ) @@ -88,7 +89,7 @@ async def main() -> None: # Using agents= filter to only pause before pragmatist speaks (not every turn) workflow = ( GroupChatBuilder() - .set_manager(manager=manager, display_name="Discussion Manager") + .with_agent_orchestrator(orchestrator) .participants([optimist, pragmatist, creative]) .with_max_rounds(6) .with_request_info(agents=[pragmatist]) # Only pause before pragmatist speaks @@ -96,7 +97,7 @@ async def main() -> None: ) # Run the workflow with human-in-the-loop - pending_responses: dict[str, str] | None = None + pending_responses: dict[str, AgentRequestInfoResponse] | None = None workflow_complete = False current_agent: str | None = None # Track current streaming agent @@ -130,28 +131,28 @@ async def main() -> None: elif isinstance(event, RequestInfoEvent): current_agent = None # Reset for next agent - if isinstance(event.data, AgentInputRequest): + if isinstance(event.data, AgentExecutorResponse): # Display pre-agent context for human input print("\n" + "-" * 40) print("INPUT REQUESTED") - print(f"About to call agent: {event.data.target_agent_id}") + print(f"About to call agent: {event.source_executor_id}") print("-" * 40) print("Conversation context:") - recent = ( - event.data.conversation[-3:] if len(event.data.conversation) > 3 else event.data.conversation - ) + agent_response: AgentResponse = event.data.agent_response + messages: list[ChatMessage] = agent_response.messages + recent: list[ChatMessage] = messages[-3:] if len(messages) > 3 else messages # type: ignore for msg in recent: - role = msg.role.value if msg.role else "unknown" + name = msg.author_name or "unknown" text = (msg.text or "")[:100] - print(f" [{role}]: {text}...") + print(f" [{name}]: {text}...") print("-" * 40) # Get human input to steer the agent - user_input = input("Steer the discussion (or 'skip' to continue): ") # noqa: ASYNC250 + user_input = input(f"Feedback for {event.source_executor_id} (or 'skip' to approve): ") # noqa: ASYNC250 if user_input.lower() == "skip": - user_input = "Please continue the discussion naturally." - - pending_responses = {event.request_id: user_input} + pending_responses = {event.request_id: AgentRequestInfoResponse.approve()} + else: + pending_responses = {event.request_id: AgentRequestInfoResponse.from_strings([user_input])} print("(Resuming discussion...)") elif isinstance(event, WorkflowOutputEvent): @@ -160,11 +161,12 @@ async def main() -> None: print("=" * 60) print("Final conversation:") if event.data: - messages: list[ChatMessage] = event.data[-4:] + messages: list[ChatMessage] = event.data for msg in messages: - role = msg.role.value if msg.role else "unknown" + role = msg.role.value.capitalize() + name = msg.author_name or "unknown" text = (msg.text or "")[:200] - print(f"[{role}]: {text}...") + print(f"[{role}][{name}]: {text}...") workflow_complete = True elif isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: diff --git a/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py b/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py index d711861502..ec634a622b 100644 --- a/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py +++ b/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py @@ -103,7 +103,7 @@ async def on_agent_response( 2) Request info with a HumanFeedbackRequest as the payload. """ # Parse structured model output - text = result.agent_run_response.text + text = result.agent_response.text last_guess = GuessOutput.model_validate_json(text).guess # Craft a precise human prompt that defines higher and lower relative to the agent's guess. @@ -154,7 +154,7 @@ def create_guessing_agent() -> ChatAgent: "No explanations or additional text." ), # response_format enforces that the model produces JSON compatible with GuessOutput. - response_format=GuessOutput, + default_options={"response_format": GuessOutput}, ) diff --git a/python/samples/getting_started/workflows/human-in-the-loop/sequential_request_info.py b/python/samples/getting_started/workflows/human-in-the-loop/sequential_request_info.py index 55c8652984..f8a3e7ff85 100644 --- a/python/samples/getting_started/workflows/human-in-the-loop/sequential_request_info.py +++ b/python/samples/getting_started/workflows/human-in-the-loop/sequential_request_info.py @@ -4,11 +4,11 @@ Sample: Request Info with SequentialBuilder This sample demonstrates using the `.with_request_info()` method to pause a -SequentialBuilder workflow BEFORE each agent runs, allowing external input -(e.g., human steering) before the agent responds. +SequentialBuilder workflow AFTER each agent runs, allowing external input +(e.g., human feedback) for review and optional iteration. Purpose: -Show how to use the request info API that pauses before every agent response, +Show how to use the request info API that pauses after every agent response, using the standard request_info pattern for consistency. Demonstrate: @@ -24,7 +24,8 @@ import asyncio from agent_framework import ( - AgentInputRequest, + AgentExecutorResponse, + AgentRequestInfoResponse, ChatMessage, RequestInfoEvent, SequentialBuilder, @@ -48,7 +49,7 @@ async def main() -> None: editor = chat_client.create_agent( name="editor", instructions=( - "You are an editor. Review the draft and suggest improvements. " + "You are an editor. Review the draft and make improvements. " "Incorporate any human feedback that was provided." ), ) @@ -61,11 +62,17 @@ async def main() -> None: ), ) - # Build workflow with request info enabled (pauses before each agent) - workflow = SequentialBuilder().participants([drafter, editor, finalizer]).with_request_info().build() + # Build workflow with request info enabled (pauses after each agent responds) + workflow = ( + SequentialBuilder() + .participants([drafter, editor, finalizer]) + # Only enable request info for the editor agent + .with_request_info(agents=["editor"]) + .build() + ) # Run the workflow with request info handling - pending_responses: dict[str, str] | None = None + pending_responses: dict[str, AgentRequestInfoResponse] | None = None workflow_complete = False print("Starting document review workflow...") @@ -84,26 +91,34 @@ async def main() -> None: # Process events async for event in stream: if isinstance(event, RequestInfoEvent): - if isinstance(event.data, AgentInputRequest): - # Display pre-agent context for steering + if isinstance(event.data, AgentExecutorResponse): + # Display agent response and conversation context for review print("\n" + "-" * 40) print("REQUEST INFO: INPUT REQUESTED") - print(f"About to call agent: {event.data.target_agent_id}") - print("-" * 40) - print("Conversation context:") - recent = ( - event.data.conversation[-2:] if len(event.data.conversation) > 2 else event.data.conversation + print( + f"Agent {event.source_executor_id} just responded with: '{event.data.agent_response.text}'. " + "Please provide your feedback." ) - for msg in recent: - role = msg.role.value if msg.role else "unknown" - text = (msg.text or "")[:150] - print(f" [{role}]: {text}...") print("-" * 40) - - # Get input to steer the agent - user_input = input("Your guidance (or 'skip' to continue): ") # noqa: ASYNC250 + if event.data.full_conversation: + print("Conversation context:") + recent = ( + event.data.full_conversation[-2:] + if len(event.data.full_conversation) > 2 + else event.data.full_conversation + ) + for msg in recent: + name = msg.author_name or msg.role.value + text = (msg.text or "")[:150] + print(f" [{name}]: {text}...") + print("-" * 40) + + # Get feedback on the agent's response (approve or request iteration) + user_input = input("Your guidance (or 'skip' to approve): ") # noqa: ASYNC250 if user_input.lower() == "skip": - user_input = "Please continue naturally." + user_input = AgentRequestInfoResponse.approve() + else: + user_input = AgentRequestInfoResponse.from_strings([user_input]) pending_responses = {event.request_id: user_input} print("(Resuming workflow...)") diff --git a/python/samples/getting_started/workflows/orchestration/concurrent_custom_aggregator.py b/python/samples/getting_started/workflows/orchestration/concurrent_custom_aggregator.py index 44f71ba7bc..e45eb0c11f 100644 --- a/python/samples/getting_started/workflows/orchestration/concurrent_custom_aggregator.py +++ b/python/samples/getting_started/workflows/orchestration/concurrent_custom_aggregator.py @@ -58,7 +58,7 @@ async def summarize_results(results: list[Any]) -> str: expert_sections: list[str] = [] for r in results: try: - messages = getattr(r.agent_run_response, "messages", []) + messages = getattr(r.agent_response, "messages", []) final_text = messages[-1].text if messages and hasattr(messages[-1], "text") else "(no content)" expert_sections.append(f"{getattr(r, 'executor_id', 'expert')}:\n{final_text}") except Exception as e: diff --git a/python/samples/getting_started/workflows/orchestration/concurrent_participant_factory.py b/python/samples/getting_started/workflows/orchestration/concurrent_participant_factory.py index 435e59b2ba..01e343a2aa 100644 --- a/python/samples/getting_started/workflows/orchestration/concurrent_participant_factory.py +++ b/python/samples/getting_started/workflows/orchestration/concurrent_participant_factory.py @@ -89,7 +89,7 @@ async def summarize_results(self, results: list[Any], ctx: WorkflowContext[Never expert_sections: list[str] = [] for r in results: try: - messages = getattr(r.agent_run_response, "messages", []) + messages = getattr(r.agent_response, "messages", []) final_text = messages[-1].text if messages and hasattr(messages[-1], "text") else "(no content)" expert_sections.append(f"{getattr(r, 'executor_id', 'expert')}:\n{final_text}") except Exception as e: diff --git a/python/samples/getting_started/workflows/orchestration/group_chat_agent_manager.py b/python/samples/getting_started/workflows/orchestration/group_chat_agent_manager.py index 3bc79fcddc..12475205d3 100644 --- a/python/samples/getting_started/workflows/orchestration/group_chat_agent_manager.py +++ b/python/samples/getting_started/workflows/orchestration/group_chat_agent_manager.py @@ -1,8 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -import logging -from typing import cast from agent_framework import ( AgentRunUpdateEvent, @@ -15,8 +13,6 @@ from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential -logging.basicConfig(level=logging.INFO) - """ Sample: Group Chat with Agent-Based Manager @@ -29,50 +25,54 @@ - OpenAI environment variables configured for OpenAIChatClient """ - -def _get_chat_client() -> AzureOpenAIChatClient: - return AzureOpenAIChatClient(credential=AzureCliCredential()) - - -async def main() -> None: - # Create coordinator agent with structured output for speaker selection - # Note: response_format is enforced to ManagerSelectionResponse by set_manager() - coordinator = ChatAgent( - name="Coordinator", - description="Coordinates multi-agent collaboration by selecting speakers", - instructions=""" +ORCHESTRATOR_AGENT_INSTRUCTIONS = """ You coordinate a team conversation to solve the user's task. -Review the conversation history and select the next participant to speak. - Guidelines: - Start with Researcher to gather information - Then have Writer synthesize the final answer - Only finish after both have contributed meaningfully -- Allow for multiple rounds of information gathering if needed -""", - chat_client=_get_chat_client(), +""" + + +async def main() -> None: + # Create a chat client using Azure OpenAI and Azure CLI credentials for all agents + chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) + + # Orchestrator agent that manages the conversation + # Note: This agent (and the underlying chat client) must support structured outputs. + # The group chat workflow relies on this to parse the orchestrator's decisions. + # `response_format` is set internally by the GroupChat workflow when the agent is invoked. + orchestrator_agent = ChatAgent( + name="Orchestrator", + description="Coordinates multi-agent collaboration by selecting speakers", + instructions=ORCHESTRATOR_AGENT_INSTRUCTIONS, + chat_client=chat_client, ) + # Participant agents researcher = ChatAgent( name="Researcher", description="Collects relevant background information", instructions="Gather concise facts that help a teammate answer the question.", - chat_client=_get_chat_client(), + chat_client=chat_client, ) writer = ChatAgent( name="Writer", description="Synthesizes polished answers from gathered information", instructions="Compose clear and structured answers using any notes provided.", - chat_client=_get_chat_client(), + chat_client=chat_client, ) + # Build the group chat workflow workflow = ( GroupChatBuilder() - .set_manager(coordinator, display_name="Orchestrator") - .with_termination_condition(lambda messages: sum(1 for msg in messages if msg.role == Role.ASSISTANT) >= 2) + .with_agent_orchestrator(orchestrator_agent) .participants([researcher, writer]) + # Set a hard termination condition: stop after 4 assistant messages + # The agent orchestrator will intelligently decide when to end before this limit but just in case + .with_termination_condition(lambda messages: sum(1 for msg in messages if msg.role == Role.ASSISTANT) >= 4) .build() ) @@ -82,30 +82,35 @@ async def main() -> None: print(f"TASK: {task}\n") print("=" * 80) - final_conversation: list[ChatMessage] = [] + # Keep track of the last executor to format output nicely in streaming mode last_executor_id: str | None = None + output_event: WorkflowOutputEvent | None = None async for event in workflow.run_stream(task): if isinstance(event, AgentRunUpdateEvent): eid = event.executor_id if eid != last_executor_id: if last_executor_id is not None: - print() + print("\n") print(f"{eid}:", end=" ", flush=True) last_executor_id = eid print(event.data, end="", flush=True) elif isinstance(event, WorkflowOutputEvent): - final_conversation = cast(list[ChatMessage], event.data) - - if final_conversation and isinstance(final_conversation, list): - print("\n\n" + "=" * 80) - print("FINAL CONVERSATION") - print("=" * 80) - for msg in final_conversation: - author = getattr(msg, "author_name", "Unknown") - text = getattr(msg, "text", str(msg)) - print(f"\n[{author}]") - print(text) - print("-" * 80) + output_event = event + + # The output of the workflow is the full list of messages exchanged + if output_event: + if not isinstance(output_event.data, list) or not all( + isinstance(msg, ChatMessage) + for msg in output_event.data # type: ignore + ): + raise RuntimeError("Unexpected output event data format.") + print("\n" + "=" * 80) + print("\nFINAL OUTPUT (The conversation history)\n") + for msg in output_event.data: # type: ignore + assert isinstance(msg, ChatMessage) + print(f"{msg.author_name or msg.role}: {msg.text}\n") + else: + raise RuntimeError("Workflow did not produce a final output event.") if __name__ == "__main__": diff --git a/python/samples/getting_started/workflows/orchestration/group_chat_philosophical_debate.py b/python/samples/getting_started/workflows/orchestration/group_chat_philosophical_debate.py index 7059a84e32..a26b9df4d0 100644 --- a/python/samples/getting_started/workflows/orchestration/group_chat_philosophical_debate.py +++ b/python/samples/getting_started/workflows/orchestration/group_chat_philosophical_debate.py @@ -211,7 +211,7 @@ async def main() -> None: workflow = ( GroupChatBuilder() - .set_manager(moderator, display_name="Moderator") + .with_agent_orchestrator(moderator) .participants([farmer, developer, teacher, activist, spiritual_leader, artist, immigrant, doctor]) .with_termination_condition(lambda messages: sum(1 for msg in messages if msg.role == Role.ASSISTANT) >= 10) .build() @@ -241,13 +241,11 @@ async def main() -> None: async for event in workflow.run_stream(f"Please begin the discussion on: {topic}"): if isinstance(event, AgentRunUpdateEvent): - speaker_id = event.executor_id.replace("groupchat_agent:", "") - - if speaker_id != current_speaker: + if event.executor_id != current_speaker: if current_speaker is not None: print("\n") - print(f"[{speaker_id}]", flush=True) - current_speaker = speaker_id + print(f"[{event.executor_id}]", flush=True) + current_speaker = event.executor_id print(event.data, end="", flush=True) @@ -286,10 +284,6 @@ async def main() -> None: DISCUSSION BEGINS ================================================================================ - [Moderator] - {"selected_participant":"Farmer","instruction":"Please start by sharing what living a good life means to you, - especially from your perspective living in a rural area in Southeast Asia.","finish":false,"final_message":null} - [Farmer] To me, a good life is deeply intertwined with the rhythm of the land and the nurturing of relationships with my family and community. It means cultivating crops that respect our environment, ensuring sustainability for future @@ -298,11 +292,6 @@ async def main() -> None: wealth. It's the simple moments, like sharing stories with my children under the stars, that truly define a good life. What good is progress if it isolates us from those we love and the land that sustains us? - [Moderator] - {"selected_participant":"Developer","instruction":"Given the insights shared by the Farmer, please discuss what a - good life means to you as a software developer in an urban setting in the United States and how it might contrast - with or complement the Farmer's view.","finish":false,"final_message":null} - [Developer] As a software developer in an urban environment, a good life for me hinges on the intersection of innovation, creativity, and balance. It's about having the freedom to explore new technologies that can solve real-world @@ -312,11 +301,6 @@ async def main() -> None: rich personal experiences. The challenge is finding harmony between technological progress and preserving the intimate human connections that truly enrich our lives. - [Moderator] - {"selected_participant":"SpiritualLeader","instruction":"Reflect on both the Farmer's and Developer's perspectives - and share your view of what constitutes a good life, particularly from your spiritual and cultural standpoint in - the Middle East.","finish":false,"final_message":null} - [SpiritualLeader] From my spiritual perspective, a good life embodies a balance between personal fulfillment and service to others, rooted in compassion and community. In our teachings, we emphasize that true happiness comes from helping those in @@ -326,11 +310,6 @@ async def main() -> None: with those around us. Ultimately, as we align our personal beliefs with our communal responsibilities, we cultivate a richness that transcends material wealth. - [Moderator] - {"selected_participant":"Activist","instruction":"Add to the discussion by sharing your perspective on what a good - life entails, particularly from your background as a young activist in South America.","finish":false, - "final_message":null} - [Activist] As a young activist in South America, a good life for me is about advocating for social justice and environmental sustainability. It means living in a society where everyone's rights are respected and where marginalized voices, @@ -341,11 +320,6 @@ async def main() -> None: not just lived for oneself but is deeply tied to the well-being of our communities and the health of our environment. How can we, regardless of our backgrounds, collaborate to foster these essential changes? - [Moderator] - {"selected_participant":"Teacher","instruction":"Considering the views shared so far, tell us how your experience - as a retired history teacher from Eastern Europe shapes your understanding of a good life, perhaps reflecting on - lessons from the past and their impact on present-day life choices.","finish":false,"final_message":null} - [Teacher] As a retired history teacher from Eastern Europe, my understanding of a good life is deeply rooted in the lessons drawn from history and the struggle for freedom and dignity. Historical events, such as the fall of the Iron @@ -357,11 +331,6 @@ async def main() -> None: contributions to the rich tapestry of our shared humanity. How can we ensure that the lessons of history inform a more compassionate and just society moving forward? - [Moderator] - {"selected_participant":"Artist","instruction":"Expound on the themes and perspectives discussed so far by sharing - how, as an artist from Africa, you define a good life and how art plays a role in that vision.","finish":false, - "final_message":null} - [Artist] As an artist from Africa, I define a good life as one steeped in cultural expression, storytelling, and the celebration of our collective memories. Art is a powerful medium through which we capture our histories, struggles, @@ -373,19 +342,6 @@ async def main() -> None: collective good, fostering empathy and understanding among diverse communities. How can we harness art to bridge differences and amplify marginalized voices in our pursuit of a good life? - [Moderator] - {"selected_participant":null,"instruction":null,"finish":true,"final_message":"As our discussion unfolds, several - key themes have gracefully emerged, reflecting the richness of diverse perspectives on what constitutes a good life. - From the rural farmer's integration with the land to the developer's search for balance between technology and - personal connection, each viewpoint validates that fulfillment, at its core, transcends material wealth. The - spiritual leader and the activist highlight the importance of community and social justice, while the history - teacher and the artist remind us of the lessons and narratives that shape our cultural and personal identities. - - Ultimately, the good life seems to revolve around meaningful relationships, honoring our legacies while striving for - progress, and nurturing both our inner selves and external communities. This dialogue demonstrates that despite our - varied backgrounds and experiences, the quest for a good life binds us together, urging cooperation and empathy in - our shared human journey."} - ================================================================================ DISCUSSION SUMMARY ================================================================================ diff --git a/python/samples/getting_started/workflows/orchestration/group_chat_simple_selector.py b/python/samples/getting_started/workflows/orchestration/group_chat_simple_selector.py index 1fd074ca4d..517ae313f3 100644 --- a/python/samples/getting_started/workflows/orchestration/group_chat_simple_selector.py +++ b/python/samples/getting_started/workflows/orchestration/group_chat_simple_selector.py @@ -1,113 +1,134 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -import logging -from typing import cast -from agent_framework import ChatAgent, ChatMessage, GroupChatBuilder, GroupChatStateSnapshot, WorkflowOutputEvent -from agent_framework.openai import OpenAIChatClient - -logging.basicConfig(level=logging.INFO) +from agent_framework import ( + AgentRunUpdateEvent, + ChatAgent, + ChatMessage, + GroupChatBuilder, + GroupChatState, + WorkflowOutputEvent, +) +from agent_framework.azure import AzureOpenAIChatClient +from azure.identity import AzureCliCredential """ -Sample: Group Chat with Simple Speaker Selector Function +Sample: Group Chat with a round-robin speaker selector What it does: -- Demonstrates the set_select_speakers_func() API for GroupChat orchestration +- Demonstrates the with_select_speaker_func() API for GroupChat orchestration - Uses a pure Python function to control speaker selection based on conversation state -- Alternates between researcher and writer agents in a simple round-robin pattern -- Shows how to access conversation history, round index, and participant metadata - -Key pattern: - def select_next_speaker(state: GroupChatStateSnapshot) -> str | None: - # state contains: task, participants, conversation, history, round_index - # Return participant name to continue, or None to finish - ... Prerequisites: - OpenAI environment variables configured for OpenAIChatClient """ -def select_next_speaker(state: GroupChatStateSnapshot) -> str | None: - """Simple speaker selector that alternates between researcher and writer. - - This function demonstrates the core pattern: - 1. Examine the current state of the group chat - 2. Decide who should speak next - 3. Return participant name or None to finish - - Args: - state: Immutable snapshot containing: - - task: ChatMessage - original user task - - participants: dict[str, str] - participant names → descriptions - - conversation: tuple[ChatMessage, ...] - full conversation history - - history: tuple[GroupChatTurn, ...] - turn-by-turn with speaker attribution - - round_index: int - number of selection rounds so far - - pending_agent: str | None - currently active agent (if any) - - Returns: - Name of next speaker, or None to finish the conversation - """ - round_idx = state["round_index"] - history = state["history"] +def round_robin_selector(state: GroupChatState) -> str: + """A round-robin selector function that picks the next speaker based on the current round index.""" - # Finish after 4 turns (researcher → writer → researcher → writer) - if round_idx >= 4: - return None + participant_names = list(state.participants.keys()) + return participant_names[state.current_round % len(participant_names)] - # Get the last speaker from history - last_speaker = history[-1].speaker if history else None - # Simple alternation: researcher → writer → researcher → writer - if last_speaker == "Researcher": - return "Writer" - return "Researcher" +async def main() -> None: + # Create a chat client using Azure OpenAI and Azure CLI credentials for all agents + chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) + + # Participant agents + expert = ChatAgent( + name="PythonExpert", + instructions=( + "You are an expert in Python in a workgroup. " + "Your job is to answer Python related questions and refine your answer " + "based on feedback from all the other participants." + ), + chat_client=chat_client, + ) + verifier = ChatAgent( + name="AnswerVerifier", + instructions=( + "You are a programming expert in a workgroup. " + f"Your job is to review the answer provided by {expert.name} and point " + "out statements that are technically true but practically dangerous." + "If there is nothing woth pointing out, respond with 'The answer looks good to me.'" + ), + chat_client=chat_client, + ) -async def main() -> None: - researcher = ChatAgent( - name="Researcher", - description="Collects relevant background information.", - instructions="Gather concise facts that help answer the question. Be brief.", - chat_client=OpenAIChatClient(model_id="gpt-4o-mini"), + clarifier = ChatAgent( + name="AnswerClarifier", + instructions=( + "You are an accessibility expert in a workgroup. " + f"Your job is to review the answer provided by {expert.name} and point " + "out jargons or complex terms that may be difficult for a beginner to understand." + "If there is nothing worth pointing out, respond with 'The answer looks clear to me.'" + ), + chat_client=chat_client, ) - writer = ChatAgent( - name="Writer", - description="Synthesizes a polished answer using the gathered notes.", - instructions="Compose a clear, structured answer using any notes provided.", - chat_client=OpenAIChatClient(model_id="gpt-4o-mini"), + skeptic = ChatAgent( + name="Skeptic", + instructions=( + "You are a devil's advocate in a workgroup. " + f"Your job is to review the answer provided by {expert.name} and point " + "out caveats, exceptions, and alternative perspectives." + "If there is nothing worth pointing out, respond with 'I have no further questions.'" + ), + chat_client=chat_client, ) - # Two ways to specify participants: - # 1. List form - uses agent.name attribute: .participants([researcher, writer]) - # 2. Dict form - explicit names: .participants(researcher=researcher, writer=writer) + # Build the group chat workflow workflow = ( GroupChatBuilder() - .set_select_speakers_func(select_next_speaker, display_name="Orchestrator") - .participants([researcher, writer]) # Uses agent.name for participant names + .participants([expert, verifier, clarifier, skeptic]) + .with_select_speaker_func(round_robin_selector) + # Set a hard termination condition: stop after 6 messages (user task + one full rounds + 1) + # One round is expert -> verifier -> clarifier -> skeptic, after which the expert gets to respond again. + # This will end the conversation after the expert has spoken 2 times (one iteration loop) + # Note: it's possible that the expert gets it right the first time and the other participants + # have nothing to add, but for demo purposes we want to see at least one full round of interaction. + .with_termination_condition(lambda conversation: len(conversation) >= 6) .build() ) - task = "What are the key benefits of using async/await in Python?" + task = "How does Python’s Protocol differ from abstract base classes?" - print("\nStarting Group Chat with Simple Speaker Selector...\n") + print("\nStarting Group Chat with round-robin speaker selector...\n") print(f"TASK: {task}\n") print("=" * 80) + # Keep track of the last executor to format output nicely in streaming mode + last_executor_id: str | None = None + output_event: WorkflowOutputEvent | None = None async for event in workflow.run_stream(task): - if isinstance(event, WorkflowOutputEvent): - conversation = cast(list[ChatMessage], event.data) - if isinstance(conversation, list): - print("\n===== Final Conversation =====\n") - for msg in conversation: - author = getattr(msg, "author_name", "Unknown") - text = getattr(msg, "text", str(msg)) - print(f"[{author}]\n{text}\n") - print("-" * 80) - - print("\nWorkflow completed.") + if isinstance(event, AgentRunUpdateEvent): + eid = event.executor_id + if eid != last_executor_id: + if last_executor_id is not None: + print("\n") + print(f"{eid}:", end=" ", flush=True) + last_executor_id = eid + print(event.data, end="", flush=True) + elif isinstance(event, WorkflowOutputEvent): + output_event = event + + # The output of the workflow is the full list of messages exchanged + if output_event: + if not isinstance(output_event.data, list) or not all( + isinstance(msg, ChatMessage) + for msg in output_event.data # type: ignore + ): + raise RuntimeError("Unexpected output event data format.") + print("\n" + "=" * 80) + print("\nFINAL OUTPUT (The conversation history)\n") + for msg in output_event.data: # type: ignore + assert isinstance(msg, ChatMessage) + print(f"{msg.author_name or msg.role}: {msg.text}\n") + else: + raise RuntimeError("Workflow did not produce a final output event.") if __name__ == "__main__": diff --git a/python/samples/getting_started/workflows/orchestration/handoff_autonomous.py b/python/samples/getting_started/workflows/orchestration/handoff_autonomous.py index 154f768d09..2d0542a0fb 100644 --- a/python/samples/getting_started/workflows/orchestration/handoff_autonomous.py +++ b/python/samples/getting_started/workflows/orchestration/handoff_autonomous.py @@ -5,7 +5,7 @@ from typing import cast from agent_framework import ( - AgentRunResponseUpdate, + AgentResponseUpdate, AgentRunUpdateEvent, ChatAgent, ChatMessage, @@ -13,6 +13,7 @@ HostedWebSearchTool, WorkflowEvent, WorkflowOutputEvent, + resolve_agent_id, ) from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential @@ -21,7 +22,7 @@ """Sample: Autonomous handoff workflow with agent iteration. -This sample demonstrates `with_interaction_mode("autonomous")`, where agents continue +This sample demonstrates `.with_autonomous_mode()`, where agents continue iterating on their task until they explicitly invoke a handoff tool. This allows specialists to perform long-running autonomous work (research, coding, analysis) without prematurely returning control to the coordinator or user. @@ -35,7 +36,7 @@ Key Concepts: - Autonomous interaction mode: agents iterate until they handoff - - Turn limits: use `with_interaction_mode("autonomous", autonomous_turn_limit=N)` to cap total iterations + - Turn limits: use `.with_autonomous_mode(turn_limits={agent_name: N})` to cap iterations per agent """ @@ -53,7 +54,7 @@ def create_agents( research_agent = chat_client.create_agent( instructions=( - "You are a research specialist that explores topics thoroughly on the Microsoft Learn Site." + "You are a research specialist that explores topics thoroughly using web search. " "When given a research task, break it down into multiple aspects and explore each one. " "Continue your research across multiple responses - don't try to finish everything in one " "response. After each response, think about what else needs to be explored. When you have " @@ -81,7 +82,7 @@ def create_agents( def _display_event(event: WorkflowEvent) -> None: """Print the final conversation snapshot from workflow output events.""" if isinstance(event, AgentRunUpdateEvent) and event.data: - update: AgentRunResponseUpdate = event.data + update: AgentResponseUpdate = event.data if not update.text: return global last_response_id @@ -112,11 +113,21 @@ async def main() -> None: name="autonomous_iteration_handoff", participants=[coordinator, research_agent, summary_agent], ) - .set_coordinator(coordinator) + .with_start_agent(coordinator) .add_handoff(coordinator, [research_agent, summary_agent]) - .add_handoff(research_agent, coordinator) # Research can hand back to coordinator - .add_handoff(summary_agent, coordinator) - .with_interaction_mode("autonomous", autonomous_turn_limit=15) + .add_handoff(research_agent, [coordinator]) # Research can hand back to coordinator + .add_handoff(summary_agent, [coordinator]) + .with_autonomous_mode( + # You can set turn limits per agent to allow some agents to go longer. + # If a limit is not set, the agent will get an default limit: 50. + # Internally, handoff prefers agent names as the agent identifiers if set. + # Otherwise, it falls back to agent IDs. + turn_limits={ + resolve_agent_id(coordinator): 5, + resolve_agent_id(research_agent): 10, + resolve_agent_id(summary_agent): 5, + } + ) .with_termination_condition( # Terminate after coordinator provides 5 assistant responses lambda conv: sum(1 for msg in conv if msg.author_name == "coordinator" and msg.role.value == "assistant") @@ -133,10 +144,10 @@ async def main() -> None: """ Expected behavior: - Coordinator routes to research_agent. - - Research agent iterates multiple times, exploring different aspects of renewable energy. + - Research agent iterates multiple times, exploring different aspects of Microsoft Agent Framework. - Each iteration adds to the conversation without returning to coordinator. - After thorough research, research_agent calls handoff to coordinator. - - Coordinator provides final summary. + - Coordinator routes to summary_agent for final summary. In autonomous mode, agents continue working until they invoke a handoff tool, allowing the research_agent to perform 3-4+ responses before handing off. diff --git a/python/samples/getting_started/workflows/orchestration/handoff_participant_factory.py b/python/samples/getting_started/workflows/orchestration/handoff_participant_factory.py index 1b676c5ffd..4330cc1ace 100644 --- a/python/samples/getting_started/workflows/orchestration/handoff_participant_factory.py +++ b/python/samples/getting_started/workflows/orchestration/handoff_participant_factory.py @@ -2,28 +2,30 @@ import asyncio import logging -from collections.abc import AsyncIterable -from typing import cast +from typing import Annotated, cast from agent_framework import ( + AgentResponse, + AgentRunEvent, ChatAgent, ChatMessage, + HandoffAgentUserRequest, HandoffBuilder, - HandoffUserInputRequest, + HandoffSentEvent, RequestInfoEvent, - Role, Workflow, WorkflowEvent, WorkflowOutputEvent, + WorkflowRunState, + WorkflowStatusEvent, ai_function, ) from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential -from typing import Annotated logging.basicConfig(level=logging.ERROR) -"""Sample: Autonomous handoff workflow with agent factory. +"""Sample: Handoff workflow with participant factories for state isolation. This sample demonstrates how to use participant factories in HandoffBuilder to create agents dynamically. @@ -33,7 +35,7 @@ requests or tasks in parallel with stateful participants. Routing Pattern: - User -> Coordinator -> Specialist (iterates N times) -> Handoff -> Final Output + User -> Triage Agent -> Specialist (Refund/Order Status/Return) -> User Prerequisites: - `az login` (Azure CLI authentication) @@ -41,6 +43,7 @@ Key Concepts: - Participant factories: create agents via factory functions for isolation + - State isolation: each workflow instance gets its own agent instances """ @@ -103,21 +106,6 @@ def create_return_agent() -> ChatAgent: ) -async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]: - """Collect all events from an async stream into a list. - - This helper drains the workflow's event stream so we can process events - synchronously after each workflow step completes. - - Args: - stream: Async iterable of WorkflowEvent - - Returns: - List of all events from the stream - """ - return [event async for event in stream] - - def _handle_events(events: list[WorkflowEvent]) -> list[RequestInfoEvent]: """Process workflow events and extract any pending user input requests. @@ -136,75 +124,98 @@ def _handle_events(events: list[WorkflowEvent]) -> list[RequestInfoEvent]: requests: list[RequestInfoEvent] = [] for event in events: + # AgentRunEvent: Contains messages generated by agents during their turn + if isinstance(event, AgentRunEvent): + for message in event.data.messages: + if not message.text: + # Skip messages without text (e.g., tool calls) + continue + speaker = message.author_name or message.role.value + print(f"- {speaker}: {message.text}") + + # HandoffSentEvent: Indicates a handoff has been initiated + if isinstance(event, HandoffSentEvent): + print(f"\n[Handoff from {event.source} to {event.target} initiated.]") + + # WorkflowStatusEvent: Indicates workflow state changes + if isinstance(event, WorkflowStatusEvent) and event.state in { + WorkflowRunState.IDLE, + WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, + }: + print(f"\n[Workflow Status] {event.state.name}") + # WorkflowOutputEvent: Contains the final conversation when workflow terminates - if isinstance(event, WorkflowOutputEvent): + elif isinstance(event, WorkflowOutputEvent): conversation = cast(list[ChatMessage], event.data) if isinstance(conversation, list): print("\n=== Final Conversation Snapshot ===") for message in conversation: speaker = message.author_name or message.role.value - print(f"- {speaker}: {message.text}") + print(f"- {speaker}: {message.text or [content.type for content in message.contents]}") print("===================================") # RequestInfoEvent: Workflow is requesting user input elif isinstance(event, RequestInfoEvent): - if isinstance(event.data, HandoffUserInputRequest): - _print_agent_responses_since_last_user_message(event.data) + if isinstance(event.data, HandoffAgentUserRequest): + _print_handoff_agent_user_request(event.data.agent_response) requests.append(event) return requests -def _print_agent_responses_since_last_user_message(request: HandoffUserInputRequest) -> None: - """Display agent responses since the last user message in a handoff request. +def _print_handoff_agent_user_request(response: AgentResponse) -> None: + """Display the agent's response messages when requesting user input. - The HandoffUserInputRequest contains the full conversation history so far, - allowing the user to see what's been discussed before providing their next input. + This will happen when an agent generates a response that doesn't trigger + a handoff, i.e., the agent is asking the user for more information. Args: - request: The user input request containing conversation and prompt + response: The AgentResponse from the agent requesting user input """ - if not request.conversation: - raise RuntimeError("HandoffUserInputRequest missing conversation history.") - - # Reverse iterate to collect agent responses since last user message - agent_responses: list[ChatMessage] = [] - for message in request.conversation[::-1]: - if message.role == Role.USER: - break - agent_responses.append(message) - - # Print agent responses in original order - agent_responses.reverse() - for message in agent_responses: + if not response.messages: + raise RuntimeError("Cannot print agent responses: response has no messages.") + + print("\n[Agent is requesting your input...]") + + # Print agent responses + for message in response.messages: + if not message.text: + # Skip messages without text (e.g., tool calls) + continue speaker = message.author_name or message.role.value print(f"- {speaker}: {message.text}") -async def _run_Workflow(workflow: Workflow, user_inputs: list[str]) -> None: +async def _run_workflow(workflow: Workflow, user_inputs: list[str]) -> None: """Run the workflow with the given user input and display events.""" print(f"- User: {user_inputs[0]}") - events = await _drain(workflow.run_stream(user_inputs[0])) - pending_requests = _handle_events(events) + workflow_result = await workflow.run(user_inputs[0]) + pending_requests = _handle_events(workflow_result) # Process the request/response cycle # The workflow will continue requesting input until: # 1. The termination condition is met (4 user messages in this case), OR # 2. We run out of scripted responses - while pending_requests and user_inputs[1:]: - # Get the next scripted response - user_response = user_inputs.pop(1) - print(f"\n- User: {user_response}") - - # Send response(s) to all pending requests - # In this demo, there's typically one request per cycle, but the API supports multiple - responses = {req.request_id: user_response for req in pending_requests} + while pending_requests: + if user_inputs[1:]: + # Get the next scripted response + user_response = user_inputs.pop(1) + print(f"\n- User: {user_response}") + + # Send response(s) to all pending requests + # In this demo, there's typically one request per cycle, but the API supports multiple + responses = { + req.request_id: HandoffAgentUserRequest.create_response(user_response) for req in pending_requests + } + else: + # No more scripted responses; terminate the workflow + responses = {req.request_id: HandoffAgentUserRequest.terminate() for req in pending_requests} # Send responses and get new events # We use send_responses_streaming() to get events as they occur, allowing us to # display agent responses in real-time and handle new requests as they arrive - events = await _drain(workflow.send_responses_streaming(responses)) - pending_requests = _handle_events(events) + workflow_result = await workflow.send_responses(responses) + pending_requests = _handle_events(workflow_result) async def main() -> None: @@ -220,7 +231,7 @@ async def main() -> None: "return": create_return_agent, }, ) - .set_coordinator("triage") + .with_start_agent("triage") .with_termination_condition( # Custom termination: Check if the triage agent has provided a closing message. # This looks for the last message being from triage_agent and containing "welcome", @@ -244,14 +255,14 @@ async def main() -> None: workflow_a = workflow_builder.build() print("=== Running workflow_a ===") - await _run_Workflow(workflow_a, list(user_inputs)) + await _run_workflow(workflow_a, list(user_inputs)) workflow_b = workflow_builder.build() print("=== Running workflow_b ===") # Only provide the last two inputs to workflow_b to demonstrate state isolation # The agents in this workflow have no prior context thus should not have knowledge of # order 1234 or previous interactions. - await _run_Workflow(workflow_b, user_inputs[2:]) + await _run_workflow(workflow_b, user_inputs[2:]) """ Expected behavior: - workflow_a and workflow_b maintain separate states for their participants. diff --git a/python/samples/getting_started/workflows/orchestration/handoff_return_to_previous.py b/python/samples/getting_started/workflows/orchestration/handoff_return_to_previous.py deleted file mode 100644 index 8f859bfb0f..0000000000 --- a/python/samples/getting_started/workflows/orchestration/handoff_return_to_previous.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import asyncio -from collections.abc import AsyncIterable -from typing import cast - -from agent_framework import ( - ChatAgent, - HandoffBuilder, - HandoffUserInputRequest, - RequestInfoEvent, - WorkflowEvent, - WorkflowOutputEvent, -) -from agent_framework.azure import AzureOpenAIChatClient -from azure.identity import AzureCliCredential - -"""Sample: Handoff workflow with return-to-previous routing enabled. - -This interactive sample demonstrates the return-to-previous feature where user inputs -route directly back to the specialist currently handling their request, rather than -always going through the coordinator for re-evaluation. - -Routing Pattern (with return-to-previous enabled): - User -> Coordinator -> Technical Support -> User -> Technical Support -> ... - -Routing Pattern (default, without return-to-previous): - User -> Coordinator -> Technical Support -> User -> Coordinator -> Technical Support -> ... - -This is useful when a specialist needs multiple turns with the user to gather -information or resolve an issue, avoiding unnecessary coordinator involvement. - -Specialist-to-Specialist Handoff: - When a user's request changes to a topic outside the current specialist's domain, - the specialist can hand off DIRECTLY to another specialist without going back through - the coordinator: - - User -> Coordinator -> Technical Support -> User -> Technical Support (billing question) - -> Billing -> User -> Billing ... - -Example Interaction: - 1. User reports a technical issue - 2. Coordinator routes to technical support specialist - 3. Technical support asks clarifying questions - 4. User provides details (routes directly back to technical support) - 5. Technical support continues troubleshooting with full context - 6. Issue resolved, user asks about billing - 7. Technical support hands off DIRECTLY to billing specialist - 8. Billing specialist helps with payment - 9. User continues with billing (routes directly to billing) - -Prerequisites: - - `az login` (Azure CLI authentication) - - Environment variables configured for AzureOpenAIChatClient (AZURE_OPENAI_ENDPOINT, etc.) - -Usage: - Run the script and interact with the support workflow by typing your requests. - Type 'exit' or 'quit' to end the conversation. - -Key Concepts: - - Return-to-previous: Direct routing to current agent handling the conversation - - Current agent tracking: Framework remembers which agent is actively helping the user - - Context preservation: Specialist maintains full conversation context - - Domain switching: Specialists can hand back to coordinator when topic changes -""" - - -def create_agents(chat_client: AzureOpenAIChatClient) -> tuple[ChatAgent, ChatAgent, ChatAgent, ChatAgent]: - """Create and configure the coordinator and specialist agents. - - Returns: - Tuple of (coordinator, technical_support, account_specialist, billing_agent) - """ - coordinator = chat_client.create_agent( - instructions=( - "You are a customer support coordinator. Analyze the user's request and route to " - "the appropriate specialist:\n" - "- technical_support for technical issues, troubleshooting, repairs, hardware/software problems\n" - "- account_specialist for account changes, profile updates, settings, login issues\n" - "- billing_agent for payments, invoices, refunds, charges, billing questions\n" - "\n" - "When you receive a request, immediately call the matching handoff tool without explaining. " - "Read the most recent user message to determine the correct specialist." - ), - name="coordinator", - ) - - technical_support = chat_client.create_agent( - instructions=( - "You provide technical support. Help users troubleshoot technical issues, " - "arrange repairs, and answer technical questions. " - "Gather information through conversation. " - "If the user asks about billing, payments, invoices, or refunds, hand off to billing_agent. " - "If the user asks about account settings or profile changes, hand off to account_specialist." - ), - name="technical_support", - ) - - account_specialist = chat_client.create_agent( - instructions=( - "You handle account management. Help with profile updates, account settings, " - "and preferences. Gather information through conversation. " - "If the user asks about technical issues or troubleshooting, hand off to technical_support. " - "If the user asks about billing, payments, invoices, or refunds, hand off to billing_agent." - ), - name="account_specialist", - ) - - billing_agent = chat_client.create_agent( - instructions=( - "You handle billing only. Process payments, explain invoices, handle refunds. " - "If the user asks about technical issues or troubleshooting, hand off to technical_support. " - "If the user asks about account settings or profile changes, hand off to account_specialist." - ), - name="billing_agent", - ) - - return coordinator, technical_support, account_specialist, billing_agent - - -def handle_events(events: list[WorkflowEvent]) -> list[RequestInfoEvent]: - """Process events and return pending input requests.""" - pending_requests: list[RequestInfoEvent] = [] - for event in events: - if isinstance(event, RequestInfoEvent): - pending_requests.append(event) - request_data = cast(HandoffUserInputRequest, event.data) - print(f"\n{'=' * 60}") - print(f"AWAITING INPUT FROM: {request_data.awaiting_agent_id.upper()}") - print(f"{'=' * 60}") - for msg in request_data.conversation[-3:]: - author = msg.author_name or msg.role.value - prefix = ">>> " if author == request_data.awaiting_agent_id else " " - print(f"{prefix}[{author}]: {msg.text}") - elif isinstance(event, WorkflowOutputEvent): - print(f"\n{'=' * 60}") - print("[WORKFLOW COMPLETE]") - print(f"{'=' * 60}") - return pending_requests - - -async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]: - """Drain an async iterable into a list.""" - events: list[WorkflowEvent] = [] - async for event in stream: - events.append(event) - return events - - -async def main() -> None: - """Demonstrate return-to-previous routing in a handoff workflow.""" - chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) - coordinator, technical, account, billing = create_agents(chat_client) - - print("Handoff Workflow with Return-to-Previous Routing") - print("=" * 60) - print("\nThis interactive demo shows how user inputs route directly") - print("to the specialist handling your request, avoiding unnecessary") - print("coordinator re-evaluation on each turn.") - print("\nSpecialists can hand off directly to other specialists when") - print("your request changes topics (e.g., from technical to billing).") - print("\nType 'exit' or 'quit' to end the conversation.\n") - - # Configure handoffs with return-to-previous enabled - # Specialists can hand off directly to other specialists when topic changes - workflow = ( - HandoffBuilder( - name="return_to_previous_demo", - participants=[coordinator, technical, account, billing], - ) - .set_coordinator(coordinator) - .add_handoff(coordinator, [technical, account, billing]) # Coordinator routes to all specialists - .add_handoff(technical, [billing, account]) # Technical can route to billing or account - .add_handoff(account, [technical, billing]) # Account can route to technical or billing - .add_handoff(billing, [technical, account]) # Billing can route to technical or account - .enable_return_to_previous(True) # Enable the `return to previous handoff` feature - .with_termination_condition(lambda conv: sum(1 for msg in conv if msg.role.value == "user") >= 10) - .build() - ) - - # Get initial user request - initial_request = input("You: ").strip() # noqa: ASYNC250 - if not initial_request or initial_request.lower() in ["exit", "quit"]: - print("Goodbye!") - return - - # Start workflow with initial message - events = await _drain(workflow.run_stream(initial_request)) - pending_requests = handle_events(events) - - # Interactive loop: keep prompting for user input - while pending_requests: - user_input = input("\nYou: ").strip() # noqa: ASYNC250 - - if not user_input or user_input.lower() in ["exit", "quit"]: - print("\nEnding conversation. Goodbye!") - break - - responses = {req.request_id: user_input for req in pending_requests} - events = await _drain(workflow.send_responses_streaming(responses)) - pending_requests = handle_events(events) - - print("\n" + "=" * 60) - print("Conversation ended.") - - """ - Sample Output: - - Handoff Workflow with Return-to-Previous Routing - ============================================================ - - This interactive demo shows how user inputs route directly - to the specialist handling your request, avoiding unnecessary - coordinator re-evaluation on each turn. - - Specialists can hand off directly to other specialists when - your request changes topics (e.g., from technical to billing). - - Type 'exit' or 'quit' to end the conversation. - - You: I need help with my bill, I was charged twice by mistake. - - ============================================================ - AWAITING INPUT FROM: BILLING_AGENT - ============================================================ - [user]: I need help with my bill, I was charged twice by mistake. - [coordinator]: You will be connected to a billing agent who can assist you with the double charge on your bill. - >>> [billing_agent]: I'm here to help with billing concerns! I'm sorry you were charged twice. Could you - please provide the invoice number or your account email so I can look into this and begin processing a refund? - - You: Invoice 1234 - - ============================================================ - AWAITING INPUT FROM: BILLING_AGENT - ============================================================ - >>> [billing_agent]: I'm here to help with billing concerns! I'm sorry you were charged twice. - Could you please provide the invoice number or your account email so I can look into this and begin - processing a refund? - [user]: Invoice 1234 - >>> [billing_agent]: Thank you for providing the invoice number (1234). I will review the details and work - on processing a refund for the duplicate charge. - - Can you confirm which payment method you used for this bill (e.g., credit card, PayPal)? - This helps ensure your refund is processed to the correct account. - - You: I used my credit card, which is on autopay. - - ============================================================ - AWAITING INPUT FROM: BILLING_AGENT - ============================================================ - >>> [billing_agent]: Thank you for providing the invoice number (1234). I will review the details and work on - processing a refund for the duplicate charge. - - Can you confirm which payment method you used for this bill (e.g., credit card, PayPal)? This helps ensure - your refund is processed to the correct account. - [user]: I used my credit card, which is on autopay. - >>> [billing_agent]: Thank you for confirming your payment method. I will look into invoice 1234 and - process a refund for the duplicate charge to your credit card. - - You will receive a notification once the refund is completed. If you have any further questions about your billing - or need an update, please let me know! - - You: Actually I also can't turn on my modem. It reset and now won't turn on. - - ============================================================ - AWAITING INPUT FROM: TECHNICAL_SUPPORT - ============================================================ - [user]: Actually I also can't turn on my modem. It reset and now won't turn on. - [billing_agent]: I'm connecting you with technical support for assistance with your modem not turning on after - the reset. They'll be able to help troubleshoot and resolve this issue. - - At the same time, technical support will also handle your refund request for the duplicate charge on invoice 1234 - to your credit card on autopay. - - You will receive updates from the appropriate teams shortly. - >>> [technical_support]: Thanks for letting me know about your modem issue! To help you further, could you tell me: - - 1. Is there any light showing on the modem at all, or is it completely off? - 2. Have you tried unplugging the modem from power and plugging it back in? - 3. Do you hear or feel anything (like a slight hum or vibration) when the modem is plugged in? - - Let me know, and I'll guide you through troubleshooting or arrange a repair if needed. - - You: exit - - Ending conversation. Goodbye! - - ============================================================ - Conversation ended. - """ - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/orchestration/handoff_simple.py b/python/samples/getting_started/workflows/orchestration/handoff_simple.py index 84b6e0f243..e2e4aa7f1c 100644 --- a/python/samples/getting_started/workflows/orchestration/handoff_simple.py +++ b/python/samples/getting_started/workflows/orchestration/handoff_simple.py @@ -1,16 +1,17 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from collections.abc import AsyncIterable from typing import Annotated, cast from agent_framework import ( + AgentResponse, + AgentRunEvent, ChatAgent, ChatMessage, + HandoffAgentUserRequest, HandoffBuilder, - HandoffUserInputRequest, + HandoffSentEvent, RequestInfoEvent, - Role, WorkflowEvent, WorkflowOutputEvent, WorkflowRunState, @@ -20,27 +21,16 @@ from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential -"""Sample: Simple handoff workflow with single-tier triage-to-specialist routing. +"""Sample: Simple handoff workflow. -This sample demonstrates the basic handoff pattern where only the triage agent can -route to specialists. Specialists cannot hand off to other specialists - after any -specialist responds, control returns to the user (via the triage agent) for the next input. - -Routing Pattern: - User → Triage Agent → Specialist → Triage Agent → User → Triage Agent → ... - -This is the simplest handoff configuration, suitable for straightforward support -scenarios where a triage agent dispatches to domain specialists, and each specialist -works independently. - -For multi-tier specialist-to-specialist handoffs, see handoff_specialist_to_specialist.py. +A handoff workflow defines a pattern that assembles agents in a mesh topology, allowing +them to transfer control to each other based on the conversation context. Prerequisites: - `az login` (Azure CLI authentication) - Environment variables configured for AzureOpenAIChatClient (AZURE_OPENAI_ENDPOINT, etc.) Key Concepts: - - Single-tier routing: Only triage agent has handoff capabilities - Auto-registered handoff tools: HandoffBuilder automatically creates handoff tools for each participant, allowing the coordinator to transfer control to specialists - Termination condition: Controls when the workflow stops requesting user input @@ -69,14 +59,8 @@ def process_return(order_number: Annotated[str, "Order number to process return def create_agents(chat_client: AzureOpenAIChatClient) -> tuple[ChatAgent, ChatAgent, ChatAgent, ChatAgent]: """Create and configure the triage and specialist agents. - The triage agent is responsible for: - - Receiving all user input first - - Deciding whether to handle the request directly or hand off to a specialist - - Signaling handoff by calling one of the explicit handoff tools exposed to it - - Specialist agents are invoked only when the triage agent explicitly hands off to them. - After a specialist responds, control returns to the triage agent, which then prompts - the user for their next message. + Args: + chat_client: The AzureOpenAIChatClient to use for creating agents. Returns: Tuple of (triage_agent, refund_agent, order_agent, return_agent) @@ -117,21 +101,6 @@ def create_agents(chat_client: AzureOpenAIChatClient) -> tuple[ChatAgent, ChatAg return triage_agent, refund_agent, order_agent, return_agent -async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]: - """Collect all events from an async stream into a list. - - This helper drains the workflow's event stream so we can process events - synchronously after each workflow step completes. - - Args: - stream: Async iterable of WorkflowEvent - - Returns: - List of all events from the stream - """ - return [event async for event in stream] - - def _handle_events(events: list[WorkflowEvent]) -> list[RequestInfoEvent]: """Process workflow events and extract any pending user input requests. @@ -150,6 +119,19 @@ def _handle_events(events: list[WorkflowEvent]) -> list[RequestInfoEvent]: requests: list[RequestInfoEvent] = [] for event in events: + # AgentRunEvent: Contains messages generated by agents during their turn + if isinstance(event, AgentRunEvent): + for message in event.data.messages: + if not message.text: + # Skip messages without text (e.g., tool calls) + continue + speaker = message.author_name or message.role.value + print(f"- {speaker}: {message.text}") + + # HandoffSentEvent: Indicates a handoff has been initiated + if isinstance(event, HandoffSentEvent): + print(f"\n[Handoff from {event.source} to {event.target} initiated.]") + # WorkflowStatusEvent: Indicates workflow state changes if isinstance(event, WorkflowStatusEvent) and event.state in { WorkflowRunState.IDLE, @@ -164,40 +146,37 @@ def _handle_events(events: list[WorkflowEvent]) -> list[RequestInfoEvent]: print("\n=== Final Conversation Snapshot ===") for message in conversation: speaker = message.author_name or message.role.value - print(f"- {speaker}: {message.text}") + print(f"- {speaker}: {message.text or [content.type for content in message.contents]}") print("===================================") # RequestInfoEvent: Workflow is requesting user input elif isinstance(event, RequestInfoEvent): - if isinstance(event.data, HandoffUserInputRequest): - _print_agent_responses_since_last_user_message(event.data) + if isinstance(event.data, HandoffAgentUserRequest): + _print_handoff_agent_user_request(event.data.agent_response) requests.append(event) return requests -def _print_agent_responses_since_last_user_message(request: HandoffUserInputRequest) -> None: - """Display agent responses since the last user message in a handoff request. +def _print_handoff_agent_user_request(response: AgentResponse) -> None: + """Display the agent's response messages when requesting user input. - The HandoffUserInputRequest contains the full conversation history so far, - allowing the user to see what's been discussed before providing their next input. + This will happen when an agent generates a response that doesn't trigger + a handoff, i.e., the agent is asking the user for more information. Args: - request: The user input request containing conversation and prompt + response: The AgentResponse from the agent requesting user input """ - if not request.conversation: - raise RuntimeError("HandoffUserInputRequest missing conversation history.") - - # Reverse iterate to collect agent responses since last user message - agent_responses: list[ChatMessage] = [] - for message in request.conversation[::-1]: - if message.role == Role.USER: - break - agent_responses.append(message) - - # Print agent responses in original order - agent_responses.reverse() - for message in agent_responses: + if not response.messages: + raise RuntimeError("Cannot print agent responses: response has no messages.") + + print("\n[Agent is requesting your input...]") + + # Print agent responses + for message in response.messages: + if not message.text: + # Skip messages without text (e.g., tool calls) + continue speaker = message.author_name or message.role.value print(f"- {speaker}: {message.text}") @@ -223,25 +202,23 @@ async def main() -> None: # Build the handoff workflow # - participants: All agents that can participate in the workflow - # - set_coordinator: The triage agent is designated as the coordinator, which means + # - with_start_agent: The triage agent is designated as the start agent, which means # it receives all user input first and orchestrates handoffs to specialists # - with_termination_condition: Custom logic to stop the request/response loop. # Without this, the default behavior continues requesting user input until max_turns # is reached. Here we use a custom condition that checks if the conversation has ended - # naturally (when triage agent says something like "you're welcome"). + # naturally (when one of the agents says something like "you're welcome"). workflow = ( HandoffBuilder( name="customer_support_handoff", participants=[triage, refund, order, support], ) - .set_coordinator(triage) + .with_start_agent(triage) .with_termination_condition( - # Custom termination: Check if the triage agent has provided a closing message. - # This looks for the last message being from triage_agent and containing "welcome", - # which indicates the conversation has concluded naturally. - lambda conversation: len(conversation) > 0 - and conversation[-1].author_name == "triage_agent" - and "welcome" in conversation[-1].text.lower() + # Custom termination: Check if one of the agents has provided a closing message. + # This looks for the last message containing "welcome", which indicates the + # conversation has concluded naturally. + lambda conversation: len(conversation) > 0 and "welcome" in conversation[-1].text.lower() ) .build() ) @@ -252,6 +229,7 @@ async def main() -> None: # or integrate with a UI/chat interface scripted_responses = [ "My order 1234 arrived damaged and the packaging was destroyed. I'd like to return it.", + "Please also process a refund for order 1234.", "Thanks for resolving this.", ] @@ -260,26 +238,32 @@ async def main() -> None: print("[Starting workflow with initial user message...]\n") initial_message = "Hello, I need assistance with my recent purchase." print(f"- User: {initial_message}") - events = await _drain(workflow.run_stream(initial_message)) - pending_requests = _handle_events(events) + workflow_result = await workflow.run(initial_message) + pending_requests = _handle_events(workflow_result) # Process the request/response cycle # The workflow will continue requesting input until: - # 1. The termination condition is met (4 user messages in this case), OR + # 1. The termination condition is met, OR # 2. We run out of scripted responses - while pending_requests and scripted_responses: - # Get the next scripted response - user_response = scripted_responses.pop(0) - print(f"\n- User: {user_response}") - - # Send response(s) to all pending requests - # In this demo, there's typically one request per cycle, but the API supports multiple - responses = {req.request_id: user_response for req in pending_requests} + while pending_requests: + if not scripted_responses: + # No more scripted responses; terminate the workflow + responses = {req.request_id: HandoffAgentUserRequest.terminate() for req in pending_requests} + else: + # Get the next scripted response + user_response = scripted_responses.pop(0) + print(f"\n- User: {user_response}") + + # Send response(s) to all pending requests + # In this demo, there's typically one request per cycle, but the API supports multiple + responses = { + req.request_id: HandoffAgentUserRequest.create_response(user_response) for req in pending_requests + } # Send responses and get new events - # We use send_responses_streaming() to get events as they occur, allowing us to - # display agent responses in real-time and handle new requests as they arrive - events = await _drain(workflow.send_responses_streaming(responses)) + # We use send_responses() to get events from the workflow, allowing us to + # display agent responses and handle new requests as they arrive + events = await workflow.send_responses(responses) pending_requests = _handle_events(events) """ diff --git a/python/samples/getting_started/workflows/orchestration/handoff_specialist_to_specialist.py b/python/samples/getting_started/workflows/orchestration/handoff_specialist_to_specialist.py deleted file mode 100644 index dfc9f0f73b..0000000000 --- a/python/samples/getting_started/workflows/orchestration/handoff_specialist_to_specialist.py +++ /dev/null @@ -1,284 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Sample: Multi-tier handoff workflow with specialist-to-specialist routing. - -This sample demonstrates advanced handoff routing where specialist agents can hand off -to other specialists, enabling complex multi-tier workflows. Unlike the simple handoff -pattern (see handoff_simple.py), specialists here can delegate to other specialists -without returning control to the user until the specialist chain completes. - -Routing Pattern: - User → Triage → Specialist A → Specialist B → Back to User - -This pattern is useful for complex support scenarios where different specialists need -to collaborate or escalate to each other before returning to the user. For example: - - Replacement agent needs shipping info → hands off to delivery agent - - Technical support needs billing info → hands off to billing agent - - Level 1 support escalates to Level 2 → hands off to escalation agent - -Configuration uses `.add_handoff()` to explicitly define the routing graph. - -Prerequisites: - - `az login` (Azure CLI authentication) - - Environment variables configured for AzureOpenAIChatClient -""" - -import asyncio -from collections.abc import AsyncIterable -from typing import cast - -from agent_framework import ( - ChatMessage, - HandoffBuilder, - HandoffUserInputRequest, - RequestInfoEvent, - WorkflowEvent, - WorkflowOutputEvent, - WorkflowRunState, - WorkflowStatusEvent, -) -from agent_framework.azure import AzureOpenAIChatClient -from azure.identity import AzureCliCredential - - -def create_agents(chat_client: AzureOpenAIChatClient): - """Create triage and specialist agents with multi-tier handoff capabilities. - - Returns: - Tuple of (triage_agent, replacement_agent, delivery_agent, billing_agent) - """ - triage = chat_client.create_agent( - instructions=( - "You are a customer support triage agent. Assess the user's issue and route appropriately:\n" - "- For product replacement issues: call handoff_to_replacement_agent\n" - "- For delivery/shipping inquiries: call handoff_to_delivery_agent\n" - "- For billing/payment issues: call handoff_to_billing_agent\n" - "Be concise and friendly." - ), - name="triage_agent", - ) - - replacement = chat_client.create_agent( - instructions=( - "You handle product replacement requests. Ask for order number and reason for replacement.\n" - "If the user also needs shipping/delivery information, call handoff_to_delivery_agent to " - "get tracking details. Otherwise, process the replacement and confirm with the user.\n" - "Be concise and helpful." - ), - name="replacement_agent", - ) - - delivery = chat_client.create_agent( - instructions=( - "You handle shipping and delivery inquiries. Provide tracking information, estimated " - "delivery dates, and address any delivery concerns.\n" - "If billing issues come up, call handoff_to_billing_agent.\n" - "Be concise and clear." - ), - name="delivery_agent", - ) - - billing = chat_client.create_agent( - instructions=( - "You handle billing and payment questions. Help with refunds, payment methods, " - "and invoice inquiries. Be concise." - ), - name="billing_agent", - ) - - return triage, replacement, delivery, billing - - -async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]: - """Collect all events from an async stream into a list.""" - return [event async for event in stream] - - -def _handle_events(events: list[WorkflowEvent]) -> list[RequestInfoEvent]: - """Process workflow events and extract pending user input requests.""" - requests: list[RequestInfoEvent] = [] - - for event in events: - if isinstance(event, WorkflowStatusEvent) and event.state in { - WorkflowRunState.IDLE, - WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, - }: - print(f"[status] {event.state.name}") - - elif isinstance(event, WorkflowOutputEvent): - conversation = cast(list[ChatMessage], event.data) - if isinstance(conversation, list): - print("\n=== Final Conversation ===") - for message in conversation: - # Filter out messages with no text (tool calls) - if not message.text.strip(): - continue - speaker = message.author_name or message.role.value - print(f"- {speaker}: {message.text}") - print("==========================") - - elif isinstance(event, RequestInfoEvent): - if isinstance(event.data, HandoffUserInputRequest): - _print_handoff_request(event.data) - requests.append(event) - - return requests - - -def _print_handoff_request(request: HandoffUserInputRequest) -> None: - """Display a user input request with conversation context.""" - print("\n=== User Input Requested ===") - # Filter out messages with no text for cleaner display - messages_with_text = [msg for msg in request.conversation if msg.text.strip()] - print(f"Last {len(messages_with_text)} messages in conversation:") - for message in messages_with_text[-5:]: # Show last 5 for brevity - speaker = message.author_name or message.role.value - text = message.text[:100] + "..." if len(message.text) > 100 else message.text - print(f" {speaker}: {text}") - print("============================") - - -async def main() -> None: - """Demonstrate specialist-to-specialist handoffs in a multi-tier support scenario. - - This sample shows: - 1. Triage agent routes to replacement specialist - 2. Replacement specialist hands off to delivery specialist - 3. Delivery specialist can hand off to billing if needed - 4. All transitions are seamless without returning to user until complete - - The workflow configuration explicitly defines which agents can hand off to which others: - - triage_agent → replacement_agent, delivery_agent, billing_agent - - replacement_agent → delivery_agent, billing_agent - - delivery_agent → billing_agent - """ - chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) - triage, replacement, delivery, billing = create_agents(chat_client) - - # Configure multi-tier handoffs using fluent add_handoff() API - # This allows specialists to hand off to other specialists - workflow = ( - HandoffBuilder( - name="multi_tier_support", - participants=[triage, replacement, delivery, billing], - ) - .set_coordinator(triage) - .add_handoff(triage, [replacement, delivery, billing]) # Triage can route to any specialist - .add_handoff(replacement, [delivery, billing]) # Replacement can delegate to delivery or billing - .add_handoff(delivery, billing) # Delivery can escalate to billing - # Termination condition: Stop when more than 3 user messages exist. - # This allows agents to respond to the 3rd user message before the 4th triggers termination. - # In this sample: initial message + 3 scripted responses = 4 messages, then workflow ends. - .with_termination_condition(lambda conv: sum(1 for msg in conv if msg.role.value == "user") > 3) - .build() - ) - - # Scripted user responses simulating a multi-tier handoff scenario - # Note: The initial run_stream() call sends the first user message, - # then these scripted responses are sent in sequence (total: 4 user messages). - # A 5th response triggers termination after agents respond to the 4th message. - scripted_responses = [ - "I need help with order 12345. I want a replacement and need to know when it will arrive.", - "The item arrived damaged. I'd like a replacement shipped to the same address.", - "Great! Can you confirm the shipping cost won't be charged again?", - "Thank you!", # Final response to trigger termination after billing agent answers - ] - - print("\n" + "=" * 80) - print("SPECIALIST-TO-SPECIALIST HANDOFF DEMONSTRATION") - print("=" * 80) - print("\nScenario: Customer needs replacement + shipping info + billing confirmation") - print("Expected flow: User → Triage → Replacement → Delivery → Billing → User") - print("=" * 80 + "\n") - - # Start workflow with initial message - print(f"[User]: {scripted_responses[0]}\n") - events = await _drain(workflow.run_stream(scripted_responses[0])) - pending_requests = _handle_events(events) - - # Process scripted responses - response_index = 1 - while pending_requests and response_index < len(scripted_responses): - user_response = scripted_responses[response_index] - print(f"\n[User]: {user_response}\n") - - responses = {req.request_id: user_response for req in pending_requests} - events = await _drain(workflow.send_responses_streaming(responses)) - pending_requests = _handle_events(events) - - response_index += 1 - - """ - Sample Output: - - ================================================================================ - SPECIALIST-TO-SPECIALIST HANDOFF DEMONSTRATION - ================================================================================ - - Scenario: Customer needs replacement + shipping info + billing confirmation - Expected flow: User → Triage → Replacement → Delivery → Billing → User - ================================================================================ - - [User]: I need help with order 12345. I want a replacement and need to know when it will arrive. - - - === User Input Requested === - Last 5 messages in conversation: - user: I need help with order 12345. I want a replacement and need to know when it will arrive. - triage_agent: I am connecting you to our replacement agent to assist with your replacement request and to our deli... - replacement_agent: I have connected you to our agents who will assist with your replacement request for order 12345 and... - delivery_agent: For your replacement request and delivery details regarding order 12345, I'll connect you to the app... - billing_agent: I don’t have access to order details. Please contact the seller or customer service directly for rep... - ============================ - [status] IDLE_WITH_PENDING_REQUESTS - - [User]: The item arrived damaged. I'd like a replacement shipped to the same address. - - - === User Input Requested === - Last 8 messages in conversation: - delivery_agent: For your replacement request and delivery details regarding order 12345, I'll connect you to the app... - billing_agent: I don’t have access to order details. Please contact the seller or customer service directly for rep... - user: The item arrived damaged. I'd like a replacement shipped to the same address. - triage_agent: I'm connecting you to our replacement agent who will assist you with getting a replacement shipped t... - replacement_agent: Thank you for the info. I'll start the replacement process for your damaged item on order 12345 and ... - ============================ - [status] IDLE_WITH_PENDING_REQUESTS - - [User]: Great! Can you confirm the shipping cost won't be charged again? - - - === User Input Requested === - Last 11 messages in conversation: - triage_agent: I'm connecting you to our replacement agent who will assist you with getting a replacement shipped t... - replacement_agent: Thank you for the info. I'll start the replacement process for your damaged item on order 12345 and ... - user: Great! Can you confirm the shipping cost won't be charged again? - triage_agent: I'm connecting you to our billing agent who can confirm whether the shipping cost will be charged ag... - billing_agent: Replacements for damaged items are typically shipped at no extra shipping cost. I recommend confirmi... - ============================ - [status] IDLE_WITH_PENDING_REQUESTS - - [User]: Thank you! - - - === Final Conversation === - - user: I need help with order 12345. I want a replacement and need to know when it will arrive. - - triage_agent: I am connecting you to our replacement agent to assist with your replacement request and to our delivery agent who can provide information on the arrival of your replacement. They will help you shortly! - - replacement_agent: I have connected you to our agents who will assist with your replacement request for order 12345 and provide shipping and delivery information. They will help you with the details shortly. If you have any other questions, feel free to ask! - - delivery_agent: For your replacement request and delivery details regarding order 12345, I'll connect you to the appropriate agent who can assist with both billing and delivery concerns. They will provide you with the estimated arrival date and replacement process. Is there anything else you'd like to know meanwhile? - - billing_agent: I don’t have access to order details. Please contact the seller or customer service directly for replacement and delivery info on order 12345. If you need help with billing, refunds, or payments, I can assist! - - user: The item arrived damaged. I'd like a replacement shipped to the same address. - - triage_agent: I'm connecting you to our replacement agent who will assist you with getting a replacement shipped to the same address for the damaged item. They will help you shortly! - - replacement_agent: Thank you for the info. I'll start the replacement process for your damaged item on order 12345 and arrange to ship it to the same address. Please hold on a moment while I proceed. - Thank you for providing the reason. I will process a replacement for the damaged item from order 12345 and arrange the shipment to the same address. Please hold on a moment while I take care of this. - - user: Great! Can you confirm the shipping cost won't be charged again? - - triage_agent: I'm connecting you to our billing agent who can confirm whether the shipping cost will be charged again for the replacement of your damaged item. They will assist you shortly! - - billing_agent: Replacements for damaged items are typically shipped at no extra shipping cost. I recommend confirming with the replacements or billing department to be sure. Let me know if you’d like me to connect you! - - user: Thank you! - ========================== - [status] IDLE - """ # noqa: E501 - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/orchestration/handoff_with_code_interpreter_file.py b/python/samples/getting_started/workflows/orchestration/handoff_with_code_interpreter_file.py index b1fd37302a..eadfe634e3 100644 --- a/python/samples/getting_started/workflows/orchestration/handoff_with_code_interpreter_file.py +++ b/python/samples/getting_started/workflows/orchestration/handoff_with_code_interpreter_file.py @@ -32,8 +32,8 @@ from agent_framework import ( AgentRunUpdateEvent, ChatAgent, + HandoffAgentUserRequest, HandoffBuilder, - HandoffUserInputRequest, HostedCodeInterpreterTool, HostedFileContent, RequestInfoEvent, @@ -68,21 +68,10 @@ def _handle_events(events: list[WorkflowEvent]) -> tuple[list[RequestInfoEvent], print(f"[status] {event.state.name}") elif isinstance(event, RequestInfoEvent): - if isinstance(event.data, HandoffUserInputRequest): - print("\n=== Conversation So Far ===") - for msg in event.data.conversation: - speaker = msg.author_name or msg.role.value - text = msg.text or "" - txt = text[:200] + "..." if len(text) > 200 else text - print(f"- {speaker}: {txt}") - print("===========================\n") requests.append(event) elif isinstance(event, AgentRunUpdateEvent): - update = event.data - if update is None: - continue - for content in update.contents: + for content in event.data.contents: if isinstance(content, HostedFileContent): file_ids.append(content.file_id) print(f"[Found HostedFileContent: file_id={content.file_id}]") @@ -137,11 +126,7 @@ async def create_agents_v2(credential: AzureCliCredential) -> AsyncIterator[tupl ): triage = triage_client.create_agent( name="TriageAgent", - instructions=( - "You are a triage agent. Your ONLY job is to route requests to the appropriate specialist. " - "For code or file creation requests, call handoff_to_CodeSpecialist immediately. " - "Do NOT try to complete tasks yourself. Just hand off." - ), + instructions="You are a triage agent. Your ONLY job is to route requests to the appropriate specialist.", ) code_specialist = code_client.create_agent( @@ -170,7 +155,7 @@ async def main() -> None: workflow = ( HandoffBuilder() .participants([triage, code_specialist]) - .set_coordinator(triage) + .with_start_agent(triage) .with_termination_condition(lambda conv: sum(1 for msg in conv if msg.role.value == "user") >= 2) .build() ) @@ -195,7 +180,7 @@ async def main() -> None: user_input = user_inputs[input_index] print(f"\nUser: {user_input}") - responses = {request.request_id: user_input} + responses = {request.request_id: HandoffAgentUserRequest.create_response(user_input)} events = await _drain(workflow.send_responses_streaming(responses)) requests, file_ids = _handle_events(events) all_file_ids.extend(file_ids) diff --git a/python/samples/getting_started/workflows/orchestration/magentic.py b/python/samples/getting_started/workflows/orchestration/magentic.py index 213486706a..8e71d09a42 100644 --- a/python/samples/getting_started/workflows/orchestration/magentic.py +++ b/python/samples/getting_started/workflows/orchestration/magentic.py @@ -1,17 +1,19 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import json import logging from typing import cast from agent_framework import ( - MAGENTIC_EVENT_TYPE_AGENT_DELTA, - MAGENTIC_EVENT_TYPE_ORCHESTRATOR, AgentRunUpdateEvent, ChatAgent, ChatMessage, + GroupChatRequestSentEvent, HostedCodeInterpreterTool, MagenticBuilder, + MagenticOrchestratorEvent, + MagenticProgressLedger, WorkflowOutputEvent, ) from agent_framework.openai import OpenAIChatClient, OpenAIResponsesClient @@ -75,13 +77,9 @@ async def main() -> None: print("\nBuilding Magentic Workflow...") - # State used by on_agent_stream callback - last_stream_agent_id: str | None = None - stream_line_open: bool = False - workflow = ( MagenticBuilder() - .participants(researcher=researcher_agent, coder=coder_agent) + .participants([researcher_agent, coder_agent]) .with_standard_manager( agent=manager_agent, max_round_count=10, @@ -103,43 +101,49 @@ async def main() -> None: print(f"\nTask: {task}") print("\nStarting workflow execution...") - try: - output: str | None = None - async for event in workflow.run_stream(task): - if isinstance(event, AgentRunUpdateEvent): - props = event.data.additional_properties if event.data else None - event_type = props.get("magentic_event_type") if props else None - - if event_type == MAGENTIC_EVENT_TYPE_ORCHESTRATOR: - kind = props.get("orchestrator_message_kind", "") if props else "" - text = event.data.text if event.data else "" - print(f"\n[ORCH:{kind}]\n\n{text}\n{'-' * 26}") - elif event_type == MAGENTIC_EVENT_TYPE_AGENT_DELTA: - agent_id = props.get("agent_id", event.executor_id) if props else event.executor_id - if last_stream_agent_id != agent_id or not stream_line_open: - if stream_line_open: - print() - print(f"\n[STREAM:{agent_id}]: ", end="", flush=True) - last_stream_agent_id = agent_id - stream_line_open = True - if event.data and event.data.text: - print(event.data.text, end="", flush=True) - elif event.data and event.data.text: - print(event.data.text, end="", flush=True) - elif isinstance(event, WorkflowOutputEvent): - output_messages = cast(list[ChatMessage], event.data) - if output_messages: - output = output_messages[-1].text - - if stream_line_open: - print() - stream_line_open = False - - if output is not None: - print(f"Workflow completed with result:\n\n{output}") - - except Exception as e: - print(f"Workflow execution failed: {e}") + # Keep track of the last executor to format output nicely in streaming mode + last_message_id: str | None = None + output_event: WorkflowOutputEvent | None = None + async for event in workflow.run_stream(task): + if isinstance(event, AgentRunUpdateEvent): + message_id = event.data.message_id + if message_id != last_message_id: + if last_message_id is not None: + print("\n") + print(f"- {event.executor_id}:", end=" ", flush=True) + last_message_id = message_id + print(event.data, end="", flush=True) + + elif isinstance(event, MagenticOrchestratorEvent): + print(f"\n[Magentic Orchestrator Event] Type: {event.event_type.name}") + if isinstance(event.data, ChatMessage): + print(f"Please review the plan:\n{event.data.text}") + elif isinstance(event.data, MagenticProgressLedger): + print(f"Please review progress ledger:\n{json.dumps(event.data.to_dict(), indent=2)}") + else: + print(f"Unknown data type in MagenticOrchestratorEvent: {type(event.data)}") + + # Block to allow user to read the plan/progress before continuing + # Note: this is for demonstration only and is not the recommended way to handle human interaction. + # Please refer to `with_plan_review` for proper human interaction during planning phases. + await asyncio.get_event_loop().run_in_executor(None, input, "Press Enter to continue...") + + elif isinstance(event, GroupChatRequestSentEvent): + print(f"\n[REQUEST SENT ({event.round_index})] to agent: {event.participant_name}") + + elif isinstance(event, WorkflowOutputEvent): + output_event = event + + if not output_event: + raise RuntimeError("Workflow did not produce a final output event.") + print("\n\nWorkflow completed!") + print("Final Output:") + # The output of the Magentic workflow is a list of ChatMessages with only one final message + # generated by the orchestrator. + output_messages = cast(list[ChatMessage], output_event.data) + if output_messages: + output = output_messages[-1].text + print(output) if __name__ == "__main__": diff --git a/python/samples/getting_started/workflows/orchestration/magentic_agent_clarification.py b/python/samples/getting_started/workflows/orchestration/magentic_agent_clarification.py deleted file mode 100644 index 44dea25acc..0000000000 --- a/python/samples/getting_started/workflows/orchestration/magentic_agent_clarification.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import asyncio -import logging -from typing import Annotated, cast - -from agent_framework import ( - MAGENTIC_EVENT_TYPE_AGENT_DELTA, - MAGENTIC_EVENT_TYPE_ORCHESTRATOR, - AgentRunUpdateEvent, - ChatAgent, - ChatMessage, - MagenticBuilder, - MagenticHumanInterventionDecision, - MagenticHumanInterventionKind, - MagenticHumanInterventionReply, - MagenticHumanInterventionRequest, - RequestInfoEvent, - WorkflowOutputEvent, - ai_function, -) -from agent_framework.openai import OpenAIChatClient - -logging.basicConfig(level=logging.WARNING) -logger = logging.getLogger(__name__) - -""" -Sample: Agent Clarification via Tool Calls in Magentic Workflows - -This sample demonstrates how agents can ask clarifying questions to users during -execution via the HITL (Human-in-the-Loop) mechanism. - -Scenario: "Onboard Jessica Smith" -- User provides an ambiguous task: "Onboard Jessica Smith" -- The onboarding agent recognizes missing information and uses the ask_user tool -- The ask_user call surfaces as a TOOL_APPROVAL request via RequestInfoEvent -- User provides the answer (e.g., "Engineering, Software Engineer") -- The answer is fed back to the agent as a FunctionResultContent -- Agent continues execution with the clarified information - -How it works: -1. Agent has an `ask_user` tool decorated with `@ai_function(approval_mode="always_require")` -2. When agent calls `ask_user`, it surfaces as a FunctionApprovalRequestContent -3. MagenticAgentExecutor converts this to a MagenticHumanInterventionRequest(kind=TOOL_APPROVAL) -4. User provides answer via MagenticHumanInterventionReply with response_text -5. The response_text becomes the function result fed back to the agent -6. Agent receives the result and continues processing - -Prerequisites: -- OpenAI credentials configured for `OpenAIChatClient`. -""" - - -@ai_function(approval_mode="always_require") -def ask_user(question: Annotated[str, "The question to ask the user for clarification"]) -> str: - """Ask the user a clarifying question to gather missing information. - - Use this tool when you need additional information from the user to complete - your task effectively. The user's response will be returned so you can - continue with your work. - - Args: - question: The question to ask the user - - Returns: - The user's response to the question - """ - # This function body is a placeholder - the actual interaction happens via HITL. - # When the agent calls this tool: - # 1. The tool call surfaces as a FunctionApprovalRequestContent - # 2. MagenticAgentExecutor detects this and emits a HITL request - # 3. The user provides their answer - # 4. The answer is fed back as the function result - return f"User was asked: {question}" - - -async def main() -> None: - # Create an onboarding agent that asks clarifying questions - onboarding_agent = ChatAgent( - name="OnboardingAgent", - description="HR specialist who handles employee onboarding", - instructions=( - "You are an HR Onboarding Specialist. Your job is to onboard new employees.\n\n" - "IMPORTANT: When given an onboarding request, you MUST gather the following " - "information before proceeding:\n" - "1. Department (e.g., Engineering, Sales, Marketing)\n" - "2. Role/Title (e.g., Software Engineer, Account Executive)\n" - "3. Start date (if not specified)\n" - "4. Manager's name (if known)\n\n" - "Use the ask_user tool to request ANY missing information. " - "Do not proceed with onboarding until you have at least the department and role.\n\n" - "Once you have the information, create an onboarding plan." - ), - chat_client=OpenAIChatClient(model_id="gpt-4o"), - tools=[ask_user], # Tool decorated with @ai_function(approval_mode="always_require") - ) - - # Create a manager agent - manager_agent = ChatAgent( - name="MagenticManager", - description="Orchestrator that coordinates the onboarding workflow", - instructions="You coordinate a team to complete HR tasks efficiently.", - chat_client=OpenAIChatClient(model_id="gpt-4o"), - ) - - print("\nBuilding Magentic Workflow with Agent Clarification...") - - workflow = ( - MagenticBuilder() - .participants(onboarding=onboarding_agent) - .with_standard_manager( - agent=manager_agent, - max_round_count=10, - max_stall_count=3, - max_reset_count=2, - ) - .build() - ) - - # Ambiguous task - agent should ask for clarification - task = "Onboard Jessica Smith" - - print(f"\nTask: {task}") - print("(This is intentionally vague - the agent should ask for more details)") - print("\nStarting workflow execution...") - print("=" * 60) - - try: - pending_request: RequestInfoEvent | None = None - pending_responses: dict[str, object] | None = None - completed = False - workflow_output: str | None = None - - last_stream_agent_id: str | None = None - stream_line_open: bool = False - - while not completed: - if pending_responses is not None: - stream = workflow.send_responses_streaming(pending_responses) - else: - stream = workflow.run_stream(task) - - async for event in stream: - if isinstance(event, AgentRunUpdateEvent): - props = event.data.additional_properties if event.data else None - event_type = props.get("magentic_event_type") if props else None - - if event_type == MAGENTIC_EVENT_TYPE_ORCHESTRATOR: - kind = props.get("orchestrator_message_kind", "") if props else "" - text = event.data.text if event.data else "" - if stream_line_open: - print() - stream_line_open = False - print(f"\n[ORCHESTRATOR: {kind}]\n{text}\n{'-' * 40}") - elif event_type == MAGENTIC_EVENT_TYPE_AGENT_DELTA: - agent_id = props.get("agent_id", "unknown") if props else "unknown" - if last_stream_agent_id != agent_id or not stream_line_open: - if stream_line_open: - print() - print(f"\n[{agent_id}]: ", end="", flush=True) - last_stream_agent_id = agent_id - stream_line_open = True - if event.data and event.data.text: - print(event.data.text, end="", flush=True) - - elif isinstance(event, RequestInfoEvent) and event.request_type is MagenticHumanInterventionRequest: - if stream_line_open: - print() - stream_line_open = False - pending_request = event - req = cast(MagenticHumanInterventionRequest, event.data) - - if req.kind == MagenticHumanInterventionKind.TOOL_APPROVAL: - print("\n" + "=" * 60) - print("AGENT ASKING FOR CLARIFICATION") - print("=" * 60) - print(f"\nAgent: {req.agent_id}") - print(f"Question: {req.prompt}") - if req.context: - print(f"Details: {req.context}") - print() - - elif isinstance(event, WorkflowOutputEvent): - if stream_line_open: - print() - stream_line_open = False - workflow_output = event.data if event.data else None - completed = True - - if stream_line_open: - print() - stream_line_open = False - pending_responses = None - - if pending_request is not None: - req = cast(MagenticHumanInterventionRequest, pending_request.data) - - if req.kind == MagenticHumanInterventionKind.TOOL_APPROVAL: - # Agent is asking for clarification - print("Please provide your answer:") - answer = input("> ").strip() # noqa: ASYNC250 - - if answer.lower() == "exit": - print("Exiting workflow...") - return - - # Send the answer back - it will be fed to the agent as the function result - reply = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.APPROVE, - response_text=answer if answer else "No additional information provided.", - ) - pending_responses = {pending_request.request_id: reply} - pending_request = None - - print("\n" + "=" * 60) - print("WORKFLOW COMPLETED") - print("=" * 60) - if workflow_output: - messages = cast(list[ChatMessage], workflow_output) - if messages: - final_msg = messages[-1] - print(f"\nFinal Result:\n{final_msg.text}") - - except Exception as e: - print(f"Workflow execution failed: {e}") - logger.exception("Workflow exception", exc_info=e) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py b/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py index 36e6ca4c01..6fc284a9ab 100644 --- a/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py +++ b/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py @@ -3,15 +3,14 @@ import asyncio import json from pathlib import Path +from typing import cast from agent_framework import ( ChatAgent, + ChatMessage, FileCheckpointStorage, MagenticBuilder, - MagenticHumanInterventionDecision, - MagenticHumanInterventionKind, - MagenticHumanInterventionReply, - MagenticHumanInterventionRequest, + MagenticPlanReviewRequest, RequestInfoEvent, WorkflowCheckpoint, WorkflowOutputEvent, @@ -82,7 +81,7 @@ def build_workflow(checkpoint_storage: FileCheckpointStorage): # stores the checkpoint backend so the runtime knows where to persist snapshots. return ( MagenticBuilder() - .participants(researcher=researcher, writer=writer) + .participants([researcher, writer]) .with_plan_review() .with_standard_manager( agent=manager_agent, @@ -110,19 +109,16 @@ async def main() -> None: # Run the workflow until the first RequestInfoEvent is surfaced. The event carries the # request_id we must reuse on resume. In a real system this is where the UI would present # the plan for human review. - plan_review_request_id: str | None = None + plan_review_request: MagenticPlanReviewRequest | None = None async for event in workflow.run_stream(TASK): - if isinstance(event, RequestInfoEvent) and event.request_type is MagenticHumanInterventionRequest: - request = event.data - if isinstance(request, MagenticHumanInterventionRequest): - if request.kind == MagenticHumanInterventionKind.PLAN_REVIEW: - plan_review_request_id = event.request_id - print(f"Captured plan review request: {plan_review_request_id}") + if isinstance(event, RequestInfoEvent) and event.request_type is MagenticPlanReviewRequest: + plan_review_request = event.data + print(f"Captured plan review request: {event.request_id}") if isinstance(event, WorkflowStatusEvent) and event.state is WorkflowRunState.IDLE_WITH_PENDING_REQUESTS: break - if plan_review_request_id is None: + if plan_review_request is None: print("No plan review request emitted; nothing to resume.") return @@ -142,19 +138,19 @@ async def main() -> None: if checkpoint_path.exists(): with checkpoint_path.open() as f: snapshot = json.load(f) - request_map = snapshot.get("executor_states", {}).get("magentic_plan_review", {}).get("request_events", {}) + request_map = snapshot.get("pending_request_info_events", {}) print(f"Pending plan-review requests persisted in checkpoint: {list(request_map.keys())}") print("\n=== Stage 2: resume from checkpoint and approve plan ===") resumed_workflow = build_workflow(checkpoint_storage) # Construct an approval reply to supply when the plan review request is re-emitted. - approval = MagenticHumanInterventionReply(decision=MagenticHumanInterventionDecision.APPROVE) + approval = plan_review_request.approve() # Resume execution and capture the re-emitted plan review request. request_info_event: RequestInfoEvent | None = None async for event in resumed_workflow.run_stream(checkpoint_id=resume_checkpoint.checkpoint_id): - if isinstance(event, RequestInfoEvent) and isinstance(event.data, MagenticHumanInterventionRequest): + if isinstance(event, RequestInfoEvent) and isinstance(event.data, MagenticPlanReviewRequest): request_info_event = event if request_info_event is None: @@ -178,9 +174,11 @@ async def main() -> None: if not result: print("No result data from workflow.") return - text = getattr(result, "text", None) or str(result) + output_messages = cast(list[ChatMessage], result) print("\n=== Final Answer ===") - print(text) + # The output of the Magentic workflow is a list of ChatMessages with only one final message + # generated by the orchestrator. + print(output_messages[-1].text) # ------------------------------------------------------------------ # Stage 3: demonstrate resuming from a later checkpoint (post-plan) @@ -233,7 +231,7 @@ def _pending_message_count(cp: WorkflowCheckpoint) -> int: if not post_emitted_events: print("No new events were emitted; checkpoint already captured a completed run.") print("\n=== Final Answer (post-plan resume) ===") - print(text) + print(output_messages[-1].text) return print("Workflow did not complete after post-plan resume.") return @@ -243,9 +241,11 @@ def _pending_message_count(cp: WorkflowCheckpoint) -> int: print("No result data from post-plan resume.") return - post_text = getattr(post_result, "text", None) or str(post_result) + output_messages = cast(list[ChatMessage], post_result) print("\n=== Final Answer (post-plan resume) ===") - print(post_text) + # The output of the Magentic workflow is a list of ChatMessages with only one final message + # generated by the orchestrator. + print(output_messages[-1].text) """ Sample Output: diff --git a/python/samples/getting_started/workflows/orchestration/magentic_human_plan_review.py b/python/samples/getting_started/workflows/orchestration/magentic_human_plan_review.py new file mode 100644 index 0000000000..37a53020e7 --- /dev/null +++ b/python/samples/getting_started/workflows/orchestration/magentic_human_plan_review.py @@ -0,0 +1,145 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import json +from typing import cast + +from agent_framework import ( + AgentRunUpdateEvent, + ChatAgent, + ChatMessage, + MagenticBuilder, + MagenticPlanReviewRequest, + RequestInfoEvent, + WorkflowOutputEvent, +) +from agent_framework.openai import OpenAIChatClient + +""" +Sample: Magentic Orchestration with Human Plan Review + +This sample demonstrates how humans can review and provide feedback on plans +generated by the Magentic workflow orchestrator. When plan review is enabled, +the workflow requests human approval or revision before executing each plan. + +Key concepts: +- with_plan_review(): Enables human review of generated plans +- MagenticPlanReviewRequest: The event type for plan review requests +- Human can choose to: approve the plan or provide revision feedback + +Plan review options: +- approve(): Accept the proposed plan and continue execution +- revise(feedback): Provide textual feedback to modify the plan + +Prerequisites: +- OpenAI credentials configured for `OpenAIChatClient`. +""" + + +async def main() -> None: + researcher_agent = ChatAgent( + name="ResearcherAgent", + description="Specialist in research and information gathering", + instructions="You are a Researcher. You find information and gather facts.", + chat_client=OpenAIChatClient(model_id="gpt-4o"), + ) + + analyst_agent = ChatAgent( + name="AnalystAgent", + description="Data analyst who processes and summarizes research findings", + instructions="You are an Analyst. You analyze findings and create summaries.", + chat_client=OpenAIChatClient(model_id="gpt-4o"), + ) + + manager_agent = ChatAgent( + name="MagenticManager", + description="Orchestrator that coordinates the workflow", + instructions="You coordinate a team to complete tasks efficiently.", + chat_client=OpenAIChatClient(model_id="gpt-4o"), + ) + + print("\nBuilding Magentic Workflow with Human Plan Review...") + + workflow = ( + MagenticBuilder() + .participants([researcher_agent, analyst_agent]) + .with_standard_manager( + agent=manager_agent, + max_round_count=10, + max_stall_count=1, + max_reset_count=2, + ) + .with_plan_review() # Request human input for plan review + .build() + ) + + task = "Research sustainable aviation fuel technology and summarize the findings." + + print(f"\nTask: {task}") + print("\nStarting workflow execution...") + print("=" * 60) + + pending_request: RequestInfoEvent | None = None + pending_responses: dict[str, object] | None = None + output_event: WorkflowOutputEvent | None = None + + while not output_event: + if pending_responses is not None: + stream = workflow.send_responses_streaming(pending_responses) + else: + stream = workflow.run_stream(task) + + last_message_id: str | None = None + async for event in stream: + if isinstance(event, AgentRunUpdateEvent): + message_id = event.data.message_id + if message_id != last_message_id: + if last_message_id is not None: + print("\n") + print(f"- {event.executor_id}:", end=" ", flush=True) + last_message_id = message_id + print(event.data, end="", flush=True) + + elif isinstance(event, RequestInfoEvent) and event.request_type is MagenticPlanReviewRequest: + pending_request = event + + elif isinstance(event, WorkflowOutputEvent): + output_event = event + + pending_responses = None + + # Handle plan review request if any + if pending_request is not None: + event_data = cast(MagenticPlanReviewRequest, pending_request.data) + + print("\n\n[Magentic Plan Review Request]") + if event_data.current_progress is not None: + print("Current Progress Ledger:") + print(json.dumps(event_data.current_progress.to_dict(), indent=2)) + print() + print(f"Proposed Plan:\n{event_data.plan.text}\n") + print("Please provide your feedback (press Enter to approve):") + + reply = await asyncio.get_event_loop().run_in_executor(None, input, "> ") + if reply.strip() == "": + print("Plan approved.\n") + pending_responses = {pending_request.request_id: event_data.approve()} + else: + print("Plan revised by human.\n") + pending_responses = {pending_request.request_id: event_data.revise(reply)} + pending_request = None + + print("\n" + "=" * 60) + print("WORKFLOW COMPLETED") + print("=" * 60) + print("Final Output:") + # The output of the Magentic workflow is a list of ChatMessages with only one final message + # generated by the orchestrator. + output_messages = cast(list[ChatMessage], output_event.data) + if output_messages: + output = output_messages[-1].text + print(output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/orchestration/magentic_human_plan_update.py b/python/samples/getting_started/workflows/orchestration/magentic_human_plan_update.py deleted file mode 100644 index b96fac7e99..0000000000 --- a/python/samples/getting_started/workflows/orchestration/magentic_human_plan_update.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import asyncio -import logging -from typing import cast - -from agent_framework import ( - MAGENTIC_EVENT_TYPE_AGENT_DELTA, - MAGENTIC_EVENT_TYPE_ORCHESTRATOR, - AgentRunUpdateEvent, - ChatAgent, - HostedCodeInterpreterTool, - MagenticBuilder, - MagenticHumanInterventionDecision, - MagenticHumanInterventionKind, - MagenticHumanInterventionReply, - MagenticHumanInterventionRequest, - RequestInfoEvent, - WorkflowOutputEvent, -) -from agent_framework.openai import OpenAIChatClient, OpenAIResponsesClient - -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger(__name__) - -""" -Sample: Magentic Orchestration + Human Plan Review - -What it does: -- Builds a Magentic workflow with two agents and enables human plan review. - A human approves or edits the plan via `RequestInfoEvent` before execution. - -- researcher: ChatAgent backed by OpenAIChatClient (web/search-capable model) -- coder: ChatAgent backed by OpenAIAssistantsClient with the Hosted Code Interpreter tool - -Key behaviors demonstrated: -- with_plan_review(): requests a PlanReviewRequest before coordination begins -- Event loop that waits for RequestInfoEvent[PlanReviewRequest], prints the plan, then - replies with PlanReviewReply (here we auto-approve, but you can edit/collect input) -- Callbacks: on_agent_stream (incremental chunks), on_agent_response (final messages), - on_result (final answer), and on_exception -- Workflow completion when idle - -Prerequisites: -- OpenAI credentials configured for `OpenAIChatClient` and `OpenAIResponsesClient`. -""" - - -async def main() -> None: - researcher_agent = ChatAgent( - name="ResearcherAgent", - description="Specialist in research and information gathering", - instructions=( - "You are a Researcher. You find information without additional computation or quantitative analysis." - ), - # This agent requires the gpt-4o-search-preview model to perform web searches. - # Feel free to explore with other agents that support web search, for example, - # the `OpenAIResponseAgent` or `AzureAgentProtocol` with bing grounding. - chat_client=OpenAIChatClient(model_id="gpt-4o-search-preview"), - ) - - coder_agent = ChatAgent( - name="CoderAgent", - description="A helpful assistant that writes and executes code to process and analyze data.", - instructions="You solve questions using code. Please provide detailed analysis and computation process.", - chat_client=OpenAIResponsesClient(), - tools=HostedCodeInterpreterTool(), - ) - - # Create a manager agent for the orchestration - manager_agent = ChatAgent( - name="MagenticManager", - description="Orchestrator that coordinates the research and coding workflow", - instructions="You coordinate a team to complete complex tasks efficiently.", - chat_client=OpenAIChatClient(), - ) - - # Callbacks - def on_exception(exception: Exception) -> None: - print(f"Exception occurred: {exception}") - logger.exception("Workflow exception", exc_info=exception) - - last_stream_agent_id: str | None = None - stream_line_open: bool = False - - print("\nBuilding Magentic Workflow...") - - workflow = ( - MagenticBuilder() - .participants(researcher=researcher_agent, coder=coder_agent) - .with_standard_manager( - agent=manager_agent, - max_round_count=10, - max_stall_count=3, - max_reset_count=2, - ) - .with_plan_review() - .build() - ) - - task = ( - "I am preparing a report on the energy efficiency of different machine learning model architectures. " - "Compare the estimated training and inference energy consumption of ResNet-50, BERT-base, and GPT-2 " - "on standard datasets (e.g., ImageNet for ResNet, GLUE for BERT, WebText for GPT-2). " - "Then, estimate the CO2 emissions associated with each, assuming training on an Azure Standard_NC6s_v3 " - "VM for 24 hours. Provide tables for clarity, and recommend the most energy-efficient model " - "per task type (image classification, text classification, and text generation)." - ) - - print(f"\nTask: {task}") - print("\nStarting workflow execution...") - - try: - pending_request: RequestInfoEvent | None = None - pending_responses: dict[str, MagenticHumanInterventionReply] | None = None - completed = False - workflow_output: str | None = None - - while not completed: - # Use streaming for both initial run and response sending - if pending_responses is not None: - stream = workflow.send_responses_streaming(pending_responses) - else: - stream = workflow.run_stream(task) - - # Collect events from the stream - async for event in stream: - if isinstance(event, AgentRunUpdateEvent): - props = event.data.additional_properties if event.data else None - event_type = props.get("magentic_event_type") if props else None - - if event_type == MAGENTIC_EVENT_TYPE_ORCHESTRATOR: - kind = props.get("orchestrator_message_kind", "") if props else "" - text = event.data.text if event.data else "" - print(f"\n[ORCH:{kind}]\n\n{text}\n{'-' * 26}") - elif event_type == MAGENTIC_EVENT_TYPE_AGENT_DELTA: - agent_id = props.get("agent_id", "unknown") if props else "unknown" - if last_stream_agent_id != agent_id or not stream_line_open: - if stream_line_open: - print() - print(f"\n[STREAM:{agent_id}]: ", end="", flush=True) - last_stream_agent_id = agent_id - stream_line_open = True - if event.data and event.data.text: - print(event.data.text, end="", flush=True) - elif isinstance(event, RequestInfoEvent) and event.request_type is MagenticHumanInterventionRequest: - request = cast(MagenticHumanInterventionRequest, event.data) - if request.kind == MagenticHumanInterventionKind.PLAN_REVIEW: - pending_request = event - if request.plan_text: - print(f"\n=== PLAN REVIEW REQUEST ===\n{request.plan_text}\n") - elif isinstance(event, WorkflowOutputEvent): - # Capture workflow output during streaming - workflow_output = str(event.data) if event.data else None - completed = True - - if stream_line_open: - print() - stream_line_open = False - pending_responses = None - - # Handle pending plan review request - if pending_request is not None: - # Get human input for plan review decision - print("Plan review options:") - print("1. approve - Approve the plan as-is") - print("2. approve with comments - Approve with feedback for the manager") - print("3. revise - Request revision with your feedback") - print("4. edit - Directly edit the plan text") - print("5. exit - Exit the workflow") - - while True: - choice = input("Enter your choice (1-5): ").strip().lower() # noqa: ASYNC250 - if choice in ["approve", "1"]: - reply = MagenticHumanInterventionReply(decision=MagenticHumanInterventionDecision.APPROVE) - break - if choice in ["approve with comments", "2"]: - comments = input("Enter your comments for the manager: ").strip() # noqa: ASYNC250 - reply = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.APPROVE, - comments=comments if comments else None, - ) - break - if choice in ["revise", "3"]: - comments = input("Enter feedback for revising the plan: ").strip() # noqa: ASYNC250 - reply = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.REVISE, - comments=comments if comments else None, - ) - break - if choice in ["edit", "4"]: - print("Enter your edited plan (end with an empty line):") - lines = [] - while True: - line = input() # noqa: ASYNC250 - if line == "": - break - lines.append(line) - edited_plan = "\n".join(lines) - reply = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.REVISE, - edited_plan_text=edited_plan if edited_plan else None, - ) - break - if choice in ["exit", "5"]: - print("Exiting workflow...") - return - print("Invalid choice. Please enter a number 1-5.") - - pending_responses = {pending_request.request_id: reply} - pending_request = None - - # Show final result from captured workflow output - if workflow_output: - print(f"Workflow completed with result:\n\n{workflow_output}") - - except Exception as e: - print(f"Workflow execution failed: {e}") - on_exception(e) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/orchestration/magentic_human_replan.py b/python/samples/getting_started/workflows/orchestration/magentic_human_replan.py deleted file mode 100644 index aaa9be66f8..0000000000 --- a/python/samples/getting_started/workflows/orchestration/magentic_human_replan.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import asyncio -import logging -from typing import cast - -from agent_framework import ( - MAGENTIC_EVENT_TYPE_AGENT_DELTA, - MAGENTIC_EVENT_TYPE_ORCHESTRATOR, - AgentRunUpdateEvent, - ChatAgent, - ChatMessage, - MagenticBuilder, - MagenticHumanInterventionDecision, - MagenticHumanInterventionKind, - MagenticHumanInterventionReply, - MagenticHumanInterventionRequest, - RequestInfoEvent, - WorkflowOutputEvent, -) -from agent_framework.openai import OpenAIChatClient - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -""" -Sample: Magentic Orchestration with Human Stall Intervention - -This sample demonstrates how humans can intervene when a Magentic workflow stalls. -When agents stop making progress, the workflow requests human input instead of -automatically replanning. - -Key concepts: -- with_human_input_on_stall(): Enables human intervention when workflow detects stalls -- MagenticHumanInterventionKind.STALL: The request kind for stall interventions -- Human can choose to: continue, trigger replan, or provide guidance - -Stall intervention options: -- CONTINUE: Reset stall counter and continue with current plan -- REPLAN: Trigger automatic replanning by the manager -- GUIDANCE: Provide text guidance to help agents get back on track - -Prerequisites: -- OpenAI credentials configured for `OpenAIChatClient`. - -NOTE: it is sometimes difficult to get the agents to actually stall depending on the task. -""" - - -async def main() -> None: - researcher_agent = ChatAgent( - name="ResearcherAgent", - description="Specialist in research and information gathering", - instructions="You are a Researcher. You find information and gather facts.", - chat_client=OpenAIChatClient(model_id="gpt-4o"), - ) - - analyst_agent = ChatAgent( - name="AnalystAgent", - description="Data analyst who processes and summarizes research findings", - instructions="You are an Analyst. You analyze findings and create summaries.", - chat_client=OpenAIChatClient(model_id="gpt-4o"), - ) - - manager_agent = ChatAgent( - name="MagenticManager", - description="Orchestrator that coordinates the workflow", - instructions="You coordinate a team to complete tasks efficiently.", - chat_client=OpenAIChatClient(model_id="gpt-4o"), - ) - - print("\nBuilding Magentic Workflow with Human Stall Intervention...") - - workflow = ( - MagenticBuilder() - .participants(researcher=researcher_agent, analyst=analyst_agent) - .with_standard_manager( - agent=manager_agent, - max_round_count=10, - max_stall_count=1, # Stall detection after 1 round without progress - max_reset_count=2, - ) - .with_human_input_on_stall() # Request human input when stalled (instead of auto-replan) - .build() - ) - - task = "Research sustainable aviation fuel technology and summarize the findings." - - print(f"\nTask: {task}") - print("\nStarting workflow execution...") - print("=" * 60) - - try: - pending_request: RequestInfoEvent | None = None - pending_responses: dict[str, object] | None = None - completed = False - workflow_output: str | None = None - - last_stream_agent_id: str | None = None - stream_line_open: bool = False - - while not completed: - if pending_responses is not None: - stream = workflow.send_responses_streaming(pending_responses) - else: - stream = workflow.run_stream(task) - - async for event in stream: - if isinstance(event, AgentRunUpdateEvent): - props = event.data.additional_properties if event.data else None - event_type = props.get("magentic_event_type") if props else None - - if event_type == MAGENTIC_EVENT_TYPE_ORCHESTRATOR: - kind = props.get("orchestrator_message_kind", "") if props else "" - text = event.data.text if event.data else "" - if stream_line_open: - print() - stream_line_open = False - print(f"\n[ORCHESTRATOR: {kind}]\n{text}\n{'-' * 40}") - elif event_type == MAGENTIC_EVENT_TYPE_AGENT_DELTA: - agent_id = props.get("agent_id", "unknown") if props else "unknown" - if last_stream_agent_id != agent_id or not stream_line_open: - if stream_line_open: - print() - print(f"\n[{agent_id}]: ", end="", flush=True) - last_stream_agent_id = agent_id - stream_line_open = True - if event.data and event.data.text: - print(event.data.text, end="", flush=True) - - elif isinstance(event, RequestInfoEvent) and event.request_type is MagenticHumanInterventionRequest: - if stream_line_open: - print() - stream_line_open = False - pending_request = event - req = cast(MagenticHumanInterventionRequest, event.data) - - if req.kind == MagenticHumanInterventionKind.STALL: - print("\n" + "=" * 60) - print("STALL INTERVENTION REQUESTED") - print("=" * 60) - print(f"\nWorkflow appears stalled after {req.stall_count} rounds") - print(f"Reason: {req.stall_reason}") - if req.last_agent: - print(f"Last active agent: {req.last_agent}") - if req.plan_text: - print(f"\nCurrent plan:\n{req.plan_text}") - print() - - elif isinstance(event, WorkflowOutputEvent): - if stream_line_open: - print() - stream_line_open = False - workflow_output = event.data if event.data else None - completed = True - - if stream_line_open: - print() - stream_line_open = False - pending_responses = None - - # Handle stall intervention request - if pending_request is not None: - req = cast(MagenticHumanInterventionRequest, pending_request.data) - reply: MagenticHumanInterventionReply | None = None - - if req.kind == MagenticHumanInterventionKind.STALL: - print("Stall intervention options:") - print("1. continue - Continue with current plan (reset stall counter)") - print("2. replan - Trigger automatic replanning") - print("3. guidance - Provide guidance to help agents") - print("4. exit - Exit the workflow") - - while True: - choice = input("Enter your choice (1-4): ").strip().lower() # noqa: ASYNC250 - if choice in ["continue", "1"]: - reply = MagenticHumanInterventionReply(decision=MagenticHumanInterventionDecision.CONTINUE) - break - if choice in ["replan", "2"]: - reply = MagenticHumanInterventionReply(decision=MagenticHumanInterventionDecision.REPLAN) - break - if choice in ["guidance", "3"]: - guidance = input("Enter your guidance: ").strip() # noqa: ASYNC250 - reply = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.GUIDANCE, - comments=guidance if guidance else None, - ) - break - if choice in ["exit", "4"]: - print("Exiting workflow...") - return - print("Invalid choice. Please enter a number 1-4.") - - if reply is not None: - pending_responses = {pending_request.request_id: reply} - pending_request = None - - print("\n" + "=" * 60) - print("WORKFLOW COMPLETED") - print("=" * 60) - if workflow_output: - messages = cast(list[ChatMessage], workflow_output) - if messages: - final_msg = messages[-1] - print(f"\nFinal Result:\n{final_msg.text}") - - except Exception as e: - print(f"Workflow execution failed: {e}") - logger.exception("Workflow exception", exc_info=e) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/orchestration/sequential_custom_executors.py b/python/samples/getting_started/workflows/orchestration/sequential_custom_executors.py index 104a833603..ecf605e73b 100644 --- a/python/samples/getting_started/workflows/orchestration/sequential_custom_executors.py +++ b/python/samples/getting_started/workflows/orchestration/sequential_custom_executors.py @@ -4,6 +4,7 @@ from typing import Any from agent_framework import ( + AgentExecutorResponse, ChatMessage, Executor, Role, @@ -20,18 +21,13 @@ This demonstrates how SequentialBuilder chains participants with a shared conversation context (list[ChatMessage]). An agent produces content; a custom executor appends a compact summary to the conversation. The workflow completes -when idle, and the final output contains the complete conversation. +after all participants have executed in sequence, and the final output contains +the complete conversation. Custom executor contract: -- Provide at least one @handler accepting list[ChatMessage] and a WorkflowContext[list[ChatMessage]] +- Provide at least one @handler accepting AgentExecutorResponse and a WorkflowContext[list[ChatMessage]] - Emit the updated conversation via ctx.send_message([...]) -Note on internal adapters: -- You may see adapter nodes in the event stream such as "input-conversation", - "to-conversation:", and "complete". These provide consistent typing, - conversion of agent responses into the shared conversation, and a single point - for completion—similar to concurrent's dispatcher/aggregator. - Prerequisites: - Azure OpenAI access configured for AzureOpenAIChatClient (use az login + env vars) """ @@ -41,11 +37,23 @@ class Summarizer(Executor): """Simple summarizer: consumes full conversation and appends an assistant summary.""" @handler - async def summarize(self, conversation: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: - users = sum(1 for m in conversation if m.role == Role.USER) - assistants = sum(1 for m in conversation if m.role == Role.ASSISTANT) + async def summarize(self, agent_response: AgentExecutorResponse, ctx: WorkflowContext[list[ChatMessage]]) -> None: + """Append a summary message to a copy of the full conversation. + + Note: A custom executor must be able to handle the message type from the prior participant, and produce + the message type expected by the next participant. In this case, the prior participant is an agent thus + the input is AgentExecutorResponse (an agent will be wrapped in an AgentExecutor, which produces + `AgentExecutorResponse`). If the next participant is also an agent or this is the final participant, + the output must be `list[ChatMessage]`. + """ + if not agent_response.full_conversation: + await ctx.send_message([ChatMessage(role=Role.ASSISTANT, text="No conversation to summarize.")]) + return + + users = sum(1 for m in agent_response.full_conversation if m.role == Role.USER) + assistants = sum(1 for m in agent_response.full_conversation if m.role == Role.ASSISTANT) summary = ChatMessage(role=Role.ASSISTANT, text=f"Summary -> users:{users} assistants:{assistants}") - final_conversation = list(conversation) + [summary] + final_conversation = list(agent_response.full_conversation) + [summary] await ctx.send_message(final_conversation) @@ -61,7 +69,7 @@ async def main() -> None: summarizer = Summarizer(id="summarizer") workflow = SequentialBuilder().participants([content, summarizer]).build() - # 3) Run and print final conversation + # 3) Run workflow and extract final conversation events = await workflow.run("Explain the benefits of budget eBikes for commuters.") outputs = events.get_outputs() diff --git a/python/samples/getting_started/workflows/parallelism/fan_out_fan_in_edges.py b/python/samples/getting_started/workflows/parallelism/fan_out_fan_in_edges.py index fbec7ca303..09287dc62f 100644 --- a/python/samples/getting_started/workflows/parallelism/fan_out_fan_in_edges.py +++ b/python/samples/getting_started/workflows/parallelism/fan_out_fan_in_edges.py @@ -36,7 +36,7 @@ Prerequisites: - Familiarity with WorkflowBuilder, executors, edges, events, and streaming runs. - Azure OpenAI access configured for AzureOpenAIChatClient. Log in with Azure CLI and set any required environment variables. -- Comfort reading AgentExecutorResponse.agent_run_response.text for assistant output aggregation. +- Comfort reading AgentExecutorResponse.agent_response.text for assistant output aggregation. """ @@ -67,8 +67,8 @@ async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowCon # Map responses to text by executor id for a simple, predictable demo. by_id: dict[str, str] = {} for r in results: - # AgentExecutorResponse.agent_run_response.text is the assistant text produced by the agent. - by_id[r.executor_id] = r.agent_run_response.text + # AgentExecutorResponse.agent_response.text is the assistant text produced by the agent. + by_id[r.executor_id] = r.agent_response.text research_text = by_id.get("researcher", "") marketing_text = by_id.get("marketer", "") diff --git a/python/samples/getting_started/workflows/state-management/shared_states_with_agents.py b/python/samples/getting_started/workflows/state-management/shared_states_with_agents.py index e9098f996e..e9d38d3161 100644 --- a/python/samples/getting_started/workflows/state-management/shared_states_with_agents.py +++ b/python/samples/getting_started/workflows/state-management/shared_states_with_agents.py @@ -117,7 +117,7 @@ async def to_detection_result(response: AgentExecutorResponse, ctx: WorkflowCont 2) Retrieve the current email_id from shared state. 3) Send a typed DetectionResult for conditional routing. """ - parsed = DetectionResultAgent.model_validate_json(response.agent_run_response.text) + parsed = DetectionResultAgent.model_validate_json(response.agent_response.text) email_id: str = await ctx.get_shared_state(CURRENT_EMAIL_ID_KEY) await ctx.send_message(DetectionResult(is_spam=parsed.is_spam, reason=parsed.reason, email_id=email_id)) @@ -142,7 +142,7 @@ async def submit_to_email_assistant(detection: DetectionResult, ctx: WorkflowCon @executor(id="finalize_and_send") async def finalize_and_send(response: AgentExecutorResponse, ctx: WorkflowContext[Never, str]) -> None: """Validate the drafted reply and yield the final output.""" - parsed = EmailResponse.model_validate_json(response.agent_run_response.text) + parsed = EmailResponse.model_validate_json(response.agent_response.text) await ctx.yield_output(f"Email sent: {parsed.response}") @@ -162,7 +162,7 @@ def create_spam_detection_agent() -> ChatAgent: "You are a spam detection assistant that identifies spam emails. " "Always return JSON with fields is_spam (bool) and reason (string)." ), - response_format=DetectionResultAgent, + default_options={"response_format": DetectionResultAgent}, # response_format enforces structured JSON from each agent. name="spam_detection_agent", ) @@ -176,7 +176,7 @@ def create_email_assistant_agent() -> ChatAgent: "Return JSON with a single field 'response' containing the drafted reply." ), # response_format enforces structured JSON from each agent. - response_format=EmailResponse, + default_options={"response_format": EmailResponse}, name="email_assistant_agent", ) diff --git a/python/samples/getting_started/workflows/tool-approval/concurrent_builder_tool_approval.py b/python/samples/getting_started/workflows/tool-approval/concurrent_builder_tool_approval.py index e4092414fc..a858fe28ce 100644 --- a/python/samples/getting_started/workflows/tool-approval/concurrent_builder_tool_approval.py +++ b/python/samples/getting_started/workflows/tool-approval/concurrent_builder_tool_approval.py @@ -10,8 +10,6 @@ FunctionApprovalResponseContent, RequestInfoEvent, WorkflowOutputEvent, - WorkflowRunState, - WorkflowStatusEvent, ai_function, ) from agent_framework.openai import OpenAIChatClient @@ -25,19 +23,18 @@ This sample works as follows: 1. A ConcurrentBuilder workflow is created with two agents running in parallel. -2. One agent has a tool requiring approval (financial transaction). -3. The other agent has only non-approval tools (market data lookup). -4. Both agents receive the same task and work concurrently. -5. When the financial agent tries to execute a trade, it triggers an approval request. -6. The sample simulates human approval and the workflow completes. -7. Results from both agents are aggregated and output. +2. Both agents have the same tools, including one requiring approval (execute_trade). +3. Both agents receive the same task and work concurrently on their respective stocks. +4. When either agent tries to execute a trade, it triggers an approval request. +5. The sample simulates human approval and the workflow completes. +6. Results from both agents are aggregated and output. Purpose: -Show how tool call approvals work in parallel execution scenarios where only some -agents have sensitive tools. +Show how tool call approvals work in parallel execution scenarios where multiple +agents may independently trigger approval requests. Demonstrate: -- Combining agents with and without approval-required tools in concurrent workflows. +- Handling multiple approval requests from different agents in concurrent workflows. - Handling RequestInfoEvent during concurrent agent execution. - Understanding that approval pauses only the agent that triggered it, not all agents. @@ -47,7 +44,7 @@ """ -# 1. Define tools for the research agent (no approval required) +# 1. Define market data tools (no approval required) @ai_function def get_stock_price(symbol: Annotated[str, "The stock ticker symbol"]) -> str: """Get the current stock price for a given symbol.""" @@ -61,10 +58,16 @@ def get_stock_price(symbol: Annotated[str, "The stock ticker symbol"]) -> str: def get_market_sentiment(symbol: Annotated[str, "The stock ticker symbol"]) -> str: """Get market sentiment analysis for a stock.""" # Mock sentiment data - return f"Market sentiment for {symbol.upper()}: Bullish (72% positive mentions in last 24h)" + mock_data = { + "AAPL": "Market sentiment for AAPL: Bullish (68% positive mentions in last 24h)", + "GOOGL": "Market sentiment for GOOGL: Neutral (50% positive mentions in last 24h)", + "MSFT": "Market sentiment for MSFT: Bullish (72% positive mentions in last 24h)", + "AMZN": "Market sentiment for AMZN: Bearish (40% positive mentions in last 24h)", + } + return mock_data.get(symbol.upper(), f"Market sentiment for {symbol.upper()}: Unknown") -# 2. Define tools for the trading agent (approval required for trades) +# 2. Define trading tools (approval required) @ai_function(approval_mode="always_require") def execute_trade( symbol: Annotated[str, "The stock ticker symbol"], @@ -78,52 +81,68 @@ def execute_trade( @ai_function def get_portfolio_balance() -> str: """Get current portfolio balance and available funds.""" - return "Portfolio: $50,000 invested, $10,000 cash available" + return "Portfolio: $50,000 invested, $10,000 cash available. Holdings: AAPL, GOOGL, MSFT." + + +def _print_output(event: WorkflowOutputEvent) -> None: + if not event.data: + raise ValueError("WorkflowOutputEvent has no data") + + if not isinstance(event.data, list) and not all(isinstance(msg, ChatMessage) for msg in event.data): + raise ValueError("WorkflowOutputEvent data is not a list of ChatMessage") + + messages: list[ChatMessage] = event.data # type: ignore + + print("\n" + "-" * 60) + print("Workflow completed. Aggregated results from both agents:") + for msg in messages: + if msg.text: + print(f"- {msg.author_name or msg.role.value}: {msg.text}") async def main() -> None: - # 3. Create two agents with different tool sets + # 3. Create two agents focused on different stocks but with the same tool sets chat_client = OpenAIChatClient() - research_agent = chat_client.create_agent( - name="ResearchAgent", + microsoft_agent = chat_client.create_agent( + name="MicrosoftAgent", instructions=( - "You are a market research analyst. Analyze stock data and provide " - "recommendations based on price and sentiment. Do not execute trades." + "You are a personal trading assistant focused on Microsoft (MSFT). " + "You manage my portfolio and take actions based on market data." ), - tools=[get_stock_price, get_market_sentiment], + tools=[get_stock_price, get_market_sentiment, get_portfolio_balance, execute_trade], ) - trading_agent = chat_client.create_agent( - name="TradingAgent", + google_agent = chat_client.create_agent( + name="GoogleAgent", instructions=( - "You are a trading assistant. When asked to buy or sell shares, you MUST " - "call the execute_trade function to complete the transaction. Check portfolio " - "balance first, then execute the requested trade." + "You are a personal trading assistant focused on Google (GOOGL). " + "You manage my trades and portfolio based on market conditions." ), - tools=[get_portfolio_balance, execute_trade], + tools=[get_stock_price, get_market_sentiment, get_portfolio_balance, execute_trade], ) # 4. Build a concurrent workflow with both agents # ConcurrentBuilder requires at least 2 participants for fan-out - workflow = ConcurrentBuilder().participants([research_agent, trading_agent]).build() + workflow = ConcurrentBuilder().participants([microsoft_agent, google_agent]).build() # 5. Start the workflow - both agents will process the same task in parallel print("Starting concurrent workflow with tool approval...") - print("Two agents will analyze MSFT - one for research, one for trading.") print("-" * 60) - # Phase 1: Run workflow and collect all events (stream ends at IDLE or IDLE_WITH_PENDING_REQUESTS) + # Phase 1: Run workflow and collect request info events request_info_events: list[RequestInfoEvent] = [] - workflow_completed_without_approvals = False - async for event in workflow.run_stream("Analyze MSFT stock and if sentiment is positive, buy 10 shares."): + async for event in workflow.run_stream( + "Manage my portfolio. Use a max of 5000 dollars to adjust my position using " + "your best judgment based on market sentiment. No need to confirm trades with me." + ): if isinstance(event, RequestInfoEvent): request_info_events.append(event) if isinstance(event.data, FunctionApprovalRequestContent): print(f"\nApproval requested for tool: {event.data.function_call.name}") print(f" Arguments: {event.data.function_call.arguments}") - elif isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: - workflow_completed_without_approvals = True + elif isinstance(event, WorkflowOutputEvent): + _print_output(event) # 6. Handle approval requests (if any) if request_info_events: @@ -136,46 +155,37 @@ async def main() -> None: if responses: # Phase 2: Send all approvals and continue workflow - output: list[ChatMessage] | None = None async for event in workflow.send_responses_streaming(responses): if isinstance(event, WorkflowOutputEvent): - output = event.data - - if output: - print("\n" + "-" * 60) - print("Workflow completed. Aggregated results from both agents:") - for msg in output: - if hasattr(msg, "author_name") and msg.author_name: - print(f"\n[{msg.author_name}]:") - text = msg.text[:300] + "..." if len(msg.text) > 300 else msg.text - if text: - print(f" {text}") - elif workflow_completed_without_approvals: + _print_output(event) + else: print("\nWorkflow completed without requiring approvals.") - print("(The trading agent may have only checked balance without executing a trade)") + print("(The agents may have only checked data without executing trades)") """ Sample Output: Starting concurrent workflow with tool approval... - Two agents will analyze MSFT - one for research, one for trading. ------------------------------------------------------------ Approval requested for tool: execute_trade - Arguments: {"symbol": "MSFT", "action": "buy", "quantity": 10} + Arguments: {"symbol":"MSFT","action":"buy","quantity":13} + + Approval requested for tool: execute_trade + Arguments: {"symbol":"GOOGL","action":"buy","quantity":35} + + Simulating human approval for: execute_trade + Simulating human approval for: execute_trade ------------------------------------------------------------ Workflow completed. Aggregated results from both agents: - - [ResearchAgent]: - MSFT is currently trading at $175.50 with bullish market sentiment - (72% positive mentions). Based on the positive sentiment, this could - be a good opportunity to consider buying. - - [TradingAgent]: - I've checked your portfolio balance ($10,000 cash available) and - executed the trade: BUY 10 shares of MSFT at approximately $175.50 - per share, totaling ~$1,755. + - user: Manage my portfolio. Use a max of 5000 dollars to adjust my position using your best judgment based on + market sentiment. No need to confirm trades with me. + - MicrosoftAgent: I have successfully executed the trade, purchasing 13 shares of Microsoft (MSFT). This action + was based on the positive market sentiment and available funds within the specified limit. + Your portfolio has been adjusted accordingly. + - GoogleAgent: I have successfully executed the trade, purchasing 35 shares of GOOGL. If you need further + assistance or any adjustments, feel free to ask! """ diff --git a/python/samples/getting_started/workflows/tool-approval/group_chat_builder_tool_approval.py b/python/samples/getting_started/workflows/tool-approval/group_chat_builder_tool_approval.py index 565002c794..a8536afc7f 100644 --- a/python/samples/getting_started/workflows/tool-approval/group_chat_builder_tool_approval.py +++ b/python/samples/getting_started/workflows/tool-approval/group_chat_builder_tool_approval.py @@ -4,9 +4,11 @@ from typing import Annotated from agent_framework import ( + AgentRunUpdateEvent, FunctionApprovalRequestContent, GroupChatBuilder, - GroupChatStateSnapshot, + GroupChatRequestSentEvent, + GroupChatState, RequestInfoEvent, ai_function, ) @@ -73,7 +75,7 @@ def create_rollback_plan(version: Annotated[str, "The version being deployed"]) # 2. Define the speaker selector function -def select_next_speaker(state: GroupChatStateSnapshot) -> str | None: +def select_next_speaker(state: GroupChatState) -> str: """Select the next speaker based on the conversation flow. This simple selector follows a predefined flow: @@ -81,19 +83,13 @@ def select_next_speaker(state: GroupChatStateSnapshot) -> str | None: 2. DevOps Engineer checks staging and creates rollback plan 3. DevOps Engineer deploys to production (triggers approval) """ - round_index: int = state["round_index"] + if not state.conversation: + raise RuntimeError("Conversation is empty; cannot select next speaker.") - # Define the conversation flow - speaker_order: list[str] = [ - "QAEngineer", # Round 0: Run tests - "DevOpsEngineer", # Round 1: Check staging, create rollback - "DevOpsEngineer", # Round 2: Deploy to production (approval required) - ] + if len(state.conversation) == 1: + return "QAEngineer" # First speaker - if round_index >= len(speaker_order): - return None # End the conversation - - return speaker_order[round_index] + return "DevOpsEngineer" # Subsequent speakers async def main() -> None: @@ -123,28 +119,47 @@ async def main() -> None: workflow = ( GroupChatBuilder() # Optionally, use `.set_manager(...)` to customize the group chat manager - .set_select_speakers_func(select_next_speaker) + .with_select_speaker_func(select_next_speaker) .participants([qa_engineer, devops_engineer]) - .with_max_rounds(5) + # Set a hard limit to 4 rounds + # First round: QAEngineer speaks + # Second round: DevOpsEngineer speaks (check staging + create rollback) + # Third round: DevOpsEngineer speaks with an approval request (deploy to production) + # Fourth round: DevOpsEngineer speaks again after approval + .with_max_rounds(4) .build() ) # 5. Start the workflow print("Starting group chat workflow for software deployment...") - print("Agents: QA Engineer, DevOps Engineer") + print(f"Agents: {[qa_engineer.name, devops_engineer.name]}") print("-" * 60) # Phase 1: Run workflow and collect all events (stream ends at IDLE or IDLE_WITH_PENDING_REQUESTS) request_info_events: list[RequestInfoEvent] = [] + # Keep track of the last response to format output nicely in streaming mode + last_response_id: str | None = None async for event in workflow.run_stream( "We need to deploy version 2.4.0 to production. Please coordinate the deployment." ): if isinstance(event, RequestInfoEvent): request_info_events.append(event) if isinstance(event.data, FunctionApprovalRequestContent): - print("\n[APPROVAL REQUIRED]") + print("\n[APPROVAL REQUIRED] From agent:", event.source_executor_id) print(f" Tool: {event.data.function_call.name}") print(f" Arguments: {event.data.function_call.arguments}") + elif isinstance(event, AgentRunUpdateEvent): + if not event.data.text: + continue # Skip empty updates + response_id = event.data.response_id + if response_id != last_response_id: + if last_response_id is not None: + print("\n") + print(f"- {event.executor_id}:", end=" ", flush=True) + last_response_id = response_id + print(event.data, end="", flush=True) + elif isinstance(event, GroupChatRequestSentEvent): + print(f"\n[REQUEST SENT ({event.round_index})] to agent: {event.participant_name}") # 6. Handle approval requests if request_info_events: @@ -160,8 +175,21 @@ async def main() -> None: approval_response = request_event.data.create_response(approved=True) # Phase 2: Send approval and continue workflow - async for _ in workflow.send_responses_streaming({request_event.request_id: approval_response}): - pass # Consume all events + # Keep track of the response to format output nicely in streaming mode + last_response_id: str | None = None + async for event in workflow.send_responses_streaming({request_event.request_id: approval_response}): + if isinstance(event, AgentRunUpdateEvent): + if not event.data.text: + continue # Skip empty updates + response_id = event.data.response_id + if response_id != last_response_id: + if last_response_id is not None: + print("\n") + print(f"- {event.executor_id}:", end=" ", flush=True) + last_response_id = response_id + print(event.data, end="", flush=True) + elif isinstance(event, GroupChatRequestSentEvent): + print(f"\n[REQUEST SENT ({event.round_index})] To agent: {event.participant_name}") print("\n" + "-" * 60) print("Deployment workflow completed successfully!") diff --git a/python/samples/getting_started/workflows/visualization/concurrent_with_visualization.py b/python/samples/getting_started/workflows/visualization/concurrent_with_visualization.py index 81545b75c7..a0555dab4b 100644 --- a/python/samples/getting_started/workflows/visualization/concurrent_with_visualization.py +++ b/python/samples/getting_started/workflows/visualization/concurrent_with_visualization.py @@ -61,8 +61,8 @@ async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowCon # Map responses to text by executor id for a simple, predictable demo. by_id: dict[str, str] = {} for r in results: - # AgentExecutorResponse.agent_run_response.text contains concatenated assistant text - by_id[r.executor_id] = r.agent_run_response.text + # AgentExecutorResponse.agent_response.text contains concatenated assistant text + by_id[r.executor_id] = r.agent_response.text research_text = by_id.get("researcher", "") marketing_text = by_id.get("marketer", "") diff --git a/python/samples/semantic-kernel-migration/copilot_studio/02_copilot_studio_streaming.py b/python/samples/semantic-kernel-migration/copilot_studio/02_copilot_studio_streaming.py index 186d093495..d437ff807e 100644 --- a/python/samples/semantic-kernel-migration/copilot_studio/02_copilot_studio_streaming.py +++ b/python/samples/semantic-kernel-migration/copilot_studio/02_copilot_studio_streaming.py @@ -26,7 +26,7 @@ async def run_agent_framework() -> None: name="TourGuide", instructions="Provide travel recommendations in short bursts.", ) - # AF streaming provides incremental AgentRunResponseUpdate objects. + # AF streaming provides incremental AgentResponseUpdate objects. print("[AF][stream]", end=" ") async for update in agent.run_stream("Plan a day in Copenhagen for foodies."): if update.text: diff --git a/python/samples/semantic-kernel-migration/openai_responses/03_responses_agent_structured_output.py b/python/samples/semantic-kernel-migration/openai_responses/03_responses_agent_structured_output.py index ffc1bf1713..b124e5f0f1 100644 --- a/python/samples/semantic-kernel-migration/openai_responses/03_responses_agent_structured_output.py +++ b/python/samples/semantic-kernel-migration/openai_responses/03_responses_agent_structured_output.py @@ -49,7 +49,7 @@ async def run_agent_framework() -> None: # AF forwards the same response_format payload at invocation time. reply = await chat_agent.run( "Draft a launch brief for the Contoso Note app.", - response_format=ReleaseBrief, + options={"response_format": ReleaseBrief}, ) print("[AF]", reply.text) diff --git a/python/uv.lock b/python/uv.lock index 227860a479..be027c7845 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -93,7 +93,7 @@ wheels = [ [[package]] name = "agent-framework" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { virtual = "." } dependencies = [ { name = "agent-framework-core", extra = ["all"], marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -160,7 +160,7 @@ docs = [ [[package]] name = "agent-framework-a2a" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/a2a" } dependencies = [ { name = "a2a-sdk", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -175,7 +175,7 @@ requires-dist = [ [[package]] name = "agent-framework-ag-ui" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/ag-ui" } dependencies = [ { name = "ag-ui-protocol", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -205,7 +205,7 @@ provides-extras = ["dev"] [[package]] name = "agent-framework-anthropic" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/anthropic" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -220,7 +220,7 @@ requires-dist = [ [[package]] name = "agent-framework-azure-ai" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/azure-ai" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -239,7 +239,7 @@ requires-dist = [ [[package]] name = "agent-framework-azure-ai-search" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/azure-ai-search" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -254,7 +254,7 @@ requires-dist = [ [[package]] name = "agent-framework-azurefunctions" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/azurefunctions" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -281,7 +281,7 @@ dev = [{ name = "types-python-dateutil", specifier = ">=2.9.0" }] [[package]] name = "agent-framework-bedrock" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/bedrock" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -298,7 +298,7 @@ requires-dist = [ [[package]] name = "agent-framework-chatkit" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/chatkit" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -313,7 +313,7 @@ requires-dist = [ [[package]] name = "agent-framework-copilotstudio" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/copilotstudio" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -328,7 +328,7 @@ requires-dist = [ [[package]] name = "agent-framework-core" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/core" } dependencies = [ { name = "azure-identity", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -382,7 +382,7 @@ requires-dist = [ { name = "agent-framework-purview", marker = "extra == 'all'", editable = "packages/purview" }, { name = "agent-framework-redis", marker = "extra == 'all'", editable = "packages/redis" }, { name = "azure-identity", specifier = ">=1,<2" }, - { name = "mcp", extras = ["ws"], specifier = ">=1.23" }, + { name = "mcp", extras = ["ws"], specifier = ">=1.24.0,<2" }, { name = "openai", specifier = ">=1.99.0" }, { name = "opentelemetry-api", specifier = ">=1.39.0" }, { name = "opentelemetry-sdk", specifier = ">=1.39.0" }, @@ -396,7 +396,7 @@ provides-extras = ["all"] [[package]] name = "agent-framework-declarative" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/declarative" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -421,7 +421,7 @@ dev = [{ name = "types-pyyaml" }] [[package]] name = "agent-framework-devui" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/devui" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -480,7 +480,7 @@ dev = [{ name = "types-python-dateutil", specifier = ">=2.9.0" }] [[package]] name = "agent-framework-foundry-local" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/foundry_local" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -495,7 +495,7 @@ requires-dist = [ [[package]] name = "agent-framework-lab" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/lab" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -586,7 +586,7 @@ dev = [ [[package]] name = "agent-framework-mem0" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/mem0" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -601,7 +601,7 @@ requires-dist = [ [[package]] name = "agent-framework-ollama" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/ollama" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -616,7 +616,7 @@ requires-dist = [ [[package]] name = "agent-framework-purview" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/purview" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -633,7 +633,7 @@ requires-dist = [ [[package]] name = "agent-framework-redis" -version = "1.0.0b260107" +version = "1.0.0b260114" source = { editable = "packages/redis" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -883,7 +883,7 @@ wheels = [ [[package]] name = "anthropic" -version = "0.75.0" +version = "0.76.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -895,9 +895,9 @@ dependencies = [ { name = "sniffio", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/04/1f/08e95f4b7e2d35205ae5dcbb4ae97e7d477fc521c275c02609e2931ece2d/anthropic-0.75.0.tar.gz", hash = "sha256:e8607422f4ab616db2ea5baacc215dd5f028da99ce2f022e33c7c535b29f3dfb", size = 439565, upload-time = "2025-11-24T20:41:45.28Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6e/be/d11abafaa15d6304826438170f7574d750218f49a106c54424a40cef4494/anthropic-0.76.0.tar.gz", hash = "sha256:e0cae6a368986d5cf6df743dfbb1b9519e6a9eee9c6c942ad8121c0b34416ffe", size = 495483, upload-time = "2026-01-13T18:41:14.908Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/60/1c/1cd02b7ae64302a6e06724bf80a96401d5313708651d277b1458504a1730/anthropic-0.75.0-py3-none-any.whl", hash = "sha256:ea8317271b6c15d80225a9f3c670152746e88805a7a61e14d4a374577164965b", size = 388164, upload-time = "2025-11-24T20:41:43.587Z" }, + { url = "https://files.pythonhosted.org/packages/e5/70/7b0fd9c1a738f59d3babe2b4212031c34ab7d0fda4ffef15b58a55c5bcea/anthropic-0.76.0-py3-none-any.whl", hash = "sha256:81efa3113901192af2f0fe977d3ec73fdadb1e691586306c4256cd6d5ccc331c", size = 390309, upload-time = "2026-01-13T18:41:13.483Z" }, ] [[package]] @@ -1408,7 +1408,7 @@ name = "clr-loader" version = "0.2.10" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cffi", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "cffi", marker = "(python_full_version < '3.14' and sys_platform == 'darwin') or (python_full_version < '3.14' and sys_platform == 'linux') or (python_full_version < '3.14' and sys_platform == 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/18/24/c12faf3f61614b3131b5c98d3bf0d376b49c7feaa73edca559aeb2aee080/clr_loader-0.2.10.tar.gz", hash = "sha256:81f114afbc5005bafc5efe5af1341d400e22137e275b042a8979f3feb9fc9446", size = 83605, upload-time = "2026-01-03T23:13:06.984Z" } wheels = [ @@ -1916,7 +1916,7 @@ name = "exceptiongroup" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "(python_full_version < '3.13' and sys_platform == 'darwin') or (python_full_version < '3.13' and sys_platform == 'linux') or (python_full_version < '3.13' and sys_platform == 'win32')" }, + { name = "typing-extensions", marker = "(python_full_version < '3.11' and sys_platform == 'darwin') or (python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } wheels = [ @@ -2362,6 +2362,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/6a/33d1702184d94106d3cdd7bfb788e19723206fce152e303473ca3b946c7b/greenlet-3.3.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:6f8496d434d5cb2dce025773ba5597f71f5410ae499d5dd9533e0653258cdb3d", size = 273658, upload-time = "2025-12-04T14:23:37.494Z" }, { url = "https://files.pythonhosted.org/packages/d6/b7/2b5805bbf1907c26e434f4e448cd8b696a0b71725204fa21a211ff0c04a7/greenlet-3.3.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b96dc7eef78fd404e022e165ec55327f935b9b52ff355b067eb4a0267fc1cffb", size = 574810, upload-time = "2025-12-04T14:50:04.154Z" }, { url = "https://files.pythonhosted.org/packages/94/38/343242ec12eddf3d8458c73f555c084359883d4ddc674240d9e61ec51fd6/greenlet-3.3.0-cp310-cp310-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:73631cd5cccbcfe63e3f9492aaa664d278fda0ce5c3d43aeda8e77317e38efbd", size = 586248, upload-time = "2025-12-04T14:57:39.35Z" }, + { url = "https://files.pythonhosted.org/packages/f0/d0/0ae86792fb212e4384041e0ef8e7bc66f59a54912ce407d26a966ed2914d/greenlet-3.3.0-cp310-cp310-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b299a0cb979f5d7197442dccc3aee67fce53500cd88951b7e6c35575701c980b", size = 597403, upload-time = "2025-12-04T15:07:10.831Z" }, { url = "https://files.pythonhosted.org/packages/b6/a8/15d0aa26c0036a15d2659175af00954aaaa5d0d66ba538345bd88013b4d7/greenlet-3.3.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7dee147740789a4632cace364816046e43310b59ff8fb79833ab043aefa72fd5", size = 586910, upload-time = "2025-12-04T14:25:59.705Z" }, { url = "https://files.pythonhosted.org/packages/e1/9b/68d5e3b7ccaba3907e5532cf8b9bf16f9ef5056a008f195a367db0ff32db/greenlet-3.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:39b28e339fc3c348427560494e28d8a6f3561c8d2bcf7d706e1c624ed8d822b9", size = 1547206, upload-time = "2025-12-04T15:04:21.027Z" }, { url = "https://files.pythonhosted.org/packages/66/bd/e3086ccedc61e49f91e2cfb5ffad9d8d62e5dc85e512a6200f096875b60c/greenlet-3.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b3c374782c2935cc63b2a27ba8708471de4ad1abaa862ffdb1ef45a643ddbb7d", size = 1613359, upload-time = "2025-12-04T14:27:26.548Z" }, @@ -2369,6 +2370,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/cb/48e964c452ca2b92175a9b2dca037a553036cb053ba69e284650ce755f13/greenlet-3.3.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:e29f3018580e8412d6aaf5641bb7745d38c85228dacf51a73bd4e26ddf2a6a8e", size = 274908, upload-time = "2025-12-04T14:23:26.435Z" }, { url = "https://files.pythonhosted.org/packages/28/da/38d7bff4d0277b594ec557f479d65272a893f1f2a716cad91efeb8680953/greenlet-3.3.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a687205fb22794e838f947e2194c0566d3812966b41c78709554aa883183fb62", size = 577113, upload-time = "2025-12-04T14:50:05.493Z" }, { url = "https://files.pythonhosted.org/packages/3c/f2/89c5eb0faddc3ff014f1c04467d67dee0d1d334ab81fadbf3744847f8a8a/greenlet-3.3.0-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4243050a88ba61842186cb9e63c7dfa677ec146160b0efd73b855a3d9c7fcf32", size = 590338, upload-time = "2025-12-04T14:57:41.136Z" }, + { url = "https://files.pythonhosted.org/packages/80/d7/db0a5085035d05134f8c089643da2b44cc9b80647c39e93129c5ef170d8f/greenlet-3.3.0-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:670d0f94cd302d81796e37299bcd04b95d62403883b24225c6b5271466612f45", size = 601098, upload-time = "2025-12-04T15:07:11.898Z" }, { url = "https://files.pythonhosted.org/packages/dc/a6/e959a127b630a58e23529972dbc868c107f9d583b5a9f878fb858c46bc1a/greenlet-3.3.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6cb3a8ec3db4a3b0eb8a3c25436c2d49e3505821802074969db017b87bc6a948", size = 590206, upload-time = "2025-12-04T14:26:01.254Z" }, { url = "https://files.pythonhosted.org/packages/48/60/29035719feb91798693023608447283b266b12efc576ed013dd9442364bb/greenlet-3.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2de5a0b09eab81fc6a382791b995b1ccf2b172a9fec934747a7a23d2ff291794", size = 1550668, upload-time = "2025-12-04T15:04:22.439Z" }, { url = "https://files.pythonhosted.org/packages/0a/5f/783a23754b691bfa86bd72c3033aa107490deac9b2ef190837b860996c9f/greenlet-3.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4449a736606bd30f27f8e1ff4678ee193bc47f6ca810d705981cfffd6ce0d8c5", size = 1615483, upload-time = "2025-12-04T14:27:28.083Z" }, @@ -2376,6 +2378,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/0a/a3871375c7b9727edaeeea994bfff7c63ff7804c9829c19309ba2e058807/greenlet-3.3.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:b01548f6e0b9e9784a2c99c5651e5dc89ffcbe870bc5fb2e5ef864e9cc6b5dcb", size = 276379, upload-time = "2025-12-04T14:23:30.498Z" }, { url = "https://files.pythonhosted.org/packages/43/ab/7ebfe34dce8b87be0d11dae91acbf76f7b8246bf9d6b319c741f99fa59c6/greenlet-3.3.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:349345b770dc88f81506c6861d22a6ccd422207829d2c854ae2af8025af303e3", size = 597294, upload-time = "2025-12-04T14:50:06.847Z" }, { url = "https://files.pythonhosted.org/packages/a4/39/f1c8da50024feecd0793dbd5e08f526809b8ab5609224a2da40aad3a7641/greenlet-3.3.0-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e8e18ed6995e9e2c0b4ed264d2cf89260ab3ac7e13555b8032b25a74c6d18655", size = 607742, upload-time = "2025-12-04T14:57:42.349Z" }, + { url = "https://files.pythonhosted.org/packages/77/cb/43692bcd5f7a0da6ec0ec6d58ee7cddb606d055ce94a62ac9b1aa481e969/greenlet-3.3.0-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c024b1e5696626890038e34f76140ed1daf858e37496d33f2af57f06189e70d7", size = 622297, upload-time = "2025-12-04T15:07:13.552Z" }, { url = "https://files.pythonhosted.org/packages/75/b0/6bde0b1011a60782108c01de5913c588cf51a839174538d266de15e4bf4d/greenlet-3.3.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:047ab3df20ede6a57c35c14bf5200fcf04039d50f908270d3f9a7a82064f543b", size = 609885, upload-time = "2025-12-04T14:26:02.368Z" }, { url = "https://files.pythonhosted.org/packages/49/0e/49b46ac39f931f59f987b7cd9f34bfec8ef81d2a1e6e00682f55be5de9f4/greenlet-3.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d9ad37fc657b1102ec880e637cccf20191581f75c64087a549e66c57e1ceb53", size = 1567424, upload-time = "2025-12-04T15:04:23.757Z" }, { url = "https://files.pythonhosted.org/packages/05/f5/49a9ac2dff7f10091935def9165c90236d8f175afb27cbed38fb1d61ab6b/greenlet-3.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83cd0e36932e0e7f36a64b732a6f60c2fc2df28c351bae79fbaf4f8092fe7614", size = 1636017, upload-time = "2025-12-04T14:27:29.688Z" }, @@ -2383,6 +2386,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/2f/28592176381b9ab2cafa12829ba7b472d177f3acc35d8fbcf3673d966fff/greenlet-3.3.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:a1e41a81c7e2825822f4e068c48cb2196002362619e2d70b148f20a831c00739", size = 275140, upload-time = "2025-12-04T14:23:01.282Z" }, { url = "https://files.pythonhosted.org/packages/2c/80/fbe937bf81e9fca98c981fe499e59a3f45df2a04da0baa5c2be0dca0d329/greenlet-3.3.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f515a47d02da4d30caaa85b69474cec77b7929b2e936ff7fb853d42f4bf8808", size = 599219, upload-time = "2025-12-04T14:50:08.309Z" }, { url = "https://files.pythonhosted.org/packages/c2/ff/7c985128f0514271b8268476af89aee6866df5eec04ac17dcfbc676213df/greenlet-3.3.0-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7d2d9fd66bfadf230b385fdc90426fcd6eb64db54b40c495b72ac0feb5766c54", size = 610211, upload-time = "2025-12-04T14:57:43.968Z" }, + { url = "https://files.pythonhosted.org/packages/79/07/c47a82d881319ec18a4510bb30463ed6891f2ad2c1901ed5ec23d3de351f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30a6e28487a790417d036088b3bcb3f3ac7d8babaa7d0139edbaddebf3af9492", size = 624311, upload-time = "2025-12-04T15:07:14.697Z" }, { url = "https://files.pythonhosted.org/packages/fd/8e/424b8c6e78bd9837d14ff7df01a9829fc883ba2ab4ea787d4f848435f23f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:087ea5e004437321508a8d6f20efc4cfec5e3c30118e1417ea96ed1d93950527", size = 612833, upload-time = "2025-12-04T14:26:03.669Z" }, { url = "https://files.pythonhosted.org/packages/b5/ba/56699ff9b7c76ca12f1cdc27a886d0f81f2189c3455ff9f65246780f713d/greenlet-3.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ab97cf74045343f6c60a39913fa59710e4bd26a536ce7ab2397adf8b27e67c39", size = 1567256, upload-time = "2025-12-04T15:04:25.276Z" }, { url = "https://files.pythonhosted.org/packages/1e/37/f31136132967982d698c71a281a8901daf1a8fbab935dce7c0cf15f942cc/greenlet-3.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5375d2e23184629112ca1ea89a53389dddbffcf417dad40125713d88eb5f96e8", size = 1636483, upload-time = "2025-12-04T14:27:30.804Z" }, @@ -2390,6 +2394,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d7/7c/f0a6d0ede2c7bf092d00bc83ad5bafb7e6ec9b4aab2fbdfa6f134dc73327/greenlet-3.3.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:60c2ef0f578afb3c8d92ea07ad327f9a062547137afe91f38408f08aacab667f", size = 275671, upload-time = "2025-12-04T14:23:05.267Z" }, { url = "https://files.pythonhosted.org/packages/44/06/dac639ae1a50f5969d82d2e3dd9767d30d6dbdbab0e1a54010c8fe90263c/greenlet-3.3.0-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a5d554d0712ba1de0a6c94c640f7aeba3f85b3a6e1f2899c11c2c0428da9365", size = 646360, upload-time = "2025-12-04T14:50:10.026Z" }, { url = "https://files.pythonhosted.org/packages/e0/94/0fb76fe6c5369fba9bf98529ada6f4c3a1adf19e406a47332245ef0eb357/greenlet-3.3.0-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3a898b1e9c5f7307ebbde4102908e6cbfcb9ea16284a3abe15cab996bee8b9b3", size = 658160, upload-time = "2025-12-04T14:57:45.41Z" }, + { url = "https://files.pythonhosted.org/packages/93/79/d2c70cae6e823fac36c3bbc9077962105052b7ef81db2f01ec3b9bf17e2b/greenlet-3.3.0-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:dcd2bdbd444ff340e8d6bdf54d2f206ccddbb3ccfdcd3c25bf4afaa7b8f0cf45", size = 671388, upload-time = "2025-12-04T15:07:15.789Z" }, { url = "https://files.pythonhosted.org/packages/b8/14/bab308fc2c1b5228c3224ec2bf928ce2e4d21d8046c161e44a2012b5203e/greenlet-3.3.0-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5773edda4dc00e173820722711d043799d3adb4f01731f40619e07ea2750b955", size = 660166, upload-time = "2025-12-04T14:26:05.099Z" }, { url = "https://files.pythonhosted.org/packages/4b/d2/91465d39164eaa0085177f61983d80ffe746c5a1860f009811d498e7259c/greenlet-3.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ac0549373982b36d5fd5d30beb8a7a33ee541ff98d2b502714a09f1169f31b55", size = 1615193, upload-time = "2025-12-04T15:04:27.041Z" }, { url = "https://files.pythonhosted.org/packages/42/1b/83d110a37044b92423084d52d5d5a3b3a73cafb51b547e6d7366ff62eff1/greenlet-3.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d198d2d977460358c3b3a4dc844f875d1adb33817f0613f663a656f463764ccc", size = 1683653, upload-time = "2025-12-04T14:27:32.366Z" }, @@ -2397,6 +2402,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/66/bd6317bc5932accf351fc19f177ffba53712a202f9df10587da8df257c7e/greenlet-3.3.0-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:d6ed6f85fae6cdfdb9ce04c9bf7a08d666cfcfb914e7d006f44f840b46741931", size = 282638, upload-time = "2025-12-04T14:25:20.941Z" }, { url = "https://files.pythonhosted.org/packages/30/cf/cc81cb030b40e738d6e69502ccbd0dd1bced0588e958f9e757945de24404/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d9125050fcf24554e69c4cacb086b87b3b55dc395a8b3ebe6487b045b2614388", size = 651145, upload-time = "2025-12-04T14:50:11.039Z" }, { url = "https://files.pythonhosted.org/packages/9c/ea/1020037b5ecfe95ca7df8d8549959baceb8186031da83d5ecceff8b08cd2/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:87e63ccfa13c0a0f6234ed0add552af24cc67dd886731f2261e46e241608bee3", size = 654236, upload-time = "2025-12-04T14:57:47.007Z" }, + { url = "https://files.pythonhosted.org/packages/69/cc/1e4bae2e45ca2fa55299f4e85854606a78ecc37fead20d69322f96000504/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2662433acbca297c9153a4023fe2161c8dcfdcc91f10433171cf7e7d94ba2221", size = 662506, upload-time = "2025-12-04T15:07:16.906Z" }, { url = "https://files.pythonhosted.org/packages/57/b9/f8025d71a6085c441a7eaff0fd928bbb275a6633773667023d19179fe815/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3c6e9b9c1527a78520357de498b0e709fb9e2f49c3a513afd5a249007261911b", size = 653783, upload-time = "2025-12-04T14:26:06.225Z" }, { url = "https://files.pythonhosted.org/packages/f6/c7/876a8c7a7485d5d6b5c6821201d542ef28be645aa024cfe1145b35c120c1/greenlet-3.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:286d093f95ec98fdd92fcb955003b8a3d054b4e2cab3e2707a5039e7b50520fd", size = 1614857, upload-time = "2025-12-04T15:04:28.484Z" }, { url = "https://files.pythonhosted.org/packages/4f/dc/041be1dff9f23dac5f48a43323cd0789cb798342011c19a248d9c9335536/greenlet-3.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c10513330af5b8ae16f023e8ddbfb486ab355d04467c4679c5cfe4659975dd9", size = 1676034, upload-time = "2025-12-04T14:27:33.531Z" }, @@ -2663,7 +2669,7 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "1.3.1" +version = "1.3.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -2677,9 +2683,9 @@ dependencies = [ { name = "typer-slim", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/dd/dd/1cc985c5dda36298b152f75e82a1c81f52243b78fb7e9cad637a29561ad1/huggingface_hub-1.3.1.tar.gz", hash = "sha256:e80e0cfb4a75557c51ab20d575bdea6bb6106c2f97b7c75d8490642f1efb6df5", size = 622356, upload-time = "2026-01-09T14:08:16.888Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/d6/02d1c505e1d3364230e5fa16d2b58c8f36a39c5efe8e99bc4d03d06fd0ca/huggingface_hub-1.3.2.tar.gz", hash = "sha256:15d7902e154f04174a0816d1e9594adcf15cdad57596920a5dc70fadb5d896c7", size = 624018, upload-time = "2026-01-14T13:57:39.635Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/90/fb/cb8fe5f71d5622427f20bcab9e06a696a5aaf21bfe7bd0a8a0c63c88abf5/huggingface_hub-1.3.1-py3-none-any.whl", hash = "sha256:efbc7f3153cb84e2bb69b62ed90985e21ecc9343d15647a419fc0ee4b85f0ac3", size = 533351, upload-time = "2026-01-09T14:08:14.519Z" }, + { url = "https://files.pythonhosted.org/packages/88/1d/acd3ef8aabb7813c6ef2f91785d855583ac5cd7c3599e5c1a1a2ed1ec2e5/huggingface_hub-1.3.2-py3-none-any.whl", hash = "sha256:b552b9562a5532102a041fa31a6966bb9de95138fc7aa578bb3703198c25d1b6", size = 534504, upload-time = "2026-01-14T13:57:37.555Z" }, ] [[package]] @@ -3042,7 +3048,7 @@ wheels = [ [[package]] name = "langfuse" -version = "3.11.2" +version = "3.12.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "backoff", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -3056,87 +3062,87 @@ dependencies = [ { name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "wrapt", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/71/10/6b28f3b2c008b1f48478c4f45ceb956dfcc951910f5896b3fe44c20174db/langfuse-3.11.2.tar.gz", hash = "sha256:ab5f296a8056815b7288c7f25bc308a5e79f82a8634467b25daffdde99276e09", size = 230795, upload-time = "2025-12-23T20:42:57.177Z" } +sdist = { url = "https://files.pythonhosted.org/packages/05/d2/33991342653d101715faae8f82c14eb3f0a5c2d22d8c99df9dbb8d099802/langfuse-3.12.0.tar.gz", hash = "sha256:0f75b3d21d4ef4014ebeaa8188eb0c855200412b4e4fb8cceca609a7ce465f91", size = 232651, upload-time = "2026-01-13T14:17:33.659Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/04/95407023b786ed2eef1e2cd220f5baf7b1dd70d88645af129cc1fd1da867/langfuse-3.11.2-py3-none-any.whl", hash = "sha256:84faea9f909694023cc7f0eb45696be190248c8790424f22af57ca4cd7a29f2d", size = 413786, upload-time = "2025-12-23T20:42:55.48Z" }, + { url = "https://files.pythonhosted.org/packages/c3/87/141689c2c2b352ed100de4a63f64f24b4df7f883ba2a3fc0c6733d9d0451/langfuse-3.12.0-py3-none-any.whl", hash = "sha256:644d9bbfa842eb6775b1e069e23f77ad1087f5241682966b8168bbb01f9c357e", size = 416875, upload-time = "2026-01-13T14:17:31.791Z" }, ] [[package]] name = "librt" -version = "0.7.7" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b7/29/47f29026ca17f35cf299290292d5f8331f5077364974b7675a353179afa2/librt-0.7.7.tar.gz", hash = "sha256:81d957b069fed1890953c3b9c3895c7689960f233eea9a1d9607f71ce7f00b2c", size = 145910, upload-time = "2026-01-01T23:52:22.87Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c6/84/2cfb1f3b9b60bab52e16a220c931223fc8e963d0d7bb9132bef012aafc3f/librt-0.7.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4836c5645f40fbdc275e5670819bde5ab5f2e882290d304e3c6ddab1576a6d0", size = 54709, upload-time = "2026-01-01T23:50:48.326Z" }, - { url = "https://files.pythonhosted.org/packages/19/a1/3127b277e9d3784a8040a54e8396d9ae5c64d6684dc6db4b4089b0eedcfb/librt-0.7.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ae8aec43117a645a31e5f60e9e3a0797492e747823b9bda6972d521b436b4e8", size = 56658, upload-time = "2026-01-01T23:50:49.74Z" }, - { url = "https://files.pythonhosted.org/packages/3a/e9/b91b093a5c42eb218120445f3fef82e0b977fa2225f4d6fc133d25cdf86a/librt-0.7.7-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:aea05f701ccd2a76b34f0daf47ca5068176ff553510b614770c90d76ac88df06", size = 161026, upload-time = "2026-01-01T23:50:50.853Z" }, - { url = "https://files.pythonhosted.org/packages/c7/cb/1ded77d5976a79d7057af4a010d577ce4f473ff280984e68f4974a3281e5/librt-0.7.7-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7b16ccaeff0ed4355dfb76fe1ea7a5d6d03b5ad27f295f77ee0557bc20a72495", size = 169529, upload-time = "2026-01-01T23:50:52.24Z" }, - { url = "https://files.pythonhosted.org/packages/da/6e/6ca5bdaa701e15f05000ac1a4c5d1475c422d3484bd3d1ca9e8c2f5be167/librt-0.7.7-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c48c7e150c095d5e3cea7452347ba26094be905d6099d24f9319a8b475fcd3e0", size = 183271, upload-time = "2026-01-01T23:50:55.287Z" }, - { url = "https://files.pythonhosted.org/packages/e7/2d/55c0e38073997b4bbb5ddff25b6d1bbba8c2f76f50afe5bb9c844b702f34/librt-0.7.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4dcee2f921a8632636d1c37f1bbdb8841d15666d119aa61e5399c5268e7ce02e", size = 179039, upload-time = "2026-01-01T23:50:56.807Z" }, - { url = "https://files.pythonhosted.org/packages/33/4e/3662a41ae8bb81b226f3968426293517b271d34d4e9fd4b59fc511f1ae40/librt-0.7.7-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:14ef0f4ac3728ffd85bfc58e2f2f48fb4ef4fa871876f13a73a7381d10a9f77c", size = 173505, upload-time = "2026-01-01T23:50:58.291Z" }, - { url = "https://files.pythonhosted.org/packages/f8/5d/cf768deb8bdcbac5f8c21fcb32dd483d038d88c529fd351bbe50590b945d/librt-0.7.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e4ab69fa37f8090f2d971a5d2bc606c7401170dbdae083c393d6cbf439cb45b8", size = 193570, upload-time = "2026-01-01T23:50:59.546Z" }, - { url = "https://files.pythonhosted.org/packages/a1/ea/ee70effd13f1d651976d83a2812391f6203971740705e3c0900db75d4bce/librt-0.7.7-cp310-cp310-win32.whl", hash = "sha256:4bf3cc46d553693382d2abf5f5bd493d71bb0f50a7c0beab18aa13a5545c8900", size = 42600, upload-time = "2026-01-01T23:51:00.694Z" }, - { url = "https://files.pythonhosted.org/packages/f0/eb/dc098730f281cba76c279b71783f5de2edcba3b880c1ab84a093ef826062/librt-0.7.7-cp310-cp310-win_amd64.whl", hash = "sha256:f0c8fe5aeadd8a0e5b0598f8a6ee3533135ca50fd3f20f130f9d72baf5c6ac58", size = 48977, upload-time = "2026-01-01T23:51:01.726Z" }, - { url = "https://files.pythonhosted.org/packages/f0/56/30b5c342518005546df78841cb0820ae85a17e7d07d521c10ef367306d0d/librt-0.7.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a487b71fbf8a9edb72a8c7a456dda0184642d99cd007bc819c0b7ab93676a8ee", size = 54709, upload-time = "2026-01-01T23:51:02.774Z" }, - { url = "https://files.pythonhosted.org/packages/72/78/9f120e3920b22504d4f3835e28b55acc2cc47c9586d2e1b6ba04c3c1bf01/librt-0.7.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f4d4efb218264ecf0f8516196c9e2d1a0679d9fb3bb15df1155a35220062eba8", size = 56663, upload-time = "2026-01-01T23:51:03.838Z" }, - { url = "https://files.pythonhosted.org/packages/1c/ea/7d7a1ee7dfc1151836028eba25629afcf45b56bbc721293e41aa2e9b8934/librt-0.7.7-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:b8bb331aad734b059c4b450cd0a225652f16889e286b2345af5e2c3c625c3d85", size = 161705, upload-time = "2026-01-01T23:51:04.917Z" }, - { url = "https://files.pythonhosted.org/packages/45/a5/952bc840ac8917fbcefd6bc5f51ad02b89721729814f3e2bfcc1337a76d6/librt-0.7.7-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:467dbd7443bda08338fc8ad701ed38cef48194017554f4c798b0a237904b3f99", size = 171029, upload-time = "2026-01-01T23:51:06.09Z" }, - { url = "https://files.pythonhosted.org/packages/fa/bf/c017ff7da82dc9192cf40d5e802a48a25d00e7639b6465cfdcee5893a22c/librt-0.7.7-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:50d1d1ee813d2d1a3baf2873634ba506b263032418d16287c92ec1cc9c1a00cb", size = 184704, upload-time = "2026-01-01T23:51:07.549Z" }, - { url = "https://files.pythonhosted.org/packages/77/ec/72f3dd39d2cdfd6402ab10836dc9cbf854d145226062a185b419c4f1624a/librt-0.7.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c7e5070cf3ec92d98f57574da0224f8c73faf1ddd6d8afa0b8c9f6e86997bc74", size = 180719, upload-time = "2026-01-01T23:51:09.062Z" }, - { url = "https://files.pythonhosted.org/packages/78/86/06e7a1a81b246f3313bf515dd9613a1c81583e6fd7843a9f4d625c4e926d/librt-0.7.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:bdb9f3d865b2dafe7f9ad7f30ef563c80d0ddd2fdc8cc9b8e4f242f475e34d75", size = 174537, upload-time = "2026-01-01T23:51:10.611Z" }, - { url = "https://files.pythonhosted.org/packages/83/08/f9fb2edc9c7a76e95b2924ce81d545673f5b034e8c5dd92159d1c7dae0c6/librt-0.7.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:8185c8497d45164e256376f9da5aed2bb26ff636c798c9dabe313b90e9f25b28", size = 195238, upload-time = "2026-01-01T23:51:11.762Z" }, - { url = "https://files.pythonhosted.org/packages/ba/56/ea2d2489d3ea1f47b301120e03a099e22de7b32c93df9a211e6ff4f9bf38/librt-0.7.7-cp311-cp311-win32.whl", hash = "sha256:44d63ce643f34a903f09ff7ca355aae019a3730c7afd6a3c037d569beeb5d151", size = 42939, upload-time = "2026-01-01T23:51:13.192Z" }, - { url = "https://files.pythonhosted.org/packages/58/7b/c288f417e42ba2a037f1c0753219e277b33090ed4f72f292fb6fe175db4c/librt-0.7.7-cp311-cp311-win_amd64.whl", hash = "sha256:7d13cc340b3b82134f8038a2bfe7137093693dcad8ba5773da18f95ad6b77a8a", size = 49240, upload-time = "2026-01-01T23:51:14.264Z" }, - { url = "https://files.pythonhosted.org/packages/7c/24/738eb33a6c1516fdb2dfd2a35db6e5300f7616679b573585be0409bc6890/librt-0.7.7-cp311-cp311-win_arm64.whl", hash = "sha256:983de36b5a83fe9222f4f7dcd071f9b1ac6f3f17c0af0238dadfb8229588f890", size = 42613, upload-time = "2026-01-01T23:51:15.268Z" }, - { url = "https://files.pythonhosted.org/packages/56/72/1cd9d752070011641e8aee046c851912d5f196ecd726fffa7aed2070f3e0/librt-0.7.7-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2a85a1fc4ed11ea0eb0a632459ce004a2d14afc085a50ae3463cd3dfe1ce43fc", size = 55687, upload-time = "2026-01-01T23:51:16.291Z" }, - { url = "https://files.pythonhosted.org/packages/50/aa/d5a1d4221c4fe7e76ae1459d24d6037783cb83c7645164c07d7daf1576ec/librt-0.7.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c87654e29a35938baead1c4559858f346f4a2a7588574a14d784f300ffba0efd", size = 57136, upload-time = "2026-01-01T23:51:17.363Z" }, - { url = "https://files.pythonhosted.org/packages/23/6f/0c86b5cb5e7ef63208c8cc22534df10ecc5278efc0d47fb8815577f3ca2f/librt-0.7.7-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:c9faaebb1c6212c20afd8043cd6ed9de0a47d77f91a6b5b48f4e46ed470703fe", size = 165320, upload-time = "2026-01-01T23:51:18.455Z" }, - { url = "https://files.pythonhosted.org/packages/16/37/df4652690c29f645ffe405b58285a4109e9fe855c5bb56e817e3e75840b3/librt-0.7.7-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1908c3e5a5ef86b23391448b47759298f87f997c3bd153a770828f58c2bb4630", size = 174216, upload-time = "2026-01-01T23:51:19.599Z" }, - { url = "https://files.pythonhosted.org/packages/9a/d6/d3afe071910a43133ec9c0f3e4ce99ee6df0d4e44e4bddf4b9e1c6ed41cc/librt-0.7.7-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dbc4900e95a98fc0729523be9d93a8fedebb026f32ed9ffc08acd82e3e181503", size = 189005, upload-time = "2026-01-01T23:51:21.052Z" }, - { url = "https://files.pythonhosted.org/packages/d5/18/74060a870fe2d9fd9f47824eba6717ce7ce03124a0d1e85498e0e7efc1b2/librt-0.7.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a7ea4e1fbd253e5c68ea0fe63d08577f9d288a73f17d82f652ebc61fa48d878d", size = 183961, upload-time = "2026-01-01T23:51:22.493Z" }, - { url = "https://files.pythonhosted.org/packages/7c/5e/918a86c66304af66a3c1d46d54df1b2d0b8894babc42a14fb6f25511497f/librt-0.7.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:ef7699b7a5a244b1119f85c5bbc13f152cd38240cbb2baa19b769433bae98e50", size = 177610, upload-time = "2026-01-01T23:51:23.874Z" }, - { url = "https://files.pythonhosted.org/packages/b2/d7/b5e58dc2d570f162e99201b8c0151acf40a03a39c32ab824dd4febf12736/librt-0.7.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:955c62571de0b181d9e9e0a0303c8bc90d47670a5eff54cf71bf5da61d1899cf", size = 199272, upload-time = "2026-01-01T23:51:25.341Z" }, - { url = "https://files.pythonhosted.org/packages/18/87/8202c9bd0968bdddc188ec3811985f47f58ed161b3749299f2c0dd0f63fb/librt-0.7.7-cp312-cp312-win32.whl", hash = "sha256:1bcd79be209313b270b0e1a51c67ae1af28adad0e0c7e84c3ad4b5cb57aaa75b", size = 43189, upload-time = "2026-01-01T23:51:26.799Z" }, - { url = "https://files.pythonhosted.org/packages/61/8d/80244b267b585e7aa79ffdac19f66c4861effc3a24598e77909ecdd0850e/librt-0.7.7-cp312-cp312-win_amd64.whl", hash = "sha256:4353ee891a1834567e0302d4bd5e60f531912179578c36f3d0430f8c5e16b456", size = 49462, upload-time = "2026-01-01T23:51:27.813Z" }, - { url = "https://files.pythonhosted.org/packages/2d/1f/75db802d6a4992d95e8a889682601af9b49d5a13bbfa246d414eede1b56c/librt-0.7.7-cp312-cp312-win_arm64.whl", hash = "sha256:a76f1d679beccccdf8c1958e732a1dfcd6e749f8821ee59d7bec009ac308c029", size = 42828, upload-time = "2026-01-01T23:51:28.804Z" }, - { url = "https://files.pythonhosted.org/packages/8d/5e/d979ccb0a81407ec47c14ea68fb217ff4315521730033e1dd9faa4f3e2c1/librt-0.7.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8f4a0b0a3c86ba9193a8e23bb18f100d647bf192390ae195d84dfa0a10fb6244", size = 55746, upload-time = "2026-01-01T23:51:29.828Z" }, - { url = "https://files.pythonhosted.org/packages/f5/2c/3b65861fb32f802c3783d6ac66fc5589564d07452a47a8cf9980d531cad3/librt-0.7.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5335890fea9f9e6c4fdf8683061b9ccdcbe47c6dc03ab8e9b68c10acf78be78d", size = 57174, upload-time = "2026-01-01T23:51:31.226Z" }, - { url = "https://files.pythonhosted.org/packages/50/df/030b50614b29e443607220097ebaf438531ea218c7a9a3e21ea862a919cd/librt-0.7.7-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:9b4346b1225be26def3ccc6c965751c74868f0578cbcba293c8ae9168483d811", size = 165834, upload-time = "2026-01-01T23:51:32.278Z" }, - { url = "https://files.pythonhosted.org/packages/5d/e1/bd8d1eacacb24be26a47f157719553bbd1b3fe812c30dddf121c0436fd0b/librt-0.7.7-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a10b8eebdaca6e9fdbaf88b5aefc0e324b763a5f40b1266532590d5afb268a4c", size = 174819, upload-time = "2026-01-01T23:51:33.461Z" }, - { url = "https://files.pythonhosted.org/packages/46/7d/91d6c3372acf54a019c1ad8da4c9ecf4fc27d039708880bf95f48dbe426a/librt-0.7.7-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:067be973d90d9e319e6eb4ee2a9b9307f0ecd648b8a9002fa237289a4a07a9e7", size = 189607, upload-time = "2026-01-01T23:51:34.604Z" }, - { url = "https://files.pythonhosted.org/packages/fa/ac/44604d6d3886f791fbd1c6ae12d5a782a8f4aca927484731979f5e92c200/librt-0.7.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:23d2299ed007812cccc1ecef018db7d922733382561230de1f3954db28433977", size = 184586, upload-time = "2026-01-01T23:51:35.845Z" }, - { url = "https://files.pythonhosted.org/packages/5c/26/d8a6e4c17117b7f9b83301319d9a9de862ae56b133efb4bad8b3aa0808c9/librt-0.7.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:6b6f8ea465524aa4c7420c7cc4ca7d46fe00981de8debc67b1cc2e9957bb5b9d", size = 178251, upload-time = "2026-01-01T23:51:37.018Z" }, - { url = "https://files.pythonhosted.org/packages/99/ab/98d857e254376f8e2f668e807daccc1f445e4b4fc2f6f9c1cc08866b0227/librt-0.7.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8df32a99cc46eb0ee90afd9ada113ae2cafe7e8d673686cf03ec53e49635439", size = 199853, upload-time = "2026-01-01T23:51:38.195Z" }, - { url = "https://files.pythonhosted.org/packages/7c/55/4523210d6ae5134a5da959900be43ad8bab2e4206687b6620befddb5b5fd/librt-0.7.7-cp313-cp313-win32.whl", hash = "sha256:86f86b3b785487c7760247bcdac0b11aa8bf13245a13ed05206286135877564b", size = 43247, upload-time = "2026-01-01T23:51:39.629Z" }, - { url = "https://files.pythonhosted.org/packages/25/40/3ec0fed5e8e9297b1cf1a3836fb589d3de55f9930e3aba988d379e8ef67c/librt-0.7.7-cp313-cp313-win_amd64.whl", hash = "sha256:4862cb2c702b1f905c0503b72d9d4daf65a7fdf5a9e84560e563471e57a56949", size = 49419, upload-time = "2026-01-01T23:51:40.674Z" }, - { url = "https://files.pythonhosted.org/packages/1c/7a/aab5f0fb122822e2acbc776addf8b9abfb4944a9056c00c393e46e543177/librt-0.7.7-cp313-cp313-win_arm64.whl", hash = "sha256:0996c83b1cb43c00e8c87835a284f9057bc647abd42b5871e5f941d30010c832", size = 42828, upload-time = "2026-01-01T23:51:41.731Z" }, - { url = "https://files.pythonhosted.org/packages/69/9c/228a5c1224bd23809a635490a162e9cbdc68d99f0eeb4a696f07886b8206/librt-0.7.7-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:23daa1ab0512bafdd677eb1bfc9611d8ffbe2e328895671e64cb34166bc1b8c8", size = 55188, upload-time = "2026-01-01T23:51:43.14Z" }, - { url = "https://files.pythonhosted.org/packages/ba/c2/0e7c6067e2b32a156308205e5728f4ed6478c501947e9142f525afbc6bd2/librt-0.7.7-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:558a9e5a6f3cc1e20b3168fb1dc802d0d8fa40731f6e9932dcc52bbcfbd37111", size = 56895, upload-time = "2026-01-01T23:51:44.534Z" }, - { url = "https://files.pythonhosted.org/packages/0e/77/de50ff70c80855eb79d1d74035ef06f664dd073fb7fb9d9fb4429651b8eb/librt-0.7.7-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:2567cb48dc03e5b246927ab35cbb343376e24501260a9b5e30b8e255dca0d1d2", size = 163724, upload-time = "2026-01-01T23:51:45.571Z" }, - { url = "https://files.pythonhosted.org/packages/6e/19/f8e4bf537899bdef9e0bb9f0e4b18912c2d0f858ad02091b6019864c9a6d/librt-0.7.7-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6066c638cdf85ff92fc6f932d2d73c93a0e03492cdfa8778e6d58c489a3d7259", size = 172470, upload-time = "2026-01-01T23:51:46.823Z" }, - { url = "https://files.pythonhosted.org/packages/42/4c/dcc575b69d99076768e8dd6141d9aecd4234cba7f0e09217937f52edb6ed/librt-0.7.7-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a609849aca463074c17de9cda173c276eb8fee9e441053529e7b9e249dc8b8ee", size = 186806, upload-time = "2026-01-01T23:51:48.009Z" }, - { url = "https://files.pythonhosted.org/packages/fe/f8/4094a2b7816c88de81239a83ede6e87f1138477d7ee956c30f136009eb29/librt-0.7.7-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:add4e0a000858fe9bb39ed55f31085506a5c38363e6eb4a1e5943a10c2bfc3d1", size = 181809, upload-time = "2026-01-01T23:51:49.35Z" }, - { url = "https://files.pythonhosted.org/packages/1b/ac/821b7c0ab1b5a6cd9aee7ace8309c91545a2607185101827f79122219a7e/librt-0.7.7-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a3bfe73a32bd0bdb9a87d586b05a23c0a1729205d79df66dee65bb2e40d671ba", size = 175597, upload-time = "2026-01-01T23:51:50.636Z" }, - { url = "https://files.pythonhosted.org/packages/71/f9/27f6bfbcc764805864c04211c6ed636fe1d58f57a7b68d1f4ae5ed74e0e0/librt-0.7.7-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:0ecce0544d3db91a40f8b57ae26928c02130a997b540f908cefd4d279d6c5848", size = 196506, upload-time = "2026-01-01T23:51:52.535Z" }, - { url = "https://files.pythonhosted.org/packages/46/ba/c9b9c6fc931dd7ea856c573174ccaf48714905b1a7499904db2552e3bbaf/librt-0.7.7-cp314-cp314-win32.whl", hash = "sha256:8f7a74cf3a80f0c3b0ec75b0c650b2f0a894a2cec57ef75f6f72c1e82cdac61d", size = 39747, upload-time = "2026-01-01T23:51:53.683Z" }, - { url = "https://files.pythonhosted.org/packages/c5/69/cd1269337c4cde3ee70176ee611ab0058aa42fc8ce5c9dce55f48facfcd8/librt-0.7.7-cp314-cp314-win_amd64.whl", hash = "sha256:3d1fe2e8df3268dd6734dba33ededae72ad5c3a859b9577bc00b715759c5aaab", size = 45971, upload-time = "2026-01-01T23:51:54.697Z" }, - { url = "https://files.pythonhosted.org/packages/79/fd/e0844794423f5583108c5991313c15e2b400995f44f6ec6871f8aaf8243c/librt-0.7.7-cp314-cp314-win_arm64.whl", hash = "sha256:2987cf827011907d3dfd109f1be0d61e173d68b1270107bb0e89f2fca7f2ed6b", size = 39075, upload-time = "2026-01-01T23:51:55.726Z" }, - { url = "https://files.pythonhosted.org/packages/42/02/211fd8f7c381e7b2a11d0fdfcd410f409e89967be2e705983f7c6342209a/librt-0.7.7-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:8e92c8de62b40bfce91d5e12c6e8b15434da268979b1af1a6589463549d491e6", size = 57368, upload-time = "2026-01-01T23:51:56.706Z" }, - { url = "https://files.pythonhosted.org/packages/4c/b6/aca257affae73ece26041ae76032153266d110453173f67d7603058e708c/librt-0.7.7-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:f683dcd49e2494a7535e30f779aa1ad6e3732a019d80abe1309ea91ccd3230e3", size = 59238, upload-time = "2026-01-01T23:51:58.066Z" }, - { url = "https://files.pythonhosted.org/packages/96/47/7383a507d8e0c11c78ca34c9d36eab9000db5989d446a2f05dc40e76c64f/librt-0.7.7-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:9b15e5d17812d4d629ff576699954f74e2cc24a02a4fc401882dd94f81daba45", size = 183870, upload-time = "2026-01-01T23:51:59.204Z" }, - { url = "https://files.pythonhosted.org/packages/a4/b8/50f3d8eec8efdaf79443963624175c92cec0ba84827a66b7fcfa78598e51/librt-0.7.7-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c084841b879c4d9b9fa34e5d5263994f21aea7fd9c6add29194dbb41a6210536", size = 194608, upload-time = "2026-01-01T23:52:00.419Z" }, - { url = "https://files.pythonhosted.org/packages/23/d9/1b6520793aadb59d891e3b98ee057a75de7f737e4a8b4b37fdbecb10d60f/librt-0.7.7-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c8fb9966f84737115513fecbaf257f9553d067a7dd45a69c2c7e5339e6a8dc", size = 206776, upload-time = "2026-01-01T23:52:01.705Z" }, - { url = "https://files.pythonhosted.org/packages/ff/db/331edc3bba929d2756fa335bfcf736f36eff4efcb4f2600b545a35c2ae58/librt-0.7.7-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:9b5fb1ecb2c35362eab2dbd354fd1efa5a8440d3e73a68be11921042a0edc0ff", size = 203206, upload-time = "2026-01-01T23:52:03.315Z" }, - { url = "https://files.pythonhosted.org/packages/b2/e1/6af79ec77204e85f6f2294fc171a30a91bb0e35d78493532ed680f5d98be/librt-0.7.7-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:d1454899909d63cc9199a89fcc4f81bdd9004aef577d4ffc022e600c412d57f3", size = 196697, upload-time = "2026-01-01T23:52:04.857Z" }, - { url = "https://files.pythonhosted.org/packages/f3/46/de55ecce4b2796d6d243295c221082ca3a944dc2fb3a52dcc8660ce7727d/librt-0.7.7-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7ef28f2e7a016b29792fe0a2dd04dec75725b32a1264e390c366103f834a9c3a", size = 217193, upload-time = "2026-01-01T23:52:06.159Z" }, - { url = "https://files.pythonhosted.org/packages/41/61/33063e271949787a2f8dd33c5260357e3d512a114fc82ca7890b65a76e2d/librt-0.7.7-cp314-cp314t-win32.whl", hash = "sha256:5e419e0db70991b6ba037b70c1d5bbe92b20ddf82f31ad01d77a347ed9781398", size = 40277, upload-time = "2026-01-01T23:52:07.625Z" }, - { url = "https://files.pythonhosted.org/packages/06/21/1abd972349f83a696ea73159ac964e63e2d14086fdd9bc7ca878c25fced4/librt-0.7.7-cp314-cp314t-win_amd64.whl", hash = "sha256:d6b7d93657332c817b8d674ef6bf1ab7796b4f7ce05e420fd45bd258a72ac804", size = 46765, upload-time = "2026-01-01T23:52:08.647Z" }, - { url = "https://files.pythonhosted.org/packages/51/0e/b756c7708143a63fca65a51ca07990fa647db2cc8fcd65177b9e96680255/librt-0.7.7-cp314-cp314t-win_arm64.whl", hash = "sha256:142c2cd91794b79fd0ce113bd658993b7ede0fe93057668c2f98a45ca00b7e91", size = 39724, upload-time = "2026-01-01T23:52:09.745Z" }, +version = "0.7.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/24/5f3646ff414285e0f7708fa4e946b9bf538345a41d1c375c439467721a5e/librt-0.7.8.tar.gz", hash = "sha256:1a4ede613941d9c3470b0368be851df6bb78ab218635512d0370b27a277a0862", size = 148323, upload-time = "2026-01-14T12:56:16.876Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/13/57b06758a13550c5f09563893b004f98e9537ee6ec67b7df85c3571c8832/librt-0.7.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b45306a1fc5f53c9330fbee134d8b3227fe5da2ab09813b892790400aa49352d", size = 56521, upload-time = "2026-01-14T12:54:40.066Z" }, + { url = "https://files.pythonhosted.org/packages/c2/24/bbea34d1452a10612fb45ac8356f95351ba40c2517e429602160a49d1fd0/librt-0.7.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:864c4b7083eeee250ed55135d2127b260d7eb4b5e953a9e5df09c852e327961b", size = 58456, upload-time = "2026-01-14T12:54:41.471Z" }, + { url = "https://files.pythonhosted.org/packages/04/72/a168808f92253ec3a810beb1eceebc465701197dbc7e865a1c9ceb3c22c7/librt-0.7.8-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:6938cc2de153bc927ed8d71c7d2f2ae01b4e96359126c602721340eb7ce1a92d", size = 164392, upload-time = "2026-01-14T12:54:42.843Z" }, + { url = "https://files.pythonhosted.org/packages/14/5c/4c0d406f1b02735c2e7af8ff1ff03a6577b1369b91aa934a9fa2cc42c7ce/librt-0.7.8-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:66daa6ac5de4288a5bbfbe55b4caa7bf0cd26b3269c7a476ffe8ce45f837f87d", size = 172959, upload-time = "2026-01-14T12:54:44.602Z" }, + { url = "https://files.pythonhosted.org/packages/82/5f/3e85351c523f73ad8d938989e9a58c7f59fb9c17f761b9981b43f0025ce7/librt-0.7.8-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4864045f49dc9c974dadb942ac56a74cd0479a2aafa51ce272c490a82322ea3c", size = 186717, upload-time = "2026-01-14T12:54:45.986Z" }, + { url = "https://files.pythonhosted.org/packages/08/f8/18bfe092e402d00fe00d33aa1e01dda1bd583ca100b393b4373847eade6d/librt-0.7.8-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a36515b1328dc5b3ffce79fe204985ca8572525452eacabee2166f44bb387b2c", size = 184585, upload-time = "2026-01-14T12:54:47.139Z" }, + { url = "https://files.pythonhosted.org/packages/4e/fc/f43972ff56fd790a9fa55028a52ccea1875100edbb856b705bd393b601e3/librt-0.7.8-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b7e7f140c5169798f90b80d6e607ed2ba5059784968a004107c88ad61fb3641d", size = 180497, upload-time = "2026-01-14T12:54:48.946Z" }, + { url = "https://files.pythonhosted.org/packages/e1/3a/25e36030315a410d3ad0b7d0f19f5f188e88d1613d7d3fd8150523ea1093/librt-0.7.8-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ff71447cb778a4f772ddc4ce360e6ba9c95527ed84a52096bd1bbf9fee2ec7c0", size = 200052, upload-time = "2026-01-14T12:54:50.382Z" }, + { url = "https://files.pythonhosted.org/packages/fc/b8/f3a5a1931ae2a6ad92bf6893b9ef44325b88641d58723529e2c2935e8abe/librt-0.7.8-cp310-cp310-win32.whl", hash = "sha256:047164e5f68b7a8ebdf9fae91a3c2161d3192418aadd61ddd3a86a56cbe3dc85", size = 43477, upload-time = "2026-01-14T12:54:51.815Z" }, + { url = "https://files.pythonhosted.org/packages/fe/91/c4202779366bc19f871b4ad25db10fcfa1e313c7893feb942f32668e8597/librt-0.7.8-cp310-cp310-win_amd64.whl", hash = "sha256:d6f254d096d84156a46a84861183c183d30734e52383602443292644d895047c", size = 49806, upload-time = "2026-01-14T12:54:53.149Z" }, + { url = "https://files.pythonhosted.org/packages/1b/a3/87ea9c1049f2c781177496ebee29430e4631f439b8553a4969c88747d5d8/librt-0.7.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ff3e9c11aa260c31493d4b3197d1e28dd07768594a4f92bec4506849d736248f", size = 56507, upload-time = "2026-01-14T12:54:54.156Z" }, + { url = "https://files.pythonhosted.org/packages/5e/4a/23bcef149f37f771ad30203d561fcfd45b02bc54947b91f7a9ac34815747/librt-0.7.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ddb52499d0b3ed4aa88746aaf6f36a08314677d5c346234c3987ddc506404eac", size = 58455, upload-time = "2026-01-14T12:54:55.978Z" }, + { url = "https://files.pythonhosted.org/packages/22/6e/46eb9b85c1b9761e0f42b6e6311e1cc544843ac897457062b9d5d0b21df4/librt-0.7.8-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:e9c0afebbe6ce177ae8edba0c7c4d626f2a0fc12c33bb993d163817c41a7a05c", size = 164956, upload-time = "2026-01-14T12:54:57.311Z" }, + { url = "https://files.pythonhosted.org/packages/7a/3f/aa7c7f6829fb83989feb7ba9aa11c662b34b4bd4bd5b262f2876ba3db58d/librt-0.7.8-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:631599598e2c76ded400c0a8722dec09217c89ff64dc54b060f598ed68e7d2a8", size = 174364, upload-time = "2026-01-14T12:54:59.089Z" }, + { url = "https://files.pythonhosted.org/packages/3f/2d/d57d154b40b11f2cb851c4df0d4c4456bacd9b1ccc4ecb593ddec56c1a8b/librt-0.7.8-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9c1ba843ae20db09b9d5c80475376168feb2640ce91cd9906414f23cc267a1ff", size = 188034, upload-time = "2026-01-14T12:55:00.141Z" }, + { url = "https://files.pythonhosted.org/packages/59/f9/36c4dad00925c16cd69d744b87f7001792691857d3b79187e7a673e812fb/librt-0.7.8-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b5b007bb22ea4b255d3ee39dfd06d12534de2fcc3438567d9f48cdaf67ae1ae3", size = 186295, upload-time = "2026-01-14T12:55:01.303Z" }, + { url = "https://files.pythonhosted.org/packages/23/9b/8a9889d3df5efb67695a67785028ccd58e661c3018237b73ad081691d0cb/librt-0.7.8-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:dbd79caaf77a3f590cbe32dc2447f718772d6eea59656a7dcb9311161b10fa75", size = 181470, upload-time = "2026-01-14T12:55:02.492Z" }, + { url = "https://files.pythonhosted.org/packages/43/64/54d6ef11afca01fef8af78c230726a9394759f2addfbf7afc5e3cc032a45/librt-0.7.8-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:87808a8d1e0bd62a01cafc41f0fd6818b5a5d0ca0d8a55326a81643cdda8f873", size = 201713, upload-time = "2026-01-14T12:55:03.919Z" }, + { url = "https://files.pythonhosted.org/packages/2d/29/73e7ed2991330b28919387656f54109139b49e19cd72902f466bd44415fd/librt-0.7.8-cp311-cp311-win32.whl", hash = "sha256:31724b93baa91512bd0a376e7cf0b59d8b631ee17923b1218a65456fa9bda2e7", size = 43803, upload-time = "2026-01-14T12:55:04.996Z" }, + { url = "https://files.pythonhosted.org/packages/3f/de/66766ff48ed02b4d78deea30392ae200bcbd99ae61ba2418b49fd50a4831/librt-0.7.8-cp311-cp311-win_amd64.whl", hash = "sha256:978e8b5f13e52cf23a9e80f3286d7546baa70bc4ef35b51d97a709d0b28e537c", size = 50080, upload-time = "2026-01-14T12:55:06.489Z" }, + { url = "https://files.pythonhosted.org/packages/6f/e3/33450438ff3a8c581d4ed7f798a70b07c3206d298cf0b87d3806e72e3ed8/librt-0.7.8-cp311-cp311-win_arm64.whl", hash = "sha256:20e3946863d872f7cabf7f77c6c9d370b8b3d74333d3a32471c50d3a86c0a232", size = 43383, upload-time = "2026-01-14T12:55:07.49Z" }, + { url = "https://files.pythonhosted.org/packages/56/04/79d8fcb43cae376c7adbab7b2b9f65e48432c9eced62ac96703bcc16e09b/librt-0.7.8-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9b6943885b2d49c48d0cff23b16be830ba46b0152d98f62de49e735c6e655a63", size = 57472, upload-time = "2026-01-14T12:55:08.528Z" }, + { url = "https://files.pythonhosted.org/packages/b4/ba/60b96e93043d3d659da91752689023a73981336446ae82078cddf706249e/librt-0.7.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:46ef1f4b9b6cc364b11eea0ecc0897314447a66029ee1e55859acb3dd8757c93", size = 58986, upload-time = "2026-01-14T12:55:09.466Z" }, + { url = "https://files.pythonhosted.org/packages/7c/26/5215e4cdcc26e7be7eee21955a7e13cbf1f6d7d7311461a6014544596fac/librt-0.7.8-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:907ad09cfab21e3c86e8f1f87858f7049d1097f77196959c033612f532b4e592", size = 168422, upload-time = "2026-01-14T12:55:10.499Z" }, + { url = "https://files.pythonhosted.org/packages/0f/84/e8d1bc86fa0159bfc24f3d798d92cafd3897e84c7fea7fe61b3220915d76/librt-0.7.8-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2991b6c3775383752b3ca0204842743256f3ad3deeb1d0adc227d56b78a9a850", size = 177478, upload-time = "2026-01-14T12:55:11.577Z" }, + { url = "https://files.pythonhosted.org/packages/57/11/d0268c4b94717a18aa91df1100e767b010f87b7ae444dafaa5a2d80f33a6/librt-0.7.8-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:03679b9856932b8c8f674e87aa3c55ea11c9274301f76ae8dc4d281bda55cf62", size = 192439, upload-time = "2026-01-14T12:55:12.7Z" }, + { url = "https://files.pythonhosted.org/packages/8d/56/1e8e833b95fe684f80f8894ae4d8b7d36acc9203e60478fcae599120a975/librt-0.7.8-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3968762fec1b2ad34ce57458b6de25dbb4142713e9ca6279a0d352fa4e9f452b", size = 191483, upload-time = "2026-01-14T12:55:13.838Z" }, + { url = "https://files.pythonhosted.org/packages/17/48/f11cf28a2cb6c31f282009e2208312aa84a5ee2732859f7856ee306176d5/librt-0.7.8-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:bb7a7807523a31f03061288cc4ffc065d684c39db7644c676b47d89553c0d714", size = 185376, upload-time = "2026-01-14T12:55:15.017Z" }, + { url = "https://files.pythonhosted.org/packages/b8/6a/d7c116c6da561b9155b184354a60a3d5cdbf08fc7f3678d09c95679d13d9/librt-0.7.8-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad64a14b1e56e702e19b24aae108f18ad1bf7777f3af5fcd39f87d0c5a814449", size = 206234, upload-time = "2026-01-14T12:55:16.571Z" }, + { url = "https://files.pythonhosted.org/packages/61/de/1975200bb0285fc921c5981d9978ce6ce11ae6d797df815add94a5a848a3/librt-0.7.8-cp312-cp312-win32.whl", hash = "sha256:0241a6ed65e6666236ea78203a73d800dbed896cf12ae25d026d75dc1fcd1dac", size = 44057, upload-time = "2026-01-14T12:55:18.077Z" }, + { url = "https://files.pythonhosted.org/packages/8e/cd/724f2d0b3461426730d4877754b65d39f06a41ac9d0a92d5c6840f72b9ae/librt-0.7.8-cp312-cp312-win_amd64.whl", hash = "sha256:6db5faf064b5bab9675c32a873436b31e01d66ca6984c6f7f92621656033a708", size = 50293, upload-time = "2026-01-14T12:55:19.179Z" }, + { url = "https://files.pythonhosted.org/packages/bd/cf/7e899acd9ee5727ad8160fdcc9994954e79fab371c66535c60e13b968ffc/librt-0.7.8-cp312-cp312-win_arm64.whl", hash = "sha256:57175aa93f804d2c08d2edb7213e09276bd49097611aefc37e3fa38d1fb99ad0", size = 43574, upload-time = "2026-01-14T12:55:20.185Z" }, + { url = "https://files.pythonhosted.org/packages/a1/fe/b1f9de2829cf7fc7649c1dcd202cfd873837c5cc2fc9e526b0e7f716c3d2/librt-0.7.8-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4c3995abbbb60b3c129490fa985dfe6cac11d88fc3c36eeb4fb1449efbbb04fc", size = 57500, upload-time = "2026-01-14T12:55:21.219Z" }, + { url = "https://files.pythonhosted.org/packages/eb/d4/4a60fbe2e53b825f5d9a77325071d61cd8af8506255067bf0c8527530745/librt-0.7.8-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:44e0c2cbc9bebd074cf2cdbe472ca185e824be4e74b1c63a8e934cea674bebf2", size = 59019, upload-time = "2026-01-14T12:55:22.256Z" }, + { url = "https://files.pythonhosted.org/packages/6a/37/61ff80341ba5159afa524445f2d984c30e2821f31f7c73cf166dcafa5564/librt-0.7.8-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:4d2f1e492cae964b3463a03dc77a7fe8742f7855d7258c7643f0ee32b6651dd3", size = 169015, upload-time = "2026-01-14T12:55:23.24Z" }, + { url = "https://files.pythonhosted.org/packages/1c/86/13d4f2d6a93f181ebf2fc953868826653ede494559da8268023fe567fca3/librt-0.7.8-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:451e7ffcef8f785831fdb791bd69211f47e95dc4c6ddff68e589058806f044c6", size = 178161, upload-time = "2026-01-14T12:55:24.826Z" }, + { url = "https://files.pythonhosted.org/packages/88/26/e24ef01305954fc4d771f1f09f3dd682f9eb610e1bec188ffb719374d26e/librt-0.7.8-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3469e1af9f1380e093ae06bedcbdd11e407ac0b303a56bbe9afb1d6824d4982d", size = 193015, upload-time = "2026-01-14T12:55:26.04Z" }, + { url = "https://files.pythonhosted.org/packages/88/a0/92b6bd060e720d7a31ed474d046a69bd55334ec05e9c446d228c4b806ae3/librt-0.7.8-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f11b300027ce19a34f6d24ebb0a25fd0e24a9d53353225a5c1e6cadbf2916b2e", size = 192038, upload-time = "2026-01-14T12:55:27.208Z" }, + { url = "https://files.pythonhosted.org/packages/06/bb/6f4c650253704279c3a214dad188101d1b5ea23be0606628bc6739456624/librt-0.7.8-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4adc73614f0d3c97874f02f2c7fd2a27854e7e24ad532ea6b965459c5b757eca", size = 186006, upload-time = "2026-01-14T12:55:28.594Z" }, + { url = "https://files.pythonhosted.org/packages/dc/00/1c409618248d43240cadf45f3efb866837fa77e9a12a71481912135eb481/librt-0.7.8-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:60c299e555f87e4c01b2eca085dfccda1dde87f5a604bb45c2906b8305819a93", size = 206888, upload-time = "2026-01-14T12:55:30.214Z" }, + { url = "https://files.pythonhosted.org/packages/d9/83/b2cfe8e76ff5c1c77f8a53da3d5de62d04b5ebf7cf913e37f8bca43b5d07/librt-0.7.8-cp313-cp313-win32.whl", hash = "sha256:b09c52ed43a461994716082ee7d87618096851319bf695d57ec123f2ab708951", size = 44126, upload-time = "2026-01-14T12:55:31.44Z" }, + { url = "https://files.pythonhosted.org/packages/a9/0b/c59d45de56a51bd2d3a401fc63449c0ac163e4ef7f523ea8b0c0dee86ec5/librt-0.7.8-cp313-cp313-win_amd64.whl", hash = "sha256:f8f4a901a3fa28969d6e4519deceab56c55a09d691ea7b12ca830e2fa3461e34", size = 50262, upload-time = "2026-01-14T12:55:33.01Z" }, + { url = "https://files.pythonhosted.org/packages/fc/b9/973455cec0a1ec592395250c474164c4a58ebf3e0651ee920fef1a2623f1/librt-0.7.8-cp313-cp313-win_arm64.whl", hash = "sha256:43d4e71b50763fcdcf64725ac680d8cfa1706c928b844794a7aa0fa9ac8e5f09", size = 43600, upload-time = "2026-01-14T12:55:34.054Z" }, + { url = "https://files.pythonhosted.org/packages/1a/73/fa8814c6ce2d49c3827829cadaa1589b0bf4391660bd4510899393a23ebc/librt-0.7.8-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:be927c3c94c74b05128089a955fba86501c3b544d1d300282cc1b4bd370cb418", size = 57049, upload-time = "2026-01-14T12:55:35.056Z" }, + { url = "https://files.pythonhosted.org/packages/53/fe/f6c70956da23ea235fd2e3cc16f4f0b4ebdfd72252b02d1164dd58b4e6c3/librt-0.7.8-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:7b0803e9008c62a7ef79058233db7ff6f37a9933b8f2573c05b07ddafa226611", size = 58689, upload-time = "2026-01-14T12:55:36.078Z" }, + { url = "https://files.pythonhosted.org/packages/1f/4d/7a2481444ac5fba63050d9abe823e6bc16896f575bfc9c1e5068d516cdce/librt-0.7.8-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:79feb4d00b2a4e0e05c9c56df707934f41fcb5fe53fd9efb7549068d0495b758", size = 166808, upload-time = "2026-01-14T12:55:37.595Z" }, + { url = "https://files.pythonhosted.org/packages/ac/3c/10901d9e18639f8953f57c8986796cfbf4c1c514844a41c9197cf87cb707/librt-0.7.8-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b9122094e3f24aa759c38f46bd8863433820654927370250f460ae75488b66ea", size = 175614, upload-time = "2026-01-14T12:55:38.756Z" }, + { url = "https://files.pythonhosted.org/packages/db/01/5cbdde0951a5090a80e5ba44e6357d375048123c572a23eecfb9326993a7/librt-0.7.8-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7e03bea66af33c95ce3addf87a9bf1fcad8d33e757bc479957ddbc0e4f7207ac", size = 189955, upload-time = "2026-01-14T12:55:39.939Z" }, + { url = "https://files.pythonhosted.org/packages/6a/b4/e80528d2f4b7eaf1d437fcbd6fc6ba4cbeb3e2a0cb9ed5a79f47c7318706/librt-0.7.8-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:f1ade7f31675db00b514b98f9ab9a7698c7282dad4be7492589109471852d398", size = 189370, upload-time = "2026-01-14T12:55:41.057Z" }, + { url = "https://files.pythonhosted.org/packages/c1/ab/938368f8ce31a9787ecd4becb1e795954782e4312095daf8fd22420227c8/librt-0.7.8-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a14229ac62adcf1b90a15992f1ab9c69ae8b99ffb23cb64a90878a6e8a2f5b81", size = 183224, upload-time = "2026-01-14T12:55:42.328Z" }, + { url = "https://files.pythonhosted.org/packages/3c/10/559c310e7a6e4014ac44867d359ef8238465fb499e7eb31b6bfe3e3f86f5/librt-0.7.8-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5bcaaf624fd24e6a0cb14beac37677f90793a96864c67c064a91458611446e83", size = 203541, upload-time = "2026-01-14T12:55:43.501Z" }, + { url = "https://files.pythonhosted.org/packages/f8/db/a0db7acdb6290c215f343835c6efda5b491bb05c3ddc675af558f50fdba3/librt-0.7.8-cp314-cp314-win32.whl", hash = "sha256:7aa7d5457b6c542ecaed79cec4ad98534373c9757383973e638ccced0f11f46d", size = 40657, upload-time = "2026-01-14T12:55:44.668Z" }, + { url = "https://files.pythonhosted.org/packages/72/e0/4f9bdc2a98a798511e81edcd6b54fe82767a715e05d1921115ac70717f6f/librt-0.7.8-cp314-cp314-win_amd64.whl", hash = "sha256:3d1322800771bee4a91f3b4bd4e49abc7d35e65166821086e5afd1e6c0d9be44", size = 46835, upload-time = "2026-01-14T12:55:45.655Z" }, + { url = "https://files.pythonhosted.org/packages/f9/3d/59c6402e3dec2719655a41ad027a7371f8e2334aa794ed11533ad5f34969/librt-0.7.8-cp314-cp314-win_arm64.whl", hash = "sha256:5363427bc6a8c3b1719f8f3845ea53553d301382928a86e8fab7984426949bce", size = 39885, upload-time = "2026-01-14T12:55:47.138Z" }, + { url = "https://files.pythonhosted.org/packages/4e/9c/2481d80950b83085fb14ba3c595db56330d21bbc7d88a19f20165f3538db/librt-0.7.8-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:ca916919793a77e4a98d4a1701e345d337ce53be4a16620f063191f7322ac80f", size = 59161, upload-time = "2026-01-14T12:55:48.45Z" }, + { url = "https://files.pythonhosted.org/packages/96/79/108df2cfc4e672336765d54e3ff887294c1cc36ea4335c73588875775527/librt-0.7.8-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:54feb7b4f2f6706bb82325e836a01be805770443e2400f706e824e91f6441dde", size = 61008, upload-time = "2026-01-14T12:55:49.527Z" }, + { url = "https://files.pythonhosted.org/packages/46/f2/30179898f9994a5637459d6e169b6abdc982012c0a4b2d4c26f50c06f911/librt-0.7.8-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:39a4c76fee41007070f872b648cc2f711f9abf9a13d0c7162478043377b52c8e", size = 187199, upload-time = "2026-01-14T12:55:50.587Z" }, + { url = "https://files.pythonhosted.org/packages/b4/da/f7563db55cebdc884f518ba3791ad033becc25ff68eb70902b1747dc0d70/librt-0.7.8-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ac9c8a458245c7de80bc1b9765b177055efff5803f08e548dd4bb9ab9a8d789b", size = 198317, upload-time = "2026-01-14T12:55:51.991Z" }, + { url = "https://files.pythonhosted.org/packages/b3/6c/4289acf076ad371471fa86718c30ae353e690d3de6167f7db36f429272f1/librt-0.7.8-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:95b67aa7eff150f075fda09d11f6bfb26edffd300f6ab1666759547581e8f666", size = 210334, upload-time = "2026-01-14T12:55:53.682Z" }, + { url = "https://files.pythonhosted.org/packages/4a/7f/377521ac25b78ac0a5ff44127a0360ee6d5ddd3ce7327949876a30533daa/librt-0.7.8-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:535929b6eff670c593c34ff435d5440c3096f20fa72d63444608a5aef64dd581", size = 211031, upload-time = "2026-01-14T12:55:54.827Z" }, + { url = "https://files.pythonhosted.org/packages/c5/b1/e1e96c3e20b23d00cf90f4aad48f0deb4cdfec2f0ed8380d0d85acf98bbf/librt-0.7.8-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:63937bd0f4d1cb56653dc7ae900d6c52c41f0015e25aaf9902481ee79943b33a", size = 204581, upload-time = "2026-01-14T12:55:56.811Z" }, + { url = "https://files.pythonhosted.org/packages/43/71/0f5d010e92ed9747e14bef35e91b6580533510f1e36a8a09eb79ee70b2f0/librt-0.7.8-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:cf243da9e42d914036fd362ac3fa77d80a41cadcd11ad789b1b5eec4daaf67ca", size = 224731, upload-time = "2026-01-14T12:55:58.175Z" }, + { url = "https://files.pythonhosted.org/packages/22/f0/07fb6ab5c39a4ca9af3e37554f9d42f25c464829254d72e4ebbd81da351c/librt-0.7.8-cp314-cp314t-win32.whl", hash = "sha256:171ca3a0a06c643bd0a2f62a8944e1902c94aa8e5da4db1ea9a8daf872685365", size = 41173, upload-time = "2026-01-14T12:55:59.315Z" }, + { url = "https://files.pythonhosted.org/packages/24/d4/7e4be20993dc6a782639625bd2f97f3c66125c7aa80c82426956811cfccf/librt-0.7.8-cp314-cp314t-win_amd64.whl", hash = "sha256:445b7304145e24c60288a2f172b5ce2ca35c0f81605f5299f3fa567e189d2e32", size = 47668, upload-time = "2026-01-14T12:56:00.261Z" }, + { url = "https://files.pythonhosted.org/packages/fc/85/69f92b2a7b3c0f88ffe107c86b952b397004b5b8ea5a81da3d9c04c04422/librt-0.7.8-cp314-cp314t-win_arm64.whl", hash = "sha256:8766ece9de08527deabcd7cb1b4f1a967a385d26e33e536d6d8913db6ef74f06", size = 40550, upload-time = "2026-01-14T12:56:01.542Z" }, ] [[package]] name = "litellm" -version = "1.80.15" +version = "1.80.16" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -3154,9 +3160,9 @@ dependencies = [ { name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "tokenizers", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/12/41/9b28df3e4739df83ddb32dfb2bccb12ad271d986494c9fd60e4927a0a6c3/litellm-1.80.15.tar.gz", hash = "sha256:759d09f33c9c6028c58dcdf71781b17b833ee926525714e09a408602be27f54e", size = 13376508, upload-time = "2026-01-11T18:31:44.95Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ac/cc/03bf7849c62587db0fb7c46f427d6a49290750752d3189a0bd95d4b78587/litellm-1.80.16.tar.gz", hash = "sha256:f96233649f99ab097f7d8a3ff9898680207b9eea7d2e23f438074a3dbcf50cca", size = 13384256, upload-time = "2026-01-13T08:52:23.067Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/df/3b/b1bd693721ccb3c9a37c8233d019a643ac57bef5a93f279e5a63839ee4db/litellm-1.80.15-py3-none-any.whl", hash = "sha256:f354e49456985a235b9ed99df1c19d686d30501f96e68882dcc5b29b1e7c59d9", size = 11670707, upload-time = "2026-01-11T18:31:41.67Z" }, + { url = "https://files.pythonhosted.org/packages/53/4d/73fdb12223bdb01889134eb75525fcc768b1724255f2b87072dd6743c6e1/litellm-1.80.16-py3-none-any.whl", hash = "sha256:21be641b350561b293b831addb25249676b72ebff973a5a1d73b5d7cf35bcd1d", size = 11682530, upload-time = "2026-01-13T08:52:19.951Z" }, ] [package.optional-dependencies] @@ -3431,7 +3437,7 @@ wheels = [ [[package]] name = "mem0ai" -version = "1.0.1" +version = "1.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -3442,9 +3448,9 @@ dependencies = [ { name = "qdrant-client", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "sqlalchemy", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/39/cd/f9047cd45952af08da8084c2297f8aad780f9ac8558631fc64b3ed235b28/mem0ai-1.0.1.tar.gz", hash = "sha256:53be77f479387e6c07508096eb6c0688150b31152613bdcf6c281246b000b14d", size = 182296, upload-time = "2025-11-13T22:32:13.658Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/b3/57edb1253e7dc24d41e102722a585d6e08a96c6191a6a04e43112c01dc5d/mem0ai-1.0.2.tar.gz", hash = "sha256:533c370e8a4e817d47a583cb7fa4df55db59de8dd67be39f2b927e2ad19607d1", size = 182395, upload-time = "2026-01-13T07:40:00.666Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/81/42/120d6db33e190ef09d69428ddd2eaaa87e10f4c8243af788f5fc524748c9/mem0ai-1.0.1-py3-none-any.whl", hash = "sha256:a8eeca9688e87f175af53d463b4a3b2d552984c81e29bc656c847dc04eaf6f75", size = 275351, upload-time = "2025-11-13T22:32:11.839Z" }, + { url = "https://files.pythonhosted.org/packages/d7/82/59309070bd2d2ddccebd89d8ebb7a2155ce12531f0c36123d0a39eada544/mem0ai-1.0.2-py3-none-any.whl", hash = "sha256:3528523653bc57efa477d55e703dcedf8decc23868d4dbcc6d43a97f2315834a", size = 275428, upload-time = "2026-01-13T07:39:58.339Z" }, ] [[package]] @@ -4470,15 +4476,15 @@ wheels = [ [[package]] name = "plotly" -version = "6.5.1" +version = "6.5.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "narwhals", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d6/ff/a4938b75e95114451efdb34db6b41930253e67efc8dc737bd592ef2e419d/plotly-6.5.1.tar.gz", hash = "sha256:b0478c8d5ada0c8756bce15315bcbfec7d3ab8d24614e34af9aff7bfcfea9281", size = 7014606, upload-time = "2026-01-07T20:11:41.644Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e3/4f/8a10a9b9f5192cb6fdef62f1d77fa7d834190b2c50c0cd256bd62879212b/plotly-6.5.2.tar.gz", hash = "sha256:7478555be0198562d1435dee4c308268187553cc15516a2f4dd034453699e393", size = 7015695, upload-time = "2026-01-14T21:26:51.222Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/8e/24e0bb90b2d75af84820693260c5534e9ed351afdda67ed6f393a141a0e2/plotly-6.5.1-py3-none-any.whl", hash = "sha256:5adad4f58c360612b6c5ce11a308cdbc4fd38ceb1d40594a614f0062e227abe1", size = 9894981, upload-time = "2026-01-07T20:11:38.124Z" }, + { url = "https://files.pythonhosted.org/packages/8a/67/f95b5460f127840310d2187f916cf0023b5875c0717fdf893f71e1325e87/plotly-6.5.2-py3-none-any.whl", hash = "sha256:91757653bd9c550eeea2fa2404dba6b85d1e366d54804c340b2c874e5a7eb4a4", size = 9895973, upload-time = "2026-01-14T21:26:47.135Z" }, ] [[package]] @@ -4515,30 +4521,30 @@ wheels = [ [[package]] name = "polars" -version = "1.37.0" +version = "1.37.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "polars-runtime-32", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6c/b5/ce40267c54b66f93572d84f7ba1c216b72a71cb2235e3724fab0911541fe/polars-1.37.0.tar.gz", hash = "sha256:6bbbeefb6f02f848d46ad4f4e922a92573986fd38611801c696bae98b02be4c8", size = 715429, upload-time = "2026-01-10T12:28:06.741Z" } +sdist = { url = "https://files.pythonhosted.org/packages/84/ae/dfebf31b9988c20998140b54d5b521f64ce08879f2c13d9b4d44d7c87e32/polars-1.37.1.tar.gz", hash = "sha256:0309e2a4633e712513401964b4d95452f124ceabf7aec6db50affb9ced4a274e", size = 715572, upload-time = "2026-01-12T23:27:03.267Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/31/07/d890382bbfdeb25db039ef4a8c8f93b3faf0016e18130513274204954203/polars-1.37.0-py3-none-any.whl", hash = "sha256:fcc549b9923ef1bd6fd99b5fd0a00dfedf85406f4758ae018a69bcd18a91f113", size = 805614, upload-time = "2026-01-10T12:26:47.897Z" }, + { url = "https://files.pythonhosted.org/packages/08/75/ec73e38812bca7c2240aff481b9ddff20d1ad2f10dee4b3353f5eeaacdab/polars-1.37.1-py3-none-any.whl", hash = "sha256:377fed8939a2f1223c1563cfabdc7b4a3d6ff846efa1f2ddeb8644fafd9b1aff", size = 805749, upload-time = "2026-01-12T23:25:48.595Z" }, ] [[package]] name = "polars-runtime-32" -version = "1.37.0" +version = "1.37.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/30/92/b818590a5ebcc55657f5483f26133174bd2b9ca88457b60c93669a9d0c75/polars_runtime_32-1.37.0.tar.gz", hash = "sha256:954ddb056e3a2db2cbcaae501225ac5604d1599b6debd9c6dbdf8efbac0e6511", size = 2820371, upload-time = "2026-01-10T12:28:08.195Z" } +sdist = { url = "https://files.pythonhosted.org/packages/40/0b/addabe5e8d28a5a4c9887a08907be7ddc3fce892dc38f37d14b055438a57/polars_runtime_32-1.37.1.tar.gz", hash = "sha256:68779d4a691da20a5eb767d74165a8f80a2bdfbde4b54acf59af43f7fa028d8f", size = 2818945, upload-time = "2026-01-12T23:27:04.653Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/67/76162c9fcc71b917bdfd2804eaf0ab7cdb264a89b89af4f195a918f9f97d/polars_runtime_32-1.37.0-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3591f4b8e734126d713a12869d3727360acbbcd1d440b45d830497a317a5a8b3", size = 43518436, upload-time = "2026-01-10T12:26:51.442Z" }, - { url = "https://files.pythonhosted.org/packages/cb/ec/56f328e8fa4ebea453f5bc10c579774dff774a873ff224b3108d53c514f9/polars_runtime_32-1.37.0-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:47849420859159681e94589daad3a04ff66a2379c116ccd812d043f7ffe0094c", size = 39663939, upload-time = "2026-01-10T12:26:54.664Z" }, - { url = "https://files.pythonhosted.org/packages/4c/b2/f1ea0edba327a92ce0158b7a0e4abe21f541e44c9fb8ec932cc47592ca5c/polars_runtime_32-1.37.0-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4648ea1e821b9a841b2a562f27bcf54ff1ad21f9c217adcf0f7d0b3c33dc6400", size = 41481348, upload-time = "2026-01-10T12:26:57.598Z" }, - { url = "https://files.pythonhosted.org/packages/3b/21/788a3dd724bb21cf42e2f4daa6510a47787e8b30dd535aa6cae20ea968d0/polars_runtime_32-1.37.0-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5272b6f1680a3e0d77c9f07cb5a54f307079eb5d519c71aa3c37b9af0ee03a9e", size = 45168069, upload-time = "2026-01-10T12:27:00.98Z" }, - { url = "https://files.pythonhosted.org/packages/8a/73/823d6534a20ebdcec4b7706ab2b3f2cfb8e07571305f4e7381cc22d83e31/polars_runtime_32-1.37.0-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:73301ef4fe80d8d748085259a4063ac52ff058088daa702e2a75e7d1ab7f14fc", size = 41675645, upload-time = "2026-01-10T12:27:04.334Z" }, - { url = "https://files.pythonhosted.org/packages/30/54/1bacad96dc2b67d33b886a45b249777212782561493718785cb27c7c362a/polars_runtime_32-1.37.0-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:c60d523d738a7b3660d9abdfaff798f7602488f469d427865965b0bd2e40473a", size = 44737715, upload-time = "2026-01-10T12:27:08.152Z" }, - { url = "https://files.pythonhosted.org/packages/38/e3/aad525d8d89b903fcfa2bd0b4cb66b8a6e83e80b3d1348c5a428092d2983/polars_runtime_32-1.37.0-cp310-abi3-win_amd64.whl", hash = "sha256:f87f76f16e8030d277ecca0c0976aca62ec2b6ba2099ee9c6f75dfc97e7dc1b1", size = 45018403, upload-time = "2026-01-10T12:27:11.292Z" }, - { url = "https://files.pythonhosted.org/packages/0e/4d/ddcaa5f2e18763e02e66d0fd2efca049a42fe96fbeda188e89aeb38dd6fa/polars_runtime_32-1.37.0-cp310-abi3-win_arm64.whl", hash = "sha256:7ffbd9487e3668b0a57519f7ab5ab53ab656086db9f62dceaab41393a07be721", size = 41026243, upload-time = "2026-01-10T12:27:14.563Z" }, + { url = "https://files.pythonhosted.org/packages/2a/a2/e828ea9f845796de02d923edb790e408ca0b560cd68dbd74bb99a1b3c461/polars_runtime_32-1.37.1-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:0b8d4d73ea9977d3731927740e59d814647c5198bdbe359bcf6a8bfce2e79771", size = 43499912, upload-time = "2026-01-12T23:25:51.182Z" }, + { url = "https://files.pythonhosted.org/packages/7e/46/81b71b7aa9e3703ee6e4ef1f69a87e40f58ea7c99212bf49a95071e99c8c/polars_runtime_32-1.37.1-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:c682bf83f5f352e5e02f5c16c652c48ca40442f07b236f30662b22217320ce76", size = 39695707, upload-time = "2026-01-12T23:25:54.289Z" }, + { url = "https://files.pythonhosted.org/packages/81/2e/20009d1fde7ee919e24040f5c87cb9d0e4f8e3f109b74ba06bc10c02459c/polars_runtime_32-1.37.1-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc82b5bbe70ca1a4b764eed1419f6336752d6ba9fc1245388d7f8b12438afa2c", size = 41467034, upload-time = "2026-01-12T23:25:56.925Z" }, + { url = "https://files.pythonhosted.org/packages/eb/21/9b55bea940524324625b1e8fd96233290303eb1bf2c23b54573487bbbc25/polars_runtime_32-1.37.1-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8362d11ac5193b994c7e9048ffe22ccfb976699cfbf6e128ce0302e06728894", size = 45142711, upload-time = "2026-01-12T23:26:00.817Z" }, + { url = "https://files.pythonhosted.org/packages/8c/25/c5f64461aeccdac6834a89f826d051ccd3b4ce204075e562c87a06ed2619/polars_runtime_32-1.37.1-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:04f5d5a2f013dca7391b7d8e7672fa6d37573a87f1d45d3dd5f0d9b5565a4b0f", size = 41638564, upload-time = "2026-01-12T23:26:04.186Z" }, + { url = "https://files.pythonhosted.org/packages/35/af/509d3cf6c45e764ccf856beaae26fc34352f16f10f94a7839b1042920a73/polars_runtime_32-1.37.1-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:fbfde7c0ca8209eeaed546e4a32cca1319189aa61c5f0f9a2b4494262bd0c689", size = 44721136, upload-time = "2026-01-12T23:26:07.088Z" }, + { url = "https://files.pythonhosted.org/packages/af/d1/5c0a83a625f72beef59394bebc57d12637997632a4f9d3ab2ffc2cc62bbf/polars_runtime_32-1.37.1-cp310-abi3-win_amd64.whl", hash = "sha256:da3d3642ae944e18dd17109d2a3036cb94ce50e5495c5023c77b1599d4c861bc", size = 44948288, upload-time = "2026-01-12T23:26:10.214Z" }, + { url = "https://files.pythonhosted.org/packages/10/f3/061bb702465904b6502f7c9081daee34b09ccbaa4f8c94cf43a2a3b6dd6f/polars_runtime_32-1.37.1-cp310-abi3-win_arm64.whl", hash = "sha256:55f2c4847a8d2e267612f564de7b753a4bde3902eaabe7b436a0a4abf75949a0", size = 41001914, upload-time = "2026-01-12T23:26:12.997Z" }, ] [[package]] @@ -4575,8 +4581,8 @@ name = "powerfx" version = "0.0.34" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cffi", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "pythonnet", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "cffi", marker = "(python_full_version < '3.14' and sys_platform == 'darwin') or (python_full_version < '3.14' and sys_platform == 'linux') or (python_full_version < '3.14' and sys_platform == 'win32')" }, + { name = "pythonnet", marker = "(python_full_version < '3.14' and sys_platform == 'darwin') or (python_full_version < '3.14' and sys_platform == 'linux') or (python_full_version < '3.14' and sys_platform == 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/9f/fb/6c4bf87e0c74ca1c563921ce89ca1c5785b7576bca932f7255cdf81082a7/powerfx-0.0.34.tar.gz", hash = "sha256:956992e7afd272657ed16d80f4cad24ec95d9e4a79fb9dfa4a068a09e136af32", size = 3237555, upload-time = "2025-12-22T15:50:59.682Z" } wheels = [ @@ -5243,7 +5249,7 @@ name = "pythonnet" version = "3.0.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "clr-loader", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "clr-loader", marker = "(python_full_version < '3.14' and sys_platform == 'darwin') or (python_full_version < '3.14' and sys_platform == 'linux') or (python_full_version < '3.14' and sys_platform == 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/9a/d6/1afd75edd932306ae9bd2c2d961d603dc2b52fcec51b04afea464f1f6646/pythonnet-3.0.5.tar.gz", hash = "sha256:48e43ca463941b3608b32b4e236db92d8d40db4c58a75ace902985f76dac21cf", size = 239212, upload-time = "2024-12-13T08:30:44.393Z" } wheels = [ @@ -5413,109 +5419,109 @@ wheels = [ [[package]] name = "regex" -version = "2025.11.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/cc/a9/546676f25e573a4cf00fe8e119b78a37b6a8fe2dc95cda877b30889c9c45/regex-2025.11.3.tar.gz", hash = "sha256:1fedc720f9bb2494ce31a58a1631f9c82df6a09b49c19517ea5cc280b4541e01", size = 414669, upload-time = "2025-11-03T21:34:22.089Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8a/d6/d788d52da01280a30a3f6268aef2aa71043bff359c618fea4c5b536654d5/regex-2025.11.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:2b441a4ae2c8049106e8b39973bfbddfb25a179dda2bdb99b0eeb60c40a6a3af", size = 488087, upload-time = "2025-11-03T21:30:47.317Z" }, - { url = "https://files.pythonhosted.org/packages/69/39/abec3bd688ec9bbea3562de0fd764ff802976185f5ff22807bf0a2697992/regex-2025.11.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2fa2eed3f76677777345d2f81ee89f5de2f5745910e805f7af7386a920fa7313", size = 290544, upload-time = "2025-11-03T21:30:49.912Z" }, - { url = "https://files.pythonhosted.org/packages/39/b3/9a231475d5653e60002508f41205c61684bb2ffbf2401351ae2186897fc4/regex-2025.11.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d8b4a27eebd684319bdf473d39f1d79eed36bf2cd34bd4465cdb4618d82b3d56", size = 288408, upload-time = "2025-11-03T21:30:51.344Z" }, - { url = "https://files.pythonhosted.org/packages/c3/c5/1929a0491bd5ac2d1539a866768b88965fa8c405f3e16a8cef84313098d6/regex-2025.11.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5cf77eac15bd264986c4a2c63353212c095b40f3affb2bc6b4ef80c4776c1a28", size = 781584, upload-time = "2025-11-03T21:30:52.596Z" }, - { url = "https://files.pythonhosted.org/packages/ce/fd/16aa16cf5d497ef727ec966f74164fbe75d6516d3d58ac9aa989bc9cdaad/regex-2025.11.3-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b7f9ee819f94c6abfa56ec7b1dbab586f41ebbdc0a57e6524bd5e7f487a878c7", size = 850733, upload-time = "2025-11-03T21:30:53.825Z" }, - { url = "https://files.pythonhosted.org/packages/e6/49/3294b988855a221cb6565189edf5dc43239957427df2d81d4a6b15244f64/regex-2025.11.3-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:838441333bc90b829406d4a03cb4b8bf7656231b84358628b0406d803931ef32", size = 898691, upload-time = "2025-11-03T21:30:55.575Z" }, - { url = "https://files.pythonhosted.org/packages/14/62/b56d29e70b03666193369bdbdedfdc23946dbe9f81dd78ce262c74d988ab/regex-2025.11.3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cfe6d3f0c9e3b7e8c0c694b24d25e677776f5ca26dce46fd6b0489f9c8339391", size = 791662, upload-time = "2025-11-03T21:30:57.262Z" }, - { url = "https://files.pythonhosted.org/packages/15/fc/e4c31d061eced63fbf1ce9d853975f912c61a7d406ea14eda2dd355f48e7/regex-2025.11.3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2ab815eb8a96379a27c3b6157fcb127c8f59c36f043c1678110cea492868f1d5", size = 782587, upload-time = "2025-11-03T21:30:58.788Z" }, - { url = "https://files.pythonhosted.org/packages/b2/bb/5e30c7394bcf63f0537121c23e796be67b55a8847c3956ae6068f4c70702/regex-2025.11.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:728a9d2d173a65b62bdc380b7932dd8e74ed4295279a8fe1021204ce210803e7", size = 774709, upload-time = "2025-11-03T21:31:00.081Z" }, - { url = "https://files.pythonhosted.org/packages/c5/c4/fce773710af81b0cb37cb4ff0947e75d5d17dee304b93d940b87a67fc2f4/regex-2025.11.3-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:509dc827f89c15c66a0c216331260d777dd6c81e9a4e4f830e662b0bb296c313", size = 845773, upload-time = "2025-11-03T21:31:01.583Z" }, - { url = "https://files.pythonhosted.org/packages/7b/5e/9466a7ec4b8ec282077095c6eb50a12a389d2e036581134d4919e8ca518c/regex-2025.11.3-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:849202cd789e5f3cf5dcc7822c34b502181b4824a65ff20ce82da5524e45e8e9", size = 836164, upload-time = "2025-11-03T21:31:03.244Z" }, - { url = "https://files.pythonhosted.org/packages/95/18/82980a60e8ed1594eb3c89eb814fb276ef51b9af7caeab1340bfd8564af6/regex-2025.11.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b6f78f98741dcc89607c16b1e9426ee46ce4bf31ac5e6b0d40e81c89f3481ea5", size = 779832, upload-time = "2025-11-03T21:31:04.876Z" }, - { url = "https://files.pythonhosted.org/packages/03/cc/90ab0fdbe6dce064a42015433f9152710139fb04a8b81b4fb57a1cb63ffa/regex-2025.11.3-cp310-cp310-win32.whl", hash = "sha256:149eb0bba95231fb4f6d37c8f760ec9fa6fabf65bab555e128dde5f2475193ec", size = 265802, upload-time = "2025-11-03T21:31:06.581Z" }, - { url = "https://files.pythonhosted.org/packages/34/9d/e9e8493a85f3b1ddc4a5014465f5c2b78c3ea1cbf238dcfde78956378041/regex-2025.11.3-cp310-cp310-win_amd64.whl", hash = "sha256:ee3a83ce492074c35a74cc76cf8235d49e77b757193a5365ff86e3f2f93db9fd", size = 277722, upload-time = "2025-11-03T21:31:08.144Z" }, - { url = "https://files.pythonhosted.org/packages/15/c4/b54b24f553966564506dbf873a3e080aef47b356a3b39b5d5aba992b50db/regex-2025.11.3-cp310-cp310-win_arm64.whl", hash = "sha256:38af559ad934a7b35147716655d4a2f79fcef2d695ddfe06a06ba40ae631fa7e", size = 270289, upload-time = "2025-11-03T21:31:10.267Z" }, - { url = "https://files.pythonhosted.org/packages/f7/90/4fb5056e5f03a7048abd2b11f598d464f0c167de4f2a51aa868c376b8c70/regex-2025.11.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:eadade04221641516fa25139273505a1c19f9bf97589a05bc4cfcd8b4a618031", size = 488081, upload-time = "2025-11-03T21:31:11.946Z" }, - { url = "https://files.pythonhosted.org/packages/85/23/63e481293fac8b069d84fba0299b6666df720d875110efd0338406b5d360/regex-2025.11.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:feff9e54ec0dd3833d659257f5c3f5322a12eee58ffa360984b716f8b92983f4", size = 290554, upload-time = "2025-11-03T21:31:13.387Z" }, - { url = "https://files.pythonhosted.org/packages/2b/9d/b101d0262ea293a0066b4522dfb722eb6a8785a8c3e084396a5f2c431a46/regex-2025.11.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3b30bc921d50365775c09a7ed446359e5c0179e9e2512beec4a60cbcef6ddd50", size = 288407, upload-time = "2025-11-03T21:31:14.809Z" }, - { url = "https://files.pythonhosted.org/packages/0c/64/79241c8209d5b7e00577ec9dca35cd493cc6be35b7d147eda367d6179f6d/regex-2025.11.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f99be08cfead2020c7ca6e396c13543baea32343b7a9a5780c462e323bd8872f", size = 793418, upload-time = "2025-11-03T21:31:16.556Z" }, - { url = "https://files.pythonhosted.org/packages/3d/e2/23cd5d3573901ce8f9757c92ca4db4d09600b865919b6d3e7f69f03b1afd/regex-2025.11.3-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6dd329a1b61c0ee95ba95385fb0c07ea0d3fe1a21e1349fa2bec272636217118", size = 860448, upload-time = "2025-11-03T21:31:18.12Z" }, - { url = "https://files.pythonhosted.org/packages/2a/4c/aecf31beeaa416d0ae4ecb852148d38db35391aac19c687b5d56aedf3a8b/regex-2025.11.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4c5238d32f3c5269d9e87be0cf096437b7622b6920f5eac4fd202468aaeb34d2", size = 907139, upload-time = "2025-11-03T21:31:20.753Z" }, - { url = "https://files.pythonhosted.org/packages/61/22/b8cb00df7d2b5e0875f60628594d44dba283e951b1ae17c12f99e332cc0a/regex-2025.11.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10483eefbfb0adb18ee9474498c9a32fcf4e594fbca0543bb94c48bac6183e2e", size = 800439, upload-time = "2025-11-03T21:31:22.069Z" }, - { url = "https://files.pythonhosted.org/packages/02/a8/c4b20330a5cdc7a8eb265f9ce593f389a6a88a0c5f280cf4d978f33966bc/regex-2025.11.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:78c2d02bb6e1da0720eedc0bad578049cad3f71050ef8cd065ecc87691bed2b0", size = 782965, upload-time = "2025-11-03T21:31:23.598Z" }, - { url = "https://files.pythonhosted.org/packages/b4/4c/ae3e52988ae74af4b04d2af32fee4e8077f26e51b62ec2d12d246876bea2/regex-2025.11.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e6b49cd2aad93a1790ce9cffb18964f6d3a4b0b3dbdbd5de094b65296fce6e58", size = 854398, upload-time = "2025-11-03T21:31:25.008Z" }, - { url = "https://files.pythonhosted.org/packages/06/d1/a8b9cf45874eda14b2e275157ce3b304c87e10fb38d9fc26a6e14eb18227/regex-2025.11.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:885b26aa3ee56433b630502dc3d36ba78d186a00cc535d3806e6bfd9ed3c70ab", size = 845897, upload-time = "2025-11-03T21:31:26.427Z" }, - { url = "https://files.pythonhosted.org/packages/ea/fe/1830eb0236be93d9b145e0bd8ab499f31602fe0999b1f19e99955aa8fe20/regex-2025.11.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ddd76a9f58e6a00f8772e72cff8ebcff78e022be95edf018766707c730593e1e", size = 788906, upload-time = "2025-11-03T21:31:28.078Z" }, - { url = "https://files.pythonhosted.org/packages/66/47/dc2577c1f95f188c1e13e2e69d8825a5ac582ac709942f8a03af42ed6e93/regex-2025.11.3-cp311-cp311-win32.whl", hash = "sha256:3e816cc9aac1cd3cc9a4ec4d860f06d40f994b5c7b4d03b93345f44e08cc68bf", size = 265812, upload-time = "2025-11-03T21:31:29.72Z" }, - { url = "https://files.pythonhosted.org/packages/50/1e/15f08b2f82a9bbb510621ec9042547b54d11e83cb620643ebb54e4eb7d71/regex-2025.11.3-cp311-cp311-win_amd64.whl", hash = "sha256:087511f5c8b7dfbe3a03f5d5ad0c2a33861b1fc387f21f6f60825a44865a385a", size = 277737, upload-time = "2025-11-03T21:31:31.422Z" }, - { url = "https://files.pythonhosted.org/packages/f4/fc/6500eb39f5f76c5e47a398df82e6b535a5e345f839581012a418b16f9cc3/regex-2025.11.3-cp311-cp311-win_arm64.whl", hash = "sha256:1ff0d190c7f68ae7769cd0313fe45820ba07ffebfddfaa89cc1eb70827ba0ddc", size = 270290, upload-time = "2025-11-03T21:31:33.041Z" }, - { url = "https://files.pythonhosted.org/packages/e8/74/18f04cb53e58e3fb107439699bd8375cf5a835eec81084e0bddbd122e4c2/regex-2025.11.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:bc8ab71e2e31b16e40868a40a69007bc305e1109bd4658eb6cad007e0bf67c41", size = 489312, upload-time = "2025-11-03T21:31:34.343Z" }, - { url = "https://files.pythonhosted.org/packages/78/3f/37fcdd0d2b1e78909108a876580485ea37c91e1acf66d3bb8e736348f441/regex-2025.11.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:22b29dda7e1f7062a52359fca6e58e548e28c6686f205e780b02ad8ef710de36", size = 291256, upload-time = "2025-11-03T21:31:35.675Z" }, - { url = "https://files.pythonhosted.org/packages/bf/26/0a575f58eb23b7ebd67a45fccbc02ac030b737b896b7e7a909ffe43ffd6a/regex-2025.11.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3a91e4a29938bc1a082cc28fdea44be420bf2bebe2665343029723892eb073e1", size = 288921, upload-time = "2025-11-03T21:31:37.07Z" }, - { url = "https://files.pythonhosted.org/packages/ea/98/6a8dff667d1af907150432cf5abc05a17ccd32c72a3615410d5365ac167a/regex-2025.11.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:08b884f4226602ad40c5d55f52bf91a9df30f513864e0054bad40c0e9cf1afb7", size = 798568, upload-time = "2025-11-03T21:31:38.784Z" }, - { url = "https://files.pythonhosted.org/packages/64/15/92c1db4fa4e12733dd5a526c2dd2b6edcbfe13257e135fc0f6c57f34c173/regex-2025.11.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3e0b11b2b2433d1c39c7c7a30e3f3d0aeeea44c2a8d0bae28f6b95f639927a69", size = 864165, upload-time = "2025-11-03T21:31:40.559Z" }, - { url = "https://files.pythonhosted.org/packages/f9/e7/3ad7da8cdee1ce66c7cd37ab5ab05c463a86ffeb52b1a25fe7bd9293b36c/regex-2025.11.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:87eb52a81ef58c7ba4d45c3ca74e12aa4b4e77816f72ca25258a85b3ea96cb48", size = 912182, upload-time = "2025-11-03T21:31:42.002Z" }, - { url = "https://files.pythonhosted.org/packages/84/bd/9ce9f629fcb714ffc2c3faf62b6766ecb7a585e1e885eb699bcf130a5209/regex-2025.11.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a12ab1f5c29b4e93db518f5e3872116b7e9b1646c9f9f426f777b50d44a09e8c", size = 803501, upload-time = "2025-11-03T21:31:43.815Z" }, - { url = "https://files.pythonhosted.org/packages/7c/0f/8dc2e4349d8e877283e6edd6c12bdcebc20f03744e86f197ab6e4492bf08/regex-2025.11.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7521684c8c7c4f6e88e35ec89680ee1aa8358d3f09d27dfbdf62c446f5d4c695", size = 787842, upload-time = "2025-11-03T21:31:45.353Z" }, - { url = "https://files.pythonhosted.org/packages/f9/73/cff02702960bc185164d5619c0c62a2f598a6abff6695d391b096237d4ab/regex-2025.11.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:7fe6e5440584e94cc4b3f5f4d98a25e29ca12dccf8873679a635638349831b98", size = 858519, upload-time = "2025-11-03T21:31:46.814Z" }, - { url = "https://files.pythonhosted.org/packages/61/83/0e8d1ae71e15bc1dc36231c90b46ee35f9d52fab2e226b0e039e7ea9c10a/regex-2025.11.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:8e026094aa12b43f4fd74576714e987803a315c76edb6b098b9809db5de58f74", size = 850611, upload-time = "2025-11-03T21:31:48.289Z" }, - { url = "https://files.pythonhosted.org/packages/c8/f5/70a5cdd781dcfaa12556f2955bf170cd603cb1c96a1827479f8faea2df97/regex-2025.11.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:435bbad13e57eb5606a68443af62bed3556de2f46deb9f7d4237bc2f1c9fb3a0", size = 789759, upload-time = "2025-11-03T21:31:49.759Z" }, - { url = "https://files.pythonhosted.org/packages/59/9b/7c29be7903c318488983e7d97abcf8ebd3830e4c956c4c540005fcfb0462/regex-2025.11.3-cp312-cp312-win32.whl", hash = "sha256:3839967cf4dc4b985e1570fd8d91078f0c519f30491c60f9ac42a8db039be204", size = 266194, upload-time = "2025-11-03T21:31:51.53Z" }, - { url = "https://files.pythonhosted.org/packages/1a/67/3b92df89f179d7c367be654ab5626ae311cb28f7d5c237b6bb976cd5fbbb/regex-2025.11.3-cp312-cp312-win_amd64.whl", hash = "sha256:e721d1b46e25c481dc5ded6f4b3f66c897c58d2e8cfdf77bbced84339108b0b9", size = 277069, upload-time = "2025-11-03T21:31:53.151Z" }, - { url = "https://files.pythonhosted.org/packages/d7/55/85ba4c066fe5094d35b249c3ce8df0ba623cfd35afb22d6764f23a52a1c5/regex-2025.11.3-cp312-cp312-win_arm64.whl", hash = "sha256:64350685ff08b1d3a6fff33f45a9ca183dc1d58bbfe4981604e70ec9801bbc26", size = 270330, upload-time = "2025-11-03T21:31:54.514Z" }, - { url = "https://files.pythonhosted.org/packages/e1/a7/dda24ebd49da46a197436ad96378f17df30ceb40e52e859fc42cac45b850/regex-2025.11.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:c1e448051717a334891f2b9a620fe36776ebf3dd8ec46a0b877c8ae69575feb4", size = 489081, upload-time = "2025-11-03T21:31:55.9Z" }, - { url = "https://files.pythonhosted.org/packages/19/22/af2dc751aacf88089836aa088a1a11c4f21a04707eb1b0478e8e8fb32847/regex-2025.11.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:9b5aca4d5dfd7fbfbfbdaf44850fcc7709a01146a797536a8f84952e940cca76", size = 291123, upload-time = "2025-11-03T21:31:57.758Z" }, - { url = "https://files.pythonhosted.org/packages/a3/88/1a3ea5672f4b0a84802ee9891b86743438e7c04eb0b8f8c4e16a42375327/regex-2025.11.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:04d2765516395cf7dda331a244a3282c0f5ae96075f728629287dfa6f76ba70a", size = 288814, upload-time = "2025-11-03T21:32:01.12Z" }, - { url = "https://files.pythonhosted.org/packages/fb/8c/f5987895bf42b8ddeea1b315c9fedcfe07cadee28b9c98cf50d00adcb14d/regex-2025.11.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5d9903ca42bfeec4cebedba8022a7c97ad2aab22e09573ce9976ba01b65e4361", size = 798592, upload-time = "2025-11-03T21:32:03.006Z" }, - { url = "https://files.pythonhosted.org/packages/99/2a/6591ebeede78203fa77ee46a1c36649e02df9eaa77a033d1ccdf2fcd5d4e/regex-2025.11.3-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:639431bdc89d6429f6721625e8129413980ccd62e9d3f496be618a41d205f160", size = 864122, upload-time = "2025-11-03T21:32:04.553Z" }, - { url = "https://files.pythonhosted.org/packages/94/d6/be32a87cf28cf8ed064ff281cfbd49aefd90242a83e4b08b5a86b38e8eb4/regex-2025.11.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f117efad42068f9715677c8523ed2be1518116d1c49b1dd17987716695181efe", size = 912272, upload-time = "2025-11-03T21:32:06.148Z" }, - { url = "https://files.pythonhosted.org/packages/62/11/9bcef2d1445665b180ac7f230406ad80671f0fc2a6ffb93493b5dd8cd64c/regex-2025.11.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4aecb6f461316adf9f1f0f6a4a1a3d79e045f9b71ec76055a791affa3b285850", size = 803497, upload-time = "2025-11-03T21:32:08.162Z" }, - { url = "https://files.pythonhosted.org/packages/e5/a7/da0dc273d57f560399aa16d8a68ae7f9b57679476fc7ace46501d455fe84/regex-2025.11.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3b3a5f320136873cc5561098dfab677eea139521cb9a9e8db98b7e64aef44cbc", size = 787892, upload-time = "2025-11-03T21:32:09.769Z" }, - { url = "https://files.pythonhosted.org/packages/da/4b/732a0c5a9736a0b8d6d720d4945a2f1e6f38f87f48f3173559f53e8d5d82/regex-2025.11.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:75fa6f0056e7efb1f42a1c34e58be24072cb9e61a601340cc1196ae92326a4f9", size = 858462, upload-time = "2025-11-03T21:32:11.769Z" }, - { url = "https://files.pythonhosted.org/packages/0c/f5/a2a03df27dc4c2d0c769220f5110ba8c4084b0bfa9ab0f9b4fcfa3d2b0fc/regex-2025.11.3-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:dbe6095001465294f13f1adcd3311e50dd84e5a71525f20a10bd16689c61ce0b", size = 850528, upload-time = "2025-11-03T21:32:13.906Z" }, - { url = "https://files.pythonhosted.org/packages/d6/09/e1cd5bee3841c7f6eb37d95ca91cdee7100b8f88b81e41c2ef426910891a/regex-2025.11.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:454d9b4ae7881afbc25015b8627c16d88a597479b9dea82b8c6e7e2e07240dc7", size = 789866, upload-time = "2025-11-03T21:32:15.748Z" }, - { url = "https://files.pythonhosted.org/packages/eb/51/702f5ea74e2a9c13d855a6a85b7f80c30f9e72a95493260193c07f3f8d74/regex-2025.11.3-cp313-cp313-win32.whl", hash = "sha256:28ba4d69171fc6e9896337d4fc63a43660002b7da53fc15ac992abcf3410917c", size = 266189, upload-time = "2025-11-03T21:32:17.493Z" }, - { url = "https://files.pythonhosted.org/packages/8b/00/6e29bb314e271a743170e53649db0fdb8e8ff0b64b4f425f5602f4eb9014/regex-2025.11.3-cp313-cp313-win_amd64.whl", hash = "sha256:bac4200befe50c670c405dc33af26dad5a3b6b255dd6c000d92fe4629f9ed6a5", size = 277054, upload-time = "2025-11-03T21:32:19.042Z" }, - { url = "https://files.pythonhosted.org/packages/25/f1/b156ff9f2ec9ac441710764dda95e4edaf5f36aca48246d1eea3f1fd96ec/regex-2025.11.3-cp313-cp313-win_arm64.whl", hash = "sha256:2292cd5a90dab247f9abe892ac584cb24f0f54680c73fcb4a7493c66c2bf2467", size = 270325, upload-time = "2025-11-03T21:32:21.338Z" }, - { url = "https://files.pythonhosted.org/packages/20/28/fd0c63357caefe5680b8ea052131acbd7f456893b69cc2a90cc3e0dc90d4/regex-2025.11.3-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:1eb1ebf6822b756c723e09f5186473d93236c06c579d2cc0671a722d2ab14281", size = 491984, upload-time = "2025-11-03T21:32:23.466Z" }, - { url = "https://files.pythonhosted.org/packages/df/ec/7014c15626ab46b902b3bcc4b28a7bae46d8f281fc7ea9c95e22fcaaa917/regex-2025.11.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:1e00ec2970aab10dc5db34af535f21fcf32b4a31d99e34963419636e2f85ae39", size = 292673, upload-time = "2025-11-03T21:32:25.034Z" }, - { url = "https://files.pythonhosted.org/packages/23/ab/3b952ff7239f20d05f1f99e9e20188513905f218c81d52fb5e78d2bf7634/regex-2025.11.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a4cb042b615245d5ff9b3794f56be4138b5adc35a4166014d31d1814744148c7", size = 291029, upload-time = "2025-11-03T21:32:26.528Z" }, - { url = "https://files.pythonhosted.org/packages/21/7e/3dc2749fc684f455f162dcafb8a187b559e2614f3826877d3844a131f37b/regex-2025.11.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:44f264d4bf02f3176467d90b294d59bf1db9fe53c141ff772f27a8b456b2a9ed", size = 807437, upload-time = "2025-11-03T21:32:28.363Z" }, - { url = "https://files.pythonhosted.org/packages/1b/0b/d529a85ab349c6a25d1ca783235b6e3eedf187247eab536797021f7126c6/regex-2025.11.3-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7be0277469bf3bd7a34a9c57c1b6a724532a0d235cd0dc4e7f4316f982c28b19", size = 873368, upload-time = "2025-11-03T21:32:30.4Z" }, - { url = "https://files.pythonhosted.org/packages/7d/18/2d868155f8c9e3e9d8f9e10c64e9a9f496bb8f7e037a88a8bed26b435af6/regex-2025.11.3-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0d31e08426ff4b5b650f68839f5af51a92a5b51abd8554a60c2fbc7c71f25d0b", size = 914921, upload-time = "2025-11-03T21:32:32.123Z" }, - { url = "https://files.pythonhosted.org/packages/2d/71/9d72ff0f354fa783fe2ba913c8734c3b433b86406117a8db4ea2bf1c7a2f/regex-2025.11.3-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e43586ce5bd28f9f285a6e729466841368c4a0353f6fd08d4ce4630843d3648a", size = 812708, upload-time = "2025-11-03T21:32:34.305Z" }, - { url = "https://files.pythonhosted.org/packages/e7/19/ce4bf7f5575c97f82b6e804ffb5c4e940c62609ab2a0d9538d47a7fdf7d4/regex-2025.11.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:0f9397d561a4c16829d4e6ff75202c1c08b68a3bdbfe29dbfcdb31c9830907c6", size = 795472, upload-time = "2025-11-03T21:32:36.364Z" }, - { url = "https://files.pythonhosted.org/packages/03/86/fd1063a176ffb7b2315f9a1b08d17b18118b28d9df163132615b835a26ee/regex-2025.11.3-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:dd16e78eb18ffdb25ee33a0682d17912e8cc8a770e885aeee95020046128f1ce", size = 868341, upload-time = "2025-11-03T21:32:38.042Z" }, - { url = "https://files.pythonhosted.org/packages/12/43/103fb2e9811205e7386366501bc866a164a0430c79dd59eac886a2822950/regex-2025.11.3-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:ffcca5b9efe948ba0661e9df0fa50d2bc4b097c70b9810212d6b62f05d83b2dd", size = 854666, upload-time = "2025-11-03T21:32:40.079Z" }, - { url = "https://files.pythonhosted.org/packages/7d/22/e392e53f3869b75804762c7c848bd2dd2abf2b70fb0e526f58724638bd35/regex-2025.11.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:c56b4d162ca2b43318ac671c65bd4d563e841a694ac70e1a976ac38fcf4ca1d2", size = 799473, upload-time = "2025-11-03T21:32:42.148Z" }, - { url = "https://files.pythonhosted.org/packages/4f/f9/8bd6b656592f925b6845fcbb4d57603a3ac2fb2373344ffa1ed70aa6820a/regex-2025.11.3-cp313-cp313t-win32.whl", hash = "sha256:9ddc42e68114e161e51e272f667d640f97e84a2b9ef14b7477c53aac20c2d59a", size = 268792, upload-time = "2025-11-03T21:32:44.13Z" }, - { url = "https://files.pythonhosted.org/packages/e5/87/0e7d603467775ff65cd2aeabf1b5b50cc1c3708556a8b849a2fa4dd1542b/regex-2025.11.3-cp313-cp313t-win_amd64.whl", hash = "sha256:7a7c7fdf755032ffdd72c77e3d8096bdcb0eb92e89e17571a196f03d88b11b3c", size = 280214, upload-time = "2025-11-03T21:32:45.853Z" }, - { url = "https://files.pythonhosted.org/packages/8d/d0/2afc6f8e94e2b64bfb738a7c2b6387ac1699f09f032d363ed9447fd2bb57/regex-2025.11.3-cp313-cp313t-win_arm64.whl", hash = "sha256:df9eb838c44f570283712e7cff14c16329a9f0fb19ca492d21d4b7528ee6821e", size = 271469, upload-time = "2025-11-03T21:32:48.026Z" }, - { url = "https://files.pythonhosted.org/packages/31/e9/f6e13de7e0983837f7b6d238ad9458800a874bf37c264f7923e63409944c/regex-2025.11.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:9697a52e57576c83139d7c6f213d64485d3df5bf84807c35fa409e6c970801c6", size = 489089, upload-time = "2025-11-03T21:32:50.027Z" }, - { url = "https://files.pythonhosted.org/packages/a3/5c/261f4a262f1fa65141c1b74b255988bd2fa020cc599e53b080667d591cfc/regex-2025.11.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e18bc3f73bd41243c9b38a6d9f2366cd0e0137a9aebe2d8ff76c5b67d4c0a3f4", size = 291059, upload-time = "2025-11-03T21:32:51.682Z" }, - { url = "https://files.pythonhosted.org/packages/8e/57/f14eeb7f072b0e9a5a090d1712741fd8f214ec193dba773cf5410108bb7d/regex-2025.11.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:61a08bcb0ec14ff4e0ed2044aad948d0659604f824cbd50b55e30b0ec6f09c73", size = 288900, upload-time = "2025-11-03T21:32:53.569Z" }, - { url = "https://files.pythonhosted.org/packages/3c/6b/1d650c45e99a9b327586739d926a1cd4e94666b1bd4af90428b36af66dc7/regex-2025.11.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c9c30003b9347c24bcc210958c5d167b9e4f9be786cb380a7d32f14f9b84674f", size = 799010, upload-time = "2025-11-03T21:32:55.222Z" }, - { url = "https://files.pythonhosted.org/packages/99/ee/d66dcbc6b628ce4e3f7f0cbbb84603aa2fc0ffc878babc857726b8aab2e9/regex-2025.11.3-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4e1e592789704459900728d88d41a46fe3969b82ab62945560a31732ffc19a6d", size = 864893, upload-time = "2025-11-03T21:32:57.239Z" }, - { url = "https://files.pythonhosted.org/packages/bf/2d/f238229f1caba7ac87a6c4153d79947fb0261415827ae0f77c304260c7d3/regex-2025.11.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:6538241f45eb5a25aa575dbba1069ad786f68a4f2773a29a2bd3dd1f9de787be", size = 911522, upload-time = "2025-11-03T21:32:59.274Z" }, - { url = "https://files.pythonhosted.org/packages/bd/3d/22a4eaba214a917c80e04f6025d26143690f0419511e0116508e24b11c9b/regex-2025.11.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bce22519c989bb72a7e6b36a199384c53db7722fe669ba891da75907fe3587db", size = 803272, upload-time = "2025-11-03T21:33:01.393Z" }, - { url = "https://files.pythonhosted.org/packages/84/b1/03188f634a409353a84b5ef49754b97dbcc0c0f6fd6c8ede505a8960a0a4/regex-2025.11.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:66d559b21d3640203ab9075797a55165d79017520685fb407b9234d72ab63c62", size = 787958, upload-time = "2025-11-03T21:33:03.379Z" }, - { url = "https://files.pythonhosted.org/packages/99/6a/27d072f7fbf6fadd59c64d210305e1ff865cc3b78b526fd147db768c553b/regex-2025.11.3-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:669dcfb2e38f9e8c69507bace46f4889e3abbfd9b0c29719202883c0a603598f", size = 859289, upload-time = "2025-11-03T21:33:05.374Z" }, - { url = "https://files.pythonhosted.org/packages/9a/70/1b3878f648e0b6abe023172dacb02157e685564853cc363d9961bcccde4e/regex-2025.11.3-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:32f74f35ff0f25a5021373ac61442edcb150731fbaa28286bbc8bb1582c89d02", size = 850026, upload-time = "2025-11-03T21:33:07.131Z" }, - { url = "https://files.pythonhosted.org/packages/dd/d5/68e25559b526b8baab8e66839304ede68ff6727237a47727d240006bd0ff/regex-2025.11.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e6c7a21dffba883234baefe91bc3388e629779582038f75d2a5be918e250f0ed", size = 789499, upload-time = "2025-11-03T21:33:09.141Z" }, - { url = "https://files.pythonhosted.org/packages/fc/df/43971264857140a350910d4e33df725e8c94dd9dee8d2e4729fa0d63d49e/regex-2025.11.3-cp314-cp314-win32.whl", hash = "sha256:795ea137b1d809eb6836b43748b12634291c0ed55ad50a7d72d21edf1cd565c4", size = 271604, upload-time = "2025-11-03T21:33:10.9Z" }, - { url = "https://files.pythonhosted.org/packages/01/6f/9711b57dc6894a55faf80a4c1b5aa4f8649805cb9c7aef46f7d27e2b9206/regex-2025.11.3-cp314-cp314-win_amd64.whl", hash = "sha256:9f95fbaa0ee1610ec0fc6b26668e9917a582ba80c52cc6d9ada15e30aa9ab9ad", size = 280320, upload-time = "2025-11-03T21:33:12.572Z" }, - { url = "https://files.pythonhosted.org/packages/f1/7e/f6eaa207d4377481f5e1775cdeb5a443b5a59b392d0065f3417d31d80f87/regex-2025.11.3-cp314-cp314-win_arm64.whl", hash = "sha256:dfec44d532be4c07088c3de2876130ff0fbeeacaa89a137decbbb5f665855a0f", size = 273372, upload-time = "2025-11-03T21:33:14.219Z" }, - { url = "https://files.pythonhosted.org/packages/c3/06/49b198550ee0f5e4184271cee87ba4dfd9692c91ec55289e6282f0f86ccf/regex-2025.11.3-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:ba0d8a5d7f04f73ee7d01d974d47c5834f8a1b0224390e4fe7c12a3a92a78ecc", size = 491985, upload-time = "2025-11-03T21:33:16.555Z" }, - { url = "https://files.pythonhosted.org/packages/ce/bf/abdafade008f0b1c9da10d934034cb670432d6cf6cbe38bbb53a1cfd6cf8/regex-2025.11.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:442d86cf1cfe4faabf97db7d901ef58347efd004934da045c745e7b5bd57ac49", size = 292669, upload-time = "2025-11-03T21:33:18.32Z" }, - { url = "https://files.pythonhosted.org/packages/f9/ef/0c357bb8edbd2ad8e273fcb9e1761bc37b8acbc6e1be050bebd6475f19c1/regex-2025.11.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:fd0a5e563c756de210bb964789b5abe4f114dacae9104a47e1a649b910361536", size = 291030, upload-time = "2025-11-03T21:33:20.048Z" }, - { url = "https://files.pythonhosted.org/packages/79/06/edbb67257596649b8fb088d6aeacbcb248ac195714b18a65e018bf4c0b50/regex-2025.11.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bf3490bcbb985a1ae97b2ce9ad1c0f06a852d5b19dde9b07bdf25bf224248c95", size = 807674, upload-time = "2025-11-03T21:33:21.797Z" }, - { url = "https://files.pythonhosted.org/packages/f4/d9/ad4deccfce0ea336296bd087f1a191543bb99ee1c53093dcd4c64d951d00/regex-2025.11.3-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3809988f0a8b8c9dcc0f92478d6501fac7200b9ec56aecf0ec21f4a2ec4b6009", size = 873451, upload-time = "2025-11-03T21:33:23.741Z" }, - { url = "https://files.pythonhosted.org/packages/13/75/a55a4724c56ef13e3e04acaab29df26582f6978c000ac9cd6810ad1f341f/regex-2025.11.3-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f4ff94e58e84aedb9c9fce66d4ef9f27a190285b451420f297c9a09f2b9abee9", size = 914980, upload-time = "2025-11-03T21:33:25.999Z" }, - { url = "https://files.pythonhosted.org/packages/67/1e/a1657ee15bd9116f70d4a530c736983eed997b361e20ecd8f5ca3759d5c5/regex-2025.11.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7eb542fd347ce61e1321b0a6b945d5701528dca0cd9759c2e3bb8bd57e47964d", size = 812852, upload-time = "2025-11-03T21:33:27.852Z" }, - { url = "https://files.pythonhosted.org/packages/b8/6f/f7516dde5506a588a561d296b2d0044839de06035bb486b326065b4c101e/regex-2025.11.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:d6c2d5919075a1f2e413c00b056ea0c2f065b3f5fe83c3d07d325ab92dce51d6", size = 795566, upload-time = "2025-11-03T21:33:32.364Z" }, - { url = "https://files.pythonhosted.org/packages/d9/dd/3d10b9e170cc16fb34cb2cef91513cf3df65f440b3366030631b2984a264/regex-2025.11.3-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:3f8bf11a4827cc7ce5a53d4ef6cddd5ad25595d3c1435ef08f76825851343154", size = 868463, upload-time = "2025-11-03T21:33:34.459Z" }, - { url = "https://files.pythonhosted.org/packages/f5/8e/935e6beff1695aa9085ff83195daccd72acc82c81793df480f34569330de/regex-2025.11.3-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:22c12d837298651e5550ac1d964e4ff57c3f56965fc1812c90c9fb2028eaf267", size = 854694, upload-time = "2025-11-03T21:33:36.793Z" }, - { url = "https://files.pythonhosted.org/packages/92/12/10650181a040978b2f5720a6a74d44f841371a3d984c2083fc1752e4acf6/regex-2025.11.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:62ba394a3dda9ad41c7c780f60f6e4a70988741415ae96f6d1bf6c239cf01379", size = 799691, upload-time = "2025-11-03T21:33:39.079Z" }, - { url = "https://files.pythonhosted.org/packages/67/90/8f37138181c9a7690e7e4cb388debbd389342db3c7381d636d2875940752/regex-2025.11.3-cp314-cp314t-win32.whl", hash = "sha256:4bf146dca15cdd53224a1bf46d628bd7590e4a07fbb69e720d561aea43a32b38", size = 274583, upload-time = "2025-11-03T21:33:41.302Z" }, - { url = "https://files.pythonhosted.org/packages/8f/cd/867f5ec442d56beb56f5f854f40abcfc75e11d10b11fdb1869dd39c63aaf/regex-2025.11.3-cp314-cp314t-win_amd64.whl", hash = "sha256:adad1a1bcf1c9e76346e091d22d23ac54ef28e1365117d99521631078dfec9de", size = 284286, upload-time = "2025-11-03T21:33:43.324Z" }, - { url = "https://files.pythonhosted.org/packages/20/31/32c0c4610cbc070362bf1d2e4ea86d1ea29014d400a6d6c2486fcfd57766/regex-2025.11.3-cp314-cp314t-win_arm64.whl", hash = "sha256:c54f768482cef41e219720013cd05933b6f971d9562544d691c68699bf2b6801", size = 274741, upload-time = "2025-11-03T21:33:45.557Z" }, +version = "2026.1.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ad/f2/638ef50852dc5741dc3bb3c7d4e773d637bc20232965ef8b6e7f6f7d4445/regex-2026.1.14.tar.gz", hash = "sha256:7bdd569b6226498001619751abe6ba3c9e3050f79cfe097e84f25b2856120e78", size = 414813, upload-time = "2026-01-14T17:53:31.244Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/e4/114fd21f052d96c955223d7640ff0ca6960af3d3310d2ecda92b5b2d9720/regex-2026.1.14-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4a19a1ec7f2450d5705c261f3a993a282fc8e165ff5bdc326515d00cee73d302", size = 488172, upload-time = "2026-01-14T17:50:35.856Z" }, + { url = "https://files.pythonhosted.org/packages/99/66/f49a04aa1ecd9c583f296b99683cc0e2c8353eb4a203868edb276aef197e/regex-2026.1.14-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a032181975313ebaecac60a5f65d65055fc3067a0a57e7683efce2a7d024af73", size = 290637, upload-time = "2026-01-14T17:50:38.53Z" }, + { url = "https://files.pythonhosted.org/packages/b5/07/f1ab12a4096e20aec12149b03e69da7b502d2798865a8421d149db720f91/regex-2026.1.14-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e43e615f2d666dfb21dd2954c57fd06fe865d517adf967c33f4d4a545f308068", size = 288498, upload-time = "2026-01-14T17:50:39.596Z" }, + { url = "https://files.pythonhosted.org/packages/70/0e/b618d756c6a5d515e9813505f43db3649a6f645f4fdd2136f0cb62e360ce/regex-2026.1.14-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:26ebd2577bb6dd00382f69722f4bc4c533e00a0c0a93beb09906bfa60b8d61fc", size = 781668, upload-time = "2026-01-14T17:50:40.678Z" }, + { url = "https://files.pythonhosted.org/packages/63/63/49973ec84c00ca688ee3714c466797f7d8a6f603fce3edbb2d771838788b/regex-2026.1.14-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:81e2edde6674dbc021222990c62402135c2218352ebbb0396365098fce0c38bc", size = 850818, upload-time = "2026-01-14T17:50:44.07Z" }, + { url = "https://files.pythonhosted.org/packages/ef/8c/685975eb8e7021da3ac0ce04ee6a857e8d9f5afb1581521deb4b9f5af721/regex-2026.1.14-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:fb03a4c74750096a7eb685313ca58a2b99b12451bb10939d02338fe1f81b25ac", size = 898774, upload-time = "2026-01-14T17:50:45.58Z" }, + { url = "https://files.pythonhosted.org/packages/c7/08/5f3a3d02804ae72972b257e9996fd95f069e40bb5465ff446871d7b6b839/regex-2026.1.14-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c5643c8fd94e0163d536180dc2a0cb512ed8f1257535d8974e0dc45a25f19e03", size = 791747, upload-time = "2026-01-14T17:50:47.138Z" }, + { url = "https://files.pythonhosted.org/packages/7b/44/f41de01adb8caa6a0ef0db81a2c208ce418f8d9e27ee30f143c4cc702348/regex-2026.1.14-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4ca389b5f6220ceb32ffbb878206faeb1cb7cbb276c683a7acf0abc4c8ac4b86", size = 782672, upload-time = "2026-01-14T17:50:48.445Z" }, + { url = "https://files.pythonhosted.org/packages/70/40/0f559f585e429f1b65b153b9f7b72457da22a961686e94d54a159fbb61ec/regex-2026.1.14-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ee317cf3f07ab1667c4b7ac9785a29a4ba40c96916c07a7550daf66605c3b98f", size = 774796, upload-time = "2026-01-14T17:50:49.625Z" }, + { url = "https://files.pythonhosted.org/packages/4a/e3/1a4ad3637f78f4f76a75d4429d2138fb23c4ca97f0dba51b15a9aa531f7b/regex-2026.1.14-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:83e7af1632a9ec7fb10759cc34986d0e7295fea1ed07a912af8cab7f52d5ad08", size = 845857, upload-time = "2026-01-14T17:50:51.316Z" }, + { url = "https://files.pythonhosted.org/packages/3d/61/55d8600387be2170c6703140bd3993db855dffc59f398d54f0a9d55c4be7/regex-2026.1.14-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:edbfa0ce6c556d181981c669d3f37cf0906ed031618ef6596bf6efcbacbd9a56", size = 836247, upload-time = "2026-01-14T17:50:53.048Z" }, + { url = "https://files.pythonhosted.org/packages/63/7c/66d1a776112d71ead4a6cd040d5a95705cc6f2513d151eb91549379ab627/regex-2026.1.14-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:51d677fd5e75a456cae0906a42f8797baecc70832ba527efc8b1840525e430d9", size = 779917, upload-time = "2026-01-14T17:50:54.548Z" }, + { url = "https://files.pythonhosted.org/packages/87/3b/3c73aaa87c32dff0e4ad0eb4f32d19858e1e0235e87fab2aed4b3d84297b/regex-2026.1.14-cp310-cp310-win32.whl", hash = "sha256:ab3b1ab5c26cfb6fdb94857ad879bd8c2340eaefe8d535857056571898b90a71", size = 265887, upload-time = "2026-01-14T17:50:55.759Z" }, + { url = "https://files.pythonhosted.org/packages/ea/36/847eb03220efa29be8187d305b307cefedd652729031acab9bed15085188/regex-2026.1.14-cp310-cp310-win_amd64.whl", hash = "sha256:2492698f88f5f455ed1945186cf818a3d75cc3c019ff3159e331d22d8de50106", size = 277829, upload-time = "2026-01-14T17:50:56.91Z" }, + { url = "https://files.pythonhosted.org/packages/6f/24/ae92d9149922ab4d979933ea0e1d2a5308094dc3934681da16c1fe9f10d3/regex-2026.1.14-cp310-cp310-win_arm64.whl", hash = "sha256:5048bcedd0c95173ca4ddbb97cb683b31b19a5b9fa054f7c9d40efea588ded4f", size = 270374, upload-time = "2026-01-14T17:50:58.024Z" }, + { url = "https://files.pythonhosted.org/packages/99/c4/7a01b922a6457f2bead58fabe5ea68d184a0c3ec0fffc71127c82af4a65f/regex-2026.1.14-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0dc77a121d04834bd56d43b38d8d15cb7a1b65ecd019a1d297527f8d1f99d993", size = 488170, upload-time = "2026-01-14T17:50:59.131Z" }, + { url = "https://files.pythonhosted.org/packages/a1/7f/2be92bb28c03000310c9c3f7c6953b38103b457b1f964a0ac3d815c996a0/regex-2026.1.14-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2fac804af5d1e69d43356e86af8f3c5e9a1610f9f1ffccaee2a7c450a029eb26", size = 290631, upload-time = "2026-01-14T17:51:00.278Z" }, + { url = "https://files.pythonhosted.org/packages/b5/96/771b24d2661b51c7fd911240e5e4ad3e5a5f091ae4f78a0deeac2d36c460/regex-2026.1.14-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5268cd9383acf03fd9bbe5e434a97e238238c817972a2ae65b0025b626d6ac61", size = 288493, upload-time = "2026-01-14T17:51:01.433Z" }, + { url = "https://files.pythonhosted.org/packages/b1/16/1f13d159db18ad3bac8ccc41687370197fa7f482094a1f5b5cddb0e065f0/regex-2026.1.14-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:14042ef29c801580a678c3b35a5ab3a1bcc535de98a1d1469d00a25043654510", size = 793501, upload-time = "2026-01-14T17:51:02.526Z" }, + { url = "https://files.pythonhosted.org/packages/07/33/69fb29f98c0df7f5c2509eb9dfe7ecba4f21287b9b0e45504c059ea42c1c/regex-2026.1.14-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:28f76df5145800f27f90b0fc59c92175858c120589c49e2ce2c383f9de52ad10", size = 860533, upload-time = "2026-01-14T17:51:03.76Z" }, + { url = "https://files.pythonhosted.org/packages/f3/ea/9a4f63dcf64835fe866df9f47220dbc1a5e806874db98778d77cd1302682/regex-2026.1.14-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:aa466e620c7e68ad9cb15c6e6e7bad9c0140f070b4a34d9c62fc233a32eed887", size = 907223, upload-time = "2026-01-14T17:51:05.349Z" }, + { url = "https://files.pythonhosted.org/packages/ae/5e/e72fb87eaeafae1f4445367b191d767c79dd598e296af8ce4d7f5e32f47a/regex-2026.1.14-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d8ebc8b624d2b21cd8848b07a8716e17af3ac48c544d6744a34eaa3b66a1f99b", size = 800523, upload-time = "2026-01-14T17:51:06.717Z" }, + { url = "https://files.pythonhosted.org/packages/83/37/3070a2b9aaed5c7de6c10ca371a8dc96f8e5304894faa0006aa0657bd920/regex-2026.1.14-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:f47f467377fe52682edad3d9f4510f6e8a846192c2ed2e927b76d7dc1ce81403", size = 783049, upload-time = "2026-01-14T17:51:08Z" }, + { url = "https://files.pythonhosted.org/packages/3e/a7/11400e0723dc53a6003420ac66c8eab2441e846171116b009b569f440d4b/regex-2026.1.14-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:758d14314665e143c45bbf78e5219bde59a3a0b9c92adc3341c1dcdeaef0aabb", size = 854483, upload-time = "2026-01-14T17:51:09.515Z" }, + { url = "https://files.pythonhosted.org/packages/3a/0f/5df8e77a2ad2b3d92d4a0d23403794cfd99cd83c46e577dba1dba3b9a4f3/regex-2026.1.14-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:28000a3e574c2ea158594f94d20a638d7cd9fbe7b9bf12b1a3c71f0d79eba54c", size = 845984, upload-time = "2026-01-14T17:51:10.715Z" }, + { url = "https://files.pythonhosted.org/packages/95/84/234ae89175e5c02d8a16888fc4e8a8dffd7e5e02afd302d53f34d6dc71c1/regex-2026.1.14-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:14498dca698c013d7ac6fae81bbdad82a3178db5ed8ff31dec23a32c6516b1a4", size = 788989, upload-time = "2026-01-14T17:51:12.229Z" }, + { url = "https://files.pythonhosted.org/packages/05/85/3696495aa7aa503e262043f366593d71e60c654117804e793044613626b3/regex-2026.1.14-cp311-cp311-win32.whl", hash = "sha256:88a8fc629d3949667a424a6156d9e00e6cff47a97c6fe8c01609d80460471751", size = 265896, upload-time = "2026-01-14T17:51:13.988Z" }, + { url = "https://files.pythonhosted.org/packages/12/cc/e4a76c5362a342002f9ac5df4209e7cee500ae86f70ba91cfc7bb5fe3ed1/regex-2026.1.14-cp311-cp311-win_amd64.whl", hash = "sha256:11f50721a57e74793ee8497fe54598d9c4217d4458617c41dba87b8f37cc68dc", size = 277838, upload-time = "2026-01-14T17:51:15.508Z" }, + { url = "https://files.pythonhosted.org/packages/ad/60/bdc37e0c465c1385cc80a580239e8ccbf792bb68ca55099e5f9d66b2277f/regex-2026.1.14-cp311-cp311-win_arm64.whl", hash = "sha256:97560f0efb40ae57a54dc672a424d4aba4eaba0ade57c8660f76d2babfe19dd2", size = 270376, upload-time = "2026-01-14T17:51:17.124Z" }, + { url = "https://files.pythonhosted.org/packages/b6/04/b26d2d5c757f550abe4fbe40e6711515383f24cb5fa86f1970136582ccd2/regex-2026.1.14-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:15b1a93eac42eb341599ab08fb1a50880f0e8aa770730012db02b2e61ae44a37", size = 489393, upload-time = "2026-01-14T17:51:18.297Z" }, + { url = "https://files.pythonhosted.org/packages/49/4a/7287bf27056253fe04a6e2b1313d20160c47ab43098c4b43d46cb3ff8508/regex-2026.1.14-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7973c8b7b3019806e4c310eb3ec4c42785726d8cbfdccc8215aa74ba492dd42d", size = 291343, upload-time = "2026-01-14T17:51:19.898Z" }, + { url = "https://files.pythonhosted.org/packages/d6/de/93b43df6b780afa6d48b14a281fe7e6805f72f281af8fc916993e058178f/regex-2026.1.14-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a7a3ff18511c390975637cccf669db354c37d151db79da189f43aedb511d6318", size = 289009, upload-time = "2026-01-14T17:51:21.397Z" }, + { url = "https://files.pythonhosted.org/packages/49/23/e5551472791f9741999cdaaf631206afe8dbffde0ab8c2b94deca5fa750c/regex-2026.1.14-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:316c8bb5a3b15043e436366a380588be9340fb3c901a17cf9110fe24db237b1b", size = 798653, upload-time = "2026-01-14T17:51:23.215Z" }, + { url = "https://files.pythonhosted.org/packages/59/35/feac1c5303819fae5c5de682237fb7a4bdc1284a21875421240fa49b10fa/regex-2026.1.14-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c0aec48d38905858e7036f56dde117e459866bd47c4d4bd3dd8f21bcb2276bbb", size = 864251, upload-time = "2026-01-14T17:51:24.492Z" }, + { url = "https://files.pythonhosted.org/packages/6e/f4/71272dcd52fbcd4e9ab443a41208450ba565d77172649bfa35290845fb22/regex-2026.1.14-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:89fd92fb4014cf818446f13d6ffa5eede92130df576988148449b58b60b4e4c7", size = 912266, upload-time = "2026-01-14T17:51:26.205Z" }, + { url = "https://files.pythonhosted.org/packages/43/b3/278c78f130df71664c9d8b7241375474ee6b8b855624b3e3726806b768b6/regex-2026.1.14-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2b823a0bdc222f844ac1ff00028d39bcb9d86eb816877a80aa373a3bea23ead7", size = 803586, upload-time = "2026-01-14T17:51:27.636Z" }, + { url = "https://files.pythonhosted.org/packages/72/a7/7cb3b6e10949a4a3faad06e98a8a2342c5372583979d8408b4af9b3232a3/regex-2026.1.14-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2c732dd15a5d1e90d90bf96a0d9944f93c6f3ceaa6c2eae6977637854a5bd990", size = 787926, upload-time = "2026-01-14T17:51:28.966Z" }, + { url = "https://files.pythonhosted.org/packages/b8/10/5d16ffeb4ae724d65637022c7a91bb82008dc5c1fbad5b97b4f59bf3bb91/regex-2026.1.14-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:aa1e4a684f38086845583d5cea2c9a4868a85bacc9452e2e2fb8dcc3bbda352b", size = 858604, upload-time = "2026-01-14T17:51:30.393Z" }, + { url = "https://files.pythonhosted.org/packages/36/ef/ce1328e8d0fe7927a61e720954612c2eea6593b743c87ca511bab9c535c6/regex-2026.1.14-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:58f088f919f834fa05d2fdb4158c5db627a73b5b408671c93611dfb8e3af6029", size = 850694, upload-time = "2026-01-14T17:51:31.686Z" }, + { url = "https://files.pythonhosted.org/packages/b5/3e/f8435c408fc823c77ced902401af5bf105bdf06bb6de479ee2801d89b858/regex-2026.1.14-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f7b5a78b4b2298347b878c3557652138cbe89b24d9bb34f1d76bac54259da685", size = 789847, upload-time = "2026-01-14T17:51:33.312Z" }, + { url = "https://files.pythonhosted.org/packages/90/ba/f859d3e3a294918985a9eca10089af199597280a02ef466e6d4e309190ab/regex-2026.1.14-cp312-cp312-win32.whl", hash = "sha256:61ccdad7ed3c7f7a2ffaf43e187253fb10395f46b5a74cf22ee507202eef9356", size = 266278, upload-time = "2026-01-14T17:51:34.568Z" }, + { url = "https://files.pythonhosted.org/packages/44/18/c03b758752d334f0104fb54db9dbe94dfe9f7c10c687f59527e372e50088/regex-2026.1.14-cp312-cp312-win_amd64.whl", hash = "sha256:59b61fa84137a965de00cb58c663bc341e983ead803a16e7ba542ffe35027088", size = 277166, upload-time = "2026-01-14T17:51:36.432Z" }, + { url = "https://files.pythonhosted.org/packages/d2/de/3ae7865b4b5c1d1f98cda59affa03ac9dc147757d487b1cdc6e989fe561c/regex-2026.1.14-cp312-cp312-win_arm64.whl", hash = "sha256:190dd6e300710bc465bfbf90b8e7c45e8eeb676cb9af1863fb512c78af1f305f", size = 270416, upload-time = "2026-01-14T17:51:37.613Z" }, + { url = "https://files.pythonhosted.org/packages/21/8c/dbf1f86f33ea9e5365a18b5f82402092ab173244f5133b133128ce9b3f7c/regex-2026.1.14-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2d79c0cbcc86da60fee4410bd492cda9121cda2fc5a5a214b363b4566f973319", size = 489162, upload-time = "2026-01-14T17:51:38.854Z" }, + { url = "https://files.pythonhosted.org/packages/8a/cd/0d42bcd848be341b9d220a66b8ee79d74c3387f1def40a58d00ca26965d1/regex-2026.1.14-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4ec665ee043e552ea705829ababf342a6437916d3820afcfb520c803a863bab2", size = 291208, upload-time = "2026-01-14T17:51:40.058Z" }, + { url = "https://files.pythonhosted.org/packages/55/b0/d60e4a1260d1070df4e8be0e41917963821f345be3522b0f1490e122fd68/regex-2026.1.14-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d03aa0b4b376591dbbe55e43e410408baa58efbc1c233eb7701390833177e20a", size = 288897, upload-time = "2026-01-14T17:51:41.441Z" }, + { url = "https://files.pythonhosted.org/packages/98/7f/fb426139aca46aeaf1aa4dcd43ed3db4cc768efc34c724e51a3b139a8c40/regex-2026.1.14-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:69c16047721fd8b517a0ebb464cbd71391711438eeaeb210f2ca698a53ec6e81", size = 798678, upload-time = "2026-01-14T17:51:42.836Z" }, + { url = "https://files.pythonhosted.org/packages/2e/35/2e1f3c985d8cd5c6aec03fc96e51dfa972c24c0b4aaef6e065bc1de0bbfd/regex-2026.1.14-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:94e547c2a2504b8f7ae78c8adfe6a5112db85f57cb0ee020d2c5877275e870d2", size = 864207, upload-time = "2026-01-14T17:51:44.194Z" }, + { url = "https://files.pythonhosted.org/packages/76/7f/405e0f3b4d98614e58aab7c18ab820b741321d2dff29ef8e7d1948359912/regex-2026.1.14-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0911e70459d42483190bfbdb51ea503cde035ad2d5a6cc2a65d89500940d6cce", size = 912355, upload-time = "2026-01-14T17:51:45.586Z" }, + { url = "https://files.pythonhosted.org/packages/10/30/c818854bbf09f41b73474381c4126c9489e02c2baa1f2178f699b2085a78/regex-2026.1.14-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ccf545852edc516f51b97d71a93308340e002ff44dd70b5f3e8486ef7db921b", size = 803581, upload-time = "2026-01-14T17:51:47.217Z" }, + { url = "https://files.pythonhosted.org/packages/cb/95/6585eee0e4ff1a0970606975962491d17c78b16738274281068ee7c59546/regex-2026.1.14-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c09b0ced5b98299709d43f106e27218c4dd8d561bea1a257c77626e6863fdad3", size = 787977, upload-time = "2026-01-14T17:51:48.636Z" }, + { url = "https://files.pythonhosted.org/packages/c5/0b/f235cb019ee7f912d7cf2e026a38569c83c0eb2bb74200551148f80ab3cb/regex-2026.1.14-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:dbfac3be697beadc831fad938d174d73a2903142447a617ef831ce870d7ec1dd", size = 858547, upload-time = "2026-01-14T17:51:49.998Z" }, + { url = "https://files.pythonhosted.org/packages/d5/1e/c8561f3a01e9031c7ecc425aac2f25178487335efbee6a6c5a8a648013c2/regex-2026.1.14-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:f9c67b0aec9e36daeb34051478c51fcf15c6eac8207645c6660f657ed26002a5", size = 850613, upload-time = "2026-01-14T17:51:51.614Z" }, + { url = "https://files.pythonhosted.org/packages/6d/21/4a1b879a4e2b991d65c92190a5e8024571c89c045cc4cf305166416b1c7b/regex-2026.1.14-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:aaf3cd810e128763d8010633af1206715daf333348a01bb5eb72c99ed15b0277", size = 789950, upload-time = "2026-01-14T17:51:53.719Z" }, + { url = "https://files.pythonhosted.org/packages/28/a3/096178b84bcb17b74d22295c5f6882c1068548cb4ddd33cfa09f1e021e14/regex-2026.1.14-cp313-cp313-win32.whl", hash = "sha256:811c57a0a32b2b9507a2d0eb4b0bfd56dce041c97c00bea6a5cca205173619a5", size = 266273, upload-time = "2026-01-14T17:51:55.097Z" }, + { url = "https://files.pythonhosted.org/packages/40/dc/708fc41f410a5d5b47ee0e0475ce9e5cc981915398035d36bb162a64dfc8/regex-2026.1.14-cp313-cp313-win_amd64.whl", hash = "sha256:d6e2e253bfd1c45b1c14f22034c88673d90a8ff21a8d410fda973e23989e14a5", size = 277146, upload-time = "2026-01-14T17:51:56.46Z" }, + { url = "https://files.pythonhosted.org/packages/d3/69/3b7090b0305672c998c1dfc27f859b406f49da381357a30ee3d112cdfe81/regex-2026.1.14-cp313-cp313-win_arm64.whl", hash = "sha256:bd2d610ce699cf378e23b63e435678742b7d565d546aaf26f7c1f14d228da78d", size = 270410, upload-time = "2026-01-14T17:51:57.807Z" }, + { url = "https://files.pythonhosted.org/packages/3f/67/d4254424777851b16c3622049383c1c71259c9d4bea87f0d304376541a28/regex-2026.1.14-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b8430037d2253f9640727c046471bc093773abcefd212254d4b904730536b652", size = 492070, upload-time = "2026-01-14T17:51:59.178Z" }, + { url = "https://files.pythonhosted.org/packages/a6/9e/c3321f78f1ddb4eee88969db36fb8552217dd99d9b16a7c0ac6e88340796/regex-2026.1.14-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:10a6877cee35e574234928bcb74125063ff907fc0f5efca7a5a44bebd2fe87f3", size = 292752, upload-time = "2026-01-14T17:52:00.772Z" }, + { url = "https://files.pythonhosted.org/packages/1f/f9/d7dd071d5d12f4f58950432c4f967b3ba6ddbd14bc84b0280a35284dd287/regex-2026.1.14-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:41f9b06ec8ebd743c78e331d062d868c398817bfb2b04191e107c1ee2ac202ed", size = 291116, upload-time = "2026-01-14T17:52:02.162Z" }, + { url = "https://files.pythonhosted.org/packages/fd/f4/a2d81988df08bb13e2068eec072c3d46fc578575975bba549f512bc74495/regex-2026.1.14-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:650979c05e632bc80f6267e645ad152e13c6931d6295c0ad8ba3e637c118f124", size = 807521, upload-time = "2026-01-14T17:52:03.495Z" }, + { url = "https://files.pythonhosted.org/packages/d9/b0/0f4217aa90bb83e04cbae39a7428fa27ed9e21dd6b5fc10186fb9a341da3/regex-2026.1.14-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6464d2c038c6bb6b534ac3144281fd5d38268bcb77cf6e17b399ca79ebbae25c", size = 873453, upload-time = "2026-01-14T17:52:04.862Z" }, + { url = "https://files.pythonhosted.org/packages/8c/69/b494cefbf67d1895568d952f1343a029dfe93428816a9956d8022f7a24f1/regex-2026.1.14-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ada211c9b8d6c0b2860ea37a52e06b0b3b316dbc28f403530e0227868318c9d4", size = 915006, upload-time = "2026-01-14T17:52:06.304Z" }, + { url = "https://files.pythonhosted.org/packages/a7/d4/54d81ba0b45893ab9dec83134d3fef383f807987c6618de3ea5ecceb98cb/regex-2026.1.14-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:54cf46d11bb344d86fc5514171a55220f87a31706ef9c0cd41b310f706d50db8", size = 812793, upload-time = "2026-01-14T17:52:07.986Z" }, + { url = "https://files.pythonhosted.org/packages/56/40/2a477aa0a2b869ea2538a7ab1ee46d581be5f17da345e9913b7a0baf7701/regex-2026.1.14-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:83792c2865452dbbd14fb00fd7c00cfc38ea76bf579944e8244a9e1b78a58bde", size = 795557, upload-time = "2026-01-14T17:52:09.45Z" }, + { url = "https://files.pythonhosted.org/packages/07/0f/54b5af02916f3ca90987c0e1c744b7fee572f1873da9b6256f85783286e4/regex-2026.1.14-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:e3e7fbf8403dadf202438c0e1c677c21809fc7ba7489f8449b76fe27a8284264", size = 868425, upload-time = "2026-01-14T17:52:11.392Z" }, + { url = "https://files.pythonhosted.org/packages/74/cd/c9dfdd504497a25ba64c4ef846c37f74019cfdedfe3d1cdcba4033a3ac0c/regex-2026.1.14-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:109ce462bf9f91ca72ef2864af706f0ed3d37de7692d9b90e9cff1e44ad6c3b4", size = 854751, upload-time = "2026-01-14T17:52:12.835Z" }, + { url = "https://files.pythonhosted.org/packages/95/b3/e5347ed1eb68a0c8d6c6b5da9318c564308d022b721b1c2ca311f7a8bd74/regex-2026.1.14-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:299c1c591ecd798ce2529e24f6e11f2fe3cc42bb21b0fead056880e0d57752c3", size = 799557, upload-time = "2026-01-14T17:52:14.228Z" }, + { url = "https://files.pythonhosted.org/packages/37/db/a6d2ca85e4bb3e02d106d7e3215ef07fb896cac1afe5ab206bb37cf26c30/regex-2026.1.14-cp313-cp313t-win32.whl", hash = "sha256:f77447f07d7dca963dec9c8f6cfc9c0fef83f40d6124f895d82d0c35d57afe62", size = 268876, upload-time = "2026-01-14T17:52:15.697Z" }, + { url = "https://files.pythonhosted.org/packages/fe/09/38423b655f01ddcf10444c637866113c0ddd0a9c89827841663394afc636/regex-2026.1.14-cp313-cp313t-win_amd64.whl", hash = "sha256:8eb68a9449ccdfdd40ed1f59eb579ecfcbaad5a93b17243ca234c4587ac07ec3", size = 280314, upload-time = "2026-01-14T17:52:16.972Z" }, + { url = "https://files.pythonhosted.org/packages/27/12/8a6ef769b0ee3a4df49e42dec9259e444cbe98bd4303b1fec38ff456425c/regex-2026.1.14-cp313-cp313t-win_arm64.whl", hash = "sha256:c50b3ebab43bbf7e7c60b7e501ae657a15efe3439cbae4acc1cb87031ba9b004", size = 271555, upload-time = "2026-01-14T17:52:18.416Z" }, + { url = "https://files.pythonhosted.org/packages/7d/90/64dcf099f3efde2115ceb0a2482064d2630532a8c2b40c95d01f4b886d68/regex-2026.1.14-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:c07bbee79ceb399ae4c8294b154fccdf2eefc1e86b157338d93e9e46ed327cd4", size = 489164, upload-time = "2026-01-14T17:52:19.811Z" }, + { url = "https://files.pythonhosted.org/packages/57/33/11f82bcf6df1477211390d3c55d9a65bbdf0454101fe6f101bbf428ed72e/regex-2026.1.14-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:ef59c01b8eab361b3e5768f491a0a59c6fc3b862d34d08ec9b78ce7b3f9c5d11", size = 291147, upload-time = "2026-01-14T17:52:21.146Z" }, + { url = "https://files.pythonhosted.org/packages/dc/b4/33df4bc04af4a7abf5754da3a1d131e9384e59ca4431d85af9f5cf7e040d/regex-2026.1.14-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:db72aebb3eede342088f6940aea3cc59f2bbf93295b8a7c7a98fa293b20accc9", size = 288981, upload-time = "2026-01-14T17:52:22.675Z" }, + { url = "https://files.pythonhosted.org/packages/72/fd/d89b1425b9b420877eec3588d1abec08948e836461a16e4748be64078cda/regex-2026.1.14-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:23da4da4a156d613946f197ad85da2c3ce3269166909249206dbfc6a62e27d4b", size = 799097, upload-time = "2026-01-14T17:52:24.081Z" }, + { url = "https://files.pythonhosted.org/packages/04/f0/149b80499a12a9ef525656a780abca8383b9689687afb3eef8f16d62574c/regex-2026.1.14-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c59aaa581c86d0003a805843399fdd859e3803ee3f6bf694a96ede726b60d26c", size = 864980, upload-time = "2026-01-14T17:52:25.847Z" }, + { url = "https://files.pythonhosted.org/packages/ce/bb/bec2a2ba7e0120915b02d46b68c265938a235657baaf7be79746e0a40583/regex-2026.1.14-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4176e42a6b940b704b25d804afc799c4f2cf88c245e71c111d91c9259a5a08bd", size = 911606, upload-time = "2026-01-14T17:52:27.529Z" }, + { url = "https://files.pythonhosted.org/packages/c3/49/fcb59ec88bf188632877ea18eca43bed95c49fd049a3a16f943dc48ec225/regex-2026.1.14-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:336662753d92ecc68a6e5d557d5648550927c3927bb18959a6c336c6d2309b95", size = 803356, upload-time = "2026-01-14T17:52:29.031Z" }, + { url = "https://files.pythonhosted.org/packages/04/a3/a4e1873b32c7b4e9839edbf86d2369bbbd5759581481bf095eb561186acd/regex-2026.1.14-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:d3f2da65a0552a244319cd0079f7dcbd7b18358c05ca803757f123b5315f9e2b", size = 788042, upload-time = "2026-01-14T17:52:30.546Z" }, + { url = "https://files.pythonhosted.org/packages/05/b9/0f3fcb32b9ac5467f3a6634fc867bb008670eabebc5dbf91c76d0ee63d1d/regex-2026.1.14-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:ea9cb3230eb1791b74530fe59a9ad1e41282eee27cddf9f765cb216f1a35b491", size = 859373, upload-time = "2026-01-14T17:52:32.11Z" }, + { url = "https://files.pythonhosted.org/packages/1f/f4/f1f7602b5e9a60fdabebaf5b6796b460a4820fbe932993467ae6c12bd8ac/regex-2026.1.14-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:00a41d01546c09bfd972492f4f45515cba5cd8b15d129e6f43b5e9b6bf5cf5db", size = 850110, upload-time = "2026-01-14T17:52:34.615Z" }, + { url = "https://files.pythonhosted.org/packages/b5/e4/96e231d977a95fe493931ee10b60352d7b0f25fe733660dd4ce34d7669dd/regex-2026.1.14-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a16876c16759e2030cbc2444d72a773ba9fb11c58ddf1411bceac3015e85ad62", size = 789584, upload-time = "2026-01-14T17:52:36.667Z" }, + { url = "https://files.pythonhosted.org/packages/94/3d/0d1cb4b3f20be9767c83ddb8e2c7041d9919363ffc50578231305e6ab768/regex-2026.1.14-cp314-cp314-win32.whl", hash = "sha256:0283db18f918b1b6627e5b9d018ea6cc90d25980d9c6ce0d60de3ea83047947e", size = 271689, upload-time = "2026-01-14T17:52:38.155Z" }, + { url = "https://files.pythonhosted.org/packages/65/a0/7c153b77d72b873e520905fecdf61456f78bad8c4a0c063420c643f76f9c/regex-2026.1.14-cp314-cp314-win_amd64.whl", hash = "sha256:f1c997e66c992bfabfb08581e7739568ffb76d2ced4344ff81783961e71ac5ea", size = 280418, upload-time = "2026-01-14T17:52:39.716Z" }, + { url = "https://files.pythonhosted.org/packages/91/98/77408d72e1bc4040007479c4553097b81c084faf2b53ae3bd20f216cc601/regex-2026.1.14-cp314-cp314-win_arm64.whl", hash = "sha256:1d6c1deddd7bf9793f87293edaa28a7e23f753dbfb5b0cafaa724ee87b2f854d", size = 273466, upload-time = "2026-01-14T17:52:41.096Z" }, + { url = "https://files.pythonhosted.org/packages/11/fe/16f795a7d49970393f43c1593a59057d9f0037858cd9797ca2e6965031e6/regex-2026.1.14-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:2cccc1a0d1c07dc5e7f65f042f17a678aa431b27d2c1b33983cdb52daf4e49a5", size = 492068, upload-time = "2026-01-14T17:52:42.561Z" }, + { url = "https://files.pythonhosted.org/packages/b8/8d/297e5410c4aba87c0c5c7760e1ffa34f9d4bec0bd3b264073c5f6d251ab1/regex-2026.1.14-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c47239a9736f6082540f91f77dd634a7771eac1e8720bc35ef274d8ea0a72b90", size = 292752, upload-time = "2026-01-14T17:52:44.414Z" }, + { url = "https://files.pythonhosted.org/packages/5c/8d/d9efc9580631603255856b306e4a19c6c3b45491a793ce60a4de76118831/regex-2026.1.14-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:ac5dcb96ed037c692eb40b0c96bd5ba588f07fd898bd14e111c751a4bf195b21", size = 291118, upload-time = "2026-01-14T17:52:46.315Z" }, + { url = "https://files.pythonhosted.org/packages/ac/cd/89735cc17f41667bf1cb7fb341109eb19ada117ef0a8e8882a9396de68f0/regex-2026.1.14-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:98af87df5496a39a7f4fa619568a12e0b719af25e75ecbd968a671609fda3702", size = 807759, upload-time = "2026-01-14T17:52:47.771Z" }, + { url = "https://files.pythonhosted.org/packages/bf/2d/e5db572360c76b335d578a4bec6437b302e1f170722b1f0c79c7295ec169/regex-2026.1.14-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:670d6865632ef2ad1ba0f326b4eb714107b71e3ea9a48a2564d407834273e2da", size = 873536, upload-time = "2026-01-14T17:52:49.695Z" }, + { url = "https://files.pythonhosted.org/packages/b3/a1/704748140afb90045c3d635cd1929e15b821627ef7a1b4ae22fe3c1cf18a/regex-2026.1.14-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:aa783080add7cedbeb8c11e8c7e3efb9353b7c183701548fae70ec44b7b886cd", size = 915064, upload-time = "2026-01-14T17:52:51.199Z" }, + { url = "https://files.pythonhosted.org/packages/7c/5a/00699f1bcc8f5aaf9cae4b1f673c1a3ba5256ea2d4d53f8f21319976cd25/regex-2026.1.14-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0aa76616ee8a1fdefa62f486324ba6fecc3059261779ebb1575a7b7ddf5fb7c9", size = 812937, upload-time = "2026-01-14T17:52:52.77Z" }, + { url = "https://files.pythonhosted.org/packages/b6/fd/c6742cb9ed24a8fe197603a6808e5641eaaa59c13a2ad8624d39d0405d82/regex-2026.1.14-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:6ff33c4e28c44de3e1877afaf55feb306093cb6cb8e49bf083cfd9bdb258e130", size = 795650, upload-time = "2026-01-14T17:52:54.717Z" }, + { url = "https://files.pythonhosted.org/packages/17/36/ccadcc5f1204529ca638c969659a9b56ef706f4eb908bbd7a9a7645793b8/regex-2026.1.14-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:a39879c089bc84fd8ab6f02de458534e7ed8e7bf72091322ff0d8b9138f612c1", size = 868549, upload-time = "2026-01-14T17:52:56.309Z" }, + { url = "https://files.pythonhosted.org/packages/78/5e/a7b09f3031bbd0e1ab15d08277cac61193adfd62bb6d10e7ba4e69cee4e6/regex-2026.1.14-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:cf0ce5cd5b0c011ec49ff51f85f5ba6ed46ecc5491fa60f803734b2e70dd32aa", size = 854779, upload-time = "2026-01-14T17:52:57.789Z" }, + { url = "https://files.pythonhosted.org/packages/de/ae/a70e39d97b9611628b1d9c3a709d24f1639bcbfa99277391864303a8cd61/regex-2026.1.14-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a29ecdaa0f5dac290b17b61150d00646240b195dbe2950bf3de6360cf41c7cce", size = 799776, upload-time = "2026-01-14T17:52:59.344Z" }, + { url = "https://files.pythonhosted.org/packages/30/43/5789f3f398de60ce9fd743b014a7f37bac5659f2dcbdcceb093d2a8778ab/regex-2026.1.14-cp314-cp314t-win32.whl", hash = "sha256:f9874b7d8ce8f12553fff86a1b49311897a391af598d4f5d1d0f08bbf7430739", size = 274667, upload-time = "2026-01-14T17:53:00.923Z" }, + { url = "https://files.pythonhosted.org/packages/69/bf/76136bfd87fe40d840220c190dfc36114afa0e5338ffe5da2e55b238bc37/regex-2026.1.14-cp314-cp314t-win_amd64.whl", hash = "sha256:401795f2195562e87f382a477404b05e4662c365c300abdea79858719870377d", size = 284388, upload-time = "2026-01-14T17:53:02.432Z" }, + { url = "https://files.pythonhosted.org/packages/2a/80/4ffc8b077a3b5bcaa6be885e77c9d9732ee335218fc438509294120da649/regex-2026.1.14-cp314-cp314t-win_arm64.whl", hash = "sha256:3539f717f3cba7a12b8f575c86edd9ecc21bf387f3590d32d995320a397a7dcf", size = 274839, upload-time = "2026-01-14T17:53:03.827Z" }, ] [[package]] @@ -6661,28 +6667,28 @@ wheels = [ [[package]] name = "uv" -version = "0.9.24" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0b/7f/6692596de7775b3059a55539aed2eec16a0642a2d6d3510baa5878287ce4/uv-0.9.24.tar.gz", hash = "sha256:d59d31c25fc530c68db9164174efac511a25fc882cec49cd48f75a18e7ebd6d5", size = 3852673, upload-time = "2026-01-09T22:34:31.635Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b8/51/10bb9541c40a5b4672527c357997a30fdf38b75e7bbaad0c37ed70889efa/uv-0.9.24-py3-none-linux_armv6l.whl", hash = "sha256:75a000f529ec92235b10fb5e16ca41f23f46c643308fd6c5b0d7b73ca056c5b9", size = 21395664, upload-time = "2026-01-09T22:34:05.887Z" }, - { url = "https://files.pythonhosted.org/packages/ec/dd/d7df524cb764ebc652e0c8bf9abe55fc34391adc2e4ab1d47375222b38a9/uv-0.9.24-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:207c8a2d4c4d55589feb63b4be74f6ff6ab92fa81b14a6515007ccec5a868ae0", size = 20547988, upload-time = "2026-01-09T22:34:16.21Z" }, - { url = "https://files.pythonhosted.org/packages/49/e4/7ca5e7eaed4b2b9d407aa5aeeb8f71cace7db77f30a63139bbbfdfe4770c/uv-0.9.24-py3-none-macosx_11_0_arm64.whl", hash = "sha256:44c0b8a78724e4cfa8e9c0266023c70fc792d0b39a5da17f5f847af2b530796b", size = 19033208, upload-time = "2026-01-09T22:33:50.91Z" }, - { url = "https://files.pythonhosted.org/packages/27/05/b7bab99541056537747bfdc55fdc97a4ba998e2b53cf855411ef176c412b/uv-0.9.24-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:841ede01d6dcf1676a21dce05f3647ba171c1d92768a03e8b8b6b7354b34a6d2", size = 20872212, upload-time = "2026-01-09T22:33:58.007Z" }, - { url = "https://files.pythonhosted.org/packages/d3/93/3a69cf481175766ee6018afb281666de12ccc04367d20a41dc070be8b422/uv-0.9.24-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:69531d9a8772afb2dff68fef2469f666e4f8a0132b2109e36541c423415835da", size = 21017966, upload-time = "2026-01-09T22:34:29.354Z" }, - { url = "https://files.pythonhosted.org/packages/17/40/7aec2d428e57a3ec992efc49bbc71e4a0ceece5a726751c661ddc3f41315/uv-0.9.24-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6720c9939cca7daff3cccc35dd896bbe139d7d463c62cba8dbbc474ff8eb93d1", size = 21943358, upload-time = "2026-01-09T22:34:08.63Z" }, - { url = "https://files.pythonhosted.org/packages/c8/f4/2aa5b275aa8e5edb659036e94bae13ae294377384cf2a93a8d742a38050f/uv-0.9.24-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d7d1333d9c21088c89cb284ef29fdf48dc2015fe993174a823a3e7c991db90f9", size = 23672949, upload-time = "2026-01-09T22:34:03.113Z" }, - { url = "https://files.pythonhosted.org/packages/8e/24/2589bed4b39394c799472f841e0580318a8b7e69ef103a0ab50cf1c39dff/uv-0.9.24-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b610d89d6025000d08cd9bd458c6e264003a0ecfdaa8e4eba28938130cd1837", size = 23270210, upload-time = "2026-01-09T22:34:13.94Z" }, - { url = "https://files.pythonhosted.org/packages/80/3a/034494492a1ad1f95371c6fd735e4b7d180b8c1712c88b0f32a34d6352fd/uv-0.9.24-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:38c59e18fe5fa42f7baeb4f08c94914cee6d87ff8faa6fc95c994dbc0de26c90", size = 22282247, upload-time = "2026-01-09T22:33:53.362Z" }, - { url = "https://files.pythonhosted.org/packages/be/0e/d8ab2c4fa6c9410a8a37fa6608d460b0126cee2efed9eecf516cdec72a1a/uv-0.9.24-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:009cc82cdfc48add6ec13a0c4ffbb788ae2cab53573b4218069ca626721a404b", size = 22348801, upload-time = "2026-01-09T22:34:00.46Z" }, - { url = "https://files.pythonhosted.org/packages/50/fa/7217764e4936d6fda1944d956452bf94f790ae8a02cb3e5aa496d23fcb25/uv-0.9.24-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:1914d33e526167dc202ec4a59119c68467b31f7c71dcf8b1077571d091ca3e7c", size = 21000825, upload-time = "2026-01-09T22:34:21.811Z" }, - { url = "https://files.pythonhosted.org/packages/94/8f/533db58a36895142b0c11eedf8bfe11c4724fb37deaa417bfb0c689d40b8/uv-0.9.24-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:aafe7dd9b633672054cf27f1a8e4127506324631f1af5edd051728f4f8085351", size = 22149066, upload-time = "2026-01-09T22:33:45.722Z" }, - { url = "https://files.pythonhosted.org/packages/cf/c7/e6eccd96341a548f0405bffdf55e7f30b5c0757cd1b8f7578e0972a66002/uv-0.9.24-py3-none-musllinux_1_1_armv7l.whl", hash = "sha256:63a0a46693098cf8446e41bd5d9ce7d3bc9b775a63fe0c8405ab6ee328424d46", size = 20993489, upload-time = "2026-01-09T22:34:27.007Z" }, - { url = "https://files.pythonhosted.org/packages/46/07/32d852d2d40c003b52601c44202c9d9e655c485fae5d84e42f326814b0be/uv-0.9.24-py3-none-musllinux_1_1_i686.whl", hash = "sha256:15d3955bfb03a7b78aaf5afb639cedefdf0fc35ff844c92e3fe6e8700b94b84f", size = 21400775, upload-time = "2026-01-09T22:34:24.278Z" }, - { url = "https://files.pythonhosted.org/packages/b0/58/f8e94226126011ba2e2e9d59c6190dc7fe9e61fa7ef4ca720d7226c1482b/uv-0.9.24-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:488a07e59fb417bf86de5630197223b7a0223229e626afc124c26827db78cff8", size = 22554194, upload-time = "2026-01-09T22:34:18.504Z" }, - { url = "https://files.pythonhosted.org/packages/da/8e/b540c304039a6561ba8e9a673009cfe1451f989d2269fe40690901ddb233/uv-0.9.24-py3-none-win32.whl", hash = "sha256:68a3186074c03876ee06b68730d5ff69a430296760d917ebbbb8e3fb54fb4091", size = 20203184, upload-time = "2026-01-09T22:34:11.02Z" }, - { url = "https://files.pythonhosted.org/packages/16/59/dba7c5feec1f694183578435eaae0d759b8c459c5e4f91237a166841a116/uv-0.9.24-py3-none-win_amd64.whl", hash = "sha256:8cd626306b415491f839b1a9100da6795c82c44d4cf278dd7ace7a774af89df4", size = 22294050, upload-time = "2026-01-09T22:33:48.228Z" }, - { url = "https://files.pythonhosted.org/packages/d7/ef/e58fb288bafb5a8b5d4994e73fa6e062e408680e5a20d0427d5f4f66d8b1/uv-0.9.24-py3-none-win_arm64.whl", hash = "sha256:8d3c0fec7aa17f936a5b258816e856647b21f978a81bcfb2dc8caf2892a4965e", size = 20620004, upload-time = "2026-01-09T22:33:55.62Z" }, +version = "0.9.25" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/41/b3/c2a6afd3d8f8f9f5d9c65fcdff1b80fb5bdaba21c8b0e99dd196e71d311f/uv-0.9.25.tar.gz", hash = "sha256:8625de8f40e7b669713e293ab4f7044bca9aa7f7c739f17dc1fd0cb765e69f28", size = 3863318, upload-time = "2026-01-13T23:20:16.141Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/13/e1/9284199aed638643a4feadf8b3283c1d43b3c3adcbdac367f26a8f5e398f/uv-0.9.25-py3-none-linux_armv6l.whl", hash = "sha256:db51f37b3f6c94f4371d8e26ee8adeb9b1b1447c5fda8cc47608694e49ea5031", size = 21479938, upload-time = "2026-01-13T23:21:13.011Z" }, + { url = "https://files.pythonhosted.org/packages/6d/5c/79dc42e1abf0afc021823c688ff04e4283f9e72d20ca4af0027aa7ed29df/uv-0.9.25-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e47a9da2ddd33b5e7efb8068a24de24e24fd0d88a99e0c4a7e2328424783eab8", size = 20681034, upload-time = "2026-01-13T23:20:19.269Z" }, + { url = "https://files.pythonhosted.org/packages/7c/0b/997f279db671fe4b1cf87ad252719c1b7c47a9546efd6c2594b5648ea983/uv-0.9.25-py3-none-macosx_11_0_arm64.whl", hash = "sha256:79af8c9b885b507a82087e45161a4bda7f2382682867dc95f7e6d22514ac844d", size = 19096089, upload-time = "2026-01-13T23:20:55.021Z" }, + { url = "https://files.pythonhosted.org/packages/d5/60/a7682177fe76501b403d464b4fee25c1ee4089fe56caf7cb87c2e6741375/uv-0.9.25-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:6ca6bdd3fe4b1730d1e3d10a4ce23b269915a60712379d3318ecea9a4ff861fd", size = 20848810, upload-time = "2026-01-13T23:20:13.916Z" }, + { url = "https://files.pythonhosted.org/packages/4c/c1/01d5df4cbec33da51fc85868f129562cbd1488290465107c03bed90d8ca4/uv-0.9.25-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9d993b9c590ac76f805e17441125d67c7774b1ba05340dc987d3de01852226b6", size = 21095071, upload-time = "2026-01-13T23:20:44.488Z" }, + { url = "https://files.pythonhosted.org/packages/6d/fe/f7cd2f02b0e0974dd95f732efd12bd36a3e8419d53f4d1d49744d2e3d979/uv-0.9.25-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c4b72881d0c66ad77844451dbdbcada87242c0d39c6bfd0f89ac30b917a3cfc3", size = 22070541, upload-time = "2026-01-13T23:21:16.936Z" }, + { url = "https://files.pythonhosted.org/packages/f8/e6/ef53b6d69b303eca6aa56ad97eb322f6cc5b9571c403e4e64313f1ccfb81/uv-0.9.25-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:ac0dfb6191e91723a69be533102f98ffa5739cba57c3dfc5f78940c27cf0d7e8", size = 23663768, upload-time = "2026-01-13T23:20:29.808Z" }, + { url = "https://files.pythonhosted.org/packages/0d/f8/f0e01ddfc62cb4b8ec5c6d94e46fc77035c0cd77865d7958144caadf8ad9/uv-0.9.25-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:41ae0f2df7c931b72949345134070efa919174321c5bd403954db960fa4c2d7d", size = 23235860, upload-time = "2026-01-13T23:20:58.724Z" }, + { url = "https://files.pythonhosted.org/packages/6a/56/905257af2c63ffaec9add9cce5d34f851f418d42e6f4e73fee18adecd499/uv-0.9.25-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3bf02fcea14b8bec42b9c04094cc5b527c2cd53b606c06e7bdabfbd943b4512c", size = 22236426, upload-time = "2026-01-13T23:20:40.995Z" }, + { url = "https://files.pythonhosted.org/packages/bb/ce/909feee469647b7929967397dcb1b6b317cfca07dc3fc0699b3cab700daf/uv-0.9.25-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642f993d8c74ecd52b192d5f3168433c4efa81b8bb19c5ac97c25f27a44557cb", size = 22294538, upload-time = "2026-01-13T23:21:09.521Z" }, + { url = "https://files.pythonhosted.org/packages/82/be/ac7cd3c45c6baf0d5181133d3bda13f843f76799809374095b6fc7122a96/uv-0.9.25-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:564b5db5e148670fdbcfd962ee8292c0764c1be0c765f63b620600a3c81087d1", size = 20963345, upload-time = "2026-01-13T23:20:25.706Z" }, + { url = "https://files.pythonhosted.org/packages/19/fd/7b6191cef8da4ad451209dde083123b1ac9d10d6c2c1554a1de64aa41ad8/uv-0.9.25-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:991cfb872ef3bc0cc5e88f4d3f68adf181218a3a57860f523ff25279e4cf6657", size = 22205573, upload-time = "2026-01-13T23:20:33.611Z" }, + { url = "https://files.pythonhosted.org/packages/15/80/8d6809df5e5ddf862f963fbfc8b2a25c286dc36724e50c7536e429d718be/uv-0.9.25-py3-none-musllinux_1_1_armv7l.whl", hash = "sha256:e1b4ab678c6816fe41e3090777393cf57a0f4ef122f99e9447d789ab83863a78", size = 21036715, upload-time = "2026-01-13T23:20:51.413Z" }, + { url = "https://files.pythonhosted.org/packages/bb/78/e3cb00bf90a359fa8106e2446bad07e49922b41e096e4d3b335b0065117a/uv-0.9.25-py3-none-musllinux_1_1_i686.whl", hash = "sha256:aa7db0ab689c3df34bdd46f83d2281d268161677ccd204804a87172150a654ef", size = 21505379, upload-time = "2026-01-13T23:21:06.045Z" }, + { url = "https://files.pythonhosted.org/packages/86/36/07f69f45878175d2907110858e5c6631a1b712420d229012296c1462b133/uv-0.9.25-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:a658e47e54f11dac9b2751fba4ad966a15db46c386497cf51c1c02f656508358", size = 22520308, upload-time = "2026-01-13T23:20:09.704Z" }, + { url = "https://files.pythonhosted.org/packages/1d/1b/2d457ee7e2dd35fc22ae6f656bb45b781b33083d4f0a40901b9ae59e0b10/uv-0.9.25-py3-none-win32.whl", hash = "sha256:4df14479f034f6d4dca9f52230f912772f56ceead3354c7b186a34927c22188a", size = 20263705, upload-time = "2026-01-13T23:20:47.814Z" }, + { url = "https://files.pythonhosted.org/packages/5c/0b/05ad2dc53dab2c8aa2e112ef1f9227a7b625ba3507bedd7b31153d73aa5f/uv-0.9.25-py3-none-win_amd64.whl", hash = "sha256:001629fbc2a955c35f373311591c6952be010a935b0bc6244dc61da108e4593d", size = 22311694, upload-time = "2026-01-13T23:21:02.562Z" }, + { url = "https://files.pythonhosted.org/packages/54/4e/99788924989082356d6aa79d8bfdba1a2e495efaeae346fd8fec83d3f078/uv-0.9.25-py3-none-win_arm64.whl", hash = "sha256:ea26319abf9f5e302af0d230c0f13f02591313e5ffadac34931f963ef4d7833d", size = 20645549, upload-time = "2026-01-13T23:20:37.201Z" }, ] [[package]]