diff --git a/.github/upgrades/prompts/SemanticKernelToAgentFramework.md b/.github/upgrades/prompts/SemanticKernelToAgentFramework.md index a8c3dcb0a6..44985bba98 100644 --- a/.github/upgrades/prompts/SemanticKernelToAgentFramework.md +++ b/.github/upgrades/prompts/SemanticKernelToAgentFramework.md @@ -839,7 +839,7 @@ var agentOptions = new ChatClientAgentRunOptions(new ChatOptions { MaxOutputTokens = 8000, // Breaking glass to access provider-specific options - RawRepresentationFactory = (_) => new OpenAI.Responses.ResponseCreationOptions() + RawRepresentationFactory = (_) => new OpenAI.Responses.CreateResponseOptions() { ReasoningOptions = new() { diff --git a/.github/workflows/dotnet-build-and-test.yml b/.github/workflows/dotnet-build-and-test.yml index 4a41b343aa..692f3e7c45 100644 --- a/.github/workflows/dotnet-build-and-test.yml +++ b/.github/workflows/dotnet-build-and-test.yml @@ -35,7 +35,8 @@ jobs: contents: read pull-requests: read outputs: - dotnetChanges: ${{ steps.filter.outputs.dotnet}} + dotnetChanges: ${{ steps.filter.outputs.dotnet }} + cosmosDbChanges: ${{ steps.filter.outputs.cosmosdb }} steps: - uses: actions/checkout@v6 - uses: dorny/paths-filter@v3 @@ -44,10 +45,15 @@ jobs: filters: | dotnet: - 'dotnet/**' + cosmosdb: + - 'dotnet/src/Microsoft.Agents.AI.CosmosNoSql/**' # run only if 'dotnet' files were changed - name: dotnet tests if: steps.filter.outputs.dotnet == 'true' run: echo "Dotnet file" + - name: dotnet CosmosDB tests + if: steps.filter.outputs.cosmosdb == 'true' + run: echo "Dotnet CosmosDB changes" # run only if not 'dotnet' files were changed - name: not dotnet tests if: steps.filter.outputs.dotnet != 'true' @@ -77,6 +83,16 @@ jobs: dotnet python workflow-samples + + # Start Cosmos DB Emulator for all integration tests and only for unit tests when CosmosDB changes happened) + - name: Start Azure Cosmos DB Emulator + if: ${{ runner.os == 'Windows' && (needs.paths-filter.outputs.cosmosDbChanges == 'true' || (github.event_name != 'pull_request' && matrix.integration-tests)) }} + shell: pwsh + run: | + Write-Host "Launching Azure Cosmos DB Emulator" + Import-Module "$env:ProgramFiles\Azure Cosmos DB Emulator\PSModules\Microsoft.Azure.CosmosDB.Emulator" + Start-CosmosDbEmulator -NoUI -Key "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" + echo "COSMOS_EMULATOR_AVAILABLE=true" >> $env:GITHUB_ENV - name: Setup dotnet uses: actions/setup-dotnet@v5.0.1 @@ -123,17 +139,7 @@ jobs: popd popd rm -rf "$TEMP_DIR" - - # Start Cosmos DB Emulator for Cosmos-based unit tests (only on Windows) - - name: Start Azure Cosmos DB Emulator - if: runner.os == 'Windows' - shell: pwsh - run: | - Write-Host "Launching Azure Cosmos DB Emulator" - Import-Module "$env:ProgramFiles\Azure Cosmos DB Emulator\PSModules\Microsoft.Azure.CosmosDB.Emulator" - Start-CosmosDbEmulator -NoUI -Key "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" - echo "COSMOS_EMULATOR_AVAILABLE=true" >> $env:GITHUB_ENV - + - name: Run Unit Tests shell: bash run: | @@ -225,7 +231,7 @@ jobs: - name: Upload coverage report artifact if: matrix.targetFramework == env.COVERAGE_FRAMEWORK - uses: actions/upload-artifact@v5 + uses: actions/upload-artifact@v6 with: name: CoverageReport-${{ matrix.os }}-${{ matrix.targetFramework }}-${{ matrix.configuration }} # Artifact name path: ./TestResults/Reports # Directory containing files to upload diff --git a/.github/workflows/python-code-quality.yml b/.github/workflows/python-code-quality.yml index 176eb3db99..4139d47156 100644 --- a/.github/workflows/python-code-quality.yml +++ b/.github/workflows/python-code-quality.yml @@ -39,7 +39,7 @@ jobs: env: # Configure a constant location for the uv cache UV_CACHE_DIR: /tmp/.uv-cache - - uses: actions/cache@v4 + - uses: actions/cache@v5 with: path: ~/.cache/pre-commit key: pre-commit|${{ matrix.python-version }}|${{ hashFiles('python/.pre-commit-config.yaml') }} diff --git a/.github/workflows/python-test-coverage-report.yml b/.github/workflows/python-test-coverage-report.yml index fa36073fc6..e09d9c8870 100644 --- a/.github/workflows/python-test-coverage-report.yml +++ b/.github/workflows/python-test-coverage-report.yml @@ -21,7 +21,7 @@ jobs: steps: - uses: actions/checkout@v6 - name: Download coverage report - uses: actions/download-artifact@v6 + uses: actions/download-artifact@v7 with: github-token: ${{ secrets.GH_ACTIONS_PR_WRITE }} run-id: ${{ github.event.workflow_run.id }} diff --git a/.github/workflows/python-test-coverage.yml b/.github/workflows/python-test-coverage.yml index 6268e7d47d..03cca20e06 100644 --- a/.github/workflows/python-test-coverage.yml +++ b/.github/workflows/python-test-coverage.yml @@ -38,7 +38,7 @@ jobs: - name: Run all tests with coverage report run: uv run poe all-tests-cov --cov-report=xml:python-coverage.xml -q --junitxml=pytest.xml - name: Upload coverage report - uses: actions/upload-artifact@v5 + uses: actions/upload-artifact@v6 with: path: | python/python-coverage.xml diff --git a/docs/features/durable-agents/durable-agents-ttl.md b/docs/features/durable-agents/durable-agents-ttl.md new file mode 100644 index 0000000000..1a4a4e32d6 --- /dev/null +++ b/docs/features/durable-agents/durable-agents-ttl.md @@ -0,0 +1,147 @@ +# Time-To-Live (TTL) for durable agent sessions + +## Overview + +The durable agents automatically maintain conversation history and state for each session. Without automatic cleanup, this state can accumulate indefinitely, consuming storage resources and increasing costs. The Time-To-Live (TTL) feature provides automatic cleanup of idle agent sessions, ensuring that sessions are automatically deleted after a period of inactivity. + +## What is TTL? + +Time-To-Live (TTL) is a configurable duration that determines how long an agent session state will be retained after its last interaction. When an agent session is idle (no messages sent to it) for longer than the TTL period, the session state is automatically deleted. Each new interaction with an agent resets the TTL timer, extending the session's lifetime. + +## Benefits + +- **Automatic cleanup**: No manual intervention required to clean up idle agent sessions +- **Cost optimization**: Reduces storage costs by automatically removing unused session state +- **Resource management**: Prevents unbounded growth of agent session state in storage +- **Configurable**: Set TTL globally or per-agent type to match your application's needs + +## Configuration + +TTL can be configured at two levels: + +1. **Global default TTL**: Applies to all agent sessions unless overridden +2. **Per-agent type TTL**: Overrides the global default for specific agent types + +Additionally, you can configure a **minimum deletion delay** that controls how frequently deletion operations are scheduled. The default value is 5 minutes, and the maximum allowed value is also 5 minutes. + +> [!NOTE] +> Reducing the minimum deletion delay below 5 minutes can be useful for testing or for ensuring rapid cleanup of short-lived agent sessions. However, this can also increase the load on the system and should be used with caution. + +### Default values + +- **Default TTL**: 14 days +- **Minimum TTL deletion delay**: 5 minutes (maximum allowed value, subject to change in future releases) + +### Configuration examples + +#### .NET + +```csharp +// Configure global default TTL and minimum signal delay +services.ConfigureDurableAgents( + options => + { + // Set global default TTL to 7 days + options.DefaultTimeToLive = TimeSpan.FromDays(7); + + // Add agents (will use global default TTL) + options.AddAIAgent(myAgent); + }); + +// Configure per-agent TTL +services.ConfigureDurableAgents( + options => + { + options.DefaultTimeToLive = TimeSpan.FromDays(14); // Global default + + // Agent with custom TTL of 1 day + options.AddAIAgent(shortLivedAgent, timeToLive: TimeSpan.FromDays(1)); + + // Agent with custom TTL of 90 days + options.AddAIAgent(longLivedAgent, timeToLive: TimeSpan.FromDays(90)); + + // Agent using global default (14 days) + options.AddAIAgent(defaultAgent); + }); + +// Disable TTL for specific agents by setting TTL to null +services.ConfigureDurableAgents( + options => + { + options.DefaultTimeToLive = TimeSpan.FromDays(14); + + // Agent with no TTL (never expires) + options.AddAIAgent(permanentAgent, timeToLive: null); + }); +``` + +## How TTL works + +The following sections describe how TTL works in detail. + +### Expiration tracking + +Each agent session maintains an expiration timestamp in its internally managed state that is updated whenever the session processes a message: + +1. When a message is sent to an agent session, the expiration time is set to `current time + TTL` +2. The runtime schedules a delete operation for the expiration time (subject to minimum delay constraints) +3. When the delete operation runs, if the current time is past the expiration time, the session state is deleted. Otherwise, the delete operation is rescheduled for the next expiration time. + +### State deletion + +When an agent session expires, its entire state is deleted, including: + +- Conversation history +- Any custom state data +- Expiration timestamps + +After deletion, if a message is sent to the same agent session, a new session is created with a fresh conversation history. + +## Behavior examples + +The following examples illustrate how TTL works in different scenarios. + +### Example 1: Agent session expires after TTL + +1. Agent configured with 30-day TTL +2. User sends message at Day 0 → agent session created, expiration set to Day 30 +3. No further messages sent +4. At Day 30 → Agent session is deleted +5. User sends message at Day 31 → New agent session created with fresh conversation history + +### Example 2: TTL reset on interaction + +1. Agent configured with 30-day TTL +2. User sends message at Day 0 → agent session created, expiration set to Day 30 +3. User sends message at Day 15 → Expiration reset to Day 45 +4. User sends message at Day 40 → Expiration reset to Day 70 +5. Agent session remains active as long as there are regular interactions + +## Logging + +The TTL feature includes comprehensive logging to track state changes: + +- **Expiration time updated**: Logged when TTL expiration time is set or updated +- **Deletion scheduled**: Logged when a deletion check signal is scheduled +- **Deletion check**: Logged when a deletion check operation runs +- **Session expired**: Logged when an agent session is deleted due to expiration +- **TTL rescheduled**: Logged when a deletion signal is rescheduled + +These logs help monitor TTL behavior and troubleshoot any issues. + +## Best practices + +1. **Choose appropriate TTL values**: Balance between storage costs and user experience. Too short TTLs may delete active sessions, while too long TTLs may accumulate unnecessary state. + +2. **Use per-agent TTLs**: Different agents may have different usage patterns. Configure TTLs per-agent based on expected session lifetimes. + +3. **Monitor expiration logs**: Review logs to understand TTL behavior and adjust configuration as needed. + +4. **Test with short TTLs**: During development, use short TTLs (e.g., minutes) to verify TTL behavior without waiting for long periods. + +## Limitations + +- TTL is based on wall-clock time, not activity time. The expiration timer starts from the last message timestamp. +- Deletion checks are durably scheduled operations and may have slight delays depending on system load. +- Once an agent session is deleted, its conversation history cannot be recovered. +- TTL deletion requires at least one worker to be available to process the deletion operation message. diff --git a/dotnet/.editorconfig b/dotnet/.editorconfig index c0d0d04fe9..fea0183976 100644 --- a/dotnet/.editorconfig +++ b/dotnet/.editorconfig @@ -209,6 +209,7 @@ dotnet_diagnostic.CA2000.severity = none # Call System.IDisposable.Dispose on ob dotnet_diagnostic.CA2225.severity = none # Operator overloads have named alternates dotnet_diagnostic.CA2227.severity = none # Change to be read-only by removing the property setter dotnet_diagnostic.CA2249.severity = suggestion # Consider using 'Contains' method instead of 'IndexOf' method +dotnet_diagnostic.CA2252.severity = none # Requires preview dotnet_diagnostic.CA2253.severity = none # Named placeholders in the logging message template should not be comprised of only numeric characters dotnet_diagnostic.CA2253.severity = none # Named placeholders in the logging message template should not be comprised of only numeric characters dotnet_diagnostic.CA2263.severity = suggestion # Use generic overload diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index c7a051bf83..32adaae308 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -11,7 +11,7 @@ - + @@ -19,10 +19,10 @@ - - + + - + @@ -61,10 +61,9 @@ - - - - + + + @@ -101,11 +100,10 @@ - - + diff --git a/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/AgentWebChat.AgentHost.csproj b/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/AgentWebChat.AgentHost.csproj index 3f2a832a69..f71becf5d3 100644 --- a/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/AgentWebChat.AgentHost.csproj +++ b/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/AgentWebChat.AgentHost.csproj @@ -25,7 +25,6 @@ - diff --git a/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/Utilities/ChatClientExtensions.cs b/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/Utilities/ChatClientExtensions.cs index 6cd3d888c8..7b1f2d86b4 100644 --- a/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/Utilities/ChatClientExtensions.cs +++ b/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/Utilities/ChatClientExtensions.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using AgentWebChat.AgentHost.Utilities; -using Azure; -using Azure.AI.Inference; using Microsoft.Extensions.AI; using OllamaSharp; @@ -24,7 +22,6 @@ public static ChatClientBuilder AddChatClient(this IHostApplicationBuilder build ClientChatProvider.Ollama => builder.AddOllamaClient(connectionName, connectionInfo), ClientChatProvider.OpenAI => builder.AddOpenAIClient(connectionName, connectionInfo), ClientChatProvider.AzureOpenAI => builder.AddAzureOpenAIClient(connectionName).AddChatClient(connectionInfo.SelectedModel), - ClientChatProvider.AzureAIInference => builder.AddAzureInferenceClient(connectionName, connectionInfo), _ => throw new NotSupportedException($"Unsupported provider: {connectionInfo.Provider}") }; @@ -44,16 +41,6 @@ private static ChatClientBuilder AddOpenAIClient(this IHostApplicationBuilder bu }) .AddChatClient(connectionInfo.SelectedModel); - private static ChatClientBuilder AddAzureInferenceClient(this IHostApplicationBuilder builder, string connectionName, ChatClientConnectionInfo connectionInfo) => - builder.Services.AddChatClient(sp => - { - var credential = new AzureKeyCredential(connectionInfo.AccessKey!); - - var client = new ChatCompletionsClient(connectionInfo.Endpoint, credential, new AzureAIInferenceClientOptions()); - - return client.AsIChatClient(connectionInfo.SelectedModel); - }); - private static ChatClientBuilder AddOllamaClient(this IHostApplicationBuilder builder, string connectionName, ChatClientConnectionInfo connectionInfo) { var httpKey = $"{connectionName}_http"; @@ -83,7 +70,6 @@ public static ChatClientBuilder AddKeyedChatClient(this IHostApplicationBuilder ClientChatProvider.Ollama => builder.AddKeyedOllamaClient(connectionName, connectionInfo), ClientChatProvider.OpenAI => builder.AddKeyedOpenAIClient(connectionName, connectionInfo), ClientChatProvider.AzureOpenAI => builder.AddKeyedAzureOpenAIClient(connectionName).AddKeyedChatClient(connectionName, connectionInfo.SelectedModel), - ClientChatProvider.AzureAIInference => builder.AddKeyedAzureInferenceClient(connectionName, connectionInfo), _ => throw new NotSupportedException($"Unsupported provider: {connectionInfo.Provider}") }; @@ -103,16 +89,6 @@ private static ChatClientBuilder AddKeyedOpenAIClient(this IHostApplicationBuild }) .AddKeyedChatClient(connectionName, connectionInfo.SelectedModel); - private static ChatClientBuilder AddKeyedAzureInferenceClient(this IHostApplicationBuilder builder, string connectionName, ChatClientConnectionInfo connectionInfo) => - builder.Services.AddKeyedChatClient(connectionName, sp => - { - var credential = new AzureKeyCredential(connectionInfo.AccessKey!); - - var client = new ChatCompletionsClient(connectionInfo.Endpoint, credential, new AzureAIInferenceClientOptions()); - - return client.AsIChatClient(connectionInfo.SelectedModel); - }); - private static ChatClientBuilder AddKeyedOllamaClient(this IHostApplicationBuilder builder, string connectionName, ChatClientConnectionInfo connectionInfo) { var httpKey = $"{connectionName}_http"; diff --git a/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIResponsesAgentClient.cs b/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIResponsesAgentClient.cs index 7cc85b97c3..d0121a6165 100644 --- a/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIResponsesAgentClient.cs +++ b/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIResponsesAgentClient.cs @@ -27,7 +27,7 @@ public override async IAsyncEnumerable RunStreamingAsync Transport = new HttpClientPipelineTransport(httpClient) }; - var openAiClient = new OpenAIResponseClient(model: agentName, credential: new ApiKeyCredential("dummy-key"), options: options).AsIChatClient(); + var openAiClient = new ResponsesClient(model: agentName, credential: new ApiKeyCredential("dummy-key"), options: options).AsIChatClient(); var chatOptions = new ChatOptions() { ConversationId = threadId diff --git a/dotnet/samples/AzureFunctions/01_SingleAgent/Program.cs b/dotnet/samples/AzureFunctions/01_SingleAgent/Program.cs index b3d40a120c..39b020e137 100644 --- a/dotnet/samples/AzureFunctions/01_SingleAgent/Program.cs +++ b/dotnet/samples/AzureFunctions/01_SingleAgent/Program.cs @@ -32,6 +32,6 @@ using IHost app = FunctionsApplication .CreateBuilder(args) .ConfigureFunctionsWebApplication() - .ConfigureDurableAgents(options => options.AddAIAgent(agent)) + .ConfigureDurableAgents(options => options.AddAIAgent(agent, timeToLive: TimeSpan.FromHours(1))) .Build(); app.Run(); diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIResponses/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIResponses/Program.cs index 83d5619382..5ce85b2b91 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIResponses/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIResponses/Program.cs @@ -13,7 +13,7 @@ AIAgent agent = new AzureOpenAIClient( new Uri(endpoint), new AzureCliCredential()) - .GetOpenAIResponseClient(deploymentName) + .GetResponsesClient(deploymentName) .CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); // Invoke the agent and output the text result. diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs index 8f1039251d..a7dafbfc50 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs @@ -45,7 +45,7 @@ public override async Task RunAsync(IEnumerable m } // Clone the input messages and turn them into response messages with upper case text. - List responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList(); + List responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the thread of the input and output messages. await typedThread.MessageStore.AddMessagesAsync(messages.Concat(responseMessages), cancellationToken); @@ -69,7 +69,7 @@ public override async IAsyncEnumerable RunStreamingAsync } // Clone the input messages and turn them into response messages with upper case text. - List responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList(); + List responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the thread of the input and output messages. await typedThread.MessageStore.AddMessagesAsync(messages.Concat(responseMessages), cancellationToken); @@ -79,7 +79,7 @@ public override async IAsyncEnumerable RunStreamingAsync yield return new AgentRunResponseUpdate { AgentId = this.Id, - AuthorName = this.DisplayName, + AuthorName = message.AuthorName, Role = ChatRole.Assistant, Contents = message.Contents, ResponseId = Guid.NewGuid().ToString("N"), @@ -88,7 +88,7 @@ public override async IAsyncEnumerable RunStreamingAsync } } - private static IEnumerable CloneAndToUpperCase(IEnumerable messages, string agentName) => messages.Select(x => + private static IEnumerable CloneAndToUpperCase(IEnumerable messages, string? agentName) => messages.Select(x => { // Clone the message and update its author to be the agent. var messageClone = x.Clone(); diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIResponses/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIResponses/Program.cs index df53ba8869..b0d0285928 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIResponses/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIResponses/Program.cs @@ -11,7 +11,7 @@ AIAgent agent = new OpenAIClient( apiKey) - .GetOpenAIResponseClient(model) + .GetResponsesClient(model) .CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); // Invoke the agent and output the text result. diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step02_Reasoning/Program.cs b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step02_Reasoning/Program.cs index e06a8cc76f..aa18fdd286 100644 --- a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step02_Reasoning/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step02_Reasoning/Program.cs @@ -11,11 +11,11 @@ var model = Environment.GetEnvironmentVariable("OPENAI_MODEL") ?? "gpt-5"; var client = new OpenAIClient(apiKey) - .GetOpenAIResponseClient(model) + .GetResponsesClient(model) .AsIChatClient().AsBuilder() .ConfigureOptions(o => { - o.RawRepresentationFactory = _ => new ResponseCreationOptions() + o.RawRepresentationFactory = _ => new CreateResponseOptions() { ReasoningOptions = new() { diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/OpenAIResponseClientAgent.cs b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/OpenAIResponseClientAgent.cs index 456de02836..622223307c 100644 --- a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/OpenAIResponseClientAgent.cs +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/OpenAIResponseClientAgent.cs @@ -16,13 +16,13 @@ public class OpenAIResponseClientAgent : DelegatingAIAgent /// /// Initialize an instance of . /// - /// Instance of + /// Instance of /// Optional instructions for the agent. /// Optional name for the agent. /// Optional description for the agent. /// Optional instance of public OpenAIResponseClientAgent( - OpenAIResponseClient client, + ResponsesClient client, string? instructions = null, string? name = null, string? description = null, @@ -39,11 +39,11 @@ public OpenAIResponseClientAgent( /// /// Initialize an instance of . /// - /// Instance of + /// Instance of /// Options to create the agent. /// Optional instance of public OpenAIResponseClientAgent( - OpenAIResponseClient client, ChatClientAgentOptions options, ILoggerFactory? loggerFactory = null) : + ResponsesClient client, ChatClientAgentOptions options, ILoggerFactory? loggerFactory = null) : base(new ChatClientAgent((client ?? throw new ArgumentNullException(nameof(client))).AsIChatClient(), options, loggerFactory)) { } @@ -55,8 +55,8 @@ public OpenAIResponseClientAgent( /// The conversation thread to continue with this invocation. If not provided, creates a new thread. The thread will be mutated with the provided messages and agent response. /// Optional parameters for agent invocation. /// The to monitor for cancellation requests. The default is . - /// A containing the list of items. - public virtual async Task RunAsync( + /// A containing the list of items. + public virtual async Task RunAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -74,7 +74,7 @@ public virtual async Task RunAsync( /// The conversation thread to continue with this invocation. If not provided, creates a new thread. The thread will be mutated with the provided messages and agent response. /// Optional parameters for agent invocation. /// The to monitor for cancellation requests. The default is . - /// A containing the list of items. + /// A containing the list of items. public virtual async IAsyncEnumerable RunStreamingAsync( IEnumerable messages, AgentThread? thread = null, diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/Program.cs b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/Program.cs index 89a96bc0fb..5c229cc57d 100644 --- a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/Program.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -// This sample demonstrates how to create OpenAIResponseClientAgent directly from an OpenAIResponseClient instance. +// This sample demonstrates how to create OpenAIResponseClientAgent directly from an ResponsesClient instance. using OpenAI; using OpenAI.Responses; @@ -9,16 +9,16 @@ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new InvalidOperationException("OPENAI_API_KEY is not set."); var model = Environment.GetEnvironmentVariable("OPENAI_MODEL") ?? "gpt-4o-mini"; -// Create an OpenAIResponseClient directly from OpenAIClient -OpenAIResponseClient responseClient = new OpenAIClient(apiKey).GetOpenAIResponseClient(model); +// Create a ResponsesClient directly from OpenAIClient +ResponsesClient responseClient = new OpenAIClient(apiKey).GetResponsesClient(model); -// Create an agent directly from the OpenAIResponseClient using OpenAIResponseClientAgent +// Create an agent directly from the ResponsesClient using OpenAIResponseClientAgent OpenAIResponseClientAgent agent = new(responseClient, instructions: "You are good at telling jokes.", name: "Joker"); ResponseItem userMessage = ResponseItem.CreateUserMessageItem("Tell me a joke about a pirate."); // Invoke the agent and output the text result. -OpenAIResponse response = await agent.RunAsync([userMessage]); +ResponseResult response = await agent.RunAsync([userMessage]); Console.WriteLine(response.GetOutputText()); // Invoke the agent with streaming support. diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/Program.cs b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/Program.cs index 9f81a27dda..8aebebdfa0 100644 --- a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/Program.cs @@ -21,8 +21,8 @@ OpenAIClient openAIClient = new(apiKey); ConversationClient conversationClient = openAIClient.GetConversationClient(); -// Create an agent directly from the OpenAIResponseClient using OpenAIResponseClientAgent -ChatClientAgent agent = new(openAIClient.GetOpenAIResponseClient(model).AsIChatClient(), instructions: "You are a helpful assistant.", name: "ConversationAgent"); +// Create an agent directly from the ResponsesClient using OpenAIResponseClientAgent +ChatClientAgent agent = new(openAIClient.GetResponsesClient(model).AsIChatClient(), instructions: "You are a helpful assistant.", name: "ConversationAgent"); ClientResult createConversationResult = await conversationClient.CreateConversationAsync(BinaryContent.Create(BinaryData.FromString("{}"))); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step13_BackgroundResponsesWithToolsAndPersistence/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step13_BackgroundResponsesWithToolsAndPersistence/Program.cs index 41493f6d79..29dc347b4a 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step13_BackgroundResponsesWithToolsAndPersistence/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step13_BackgroundResponsesWithToolsAndPersistence/Program.cs @@ -22,7 +22,7 @@ AIAgent agent = new AzureOpenAIClient( new Uri(endpoint), new AzureCliCredential()) - .GetOpenAIResponseClient(deploymentName) + .GetResponsesClient(deploymentName) .CreateAIAgent( name: "SpaceNovelWriter", instructions: "You are a space novel writer. Always research relevant facts and generate character profiles for the main characters before writing novels." + diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step17_BackgroundResponses/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step17_BackgroundResponses/Program.cs index 510a5dfbd0..3e172a95b5 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step17_BackgroundResponses/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step17_BackgroundResponses/Program.cs @@ -13,7 +13,7 @@ AIAgent agent = new AzureOpenAIClient( new Uri(endpoint), new AzureCliCredential()) - .GetOpenAIResponseClient(deploymentName) + .GetResponsesClient(deploymentName) .CreateAIAgent(); // Enable background responses (only supported by OpenAI Responses at this time). diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step15_ComputerUse/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step15_ComputerUse/Program.cs index 05fb39bbf4..ff4f57924a 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step15_ComputerUse/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step15_ComputerUse/Program.cs @@ -73,7 +73,7 @@ private static async Task InvokeComputerUseAgentAsync(AIAgent agent) Dictionary screenshots = ComputerUseUtil.LoadScreenshotAssets(); ChatOptions chatOptions = new(); - ResponseCreationOptions responseCreationOptions = new() + CreateResponseOptions responseCreationOptions = new() { TruncationMode = ResponseTruncationMode.Auto }; diff --git a/dotnet/samples/GettingStarted/ModelContextProtocol/ResponseAgent_Hosted_MCP/Program.cs b/dotnet/samples/GettingStarted/ModelContextProtocol/ResponseAgent_Hosted_MCP/Program.cs index ba4249c765..13ee28d6a1 100644 --- a/dotnet/samples/GettingStarted/ModelContextProtocol/ResponseAgent_Hosted_MCP/Program.cs +++ b/dotnet/samples/GettingStarted/ModelContextProtocol/ResponseAgent_Hosted_MCP/Program.cs @@ -30,7 +30,7 @@ AIAgent agent = new AzureOpenAIClient( new Uri(endpoint), new AzureCliCredential()) - .GetOpenAIResponseClient(deploymentName) + .GetResponsesClient(deploymentName) .CreateAIAgent( instructions: "You answer questions by searching the Microsoft Learn content only.", name: "MicrosoftLearnAgent", @@ -57,7 +57,7 @@ AIAgent agentWithRequiredApproval = new AzureOpenAIClient( new Uri(endpoint), new AzureCliCredential()) - .GetOpenAIResponseClient(deploymentName) + .GetResponsesClient(deploymentName) .CreateAIAgent( instructions: "You answer questions by searching the Microsoft Learn content only.", name: "MicrosoftLearnAgentWithApproval", diff --git a/dotnet/samples/Purview/AgentWithPurview/Program.cs b/dotnet/samples/Purview/AgentWithPurview/Program.cs index 842917b427..a4b27c47cd 100644 --- a/dotnet/samples/Purview/AgentWithPurview/Program.cs +++ b/dotnet/samples/Purview/AgentWithPurview/Program.cs @@ -27,7 +27,7 @@ using IChatClient client = new AzureOpenAIClient( new Uri(endpoint), new AzureCliCredential()) - .GetOpenAIResponseClient(deploymentName) + .GetResponsesClient(deploymentName) .AsIChatClient() .AsBuilder() .WithPurview(browserCredential, new PurviewSettings("Agent Framework Test App")) diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs index e804fbb389..e326151b13 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs @@ -30,7 +30,6 @@ internal sealed class A2AAgent : AIAgent private readonly string? _id; private readonly string? _name; private readonly string? _description; - private readonly string? _displayName; private readonly ILogger _logger; /// @@ -40,9 +39,8 @@ internal sealed class A2AAgent : AIAgent /// The unique identifier for the agent. /// The the name of the agent. /// The description of the agent. - /// The display name of the agent. /// Optional logger factory to use for logging. - public A2AAgent(A2AClient a2aClient, string? id = null, string? name = null, string? description = null, string? displayName = null, ILoggerFactory? loggerFactory = null) + public A2AAgent(A2AClient a2aClient, string? id = null, string? name = null, string? description = null, ILoggerFactory? loggerFactory = null) { _ = Throw.IfNull(a2aClient); @@ -50,7 +48,6 @@ public A2AAgent(A2AClient a2aClient, string? id = null, string? name = null, str this._id = id; this._name = name; this._description = description; - this._displayName = displayName; this._logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } @@ -203,9 +200,6 @@ public override async IAsyncEnumerable RunStreamingAsync /// public override string? Name => this._name ?? base.Name; - /// - public override string DisplayName => this._displayName ?? base.DisplayName; - /// public override string? Description => this._description ?? base.Description; diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2AClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2AClientExtensions.cs index 095481c0d4..d57ed4cb42 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2AClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/Extensions/A2AClientExtensions.cs @@ -33,9 +33,8 @@ public static class A2AClientExtensions /// The unique identifier for the agent. /// The the name of the agent. /// The description of the agent. - /// The display name of the agent. /// Optional logger factory for enabling logging within the agent. /// An instance backed by the A2A agent. - public static AIAgent GetAIAgent(this A2AClient client, string? id = null, string? name = null, string? description = null, string? displayName = null, ILoggerFactory? loggerFactory = null) => - new A2AAgent(client, id, name, description, displayName, loggerFactory); + public static AIAgent GetAIAgent(this A2AClient client, string? id = null, string? name = null, string? description = null, ILoggerFactory? loggerFactory = null) => + new A2AAgent(client, id, name, description, loggerFactory); } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs index 4cff385dcc..1f39a2758f 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs @@ -60,18 +60,6 @@ public abstract class AIAgent /// public virtual string? Name { get; } - /// - /// Gets a display-friendly name for the agent. - /// - /// - /// The agent's if available, otherwise the . - /// - /// - /// This property provides a guaranteed non-null string suitable for display in user interfaces, - /// logs, or other contexts where a readable identifier is needed. - /// - public virtual string DisplayName => this.Name ?? this.Id; - /// /// Gets a description of the agent's purpose, capabilities, or behavior. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs index 72f1980b1c..4c0ff1a36d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs @@ -25,7 +25,7 @@ namespace Microsoft.Agents.AI; /// Derived classes can override specific methods to add custom behavior while maintaining compatibility with the agent interface. /// /// -public class DelegatingAIAgent : AIAgent +public abstract class DelegatingAIAgent : AIAgent { /// /// Initializes a new instance of the class with the specified inner agent. diff --git a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClient.cs b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClient.cs index 8acafc8fc3..f31c570508 100644 --- a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClient.cs +++ b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClient.cs @@ -23,11 +23,6 @@ internal sealed class AzureAIProjectChatClient : DelegatingChatClient private readonly AgentRecord? _agentRecord; private readonly ChatOptions? _chatOptions; private readonly AgentReference _agentReference; - /// - /// The usage of a no-op model is a necessary change to avoid OpenAIClients to throw exceptions when - /// used with Azure AI Agents as the model used is now defined at the agent creation time. - /// - private const string NoOpModel = "no-op"; /// /// Initializes a new instance of the class. @@ -42,7 +37,7 @@ internal sealed class AzureAIProjectChatClient : DelegatingChatClient internal AzureAIProjectChatClient(AIProjectClient aiProjectClient, AgentReference agentReference, string? defaultModelId, ChatOptions? chatOptions) : base(Throw.IfNull(aiProjectClient) .GetProjectOpenAIClient() - .GetOpenAIResponseClient(defaultModelId ?? NoOpModel) + .GetProjectResponsesClientForAgent(agentReference) .AsIChatClient()) { this._agentClient = aiProjectClient; @@ -132,13 +127,15 @@ private ChatOptions GetAgentEnabledChatOptions(ChatOptions? options) agentEnabledChatOptions.RawRepresentationFactory = (client) => { - if (originalFactory?.Invoke(this) is not ResponseCreationOptions responseCreationOptions) + if (originalFactory?.Invoke(this) is not CreateResponseOptions responseCreationOptions) { - responseCreationOptions = new ResponseCreationOptions(); + responseCreationOptions = new CreateResponseOptions(); } - ResponseCreationOptionsExtensions.set_Agent(responseCreationOptions, this._agentReference); - ResponseCreationOptionsExtensions.set_Model(responseCreationOptions, null); + responseCreationOptions.Agent = this._agentReference; +#pragma warning disable SCME0001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + responseCreationOptions.Patch.Remove("$.model"u8); +#pragma warning restore SCME0001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. return responseCreationOptions; }; diff --git a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs index dfbdad8e98..7319bb13eb 100644 --- a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs @@ -400,7 +400,7 @@ public static ChatClientAgent CreateAIAgent( }; // Attempt to capture breaking glass options from the raw representation factory that match the agent definition. - if (options.ChatOptions?.RawRepresentationFactory?.Invoke(new NoOpChatClient()) is ResponseCreationOptions respCreationOptions) + if (options.ChatOptions?.RawRepresentationFactory?.Invoke(new NoOpChatClient()) is CreateResponseOptions respCreationOptions) { agentDefinition.ReasoningOptions = respCreationOptions.ReasoningOptions; } @@ -466,7 +466,7 @@ public static async Task CreateAIAgentAsync( }; // Attempt to capture breaking glass options from the raw representation factory that match the agent definition. - if (options.ChatOptions?.RawRepresentationFactory?.Invoke(new NoOpChatClient()) is ResponseCreationOptions respCreationOptions) + if (options.ChatOptions?.RawRepresentationFactory?.Invoke(new NoOpChatClient()) is CreateResponseOptions respCreationOptions) { agentDefinition.ReasoningOptions = respCreationOptions.ReasoningOptions; } diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosCheckpointStore.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosCheckpointStore.cs index 62987b1dfc..e0073feaf9 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosCheckpointStore.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosCheckpointStore.cs @@ -217,9 +217,7 @@ protected virtual void Dispose(bool disposing) } } - /// - /// Represents a checkpoint document stored in Cosmos DB. - /// + /// Represents a checkpoint document stored in Cosmos DB. internal sealed class CosmosCheckpointDocument { [JsonProperty("id")] diff --git a/dotnet/src/Microsoft.Agents.AI.DevUI/EntitiesApiExtensions.cs b/dotnet/src/Microsoft.Agents.AI.DevUI/EntitiesApiExtensions.cs index 3271b40853..8dcc46b53c 100644 --- a/dotnet/src/Microsoft.Agents.AI.DevUI/EntitiesApiExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.DevUI/EntitiesApiExtensions.cs @@ -231,7 +231,7 @@ private static EntityInfo CreateAgentEntityInfo(AIAgent agent) return new EntityInfo( Id: entityId, Type: "agent", - Name: agent.DisplayName, + Name: agent.Name ?? agent.Id, Description: agent.Description, Framework: "agent_framework", Tools: tools, diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentEntity.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentEntity.cs index 166799a124..ec4ba3acf6 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentEntity.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentEntity.cs @@ -16,29 +16,34 @@ internal class AgentEntity(IServiceProvider services, CancellationToken cancella private readonly DurableTaskClient _client = services.GetRequiredService(); private readonly ILoggerFactory _loggerFactory = services.GetRequiredService(); private readonly IAgentResponseHandler? _messageHandler = services.GetService(); + private readonly DurableAgentsOptions _options = services.GetRequiredService(); private readonly CancellationToken _cancellationToken = cancellationToken != default ? cancellationToken : services.GetService()?.ApplicationStopping ?? CancellationToken.None; - public async Task RunAgentAsync(RunRequest request) + public Task RunAgentAsync(RunRequest request) { - AgentSessionId sessionId = this.Context.Id; - IReadOnlyDictionary> agents = - this._services.GetRequiredService>>(); - if (!agents.TryGetValue(sessionId.Name, out Func? agentFactory)) - { - throw new InvalidOperationException($"Agent '{sessionId.Name}' not found"); - } + return this.Run(request); + } - AIAgent agent = agentFactory(this._services); + // IDE1006 and VSTHRD200 disabled to allow method name to match the common cross-platform entity operation name. +#pragma warning disable IDE1006 +#pragma warning disable VSTHRD200 + public async Task Run(RunRequest request) +#pragma warning restore VSTHRD200 +#pragma warning restore IDE1006 + { + AgentSessionId sessionId = this.Context.Id; + AIAgent agent = this.GetAgent(sessionId); EntityAgentWrapper agentWrapper = new(agent, this.Context, request, this._services); // Logger category is Microsoft.DurableTask.Agents.{agentName}.{sessionId} - ILogger logger = this._loggerFactory.CreateLogger($"Microsoft.DurableTask.Agents.{agent.Name}.{sessionId.Key}"); + ILogger logger = this.GetLogger(agent.Name!, sessionId.Key); if (request.Messages.Count == 0) { logger.LogInformation("Ignoring empty request"); + return new AgentRunResponse(); } this.State.Data.ConversationHistory.Add(DurableAgentStateRequest.FromRunRequest(request)); @@ -113,6 +118,36 @@ async IAsyncEnumerable StreamResultsAsync() response.Usage?.TotalTokenCount); } + // Update TTL expiration time. Only schedule deletion check on first interaction. + // Subsequent interactions just update the expiration time; CheckAndDeleteIfExpiredAsync + // will reschedule the deletion check when it runs. + TimeSpan? timeToLive = this._options.GetTimeToLive(sessionId.Name); + if (timeToLive.HasValue) + { + DateTime newExpirationTime = DateTime.UtcNow.Add(timeToLive.Value); + bool isFirstInteraction = this.State.Data.ExpirationTimeUtc is null; + + this.State.Data.ExpirationTimeUtc = newExpirationTime; + logger.LogTTLExpirationTimeUpdated(sessionId, newExpirationTime); + + // Only schedule deletion check on the first interaction when entity is created. + // On subsequent interactions, we just update the expiration time. The scheduled + // CheckAndDeleteIfExpiredAsync will reschedule itself if the entity hasn't expired. + if (isFirstInteraction) + { + this.ScheduleDeletionCheck(sessionId, logger, timeToLive.Value); + } + } + else + { + // TTL is disabled. Clear the expiration time if it was previously set. + if (this.State.Data.ExpirationTimeUtc.HasValue) + { + logger.LogTTLExpirationTimeCleared(sessionId); + this.State.Data.ExpirationTimeUtc = null; + } + } + return response; } finally @@ -121,4 +156,78 @@ async IAsyncEnumerable StreamResultsAsync() DurableAgentContext.ClearCurrent(); } } + + /// + /// Checks if the entity has expired and deletes it if so, otherwise reschedules the deletion check. + /// + /// + /// This method is called by the durable task runtime when a CheckAndDeleteIfExpired signal is received. + /// + public void CheckAndDeleteIfExpired() + { + AgentSessionId sessionId = this.Context.Id; + AIAgent agent = this.GetAgent(sessionId); + ILogger logger = this.GetLogger(agent.Name!, sessionId.Key); + + DateTime currentTime = DateTime.UtcNow; + DateTime? expirationTime = this.State.Data.ExpirationTimeUtc; + + logger.LogTTLDeletionCheck(sessionId, expirationTime, currentTime); + + if (expirationTime.HasValue) + { + if (currentTime >= expirationTime.Value) + { + // Entity has expired, delete it + logger.LogTTLEntityExpired(sessionId, expirationTime.Value); + this.State = null!; + } + else + { + // Entity hasn't expired yet, reschedule the deletion check + TimeSpan? timeToLive = this._options.GetTimeToLive(sessionId.Name); + if (timeToLive.HasValue) + { + this.ScheduleDeletionCheck(sessionId, logger, timeToLive.Value); + } + } + } + } + + private void ScheduleDeletionCheck(AgentSessionId sessionId, ILogger logger, TimeSpan timeToLive) + { + DateTime currentTime = DateTime.UtcNow; + DateTime expirationTime = this.State.Data.ExpirationTimeUtc ?? currentTime.Add(timeToLive); + TimeSpan minimumDelay = this._options.MinimumTimeToLiveSignalDelay; + + // To avoid excessive scheduling, we schedule the deletion check for no less than the minimum delay. + DateTime scheduledTime = expirationTime > currentTime.Add(minimumDelay) + ? expirationTime + : currentTime.Add(minimumDelay); + + logger.LogTTLDeletionScheduled(sessionId, scheduledTime); + + // Schedule a signal to self to check for expiration + this.Context.SignalEntity( + this.Context.Id, + nameof(CheckAndDeleteIfExpired), // self-signal + options: new SignalEntityOptions { SignalTime = scheduledTime }); + } + + private AIAgent GetAgent(AgentSessionId sessionId) + { + IReadOnlyDictionary> agents = + this._services.GetRequiredService>>(); + if (!agents.TryGetValue(sessionId.Name, out Func? agentFactory)) + { + throw new InvalidOperationException($"Agent '{sessionId.Name}' not found"); + } + + return agentFactory(this._services); + } + + private ILogger GetLogger(string agentName, string sessionKey) + { + return this._loggerFactory.CreateLogger($"Microsoft.DurableTask.Agents.{agentName}.{sessionKey}"); + } } diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/CHANGELOG.md b/dotnet/src/Microsoft.Agents.AI.DurableTask/CHANGELOG.md index d2cdc7cd41..ccc6aa7181 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/CHANGELOG.md +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/CHANGELOG.md @@ -1,5 +1,12 @@ # Release History +## [Unreleased] + +### Changed + +- Added TTL configuration for durable agent entities ([#2679](https://github.com/microsoft/agent-framework/pull/2679)) +- Switch to new "Run" method name ([#2843](https://github.com/microsoft/agent-framework/pull/2843)) + ## v1.0.0-preview.251204.1 - Added orchestration ID to durable agent entity state ([#2137](https://github.com/microsoft/agent-framework/pull/2137)) diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DefaultDurableAgentClient.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DefaultDurableAgentClient.cs index 2086a00ecb..9005641860 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DefaultDurableAgentClient.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DefaultDurableAgentClient.cs @@ -22,7 +22,7 @@ public async Task RunAgentAsync( await this._client.Entities.SignalEntityAsync( sessionId, - nameof(AgentEntity.RunAgentAsync), + nameof(AgentEntity.Run), request, cancellation: cancellationToken); diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs index 021c8f22c7..2035b792fd 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs @@ -107,7 +107,7 @@ public override async Task RunAsync( { return await this._context.Entities.CallEntityAsync( durableThread.SessionId, - nameof(AgentEntity.RunAgentAsync), + nameof(AgentEntity.Run), request); } catch (EntityOperationFailedException e) when (e.FailureDetails.ErrorType == "EntityTaskNotFound") diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentsOptions.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentsOptions.cs index f2ac3f4c9a..cefcad323a 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentsOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentsOptions.cs @@ -9,23 +9,67 @@ public sealed class DurableAgentsOptions { // Agent names are case-insensitive private readonly Dictionary> _agentFactories = new(StringComparer.OrdinalIgnoreCase); + private readonly Dictionary _agentTimeToLive = new(StringComparer.OrdinalIgnoreCase); internal DurableAgentsOptions() { } + /// + /// Gets or sets the default time-to-live (TTL) for agent entities. + /// + /// + /// If an agent entity is idle for this duration, it will be automatically deleted. + /// Defaults to 14 days. Set to to disable TTL for agents without explicit TTL configuration. + /// + public TimeSpan? DefaultTimeToLive { get; set; } = TimeSpan.FromDays(14); + + /// + /// Gets or sets the minimum delay for scheduling TTL deletion signals. Defaults to 5 minutes. + /// + /// + /// This property is primarily useful for testing (where shorter delays are needed) or for + /// shorter-lived agents in workflows that need more rapid cleanup. The maximum allowed value is 5 minutes. + /// Reducing the minimum deletion delay below 5 minutes can be useful for testing or for ensuring rapid cleanup of short-lived agent sessions. + /// However, this can also increase the load on the system and should be used with caution. + /// + /// Thrown when the value exceeds 5 minutes. + public TimeSpan MinimumTimeToLiveSignalDelay + { + get; + set + { + const int MaximumDelayMinutes = 5; + if (value > TimeSpan.FromMinutes(MaximumDelayMinutes)) + { + throw new ArgumentOutOfRangeException( + nameof(value), + value, + $"The minimum time-to-live signal delay cannot exceed {MaximumDelayMinutes} minutes."); + } + + field = value; + } + } = TimeSpan.FromMinutes(5); + /// /// Adds an AI agent factory to the options. /// /// The name of the agent. /// The factory function to create the agent. + /// Optional time-to-live for this agent's entities. If not specified, uses . /// The options instance. /// Thrown when or is null. - public DurableAgentsOptions AddAIAgentFactory(string name, Func factory) + public DurableAgentsOptions AddAIAgentFactory(string name, Func factory, TimeSpan? timeToLive = null) { ArgumentNullException.ThrowIfNull(name); ArgumentNullException.ThrowIfNull(factory); this._agentFactories.Add(name, factory); + if (timeToLive.HasValue) + { + this._agentTimeToLive[name] = timeToLive; + } + return this; } @@ -50,12 +94,13 @@ public DurableAgentsOptions AddAIAgents(params IEnumerable agents) /// Adds an AI agent to the options. /// /// The agent to add. + /// Optional time-to-live for this agent's entities. If not specified, uses . /// The options instance. /// Thrown when is null. /// /// Thrown when is null or whitespace or when an agent with the same name has already been registered. /// - public DurableAgentsOptions AddAIAgent(AIAgent agent) + public DurableAgentsOptions AddAIAgent(AIAgent agent, TimeSpan? timeToLive = null) { ArgumentNullException.ThrowIfNull(agent); @@ -70,6 +115,11 @@ public DurableAgentsOptions AddAIAgent(AIAgent agent) } this._agentFactories.Add(agent.Name, sp => agent); + if (timeToLive.HasValue) + { + this._agentTimeToLive[agent.Name] = timeToLive; + } + return this; } @@ -81,4 +131,14 @@ internal IReadOnlyDictionary> GetAgentFa { return this._agentFactories.AsReadOnly(); } + + /// + /// Gets the time-to-live for a specific agent, or the default TTL if not specified. + /// + /// The name of the agent. + /// The time-to-live for the agent, or the default TTL if not specified. + internal TimeSpan? GetTimeToLive(string agentName) + { + return this._agentTimeToLive.TryGetValue(agentName, out TimeSpan? ttl) ? ttl : this.DefaultTimeToLive; + } } diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/Logs.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/Logs.cs index 0bec1e149c..ba310441df 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/Logs.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/Logs.cs @@ -46,4 +46,58 @@ public static partial void LogAgentResponse( Level = LogLevel.Information, Message = "Found response for agent with session ID '{SessionId}' with correlation ID '{CorrelationId}'")] public static partial void LogDonePollingForResponse(this ILogger logger, AgentSessionId sessionId, string correlationId); + + [LoggerMessage( + EventId = 6, + Level = LogLevel.Information, + Message = "[{SessionId}] TTL expiration time updated to {ExpirationTime:O}")] + public static partial void LogTTLExpirationTimeUpdated( + this ILogger logger, + AgentSessionId sessionId, + DateTime expirationTime); + + [LoggerMessage( + EventId = 7, + Level = LogLevel.Information, + Message = "[{SessionId}] TTL deletion signal scheduled for {ScheduledTime:O}")] + public static partial void LogTTLDeletionScheduled( + this ILogger logger, + AgentSessionId sessionId, + DateTime scheduledTime); + + [LoggerMessage( + EventId = 8, + Level = LogLevel.Information, + Message = "[{SessionId}] TTL deletion check running. Expiration time: {ExpirationTime:O}, Current time: {CurrentTime:O}")] + public static partial void LogTTLDeletionCheck( + this ILogger logger, + AgentSessionId sessionId, + DateTime? expirationTime, + DateTime currentTime); + + [LoggerMessage( + EventId = 9, + Level = LogLevel.Information, + Message = "[{SessionId}] Entity expired and deleted due to TTL. Expiration time: {ExpirationTime:O}")] + public static partial void LogTTLEntityExpired( + this ILogger logger, + AgentSessionId sessionId, + DateTime expirationTime); + + [LoggerMessage( + EventId = 10, + Level = LogLevel.Information, + Message = "[{SessionId}] TTL deletion signal rescheduled for {ScheduledTime:O}")] + public static partial void LogTTLRescheduled( + this ILogger logger, + AgentSessionId sessionId, + DateTime scheduledTime); + + [LoggerMessage( + EventId = 11, + Level = LogLevel.Information, + Message = "[{SessionId}] TTL expiration time cleared (TTL disabled)")] + public static partial void LogTTLExpirationTimeCleared( + this ILogger logger, + AgentSessionId sessionId); } diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/ServiceCollectionExtensions.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/ServiceCollectionExtensions.cs index 2f435e0541..79d44924ca 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/ServiceCollectionExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/ServiceCollectionExtensions.cs @@ -85,6 +85,9 @@ internal static DurableAgentsOptions ConfigureDurableAgents( // The agent dictionary contains the real agent factories, which is used by the agent entities. services.AddSingleton(agents); + // Register the options so AgentEntity can access TTL configuration + services.AddSingleton(options); + // The keyed services are used to resolve durable agent *proxy* instances for external clients. foreach (var factory in agents) { diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/State/DurableAgentStateData.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/State/DurableAgentStateData.cs index f51820dcf5..745f619f48 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/State/DurableAgentStateData.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/State/DurableAgentStateData.cs @@ -17,6 +17,13 @@ internal sealed class DurableAgentStateData [JsonPropertyName("conversationHistory")] public IList ConversationHistory { get; init; } = []; + /// + /// Gets or sets the expiration time (UTC) for this agent entity. + /// If the entity is idle beyond this time, it will be automatically deleted. + /// + [JsonPropertyName("expirationTimeUtc")] + public DateTime? ExpirationTimeUtc { get; set; } + /// /// Gets any additional data found during deserialization that does not map to known properties. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/EndpointRouteBuilderExtensions.ChatCompletions.cs b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/EndpointRouteBuilderExtensions.ChatCompletions.cs index 3fcc9cad27..92c817b124 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/EndpointRouteBuilderExtensions.ChatCompletions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/EndpointRouteBuilderExtensions.ChatCompletions.cs @@ -61,7 +61,7 @@ public static IEndpointConventionBuilder MapOpenAIChatCompletions( path ??= $"/{agent.Name}/v1/chat/completions"; var group = endpoints.MapGroup(path); - var endpointAgentName = agent.DisplayName; + var endpointAgentName = agent.Name ?? agent.Id; group.MapPost("/", async ([FromBody] CreateChatCompletion request, CancellationToken cancellationToken) => await AIAgentChatCompletionsProcessor.CreateChatCompletionAsync(agent, request, cancellationToken).ConfigureAwait(false)) diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/EndpointRouteBuilderExtensions.Responses.cs b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/EndpointRouteBuilderExtensions.Responses.cs index 9a395b9b12..ae96636f16 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/EndpointRouteBuilderExtensions.Responses.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/EndpointRouteBuilderExtensions.Responses.cs @@ -76,7 +76,7 @@ public static IEndpointConventionBuilder MapOpenAIResponses( var handlers = new ResponsesHttpHandler(responsesService); var group = endpoints.MapGroup(responsesPath); - var endpointAgentName = agent.DisplayName; + var endpointAgentName = agent.Name ?? agent.Id; // Create response endpoint group.MapPost("/", handlers.CreateResponseAsync) diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/Models/ConversationReference.cs b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/Models/ConversationReference.cs index dc38375331..d5a1d96240 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/Models/ConversationReference.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/Models/ConversationReference.cs @@ -84,22 +84,18 @@ public override void Write(Utf8JsonWriter writer, ConversationReference value, J return; } - // If only ID is present and no metadata, serialize as a simple string - if (value.Metadata is null || value.Metadata.Count == 0) + // Ideally if only ID is present and no metadata, we would serialize as a simple string. + // However, while a request's "conversation" property can be either a string or an object + // containing a string, a response's "conversation" property is always an object. Since + // here we don't know which scenario we're in, we always serialize as an object, which works + // in any scenario. + writer.WriteStartObject(); + writer.WriteString("id", value.Id); + if (value.Metadata is not null) { - writer.WriteStringValue(value.Id); - } - else - { - // Otherwise, serialize as an object - writer.WriteStartObject(); - writer.WriteString("id", value.Id); - if (value.Metadata is not null) - { - writer.WritePropertyName("metadata"); - JsonSerializer.Serialize(writer, value.Metadata, OpenAIHostingJsonContext.Default.DictionaryStringString); - } - writer.WriteEndObject(); + writer.WritePropertyName("metadata"); + JsonSerializer.Serialize(writer, value.Metadata, OpenAIHostingJsonContext.Default.DictionaryStringString); } + writer.WriteEndObject(); } } diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AIAgentWithOpenAIExtensions.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AIAgentWithOpenAIExtensions.cs index 4abc6915a6..d487ba00e1 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AIAgentWithOpenAIExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AIAgentWithOpenAIExtensions.cs @@ -73,22 +73,22 @@ public static AsyncCollectionResult RunStreamingA } /// - /// Runs the AI agent with a collection of OpenAI response items and returns the response as a native OpenAI . + /// Runs the AI agent with a collection of OpenAI response items and returns the response as a native OpenAI . /// /// The AI agent to run. /// The collection of OpenAI response items to send to the agent. /// The conversation thread to continue with this invocation. If not provided, creates a new thread. The thread will be mutated with the provided messages and agent response. /// Optional parameters for agent invocation. /// The to monitor for cancellation requests. The default is . - /// A representing the asynchronous operation that returns a native OpenAI response. + /// A representing the asynchronous operation that returns a native OpenAI response. /// Thrown when or is . - /// Thrown when the agent's response cannot be converted to an , typically when the underlying representation is not an OpenAI response. + /// Thrown when the agent's response cannot be converted to an , typically when the underlying representation is not an OpenAI response. /// Thrown when any message in has a type that is not supported by the message conversion method. /// /// This method converts the OpenAI response items to the Microsoft Extensions AI format using the appropriate conversion method, - /// runs the agent with the converted message collection, and then extracts the native OpenAI from the response using . + /// runs the agent with the converted message collection, and then extracts the native OpenAI from the response using . /// - public static async Task RunAsync(this AIAgent agent, IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + public static async Task RunAsync(this AIAgent agent, IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { Throw.IfNull(agent); Throw.IfNull(messages); diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AgentRunResponseExtensions.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AgentRunResponseExtensions.cs index 9a164d862b..44844e64f5 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AgentRunResponseExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AgentRunResponseExtensions.cs @@ -29,17 +29,17 @@ response.RawRepresentation as ChatCompletion ?? } /// - /// Creates or extracts a native OpenAI object from an . + /// Creates or extracts a native OpenAI object from an . /// /// The agent response. - /// The OpenAI object. + /// The OpenAI object. /// is . - public static OpenAIResponse AsOpenAIResponse(this AgentRunResponse response) + public static ResponseResult AsOpenAIResponse(this AgentRunResponse response) { Throw.IfNull(response); return - response.RawRepresentation as OpenAIResponse ?? - response.AsChatResponse().AsOpenAIResponse(); + response.RawRepresentation as ResponseResult ?? + response.AsChatResponse().AsOpenAIResponseResult(); } } diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIResponseClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIResponseClientExtensions.cs index 0d48147c77..224bf5db95 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIResponseClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIResponseClientExtensions.cs @@ -8,7 +8,7 @@ namespace OpenAI.Responses; /// -/// Provides extension methods for +/// Provides extension methods for /// to simplify the creation of AI agents that work with OpenAI services. /// /// @@ -20,9 +20,9 @@ namespace OpenAI.Responses; public static class OpenAIResponseClientExtensions { /// - /// Creates an AI agent from an using the OpenAI Response API. + /// Creates an AI agent from an using the OpenAI Response API. /// - /// The to use for the agent. + /// The to use for the agent. /// Optional system instructions that define the agent's behavior and personality. /// Optional name for the agent for identification purposes. /// Optional description of the agent's capabilities and purpose. @@ -33,7 +33,7 @@ public static class OpenAIResponseClientExtensions /// An instance backed by the OpenAI Response service. /// Thrown when is . public static ChatClientAgent CreateAIAgent( - this OpenAIResponseClient client, + this ResponsesClient client, string? instructions = null, string? name = null, string? description = null, @@ -61,9 +61,9 @@ public static ChatClientAgent CreateAIAgent( } /// - /// Creates an AI agent from an using the OpenAI Response API. + /// Creates an AI agent from an using the OpenAI Response API. /// - /// The to use for the agent. + /// The to use for the agent. /// Full set of options to configure the agent. /// Provides a way to customize the creation of the underlying used by the agent. /// Optional logger factory for enabling logging within the agent. @@ -71,7 +71,7 @@ public static ChatClientAgent CreateAIAgent( /// An instance backed by the OpenAI Response service. /// Thrown when or is . public static ChatClientAgent CreateAIAgent( - this OpenAIResponseClient client, + this ResponsesClient client, ChatClientAgentOptions options, Func? clientFactory = null, ILoggerFactory? loggerFactory = null, diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.AzureAI/AzureAgentProvider.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.AzureAI/AzureAgentProvider.cs index c4a613901c..d4010a43c2 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.AzureAI/AzureAgentProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.AzureAI/AzureAgentProvider.cs @@ -111,7 +111,7 @@ public override async IAsyncEnumerable InvokeAgentAsync( if (inputArguments is not null) { JsonNode jsonNode = ConvertDictionaryToJson(inputArguments); - ResponseCreationOptions responseCreationOptions = new(); + CreateResponseOptions responseCreationOptions = new(); #pragma warning disable SCME0001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. responseCreationOptions.Patch.Set("$.structured_inputs"u8, BinaryData.FromString(jsonNode.ToJsonString())); #pragma warning restore SCME0001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. @@ -206,7 +206,7 @@ private async Task GetAgentAsync(AgentVersion agentVersion, Cancellatio public override async Task GetMessageAsync(string conversationId, string messageId, CancellationToken cancellationToken = default) { AgentResponseItem responseItem = await this.GetConversationClient().GetProjectConversationItemAsync(conversationId, messageId, include: null, cancellationToken).ConfigureAwait(false); - ResponseItem[] items = [responseItem.AsOpenAIResponseItem()]; + ResponseItem[] items = [responseItem.AsResponseResultItem()]; return items.AsChatMessages().Single(); } @@ -223,7 +223,7 @@ public override async IAsyncEnumerable GetMessagesAsync( await foreach (AgentResponseItem responseItem in this.GetConversationClient().GetProjectConversationItemsAsync(conversationId, null, limit, order.ToString(), after, before, include: null, cancellationToken).ConfigureAwait(false)) { - ResponseItem[] items = [responseItem.AsOpenAIResponseItem()]; + ResponseItem[] items = [responseItem.AsResponseResultItem()]; foreach (ChatMessage message in items.AsChatMessages()) { yield return message; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/HandoffsWorkflowBuilder.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/HandoffsWorkflowBuilder.cs index 9e5b61ac42..9a3abfe960 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/HandoffsWorkflowBuilder.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/HandoffsWorkflowBuilder.cs @@ -125,14 +125,14 @@ public HandoffsWorkflowBuilder WithHandoff(AIAgent from, AIAgent to, string? han { Throw.ArgumentException( nameof(to), - $"The provided target agent '{to.DisplayName}' has no description, name, or instructions, and no handoff description has been provided. " + + $"The provided target agent '{to.Name ?? to.Id}' has no description, name, or instructions, and no handoff description has been provided. " + "At least one of these is required to register a handoff so that the appropriate target agent can be chosen."); } } if (!handoffs.Add(new(to, handoffReason))) { - Throw.InvalidOperationException($"A handoff from agent '{from.DisplayName}' to agent '{to.DisplayName}' has already been registered."); + Throw.InvalidOperationException($"A handoff from agent '{from.Name ?? from.Id}' to agent '{to.Name ?? to.Id}' has already been registered."); } return this; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AgentRunStreamingExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AgentRunStreamingExecutor.cs index ea80f646f0..ae3a932feb 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AgentRunStreamingExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AgentRunStreamingExecutor.cs @@ -20,7 +20,7 @@ internal sealed class AgentRunStreamingExecutor(AIAgent agent, bool includeInput protected override async ValueTask TakeTurnAsync(List messages, IWorkflowContext context, bool? emitEvents, CancellationToken cancellationToken = default) { - List? roleChanged = messages.ChangeAssistantToUserForOtherParticipants(agent.DisplayName); + List? roleChanged = messages.ChangeAssistantToUserForOtherParticipants(agent.Name ?? agent.Id); List updates = []; await foreach (var update in agent.RunStreamingAsync(messages, cancellationToken: cancellationToken).ConfigureAwait(false)) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs index 59dc49f143..24e0eea3cb 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs @@ -67,7 +67,7 @@ protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) => List updates = []; List allMessages = handoffState.Messages; - List? roleChanges = allMessages.ChangeAssistantToUserForOtherParticipants(this._agent.DisplayName); + List? roleChanges = allMessages.ChangeAssistantToUserForOtherParticipants(this._agent.Name ?? this._agent.Id); await foreach (var update in this._agent.RunStreamingAsync(allMessages, options: this._agentOptions, @@ -85,7 +85,7 @@ await AddUpdateAsync( new AgentRunResponseUpdate { AgentId = this._agent.Id, - AuthorName = this._agent.DisplayName, + AuthorName = this._agent.Name ?? this._agent.Id, Contents = [new FunctionResultContent(fcc.CallId, "Transferred.")], CreatedAt = DateTimeOffset.UtcNow, MessageId = Guid.NewGuid().ToString("N"), diff --git a/dotnet/src/Microsoft.Agents.AI/OpenTelemetryAgent.cs b/dotnet/src/Microsoft.Agents.AI/OpenTelemetryAgent.cs index 7cd3c27b70..22da0c99da 100644 --- a/dotnet/src/Microsoft.Agents.AI/OpenTelemetryAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/OpenTelemetryAgent.cs @@ -114,7 +114,9 @@ private void UpdateCurrentActivity(Activity? previousActivity) // Override information set by OpenTelemetryChatClient to make it specific to invoke_agent. - activity.DisplayName = $"{OpenTelemetryConsts.GenAI.InvokeAgent} {this.DisplayName}"; + activity.DisplayName = string.IsNullOrWhiteSpace(this.Name) + ? $"{OpenTelemetryConsts.GenAI.InvokeAgent} {this.Id}" + : $"{OpenTelemetryConsts.GenAI.InvokeAgent} {this.Name}({this.Id})"; activity.SetTag(OpenTelemetryConsts.GenAI.Operation.Name, OpenTelemetryConsts.GenAI.InvokeAgent); if (!string.IsNullOrWhiteSpace(this._providerName)) diff --git a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs index e982c8081f..883b317f5e 100644 --- a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs +++ b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs @@ -89,7 +89,7 @@ private async Task> GetChatHistoryFromConversationAsync(string List messages = []; await foreach (AgentResponseItem item in this._client.GetProjectOpenAIClient().GetProjectConversationsClient().GetProjectConversationItemsAsync(conversationId, order: "asc")) { - var openAIItem = item.AsOpenAIResponseItem(); + var openAIItem = item.AsResponseResultItem(); if (openAIItem is MessageResponseItem messageItem) { messages.Add(new ChatMessage diff --git a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs index 9869d47f6b..0b491fb303 100644 --- a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs @@ -42,16 +42,14 @@ public void Constructor_WithAllParameters_InitializesPropertiesCorrectly() const string TestId = "test-id"; const string TestName = "test-name"; const string TestDescription = "test-description"; - const string TestDisplayName = "test-display-name"; // Act - var agent = new A2AAgent(this._a2aClient, TestId, TestName, TestDescription, TestDisplayName); + var agent = new A2AAgent(this._a2aClient, TestId, TestName, TestDescription); // Assert Assert.Equal(TestId, agent.Id); Assert.Equal(TestName, agent.Name); Assert.Equal(TestDescription, agent.Description); - Assert.Equal(TestDisplayName, agent.DisplayName); } [Fact] @@ -70,7 +68,6 @@ public void Constructor_WithDefaultParameters_UsesBaseProperties() Assert.NotEmpty(agent.Id); Assert.Null(agent.Name); Assert.Null(agent.Description); - Assert.Equal(agent.Id, agent.DisplayName); } [Fact] diff --git a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/Extensions/A2AClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/Extensions/A2AClientExtensionsTests.cs index e21035003e..5b84324e8b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/Extensions/A2AClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/Extensions/A2AClientExtensionsTests.cs @@ -19,10 +19,9 @@ public void GetAIAgent_WithAllParameters_ReturnsA2AAgentWithSpecifiedProperties( const string TestId = "test-agent-id"; const string TestName = "Test Agent"; const string TestDescription = "This is a test agent description"; - const string TestDisplayName = "Test Display Name"; // Act - var agent = a2aClient.GetAIAgent(TestId, TestName, TestDescription, TestDisplayName); + var agent = a2aClient.GetAIAgent(TestId, TestName, TestDescription); // Assert Assert.NotNull(agent); @@ -30,6 +29,5 @@ public void GetAIAgent_WithAllParameters_ReturnsA2AAgentWithSpecifiedProperties( Assert.Equal(TestId, agent.Id); Assert.Equal(TestName, agent.Name); Assert.Equal(TestDescription, agent.Description); - Assert.Equal(TestDisplayName, agent.DisplayName); } } diff --git a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatMessageStoreTests.cs b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatMessageStoreTests.cs index 7ce0611c99..3dbd3ec367 100644 --- a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatMessageStoreTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatMessageStoreTests.cs @@ -10,7 +10,6 @@ using Azure.Identity; using Microsoft.Azure.Cosmos; using Microsoft.Extensions.AI; -using Xunit; namespace Microsoft.Agents.AI.CosmosNoSql.UnitTests; @@ -59,6 +58,9 @@ public sealed class CosmosChatMessageStoreTests : IAsyncLifetime, IDisposable public async Task InitializeAsync() { + // Fail fast if emulator is not available + this.SkipIfEmulatorNotAvailable(); + // Check environment variable to determine if we should preserve containers // Set COSMOS_PRESERVE_CONTAINERS=true to keep containers and data for inspection this._preserveContainer = string.Equals(Environment.GetEnvironmentVariable("COSMOS_PRESERVE_CONTAINERS"), "true", StringComparison.OrdinalIgnoreCase); diff --git a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosCheckpointStoreTests.cs b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosCheckpointStoreTests.cs index 8f5749b187..dc75b34758 100644 --- a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosCheckpointStoreTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosCheckpointStoreTests.cs @@ -7,7 +7,6 @@ using Microsoft.Agents.AI.Workflows; using Microsoft.Agents.AI.Workflows.Checkpointing; using Microsoft.Azure.Cosmos; -using Xunit; namespace Microsoft.Agents.AI.CosmosNoSql.UnitTests; @@ -58,6 +57,9 @@ private static JsonSerializerOptions CreateJsonOptions() public async Task InitializeAsync() { + // Fail fast if emulator is not available + this.SkipIfEmulatorNotAvailable(); + // Check environment variable to determine if we should preserve containers // Set COSMOS_PRESERVE_CONTAINERS=true to keep containers and data for inspection this._preserveContainer = string.Equals(Environment.GetEnvironmentVariable("COSMOS_PRESERVE_CONTAINERS"), "true", StringComparison.OrdinalIgnoreCase); diff --git a/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/AgentEntityTests.cs b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/AgentEntityTests.cs index 98e40ad4fb..b615bf1cd6 100644 --- a/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/AgentEntityTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/AgentEntityTests.cs @@ -81,6 +81,64 @@ await simpleAgentProxy.RunAsync( Assert.Null(request.OrchestrationId); } + [Theory] + [InlineData("run")] + [InlineData("Run")] + [InlineData("RunAgentAsync")] + public async Task RunAgentMethodNamesAllWorkAsync(string runAgentMethodName) + { + // Setup + AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + name: "TestAgent", + instructions: "You are a helpful assistant that always responds with a friendly greeting." + ); + + using TestHelper testHelper = TestHelper.Start([simpleAgent], this._outputHelper); + + // A proxy agent is needed to call the hosted test agent + AIAgent simpleAgentProxy = simpleAgent.AsDurableAgentProxy(testHelper.Services); + + AgentThread thread = simpleAgentProxy.GetNewThread(); + + DurableTaskClient client = testHelper.GetClient(); + + AgentSessionId sessionId = thread.GetService(); + EntityInstanceId expectedEntityId = new($"dafx-{simpleAgent.Name}", sessionId.Key); + + EntityMetadata? entity = await client.Entities.GetEntityAsync(expectedEntityId, false, this.TestTimeoutToken); + + Assert.Null(entity); + + // Act: send a prompt to the agent + await client.Entities.SignalEntityAsync( + expectedEntityId, + runAgentMethodName, + new RunRequest("Hello!"), + cancellation: this.TestTimeoutToken); + + while (!this.TestTimeoutToken.IsCancellationRequested) + { + await Task.Delay(500, this.TestTimeoutToken); + + // Assert: verify the agent state was stored with the correct entity name prefix + entity = await client.Entities.GetEntityAsync(expectedEntityId, true, this.TestTimeoutToken); + + if (entity is not null) + { + break; + } + } + + Assert.NotNull(entity); + Assert.True(entity.IncludesState); + + DurableAgentState state = entity.State.ReadAs(); + + DurableAgentStateRequest request = Assert.Single(state.Data.ConversationHistory.OfType()); + + Assert.Null(request.OrchestrationId); + } + [Fact] public async Task OrchestrationIdSetDuringOrchestrationAsync() { diff --git a/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/TimeToLiveTests.cs b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/TimeToLiveTests.cs new file mode 100644 index 0000000000..25d40a1c5a --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/TimeToLiveTests.cs @@ -0,0 +1,197 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics; +using System.Reflection; +using Microsoft.Agents.AI.DurableTask.State; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.Extensions.Configuration; +using OpenAI.Chat; +using Xunit.Abstractions; + +namespace Microsoft.Agents.AI.DurableTask.IntegrationTests; + +/// +/// Tests for Time-To-Live (TTL) functionality of durable agent entities. +/// +[Collection("Sequential")] +[Trait("Category", "Integration")] +public sealed class TimeToLiveTests(ITestOutputHelper outputHelper) : IDisposable +{ + private static readonly TimeSpan s_defaultTimeout = Debugger.IsAttached + ? TimeSpan.FromMinutes(5) + : TimeSpan.FromSeconds(30); + + private static readonly IConfiguration s_configuration = + new ConfigurationBuilder() + .AddUserSecrets(Assembly.GetExecutingAssembly()) + .AddEnvironmentVariables() + .Build(); + + private readonly ITestOutputHelper _outputHelper = outputHelper; + private readonly CancellationTokenSource _cts = new(delay: s_defaultTimeout); + + private CancellationToken TestTimeoutToken => this._cts.Token; + + public void Dispose() => this._cts.Dispose(); + + [Fact] + public async Task EntityExpiresAfterTTLAsync() + { + // Arrange: Create agent with short TTL (10 seconds) + TimeSpan ttl = TimeSpan.FromSeconds(10); + AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + name: "TTLTestAgent", + instructions: "You are a helpful assistant." + ); + + using TestHelper testHelper = TestHelper.Start( + this._outputHelper, + options => + { + options.DefaultTimeToLive = ttl; + options.MinimumTimeToLiveSignalDelay = TimeSpan.FromSeconds(1); + options.AddAIAgent(simpleAgent); + }); + + AIAgent agentProxy = simpleAgent.AsDurableAgentProxy(testHelper.Services); + AgentThread thread = agentProxy.GetNewThread(); + DurableTaskClient client = testHelper.GetClient(); + AgentSessionId sessionId = thread.GetService(); + + // Act: Send a message to the agent + await agentProxy.RunAsync( + message: "Hello!", + thread, + cancellationToken: this.TestTimeoutToken); + + // Verify entity exists and get expiration time + EntityMetadata? entity = await client.Entities.GetEntityAsync(sessionId, true, this.TestTimeoutToken); + Assert.NotNull(entity); + Assert.True(entity.IncludesState); + + DurableAgentState state = entity.State.ReadAs(); + Assert.NotNull(state.Data.ExpirationTimeUtc); + DateTime expirationTime = state.Data.ExpirationTimeUtc.Value; + Assert.True(expirationTime > DateTime.UtcNow); + + // Calculate how long to wait: expiration time + buffer for signal processing + TimeSpan waitTime = expirationTime - DateTime.UtcNow + TimeSpan.FromSeconds(1); + if (waitTime > TimeSpan.Zero) + { + await Task.Delay(waitTime, this.TestTimeoutToken); + } + + // Poll the entity state until it's deleted (with timeout) + DateTime pollTimeout = DateTime.UtcNow.AddSeconds(10); + bool entityDeleted = false; + while (DateTime.UtcNow < pollTimeout && !entityDeleted) + { + entity = await client.Entities.GetEntityAsync(sessionId, true, this.TestTimeoutToken); + entityDeleted = entity is null; + + if (!entityDeleted) + { + await Task.Delay(TimeSpan.FromSeconds(1), this.TestTimeoutToken); + } + } + + // Assert: Verify entity state is deleted + Assert.True(entityDeleted, "Entity should have been deleted after TTL expiration"); + } + + [Fact] + public async Task EntityTTLResetsOnInteractionAsync() + { + // Arrange: Create agent with short TTL + TimeSpan ttl = TimeSpan.FromSeconds(6); + AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + name: "TTLResetTestAgent", + instructions: "You are a helpful assistant." + ); + + using TestHelper testHelper = TestHelper.Start( + this._outputHelper, + options => + { + options.DefaultTimeToLive = ttl; + options.MinimumTimeToLiveSignalDelay = TimeSpan.FromSeconds(1); + options.AddAIAgent(simpleAgent); + }); + + AIAgent agentProxy = simpleAgent.AsDurableAgentProxy(testHelper.Services); + AgentThread thread = agentProxy.GetNewThread(); + DurableTaskClient client = testHelper.GetClient(); + AgentSessionId sessionId = thread.GetService(); + + // Act: Send first message + await agentProxy.RunAsync( + message: "Hello!", + thread, + cancellationToken: this.TestTimeoutToken); + + EntityMetadata? entity = await client.Entities.GetEntityAsync(sessionId, true, this.TestTimeoutToken); + Assert.NotNull(entity); + Assert.True(entity.IncludesState); + + DurableAgentState state = entity.State.ReadAs(); + DateTime firstExpirationTime = state.Data.ExpirationTimeUtc!.Value; + + // Wait partway through TTL + await Task.Delay(TimeSpan.FromSeconds(3), this.TestTimeoutToken); + + // Send second message (should reset TTL) + await agentProxy.RunAsync( + message: "Hello again!", + thread, + cancellationToken: this.TestTimeoutToken); + + // Verify expiration time was updated + entity = await client.Entities.GetEntityAsync(sessionId, true, this.TestTimeoutToken); + Assert.NotNull(entity); + Assert.True(entity.IncludesState); + + state = entity.State.ReadAs(); + DateTime secondExpirationTime = state.Data.ExpirationTimeUtc!.Value; + Assert.True(secondExpirationTime > firstExpirationTime); + + // Calculate when the original expiration time would have been + DateTime originalExpirationTime = firstExpirationTime; + TimeSpan waitUntilOriginalExpiration = originalExpirationTime - DateTime.UtcNow + TimeSpan.FromSeconds(2); + + if (waitUntilOriginalExpiration > TimeSpan.Zero) + { + await Task.Delay(waitUntilOriginalExpiration, this.TestTimeoutToken); + } + + // Assert: Entity should still exist because TTL was reset + // The new expiration time should be in the future + entity = await client.Entities.GetEntityAsync(sessionId, true, this.TestTimeoutToken); + Assert.NotNull(entity); + Assert.True(entity.IncludesState); + + state = entity.State.ReadAs(); + Assert.NotNull(state); + Assert.NotNull(state.Data.ExpirationTimeUtc); + Assert.True( + state.Data.ExpirationTimeUtc > DateTime.UtcNow, + "Entity should still be valid because TTL was reset"); + + // Wait for the entity to be deleted + DateTime pollTimeout = DateTime.UtcNow.AddSeconds(10); + bool entityDeleted = false; + while (DateTime.UtcNow < pollTimeout && !entityDeleted) + { + entity = await client.Entities.GetEntityAsync(sessionId, true, this.TestTimeoutToken); + entityDeleted = entity is null; + + if (!entityDeleted) + { + await Task.Delay(TimeSpan.FromSeconds(1), this.TestTimeoutToken); + } + } + + // Assert: Entity should have been deleted + Assert.True(entityDeleted, "Entity should have been deleted after TTL expiration"); + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.OpenAI.UnitTests/OpenAIResponsesIntegrationTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.OpenAI.UnitTests/OpenAIResponsesIntegrationTests.cs index abf66a732f..2dd5b85e5f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.OpenAI.UnitTests/OpenAIResponsesIntegrationTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.OpenAI.UnitTests/OpenAIResponsesIntegrationTests.cs @@ -49,7 +49,7 @@ public async Task CreateResponseStreaming_WithSimpleMessage_ReturnsStreamingUpda const string ExpectedResponse = "One Two Three"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Count to 3"); @@ -90,10 +90,10 @@ public async Task CreateResponse_WithSimpleMessage_ReturnsCompleteResponseAsync( const string ExpectedResponse = "Hello! How can I help you today?"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Hello"); + ResponseResult response = await responseClient.CreateResponseAsync("Hello"); // Assert Assert.NotNull(response); @@ -117,7 +117,7 @@ public async Task CreateResponseStreaming_WithMultipleChunks_StreamsAllContentAs const string ExpectedResponse = "This is a test response with multiple words"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -162,12 +162,12 @@ public async Task CreateResponse_WithMultipleAgents_EachAgentRespondsCorrectlyAs (Agent1Name, Agent1Instructions, Agent1Response), (Agent2Name, Agent2Instructions, Agent2Response)); - OpenAIResponseClient responseClient1 = this.CreateResponseClient(Agent1Name); - OpenAIResponseClient responseClient2 = this.CreateResponseClient(Agent2Name); + ResponsesClient responseClient1 = this.CreateResponseClient(Agent1Name); + ResponsesClient responseClient2 = this.CreateResponseClient(Agent2Name); // Act - OpenAIResponse response1 = await responseClient1.CreateResponseAsync("Hello"); - OpenAIResponse response2 = await responseClient2.CreateResponseAsync("Hello"); + ResponseResult response1 = await responseClient1.CreateResponseAsync("Hello"); + ResponseResult response2 = await responseClient2.CreateResponseAsync("Hello"); // Assert string content1 = response1.GetOutputText(); @@ -190,10 +190,10 @@ public async Task CreateResponse_SameAgentStreamingAndNonStreaming_BothWorkCorre const string ExpectedResponse = "This is the response"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - Non-streaming - OpenAIResponse nonStreamingResponse = await responseClient.CreateResponseAsync("Test"); + ResponseResult nonStreamingResponse = await responseClient.CreateResponseAsync("Test"); // Act - Streaming AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -224,10 +224,10 @@ public async Task CreateResponse_CompletedResponse_HasCorrectStatusAsync() const string ExpectedResponse = "Complete"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Test"); + ResponseResult response = await responseClient.CreateResponseAsync("Test"); // Assert Assert.Equal(ResponseStatus.Completed, response.Status); @@ -247,7 +247,7 @@ public async Task CreateResponseStreaming_VerifyEventSequence_ContainsExpectedEv const string ExpectedResponse = "Test response with multiple words"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -286,7 +286,7 @@ public async Task CreateResponseStreaming_EmptyResponse_HandlesGracefullyAsync() const string ExpectedResponse = ""; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -316,10 +316,10 @@ public async Task CreateResponse_IncludesMetadata_HasRequiredFieldsAsync() const string ExpectedResponse = "Response with metadata"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Test"); + ResponseResult response = await responseClient.CreateResponseAsync("Test"); // Assert Assert.NotNull(response.Id); @@ -340,7 +340,7 @@ public async Task CreateResponseStreaming_LongText_StreamsAllContentAsync() string expectedResponse = string.Join(" ", Enumerable.Range(1, 100).Select(i => $"Word{i}")); this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, expectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Generate long text"); @@ -371,7 +371,7 @@ public async Task CreateResponseStreaming_OutputIndices_AreConsistentAsync() const string ExpectedResponse = "Test output index"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -407,7 +407,7 @@ public async Task CreateResponseStreaming_SingleWord_StreamsCorrectlyAsync() const string ExpectedResponse = "Hello"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -437,7 +437,7 @@ public async Task CreateResponseStreaming_SpecialCharacters_PreservesFormattingA const string ExpectedResponse = "Hello! How are you? I'm fine. 100% great!"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -467,10 +467,10 @@ public async Task CreateResponse_SpecialCharacters_PreservesContentAsync() const string ExpectedResponse = "Symbols: @#$%^&*() Quotes: \"Hello\" 'World' Unicode: 你好 🌍"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Test"); + ResponseResult response = await responseClient.CreateResponseAsync("Test"); // Assert string content = response.GetOutputText(); @@ -489,7 +489,7 @@ public async Task CreateResponseStreaming_ItemIds_AreConsistentAsync() const string ExpectedResponse = "Testing item IDs"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -525,12 +525,12 @@ public async Task CreateResponse_MultipleSequentialRequests_AllSucceedAsync() const string ExpectedResponse = "Response"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act & Assert - Make 5 sequential requests for (int i = 0; i < 5; i++) { - OpenAIResponse response = await responseClient.CreateResponseAsync($"Request {i}"); + ResponseResult response = await responseClient.CreateResponseAsync($"Request {i}"); Assert.NotNull(response); Assert.Equal(ResponseStatus.Completed, response.Status); Assert.Equal(ExpectedResponse, response.GetOutputText()); @@ -549,7 +549,7 @@ public async Task CreateResponseStreaming_MultipleSequentialRequests_AllStreamCo const string ExpectedResponse = "Streaming response"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act & Assert - Make 3 sequential streaming requests for (int i = 0; i < 3; i++) @@ -581,13 +581,13 @@ public async Task CreateResponse_MultipleRequests_GenerateUniqueIdsAsync() const string ExpectedResponse = "Response"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act List responseIds = []; for (int i = 0; i < 10; i++) { - OpenAIResponse response = await responseClient.CreateResponseAsync($"Request {i}"); + ResponseResult response = await responseClient.CreateResponseAsync($"Request {i}"); responseIds.Add(response.Id); } @@ -608,7 +608,7 @@ public async Task CreateResponseStreaming_SequenceNumbers_AreMonotonicallyIncrea const string ExpectedResponse = "Test sequence numbers with multiple words"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -641,10 +641,10 @@ public async Task CreateResponse_ModelInformation_IsCorrectAsync() const string ExpectedResponse = "Test model info"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Test"); + ResponseResult response = await responseClient.CreateResponseAsync("Test"); // Assert Assert.NotNull(response.Model); @@ -663,7 +663,7 @@ public async Task CreateResponseStreaming_Punctuation_PreservesContentAsync() const string ExpectedResponse = "Hello, world! How are you today? I'm doing well."; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -693,10 +693,10 @@ public async Task CreateResponse_ShortInput_ReturnsValidResponseAsync() const string ExpectedResponse = "OK"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Hi"); + ResponseResult response = await responseClient.CreateResponseAsync("Hi"); // Assert Assert.NotNull(response); @@ -716,7 +716,7 @@ public async Task CreateResponseStreaming_ContentIndices_AreConsistentAsync() const string ExpectedResponse = "Test content indices"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -748,10 +748,10 @@ public async Task CreateResponse_Newlines_PreservesFormattingAsync() const string ExpectedResponse = "Line 1\nLine 2\nLine 3"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Test"); + ResponseResult response = await responseClient.CreateResponseAsync("Test"); // Assert string content = response.GetOutputText(); @@ -771,7 +771,7 @@ public async Task CreateResponseStreaming_Newlines_PreservesFormattingAsync() const string ExpectedResponse = "First line\nSecond line\nThird line"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -807,10 +807,10 @@ public async Task CreateResponse_ImageContent_ReturnsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.ImageContentMockChatClient(ImageUrl)); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Show me an image"); + ResponseResult response = await responseClient.CreateResponseAsync("Show me an image"); // Assert Assert.NotNull(response); @@ -834,7 +834,7 @@ public async Task CreateResponseStreaming_ImageContent_StreamsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.ImageContentMockChatClient(ImageUrl)); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Show me an image"); @@ -868,10 +868,10 @@ public async Task CreateResponse_AudioContent_ReturnsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.AudioContentMockChatClient(AudioData, Transcript)); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Generate audio"); + ResponseResult response = await responseClient.CreateResponseAsync("Generate audio"); // Assert Assert.NotNull(response); @@ -896,7 +896,7 @@ public async Task CreateResponseStreaming_AudioContent_StreamsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.AudioContentMockChatClient(AudioData, Transcript)); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Generate audio"); @@ -930,10 +930,10 @@ public async Task CreateResponse_FunctionCall_ReturnsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.FunctionCallMockChatClient(FunctionName, Arguments)); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("What's the weather?"); + ResponseResult response = await responseClient.CreateResponseAsync("What's the weather?"); // Assert Assert.NotNull(response); @@ -957,7 +957,7 @@ public async Task CreateResponseStreaming_FunctionCall_StreamsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.FunctionCallMockChatClient(FunctionName, Arguments)); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Calculate 2+2"); @@ -988,10 +988,10 @@ public async Task CreateResponse_MixedContent_ReturnsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.MixedContentMockChatClient()); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Show me various content"); + ResponseResult response = await responseClient.CreateResponseAsync("Show me various content"); // Assert Assert.NotNull(response); @@ -1014,7 +1014,7 @@ public async Task CreateResponseStreaming_MixedContent_StreamsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.MixedContentMockChatClient()); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Show me various content"); @@ -1047,7 +1047,7 @@ public async Task CreateResponseStreaming_TextDone_IncludesDoneEventAsync() const string ExpectedResponse = "Complete text response"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -1075,7 +1075,7 @@ public async Task CreateResponseStreaming_ContentPartAdded_IncludesEventAsync() const string ExpectedResponse = "Response with content parts"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -1122,7 +1122,7 @@ public async Task CreateResponse_WithConversationId_DoesNotForwardConversationId string conversationId = convDoc.RootElement.GetProperty("id").GetString()!; // Act - Send request with conversation ID using raw HTTP - // (OpenAI SDK doesn't expose ConversationId directly on ResponseCreationOptions) + // (OpenAI SDK doesn't expose ConversationId directly on CreateResponseOptions) var requestBody = new { input = "Test", @@ -1201,9 +1201,9 @@ public async Task CreateResponseStreaming_WithConversationId_DoesNotForwardConve Assert.Null(mockChatClient.LastChatOptions.ConversationId); } - private OpenAIResponseClient CreateResponseClient(string agentName) + private ResponsesClient CreateResponseClient(string agentName) { - return new OpenAIResponseClient( + return new ResponsesClient( model: "test-model", credential: new ApiKeyCredential("test-api-key"), options: new OpenAIClientOptions diff --git a/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIResponseClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIResponseClientExtensionsTests.cs index 781ccb123e..127fe1a58f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIResponseClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIResponseClientExtensionsTests.cs @@ -55,9 +55,9 @@ public async IAsyncEnumerable GetStreamingResponseAsync( } /// - /// Creates a test OpenAIResponseClient implementation for testing. + /// Creates a test ResponsesClient implementation for testing. /// - private sealed class TestOpenAIResponseClient : OpenAIResponseClient + private sealed class TestOpenAIResponseClient : ResponsesClient { public TestOpenAIResponseClient() { @@ -147,7 +147,7 @@ public void CreateAIAgent_WithNullClient_ThrowsArgumentNullException() { // Act & Assert var exception = Assert.Throws(() => - ((OpenAIResponseClient)null!).CreateAIAgent()); + ((ResponsesClient)null!).CreateAIAgent()); Assert.Equal("client", exception.ParamName); } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/LoggingAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/LoggingAgentTests.cs index 57b3051197..58e9536491 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/LoggingAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/LoggingAgentTests.cs @@ -42,7 +42,6 @@ public void Properties_DelegateToInnerAgent() Assert.Equal("TestAgent", agent.Name); Assert.Equal("This is a test agent.", agent.Description); Assert.Equal(innerAgent.Id, agent.Id); - Assert.Equal(innerAgent.DisplayName, agent.DisplayName); } [Fact] diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/OpenTelemetryAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/OpenTelemetryAgentTests.cs index b9b04b7228..405832763c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/OpenTelemetryAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/OpenTelemetryAgentTests.cs @@ -45,7 +45,6 @@ public void Properties_DelegateToInnerAgent() Assert.Equal("TestAgent", agent.Name); Assert.Equal("This is a test agent.", agent.Description); Assert.Equal(innerAgent.Id, agent.Id); - Assert.Equal(innerAgent.DisplayName, agent.DisplayName); } [Fact] @@ -170,7 +169,7 @@ async static IAsyncEnumerable CallbackAsync( Assert.Equal("localhost", activity.GetTagItem("server.address")); Assert.Equal(12345, (int)activity.GetTagItem("server.port")!); - Assert.Equal("invoke_agent TestAgent", activity.DisplayName); + Assert.Equal($"invoke_agent {agent.Name}({agent.Id})", activity.DisplayName); Assert.Equal("invoke_agent", activity.GetTagItem("gen_ai.operation.name")); Assert.Equal("TestAgentProviderFromAIAgentMetadata", activity.GetTagItem("gen_ai.provider.name")); Assert.Equal(innerAgent.Name, activity.GetTagItem("gen_ai.agent.name")); @@ -431,7 +430,15 @@ async static IAsyncEnumerable CallbackAsync( Assert.Equal("localhost", activity.GetTagItem("server.address")); Assert.Equal(12345, (int)activity.GetTagItem("server.port")!); - Assert.Equal($"invoke_agent {innerAgent.DisplayName}", activity.DisplayName); + if (string.IsNullOrWhiteSpace(innerAgent.Name)) + { + Assert.Equal($"invoke_agent {innerAgent.Id}", activity.DisplayName); + } + else + { + Assert.Equal($"invoke_agent {innerAgent.Name}({innerAgent.Id})", activity.DisplayName); + } + Assert.Equal("invoke_agent", activity.GetTagItem("gen_ai.operation.name")); Assert.Equal("TestAgentProviderFromAIAgentMetadata", activity.GetTagItem("gen_ai.provider.name")); Assert.Equal(innerAgent.Name, activity.GetTagItem("gen_ai.agent.name")); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/MediaInputTest.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/MediaInputTest.cs index 32adb93ddb..d5976a3174 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/MediaInputTest.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/MediaInputTest.cs @@ -24,7 +24,7 @@ public sealed class MediaInputTest(ITestOutputHelper output) : IntegrationTest(o private const string ImageReference = "https://sample-files.com/downloads/images/jpg/web_optimized_1200x800_97kb.jpg"; [Theory] - [InlineData(ImageReference, "image/jpeg")] + [InlineData(ImageReference, "image/jpeg", Skip = "Failing consistently in the agent service api")] [InlineData(PdfReference, "application/pdf", Skip = "Not currently supported by agent service api")] public async Task ValidateFileUrlAsync(string fileSource, string mediaType) { diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs index 824a75d5d0..c319a0ac32 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs @@ -30,7 +30,7 @@ protected override IEnumerable GetEpilogueMessages(AgentRunOptions? { return [new(ChatRole.Assistant, [new FunctionCallContent(Guid.NewGuid().ToString("N"), handoff.Name)]) { - AuthorName = this.DisplayName, + AuthorName = this.Name ?? this.Id, MessageId = Guid.NewGuid().ToString("N"), CreatedAt = DateTime.UtcNow }]; diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs index 9ddc94cf71..caec2a0631 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs @@ -47,7 +47,7 @@ IEnumerable echoMessages select UpdateThread(new ChatMessage(ChatRole.Assistant, $"{prefix}{message.Text}") { - AuthorName = this.DisplayName, + AuthorName = this.Name ?? this.Id, CreatedAt = DateTimeOffset.Now, MessageId = Guid.NewGuid().ToString("N") }, thread as InMemoryAgentThread); diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunStreamingTests.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunStreamingTests.cs index 669a4dd2a0..80a148d7fc 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunStreamingTests.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunStreamingTests.cs @@ -3,11 +3,11 @@ using System.Threading.Tasks; using AgentConformance.IntegrationTests; -namespace OpenAIResponse.IntegrationTests; +namespace ResponseResult.IntegrationTests; public class OpenAIResponseStoreTrueChatClientAgentRunStreamingTests() : ChatClientAgentRunStreamingTests(() => new(store: true)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithInstructionsAndNoMessageReturnsExpectedResultAsync() => @@ -16,7 +16,7 @@ public override Task RunWithInstructionsAndNoMessageReturnsExpectedResultAsync() public class OpenAIResponseStoreFalseChatClientAgentRunStreamingTests() : ChatClientAgentRunStreamingTests(() => new(store: false)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithInstructionsAndNoMessageReturnsExpectedResultAsync() => diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunTests.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunTests.cs index af2f1c14ec..8b742e2964 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunTests.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunTests.cs @@ -3,11 +3,11 @@ using System.Threading.Tasks; using AgentConformance.IntegrationTests; -namespace OpenAIResponse.IntegrationTests; +namespace ResponseResult.IntegrationTests; public class OpenAIResponseStoreTrueChatClientAgentRunTests() : ChatClientAgentRunTests(() => new(store: true)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithInstructionsAndNoMessageReturnsExpectedResultAsync() => @@ -16,7 +16,7 @@ public override Task RunWithInstructionsAndNoMessageReturnsExpectedResultAsync() public class OpenAIResponseStoreFalseChatClientAgentRunTests() : ChatClientAgentRunTests(() => new(store: false)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithInstructionsAndNoMessageReturnsExpectedResultAsync() => diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs index a58583fbca..c6c84db569 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs @@ -12,13 +12,13 @@ using OpenAI.Responses; using Shared.IntegrationTests; -namespace OpenAIResponse.IntegrationTests; +namespace ResponseResult.IntegrationTests; public class OpenAIResponseFixture(bool store) : IChatClientAgentFixture { private static readonly OpenAIConfiguration s_config = TestConfiguration.LoadSection(); - private OpenAIResponseClient _openAIResponseClient = null!; + private ResponsesClient _openAIResponseClient = null!; private ChatClientAgent _agent = null!; public AIAgent Agent => this._agent; @@ -77,7 +77,7 @@ public async Task CreateChatClientAgentAsync( { Instructions = instructions, Tools = aiTools, - RawRepresentationFactory = new Func(_ => new ResponseCreationOptions() { StoredOutputEnabled = store }) + RawRepresentationFactory = new Func(_ => new CreateResponseOptions() { StoredOutputEnabled = store }) }, }); @@ -92,7 +92,7 @@ public Task DeleteThreadAsync(AgentThread thread) => public async Task InitializeAsync() { this._openAIResponseClient = new OpenAIClient(s_config.ApiKey) - .GetOpenAIResponseClient(s_config.ChatModelId); + .GetResponsesClient(s_config.ChatModelId); this._agent = await this.CreateChatClientAgentAsync(); } diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunStreamingTests.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunStreamingTests.cs index e2e7e28bbd..c12f8f2db5 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunStreamingTests.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunStreamingTests.cs @@ -3,11 +3,11 @@ using System.Threading.Tasks; using AgentConformance.IntegrationTests; -namespace OpenAIResponse.IntegrationTests; +namespace ResponseResult.IntegrationTests; public class OpenAIResponseStoreTrueRunStreamingTests() : RunStreamingTests(() => new(store: true)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithNoMessageDoesNotFailAsync() => Task.CompletedTask; @@ -15,7 +15,7 @@ public override Task RunWithNoMessageDoesNotFailAsync() => public class OpenAIResponseStoreFalseRunStreamingTests() : RunStreamingTests(() => new(store: false)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithNoMessageDoesNotFailAsync() => diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunTests.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunTests.cs index 41c5254474..423ac583c7 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunTests.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunTests.cs @@ -3,11 +3,11 @@ using System.Threading.Tasks; using AgentConformance.IntegrationTests; -namespace OpenAIResponse.IntegrationTests; +namespace ResponseResult.IntegrationTests; public class OpenAIResponseStoreTrueRunTests() : RunTests(() => new(store: true)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithNoMessageDoesNotFailAsync() => Task.CompletedTask; @@ -15,7 +15,7 @@ public override Task RunWithNoMessageDoesNotFailAsync() => public class OpenAIResponseStoreFalseRunTests() : RunTests(() => new(store: false)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithNoMessageDoesNotFailAsync() => diff --git a/python/.env.example b/python/.env.example index f864f18f72..c09300d775 100644 --- a/python/.env.example +++ b/python/.env.example @@ -33,7 +33,6 @@ ANTHROPIC_MODEL="" OLLAMA_ENDPOINT="" OLLAMA_MODEL="" # Observability -ENABLE_OTEL=true +ENABLE_INSTRUMENTATION=true ENABLE_SENSITIVE_DATA=true -OTLP_ENDPOINT="http://localhost:4317/" -# APPLICATIONINSIGHTS_CONNECTION_STRING="..." +OTEL_EXPORTER_OTLP_ENDPOINT="http://localhost:4317/" diff --git a/python/CHANGELOG.md b/python/CHANGELOG.md index b9bb82d6e7..7b012ccf23 100644 --- a/python/CHANGELOG.md +++ b/python/CHANGELOG.md @@ -7,6 +7,32 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- **agent-framework-azurefunctions**: Durable Agents: platforms should use consistent entity method names (#2234) + +## [1.0.0b251216] - 2025-12-16 + +### Added + +- **agent-framework-ollama**: Ollama connector for Agent Framework (#1104) +- **agent-framework-core**: Added custom args and thread object to `ai_function` kwargs (#2769) +- **agent-framework-core**: Enable checkpointing for `WorkflowAgent` (#2774) + +### Changed + +- **agent-framework-core**: [BREAKING] Observability updates (#2782) +- **agent-framework-core**: Use agent description in `HandoffBuilder` auto-generated tools (#2714) +- **agent-framework-core**: Remove warnings from workflow builder when not using factories (#2808) + +### Fixed + +- **agent-framework-core**: Fix `WorkflowAgent` to include thread conversation history (#2774) +- **agent-framework-core**: Fix context duplication in handoff workflows when restoring from checkpoint (#2867) +- **agent-framework-core**: Fix middleware terminate flag to exit function calling loop immediately (#2868) +- **agent-framework-core**: Fix `WorkflowAgent` to emit `yield_output` as agent response (#2866) +- **agent-framework-core**: Filter framework kwargs from MCP tool invocations (#2870) + ## [1.0.0b251211] - 2025-12-11 ### Added @@ -366,7 +392,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 For more information, see the [announcement blog post](https://devblogs.microsoft.com/foundry/introducing-microsoft-agent-framework-the-open-source-engine-for-agentic-ai-apps/). -[Unreleased]: https://github.com/microsoft/agent-framework/compare/python-1.0.0b251211...HEAD +[Unreleased]: https://github.com/microsoft/agent-framework/compare/python-1.0.0b251216...HEAD +[1.0.0b251216]: https://github.com/microsoft/agent-framework/compare/python-1.0.0b251211...python-1.0.0b251216 [1.0.0b251211]: https://github.com/microsoft/agent-framework/compare/python-1.0.0b251209...python-1.0.0b251211 [1.0.0b251209]: https://github.com/microsoft/agent-framework/compare/python-1.0.0b251204...python-1.0.0b251209 [1.0.0b251204]: https://github.com/microsoft/agent-framework/compare/python-1.0.0b251120...python-1.0.0b251204 diff --git a/python/DEV_SETUP.md b/python/DEV_SETUP.md index 2d4b9b92b1..fc261bdde0 100644 --- a/python/DEV_SETUP.md +++ b/python/DEV_SETUP.md @@ -154,6 +154,14 @@ Example: chat_completion = OpenAIChatClient(env_file_path="openai.env") ``` +# Method naming inside connectors + +When naming methods inside connectors, we have a loose preference for using the following conventions: +- Use `_prepare__for_` as a prefix for methods that prepare data for sending to the external service. +- Use `_parse__from_` as a prefix for methods that process data received from the external service. + +This is not a strict rule, but a guideline to help maintain consistency across the codebase. + ## Tests All the tests are located in the `tests` folder of each package. There are tests that are marked with a `@skip_if_..._integration_tests_disabled` decorator, these are integration tests that require an external service to be running, like OpenAI or Azure OpenAI. diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index 4f86eb5afc..cd85509a40 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -5,7 +5,7 @@ import re import uuid from collections.abc import AsyncIterable, Sequence -from typing import Any, cast +from typing import Any, Final, cast import httpx from a2a.client import Client, ClientConfig, ClientFactory, minimal_agent_card @@ -38,6 +38,7 @@ UriContent, prepend_agent_framework_to_user_agent, ) +from agent_framework.observability import use_agent_instrumentation __all__ = ["A2AAgent"] @@ -58,6 +59,7 @@ def _get_uri_data(uri: str) -> str: return match.group("base64_data") +@use_agent_instrumentation class A2AAgent(BaseAgent): """Agent2Agent (A2A) protocol implementation. @@ -69,6 +71,8 @@ class A2AAgent(BaseAgent): Can be initialized with a URL, AgentCard, or existing A2A Client instance. """ + AGENT_PROVIDER_NAME: Final[str] = "A2A" + def __init__( self, *, @@ -233,14 +237,14 @@ async def run_stream( An agent response item. """ messages = self._normalize_messages(messages) - a2a_message = self._chat_message_to_a2a_message(messages[-1]) + a2a_message = self._prepare_message_for_a2a(messages[-1]) response_stream = self.client.send_message(a2a_message) async for item in response_stream: if isinstance(item, Message): # Process A2A Message - contents = self._a2a_parts_to_contents(item.parts) + contents = self._parse_contents_from_a2a(item.parts) yield AgentRunResponseUpdate( contents=contents, role=Role.ASSISTANT if item.role == A2ARole.agent else Role.USER, @@ -251,7 +255,7 @@ async def run_stream( task, _update_event = item if isinstance(task, Task) and task.status.state in TERMINAL_TASK_STATES: # Convert Task artifacts to ChatMessages and yield as separate updates - task_messages = self._task_to_chat_messages(task) + task_messages = self._parse_messages_from_task(task) if task_messages: for message in task_messages: # Use the artifact's ID from raw_representation as message_id for unique identification @@ -276,8 +280,8 @@ async def run_stream( msg = f"Only Message and Task responses are supported from A2A agents. Received: {type(item)}" raise NotImplementedError(msg) - def _chat_message_to_a2a_message(self, message: ChatMessage) -> A2AMessage: - """Convert a ChatMessage to an A2A Message. + def _prepare_message_for_a2a(self, message: ChatMessage) -> A2AMessage: + """Prepare a ChatMessage for the A2A protocol. Transforms Agent Framework ChatMessage objects into A2A protocol Messages by: - Converting all message contents to appropriate A2A Part types @@ -357,8 +361,8 @@ def _chat_message_to_a2a_message(self, message: ChatMessage) -> A2AMessage: metadata=cast(dict[str, Any], message.additional_properties), ) - def _a2a_parts_to_contents(self, parts: Sequence[A2APart]) -> list[Contents]: - """Convert A2A Parts to Agent Framework Contents. + def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Contents]: + """Parse A2A Parts into Agent Framework Contents. Transforms A2A protocol Parts into framework-native Content objects, handling text, file (URI/bytes), and data parts with metadata preservation. @@ -406,17 +410,17 @@ def _a2a_parts_to_contents(self, parts: Sequence[A2APart]) -> list[Contents]: raise ValueError(f"Unknown Part kind: {inner_part.kind}") return contents - def _task_to_chat_messages(self, task: Task) -> list[ChatMessage]: - """Convert A2A Task artifacts to ChatMessages with ASSISTANT role.""" + def _parse_messages_from_task(self, task: Task) -> list[ChatMessage]: + """Parse A2A Task artifacts into ChatMessages with ASSISTANT role.""" messages: list[ChatMessage] = [] if task.artifacts is not None: for artifact in task.artifacts: - messages.append(self._artifact_to_chat_message(artifact)) + messages.append(self._parse_message_from_artifact(artifact)) elif task.history is not None and len(task.history) > 0: # Include the last history item as the agent response history_item = task.history[-1] - contents = self._a2a_parts_to_contents(history_item.parts) + contents = self._parse_contents_from_a2a(history_item.parts) messages.append( ChatMessage( role=Role.ASSISTANT if history_item.role == A2ARole.agent else Role.USER, @@ -427,9 +431,9 @@ def _task_to_chat_messages(self, task: Task) -> list[ChatMessage]: return messages - def _artifact_to_chat_message(self, artifact: Artifact) -> ChatMessage: - """Convert A2A Artifact to ChatMessage using part contents.""" - contents = self._a2a_parts_to_contents(artifact.parts) + def _parse_message_from_artifact(self, artifact: Artifact) -> ChatMessage: + """Parse A2A Artifact into ChatMessage using part contents.""" + contents = self._parse_contents_from_a2a(artifact.parts) return ChatMessage( role=Role.ASSISTANT, contents=contents, diff --git a/python/packages/a2a/pyproject.toml b/python/packages/a2a/pyproject.toml index 8ac7a0abbc..56d79ce7fe 100644 --- a/python/packages/a2a/pyproject.toml +++ b/python/packages/a2a/pyproject.toml @@ -4,7 +4,7 @@ description = "A2A integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251211" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py index 82d3d02875..58ab18fee4 100644 --- a/python/packages/a2a/tests/test_a2a_agent.py +++ b/python/packages/a2a/tests/test_a2a_agent.py @@ -197,18 +197,18 @@ async def test_run_with_unknown_response_type_raises_error(a2a_agent: A2AAgent, await a2a_agent.run("Test message") -def test_task_to_chat_messages_empty_artifacts(a2a_agent: A2AAgent) -> None: - """Test _task_to_chat_messages with task containing no artifacts.""" +def test_parse_messages_from_task_empty_artifacts(a2a_agent: A2AAgent) -> None: + """Test _parse_messages_from_task with task containing no artifacts.""" task = MagicMock() task.artifacts = None - result = a2a_agent._task_to_chat_messages(task) + result = a2a_agent._parse_messages_from_task(task) assert len(result) == 0 -def test_task_to_chat_messages_with_artifacts(a2a_agent: A2AAgent) -> None: - """Test _task_to_chat_messages with task containing artifacts.""" +def test_parse_messages_from_task_with_artifacts(a2a_agent: A2AAgent) -> None: + """Test _parse_messages_from_task with task containing artifacts.""" task = MagicMock() # Create mock artifacts @@ -232,7 +232,7 @@ def test_task_to_chat_messages_with_artifacts(a2a_agent: A2AAgent) -> None: task.artifacts = [artifact1, artifact2] - result = a2a_agent._task_to_chat_messages(task) + result = a2a_agent._parse_messages_from_task(task) assert len(result) == 2 assert result[0].text == "Content 1" @@ -240,8 +240,8 @@ def test_task_to_chat_messages_with_artifacts(a2a_agent: A2AAgent) -> None: assert all(msg.role == Role.ASSISTANT for msg in result) -def test_artifact_to_chat_message(a2a_agent: A2AAgent) -> None: - """Test _artifact_to_chat_message conversion.""" +def test_parse_message_from_artifact(a2a_agent: A2AAgent) -> None: + """Test _parse_message_from_artifact conversion.""" artifact = MagicMock() artifact.artifact_id = "test-artifact" @@ -253,7 +253,7 @@ def test_artifact_to_chat_message(a2a_agent: A2AAgent) -> None: artifact.parts = [text_part] - result = a2a_agent._artifact_to_chat_message(artifact) + result = a2a_agent._parse_message_from_artifact(artifact) assert isinstance(result, ChatMessage) assert result.role == Role.ASSISTANT @@ -276,7 +276,7 @@ def test_get_uri_data_invalid_uri() -> None: _get_uri_data("not-a-valid-data-uri") -def test_a2a_parts_to_contents_conversion(a2a_agent: A2AAgent) -> None: +def test_parse_contents_from_a2a_conversion(a2a_agent: A2AAgent) -> None: """Test A2A parts to contents conversion.""" agent = A2AAgent(name="Test Agent", client=MockA2AClient(), _http_client=None) @@ -285,7 +285,7 @@ def test_a2a_parts_to_contents_conversion(a2a_agent: A2AAgent) -> None: parts = [Part(root=TextPart(text="First part")), Part(root=TextPart(text="Second part"))] # Convert to contents - contents = agent._a2a_parts_to_contents(parts) + contents = agent._parse_contents_from_a2a(parts) # Verify conversion assert len(contents) == 2 @@ -295,30 +295,30 @@ def test_a2a_parts_to_contents_conversion(a2a_agent: A2AAgent) -> None: assert contents[1].text == "Second part" -def test_chat_message_to_a2a_message_with_error_content(a2a_agent: A2AAgent) -> None: - """Test _chat_message_to_a2a_message with ErrorContent.""" +def test_prepare_message_for_a2a_with_error_content(a2a_agent: A2AAgent) -> None: + """Test _prepare_message_for_a2a with ErrorContent.""" # Create ChatMessage with ErrorContent error_content = ErrorContent(message="Test error message") message = ChatMessage(role=Role.USER, contents=[error_content]) # Convert to A2A message - a2a_message = a2a_agent._chat_message_to_a2a_message(message) + a2a_message = a2a_agent._prepare_message_for_a2a(message) # Verify conversion assert len(a2a_message.parts) == 1 assert a2a_message.parts[0].root.text == "Test error message" -def test_chat_message_to_a2a_message_with_uri_content(a2a_agent: A2AAgent) -> None: - """Test _chat_message_to_a2a_message with UriContent.""" +def test_prepare_message_for_a2a_with_uri_content(a2a_agent: A2AAgent) -> None: + """Test _prepare_message_for_a2a with UriContent.""" # Create ChatMessage with UriContent uri_content = UriContent(uri="http://example.com/file.pdf", media_type="application/pdf") message = ChatMessage(role=Role.USER, contents=[uri_content]) # Convert to A2A message - a2a_message = a2a_agent._chat_message_to_a2a_message(message) + a2a_message = a2a_agent._prepare_message_for_a2a(message) # Verify conversion assert len(a2a_message.parts) == 1 @@ -326,15 +326,15 @@ def test_chat_message_to_a2a_message_with_uri_content(a2a_agent: A2AAgent) -> No assert a2a_message.parts[0].root.file.mime_type == "application/pdf" -def test_chat_message_to_a2a_message_with_data_content(a2a_agent: A2AAgent) -> None: - """Test _chat_message_to_a2a_message with DataContent.""" +def test_prepare_message_for_a2a_with_data_content(a2a_agent: A2AAgent) -> None: + """Test _prepare_message_for_a2a with DataContent.""" # Create ChatMessage with DataContent (base64 data URI) data_content = DataContent(uri="data:text/plain;base64,SGVsbG8gV29ybGQ=", media_type="text/plain") message = ChatMessage(role=Role.USER, contents=[data_content]) # Convert to A2A message - a2a_message = a2a_agent._chat_message_to_a2a_message(message) + a2a_message = a2a_agent._prepare_message_for_a2a(message) # Verify conversion assert len(a2a_message.parts) == 1 @@ -342,14 +342,14 @@ def test_chat_message_to_a2a_message_with_data_content(a2a_agent: A2AAgent) -> N assert a2a_message.parts[0].root.file.mime_type == "text/plain" -def test_chat_message_to_a2a_message_empty_contents_raises_error(a2a_agent: A2AAgent) -> None: - """Test _chat_message_to_a2a_message with empty contents raises ValueError.""" +def test_prepare_message_for_a2a_empty_contents_raises_error(a2a_agent: A2AAgent) -> None: + """Test _prepare_message_for_a2a with empty contents raises ValueError.""" # Create ChatMessage with no contents message = ChatMessage(role=Role.USER, contents=[]) # Should raise ValueError for empty contents with raises(ValueError, match="ChatMessage.contents is empty"): - a2a_agent._chat_message_to_a2a_message(message) + a2a_agent._prepare_message_for_a2a(message) async def test_run_stream_with_message_response(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None: @@ -405,7 +405,7 @@ async def test_context_manager_no_cleanup_when_no_http_client() -> None: pass -def test_chat_message_to_a2a_message_with_multiple_contents() -> None: +def test_prepare_message_for_a2a_with_multiple_contents() -> None: """Test conversion of ChatMessage with multiple contents.""" agent = A2AAgent(client=MagicMock(), _http_client=None) @@ -421,7 +421,7 @@ def test_chat_message_to_a2a_message_with_multiple_contents() -> None: ], ) - result = agent._chat_message_to_a2a_message(message) + result = agent._prepare_message_for_a2a(message) # Should have converted all 4 contents to parts assert len(result.parts) == 4 @@ -433,7 +433,7 @@ def test_chat_message_to_a2a_message_with_multiple_contents() -> None: assert result.parts[3].root.kind == "text" # JSON text remains as text (no parsing) -def test_a2a_parts_to_contents_with_data_part() -> None: +def test_parse_contents_from_a2a_with_data_part() -> None: """Test conversion of A2A DataPart.""" agent = A2AAgent(client=MagicMock(), _http_client=None) @@ -441,7 +441,7 @@ def test_a2a_parts_to_contents_with_data_part() -> None: # Create DataPart data_part = Part(root=DataPart(data={"key": "value", "number": 42}, metadata={"source": "test"})) - contents = agent._a2a_parts_to_contents([data_part]) + contents = agent._parse_contents_from_a2a([data_part]) assert len(contents) == 1 @@ -450,7 +450,7 @@ def test_a2a_parts_to_contents_with_data_part() -> None: assert contents[0].additional_properties == {"source": "test"} -def test_a2a_parts_to_contents_unknown_part_kind() -> None: +def test_parse_contents_from_a2a_unknown_part_kind() -> None: """Test error handling for unknown A2A part kind.""" agent = A2AAgent(client=MagicMock(), _http_client=None) @@ -459,10 +459,10 @@ def test_a2a_parts_to_contents_unknown_part_kind() -> None: mock_part.root.kind = "unknown_kind" with raises(ValueError, match="Unknown Part kind: unknown_kind"): - agent._a2a_parts_to_contents([mock_part]) + agent._parse_contents_from_a2a([mock_part]) -def test_chat_message_to_a2a_message_with_hosted_file() -> None: +def test_prepare_message_for_a2a_with_hosted_file() -> None: """Test conversion of ChatMessage with HostedFileContent to A2A message.""" agent = A2AAgent(client=MagicMock(), _http_client=None) @@ -473,7 +473,7 @@ def test_chat_message_to_a2a_message_with_hosted_file() -> None: contents=[HostedFileContent(file_id="hosted://storage/document.pdf")], ) - result = agent._chat_message_to_a2a_message(message) # noqa: SLF001 + result = agent._prepare_message_for_a2a(message) # noqa: SLF001 # Verify the conversion assert len(result.parts) == 1 @@ -488,7 +488,7 @@ def test_chat_message_to_a2a_message_with_hosted_file() -> None: assert part.root.file.mime_type is None # HostedFileContent doesn't specify media_type -def test_a2a_parts_to_contents_with_hosted_file_uri() -> None: +def test_parse_contents_from_a2a_with_hosted_file_uri() -> None: """Test conversion of A2A FilePart with hosted file URI back to UriContent.""" agent = A2AAgent(client=MagicMock(), _http_client=None) @@ -503,7 +503,7 @@ def test_a2a_parts_to_contents_with_hosted_file_uri() -> None: ) ) - contents = agent._a2a_parts_to_contents([file_part]) # noqa: SLF001 + contents = agent._parse_contents_from_a2a([file_part]) # noqa: SLF001 assert len(contents) == 1 diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index ab7eb53940..db2f160a9d 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -23,7 +23,7 @@ from agent_framework._middleware import use_chat_middleware from agent_framework._tools import use_function_invocation from agent_framework._types import BaseContent, Contents -from agent_framework.observability import use_observability +from agent_framework.observability import use_instrumentation from ._event_converters import AGUIEventConverter from ._http_service import AGUIHttpService @@ -89,7 +89,7 @@ async def response_wrapper(self, *args: Any, **kwargs: Any) -> ChatResponse: @_apply_server_function_call_unwrap @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class AGUIChatClient(BaseChatClient): """Chat client for communicating with AG-UI compliant servers. diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 654498e371..6bdff552b6 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -86,7 +86,7 @@ def last_message(self): def run_id(self) -> str: """Get or generate run ID.""" if self._run_id is None: - self._run_id = self.input_data.get("run_id") or str(uuid.uuid4()) + self._run_id = self.input_data.get("run_id") or self.input_data.get("runId") or str(uuid.uuid4()) # This should never be None after the if block above, but satisfy type checkers if self._run_id is None: # pragma: no cover raise RuntimeError("Failed to initialize run_id") @@ -96,7 +96,7 @@ def run_id(self) -> str: def thread_id(self) -> str: """Get or generate thread ID.""" if self._thread_id is None: - self._thread_id = self.input_data.get("thread_id") or str(uuid.uuid4()) + self._thread_id = self.input_data.get("thread_id") or self.input_data.get("threadId") or str(uuid.uuid4()) # This should never be None after the if block above, but satisfy type checkers if self._thread_id is None: # pragma: no cover raise RuntimeError("Failed to initialize thread_id") diff --git a/python/packages/ag-ui/pyproject.toml b/python/packages/ag-ui/pyproject.toml index 805b2b55e4..8a4adceeee 100644 --- a/python/packages/ag-ui/pyproject.toml +++ b/python/packages/ag-ui/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "agent-framework-ag-ui" -version = "1.0.0b251211" +version = "1.0.0b251216" description = "AG-UI protocol integration for Agent Framework" readme = "README.md" license-files = ["LICENSE"] diff --git a/python/packages/ag-ui/tests/test_orchestrators.py b/python/packages/ag-ui/tests/test_orchestrators.py index 10843a259c..af90ea2e88 100644 --- a/python/packages/ag-ui/tests/test_orchestrators.py +++ b/python/packages/ag-ui/tests/test_orchestrators.py @@ -83,3 +83,71 @@ async def test_default_orchestrator_merges_client_tools() -> None: assert "server_tool" in tool_names assert "get_weather" in tool_names assert agent.chat_client.function_invocation_configuration.additional_tools + + +async def test_default_orchestrator_with_camel_case_ids() -> None: + """Client tool is able to extract camelCase IDs.""" + + agent = DummyAgent() + orchestrator = DefaultOrchestrator() + + input_data = { + "runId": "test-camelcase-runid", + "threadId": "test-camelcase-threadid", + "messages": [ + { + "role": "user", + "content": [{"type": "input_text", "text": "Hello"}], + } + ], + "tools": [], + } + + context = ExecutionContext( + input_data=input_data, + agent=agent, + config=AgentConfig(), + ) + + events = [] + async for event in orchestrator.run(context): + events.append(event) + + # assert the last event has the expected run_id and thread_id + last_event = events[-1] + assert last_event.run_id == "test-camelcase-runid" + assert last_event.thread_id == "test-camelcase-threadid" + + +async def test_default_orchestrator_with_snake_case_ids() -> None: + """Client tool is able to extract snake_case IDs.""" + + agent = DummyAgent() + orchestrator = DefaultOrchestrator() + + input_data = { + "run_id": "test-snakecase-runid", + "thread_id": "test-snakecase-threadid", + "messages": [ + { + "role": "user", + "content": [{"type": "input_text", "text": "Hello"}], + } + ], + "tools": [], + } + + context = ExecutionContext( + input_data=input_data, + agent=agent, + config=AgentConfig(), + ) + + events = [] + async for event in orchestrator.run(context): + events.append(event) + + # assert the last event has the expected run_id and thread_id + last_event = events[-1] + assert last_event.run_id == "test-snakecase-runid" + assert last_event.thread_id == "test-snakecase-threadid" diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index 96a70bc4a0..a5b169fbbf 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -25,7 +25,6 @@ TextContent, TextReasoningContent, TextSpanRegion, - ToolProtocol, UsageContent, UsageDetails, get_logger, @@ -35,7 +34,7 @@ ) from agent_framework._pydantic import AFBaseSettings from agent_framework.exceptions import ServiceInitializationError -from agent_framework.observability import use_observability +from agent_framework.observability import use_instrumentation from anthropic import AsyncAnthropic from anthropic.types.beta import ( BetaContentBlock, @@ -110,7 +109,7 @@ class AnthropicSettings(AFBaseSettings): @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class AnthropicClient(BaseChatClient): """Anthropic Chat client.""" @@ -214,9 +213,11 @@ async def _inner_get_response( chat_options: ChatOptions, **kwargs: Any, ) -> ChatResponse: - # Extract necessary state from messages and options - run_options = self._create_run_options(messages, chat_options, **kwargs) + # prepare + run_options = self._prepare_options(messages, chat_options, **kwargs) + # execute message = await self.anthropic_client.beta.messages.create(**run_options, stream=False) + # process return self._process_message(message) async def _inner_get_streaming_response( @@ -226,16 +227,17 @@ async def _inner_get_streaming_response( chat_options: ChatOptions, **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: - # Extract necessary state from messages and options - run_options = self._create_run_options(messages, chat_options, **kwargs) + # prepare + run_options = self._prepare_options(messages, chat_options, **kwargs) + # execute and process async for chunk in await self.anthropic_client.beta.messages.create(**run_options, stream=True): parsed_chunk = self._process_stream_event(chunk) if parsed_chunk: yield parsed_chunk - # region Create Run Options and Helpers + # region Prep methods - def _create_run_options( + def _prepare_options( self, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, @@ -251,78 +253,91 @@ def _create_run_options( Returns: A dictionary of run options for the Anthropic client. """ - if chat_options.additional_properties and "additional_beta_flags" in chat_options.additional_properties: - betas = chat_options.additional_properties.pop("additional_beta_flags") - else: - betas = [] - run_options: dict[str, Any] = { - "model": chat_options.model_id or self.model_id, - "messages": self._convert_messages_to_anthropic_format(messages), - "max_tokens": chat_options.max_tokens or ANTHROPIC_DEFAULT_MAX_TOKENS, - "extra_headers": {"User-Agent": AGENT_FRAMEWORK_USER_AGENT}, - "betas": {*BETA_FLAGS, *self.additional_beta_flags, *betas}, + run_options: dict[str, Any] = chat_options.to_dict( + exclude={ + "type", + "instructions", # handled via system message + "tool_choice", # handled separately + "allow_multiple_tool_calls", # handled via tool_choice + "additional_properties", # handled separately + } + ) + + # translations between ChatOptions and Anthropic API + translations = { + "model_id": "model", + "stop": "stop_sequences", } + for old_key, new_key in translations.items(): + if old_key in run_options and old_key != new_key: + run_options[new_key] = run_options.pop(old_key) + + # model id + if not run_options.get("model"): + if not self.model_id: + raise ValueError("model_id must be a non-empty string") + run_options["model"] = self.model_id + + # max_tokens - Anthropic requires this, default if not provided + if not run_options.get("max_tokens"): + run_options["max_tokens"] = ANTHROPIC_DEFAULT_MAX_TOKENS - # Add any additional options from chat_options or kwargs - if chat_options.temperature is not None: - run_options["temperature"] = chat_options.temperature - if chat_options.top_p is not None: - run_options["top_p"] = chat_options.top_p - if chat_options.stop is not None: - run_options["stop_sequences"] = chat_options.stop + # messages + run_options["messages"] = self._prepare_messages_for_anthropic(messages) + + # system message - first system message is passed as instructions if messages and isinstance(messages[0], ChatMessage) and messages[0].role == Role.SYSTEM: - # first system message is passed as instructions run_options["system"] = messages[0].text - if chat_options.tool_choice is not None: - match ( - chat_options.tool_choice if isinstance(chat_options.tool_choice, str) else chat_options.tool_choice.mode - ): - case "auto": - run_options["tool_choice"] = {"type": "auto"} - if chat_options.allow_multiple_tool_calls is not None: - run_options["tool_choice"][ # type:ignore[reportArgumentType] - "disable_parallel_tool_use" - ] = not chat_options.allow_multiple_tool_calls - case "required": - if chat_options.tool_choice.required_function_name: - run_options["tool_choice"] = { - "type": "tool", - "name": chat_options.tool_choice.required_function_name, - } - if chat_options.allow_multiple_tool_calls is not None: - run_options["tool_choice"][ # type:ignore[reportArgumentType] - "disable_parallel_tool_use" - ] = not chat_options.allow_multiple_tool_calls - else: - run_options["tool_choice"] = {"type": "any"} - if chat_options.allow_multiple_tool_calls is not None: - run_options["tool_choice"][ # type:ignore[reportArgumentType] - "disable_parallel_tool_use" - ] = not chat_options.allow_multiple_tool_calls - case "none": - run_options["tool_choice"] = {"type": "none"} - case _: - logger.debug(f"Ignoring unsupported tool choice mode: {chat_options.tool_choice.mode} for now") - if tools_and_mcp := self._convert_tools_to_anthropic_format(chat_options.tools): - run_options.update(tools_and_mcp) - if chat_options.additional_properties: - run_options.update(chat_options.additional_properties) + + # betas + run_options["betas"] = self._prepare_betas(chat_options) + + # extra headers + run_options["extra_headers"] = {"User-Agent": AGENT_FRAMEWORK_USER_AGENT} + + # tools, mcp servers and tool choice + if tools_config := self._prepare_tools_for_anthropic(chat_options): + run_options.update(tools_config) + + # additional properties + additional_options = { + key: value + for key, value in chat_options.additional_properties.items() + if value is not None and key != "additional_beta_flags" + } + if additional_options: + run_options.update(additional_options) run_options.update(kwargs) return run_options - def _convert_messages_to_anthropic_format(self, messages: MutableSequence[ChatMessage]) -> list[dict[str, Any]]: - """Convert a list of ChatMessages to the format expected by the Anthropic client. + def _prepare_betas(self, chat_options: ChatOptions) -> set[str]: + """Prepare the beta flags for the Anthropic API request. + + Args: + chat_options: The chat options that may contain additional beta flags. + + Returns: + A set of beta flag strings to include in the request. + """ + return { + *BETA_FLAGS, + *self.additional_beta_flags, + *chat_options.additional_properties.get("additional_beta_flags", []), + } + + def _prepare_messages_for_anthropic(self, messages: MutableSequence[ChatMessage]) -> list[dict[str, Any]]: + """Prepare a list of ChatMessages for the Anthropic client. This skips the first message if it is a system message, as Anthropic expects system instructions as a separate parameter. """ # first system message is passed as instructions if messages and isinstance(messages[0], ChatMessage) and messages[0].role == Role.SYSTEM: - return [self._convert_message_to_anthropic_format(msg) for msg in messages[1:]] - return [self._convert_message_to_anthropic_format(msg) for msg in messages] + return [self._prepare_message_for_anthropic(msg) for msg in messages[1:]] + return [self._prepare_message_for_anthropic(msg) for msg in messages] - def _convert_message_to_anthropic_format(self, message: ChatMessage) -> dict[str, Any]: - """Convert a ChatMessage to the format expected by the Anthropic client. + def _prepare_message_for_anthropic(self, message: ChatMessage) -> dict[str, Any]: + """Prepare a ChatMessage for the Anthropic client. Args: message: The ChatMessage to convert. @@ -376,58 +391,96 @@ def _convert_message_to_anthropic_format(self, message: ChatMessage) -> dict[str "content": a_content, } - def _convert_tools_to_anthropic_format( - self, tools: list[ToolProtocol | MutableMapping[str, Any]] | None - ) -> dict[str, Any] | None: - if not tools: - return None - tool_list: list[MutableMapping[str, Any]] = [] - mcp_server_list: list[MutableMapping[str, Any]] = [] - for tool in tools: - match tool: - case MutableMapping(): - tool_list.append(tool) - case AIFunction(): - tool_list.append({ - "type": "custom", - "name": tool.name, - "description": tool.description, - "input_schema": tool.parameters(), - }) - case HostedWebSearchTool(): - search_tool: dict[str, Any] = { - "type": "web_search_20250305", - "name": "web_search", - } - if tool.additional_properties: - search_tool.update(tool.additional_properties) - tool_list.append(search_tool) - case HostedCodeInterpreterTool(): - code_tool: dict[str, Any] = { - "type": "code_execution_20250825", - "name": "code_execution", - } - tool_list.append(code_tool) - case HostedMCPTool(): - server_def: dict[str, Any] = { - "type": "url", - "name": tool.name, - "url": str(tool.url), - } - if tool.allowed_tools: - server_def["tool_configuration"] = {"allowed_tools": list(tool.allowed_tools)} - if tool.headers and (auth := tool.headers.get("authorization")): - server_def["authorization_token"] = auth - mcp_server_list.append(server_def) + def _prepare_tools_for_anthropic(self, chat_options: ChatOptions) -> dict[str, Any] | None: + """Prepare tools and tool choice configuration for the Anthropic API request. + + Args: + chat_options: The chat options containing tools and tool choice settings. + + Returns: + A dictionary with tools, mcp_servers, and tool_choice configuration, or None if empty. + """ + result: dict[str, Any] = {} + + # Process tools + if chat_options.tools: + tool_list: list[MutableMapping[str, Any]] = [] + mcp_server_list: list[MutableMapping[str, Any]] = [] + for tool in chat_options.tools: + match tool: + case MutableMapping(): + tool_list.append(tool) + case AIFunction(): + tool_list.append({ + "type": "custom", + "name": tool.name, + "description": tool.description, + "input_schema": tool.parameters(), + }) + case HostedWebSearchTool(): + search_tool: dict[str, Any] = { + "type": "web_search_20250305", + "name": "web_search", + } + if tool.additional_properties: + search_tool.update(tool.additional_properties) + tool_list.append(search_tool) + case HostedCodeInterpreterTool(): + code_tool: dict[str, Any] = { + "type": "code_execution_20250825", + "name": "code_execution", + } + tool_list.append(code_tool) + case HostedMCPTool(): + server_def: dict[str, Any] = { + "type": "url", + "name": tool.name, + "url": str(tool.url), + } + if tool.allowed_tools: + server_def["tool_configuration"] = {"allowed_tools": list(tool.allowed_tools)} + if tool.headers and (auth := tool.headers.get("authorization")): + server_def["authorization_token"] = auth + mcp_server_list.append(server_def) + case _: + logger.debug(f"Ignoring unsupported tool type: {type(tool)} for now") + + if tool_list: + result["tools"] = tool_list + if mcp_server_list: + result["mcp_servers"] = mcp_server_list + + # Process tool choice + if chat_options.tool_choice is not None: + tool_choice_mode = ( + chat_options.tool_choice if isinstance(chat_options.tool_choice, str) else chat_options.tool_choice.mode + ) + match tool_choice_mode: + case "auto": + tool_choice: dict[str, Any] = {"type": "auto"} + if chat_options.allow_multiple_tool_calls is not None: + tool_choice["disable_parallel_tool_use"] = not chat_options.allow_multiple_tool_calls + result["tool_choice"] = tool_choice + case "required": + if ( + not isinstance(chat_options.tool_choice, str) + and chat_options.tool_choice.required_function_name + ): + tool_choice = { + "type": "tool", + "name": chat_options.tool_choice.required_function_name, + } + else: + tool_choice = {"type": "any"} + if chat_options.allow_multiple_tool_calls is not None: + tool_choice["disable_parallel_tool_use"] = not chat_options.allow_multiple_tool_calls + result["tool_choice"] = tool_choice + case "none": + result["tool_choice"] = {"type": "none"} case _: - logger.debug(f"Ignoring unsupported tool type: {type(tool)} for now") + logger.debug(f"Ignoring unsupported tool choice mode: {tool_choice_mode} for now") - all_tools: dict[str, list[MutableMapping[str, Any]]] = {} - if tool_list: - all_tools["tools"] = tool_list - if mcp_server_list: - all_tools["mcp_servers"] = mcp_server_list - return all_tools + return result or None # region Response Processing Methods @@ -445,11 +498,11 @@ def _process_message(self, message: BetaMessage) -> ChatResponse: messages=[ ChatMessage( role=Role.ASSISTANT, - contents=self._parse_message_contents(message.content), + contents=self._parse_contents_from_anthropic(message.content), raw_representation=message, ) ], - usage_details=self._parse_message_usage(message.usage), + usage_details=self._parse_usage_from_anthropic(message.usage), model_id=message.model, finish_reason=FINISH_REASON_MAP.get(message.stop_reason) if message.stop_reason else None, raw_response=message, @@ -467,12 +520,12 @@ def _process_stream_event(self, event: BetaRawMessageStreamEvent) -> ChatRespons match event.type: case "message_start": usage_details: list[UsageContent] = [] - if event.message.usage and (details := self._parse_message_usage(event.message.usage)): + if event.message.usage and (details := self._parse_usage_from_anthropic(event.message.usage)): usage_details.append(UsageContent(details=details)) return ChatResponseUpdate( response_id=event.message.id, - contents=[*self._parse_message_contents(event.message.content), *usage_details], + contents=[*self._parse_contents_from_anthropic(event.message.content), *usage_details], model_id=event.message.model, finish_reason=FINISH_REASON_MAP.get(event.message.stop_reason) if event.message.stop_reason @@ -480,7 +533,7 @@ def _process_stream_event(self, event: BetaRawMessageStreamEvent) -> ChatRespons raw_response=event, ) case "message_delta": - usage = self._parse_message_usage(event.usage) + usage = self._parse_usage_from_anthropic(event.usage) return ChatResponseUpdate( contents=[UsageContent(details=usage, raw_representation=event.usage)] if usage else [], raw_response=event, @@ -488,13 +541,13 @@ def _process_stream_event(self, event: BetaRawMessageStreamEvent) -> ChatRespons case "message_stop": logger.debug("Received message_stop event; no content to process.") case "content_block_start": - contents = self._parse_message_contents([event.content_block]) + contents = self._parse_contents_from_anthropic([event.content_block]) return ChatResponseUpdate( contents=contents, raw_response=event, ) case "content_block_delta": - contents = self._parse_message_contents([event.delta]) + contents = self._parse_contents_from_anthropic([event.delta]) return ChatResponseUpdate( contents=contents, raw_response=event, @@ -505,7 +558,7 @@ def _process_stream_event(self, event: BetaRawMessageStreamEvent) -> ChatRespons logger.debug(f"Ignoring unsupported event type: {event.type}") return None - def _parse_message_usage(self, usage: BetaUsage | BetaMessageDeltaUsage | None) -> UsageDetails | None: + def _parse_usage_from_anthropic(self, usage: BetaUsage | BetaMessageDeltaUsage | None) -> UsageDetails | None: """Parse usage details from the Anthropic message usage.""" if not usage: return None @@ -518,7 +571,7 @@ def _parse_message_usage(self, usage: BetaUsage | BetaMessageDeltaUsage | None) usage_details.additional_counts["anthropic.cache_read_input_tokens"] = usage.cache_read_input_tokens return usage_details - def _parse_message_contents( + def _parse_contents_from_anthropic( self, content: Sequence[BetaContentBlock | BetaRawContentBlockDelta | BetaTextBlock] ) -> list[Contents]: """Parse contents from the Anthropic message.""" @@ -530,7 +583,7 @@ def _parse_message_contents( TextContent( text=content_block.text, raw_representation=content_block, - annotations=self._parse_citations(content_block), + annotations=self._parse_citations_from_anthropic(content_block), ) ) case "tool_use" | "mcp_tool_use" | "server_tool_use": @@ -549,7 +602,7 @@ def _parse_message_contents( FunctionResultContent( call_id=content_block.tool_use_id, name=name if name and call_id == content_block.tool_use_id else "mcp_tool", - result=self._parse_message_contents(content_block.content) + result=self._parse_contents_from_anthropic(content_block.content) if isinstance(content_block.content, list) else content_block.content, raw_representation=content_block, @@ -608,7 +661,7 @@ def _parse_message_contents( logger.debug(f"Ignoring unsupported content type: {content_block.type} for now") return contents - def _parse_citations( + def _parse_citations_from_anthropic( self, content_block: BetaContentBlock | BetaRawContentBlockDelta | BetaTextBlock ) -> list[Annotations] | None: content_citations = getattr(content_block, "citations", None) diff --git a/python/packages/anthropic/pyproject.toml b/python/packages/anthropic/pyproject.toml index 7cdc807541..55ef501a9f 100644 --- a/python/packages/anthropic/pyproject.toml +++ b/python/packages/anthropic/pyproject.toml @@ -4,7 +4,7 @@ description = "Anthropic integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251211" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/anthropic/tests/test_anthropic_client.py b/python/packages/anthropic/tests/test_anthropic_client.py index fa6061a998..e8a3ac9cb0 100644 --- a/python/packages/anthropic/tests/test_anthropic_client.py +++ b/python/packages/anthropic/tests/test_anthropic_client.py @@ -151,12 +151,12 @@ def test_anthropic_client_service_url(mock_anthropic_client: MagicMock) -> None: # Message Conversion Tests -def test_convert_message_to_anthropic_format_text(mock_anthropic_client: MagicMock) -> None: +def test_prepare_message_for_anthropic_text(mock_anthropic_client: MagicMock) -> None: """Test converting text message to Anthropic format.""" chat_client = create_test_anthropic_client(mock_anthropic_client) message = ChatMessage(role=Role.USER, text="Hello, world!") - result = chat_client._convert_message_to_anthropic_format(message) + result = chat_client._prepare_message_for_anthropic(message) assert result["role"] == "user" assert len(result["content"]) == 1 @@ -164,7 +164,7 @@ def test_convert_message_to_anthropic_format_text(mock_anthropic_client: MagicMo assert result["content"][0]["text"] == "Hello, world!" -def test_convert_message_to_anthropic_format_function_call(mock_anthropic_client: MagicMock) -> None: +def test_prepare_message_for_anthropic_function_call(mock_anthropic_client: MagicMock) -> None: """Test converting function call message to Anthropic format.""" chat_client = create_test_anthropic_client(mock_anthropic_client) message = ChatMessage( @@ -178,7 +178,7 @@ def test_convert_message_to_anthropic_format_function_call(mock_anthropic_client ], ) - result = chat_client._convert_message_to_anthropic_format(message) + result = chat_client._prepare_message_for_anthropic(message) assert result["role"] == "assistant" assert len(result["content"]) == 1 @@ -188,7 +188,7 @@ def test_convert_message_to_anthropic_format_function_call(mock_anthropic_client assert result["content"][0]["input"] == {"location": "San Francisco"} -def test_convert_message_to_anthropic_format_function_result(mock_anthropic_client: MagicMock) -> None: +def test_prepare_message_for_anthropic_function_result(mock_anthropic_client: MagicMock) -> None: """Test converting function result message to Anthropic format.""" chat_client = create_test_anthropic_client(mock_anthropic_client) message = ChatMessage( @@ -202,7 +202,7 @@ def test_convert_message_to_anthropic_format_function_result(mock_anthropic_clie ], ) - result = chat_client._convert_message_to_anthropic_format(message) + result = chat_client._prepare_message_for_anthropic(message) assert result["role"] == "user" assert len(result["content"]) == 1 @@ -214,7 +214,7 @@ def test_convert_message_to_anthropic_format_function_result(mock_anthropic_clie assert result["content"][0]["is_error"] is False -def test_convert_message_to_anthropic_format_text_reasoning(mock_anthropic_client: MagicMock) -> None: +def test_prepare_message_for_anthropic_text_reasoning(mock_anthropic_client: MagicMock) -> None: """Test converting text reasoning message to Anthropic format.""" chat_client = create_test_anthropic_client(mock_anthropic_client) message = ChatMessage( @@ -222,7 +222,7 @@ def test_convert_message_to_anthropic_format_text_reasoning(mock_anthropic_clien contents=[TextReasoningContent(text="Let me think about this...")], ) - result = chat_client._convert_message_to_anthropic_format(message) + result = chat_client._prepare_message_for_anthropic(message) assert result["role"] == "assistant" assert len(result["content"]) == 1 @@ -230,7 +230,7 @@ def test_convert_message_to_anthropic_format_text_reasoning(mock_anthropic_clien assert result["content"][0]["thinking"] == "Let me think about this..." -def test_convert_messages_to_anthropic_format_with_system(mock_anthropic_client: MagicMock) -> None: +def test_prepare_messages_for_anthropic_with_system(mock_anthropic_client: MagicMock) -> None: """Test converting messages list with system message.""" chat_client = create_test_anthropic_client(mock_anthropic_client) messages = [ @@ -238,7 +238,7 @@ def test_convert_messages_to_anthropic_format_with_system(mock_anthropic_client: ChatMessage(role=Role.USER, text="Hello!"), ] - result = chat_client._convert_messages_to_anthropic_format(messages) + result = chat_client._prepare_messages_for_anthropic(messages) # System message should be skipped assert len(result) == 1 @@ -246,7 +246,7 @@ def test_convert_messages_to_anthropic_format_with_system(mock_anthropic_client: assert result[0]["content"][0]["text"] == "Hello!" -def test_convert_messages_to_anthropic_format_without_system(mock_anthropic_client: MagicMock) -> None: +def test_prepare_messages_for_anthropic_without_system(mock_anthropic_client: MagicMock) -> None: """Test converting messages list without system message.""" chat_client = create_test_anthropic_client(mock_anthropic_client) messages = [ @@ -254,7 +254,7 @@ def test_convert_messages_to_anthropic_format_without_system(mock_anthropic_clie ChatMessage(role=Role.ASSISTANT, text="Hi there!"), ] - result = chat_client._convert_messages_to_anthropic_format(messages) + result = chat_client._prepare_messages_for_anthropic(messages) assert len(result) == 2 assert result[0]["role"] == "user" @@ -264,7 +264,7 @@ def test_convert_messages_to_anthropic_format_without_system(mock_anthropic_clie # Tool Conversion Tests -def test_convert_tools_to_anthropic_format_ai_function(mock_anthropic_client: MagicMock) -> None: +def test_prepare_tools_for_anthropic_ai_function(mock_anthropic_client: MagicMock) -> None: """Test converting AIFunction to Anthropic format.""" chat_client = create_test_anthropic_client(mock_anthropic_client) @@ -273,9 +273,8 @@ def get_weather(location: Annotated[str, Field(description="Location to get weat """Get weather for a location.""" return f"Weather for {location}" - tools = [get_weather] - - result = chat_client._convert_tools_to_anthropic_format(tools) + chat_options = ChatOptions(tools=[get_weather]) + result = chat_client._prepare_tools_for_anthropic(chat_options) assert result is not None assert "tools" in result @@ -285,12 +284,12 @@ def get_weather(location: Annotated[str, Field(description="Location to get weat assert "Get weather for a location" in result["tools"][0]["description"] -def test_convert_tools_to_anthropic_format_web_search(mock_anthropic_client: MagicMock) -> None: +def test_prepare_tools_for_anthropic_web_search(mock_anthropic_client: MagicMock) -> None: """Test converting HostedWebSearchTool to Anthropic format.""" chat_client = create_test_anthropic_client(mock_anthropic_client) - tools = [HostedWebSearchTool()] + chat_options = ChatOptions(tools=[HostedWebSearchTool()]) - result = chat_client._convert_tools_to_anthropic_format(tools) + result = chat_client._prepare_tools_for_anthropic(chat_options) assert result is not None assert "tools" in result @@ -299,12 +298,12 @@ def test_convert_tools_to_anthropic_format_web_search(mock_anthropic_client: Mag assert result["tools"][0]["name"] == "web_search" -def test_convert_tools_to_anthropic_format_code_interpreter(mock_anthropic_client: MagicMock) -> None: +def test_prepare_tools_for_anthropic_code_interpreter(mock_anthropic_client: MagicMock) -> None: """Test converting HostedCodeInterpreterTool to Anthropic format.""" chat_client = create_test_anthropic_client(mock_anthropic_client) - tools = [HostedCodeInterpreterTool()] + chat_options = ChatOptions(tools=[HostedCodeInterpreterTool()]) - result = chat_client._convert_tools_to_anthropic_format(tools) + result = chat_client._prepare_tools_for_anthropic(chat_options) assert result is not None assert "tools" in result @@ -313,12 +312,12 @@ def test_convert_tools_to_anthropic_format_code_interpreter(mock_anthropic_clien assert result["tools"][0]["name"] == "code_execution" -def test_convert_tools_to_anthropic_format_mcp_tool(mock_anthropic_client: MagicMock) -> None: +def test_prepare_tools_for_anthropic_mcp_tool(mock_anthropic_client: MagicMock) -> None: """Test converting HostedMCPTool to Anthropic format.""" chat_client = create_test_anthropic_client(mock_anthropic_client) - tools = [HostedMCPTool(name="test-mcp", url="https://example.com/mcp")] + chat_options = ChatOptions(tools=[HostedMCPTool(name="test-mcp", url="https://example.com/mcp")]) - result = chat_client._convert_tools_to_anthropic_format(tools) + result = chat_client._prepare_tools_for_anthropic(chat_options) assert result is not None assert "mcp_servers" in result @@ -328,18 +327,20 @@ def test_convert_tools_to_anthropic_format_mcp_tool(mock_anthropic_client: Magic assert result["mcp_servers"][0]["url"] == "https://example.com/mcp" -def test_convert_tools_to_anthropic_format_mcp_with_auth(mock_anthropic_client: MagicMock) -> None: +def test_prepare_tools_for_anthropic_mcp_with_auth(mock_anthropic_client: MagicMock) -> None: """Test converting HostedMCPTool with authorization headers.""" chat_client = create_test_anthropic_client(mock_anthropic_client) - tools = [ - HostedMCPTool( - name="test-mcp", - url="https://example.com/mcp", - headers={"authorization": "Bearer token123"}, - ) - ] + chat_options = ChatOptions( + tools=[ + HostedMCPTool( + name="test-mcp", + url="https://example.com/mcp", + headers={"authorization": "Bearer token123"}, + ) + ] + ) - result = chat_client._convert_tools_to_anthropic_format(tools) + result = chat_client._prepare_tools_for_anthropic(chat_options) assert result is not None assert "mcp_servers" in result @@ -348,12 +349,12 @@ def test_convert_tools_to_anthropic_format_mcp_with_auth(mock_anthropic_client: assert result["mcp_servers"][0]["authorization_token"] == "Bearer token123" -def test_convert_tools_to_anthropic_format_dict_tool(mock_anthropic_client: MagicMock) -> None: +def test_prepare_tools_for_anthropic_dict_tool(mock_anthropic_client: MagicMock) -> None: """Test converting dict tool to Anthropic format.""" chat_client = create_test_anthropic_client(mock_anthropic_client) - tools = [{"type": "custom", "name": "custom_tool", "description": "A custom tool"}] + chat_options = ChatOptions(tools=[{"type": "custom", "name": "custom_tool", "description": "A custom tool"}]) - result = chat_client._convert_tools_to_anthropic_format(tools) + result = chat_client._prepare_tools_for_anthropic(chat_options) assert result is not None assert "tools" in result @@ -361,11 +362,12 @@ def test_convert_tools_to_anthropic_format_dict_tool(mock_anthropic_client: Magi assert result["tools"][0]["name"] == "custom_tool" -def test_convert_tools_to_anthropic_format_none(mock_anthropic_client: MagicMock) -> None: +def test_prepare_tools_for_anthropic_none(mock_anthropic_client: MagicMock) -> None: """Test converting None tools.""" chat_client = create_test_anthropic_client(mock_anthropic_client) + chat_options = ChatOptions() - result = chat_client._convert_tools_to_anthropic_format(None) + result = chat_client._prepare_tools_for_anthropic(chat_options) assert result is None @@ -373,14 +375,14 @@ def test_convert_tools_to_anthropic_format_none(mock_anthropic_client: MagicMock # Run Options Tests -async def test_create_run_options_basic(mock_anthropic_client: MagicMock) -> None: - """Test _create_run_options with basic ChatOptions.""" +async def test_prepare_options_basic(mock_anthropic_client: MagicMock) -> None: + """Test _prepare_options with basic ChatOptions.""" chat_client = create_test_anthropic_client(mock_anthropic_client) messages = [ChatMessage(role=Role.USER, text="Hello")] chat_options = ChatOptions(max_tokens=100, temperature=0.7) - run_options = chat_client._create_run_options(messages, chat_options) + run_options = chat_client._prepare_options(messages, chat_options) assert run_options["model"] == chat_client.model_id assert run_options["max_tokens"] == 100 @@ -388,8 +390,8 @@ async def test_create_run_options_basic(mock_anthropic_client: MagicMock) -> Non assert "messages" in run_options -async def test_create_run_options_with_system_message(mock_anthropic_client: MagicMock) -> None: - """Test _create_run_options with system message.""" +async def test_prepare_options_with_system_message(mock_anthropic_client: MagicMock) -> None: + """Test _prepare_options with system message.""" chat_client = create_test_anthropic_client(mock_anthropic_client) messages = [ @@ -398,52 +400,52 @@ async def test_create_run_options_with_system_message(mock_anthropic_client: Mag ] chat_options = ChatOptions() - run_options = chat_client._create_run_options(messages, chat_options) + run_options = chat_client._prepare_options(messages, chat_options) assert run_options["system"] == "You are helpful." assert len(run_options["messages"]) == 1 # System message not in messages list -async def test_create_run_options_with_tool_choice_auto(mock_anthropic_client: MagicMock) -> None: - """Test _create_run_options with auto tool choice.""" +async def test_prepare_options_with_tool_choice_auto(mock_anthropic_client: MagicMock) -> None: + """Test _prepare_options with auto tool choice.""" chat_client = create_test_anthropic_client(mock_anthropic_client) messages = [ChatMessage(role=Role.USER, text="Hello")] chat_options = ChatOptions(tool_choice="auto") - run_options = chat_client._create_run_options(messages, chat_options) + run_options = chat_client._prepare_options(messages, chat_options) assert run_options["tool_choice"]["type"] == "auto" -async def test_create_run_options_with_tool_choice_required(mock_anthropic_client: MagicMock) -> None: - """Test _create_run_options with required tool choice.""" +async def test_prepare_options_with_tool_choice_required(mock_anthropic_client: MagicMock) -> None: + """Test _prepare_options with required tool choice.""" chat_client = create_test_anthropic_client(mock_anthropic_client) messages = [ChatMessage(role=Role.USER, text="Hello")] # For required with specific function, need to pass as dict chat_options = ChatOptions(tool_choice={"mode": "required", "required_function_name": "get_weather"}) - run_options = chat_client._create_run_options(messages, chat_options) + run_options = chat_client._prepare_options(messages, chat_options) assert run_options["tool_choice"]["type"] == "tool" assert run_options["tool_choice"]["name"] == "get_weather" -async def test_create_run_options_with_tool_choice_none(mock_anthropic_client: MagicMock) -> None: - """Test _create_run_options with none tool choice.""" +async def test_prepare_options_with_tool_choice_none(mock_anthropic_client: MagicMock) -> None: + """Test _prepare_options with none tool choice.""" chat_client = create_test_anthropic_client(mock_anthropic_client) messages = [ChatMessage(role=Role.USER, text="Hello")] chat_options = ChatOptions(tool_choice="none") - run_options = chat_client._create_run_options(messages, chat_options) + run_options = chat_client._prepare_options(messages, chat_options) assert run_options["tool_choice"]["type"] == "none" -async def test_create_run_options_with_tools(mock_anthropic_client: MagicMock) -> None: - """Test _create_run_options with tools.""" +async def test_prepare_options_with_tools(mock_anthropic_client: MagicMock) -> None: + """Test _prepare_options with tools.""" chat_client = create_test_anthropic_client(mock_anthropic_client) @ai_function @@ -454,32 +456,32 @@ def get_weather(location: str) -> str: messages = [ChatMessage(role=Role.USER, text="Hello")] chat_options = ChatOptions(tools=[get_weather]) - run_options = chat_client._create_run_options(messages, chat_options) + run_options = chat_client._prepare_options(messages, chat_options) assert "tools" in run_options assert len(run_options["tools"]) == 1 -async def test_create_run_options_with_stop_sequences(mock_anthropic_client: MagicMock) -> None: - """Test _create_run_options with stop sequences.""" +async def test_prepare_options_with_stop_sequences(mock_anthropic_client: MagicMock) -> None: + """Test _prepare_options with stop sequences.""" chat_client = create_test_anthropic_client(mock_anthropic_client) messages = [ChatMessage(role=Role.USER, text="Hello")] chat_options = ChatOptions(stop=["STOP", "END"]) - run_options = chat_client._create_run_options(messages, chat_options) + run_options = chat_client._prepare_options(messages, chat_options) assert run_options["stop_sequences"] == ["STOP", "END"] -async def test_create_run_options_with_top_p(mock_anthropic_client: MagicMock) -> None: - """Test _create_run_options with top_p.""" +async def test_prepare_options_with_top_p(mock_anthropic_client: MagicMock) -> None: + """Test _prepare_options with top_p.""" chat_client = create_test_anthropic_client(mock_anthropic_client) messages = [ChatMessage(role=Role.USER, text="Hello")] chat_options = ChatOptions(top_p=0.9) - run_options = chat_client._create_run_options(messages, chat_options) + run_options = chat_client._prepare_options(messages, chat_options) assert run_options["top_p"] == 0.9 @@ -540,41 +542,41 @@ def test_process_message_with_tool_use(mock_anthropic_client: MagicMock) -> None assert response.finish_reason == FinishReason.TOOL_CALLS -def test_parse_message_usage_basic(mock_anthropic_client: MagicMock) -> None: - """Test _parse_message_usage with basic usage.""" +def test_parse_usage_from_anthropic_basic(mock_anthropic_client: MagicMock) -> None: + """Test _parse_usage_from_anthropic with basic usage.""" chat_client = create_test_anthropic_client(mock_anthropic_client) usage = BetaUsage(input_tokens=10, output_tokens=5) - result = chat_client._parse_message_usage(usage) + result = chat_client._parse_usage_from_anthropic(usage) assert result is not None assert result.input_token_count == 10 assert result.output_token_count == 5 -def test_parse_message_usage_none(mock_anthropic_client: MagicMock) -> None: - """Test _parse_message_usage with None usage.""" +def test_parse_usage_from_anthropic_none(mock_anthropic_client: MagicMock) -> None: + """Test _parse_usage_from_anthropic with None usage.""" chat_client = create_test_anthropic_client(mock_anthropic_client) - result = chat_client._parse_message_usage(None) + result = chat_client._parse_usage_from_anthropic(None) assert result is None -def test_parse_message_contents_text(mock_anthropic_client: MagicMock) -> None: - """Test _parse_message_contents with text content.""" +def test_parse_contents_from_anthropic_text(mock_anthropic_client: MagicMock) -> None: + """Test _parse_contents_from_anthropic with text content.""" chat_client = create_test_anthropic_client(mock_anthropic_client) content = [BetaTextBlock(type="text", text="Hello!")] - result = chat_client._parse_message_contents(content) + result = chat_client._parse_contents_from_anthropic(content) assert len(result) == 1 assert isinstance(result[0], TextContent) assert result[0].text == "Hello!" -def test_parse_message_contents_tool_use(mock_anthropic_client: MagicMock) -> None: - """Test _parse_message_contents with tool use.""" +def test_parse_contents_from_anthropic_tool_use(mock_anthropic_client: MagicMock) -> None: + """Test _parse_contents_from_anthropic with tool use.""" chat_client = create_test_anthropic_client(mock_anthropic_client) content = [ @@ -585,7 +587,7 @@ def test_parse_message_contents_tool_use(mock_anthropic_client: MagicMock) -> No input={"location": "SF"}, ) ] - result = chat_client._parse_message_contents(content) + result = chat_client._parse_contents_from_anthropic(content) assert len(result) == 1 assert isinstance(result[0], FunctionCallContent) diff --git a/python/packages/azure-ai-search/pyproject.toml b/python/packages/azure-ai-search/pyproject.toml index ebb897df8c..d4227e3dd8 100644 --- a/python/packages/azure-ai-search/pyproject.toml +++ b/python/packages/azure-ai-search/pyproject.toml @@ -4,7 +4,7 @@ description = "Azure AI Search integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251211" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index 0ea9ee1f05..50d18bbdc1 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -43,7 +43,7 @@ use_function_invocation, ) from agent_framework.exceptions import ServiceInitializationError, ServiceResponseException -from agent_framework.observability import use_observability +from agent_framework.observability import use_instrumentation from azure.ai.agents.aio import AgentsClient from azure.ai.agents.models import ( Agent, @@ -107,7 +107,7 @@ @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class AzureAIAgentClient(BaseChatClient): """Azure AI Agent Chat client.""" @@ -278,22 +278,13 @@ async def _inner_get_streaming_response( chat_options: ChatOptions, **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: - # Extract necessary state from messages and options - run_options, required_action_results = await self._create_run_options(messages, chat_options, **kwargs) - - # Get the thread ID - thread_id: str | None = ( - chat_options.conversation_id - if chat_options.conversation_id is not None - else run_options.get("conversation_id", self.thread_id) - ) - - # Determine which agent to use and create if needed + # prepare + run_options, required_action_results = await self._prepare_options(messages, chat_options, **kwargs) agent_id = await self._get_agent_id_or_create(run_options) - # Process and yield each update from the stream + # execute and process async for update in self._process_stream( - *(await self._create_agent_stream(thread_id, agent_id, run_options, required_action_results)) + *(await self._create_agent_stream(agent_id, run_options, required_action_results)) ): yield update @@ -342,7 +333,6 @@ async def _get_agent_id_or_create(self, run_options: dict[str, Any] | None = Non async def _create_agent_stream( self, - thread_id: str | None, agent_id: str, run_options: dict[str, Any], required_action_results: list[FunctionResultContent | FunctionApprovalResponseContent] | None, @@ -352,14 +342,14 @@ async def _create_agent_stream( Returns: tuple: (stream, final_thread_id) """ + thread_id = run_options.pop("thread_id", None) + # Get any active run for this thread thread_run = await self._get_active_thread_run(thread_id) stream: AsyncAgentRunStream[AsyncAgentEventHandler[Any]] | AsyncAgentEventHandler[Any] handler: AsyncAgentEventHandler[Any] = AsyncAgentEventHandler() - tool_run_id, tool_outputs, tool_approvals = self._convert_required_action_to_tool_output( - required_action_results - ) + tool_run_id, tool_outputs, tool_approvals = self._prepare_tool_outputs_for_azure_ai(required_action_results) if ( thread_run is not None @@ -421,19 +411,11 @@ async def _prepare_thread( # No thread ID was provided, so create a new thread. thread = await self.agents_client.threads.create( - tool_resources=run_options.get("tool_resources"), metadata=run_options.get("metadata") + tool_resources=run_options.get("tool_resources"), + metadata=run_options.get("metadata"), + messages=run_options.get("additional_messages"), ) - thread_id = thread.id - # workaround for: https://github.com/Azure/azure-sdk-for-python/issues/42805 - # this occurs when otel is enabled - # once fixed, in the function above, readd: - # `messages=run_options.pop("additional_messages")` - for msg in run_options.pop("additional_messages", []): - await self.agents_client.messages.create( - thread_id=thread_id, role=msg.role, content=msg.content, metadata=msg.metadata - ) - # and remove until here. - return thread_id + return thread.id def _extract_url_citations( self, message_delta_chunk: MessageDeltaChunk, azure_search_tool_calls: list[dict[str, Any]] @@ -611,7 +593,7 @@ async def _process_stream( "submit_tool_outputs", "submit_tool_approval", ]: - function_call_contents = self._create_function_call_contents( + function_call_contents = self._parse_function_calls_from_azure_ai( event_data, response_id ) if function_call_contents: @@ -753,8 +735,8 @@ def _capture_azure_search_tool_calls( except Exception as ex: logger.debug(f"Failed to capture Azure AI Search tool call: {ex}") - def _create_function_call_contents(self, event_data: ThreadRun, response_id: str | None) -> list[Contents]: - """Create function call contents from a tool action event.""" + def _parse_function_calls_from_azure_ai(self, event_data: ThreadRun, response_id: str | None) -> list[Contents]: + """Parse function call contents from an Azure AI tool action event.""" if isinstance(event_data, ThreadRun) and event_data.required_action is not None: if isinstance(event_data.required_action, SubmitToolOutputsAction): return [ @@ -815,117 +797,197 @@ def _prepare_tool_choice(self, chat_options: ChatOptions) -> None: chat_options.tool_choice = chat_tool_mode - async def _create_run_options( + async def _prepare_options( self, messages: MutableSequence[ChatMessage], - chat_options: ChatOptions | None, + chat_options: ChatOptions, **kwargs: Any, ) -> tuple[dict[str, Any], list[FunctionResultContent | FunctionApprovalResponseContent] | None]: - run_options: dict[str, Any] = {**kwargs} - agent_definition = await self._load_agent_definition_if_needed() - if chat_options is not None: - run_options["max_completion_tokens"] = chat_options.max_tokens - if chat_options.model_id is not None: - run_options["model"] = chat_options.model_id - else: - run_options["model"] = self.model_id - run_options["top_p"] = chat_options.top_p - run_options["temperature"] = chat_options.temperature - run_options["parallel_tool_calls"] = chat_options.allow_multiple_tool_calls - - tool_definitions: list[ToolDefinition | dict[str, Any]] = [] - - # Add tools from existing agent - if agent_definition is not None: - # Don't include function tools, since they will be passed through chat_options.tools - agent_tools = [tool for tool in agent_definition.tools if not isinstance(tool, FunctionToolDefinition)] - if agent_tools: - tool_definitions.extend(agent_tools) - if agent_definition.tool_resources: - run_options["tool_resources"] = agent_definition.tool_resources - - if chat_options.tool_choice is not None: - if chat_options.tool_choice != "none" and chat_options.tools: - # Add run tools - tool_definitions.extend(await self._prep_tools(chat_options.tools, run_options)) - - # Handle MCP tool resources for approval mode - mcp_tools = [tool for tool in chat_options.tools if isinstance(tool, HostedMCPTool)] - if mcp_tools: - mcp_resources = [] - for mcp_tool in mcp_tools: - server_label = mcp_tool.name.replace(" ", "_") - mcp_resource: dict[str, Any] = {"server_label": server_label} - - # Add headers if they exist - if mcp_tool.headers: - mcp_resource["headers"] = mcp_tool.headers - - if mcp_tool.approval_mode is not None: - match mcp_tool.approval_mode: - case str(): - # Map agent framework approval modes to Azure AI approval modes - approval_mode = ( - "always" if mcp_tool.approval_mode == "always_require" else "never" - ) - mcp_resource["require_approval"] = approval_mode - case _: - if "always_require_approval" in mcp_tool.approval_mode: - mcp_resource["require_approval"] = { - "always": mcp_tool.approval_mode["always_require_approval"] - } - elif "never_require_approval" in mcp_tool.approval_mode: - mcp_resource["require_approval"] = { - "never": mcp_tool.approval_mode["never_require_approval"] - } - - mcp_resources.append(mcp_resource) - - # Add MCP resources to tool_resources - if "tool_resources" not in run_options: - run_options["tool_resources"] = {} - run_options["tool_resources"]["mcp"] = mcp_resources - - if chat_options.tool_choice == "none": - run_options["tool_choice"] = AgentsToolChoiceOptionMode.NONE - elif chat_options.tool_choice == "auto": - run_options["tool_choice"] = AgentsToolChoiceOptionMode.AUTO - elif ( - isinstance(chat_options.tool_choice, ToolMode) - and chat_options.tool_choice == "required" - and chat_options.tool_choice.required_function_name is not None - ): - run_options["tool_choice"] = AgentsNamedToolChoice( - type=AgentsNamedToolChoiceType.FUNCTION, - function=FunctionName(name=chat_options.tool_choice.required_function_name), - ) + # Use to_dict with exclusions for properties handled separately + run_options: dict[str, Any] = chat_options.to_dict( + exclude={ + "type", + "instructions", # handled via messages + "tools", # handled separately + "tool_choice", # handled separately + "response_format", # handled separately + "additional_properties", # handled separately + "frequency_penalty", # not supported + "presence_penalty", # not supported + "user", # not supported + "stop", # not supported + "logit_bias", # not supported + "seed", # not supported + "store", # not supported + } + ) - if tool_definitions: - run_options["tools"] = tool_definitions + # Translation between ChatOptions and Azure AI Agents API + translations = { + "model_id": "model", + "allow_multiple_tool_calls": "parallel_tool_calls", + "max_tokens": "max_completion_tokens", + } + for old_key, new_key in translations.items(): + if old_key in run_options and old_key != new_key: + run_options[new_key] = run_options.pop(old_key) + + # model id fallback + if not run_options.get("model"): + run_options["model"] = self.model_id + + # tools and tool_choice + if tool_definitions := await self._prepare_tool_definitions_and_resources( + chat_options, agent_definition, run_options + ): + run_options["tools"] = tool_definitions - if chat_options.response_format is not None: - run_options["response_format"] = ResponseFormatJsonSchemaType( - json_schema=ResponseFormatJsonSchema( - name=chat_options.response_format.__name__, - schema=chat_options.response_format.model_json_schema(), - ) + if tool_choice := self._prepare_tool_choice_mode(chat_options): + run_options["tool_choice"] = tool_choice + + # response format + if chat_options.response_format is not None: + run_options["response_format"] = ResponseFormatJsonSchemaType( + json_schema=ResponseFormatJsonSchema( + name=chat_options.response_format.__name__, + schema=chat_options.response_format.model_json_schema(), ) + ) + + # messages + additional_messages, instructions, required_action_results = self._prepare_messages(messages) + if additional_messages: + run_options["additional_messages"] = additional_messages + + # Add instruction from existing agent at the beginning + if ( + agent_definition is not None + and agent_definition.instructions + and agent_definition.instructions not in instructions + ): + instructions.insert(0, agent_definition.instructions) + + if instructions: + run_options["instructions"] = "\n".join(instructions) + + # thread_id resolution (conversation_id takes precedence, then kwargs, then instance default) + run_options["thread_id"] = chat_options.conversation_id or kwargs.get("conversation_id") or self.thread_id + + return run_options, required_action_results + + def _prepare_tool_choice_mode( + self, chat_options: ChatOptions + ) -> AgentsToolChoiceOptionMode | AgentsNamedToolChoice | None: + """Prepare the tool choice mode for Azure AI Agents API.""" + if chat_options.tool_choice is None: + return None + if chat_options.tool_choice == "none": + return AgentsToolChoiceOptionMode.NONE + if chat_options.tool_choice == "auto": + return AgentsToolChoiceOptionMode.AUTO + if ( + isinstance(chat_options.tool_choice, ToolMode) + and chat_options.tool_choice == "required" + and chat_options.tool_choice.required_function_name is not None + ): + return AgentsNamedToolChoice( + type=AgentsNamedToolChoiceType.FUNCTION, + function=FunctionName(name=chat_options.tool_choice.required_function_name), + ) + return None + + async def _prepare_tool_definitions_and_resources( + self, + chat_options: ChatOptions, + agent_definition: Agent | None, + run_options: dict[str, Any], + ) -> list[ToolDefinition | dict[str, Any]]: + """Prepare tool definitions and resources for the run options.""" + tool_definitions: list[ToolDefinition | dict[str, Any]] = [] + + # Add tools from existing agent (exclude function tools - passed via chat_options.tools) + if agent_definition is not None: + agent_tools = [tool for tool in agent_definition.tools if not isinstance(tool, FunctionToolDefinition)] + if agent_tools: + tool_definitions.extend(agent_tools) + if agent_definition.tool_resources: + run_options["tool_resources"] = agent_definition.tool_resources + + # Add run tools if tool_choice allows + if chat_options.tool_choice is not None and chat_options.tool_choice != "none" and chat_options.tools: + tool_definitions.extend(await self._prepare_tools_for_azure_ai(chat_options.tools, run_options)) + + # Handle MCP tool resources + mcp_resources = self._prepare_mcp_resources(chat_options.tools) + if mcp_resources: + if "tool_resources" not in run_options: + run_options["tool_resources"] = {} + run_options["tool_resources"]["mcp"] = mcp_resources + return tool_definitions + + def _prepare_mcp_resources( + self, tools: Sequence["ToolProtocol | MutableMapping[str, Any]"] + ) -> list[dict[str, Any]]: + """Prepare MCP tool resources for approval mode configuration.""" + mcp_tools = [tool for tool in tools if isinstance(tool, HostedMCPTool)] + if not mcp_tools: + return [] + + mcp_resources: list[dict[str, Any]] = [] + for mcp_tool in mcp_tools: + server_label = mcp_tool.name.replace(" ", "_") + mcp_resource: dict[str, Any] = {"server_label": server_label} + + if mcp_tool.headers: + mcp_resource["headers"] = mcp_tool.headers + + if mcp_tool.approval_mode is not None: + match mcp_tool.approval_mode: + case str(): + # Map agent framework approval modes to Azure AI approval modes + approval_mode = "always" if mcp_tool.approval_mode == "always_require" else "never" + mcp_resource["require_approval"] = approval_mode + case _: + if "always_require_approval" in mcp_tool.approval_mode: + mcp_resource["require_approval"] = { + "always": mcp_tool.approval_mode["always_require_approval"] + } + elif "never_require_approval" in mcp_tool.approval_mode: + mcp_resource["require_approval"] = { + "never": mcp_tool.approval_mode["never_require_approval"] + } + + mcp_resources.append(mcp_resource) + + return mcp_resources + + def _prepare_messages( + self, messages: MutableSequence[ChatMessage] + ) -> tuple[ + list[ThreadMessageOptions] | None, + list[str], + list[FunctionResultContent | FunctionApprovalResponseContent] | None, + ]: + """Prepare messages for Azure AI Agents API. + + System/developer messages are turned into instructions, since there is no such message roles in Azure AI. + All other messages are added 1:1, treating assistant messages as agent messages + and everything else as user messages. + + Returns: + Tuple of (additional_messages, instructions, required_action_results) + """ instructions: list[str] = [] required_action_results: list[FunctionResultContent | FunctionApprovalResponseContent] | None = None - additional_messages: list[ThreadMessageOptions] | None = None - # System/developer messages are turned into instructions, since there is no such message roles in Azure AI. - # All other messages are added 1:1, treating assistant messages as agent messages - # and everything else as user messages. for chat_message in messages: if chat_message.role.value in ["system", "developer"]: for text_content in [content for content in chat_message.contents if isinstance(content, TextContent)]: instructions.append(text_content.text) - continue message_contents: list[MessageInputContentBlock] = [] @@ -942,7 +1004,7 @@ async def _create_run_options( elif isinstance(content.raw_representation, MessageInputContentBlock): message_contents.append(content.raw_representation) - if len(message_contents) > 0: + if message_contents: if additional_messages is None: additional_messages = [] additional_messages.append( @@ -952,26 +1014,12 @@ async def _create_run_options( ) ) - if additional_messages is not None: - run_options["additional_messages"] = additional_messages - - # Add instruction from existing agent at the beginning - if ( - agent_definition is not None - and agent_definition.instructions - and agent_definition.instructions not in instructions - ): - instructions.insert(0, agent_definition.instructions) - - if len(instructions) > 0: - run_options["instructions"] = "".join(instructions) - - return run_options, required_action_results + return additional_messages, instructions, required_action_results - async def _prep_tools( + async def _prepare_tools_for_azure_ai( self, tools: Sequence["ToolProtocol | MutableMapping[str, Any]"], run_options: dict[str, Any] | None = None ) -> list[ToolDefinition | dict[str, Any]]: - """Prepare tool definitions for the run options.""" + """Prepare tool definitions for the Azure AI Agents API.""" tool_definitions: list[ToolDefinition | dict[str, Any]] = [] for tool in tools: match tool: @@ -1044,10 +1092,11 @@ async def _prep_tools( raise ServiceInitializationError(f"Unsupported tool type: {type(tool)}") return tool_definitions - def _convert_required_action_to_tool_output( + def _prepare_tool_outputs_for_azure_ai( self, required_action_results: list[FunctionResultContent | FunctionApprovalResponseContent] | None, ) -> tuple[str | None, list[ToolOutput] | None, list[ToolApproval] | None]: + """Prepare function results and approvals for submission to the Azure AI API.""" run_id: str | None = None tool_outputs: list[ToolOutput] | None = None tool_approvals: list[ToolApproval] | None = None diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index 63bd2b27df..e10fc19068 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -15,7 +15,7 @@ use_function_invocation, ) from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidRequestError -from agent_framework.observability import use_observability +from agent_framework.observability import use_instrumentation from agent_framework.openai._responses_client import OpenAIBaseResponsesClient from azure.ai.projects.aio import AIProjectClient from azure.ai.projects.models import ( @@ -28,10 +28,6 @@ ) from azure.core.credentials_async import AsyncTokenCredential from azure.core.exceptions import ResourceNotFoundError -from openai.types.responses.parsed_response import ( - ParsedResponse, -) -from openai.types.responses.response import Response as OpenAIResponse from pydantic import BaseModel, ValidationError from ._shared import AzureAISettings @@ -41,6 +37,11 @@ else: from typing_extensions import Self # pragma: no cover +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover + logger = get_logger("agent_framework.azure") @@ -49,7 +50,7 @@ @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class AzureAIClient(OpenAIBaseResponsesClient): """Azure AI Agent client.""" @@ -164,27 +165,94 @@ def __init__( # Track whether we should close client connection self._should_close_client = should_close_client - async def setup_azure_ai_observability(self, enable_sensitive_data: bool | None = None) -> None: - """Use this method to setup tracing in your Azure AI Project. + async def configure_azure_monitor( + self, + enable_sensitive_data: bool = False, + **kwargs: Any, + ) -> None: + """Setup observability with Azure Monitor (Azure AI Foundry integration). + + This method configures Azure Monitor for telemetry collection using the + connection string from the Azure AI project client. + + Args: + enable_sensitive_data: Enable sensitive data logging (prompts, responses). + Should only be enabled in development/test environments. Default is False. + **kwargs: Additional arguments passed to configure_azure_monitor(). + Common options include: + - enable_live_metrics (bool): Enable Azure Monitor Live Metrics + - credential (TokenCredential): Azure credential for Entra ID auth + - resource (Resource): Custom OpenTelemetry resource + See https://learn.microsoft.com/python/api/azure-monitor-opentelemetry/azure.monitor.opentelemetry.configure_azure_monitor + for full list of options. + + Raises: + ImportError: If azure-monitor-opentelemetry-exporter is not installed. + + Examples: + .. code-block:: python + + from agent_framework.azure import AzureAIClient + from azure.ai.projects.aio import AIProjectClient + from azure.identity.aio import DefaultAzureCredential - This will take the connection string from the project project_client. - It will override any connection string that is set in the environment variables. - It will disable any OTLP endpoint that might have been set. + async with ( + DefaultAzureCredential() as credential, + AIProjectClient( + endpoint="https://your-project.api.azureml.ms", credential=credential + ) as project_client, + AzureAIClient(project_client=project_client) as client, + ): + # Setup observability with defaults + await client.configure_azure_monitor() + + # With live metrics enabled + await client.configure_azure_monitor(enable_live_metrics=True) + + # With sensitive data logging (dev/test only) + await client.configure_azure_monitor(enable_sensitive_data=True) + + Note: + This method retrieves the Application Insights connection string from the + Azure AI project client automatically. You must have Application Insights + configured in your Azure AI project for this to work. """ + # Get connection string from project client try: conn_string = await self.project_client.telemetry.get_application_insights_connection_string() except ResourceNotFoundError: logger.warning( - "No Application Insights connection string found for the Azure AI Project, " - "please call setup_observability() manually." + "No Application Insights connection string found for the Azure AI Project. " + "Please ensure Application Insights is configured in your Azure AI project, " + "or call configure_otel_providers() manually with custom exporters." ) return - from agent_framework.observability import setup_observability - setup_observability( - applicationinsights_connection_string=conn_string, enable_sensitive_data=enable_sensitive_data + # Import Azure Monitor with proper error handling + try: + from azure.monitor.opentelemetry import configure_azure_monitor + except ImportError as exc: + raise ImportError( + "azure-monitor-opentelemetry is required for Azure Monitor integration. " + "Install it with: pip install azure-monitor-opentelemetry" + ) from exc + + from agent_framework.observability import create_metric_views, create_resource, enable_instrumentation + + # Create resource if not provided in kwargs + if "resource" not in kwargs: + kwargs["resource"] = create_resource() + + # Configure Azure Monitor with connection string and kwargs + configure_azure_monitor( + connection_string=conn_string, + views=create_metric_views(), + **kwargs, ) + # Complete setup with core observability + enable_instrumentation(enable_sensitive_data=enable_sensitive_data) + async def __aenter__(self) -> "Self": """Async context manager entry.""" return self @@ -268,6 +336,10 @@ async def _get_agent_reference_or_create( if "tools" in run_options: args["tools"] = run_options["tools"] + if "temperature" in run_options: + args["temperature"] = run_options["temperature"] + if "top_p" in run_options: + args["top_p"] = run_options["top_p"] if "response_format" in run_options: response_format = run_options["response_format"] @@ -297,63 +369,57 @@ async def _close_client_if_needed(self) -> None: if self._should_close_client: await self.project_client.close() - def _prepare_input(self, messages: MutableSequence[ChatMessage]) -> tuple[list[ChatMessage], str | None]: - """Prepare input from messages and convert system/developer messages to instructions.""" - result: list[ChatMessage] = [] - instructions_list: list[str] = [] - instructions: str | None = None - - # System/developer messages are turned into instructions, since there is no such message roles in Azure AI. - for message in messages: - if message.role.value in ["system", "developer"]: - for text_content in [content for content in message.contents if isinstance(content, TextContent)]: - instructions_list.append(text_content.text) - else: - result.append(message) - - if len(instructions_list) > 0: - instructions = "".join(instructions_list) - - return result, instructions - - async def prepare_options( + @override + async def _prepare_options( self, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any, ) -> dict[str, Any]: """Take ChatOptions and create the specific options for Azure AI.""" - prepared_messages, instructions = self._prepare_input(messages) - run_options = await super().prepare_options(prepared_messages, chat_options, **kwargs) - + prepared_messages, instructions = self._prepare_messages_for_azure_ai(messages) + run_options = await super()._prepare_options(prepared_messages, chat_options, **kwargs) if not self._is_application_endpoint: # Application-scoped response APIs do not support "agent" property. agent_reference = await self._get_agent_reference_or_create(run_options, instructions) run_options["extra_body"] = {"agent": agent_reference} - conversation_id = chat_options.conversation_id or self.conversation_id - - # Handle different conversation ID formats - if conversation_id: - if conversation_id.startswith("resp_"): - # For response IDs, set previous_response_id and remove conversation property - run_options.pop("conversation", None) - run_options["previous_response_id"] = conversation_id - elif conversation_id.startswith("conv_"): - # For conversation IDs, set conversation and remove previous_response_id property - run_options.pop("previous_response_id", None) - run_options["conversation"] = conversation_id - # Remove properties that are not supported on request level # but were configured on agent level - exclude = ["model", "tools", "response_format"] + exclude = ["model", "tools", "response_format", "temperature", "top_p"] for property in exclude: run_options.pop(property, None) return run_options - async def initialize_client(self) -> None: + @override + def _get_current_conversation_id(self, chat_options: ChatOptions, **kwargs: Any) -> str | None: + """Get the current conversation ID from chat options or kwargs.""" + return chat_options.conversation_id or kwargs.get("conversation_id") or self.conversation_id + + def _prepare_messages_for_azure_ai( + self, messages: MutableSequence[ChatMessage] + ) -> tuple[list[ChatMessage], str | None]: + """Prepare input from messages and convert system/developer messages to instructions.""" + result: list[ChatMessage] = [] + instructions_list: list[str] = [] + instructions: str | None = None + + # System/developer messages are turned into instructions, since there is no such message roles in Azure AI. + for message in messages: + if message.role.value in ["system", "developer"]: + for text_content in [content for content in message.contents if isinstance(content, TextContent)]: + instructions_list.append(text_content.text) + else: + result.append(message) + + if len(instructions_list) > 0: + instructions = "".join(instructions_list) + + return result, instructions + + async def _initialize_client(self) -> None: """Initialize OpenAI client.""" self.client = self.project_client.get_openai_client() # type: ignore @@ -371,7 +437,8 @@ def _update_agent_name_and_description(self, agent_name: str | None, description if description and not self.agent_description: self.agent_description = description - def get_mcp_tool(self, tool: HostedMCPTool) -> Any: + @staticmethod + def _prepare_mcp_tool(tool: HostedMCPTool) -> MCPTool: # type: ignore[override] """Get MCP tool from HostedMCPTool.""" mcp = MCPTool(server_label=tool.name.replace(" ", "_"), server_url=str(tool.url)) @@ -389,17 +456,3 @@ def get_mcp_tool(self, tool: HostedMCPTool) -> Any: mcp["require_approval"] = {"never": {"tool_names": list(never_require_approvals)}} return mcp - - def get_conversation_id( - self, response: OpenAIResponse | ParsedResponse[BaseModel], store: bool | None - ) -> str | None: - """Get the conversation ID from the response if store is True.""" - if store is False: - return None - # If conversation ID exists, it means that we operate with conversation - # so we use conversation ID as input and output. - if response.conversation and response.conversation.id: - return response.conversation.id - # If conversation ID doesn't exist, we operate with responses - # so we use response ID as input and output. - return response.id diff --git a/python/packages/azure-ai/pyproject.toml b/python/packages/azure-ai/pyproject.toml index c20a28dc2e..685172e2e4 100644 --- a/python/packages/azure-ai/pyproject.toml +++ b/python/packages/azure-ai/pyproject.toml @@ -4,7 +4,7 @@ description = "Azure AI Foundry integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251211" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py index f1b4dafb63..134a3586b0 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py @@ -367,33 +367,33 @@ async def test_azure_ai_chat_client_get_agent_id_or_create_missing_model( await chat_client._get_agent_id_or_create() # type: ignore -async def test_azure_ai_chat_client_create_run_options_basic(mock_agents_client: MagicMock) -> None: - """Test _create_run_options with basic ChatOptions.""" +async def test_azure_ai_chat_client_prepare_options_basic(mock_agents_client: MagicMock) -> None: + """Test _prepare_options with basic ChatOptions.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) messages = [ChatMessage(role=Role.USER, text="Hello")] chat_options = ChatOptions(max_tokens=100, temperature=0.7) - run_options, tool_results = await chat_client._create_run_options(messages, chat_options) # type: ignore + run_options, tool_results = await chat_client._prepare_options(messages, chat_options) # type: ignore assert run_options is not None assert tool_results is None -async def test_azure_ai_chat_client_create_run_options_no_chat_options(mock_agents_client: MagicMock) -> None: - """Test _create_run_options with no ChatOptions.""" +async def test_azure_ai_chat_client_prepare_options_no_chat_options(mock_agents_client: MagicMock) -> None: + """Test _prepare_options with default ChatOptions.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) messages = [ChatMessage(role=Role.USER, text="Hello")] - run_options, tool_results = await chat_client._create_run_options(messages, None) # type: ignore + run_options, tool_results = await chat_client._prepare_options(messages, ChatOptions()) # type: ignore assert run_options is not None assert tool_results is None -async def test_azure_ai_chat_client_create_run_options_with_image_content(mock_agents_client: MagicMock) -> None: - """Test _create_run_options with image content.""" +async def test_azure_ai_chat_client_prepare_options_with_image_content(mock_agents_client: MagicMock) -> None: + """Test _prepare_options with image content.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") @@ -403,7 +403,7 @@ async def test_azure_ai_chat_client_create_run_options_with_image_content(mock_a image_content = UriContent(uri="https://example.com/image.jpg", media_type="image/jpeg") messages = [ChatMessage(role=Role.USER, contents=[image_content])] - run_options, _ = await chat_client._create_run_options(messages, None) # type: ignore + run_options, _ = await chat_client._prepare_options(messages, ChatOptions()) # type: ignore assert "additional_messages" in run_options assert len(run_options["additional_messages"]) == 1 @@ -412,11 +412,11 @@ async def test_azure_ai_chat_client_create_run_options_with_image_content(mock_a assert len(message.content) == 1 -def test_azure_ai_chat_client_convert_function_results_to_tool_output_none(mock_agents_client: MagicMock) -> None: - """Test _convert_required_action_to_tool_output with None input.""" +def test_azure_ai_chat_client_prepare_tool_outputs_for_azure_ai_none(mock_agents_client: MagicMock) -> None: + """Test _prepare_tool_outputs_for_azure_ai with None input.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) - run_id, tool_outputs, tool_approvals = chat_client._convert_required_action_to_tool_output(None) # type: ignore + run_id, tool_outputs, tool_approvals = chat_client._prepare_tool_outputs_for_azure_ai(None) # type: ignore assert run_id is None assert tool_outputs is None @@ -484,8 +484,8 @@ def test_azure_ai_chat_client_update_agent_name_and_description_with_none_input( assert chat_client.agent_description is None -async def test_azure_ai_chat_client_create_run_options_with_messages(mock_agents_client: MagicMock) -> None: - """Test _create_run_options with different message types.""" +async def test_azure_ai_chat_client_prepare_options_with_messages(mock_agents_client: MagicMock) -> None: + """Test _prepare_options with different message types.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) # Test with system message (becomes instruction) @@ -494,7 +494,7 @@ async def test_azure_ai_chat_client_create_run_options_with_messages(mock_agents ChatMessage(role=Role.USER, text="Hello"), ] - run_options, _ = await chat_client._create_run_options(messages, None) # type: ignore + run_options, _ = await chat_client._prepare_options(messages, ChatOptions()) # type: ignore assert "instructions" in run_options assert "You are a helpful assistant" in run_options["instructions"] @@ -565,8 +565,8 @@ async def test_azure_ai_chat_client_prepare_thread_cancels_active_run(mock_agent mock_agents_client.runs.cancel.assert_called_once_with("test-thread", "run_123") -def test_azure_ai_chat_client_create_function_call_contents_basic(mock_agents_client: MagicMock) -> None: - """Test _create_function_call_contents with basic function call.""" +def test_azure_ai_chat_client_parse_function_calls_from_azure_ai_basic(mock_agents_client: MagicMock) -> None: + """Test _parse_function_calls_from_azure_ai with basic function call.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) mock_tool_call = MagicMock(spec=RequiredFunctionToolCall) @@ -580,7 +580,7 @@ def test_azure_ai_chat_client_create_function_call_contents_basic(mock_agents_cl mock_event_data = MagicMock(spec=ThreadRun) mock_event_data.required_action = mock_submit_action - result = chat_client._create_function_call_contents(mock_event_data, "response_123") # type: ignore + result = chat_client._parse_function_calls_from_azure_ai(mock_event_data, "response_123") # type: ignore assert len(result) == 1 assert isinstance(result[0], FunctionCallContent) @@ -588,22 +588,24 @@ def test_azure_ai_chat_client_create_function_call_contents_basic(mock_agents_cl assert result[0].call_id == '["response_123", "call_123"]' -def test_azure_ai_chat_client_create_function_call_contents_no_submit_action(mock_agents_client: MagicMock) -> None: - """Test _create_function_call_contents when required_action is not SubmitToolOutputsAction.""" +def test_azure_ai_chat_client_parse_function_calls_from_azure_ai_no_submit_action( + mock_agents_client: MagicMock, +) -> None: + """Test _parse_function_calls_from_azure_ai when required_action is not SubmitToolOutputsAction.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) mock_event_data = MagicMock(spec=ThreadRun) mock_event_data.required_action = MagicMock() - result = chat_client._create_function_call_contents(mock_event_data, "response_123") # type: ignore + result = chat_client._parse_function_calls_from_azure_ai(mock_event_data, "response_123") # type: ignore assert result == [] -def test_azure_ai_chat_client_create_function_call_contents_non_function_tool_call( +def test_azure_ai_chat_client_parse_function_calls_from_azure_ai_non_function_tool_call( mock_agents_client: MagicMock, ) -> None: - """Test _create_function_call_contents with non-function tool call.""" + """Test _parse_function_calls_from_azure_ai with non-function tool call.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) mock_tool_call = MagicMock() @@ -614,37 +616,37 @@ def test_azure_ai_chat_client_create_function_call_contents_non_function_tool_ca mock_event_data = MagicMock(spec=ThreadRun) mock_event_data.required_action = mock_submit_action - result = chat_client._create_function_call_contents(mock_event_data, "response_123") # type: ignore + result = chat_client._parse_function_calls_from_azure_ai(mock_event_data, "response_123") # type: ignore assert result == [] -async def test_azure_ai_chat_client_create_run_options_with_none_tool_choice( +async def test_azure_ai_chat_client_prepare_options_with_none_tool_choice( mock_agents_client: MagicMock, ) -> None: - """Test _create_run_options with tool_choice set to 'none'.""" + """Test _prepare_options with tool_choice set to 'none'.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) chat_options = ChatOptions() chat_options.tool_choice = "none" - run_options, _ = await chat_client._create_run_options([], chat_options) # type: ignore + run_options, _ = await chat_client._prepare_options([], chat_options) # type: ignore from azure.ai.agents.models import AgentsToolChoiceOptionMode assert run_options["tool_choice"] == AgentsToolChoiceOptionMode.NONE -async def test_azure_ai_chat_client_create_run_options_with_auto_tool_choice( +async def test_azure_ai_chat_client_prepare_options_with_auto_tool_choice( mock_agents_client: MagicMock, ) -> None: - """Test _create_run_options with tool_choice set to 'auto'.""" + """Test _prepare_options with tool_choice set to 'auto'.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) chat_options = ChatOptions() chat_options.tool_choice = "auto" - run_options, _ = await chat_client._create_run_options([], chat_options) # type: ignore + run_options, _ = await chat_client._prepare_options([], chat_options) # type: ignore from azure.ai.agents.models import AgentsToolChoiceOptionMode @@ -669,10 +671,10 @@ async def test_azure_ai_chat_client_prepare_tool_choice_none_string( assert chat_options.tool_choice == ToolMode.NONE.mode -async def test_azure_ai_chat_client_create_run_options_tool_choice_required_specific_function( +async def test_azure_ai_chat_client_prepare_options_tool_choice_required_specific_function( mock_agents_client: MagicMock, ) -> None: - """Test _create_run_options with ToolMode.REQUIRED specifying a specific function name.""" + """Test _prepare_options with ToolMode.REQUIRED specifying a specific function name.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) required_tool_mode = ToolMode.REQUIRED("specific_function_name") @@ -682,7 +684,7 @@ async def test_azure_ai_chat_client_create_run_options_tool_choice_required_spec chat_options = ChatOptions(tools=[dict_tool], tool_choice=required_tool_mode) messages = [ChatMessage(role=Role.USER, text="Hello")] - run_options, _ = await chat_client._create_run_options(messages, chat_options) # type: ignore + run_options, _ = await chat_client._prepare_options(messages, chat_options) # type: ignore # Verify tool_choice is set to the specific named function assert "tool_choice" in run_options @@ -692,10 +694,10 @@ async def test_azure_ai_chat_client_create_run_options_tool_choice_required_spec assert tool_choice.function.name == "specific_function_name" # type: ignore -async def test_azure_ai_chat_client_create_run_options_with_response_format( +async def test_azure_ai_chat_client_prepare_options_with_response_format( mock_agents_client: MagicMock, ) -> None: - """Test _create_run_options with response_format configured.""" + """Test _prepare_options with response_format configured.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) class TestResponseModel(BaseModel): @@ -704,7 +706,7 @@ class TestResponseModel(BaseModel): chat_options = ChatOptions() chat_options.response_format = TestResponseModel - run_options, _ = await chat_client._create_run_options([], chat_options) # type: ignore + run_options, _ = await chat_client._prepare_options([], chat_options) # type: ignore assert "response_format" in run_options response_format = run_options["response_format"] @@ -720,8 +722,8 @@ def test_azure_ai_chat_client_service_url_method(mock_agents_client: MagicMock) assert url == "https://test-endpoint.com/" -async def test_azure_ai_chat_client_prep_tools_ai_function(mock_agents_client: MagicMock) -> None: - """Test _prep_tools with AIFunction tool.""" +async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_ai_function(mock_agents_client: MagicMock) -> None: + """Test _prepare_tools_for_azure_ai with AIFunction tool.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") @@ -729,28 +731,28 @@ async def test_azure_ai_chat_client_prep_tools_ai_function(mock_agents_client: M mock_ai_function = MagicMock(spec=AIFunction) mock_ai_function.to_json_schema_spec.return_value = {"type": "function", "function": {"name": "test_function"}} - result = await chat_client._prep_tools([mock_ai_function]) # type: ignore + result = await chat_client._prepare_tools_for_azure_ai([mock_ai_function]) # type: ignore assert len(result) == 1 assert result[0] == {"type": "function", "function": {"name": "test_function"}} mock_ai_function.to_json_schema_spec.assert_called_once() -async def test_azure_ai_chat_client_prep_tools_code_interpreter(mock_agents_client: MagicMock) -> None: - """Test _prep_tools with HostedCodeInterpreterTool.""" +async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_code_interpreter(mock_agents_client: MagicMock) -> None: + """Test _prepare_tools_for_azure_ai with HostedCodeInterpreterTool.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") code_interpreter_tool = HostedCodeInterpreterTool() - result = await chat_client._prep_tools([code_interpreter_tool]) # type: ignore + result = await chat_client._prepare_tools_for_azure_ai([code_interpreter_tool]) # type: ignore assert len(result) == 1 assert isinstance(result[0], CodeInterpreterToolDefinition) -async def test_azure_ai_chat_client_prep_tools_mcp_tool(mock_agents_client: MagicMock) -> None: - """Test _prep_tools with HostedMCPTool.""" +async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_mcp_tool(mock_agents_client: MagicMock) -> None: + """Test _prepare_tools_for_azure_ai with HostedMCPTool.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") @@ -762,7 +764,7 @@ async def test_azure_ai_chat_client_prep_tools_mcp_tool(mock_agents_client: Magi mock_mcp_tool.definitions = [{"type": "mcp", "name": "test_mcp"}] mock_mcp_tool_class.return_value = mock_mcp_tool - result = await chat_client._prep_tools([mcp_tool]) # type: ignore + result = await chat_client._prepare_tools_for_azure_ai([mcp_tool]) # type: ignore assert len(result) == 1 assert result[0] == {"type": "mcp", "name": "test_mcp"} @@ -774,8 +776,8 @@ async def test_azure_ai_chat_client_prep_tools_mcp_tool(mock_agents_client: Magi assert set(call_args["allowed_tools"]) == {"tool1", "tool2"} -async def test_azure_ai_chat_client_create_run_options_mcp_never_require(mock_agents_client: MagicMock) -> None: - """Test _create_run_options with HostedMCPTool having never_require approval mode.""" +async def test_azure_ai_chat_client_prepare_options_mcp_never_require(mock_agents_client: MagicMock) -> None: + """Test _prepare_options with HostedMCPTool having never_require approval mode.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) mcp_tool = HostedMCPTool(name="Test MCP Tool", url="https://example.com/mcp", approval_mode="never_require") @@ -784,12 +786,12 @@ async def test_azure_ai_chat_client_create_run_options_mcp_never_require(mock_ag chat_options = ChatOptions(tools=[mcp_tool], tool_choice="auto") with patch("agent_framework_azure_ai._chat_client.McpTool") as mock_mcp_tool_class: - # Mock _prep_tools to avoid actual tool preparation + # Mock _prepare_tools_for_azure_ai to avoid actual tool preparation mock_mcp_tool_instance = MagicMock() mock_mcp_tool_instance.definitions = [{"type": "mcp", "name": "test_mcp"}] mock_mcp_tool_class.return_value = mock_mcp_tool_instance - run_options, _ = await chat_client._create_run_options(messages, chat_options) # type: ignore + run_options, _ = await chat_client._prepare_options(messages, chat_options) # type: ignore # Verify tool_resources is created with correct MCP approval structure assert "tool_resources" in run_options, ( @@ -803,8 +805,8 @@ async def test_azure_ai_chat_client_create_run_options_mcp_never_require(mock_ag assert mcp_resource["require_approval"] == "never" -async def test_azure_ai_chat_client_create_run_options_mcp_with_headers(mock_agents_client: MagicMock) -> None: - """Test _create_run_options with HostedMCPTool having headers.""" +async def test_azure_ai_chat_client_prepare_options_mcp_with_headers(mock_agents_client: MagicMock) -> None: + """Test _prepare_options with HostedMCPTool having headers.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client) # Test with headers @@ -817,12 +819,12 @@ async def test_azure_ai_chat_client_create_run_options_mcp_with_headers(mock_age chat_options = ChatOptions(tools=[mcp_tool], tool_choice="auto") with patch("agent_framework_azure_ai._chat_client.McpTool") as mock_mcp_tool_class: - # Mock _prep_tools to avoid actual tool preparation + # Mock _prepare_tools_for_azure_ai to avoid actual tool preparation mock_mcp_tool_instance = MagicMock() mock_mcp_tool_instance.definitions = [{"type": "mcp", "name": "test_mcp"}] mock_mcp_tool_class.return_value = mock_mcp_tool_instance - run_options, _ = await chat_client._create_run_options(messages, chat_options) # type: ignore + run_options, _ = await chat_client._prepare_options(messages, chat_options) # type: ignore # Verify tool_resources is created with headers assert "tool_resources" in run_options @@ -835,8 +837,10 @@ async def test_azure_ai_chat_client_create_run_options_mcp_with_headers(mock_age assert mcp_resource["headers"] == headers -async def test_azure_ai_chat_client_prep_tools_web_search_bing_grounding(mock_agents_client: MagicMock) -> None: - """Test _prep_tools with HostedWebSearchTool using Bing Grounding.""" +async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_web_search_bing_grounding( + mock_agents_client: MagicMock, +) -> None: + """Test _prepare_tools_for_azure_ai with HostedWebSearchTool using Bing Grounding.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") @@ -856,7 +860,7 @@ async def test_azure_ai_chat_client_prep_tools_web_search_bing_grounding(mock_ag mock_bing_tool.definitions = [{"type": "bing_grounding"}] mock_bing_grounding.return_value = mock_bing_tool - result = await chat_client._prep_tools([web_search_tool]) # type: ignore + result = await chat_client._prepare_tools_for_azure_ai([web_search_tool]) # type: ignore assert len(result) == 1 assert result[0] == {"type": "bing_grounding"} @@ -868,10 +872,10 @@ async def test_azure_ai_chat_client_prep_tools_web_search_bing_grounding(mock_ag assert "connection_id" in call_args -async def test_azure_ai_chat_client_prep_tools_web_search_bing_grounding_with_connection_id( +async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_web_search_bing_grounding_with_connection_id( mock_agents_client: MagicMock, ) -> None: - """Test _prep_tools with HostedWebSearchTool using Bing Grounding with connection_id (no HTTP call).""" + """Test _prepare_tools_... with HostedWebSearchTool using Bing Grounding with connection_id (no HTTP call).""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") @@ -888,15 +892,17 @@ async def test_azure_ai_chat_client_prep_tools_web_search_bing_grounding_with_co mock_bing_tool.definitions = [{"type": "bing_grounding"}] mock_bing_grounding.return_value = mock_bing_tool - result = await chat_client._prep_tools([web_search_tool]) # type: ignore + result = await chat_client._prepare_tools_for_azure_ai([web_search_tool]) # type: ignore assert len(result) == 1 assert result[0] == {"type": "bing_grounding"} mock_bing_grounding.assert_called_once_with(connection_id="direct-connection-id", count=3) -async def test_azure_ai_chat_client_prep_tools_web_search_custom_bing(mock_agents_client: MagicMock) -> None: - """Test _prep_tools with HostedWebSearchTool using Custom Bing Search.""" +async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_web_search_custom_bing( + mock_agents_client: MagicMock, +) -> None: + """Test _prepare_tools_for_azure_ai with HostedWebSearchTool using Custom Bing Search.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") @@ -914,16 +920,16 @@ async def test_azure_ai_chat_client_prep_tools_web_search_custom_bing(mock_agent mock_custom_tool.definitions = [{"type": "bing_custom_search"}] mock_custom_bing.return_value = mock_custom_tool - result = await chat_client._prep_tools([web_search_tool]) # type: ignore + result = await chat_client._prepare_tools_for_azure_ai([web_search_tool]) # type: ignore assert len(result) == 1 assert result[0] == {"type": "bing_custom_search"} -async def test_azure_ai_chat_client_prep_tools_file_search_with_vector_stores( +async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_file_search_with_vector_stores( mock_agents_client: MagicMock, ) -> None: - """Test _prep_tools with HostedFileSearchTool using vector stores.""" + """Test _prepare_tools_for_azure_ai with HostedFileSearchTool using vector stores.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") @@ -938,7 +944,7 @@ async def test_azure_ai_chat_client_prep_tools_file_search_with_vector_stores( mock_file_search.return_value = mock_file_tool run_options = {} - result = await chat_client._prep_tools([file_search_tool], run_options) # type: ignore + result = await chat_client._prepare_tools_for_azure_ai([file_search_tool], run_options) # type: ignore assert len(result) == 1 assert result[0] == {"type": "file_search"} @@ -973,7 +979,7 @@ async def test_azure_ai_chat_client_create_agent_stream_submit_tool_approvals( with patch("azure.ai.agents.models.AsyncAgentEventHandler", return_value=mock_handler): stream, final_thread_id = await chat_client._create_agent_stream( # type: ignore - "test-thread", "test-agent", {}, [approval_response] + "test-agent", {"thread_id": "test-thread"}, [approval_response] ) # Verify the approvals path was taken @@ -987,26 +993,26 @@ async def test_azure_ai_chat_client_create_agent_stream_submit_tool_approvals( assert call_args["tool_approvals"][0].approve is True -async def test_azure_ai_chat_client_prep_tools_dict_tool(mock_agents_client: MagicMock) -> None: - """Test _prep_tools with dictionary tool definition.""" +async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_dict_tool(mock_agents_client: MagicMock) -> None: + """Test _prepare_tools_for_azure_ai with dictionary tool definition.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") dict_tool = {"type": "custom_tool", "config": {"param": "value"}} - result = await chat_client._prep_tools([dict_tool]) # type: ignore + result = await chat_client._prepare_tools_for_azure_ai([dict_tool]) # type: ignore assert len(result) == 1 assert result[0] == dict_tool -async def test_azure_ai_chat_client_prep_tools_unsupported_tool(mock_agents_client: MagicMock) -> None: - """Test _prep_tools with unsupported tool type.""" +async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_unsupported_tool(mock_agents_client: MagicMock) -> None: + """Test _prepare_tools_for_azure_ai with unsupported tool type.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") unsupported_tool = "not_a_tool" with pytest.raises(ServiceInitializationError, match="Unsupported tool type: "): - await chat_client._prep_tools([unsupported_tool]) # type: ignore + await chat_client._prepare_tools_for_azure_ai([unsupported_tool]) # type: ignore async def test_azure_ai_chat_client_get_active_thread_run_with_active_run(mock_agents_client: MagicMock) -> None: @@ -1072,16 +1078,16 @@ async def test_azure_ai_chat_client_service_url(mock_agents_client: MagicMock) - assert result == "https://test-endpoint.com/" -async def test_azure_ai_chat_client_convert_required_action_to_tool_output_function_result( +async def test_azure_ai_chat_client_prepare_tool_outputs_for_azure_ai_function_result( mock_agents_client: MagicMock, ) -> None: - """Test _convert_required_action_to_tool_output with FunctionResultContent.""" + """Test _prepare_tool_outputs_for_azure_ai with FunctionResultContent.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") # Test with simple result function_result = FunctionResultContent(call_id='["run_123", "call_456"]', result="Simple result") - run_id, tool_outputs, tool_approvals = chat_client._convert_required_action_to_tool_output([function_result]) # type: ignore + run_id, tool_outputs, tool_approvals = chat_client._prepare_tool_outputs_for_azure_ai([function_result]) # type: ignore assert run_id == "run_123" assert tool_approvals is None @@ -1092,7 +1098,7 @@ async def test_azure_ai_chat_client_convert_required_action_to_tool_output_funct async def test_azure_ai_chat_client_convert_required_action_invalid_call_id(mock_agents_client: MagicMock) -> None: - """Test _convert_required_action_to_tool_output with invalid call_id format.""" + """Test _prepare_tool_outputs_for_azure_ai with invalid call_id format.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") @@ -1100,19 +1106,19 @@ async def test_azure_ai_chat_client_convert_required_action_invalid_call_id(mock function_result = FunctionResultContent(call_id="invalid_json", result="result") with pytest.raises(json.JSONDecodeError): - chat_client._convert_required_action_to_tool_output([function_result]) # type: ignore + chat_client._prepare_tool_outputs_for_azure_ai([function_result]) # type: ignore async def test_azure_ai_chat_client_convert_required_action_invalid_structure( mock_agents_client: MagicMock, ) -> None: - """Test _convert_required_action_to_tool_output with invalid call_id structure.""" + """Test _prepare_tool_outputs_for_azure_ai with invalid call_id structure.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") # Valid JSON but invalid structure (missing second element) function_result = FunctionResultContent(call_id='["run_123"]', result="result") - run_id, tool_outputs, tool_approvals = chat_client._convert_required_action_to_tool_output([function_result]) # type: ignore + run_id, tool_outputs, tool_approvals = chat_client._prepare_tool_outputs_for_azure_ai([function_result]) # type: ignore # Should return None values when structure is invalid assert run_id is None @@ -1123,7 +1129,7 @@ async def test_azure_ai_chat_client_convert_required_action_invalid_structure( async def test_azure_ai_chat_client_convert_required_action_serde_model_results( mock_agents_client: MagicMock, ) -> None: - """Test _convert_required_action_to_tool_output with BaseModel results.""" + """Test _prepare_tool_outputs_for_azure_ai with BaseModel results.""" class MockResult(SerializationMixin): def __init__(self, name: str, value: int): @@ -1136,7 +1142,7 @@ def __init__(self, name: str, value: int): mock_result = MockResult(name="test", value=42) function_result = FunctionResultContent(call_id='["run_123", "call_456"]', result=mock_result) - run_id, tool_outputs, tool_approvals = chat_client._convert_required_action_to_tool_output([function_result]) # type: ignore + run_id, tool_outputs, tool_approvals = chat_client._prepare_tool_outputs_for_azure_ai([function_result]) # type: ignore assert run_id == "run_123" assert tool_approvals is None @@ -1151,7 +1157,7 @@ def __init__(self, name: str, value: int): async def test_azure_ai_chat_client_convert_required_action_multiple_results( mock_agents_client: MagicMock, ) -> None: - """Test _convert_required_action_to_tool_output with multiple results.""" + """Test _prepare_tool_outputs_for_azure_ai with multiple results.""" class MockResult(SerializationMixin): def __init__(self, data: str): @@ -1164,7 +1170,7 @@ def __init__(self, data: str): results_list = [mock_basemodel, {"key": "value"}, "string_result"] function_result = FunctionResultContent(call_id='["run_123", "call_456"]', result=results_list) - run_id, tool_outputs, tool_approvals = chat_client._convert_required_action_to_tool_output([function_result]) # type: ignore + run_id, tool_outputs, tool_approvals = chat_client._prepare_tool_outputs_for_azure_ai([function_result]) # type: ignore assert run_id == "run_123" assert tool_outputs is not None @@ -1184,7 +1190,7 @@ def __init__(self, data: str): async def test_azure_ai_chat_client_convert_required_action_approval_response( mock_agents_client: MagicMock, ) -> None: - """Test _convert_required_action_to_tool_output with FunctionApprovalResponseContent.""" + """Test _prepare_tool_outputs_for_azure_ai with FunctionApprovalResponseContent.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") # Test with approval response - need to provide required fields @@ -1194,7 +1200,7 @@ async def test_azure_ai_chat_client_convert_required_action_approval_response( approved=True, ) - run_id, tool_outputs, tool_approvals = chat_client._convert_required_action_to_tool_output([approval_response]) # type: ignore + run_id, tool_outputs, tool_approvals = chat_client._prepare_tool_outputs_for_azure_ai([approval_response]) # type: ignore assert run_id == "run_123" assert tool_outputs is None @@ -1204,10 +1210,10 @@ async def test_azure_ai_chat_client_convert_required_action_approval_response( assert tool_approvals[0].approve is True -async def test_azure_ai_chat_client_create_function_call_contents_approval_request( +async def test_azure_ai_chat_client_parse_function_calls_from_azure_ai_approval_request( mock_agents_client: MagicMock, ) -> None: - """Test _create_function_call_contents with approval action.""" + """Test _parse_function_calls_from_azure_ai with approval action.""" chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") # Mock SubmitToolApprovalAction with RequiredMcpToolCall @@ -1222,7 +1228,7 @@ async def test_azure_ai_chat_client_create_function_call_contents_approval_reque mock_event_data = MagicMock(spec=ThreadRun) mock_event_data.required_action = mock_approval_action - result = chat_client._create_function_call_contents(mock_event_data, "response_123") # type: ignore + result = chat_client._parse_function_calls_from_azure_ai(mock_event_data, "response_123") # type: ignore assert len(result) == 1 assert isinstance(result[0], FunctionApprovalRequestContent) @@ -1312,7 +1318,7 @@ async def test_azure_ai_chat_client_create_agent_stream_submit_tool_outputs( with patch("azure.ai.agents.models.AsyncAgentEventHandler", return_value=mock_handler): stream, final_thread_id = await chat_client._create_agent_stream( # type: ignore - thread_id="test-thread", agent_id="test-agent", run_options={}, required_action_results=[function_result] + agent_id="test-agent", run_options={"thread_id": "test-thread"}, required_action_results=[function_result] ) # Should call submit_tool_outputs_stream since we have matching run ID diff --git a/python/packages/azure-ai/tests/test_azure_ai_client.py b/python/packages/azure-ai/tests/test_azure_ai_client.py index 2167c1340a..028e8fbdb8 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_client.py @@ -249,10 +249,10 @@ async def test_azure_ai_client_get_agent_reference_missing_model( await client._get_agent_reference_or_create({}, None) # type: ignore -async def test_azure_ai_client_prepare_input_with_system_messages( +async def test_azure_ai_client_prepare_messages_for_azure_ai_with_system_messages( mock_project_client: MagicMock, ) -> None: - """Test _prepare_input converts system/developer messages to instructions.""" + """Test _prepare_messages_for_azure_ai converts system/developer messages to instructions.""" client = create_test_azure_ai_client(mock_project_client) messages = [ @@ -261,7 +261,7 @@ async def test_azure_ai_client_prepare_input_with_system_messages( ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="System response")]), ] - result_messages, instructions = client._prepare_input(messages) # type: ignore + result_messages, instructions = client._prepare_messages_for_azure_ai(messages) # type: ignore assert len(result_messages) == 2 assert result_messages[0].role == Role.USER @@ -269,10 +269,10 @@ async def test_azure_ai_client_prepare_input_with_system_messages( assert instructions == "You are a helpful assistant." -async def test_azure_ai_client_prepare_input_no_system_messages( +async def test_azure_ai_client_prepare_messages_for_azure_ai_no_system_messages( mock_project_client: MagicMock, ) -> None: - """Test _prepare_input with no system/developer messages.""" + """Test _prepare_messages_for_azure_ai with no system/developer messages.""" client = create_test_azure_ai_client(mock_project_client) messages = [ @@ -280,7 +280,7 @@ async def test_azure_ai_client_prepare_input_no_system_messages( ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Hi there!")]), ] - result_messages, instructions = client._prepare_input(messages) # type: ignore + result_messages, instructions = client._prepare_messages_for_azure_ai(messages) # type: ignore assert len(result_messages) == 2 assert instructions is None @@ -294,14 +294,14 @@ async def test_azure_ai_client_prepare_options_basic(mock_project_client: MagicM chat_options = ChatOptions() with ( - patch.object(client.__class__.__bases__[0], "prepare_options", return_value={"model": "test-model"}), + patch.object(client.__class__.__bases__[0], "_prepare_options", return_value={"model": "test-model"}), patch.object( client, "_get_agent_reference_or_create", return_value={"name": "test-agent", "version": "1.0", "type": "agent_reference"}, ), ): - run_options = await client.prepare_options(messages, chat_options) + run_options = await client._prepare_options(messages, chat_options) assert "extra_body" in run_options assert run_options["extra_body"]["agent"]["name"] == "test-agent" @@ -329,14 +329,14 @@ async def test_azure_ai_client_prepare_options_with_application_endpoint( chat_options = ChatOptions() with ( - patch.object(client.__class__.__bases__[0], "prepare_options", return_value={"model": "test-model"}), + patch.object(client.__class__.__bases__[0], "_prepare_options", return_value={"model": "test-model"}), patch.object( client, "_get_agent_reference_or_create", return_value={"name": "test-agent", "version": "1", "type": "agent_reference"}, ), ): - run_options = await client.prepare_options(messages, chat_options) + run_options = await client._prepare_options(messages, chat_options) if expects_agent: assert "extra_body" in run_options @@ -369,14 +369,14 @@ async def test_azure_ai_client_prepare_options_with_application_project_client( chat_options = ChatOptions() with ( - patch.object(client.__class__.__bases__[0], "prepare_options", return_value={"model": "test-model"}), + patch.object(client.__class__.__bases__[0], "_prepare_options", return_value={"model": "test-model"}), patch.object( client, "_get_agent_reference_or_create", return_value={"name": "test-agent", "version": "1", "type": "agent_reference"}, ), ): - run_options = await client.prepare_options(messages, chat_options) + run_options = await client._prepare_options(messages, chat_options) if expects_agent: assert "extra_body" in run_options @@ -386,13 +386,13 @@ async def test_azure_ai_client_prepare_options_with_application_project_client( async def test_azure_ai_client_initialize_client(mock_project_client: MagicMock) -> None: - """Test initialize_client method.""" + """Test _initialize_client method.""" client = create_test_azure_ai_client(mock_project_client) mock_openai_client = MagicMock() mock_project_client.get_openai_client = MagicMock(return_value=mock_openai_client) - await client.initialize_client() + await client._initialize_client() assert client.client is mock_openai_client mock_project_client.get_openai_client.assert_called_once() @@ -477,6 +477,30 @@ async def test_azure_ai_client_agent_creation_with_instructions( assert call_args[1]["definition"].instructions == "Message instructions. Option instructions. " +async def test_azure_ai_client_agent_creation_with_additional_args( + mock_project_client: MagicMock, +) -> None: + """Test agent creation with additional arguments.""" + client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent") + + # Mock agent creation response + mock_agent = MagicMock() + mock_agent.name = "test-agent" + mock_agent.version = "1.0" + mock_project_client.agents.create_version = AsyncMock(return_value=mock_agent) + + run_options = {"model": "test-model", "temperature": 0.9, "top_p": 0.8} + messages_instructions = "Message instructions. " + + await client._get_agent_reference_or_create(run_options, messages_instructions) # type: ignore + + # Verify agent was created with provided arguments + call_args = mock_project_client.agents.create_version.call_args + definition = call_args[1]["definition"] + assert definition.temperature == 0.9 + assert definition.top_p == 0.8 + + async def test_azure_ai_client_agent_creation_with_tools( mock_project_client: MagicMock, ) -> None: @@ -703,7 +727,7 @@ async def test_azure_ai_client_prepare_options_excludes_response_format( with ( patch.object( client.__class__.__bases__[0], - "prepare_options", + "_prepare_options", return_value={"model": "test-model", "response_format": ResponseFormatModel}, ), patch.object( @@ -712,7 +736,7 @@ async def test_azure_ai_client_prepare_options_excludes_response_format( return_value={"name": "test-agent", "version": "1.0", "type": "agent_reference"}, ), ): - run_options = await client.prepare_options(messages, chat_options) + run_options = await client._prepare_options(messages, chat_options) # response_format should be excluded from final run options assert "response_format" not in run_options @@ -721,94 +745,8 @@ async def test_azure_ai_client_prepare_options_excludes_response_format( assert run_options["extra_body"]["agent"]["name"] == "test-agent" -async def test_azure_ai_client_prepare_options_with_resp_conversation_id( - mock_project_client: MagicMock, -) -> None: - """Test prepare_options with conversation ID starting with 'resp_'.""" - client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0") - - messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])] - chat_options = ChatOptions(conversation_id="resp_12345") - - with ( - patch.object( - client.__class__.__bases__[0], - "prepare_options", - return_value={"model": "test-model", "previous_response_id": "old_value", "conversation": "old_conv"}, - ), - patch.object( - client, - "_get_agent_reference_or_create", - return_value={"name": "test-agent", "version": "1.0", "type": "agent_reference"}, - ), - ): - run_options = await client.prepare_options(messages, chat_options) - - # Should set previous_response_id and remove conversation property - assert run_options["previous_response_id"] == "resp_12345" - assert "conversation" not in run_options - - -async def test_azure_ai_client_prepare_options_with_conv_conversation_id( - mock_project_client: MagicMock, -) -> None: - """Test prepare_options with conversation ID starting with 'conv_'.""" - client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0") - - messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])] - chat_options = ChatOptions(conversation_id="conv_67890") - - with ( - patch.object( - client.__class__.__bases__[0], - "prepare_options", - return_value={"model": "test-model", "previous_response_id": "old_value", "conversation": "old_conv"}, - ), - patch.object( - client, - "_get_agent_reference_or_create", - return_value={"name": "test-agent", "version": "1.0", "type": "agent_reference"}, - ), - ): - run_options = await client.prepare_options(messages, chat_options) - - # Should set conversation and remove previous_response_id property - assert run_options["conversation"] == "conv_67890" - assert "previous_response_id" not in run_options - - -async def test_azure_ai_client_prepare_options_with_client_conversation_id( - mock_project_client: MagicMock, -) -> None: - """Test prepare_options using client's default conversation ID when chat options don't have one.""" - client = create_test_azure_ai_client( - mock_project_client, agent_name="test-agent", agent_version="1.0", conversation_id="resp_client_default" - ) - - messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])] - chat_options = ChatOptions() # No conversation_id specified - - with ( - patch.object( - client.__class__.__bases__[0], - "prepare_options", - return_value={"model": "test-model", "previous_response_id": "old_value", "conversation": "old_conv"}, - ), - patch.object( - client, - "_get_agent_reference_or_create", - return_value={"name": "test-agent", "version": "1.0", "type": "agent_reference"}, - ), - ): - run_options = await client.prepare_options(messages, chat_options) - - # Should use client's default conversation_id and set previous_response_id - assert run_options["previous_response_id"] == "resp_client_default" - assert "conversation" not in run_options - - def test_get_conversation_id_with_store_true_and_conversation_id() -> None: - """Test get_conversation_id returns conversation ID when store is True and conversation exists.""" + """Test _get_conversation_id returns conversation ID when store is True and conversation exists.""" client = create_test_azure_ai_client(MagicMock()) # Mock OpenAI response with conversation @@ -818,13 +756,13 @@ def test_get_conversation_id_with_store_true_and_conversation_id() -> None: mock_conversation.id = "conv_67890" mock_response.conversation = mock_conversation - result = client.get_conversation_id(mock_response, store=True) + result = client._get_conversation_id(mock_response, store=True) assert result == "conv_67890" def test_get_conversation_id_with_store_true_and_no_conversation() -> None: - """Test get_conversation_id returns response ID when store is True and no conversation exists.""" + """Test _get_conversation_id returns response ID when store is True and no conversation exists.""" client = create_test_azure_ai_client(MagicMock()) # Mock OpenAI response without conversation @@ -832,13 +770,13 @@ def test_get_conversation_id_with_store_true_and_no_conversation() -> None: mock_response.id = "resp_12345" mock_response.conversation = None - result = client.get_conversation_id(mock_response, store=True) + result = client._get_conversation_id(mock_response, store=True) assert result == "resp_12345" def test_get_conversation_id_with_store_true_and_empty_conversation_id() -> None: - """Test get_conversation_id returns response ID when store is True and conversation ID is empty.""" + """Test _get_conversation_id returns response ID when store is True and conversation ID is empty.""" client = create_test_azure_ai_client(MagicMock()) # Mock OpenAI response with conversation but empty ID @@ -848,13 +786,13 @@ def test_get_conversation_id_with_store_true_and_empty_conversation_id() -> None mock_conversation.id = "" mock_response.conversation = mock_conversation - result = client.get_conversation_id(mock_response, store=True) + result = client._get_conversation_id(mock_response, store=True) assert result == "resp_12345" def test_get_conversation_id_with_store_false() -> None: - """Test get_conversation_id returns None when store is False.""" + """Test _get_conversation_id returns None when store is False.""" client = create_test_azure_ai_client(MagicMock()) # Mock OpenAI response with conversation @@ -864,13 +802,13 @@ def test_get_conversation_id_with_store_false() -> None: mock_conversation.id = "conv_67890" mock_response.conversation = mock_conversation - result = client.get_conversation_id(mock_response, store=False) + result = client._get_conversation_id(mock_response, store=False) assert result is None def test_get_conversation_id_with_parsed_response_and_store_true() -> None: - """Test get_conversation_id works with ParsedResponse when store is True.""" + """Test _get_conversation_id works with ParsedResponse when store is True.""" client = create_test_azure_ai_client(MagicMock()) # Mock ParsedResponse with conversation @@ -880,13 +818,13 @@ def test_get_conversation_id_with_parsed_response_and_store_true() -> None: mock_conversation.id = "conv_parsed_67890" mock_response.conversation = mock_conversation - result = client.get_conversation_id(mock_response, store=True) + result = client._get_conversation_id(mock_response, store=True) assert result == "conv_parsed_67890" def test_get_conversation_id_with_parsed_response_no_conversation() -> None: - """Test get_conversation_id returns response ID with ParsedResponse when no conversation exists.""" + """Test _get_conversation_id returns response ID with ParsedResponse when no conversation exists.""" client = create_test_azure_ai_client(MagicMock()) # Mock ParsedResponse without conversation @@ -894,7 +832,7 @@ def test_get_conversation_id_with_parsed_response_no_conversation() -> None: mock_response.id = "resp_parsed_12345" mock_response.conversation = None - result = client.get_conversation_id(mock_response, store=True) + result = client._get_conversation_id(mock_response, store=True) assert result == "resp_parsed_12345" diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 7d8ebe0264..74462c8441 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -414,7 +414,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien request_response_format, ) logger.debug("Signalling entity %s with request: %s", entity_instance_id, run_request) - await client.signal_entity(entity_instance_id, "run_agent", run_request) + await client.signal_entity(entity_instance_id, "run", run_request) logger.debug(f"[HTTP Trigger] Signal sent to entity {session_id}") @@ -495,7 +495,8 @@ def entity_function(context: df.DurableEntityContext) -> None: """Durable entity that manages agent execution and conversation state. Operations: - - run_agent: Execute the agent with a message + - run: Execute the agent with a message + - run_agent: (Deprecated) Execute the agent with a message - reset: Clear conversation history """ entity_handler = create_agent_entity(agent, callback) @@ -637,7 +638,7 @@ async def _handle_mcp_tool_invocation( logger.info("[MCP Tool] Invoking agent '%s' with query: %s", agent_name, query_preview) # Signal entity to run agent - await client.signal_entity(entity_instance_id, "run_agent", run_request) + await client.signal_entity(entity_instance_id, "run", run_request) # Poll for response (similar to HTTP handler) try: diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py index 45872ce1a1..2cc86c1b65 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py @@ -46,7 +46,8 @@ class AgentEntity: - Handles tool execution Operations: - - run_agent: Execute the agent with a message + - run: Execute the agent with a message + - run_agent: (Deprecated) Execute the agent with a message - reset: Clear conversation history Attributes: @@ -94,6 +95,22 @@ async def run_agent( self, context: df.DurableEntityContext, request: RunRequest | dict[str, Any] | str, + ) -> AgentRunResponse: + """(Deprecated) Execute the agent with a message directly in the entity. + + Args: + context: Entity context + request: RunRequest object, dict, or string message (for backward compatibility) + + Returns: + AgentRunResponse enriched with execution metadata. + """ + return await self.run(context, request) + + async def run( + self, + context: df.DurableEntityContext, + request: RunRequest | dict[str, Any] | str, ) -> AgentRunResponse: """Execute the agent with a message directly in the entity. @@ -124,7 +141,7 @@ async def run_agent( state_request = DurableAgentStateRequest.from_run_request(run_request) self.state.data.conversation_history.append(state_request) - logger.debug(f"[AgentEntity.run_agent] Received Message: {state_request}") + logger.debug(f"[AgentEntity.run] Received Message: {state_request}") try: # Build messages from conversation history, excluding error responses @@ -150,7 +167,7 @@ async def run_agent( ) logger.debug( - "[AgentEntity.run_agent] Agent invocation completed - response type: %s", + "[AgentEntity.run] Agent invocation completed - response type: %s", type(agent_run_response).__name__, ) @@ -167,12 +184,12 @@ async def run_agent( state_response = DurableAgentStateResponse.from_run_response(correlation_id, agent_run_response) self.state.data.conversation_history.append(state_response) - logger.debug("[AgentEntity.run_agent] AgentRunResponse stored in conversation history") + logger.debug("[AgentEntity.run] AgentRunResponse stored in conversation history") return agent_run_response except Exception as exc: - logger.exception("[AgentEntity.run_agent] Agent execution failed.") + logger.exception("[AgentEntity.run] Agent execution failed.") # Create error message error_message = ChatMessage( @@ -367,7 +384,7 @@ async def _entity_coroutine(context: df.DurableEntityContext) -> None: operation = context.operation_name - if operation == "run_agent": + if operation == "run" or operation == "run_agent": input_data: Any = context.get_input() request: str | dict[str, Any] @@ -377,7 +394,7 @@ async def _entity_coroutine(context: df.DurableEntityContext) -> None: # Fall back to treating input as message string request = "" if input_data is None else str(cast(object, input_data)) - result = await entity.run_agent(context, request) + result = await entity.run(context, request) context.set_result(result.to_dict()) elif operation == "reset": diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py index d5adc68d74..4cef22e023 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py @@ -285,7 +285,7 @@ def my_orchestration(context): logger.debug("[DurableAIAgent] Calling entity %s with message: %s", entity_id, message_str[:100]) # Call the entity to get the underlying task - entity_task = self.context.call_entity(entity_id, "run_agent", run_request.to_dict()) + entity_task = self.context.call_entity(entity_id, "run", run_request.to_dict()) # Wrap it in an AgentTask that will convert the result to AgentRunResponse agent_task = AgentTask( diff --git a/python/packages/azurefunctions/pyproject.toml b/python/packages/azurefunctions/pyproject.toml index e6363a65f8..17320b92e4 100644 --- a/python/packages/azurefunctions/pyproject.toml +++ b/python/packages/azurefunctions/pyproject.toml @@ -4,7 +4,7 @@ description = "Azure Functions integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251211" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/azurefunctions/tests/integration_tests/README.md b/python/packages/azurefunctions/tests/integration_tests/README.md index d9ecb86234..a7f9fadc44 100644 --- a/python/packages/azurefunctions/tests/integration_tests/README.md +++ b/python/packages/azurefunctions/tests/integration_tests/README.md @@ -29,7 +29,7 @@ docker run -d -p 10000:10000 -p 10001:10001 -p 10002:10002 mcr.microsoft.com/azu **Durable Task Scheduler:** ```bash -docker run -d -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest +docker run -d -p 8080:8080 -p 8082:8082 -e DTS_USE_DYNAMIC_TASK_HUBS=true mcr.microsoft.com/dts/dts-emulator:latest ``` ## Running Tests diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index 817a81e856..6937b3e0f5 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -338,7 +338,7 @@ async def test_entity_run_agent_operation(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, {"message": "Test message", "thread_id": "test-conv-123", "correlationId": "corr-app-entity-1"}, ) @@ -358,7 +358,7 @@ async def test_entity_stores_conversation_history(self) -> None: mock_context = Mock() # Send first message - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-app-entity-2"} ) @@ -367,7 +367,7 @@ async def test_entity_stores_conversation_history(self) -> None: assert len(history) == 1 # Just the user message # Send second message - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 2", "thread_id": "conv-2", "correlationId": "corr-app-entity-2b"} ) @@ -398,12 +398,12 @@ async def test_entity_increments_message_count(self) -> None: assert len(entity.state.data.conversation_history) == 0 - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-app-entity-3a"} ) assert len(entity.state.data.conversation_history) == 2 - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-app-entity-3b"} ) assert len(entity.state.data.conversation_history) == 4 @@ -433,8 +433,36 @@ def test_create_agent_entity_returns_function(self) -> None: assert callable(entity_function) + def test_entity_function_handles_run_operation(self) -> None: + """Test that the entity function handles the run operation.""" + mock_agent = Mock() + mock_agent.run = AsyncMock( + return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response")]) + ) + + entity_function = create_agent_entity(mock_agent) + + # Mock context + mock_context = Mock() + mock_context.operation_name = "run" + mock_context.get_input.return_value = { + "message": "Test message", + "thread_id": "conv-123", + "correlationId": "corr-app-factory-1", + } + mock_context.get_state.return_value = None + + # Execute entity function + entity_function(mock_context) + + # Verify result was set + assert mock_context.set_result.called + assert mock_context.set_state.called + result_call = mock_context.set_result.call_args[0][0] + assert "error" not in result_call + def test_entity_function_handles_run_agent_operation(self) -> None: - """Test that the entity function handles the run_agent operation.""" + """Test that the entity function handles the deprecated run_agent operation for backward compatibility.""" mock_agent = Mock() mock_agent.run = AsyncMock( return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response")]) @@ -458,6 +486,8 @@ def test_entity_function_handles_run_agent_operation(self) -> None: # Verify result was set assert mock_context.set_result.called assert mock_context.set_state.called + result_call = mock_context.set_result.call_args[0][0] + assert "error" not in result_call def test_entity_function_handles_reset_operation(self) -> None: """Test that the entity function handles the reset operation.""" @@ -585,7 +615,7 @@ async def test_entity_handles_agent_error(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, {"message": "Test message", "thread_id": "conv-1", "correlationId": "corr-app-error-1"} ) @@ -605,7 +635,7 @@ def test_entity_function_handles_exception(self) -> None: entity_function = create_agent_entity(mock_agent) mock_context = Mock() - mock_context.operation_name = "run_agent" + mock_context.operation_name = "run" mock_context.get_input.side_effect = Exception("Input error") mock_context.get_state.return_value = None diff --git a/python/packages/azurefunctions/tests/test_entities.py b/python/packages/azurefunctions/tests/test_entities.py index 66f39861a1..dcea75aa1c 100644 --- a/python/packages/azurefunctions/tests/test_entities.py +++ b/python/packages/azurefunctions/tests/test_entities.py @@ -108,6 +108,33 @@ def test_init_with_different_agent_types(self) -> None: class TestAgentEntityRunAgent: """Test suite for the run_agent operation.""" + async def test_run_executes_agent(self) -> None: + """Test that run executes the agent.""" + mock_agent = Mock() + mock_response = _agent_response("Test response") + mock_agent.run = AsyncMock(return_value=mock_response) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + result = await entity.run( + mock_context, {"message": "Test message", "thread_id": "conv-123", "correlationId": "corr-entity-1"} + ) + + # Verify agent.run was called + mock_agent.run.assert_called_once() + _, kwargs = mock_agent.run.call_args + sent_messages: list[Any] = kwargs.get("messages") + assert len(sent_messages) == 1 + sent_message = sent_messages[0] + assert isinstance(sent_message, ChatMessage) + assert getattr(sent_message, "text", None) == "Test message" + assert getattr(sent_message.role, "value", sent_message.role) == "user" + + # Verify result + assert isinstance(result, AgentRunResponse) + assert result.text == "Test response" + async def test_run_agent_executes_agent(self) -> None: """Test that run_agent executes the agent.""" mock_agent = Mock() @@ -156,7 +183,7 @@ async def update_generator() -> AsyncIterator[AgentRunResponseUpdate]: entity = AgentEntity(mock_agent, callback=callback) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, { "message": "Tell me something", @@ -203,7 +230,7 @@ async def test_run_agent_final_callback_without_streaming(self) -> None: entity = AgentEntity(mock_agent, callback=callback) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, { "message": "Hi", @@ -235,7 +262,7 @@ async def test_run_agent_updates_conversation_history(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - await entity.run_agent( + await entity.run( mock_context, {"message": "User message", "thread_id": "conv-1", "correlationId": "corr-entity-2"} ) @@ -263,17 +290,17 @@ async def test_run_agent_increments_message_count(self) -> None: assert len(entity.state.data.conversation_history) == 0 - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-3a"} ) assert len(entity.state.data.conversation_history) == 2 - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-3b"} ) assert len(entity.state.data.conversation_history) == 4 - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 3", "thread_id": "conv-1", "correlationId": "corr-entity-3c"} ) assert len(entity.state.data.conversation_history) == 6 @@ -287,9 +314,7 @@ async def test_run_agent_with_none_thread_id(self) -> None: mock_context = Mock() with pytest.raises(ValueError, match="thread_id"): - await entity.run_agent( - mock_context, {"message": "Message", "thread_id": None, "correlationId": "corr-entity-5"} - ) + await entity.run(mock_context, {"message": "Message", "thread_id": None, "correlationId": "corr-entity-5"}) async def test_run_agent_multiple_conversations(self) -> None: """Test that run_agent maintains history across multiple messages.""" @@ -300,13 +325,13 @@ async def test_run_agent_multiple_conversations(self) -> None: mock_context = Mock() # Send multiple messages - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-8a"} ) - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-8b"} ) - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 3", "thread_id": "conv-1", "correlationId": "corr-entity-8c"} ) @@ -374,10 +399,10 @@ async def test_reset_after_conversation(self) -> None: mock_context = Mock() # Have a conversation - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-10a"} ) - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-10b"} ) @@ -413,7 +438,7 @@ def test_entity_function_handles_run_agent(self) -> None: # Mock context mock_context = Mock() - mock_context.operation_name = "run_agent" + mock_context.operation_name = "run" mock_context.get_input.return_value = { "message": "Test message", "thread_id": "conv-123", @@ -576,7 +601,7 @@ async def test_run_agent_handles_agent_exception(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-1"} ) @@ -595,7 +620,7 @@ async def test_run_agent_handles_value_error(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-2"} ) @@ -614,7 +639,7 @@ async def test_run_agent_handles_timeout_error(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-3"} ) @@ -631,7 +656,7 @@ def test_entity_function_handles_exception_in_operation(self) -> None: entity_function = create_agent_entity(mock_agent) mock_context = Mock() - mock_context.operation_name = "run_agent" + mock_context.operation_name = "run" mock_context.get_input.side_effect = Exception("Input error") mock_context.get_state.return_value = None @@ -651,7 +676,7 @@ async def test_run_agent_preserves_message_on_error(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, {"message": "Test message", "thread_id": "conv-123", "correlationId": "corr-entity-error-4"}, ) @@ -674,7 +699,7 @@ async def test_conversation_history_has_timestamps(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - await entity.run_agent( + await entity.run( mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-history-1"} ) @@ -694,19 +719,19 @@ async def test_conversation_history_ordering(self) -> None: # Send multiple messages with different responses mock_agent.run = AsyncMock(return_value=_agent_response("Response 1")) - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-history-2a"}, ) mock_agent.run = AsyncMock(return_value=_agent_response("Response 2")) - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-history-2b"}, ) mock_agent.run = AsyncMock(return_value=_agent_response("Response 3")) - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 3", "thread_id": "conv-1", "correlationId": "corr-entity-history-2c"}, ) @@ -729,11 +754,11 @@ async def test_conversation_history_role_alternation(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-history-3a"}, ) - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-history-3b"}, ) @@ -766,7 +791,7 @@ async def test_run_agent_with_run_request_object(self) -> None: correlation_id="corr-runreq-1", ) - result = await entity.run_agent(mock_context, request) + result = await entity.run(mock_context, request) assert isinstance(result, AgentRunResponse) assert result.text == "Response" @@ -787,7 +812,7 @@ async def test_run_agent_with_dict_request(self) -> None: "correlationId": "corr-runreq-2", } - result = await entity.run_agent(mock_context, request_dict) + result = await entity.run(mock_context, request_dict) assert isinstance(result, AgentRunResponse) assert result.text == "Response" @@ -801,7 +826,7 @@ async def test_run_agent_with_string_raises_without_correlation(self) -> None: mock_context = Mock() with pytest.raises(ValueError): - await entity.run_agent(mock_context, "Simple message") + await entity.run(mock_context, "Simple message") async def test_run_agent_stores_role_in_history(self) -> None: """Test that run_agent stores the role in conversation history.""" @@ -819,7 +844,7 @@ async def test_run_agent_stores_role_in_history(self) -> None: correlation_id="corr-runreq-3", ) - await entity.run_agent(mock_context, request) + await entity.run(mock_context, request) # Check that system role was stored history = entity.state.data.conversation_history @@ -842,7 +867,7 @@ async def test_run_agent_with_response_format(self) -> None: correlation_id="corr-runreq-4", ) - result = await entity.run_agent(mock_context, request) + result = await entity.run(mock_context, request) assert isinstance(result, AgentRunResponse) assert result.text == '{"answer": 42}' @@ -860,7 +885,7 @@ async def test_run_agent_disable_tool_calls(self) -> None: message="Test", thread_id="conv-runreq-5", enable_tool_calls=False, correlation_id="corr-runreq-5" ) - result = await entity.run_agent(mock_context, request) + result = await entity.run(mock_context, request) assert isinstance(result, AgentRunResponse) # Agent should have been called (tool disabling is framework-dependent) @@ -874,7 +899,7 @@ async def test_entity_function_with_run_request_dict(self) -> None: entity_function = create_agent_entity(mock_agent) mock_context = Mock() - mock_context.operation_name = "run_agent" + mock_context.operation_name = "run" mock_context.get_input.return_value = { "message": "Test message", "thread_id": "conv-789", diff --git a/python/packages/azurefunctions/tests/test_orchestration.py b/python/packages/azurefunctions/tests/test_orchestration.py index 0f845d4105..b0dd313b0b 100644 --- a/python/packages/azurefunctions/tests/test_orchestration.py +++ b/python/packages/azurefunctions/tests/test_orchestration.py @@ -295,7 +295,7 @@ def test_run_creates_entity_call(self) -> None: call_args = mock_context.call_entity.call_args entity_id, operation, request = call_args[0] - assert operation == "run_agent" + assert operation == "run" assert request["message"] == "Test message" assert request["enable_tool_calls"] is True assert "correlationId" in request diff --git a/python/packages/chatkit/pyproject.toml b/python/packages/chatkit/pyproject.toml index 6cd6a16146..8c73a2ffb0 100644 --- a/python/packages/chatkit/pyproject.toml +++ b/python/packages/chatkit/pyproject.toml @@ -4,7 +4,7 @@ description = "OpenAI ChatKit integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251211" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/copilotstudio/pyproject.toml b/python/packages/copilotstudio/pyproject.toml index 955fb2cbcd..c8517c6eea 100644 --- a/python/packages/copilotstudio/pyproject.toml +++ b/python/packages/copilotstudio/pyproject.toml @@ -4,7 +4,7 @@ description = "Copilot Studio integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251211" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 3c40004362..aadd1be40a 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -34,7 +34,7 @@ ToolMode, ) from .exceptions import AgentExecutionException, AgentInitializationError -from .observability import use_agent_observability +from .observability import use_agent_instrumentation if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover @@ -516,8 +516,8 @@ def _prepare_context_providers( @use_agent_middleware -@use_agent_observability -class ChatAgent(BaseAgent): +@use_agent_instrumentation(capture_usage=False) # type: ignore[arg-type,misc] +class ChatAgent(BaseAgent): # type: ignore[misc] """A Chat Client Agent. This is the primary agent implementation that uses a chat client to interact @@ -583,7 +583,7 @@ def get_weather(location: str) -> str: print(update.text, end="") """ - AGENT_SYSTEM_NAME: ClassVar[str] = "microsoft.agent_framework" + AGENT_PROVIDER_NAME: ClassVar[str] = "microsoft.agent_framework" def __init__( self, diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 4d91492822..bfb2c3f7d4 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -8,7 +8,6 @@ from pydantic import BaseModel from ._logging import get_logger -from ._mcp import MCPTool from ._memory import AggregateContextProvider, ContextProvider from ._middleware import ( ChatMiddleware, @@ -426,6 +425,8 @@ async def _normalize_tools( else [tools] ) for tool in tools_list: # type: ignore[reportUnknownType] + from ._mcp import MCPTool + if isinstance(tool, MCPTool): if not tool.is_connected: await tool.connect() @@ -500,7 +501,7 @@ async def get_response( stop: str | Sequence[str] | None = None, store: bool | None = None, temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None, + tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", tools: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] @@ -534,6 +535,7 @@ async def get_response( store: Whether to store the response. temperature: The sampling temperature to use. tool_choice: The tool choice for the request. + Default is `auto`. tools: The tools to use for the request. top_p: The nucleus sampling probability to use. user: The user to associate with the request. @@ -594,7 +596,7 @@ async def get_streaming_response( stop: str | Sequence[str] | None = None, store: bool | None = None, temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None, + tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", tools: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] @@ -628,6 +630,7 @@ async def get_streaming_response( store: Whether to store the response. temperature: The sampling temperature to use. tool_choice: The tool choice for the request. + Default is `auto`. tools: The tools to use for the request. top_p: The nucleus sampling probability to use. user: The user to associate with the request. diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 721af6210c..a25f359a59 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -63,21 +63,21 @@ ] -def _mcp_prompt_message_to_chat_message( +def _parse_message_from_mcp( mcp_type: types.PromptMessage | types.SamplingMessage, ) -> ChatMessage: - """Convert a MCP container type to a Agent Framework type.""" + """Parse an MCP container type into an Agent Framework type.""" return ChatMessage( role=Role(value=mcp_type.role), - contents=_mcp_type_to_ai_content(mcp_type.content), + contents=_parse_content_from_mcp(mcp_type.content), raw_representation=mcp_type, ) -def _mcp_call_tool_result_to_ai_contents( +def _parse_contents_from_mcp_tool_result( mcp_type: types.CallToolResult, ) -> list[Contents]: - """Convert a MCP container type to a Agent Framework type. + """Parse an MCP CallToolResult into Agent Framework content types. This function extracts the complete _meta field from CallToolResult objects and merges all metadata into the additional_properties field of converted @@ -111,7 +111,7 @@ def _mcp_call_tool_result_to_ai_contents( # Convert each content item and merge metadata result_contents = [] for item in mcp_type.content: - contents = _mcp_type_to_ai_content(item) + contents = _parse_content_from_mcp(item) if merged_meta_props: for content in contents: @@ -124,7 +124,7 @@ def _mcp_call_tool_result_to_ai_contents( return result_contents -def _mcp_type_to_ai_content( +def _parse_content_from_mcp( mcp_type: types.ImageContent | types.TextContent | types.AudioContent @@ -142,7 +142,7 @@ def _mcp_type_to_ai_content( | types.ToolResultContent ], ) -> list[Contents]: - """Convert a MCP type to a Agent Framework type.""" + """Parse an MCP type into an Agent Framework type.""" mcp_types = mcp_type if isinstance(mcp_type, Sequence) else [mcp_type] return_types: list[Contents] = [] for mcp_type in mcp_types: @@ -152,7 +152,7 @@ def _mcp_type_to_ai_content( case types.ImageContent() | types.AudioContent(): return_types.append( DataContent( - uri=mcp_type.data, + data=mcp_type.data, media_type=mcp_type.mimeType, raw_representation=mcp_type, ) @@ -178,7 +178,7 @@ def _mcp_type_to_ai_content( return_types.append( FunctionResultContent( call_id=mcp_type.toolUseId, - result=_mcp_type_to_ai_content(mcp_type.content) + result=_parse_content_from_mcp(mcp_type.content) if mcp_type.content else mcp_type.structuredContent, exception=Exception() if mcp_type.isError else None, @@ -211,10 +211,10 @@ def _mcp_type_to_ai_content( return return_types -def _ai_content_to_mcp_types( +def _prepare_content_for_mcp( content: Contents, ) -> types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink | None: - """Convert a BaseContent type to a MCP type.""" + """Prepare an Agent Framework content type for MCP.""" match content: case TextContent(): return types.TextContent(type="text", text=content.text) @@ -253,15 +253,15 @@ def _ai_content_to_mcp_types( return None -def _chat_message_to_mcp_types( +def _prepare_message_for_mcp( content: ChatMessage, ) -> list[types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink]: - """Convert a ChatMessage to a list of MCP types.""" + """Prepare a ChatMessage for MCP format.""" messages: list[ types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink ] = [] for item in content.contents: - mcp_content = _ai_content_to_mcp_types(item) + mcp_content = _prepare_content_for_mcp(item) if mcp_content: messages.append(mcp_content) return messages @@ -469,7 +469,7 @@ async def sampling_callback( logger.debug("Sampling callback called with params: %s", params) messages: list[ChatMessage] = [] for msg in params.messages: - messages.append(_mcp_prompt_message_to_chat_message(msg)) + messages.append(_parse_message_from_mcp(msg)) try: response = await self.chat_client.get_response( messages, @@ -487,7 +487,7 @@ async def sampling_callback( code=types.INTERNAL_ERROR, message="Failed to get chat message content.", ) - mcp_contents = _chat_message_to_mcp_types(response.messages[0]) + mcp_contents = _prepare_message_for_mcp(response.messages[0]) # grab the first content that is of type TextContent or ImageContent mcp_content = next( (content for content in mcp_contents if isinstance(content, (types.TextContent, types.ImageContent))), @@ -685,8 +685,16 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents]: raise ToolExecutionException( "Tools are not loaded for this server, please set load_tools=True in the constructor." ) + # Filter out framework kwargs that cannot be serialized by the MCP SDK. + # These are internal objects passed through the function invocation pipeline + # that should not be forwarded to external MCP servers. + filtered_kwargs = { + k: v for k, v in kwargs.items() if k not in {"chat_options", "tools", "tool_choice", "thread"} + } try: - return _mcp_call_tool_result_to_ai_contents(await self.session.call_tool(tool_name, arguments=kwargs)) + return _parse_contents_from_mcp_tool_result( + await self.session.call_tool(tool_name, arguments=filtered_kwargs) + ) except McpError as mcp_exc: raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc except Exception as ex: @@ -716,7 +724,7 @@ async def get_prompt(self, prompt_name: str, **kwargs: Any) -> list[ChatMessage] ) try: prompt_result = await self.session.get_prompt(prompt_name, arguments=kwargs) - return [_mcp_prompt_message_to_chat_message(message) for message in prompt_result.messages] + return [_parse_message_from_mcp(message) for message in prompt_result.messages] except McpError as mcp_exc: raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc except Exception as ex: diff --git a/python/packages/core/agent_framework/_memory.py b/python/packages/core/agent_framework/_memory.py index 4b2a01ad24..a5b53fc39f 100644 --- a/python/packages/core/agent_framework/_memory.py +++ b/python/packages/core/agent_framework/_memory.py @@ -6,11 +6,13 @@ from collections.abc import MutableSequence, Sequence from contextlib import AsyncExitStack from types import TracebackType -from typing import Any, Final, cast +from typing import TYPE_CHECKING, Any, Final, cast -from ._tools import ToolProtocol from ._types import ChatMessage +if TYPE_CHECKING: + from ._tools import ToolProtocol + if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: @@ -54,7 +56,7 @@ def __init__( self, instructions: str | None = None, messages: Sequence[ChatMessage] | None = None, - tools: Sequence[ToolProtocol] | None = None, + tools: Sequence["ToolProtocol"] | None = None, ): """Create a new Context object. @@ -65,7 +67,7 @@ def __init__( """ self.instructions = instructions self.messages: Sequence[ChatMessage] = messages or [] - self.tools: Sequence[ToolProtocol] = tools or [] + self.tools: Sequence["ToolProtocol"] = tools or [] # region ContextProvider @@ -247,7 +249,7 @@ async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], * contexts = await asyncio.gather(*[provider.invoking(messages, **kwargs) for provider in self.providers]) instructions: str = "" return_messages: list[ChatMessage] = [] - tools: list[ToolProtocol] = [] + tools: list["ToolProtocol"] = [] for ctx in contexts: if ctx.instructions: instructions += ctx.instructions diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 9bb730ba62..4e36cb764a 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -1405,13 +1405,17 @@ async def _stream_generator() -> Any: call_middleware = kwargs.pop("middleware", None) instance_middleware = getattr(self, "middleware", None) - # Merge middleware from both sources, filtering for chat middleware only - all_middleware: list[ChatMiddleware | ChatMiddlewareCallable] = _merge_and_filter_chat_middleware( - instance_middleware, call_middleware - ) + # Merge all middleware and separate by type + middleware = categorize_middleware(instance_middleware, call_middleware) + chat_middleware_list = middleware["chat"] + function_middleware_list = middleware["function"] + + # Pass function middleware to function invocation system if present + if function_middleware_list: + kwargs["_function_middleware_pipeline"] = FunctionMiddlewarePipeline(function_middleware_list) - # If no middleware, use original method - if not all_middleware: + # If no chat middleware, use original method + if not chat_middleware_list: async for update in original_get_streaming_response(self, messages, **kwargs): yield update return @@ -1422,7 +1426,7 @@ async def _stream_generator() -> Any: # Extract chat_options or create default chat_options = kwargs.pop("chat_options", ChatOptions()) - pipeline = ChatMiddlewarePipeline(all_middleware) # type: ignore[arg-type] + pipeline = ChatMiddlewarePipeline(chat_middleware_list) # type: ignore[arg-type] context = ChatContext( chat_client=self, messages=prepare_messages(messages), @@ -1536,27 +1540,40 @@ def _merge_and_filter_chat_middleware( return middleware["chat"] # type: ignore[return-value] -def extract_and_merge_function_middleware(chat_client: Any, **kwargs: Any) -> None: +def extract_and_merge_function_middleware( + chat_client: Any, kwargs: dict[str, Any] +) -> "FunctionMiddlewarePipeline | None": """Extract function middleware from chat client and merge with existing pipeline in kwargs. Args: chat_client: The chat client instance to extract middleware from. + kwargs: Dictionary containing middleware and pipeline information. - Keyword Args: - **kwargs: Dictionary containing middleware and pipeline information. + Returns: + A FunctionMiddlewarePipeline if function middleware is found, None otherwise. """ + # Check if a pipeline was already created by use_chat_middleware + existing_pipeline: FunctionMiddlewarePipeline | None = kwargs.get("_function_middleware_pipeline") + # Get middleware sources client_middleware = getattr(chat_client, "middleware", None) if hasattr(chat_client, "middleware") else None run_level_middleware = kwargs.get("middleware") - existing_pipeline = kwargs.get("_function_middleware_pipeline") - # Extract existing pipeline middlewares if present - existing_middlewares = existing_pipeline._middlewares if existing_pipeline else None + # If we have an existing pipeline but no additional middleware sources, return it directly + if existing_pipeline and not client_middleware and not run_level_middleware: + return existing_pipeline + + # If we have an existing pipeline with additional middleware, we need to merge + # Extract existing pipeline middlewares if present - cast to list[Middleware] for type compatibility + existing_middlewares: list[Middleware] | None = list(existing_pipeline._middlewares) if existing_pipeline else None # Create combined pipeline from all sources using existing helper combined_pipeline = create_function_middleware_pipeline( client_middleware, run_level_middleware, existing_middlewares ) - if combined_pipeline: - kwargs["_function_middleware_pipeline"] = combined_pipeline + # If we have an existing pipeline but combined is None (no new middlewares), return existing + if existing_pipeline and combined_pipeline is None: + return existing_pipeline + + return combined_pipeline diff --git a/python/packages/core/agent_framework/_serialization.py b/python/packages/core/agent_framework/_serialization.py index 1a38d9030a..cf28df2f4f 100644 --- a/python/packages/core/agent_framework/_serialization.py +++ b/python/packages/core/agent_framework/_serialization.py @@ -339,11 +339,17 @@ def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) continue # Handle dicts containing SerializationProtocol values if isinstance(value, dict): + from datetime import date, datetime, time + serialized_dict: dict[str, Any] = {} for k, v in value.items(): if isinstance(v, SerializationProtocol): serialized_dict[k] = v.to_dict(exclude=exclude, exclude_none=exclude_none) continue + # Convert datetime objects to strings + if isinstance(v, (datetime, date, time)): + serialized_dict[k] = str(v) + continue # Check if the value is JSON serializable if is_serializable(v): serialized_dict[k] = v diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index bc16d9edb9..dd14f0dec8 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -886,6 +886,8 @@ def _parse_annotation(annotation: Any) -> Any: If the second annotation (after the type) is a string, then we convert that to a Pydantic Field description. The rest are returned as-is, allowing for multiple annotations. + Literal types are returned as-is to preserve their enum-like values. + Args: annotation: The type annotation to parse. @@ -894,6 +896,12 @@ def _parse_annotation(annotation: Any) -> Any: """ origin = get_origin(annotation) if origin is not None: + # Literal types should be returned as-is - their args are the allowed values, + # not type annotations to be parsed. For example, Literal["Data", "Security"] + # has args ("Data", "Security") which are the valid string values. + if origin is Literal: + return annotation + args = get_args(annotation) # For other generics, return the origin type (e.g., list for List[int]) if len(args) > 1 and isinstance(args[1], str): @@ -1348,6 +1356,35 @@ def __init__( self.include_detailed_errors = include_detailed_errors +class FunctionExecutionResult: + """Internal wrapper pairing function output with loop control signals. + + Function execution produces two distinct concerns: the semantic result (returned to + the LLM as FunctionResultContent) and control flow decisions (whether middleware + requested early termination). This wrapper keeps control signals out of user-facing + content types while allowing _try_execute_function_calls to communicate both. + + Not exposed to users. + + Attributes: + content: The FunctionResultContent or other content from the function execution. + terminate: If True, the function invocation loop should exit immediately without + another LLM call. Set when middleware sets context.terminate=True. + """ + + __slots__ = ("content", "terminate") + + def __init__(self, content: "Contents", terminate: bool = False) -> None: + """Initialize FunctionExecutionResult. + + Args: + content: The content from the function execution. + terminate: Whether to terminate the function calling loop. + """ + self.content = content + self.terminate = terminate + + async def _auto_invoke_function( function_call_content: "FunctionCallContent | FunctionApprovalResponseContent", custom_args: dict[str, Any] | None = None, @@ -1357,7 +1394,7 @@ async def _auto_invoke_function( sequence_index: int | None = None, request_index: int | None = None, middleware_pipeline: Any = None, # Optional MiddlewarePipeline -) -> "Contents": +) -> "FunctionExecutionResult | Contents": """Invoke a function call requested by the agent, applying middleware that is defined. Args: @@ -1372,7 +1409,8 @@ async def _auto_invoke_function( middleware_pipeline: Optional middleware pipeline to apply during execution. Returns: - A FunctionResultContent containing the result or exception. + A FunctionExecutionResult wrapping the content and terminate signal, + or a Contents object for approval/hosted tool scenarios. Raises: KeyError: If the requested function is not found in the tool map. @@ -1392,10 +1430,12 @@ async def _auto_invoke_function( # Tool should exist because _try_execute_function_calls validates this if tool is None: exc = KeyError(f'Function "{function_call_content.name}" not found.') - return FunctionResultContent( - call_id=function_call_content.call_id, - result=f'Error: Requested function "{function_call_content.name}" not found.', - exception=exc, + return FunctionExecutionResult( + content=FunctionResultContent( + call_id=function_call_content.call_id, + result=f'Error: Requested function "{function_call_content.name}" not found.', + exception=exc, + ) ) else: # Note: Unapproved tools (approved=False) are handled in _replace_approval_contents_with_results @@ -1420,7 +1460,9 @@ async def _auto_invoke_function( message = "Error: Argument parsing failed." if config.include_detailed_errors: message = f"{message} Exception: {exc}" - return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + return FunctionExecutionResult( + content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + ) if not middleware_pipeline or ( not hasattr(middleware_pipeline, "has_middlewares") and not middleware_pipeline.has_middlewares @@ -1432,15 +1474,19 @@ async def _auto_invoke_function( tool_call_id=function_call_content.call_id, **runtime_kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {}, ) - return FunctionResultContent( - call_id=function_call_content.call_id, - result=function_result, + return FunctionExecutionResult( + content=FunctionResultContent( + call_id=function_call_content.call_id, + result=function_result, + ) ) except Exception as exc: message = "Error: Function failed." if config.include_detailed_errors: message = f"{message} Exception: {exc}" - return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + return FunctionExecutionResult( + content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + ) # Execute through middleware pipeline if available from ._middleware import FunctionInvocationContext @@ -1464,15 +1510,20 @@ async def final_function_handler(context_obj: Any) -> Any: context=middleware_context, final_handler=final_function_handler, ) - return FunctionResultContent( - call_id=function_call_content.call_id, - result=function_result, + return FunctionExecutionResult( + content=FunctionResultContent( + call_id=function_call_content.call_id, + result=function_result, + ), + terminate=middleware_context.terminate, ) except Exception as exc: message = "Error: Function failed." if config.include_detailed_errors: message = f"{message} Exception: {exc}" - return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + return FunctionExecutionResult( + content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + ) def _get_tool_map( @@ -1503,7 +1554,7 @@ async def _try_execute_function_calls( | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]", config: FunctionInvocationConfiguration, middleware_pipeline: Any = None, # Optional MiddlewarePipeline to avoid circular imports -) -> Sequence["Contents"]: +) -> tuple[Sequence["Contents"], bool]: """Execute multiple function calls concurrently. Args: @@ -1515,9 +1566,11 @@ async def _try_execute_function_calls( middleware_pipeline: Optional middleware pipeline to apply during execution. Returns: - A list of Contents containing the results of each function call, - or the approval requests if any function requires approval, - or the original function calls if any are declaration only. + A tuple of: + - A list of Contents containing the results of each function call, + or the approval requests if any function requires approval, + or the original function calls if any are declaration only. + - A boolean indicating whether to terminate the function calling loop. """ from ._types import FunctionApprovalRequestContent, FunctionCallContent @@ -1540,17 +1593,20 @@ async def _try_execute_function_calls( raise KeyError(f'Error: Requested function "{fcc.name}" not found.') if approval_needed: # approval can only be needed for Function Call Contents, not Approval Responses. - return [ - FunctionApprovalRequestContent(id=fcc.call_id, function_call=fcc) - for fcc in function_calls - if isinstance(fcc, FunctionCallContent) - ] + return ( + [ + FunctionApprovalRequestContent(id=fcc.call_id, function_call=fcc) + for fcc in function_calls + if isinstance(fcc, FunctionCallContent) + ], + False, + ) if declaration_only_flag: # return the declaration only tools to the user, since we cannot execute them. - return [fcc for fcc in function_calls if isinstance(fcc, FunctionCallContent)] + return ([fcc for fcc in function_calls if isinstance(fcc, FunctionCallContent)], False) # Run all function calls concurrently - return await asyncio.gather(*[ + execution_results = await asyncio.gather(*[ _auto_invoke_function( function_call_content=function_call, # type: ignore[arg-type] custom_args=custom_args, @@ -1563,6 +1619,20 @@ async def _try_execute_function_calls( for seq_idx, function_call in enumerate(function_calls) ]) + # Unpack FunctionExecutionResult wrappers and check for terminate signal + contents: list[Contents] = [] + should_terminate = False + for result in execution_results: + if isinstance(result, FunctionExecutionResult): + contents.append(result.content) + if result.terminate: + should_terminate = True + else: + # Direct Contents (e.g., from hosted tools) + contents.append(result) + + return (contents, should_terminate) + def _update_conversation_id(kwargs: dict[str, Any], conversation_id: str | None) -> None: """Update kwargs with conversation id. @@ -1695,12 +1765,8 @@ async def function_invocation_wrapper( prepare_messages, ) - # Extract and merge function middleware from chat client with kwargs pipeline - extract_and_merge_function_middleware(self, **kwargs) - - # Extract the middleware pipeline before calling the underlying function - # because the underlying function may not preserve it in kwargs - stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline") + # Extract and merge function middleware from chat client with kwargs + stored_middleware_pipeline = extract_and_merge_function_middleware(self, kwargs) # Get the config for function invocation (not part of ChatClientProtocol, hence getattr) config: FunctionInvocationConfiguration | None = getattr(self, "function_invocation_configuration", None) @@ -1713,11 +1779,6 @@ async def function_invocation_wrapper( response: "ChatResponse | None" = None fcc_messages: "list[ChatMessage]" = [] - # If tools are provided but tool_choice is not set, default to "auto" for function invocation - tools = _extract_tools(kwargs) - if tools and kwargs.get("tool_choice") is None: - kwargs["tool_choice"] = "auto" - for attempt_idx in range(config.max_iterations if config.enabled else 0): fcc_todo = _collect_approval_responses(prepped_messages) if fcc_todo: @@ -1726,7 +1787,7 @@ async def function_invocation_wrapper( approved_responses = [resp for resp in fcc_todo.values() if resp.approved] approved_function_results: list[Contents] = [] if approved_responses: - approved_function_results = await _try_execute_function_calls( + results, _ = await _try_execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, function_calls=approved_responses, @@ -1734,6 +1795,7 @@ async def function_invocation_wrapper( middleware_pipeline=stored_middleware_pipeline, config=config, ) + approved_function_results = list(results) if any( fcr.exception is not None for fcr in approved_function_results @@ -1773,7 +1835,7 @@ async def function_invocation_wrapper( if function_calls and tools: # Use the stored middleware pipeline instead of extracting from kwargs # because kwargs may have been modified by the underlying function - function_call_results: list[Contents] = await _try_execute_function_calls( + function_call_results, should_terminate = await _try_execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, function_calls=function_calls, @@ -1798,6 +1860,17 @@ async def function_invocation_wrapper( # the function calls are already in the response, so we just continue return response + # Check if middleware signaled to terminate the loop (context.terminate=True) + # This allows middleware to short-circuit the tool loop without another LLM call + if should_terminate: + # Add tool results to response and return immediately without calling LLM again + result_message = ChatMessage(role="tool", contents=function_call_results) + response.messages.append(result_message) + if fcc_messages: + for msg in reversed(fcc_messages): + response.messages.insert(0, msg) + return response + if any( fcr.exception is not None for fcr in function_call_results @@ -1890,12 +1963,8 @@ async def streaming_function_invocation_wrapper( prepare_messages, ) - # Extract and merge function middleware from chat client with kwargs pipeline - extract_and_merge_function_middleware(self, **kwargs) - - # Extract the middleware pipeline before calling the underlying function - # because the underlying function may not preserve it in kwargs - stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline") + # Extract and merge function middleware from chat client with kwargs + stored_middleware_pipeline = extract_and_merge_function_middleware(self, kwargs) # Get the config for function invocation (not part of ChatClientProtocol, hence getattr) config: FunctionInvocationConfiguration | None = getattr(self, "function_invocation_configuration", None) @@ -1914,7 +1983,7 @@ async def streaming_function_invocation_wrapper( approved_responses = [resp for resp in fcc_todo.values() if resp.approved] approved_function_results: list[Contents] = [] if approved_responses: - approved_function_results = await _try_execute_function_calls( + results, _ = await _try_execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, function_calls=approved_responses, @@ -1922,6 +1991,7 @@ async def streaming_function_invocation_wrapper( middleware_pipeline=stored_middleware_pipeline, config=config, ) + approved_function_results = list(results) if any( fcr.exception is not None for fcr in approved_function_results @@ -1976,7 +2046,7 @@ async def streaming_function_invocation_wrapper( if function_calls and tools: # Use the stored middleware pipeline instead of extracting from kwargs # because kwargs may have been modified by the underlying function - function_call_results: list[Contents] = await _try_execute_function_calls( + function_call_results, should_terminate = await _try_execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, function_calls=function_calls, @@ -2005,6 +2075,13 @@ async def streaming_function_invocation_wrapper( # the function calls were already yielded. return + # Check if middleware signaled to terminate the loop (context.terminate=True) + # This allows middleware to short-circuit the tool loop without another LLM call + if should_terminate: + # Yield tool results and return immediately without calling LLM again + yield ChatResponseUpdate(contents=function_call_results, role="tool") + return + if any( fcr.exception is not None for fcr in function_call_results diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index f4662352a0..ab68382a83 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -925,6 +925,10 @@ class DataContent(BaseContent): image_data = b"raw image bytes" data_content = DataContent(data=image_data, media_type="image/png") + # Create from base64-encoded string + base64_string = "iVBORw0KGgoAAAANS..." + data_content = DataContent(data=base64_string, media_type="image/png") + # Create from data URI data_uri = "..." data_content = DataContent(uri=data_uri) @@ -986,11 +990,38 @@ def __init__( **kwargs: Any additional keyword arguments. """ + @overload + def __init__( + self, + *, + data: str, + media_type: str, + annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, + additional_properties: dict[str, Any] | None = None, + raw_representation: Any | None = None, + **kwargs: Any, + ) -> None: + """Initializes a DataContent instance with base64-encoded string data. + + Important: + This is for binary data that is represented as a data URI, not for online resources. + Use ``UriContent`` for online resources. + + Keyword Args: + data: The base64-encoded string data represented by this instance. + The data is used directly to construct a data URI. + media_type: The media type of the data. + annotations: Optional annotations associated with the content. + additional_properties: Optional additional properties associated with the content. + raw_representation: Optional raw representation of the content. + **kwargs: Any additional keyword arguments. + """ + def __init__( self, *, uri: str | None = None, - data: bytes | None = None, + data: bytes | str | None = None, media_type: str | None = None, annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, additional_properties: dict[str, Any] | None = None, @@ -1006,8 +1037,9 @@ def __init__( Keyword Args: uri: The URI of the data represented by this instance. Should be in the form: "data:{media_type};base64,{base64_data}". - data: The binary data represented by this instance. - The data is transformed into a base64-encoded data URI. + data: The binary data or base64-encoded string represented by this instance. + If bytes, the data is transformed into a base64-encoded data URI. + If str, it is assumed to be already base64-encoded and used directly. media_type: The media type of the data. annotations: Optional annotations associated with the content. additional_properties: Optional additional properties associated with the content. @@ -1017,7 +1049,9 @@ def __init__( if uri is None: if data is None or media_type is None: raise ValueError("Either 'data' and 'media_type' or 'uri' must be provided.") - uri = f"data:{media_type};base64,{base64.b64encode(data).decode('utf-8')}" + + base64_data: str = base64.b64encode(data).decode("utf-8") if isinstance(data, bytes) else data + uri = f"data:{media_type};base64,{base64_data}" # Validate URI format and extract media type if not provided validated_uri = self._validate_uri(uri) @@ -1816,13 +1850,14 @@ def prepare_function_call_results(content: Contents | Any | list[Contents | Any] """Prepare the values of the function call results.""" if isinstance(content, Contents): # For BaseContent objects, use to_dict and serialize to JSON - return json.dumps(content.to_dict(exclude={"raw_representation", "additional_properties"})) + # Use default=str to handle datetime and other non-JSON-serializable objects + return json.dumps(content.to_dict(exclude={"raw_representation", "additional_properties"}), default=str) dumpable = _prepare_function_call_results_as_dumpable(content) if isinstance(dumpable, str): return dumpable - # fallback - return json.dumps(dumpable) + # fallback - use default=str to handle datetime and other non-JSON-serializable objects + return json.dumps(dumpable, default=str) # region Chat Response constants diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index d1ff567f81..d4f6c1411d 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -13,12 +13,15 @@ AgentRunResponseUpdate, AgentThread, BaseAgent, + BaseContent, ChatMessage, + Contents, FunctionApprovalRequestContent, FunctionApprovalResponseContent, FunctionCallContent, FunctionResultContent, Role, + TextContent, UsageDetails, ) @@ -28,6 +31,7 @@ AgentRunUpdateEvent, RequestInfoEvent, WorkflowEvent, + WorkflowOutputEvent, ) from ._message_utils import normalize_messages_input from ._typing_utils import is_type_compatible @@ -280,9 +284,8 @@ def _convert_workflow_event_to_agent_update( ) -> AgentRunResponseUpdate | None: """Convert a workflow event to an AgentRunResponseUpdate. - Only AgentRunUpdateEvent and RequestInfoEvent are processed. - Other workflow events are ignored as they are workflow-internal and should - have corresponding AgentRunUpdateEvent emissions if relevant to agent consumers. + AgentRunUpdateEvent, RequestInfoEvent, and WorkflowOutputEvent are processed. + Other workflow events are ignored as they are workflow-internal. """ match event: case AgentRunUpdateEvent(data=update): @@ -291,6 +294,42 @@ def _convert_workflow_event_to_agent_update( return update return None + case WorkflowOutputEvent(data=data, source_executor_id=source_executor_id): + # Convert workflow output to an agent response update. + # Handle different data types appropriately. + if isinstance(data, AgentRunResponseUpdate): + # Already an update, pass through + return data + if isinstance(data, ChatMessage): + # Convert ChatMessage to update + return AgentRunResponseUpdate( + contents=list(data.contents), + role=data.role, + author_name=data.author_name or source_executor_id, + response_id=response_id, + message_id=str(uuid.uuid4()), + created_at=datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + raw_representation=data, + ) + # Determine contents based on data type + if isinstance(data, BaseContent): + # Already a content type (TextContent, ImageContent, etc.) + contents: list[Contents] = [cast(Contents, data)] + elif isinstance(data, str): + contents = [TextContent(text=data)] + else: + # Fallback: convert to string representation + contents = [TextContent(text=str(data))] + return AgentRunResponseUpdate( + contents=contents, + role=Role.ASSISTANT, + author_name=source_executor_id, + response_id=response_id, + message_id=str(uuid.uuid4()), + created_at=datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + raw_representation=data, + ) + case RequestInfoEvent(request_id=request_id): # Store the pending request for later correlation self.pending_requests[request_id] = event diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 358cee94dd..26300ad473 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -11,6 +11,7 @@ from .._threads import AgentThread from .._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value +from ._const import WORKFLOW_RUN_KWARGS_KEY from ._conversation_state import encode_chat_messages from ._events import ( AgentRunEvent, @@ -105,6 +106,11 @@ def workflow_output_types(self) -> list[type[Any]]: return [AgentRunResponse] return [] + @property + def description(self) -> str | None: + """Get the description of the underlying agent.""" + return self._agent.description + @handler async def run( self, request: AgentExecutorRequest, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse] @@ -304,9 +310,12 @@ async def _run_agent(self, ctx: WorkflowContext) -> AgentRunResponse | None: Returns: The complete AgentRunResponse, or None if waiting for user input. """ + run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + response = await self._agent.run( self._cache, thread=self._agent_thread, + **run_kwargs, ) await ctx.add_event(AgentRunEvent(self.id, response)) @@ -328,11 +337,14 @@ async def _run_agent_streaming(self, ctx: WorkflowContext) -> AgentRunResponse | Returns: The complete AgentRunResponse, or None if waiting for user input. """ + run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + updates: list[AgentRunResponseUpdate] = [] user_input_requests: list[FunctionApprovalRequestContent] = [] async for update in self._agent.run_stream( self._cache, thread=self._agent_thread, + **run_kwargs, ): updates.append(update) await ctx.add_event(AgentRunUpdateEvent(self.id, update)) diff --git a/python/packages/core/agent_framework/_workflows/_const.py b/python/packages/core/agent_framework/_workflows/_const.py index 6247be338a..34bde1da47 100644 --- a/python/packages/core/agent_framework/_workflows/_const.py +++ b/python/packages/core/agent_framework/_workflows/_const.py @@ -9,6 +9,11 @@ # Source identifier for internal workflow messages. INTERNAL_SOURCE_PREFIX = "internal" +# SharedState key for storing run kwargs that should be passed to agent invocations. +# Used by all orchestration patterns (Sequential, Concurrent, GroupChat, Handoff, Magentic) +# to pass kwargs from workflow.run_stream() through to agent.run_stream() and @ai_function tools. +WORKFLOW_RUN_KWARGS_KEY = "_workflow_run_kwargs" + def INTERNAL_SOURCE_ID(executor_id: str) -> str: """Generate an internal source ID for a given executor.""" diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index 8e0a7aec1e..9a99657902 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -14,7 +14,10 @@ Key properties: - The entire conversation is maintained and reused on every hop - The coordinator signals a handoff by invoking a tool call that names the specialist -- After a specialist responds, the workflow immediately requests new user input +- In human_in_loop mode (default), the workflow requests user input after each agent response + that doesn't trigger a handoff +- In autonomous mode, agents continue responding until they invoke a handoff tool or reach + a termination condition or turn limit """ import logging @@ -76,9 +79,9 @@ def _create_handoff_tool(alias: str, description: str | None = None) -> AIFuncti # Note: approval_mode is intentionally NOT set for handoff tools. # Handoff tools are framework-internal signals that trigger routing logic, - # not actual function executions. They are automatically intercepted and - # never actually execute, so approval is unnecessary and causes issues - # with tool_calls/responses pairing when cleaning conversations. + # not actual function executions. They are automatically intercepted by + # _AutoHandoffMiddleware which short-circuits execution and provides synthetic + # results, so the function body never actually runs in practice. @ai_function(name=tool_name, description=doc) def _handoff_tool(context: str | None = None) -> str: """Return a deterministic acknowledgement that encodes the target alias.""" @@ -109,6 +112,8 @@ def _clone_chat_agent(agent: ChatAgent) -> ChatAgent: chat_message_store_factory=agent.chat_message_store_factory, context_providers=agent.context_provider, middleware=middleware, + # Disable parallel tool calls to prevent the agent from invoking multiple handoff tools at once. + allow_multiple_tool_calls=False, frequency_penalty=options.frequency_penalty, logit_bias=dict(options.logit_bias) if options.logit_bias else None, max_tokens=options.max_tokens, @@ -130,19 +135,57 @@ def _clone_chat_agent(agent: ChatAgent) -> ChatAgent: @dataclass class HandoffUserInputRequest: - """Request message emitted when the workflow needs fresh user input.""" + """Request message emitted when the workflow needs fresh user input. + + Note: The conversation field is intentionally excluded from checkpoint serialization + to prevent duplication. The conversation is preserved in the coordinator's state + and will be reconstructed on restore. See issue #2667. + """ conversation: list[ChatMessage] awaiting_agent_id: str prompt: str source_executor_id: str + def to_dict(self) -> dict[str, Any]: + """Serialize to dict, excluding conversation to prevent checkpoint duplication. + + The conversation is already preserved in the workflow coordinator's state. + Including it here would cause duplicate messages when restoring from checkpoint. + """ + return { + "awaiting_agent_id": self.awaiting_agent_id, + "prompt": self.prompt, + "source_executor_id": self.source_executor_id, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "HandoffUserInputRequest": + """Deserialize from dict, initializing conversation as empty. + + The conversation will be reconstructed from the coordinator's state on restore. + """ + return cls( + conversation=[], + awaiting_agent_id=data["awaiting_agent_id"], + prompt=data["prompt"], + source_executor_id=data["source_executor_id"], + ) + @dataclass class _ConversationWithUserInput: - """Internal message carrying full conversation + new user messages from gateway to coordinator.""" + """Internal message carrying full conversation + new user messages from gateway to coordinator. + + Attributes: + full_conversation: The conversation messages to process. + is_post_restore: If True, indicates this message was created after a checkpoint restore. + The coordinator should append these messages to its existing conversation rather + than replacing it. This prevents duplicate messages (see issue #2667). + """ full_conversation: list[ChatMessage] = field(default_factory=lambda: []) # type: ignore[misc] + is_post_restore: bool = False @dataclass @@ -179,7 +222,7 @@ async def process( class _InputToConversation(Executor): - """Normalises initial workflow input into a list[ChatMessage].""" + """Normalizes initial workflow input into a list[ChatMessage].""" @handler async def from_str(self, prompt: str, ctx: WorkflowContext[list[ChatMessage]]) -> None: @@ -187,16 +230,12 @@ async def from_str(self, prompt: str, ctx: WorkflowContext[list[ChatMessage]]) - await ctx.send_message([ChatMessage(Role.USER, text=prompt)]) @handler - async def from_message(self, message: ChatMessage, ctx: WorkflowContext[list[ChatMessage]]) -> None: # type: ignore[name-defined] + async def from_message(self, message: ChatMessage, ctx: WorkflowContext[list[ChatMessage]]) -> None: """Pass through an existing chat message as the initial conversation.""" await ctx.send_message([message]) @handler - async def from_messages( - self, - messages: list[ChatMessage], - ctx: WorkflowContext[list[ChatMessage]], - ) -> None: # type: ignore[name-defined] + async def from_messages(self, messages: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: """Forward a list of chat messages as the starting conversation history.""" await ctx.send_message(list(messages)) @@ -362,7 +401,8 @@ async def handle_agent_response( self._conversation = list(full_conv) else: # Subsequent responses - append only new messages from this agent - # Keep ALL messages including tool calls to maintain complete history + # Keep ALL messages including tool calls to maintain complete history. + # This includes assistant messages with function calls and tool role messages with results. new_messages = response.agent_run_response.messages or [] self._conversation.extend(new_messages) @@ -439,9 +479,25 @@ async def handle_user_input( message: _ConversationWithUserInput, ctx: WorkflowContext[AgentExecutorRequest, list[ChatMessage]], ) -> None: - """Receive full conversation with new user input from gateway, update history, trim for agent.""" - # Update authoritative conversation - self._conversation = list(message.full_conversation) + """Receive user input from gateway, update history, and route to agent. + + The message.full_conversation may contain: + - Full conversation history + new user messages (normal flow) + - Only new user messages (post-checkpoint-restore flow, see issue #2667) + + The gateway sets message.is_post_restore=True when resuming after a checkpoint + restore. In that case, we append the new messages to the existing conversation + rather than replacing it. + """ + incoming = message.full_conversation + + if message.is_post_restore and self._conversation: + # Post-restore: append new user messages to existing conversation + # The coordinator already has its conversation restored from checkpoint + self._conversation.extend(incoming) + else: + # Normal flow: replace with full conversation + self._conversation = list(incoming) if incoming else self._conversation # Reset autonomous turn counter on new user input self._autonomous_turns = 0 @@ -462,9 +518,9 @@ async def handle_user_input( ) else: logger.info(f"Routing user input to coordinator '{target_agent_id}'") - # Note: Stack is only used for specialist-to-specialist handoffs, not user input routing - # Clean before sending to target agent + # Clean conversation before sending to target agent + # Removes tool-related messages that shouldn't be resent on every turn cleaned = clean_conversation_for_handoff(self._conversation) request = AgentExecutorRequest(messages=cleaned, should_respond=True) await ctx.send_message(request, target_id=target_agent_id) @@ -581,13 +637,7 @@ def _apply_response_metadata(self, conversation: list[ChatMessage], agent_respon class _UserInputGateway(Executor): """Bridges conversation context with the request & response cycle and re-enters the loop.""" - def __init__( - self, - *, - starting_agent_id: str, - prompt: str | None, - id: str, - ) -> None: + def __init__(self, *, starting_agent_id: str, prompt: str | None, id: str) -> None: """Initialise the gateway that requests user input and forwards responses.""" super().__init__(id) self._starting_agent_id = starting_agent_id @@ -626,20 +676,39 @@ async def resume_from_user( response: object, ctx: WorkflowContext[_ConversationWithUserInput], ) -> None: - """Convert user input responses back into chat messages and resume the workflow.""" - # Reconstruct full conversation with new user input - conversation = list(original_request.conversation) + """Convert user input responses back into chat messages and resume the workflow. + + After checkpoint restore, original_request.conversation will be empty (not serialized + to prevent duplication - see issue #2667). In this case, we send only the new user + messages and let the coordinator append them to its already-restored conversation. + """ user_messages = _as_user_messages(response) - conversation.extend(user_messages) - # Send full conversation back to coordinator (not trimmed) - # Coordinator will update its authoritative history and trim for agent - message = _ConversationWithUserInput(full_conversation=conversation) + if original_request.conversation: + # Normal flow: have conversation history from the original request + conversation = list(original_request.conversation) + conversation.extend(user_messages) + message = _ConversationWithUserInput(full_conversation=conversation, is_post_restore=False) + else: + # Post-restore flow: conversation was not serialized, send only new user messages + # The coordinator will append these to its already-restored conversation + message = _ConversationWithUserInput(full_conversation=user_messages, is_post_restore=True) + await ctx.send_message(message, target_id="handoff-coordinator") def _as_user_messages(payload: Any) -> list[ChatMessage]: - """Normalise arbitrary payloads into user-authored chat messages.""" + """Normalize arbitrary payloads into user-authored chat messages. + + Handles various input formats: + - ChatMessage instances (converted to USER role if needed) + - List of ChatMessage instances + - Mapping with 'text' or 'content' key + - Any other type (converted to string) + + Returns: + List of ChatMessage instances with USER role. + """ if isinstance(payload, ChatMessage): if payload.role == Role.USER: return [payload] @@ -735,7 +804,7 @@ class HandoffBuilder: name="customer_support", participants=[coordinator, refund, shipping], ) - .set_coordinator("coordinator_agent") + .set_coordinator(coordinator) .build() ) @@ -754,7 +823,7 @@ class HandoffBuilder: # Enable specialist-to-specialist handoffs with fluent API workflow = ( HandoffBuilder(participants=[coordinator, replacement, delivery, billing]) - .set_coordinator("coordinator_agent") + .set_coordinator(coordinator) .add_handoff(coordinator, [replacement, delivery, billing]) # Coordinator routes to all .add_handoff(replacement, [delivery, billing]) # Replacement delegates to delivery/billing .add_handoff(delivery, billing) # Delivery escalates to billing @@ -764,6 +833,35 @@ class HandoffBuilder: # Flow: User → Coordinator → Replacement → Delivery → Back to User # (Replacement hands off to Delivery without returning to user) + **Use Participant Factories for State Isolation:** + + .. code-block:: python + # Define factories that produce fresh agent instances per workflow run + def create_coordinator() -> AgentProtocol: + return chat_client.create_agent( + instructions="You are the coordinator agent...", + name="coordinator_agent", + ) + + + def create_specialist() -> AgentProtocol: + return chat_client.create_agent( + instructions="You are the specialist agent...", + name="specialist_agent", + ) + + + workflow = ( + HandoffBuilder( + participant_factories={ + "coordinator": create_coordinator, + "specialist": create_specialist, + } + ) + .set_coordinator("coordinator") + .build() + ) + **Custom Termination Condition:** .. code-block:: python @@ -771,7 +869,7 @@ class HandoffBuilder: # Terminate when user says goodbye or after 5 exchanges workflow = ( HandoffBuilder(participants=[coordinator, refund, shipping]) - .set_coordinator("coordinator_agent") + .set_coordinator(coordinator) .with_termination_condition( lambda conv: sum(1 for msg in conv if msg.role.value == "user") >= 5 or any("goodbye" in msg.text.lower() for msg in conv[-2:]) @@ -788,7 +886,7 @@ class HandoffBuilder: storage = InMemoryCheckpointStorage() workflow = ( HandoffBuilder(participants=[coordinator, refund, shipping]) - .set_coordinator("coordinator_agent") + .set_coordinator(coordinator) .with_checkpointing(storage) .build() ) @@ -797,6 +895,9 @@ class HandoffBuilder: name: Optional workflow name for identification and logging. participants: List of agents (AgentProtocol) or executors to participate in the handoff. The first agent you specify as coordinator becomes the orchestrating agent. + participant_factories: Mapping of factory names to callables that produce agents or + executors when invoked. This allows for lazy instantiation + and state isolation per workflow instance created by this builder. description: Optional human-readable description of the workflow. Raises: @@ -809,14 +910,16 @@ def __init__( *, name: str | None = None, participants: Sequence[AgentProtocol | Executor] | None = None, + participant_factories: Mapping[str, Callable[[], AgentProtocol | Executor]] | None = None, description: str | None = None, ) -> None: r"""Initialize a HandoffBuilder for creating conversational handoff workflows. The builder starts in an unconfigured state and requires you to call: 1. `.participants([...])` - Register agents - 2. `.set_coordinator(...)` - Designate which agent receives initial user input - 3. `.build()` - Construct the final Workflow + 2. or `.participant_factories({...})` - Register agent/executor factories + 3. `.set_coordinator(...)` - Designate which agent receives initial user input + 4. `.build()` - Construct the final Workflow Optional configuration methods allow you to customize context management, termination logic, and persistence. @@ -828,6 +931,9 @@ def __init__( participate in the handoff workflow. You can also call `.participants([...])` later. Each participant must have a unique identifier (name for agents, id for executors). + participant_factories: Optional mapping of factory names to callables that produce agents or + executors when invoked. This allows for lazy instantiation + and state isolation per workflow instance created by this builder. description: Optional human-readable description explaining the workflow's purpose. Useful for documentation and observability. @@ -848,7 +954,6 @@ def __init__( self._termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] = ( _default_termination_condition ) - self._auto_register_handoff_tools: bool = True self._handoff_config: dict[str, list[str]] = {} # Maps agent_id -> [target_agent_ids] self._return_to_previous: bool = False self._interaction_mode: Literal["human_in_loop", "autonomous"] = "human_in_loop" @@ -856,9 +961,79 @@ def __init__( self._request_info_enabled: bool = False self._request_info_filter: set[str] | None = None + self._participant_factories: dict[str, Callable[[], AgentProtocol | Executor]] = {} + if participant_factories: + self.participant_factories(participant_factories) + if participants: self.participants(participants) + # region Fluent Configuration Methods + + def participant_factories( + self, participant_factories: Mapping[str, Callable[[], AgentProtocol | Executor]] + ) -> "HandoffBuilder": + """Register factories that produce agents or executors for the handoff workflow. + + Each factory is a callable that returns an AgentProtocol or Executor instance. + Factories are invoked when building the workflow, allowing for lazy instantiation + and state isolation per workflow instance. + + Args: + participant_factories: Mapping of factory names to callables that return AgentProtocol or Executor + instances. Each produced participant must have a unique identifier (name for + agents, id for executors). + + Returns: + Self for method chaining. + + Raises: + ValueError: If participant_factories is empty or `.participants(...)` or `.participant_factories(...)` + has already been called. + + Example: + .. code-block:: python + + from agent_framework import ChatAgent, HandoffBuilder + + + def create_coordinator() -> ChatAgent: + return ... + + + def create_refund_agent() -> ChatAgent: + return ... + + + def create_billing_agent() -> ChatAgent: + return ... + + + factories = { + "coordinator": create_coordinator, + "refund": create_refund_agent, + "billing": create_billing_agent, + } + + builder = HandoffBuilder().participant_factories(factories) + # Use the factory IDs to create handoffs and set the coordinator + builder.add_handoff("coordinator", ["refund", "billing"]) + builder.set_coordinator("coordinator") + """ + if self._executors: + raise ValueError( + "Cannot mix .participants([...]) and .participant_factories() in the same builder instance." + ) + + if self._participant_factories: + raise ValueError("participant_factories() has already been called on this builder instance.") + + if not participant_factories: + raise ValueError("participant_factories cannot be empty") + + self._participant_factories = dict(participant_factories) + return self + def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "HandoffBuilder": """Register the agents or executors that will participate in the handoff workflow. @@ -875,7 +1050,8 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Han Self for method chaining. Raises: - ValueError: If participants is empty or contains duplicates. + ValueError: If participants is empty, contains duplicates, or `.participants(...)` or + `.participant_factories(...)` has already been called. TypeError: If participants are not AgentProtocol or Executor instances. Example: @@ -897,26 +1073,28 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Han This method resets any previously configured coordinator, so you must call `.set_coordinator(...)` again after changing participants. """ + if self._participant_factories: + raise ValueError( + "Cannot mix .participants([...]) and .participant_factories() in the same builder instance." + ) + + if self._executors: + raise ValueError("participants have already been assigned") + if not participants: raise ValueError("participants cannot be empty") named: dict[str, AgentProtocol | Executor] = {} for participant in participants: - identifier: str if isinstance(participant, Executor): identifier = participant.id elif isinstance(participant, AgentProtocol): - name_attr = getattr(participant, "name", None) - if not name_attr: - raise ValueError( - "Agents used in handoff workflows must have a stable name " - "so they can be addressed during routing." - ) - identifier = str(name_attr) + identifier = participant.display_name else: raise TypeError( f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}." ) + if identifier in named: raise ValueError(f"Duplicate participant name '{identifier}' detected") named[identifier] = participant @@ -927,15 +1105,10 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Han ) wrapped = metadata["executors"] - seen_ids: set[str] = set() - for executor in wrapped.values(): - if executor.id in seen_ids: - raise ValueError(f"Duplicate participant with id '{executor.id}' detected") - seen_ids.add(executor.id) - self._executors = {executor.id: executor for executor in wrapped.values()} self._aliases = metadata["aliases"] self._starting_agent_id = None + return self def set_coordinator(self, agent: str | AgentProtocol | Executor) -> "HandoffBuilder": @@ -952,7 +1125,7 @@ def set_coordinator(self, agent: str | AgentProtocol | Executor) -> "HandoffBuil Args: agent: The agent to use as the coordinator. Can be: - - Agent name (str): e.g., "coordinator_agent" + - Factory name (str): If using participant factories - AgentProtocol instance: The actual agent object - Executor instance: A custom executor wrapping an agent @@ -960,15 +1133,26 @@ def set_coordinator(self, agent: str | AgentProtocol | Executor) -> "HandoffBuil Self for method chaining. Raises: - ValueError: If participants(...) hasn't been called yet, or if the specified - agent is not in the participants list. + ValueError: 1) If `agent` is an AgentProtocol or Executor instance but `.participants(...)` hasn't + been called yet, or if it is not in the participants list. + 2) If `agent` is a factory name (str) but `.participant_factories(...)` hasn't been + called yet, or if it is not in the participant_factories list. + TypeError: If `agent` is not a str, AgentProtocol, or Executor instance. Example: .. code-block:: python - # Use agent name - builder = HandoffBuilder().participants([coordinator, refund, billing]).set_coordinator("coordinator") + # Use factory name with `.participant_factories()` + builder = ( + HandoffBuilder() + .participant_factories({ + "coordinator": create_coordinator, + "refund": create_refund_agent, + "billing": create_billing_agent, + }) + .set_coordinator("coordinator") + ) # Or pass the agent object directly builder = HandoffBuilder().participants([coordinator, refund, billing]).set_coordinator(coordinator) @@ -979,12 +1163,29 @@ def set_coordinator(self, agent: str | AgentProtocol | Executor) -> "HandoffBuil Decorate the tool with `approval_mode="always_require"` to ensure the workflow intercepts the call before execution and can make the transition. """ - if not self._executors: - raise ValueError("Call participants(...) before coordinator(...)") - resolved = self._resolve_to_id(agent) - if resolved not in self._executors: - raise ValueError(f"coordinator '{resolved}' is not part of the participants list") - self._starting_agent_id = resolved + if isinstance(agent, (AgentProtocol, Executor)): + if not self._executors: + raise ValueError( + "Call participants(...) before coordinator(...). If using participant_factories, " + "pass the factory name (str) instead of the agent instance." + ) + resolved = self._resolve_to_id(agent) + if resolved not in self._executors: + raise ValueError(f"coordinator '{resolved}' is not part of the participants list") + self._starting_agent_id = resolved + elif isinstance(agent, str): + if agent not in self._participant_factories: + raise ValueError( + f"coordinator factory name '{agent}' is not part of the participant_factories list. If " + "you are using participant instances, call .participants(...) and pass the agent instance instead." + ) + self._starting_agent_id = agent + else: + raise TypeError( + "coordinator must be a factory name (str), AgentProtocol, or Executor instance. " + f"Got {type(agent).__name__}." + ) + return self def add_handoff( @@ -1004,33 +1205,42 @@ def add_handoff( Args: source: The agent that can initiate the handoff. Can be: - - Agent name (str): e.g., "triage_agent" + - Factory name (str): If using participant factories - AgentProtocol instance: The actual agent object - Executor instance: A custom executor wrapping an agent + - Cannot mix factory names and instances across source and targets targets: One or more target agents that the source can hand off to. Can be: - - Single agent: "billing_agent" or agent_instance - - Multiple agents: ["billing_agent", "support_agent"] or [agent1, agent2] - tool_name: Optional custom name for the handoff tool. If not provided, generates - "handoff_to_" for single targets or "handoff_to__agent" - for multiple targets based on target names. - tool_description: Optional custom description for the handoff tool. If not provided, - generates "Handoff to the agent." + - Factory name (str): If using participant factories + - AgentProtocol instance: The actual agent object + - Executor instance: A custom executor wrapping an agent + - Single target: "billing_agent" or agent_instance + - Multiple targets: ["billing_agent", "support_agent"] or [agent1, agent2] + - Cannot mix factory names and instances across source and targets + tool_name: Optional custom name for the handoff tool. Currently not used in the + implementation - tools are always auto-generated as "handoff_to_". + Reserved for future enhancement. + tool_description: Optional custom description for the handoff tool. Currently not used + in the implementation - descriptions are always auto-generated as + "Handoff to the agent.". Reserved for future enhancement. Returns: Self for method chaining. Raises: - ValueError: If source or targets are not in the participants list, or if + ValueError: 1) If source or targets are not in the participants list, or if participants(...) hasn't been called yet. + 2) If source or targets are factory names (str) but participant_factories(...) + hasn't been called yet, or if they are not in the participant_factories list. + TypeError: If mixing factory names (str) and AgentProtocol/Executor instances Examples: - Single target: + Single target (using factory name): .. code-block:: python builder.add_handoff("triage_agent", "billing_agent") - Multiple targets (using agent names): + Multiple targets (using factory names): .. code-block:: python @@ -1055,138 +1265,70 @@ def add_handoff( .build() ) - Custom tool names and descriptions: - - .. code-block:: python - - builder.add_handoff( - "support_agent", - "escalation_agent", - tool_name="escalate_to_l2", - tool_description="Escalate this issue to Level 2 support", - ) - Note: - Handoff tools are automatically registered for each source agent - If a source agent is configured multiple times via add_handoff, targets are merged """ - if not self._executors: - raise ValueError("Call participants(...) before add_handoff(...)") - - # Resolve source agent ID - source_id = self._resolve_to_id(source) - if source_id not in self._executors: - raise ValueError(f"Source agent '{source}' is not in the participants list") - - # Normalize targets to list - target_list = [targets] if isinstance(targets, (str, AgentProtocol, Executor)) else list(targets) - - # Resolve all target IDs - target_ids: list[str] = [] - for target in target_list: - target_id = self._resolve_to_id(target) - if target_id not in self._executors: - raise ValueError(f"Target agent '{target}' is not in the participants list") - target_ids.append(target_id) - - # Merge with existing handoff configuration for this source - if source_id in self._handoff_config: - # Add new targets to existing list, avoiding duplicates - existing = self._handoff_config[source_id] - for target_id in target_ids: - if target_id not in existing: - existing.append(target_id) - else: - self._handoff_config[source_id] = target_ids - - return self - - def auto_register_handoff_tools(self, enabled: bool) -> "HandoffBuilder": - """Configure whether the builder should synthesize handoff tools for the starting agent.""" - self._auto_register_handoff_tools = enabled - return self - - def _apply_auto_tools(self, agent: ChatAgent, specialists: Mapping[str, Executor]) -> dict[str, str]: - """Attach synthetic handoff tools to a chat agent and return the target lookup table.""" - chat_options = agent.chat_options - existing_tools = list(chat_options.tools or []) - existing_names = {getattr(tool, "name", "") for tool in existing_tools if hasattr(tool, "name")} - - tool_targets: dict[str, str] = {} - new_tools: list[Any] = [] - for exec_id in specialists: - alias = exec_id - sanitized = sanitize_identifier(alias) - tool = _create_handoff_tool(alias) - if tool.name not in existing_names: - new_tools.append(tool) - tool_targets[tool.name.lower()] = exec_id - tool_targets[sanitized] = exec_id - tool_targets[alias.lower()] = exec_id - - if new_tools: - chat_options.tools = existing_tools + new_tools - else: - chat_options.tools = existing_tools - - return tool_targets - - def _resolve_agent_id(self, agent_identifier: str) -> str: - """Resolve an agent identifier to an executor ID. - - Args: - agent_identifier: Can be agent name, display name, or executor ID - - Returns: - The executor ID - - Raises: - ValueError: If the identifier cannot be resolved - """ - # Check if it's already an executor ID - if agent_identifier in self._executors: - return agent_identifier - - # Check if it's an alias - if agent_identifier in self._aliases: - return self._aliases[agent_identifier] - - # Not found - raise ValueError(f"Agent identifier '{agent_identifier}' not found in participants") - - def _prepare_agent_with_handoffs( - self, - executor: AgentExecutor, - target_agents: Mapping[str, Executor], - ) -> tuple[AgentExecutor, dict[str, str]]: - """Prepare an agent by adding handoff tools for the specified target agents. - - Args: - executor: The agent executor to prepare - target_agents: Map of executor IDs to target executors this agent can hand off to - - Returns: - Tuple of (updated executor, tool_targets map) - """ - agent = getattr(executor, "_agent", None) - if not isinstance(agent, ChatAgent): - return executor, {} + if isinstance(source, str) and ( + isinstance(targets, str) or (isinstance(targets, Sequence) and all(isinstance(t, str) for t in targets)) + ): + # Both source and targets are factory names + if not self._participant_factories: + raise ValueError("Call participant_factories(...) before add_handoff(...)") + + if source not in self._participant_factories: + raise ValueError(f"Source factory name '{source}' is not in the participant_factories list") + + target_list: list[str] = [targets] if isinstance(targets, str) else list(targets) # type: ignore + for target in target_list: + if target not in self._participant_factories: + raise ValueError(f"Target factory name '{target}' is not in the participant_factories list") + + self._handoff_config[source] = target_list # type: ignore + return self + + if isinstance(source, (AgentProtocol, Executor)) and ( + isinstance(targets, (AgentProtocol, Executor)) + or all(isinstance(t, (AgentProtocol, Executor)) for t in targets) + ): + # Both source and targets are instances + if not self._executors: + raise ValueError("Call participants(...) before add_handoff(...)") + + # Resolve source agent ID + source_id = self._resolve_to_id(source) + if source_id not in self._executors: + raise ValueError(f"Source agent '{source}' is not in the participants list") + + # Normalize targets to list + target_list: list[AgentProtocol | Executor] = ( # type: ignore[no-redef] + [targets] if isinstance(targets, (AgentProtocol, Executor)) else list(targets) + ) # type: ignore + + # Resolve all target IDs + target_ids: list[str] = [] + for target in target_list: + target_id = self._resolve_to_id(target) + if target_id not in self._executors: + raise ValueError(f"Target agent '{target}' is not in the participants list") + target_ids.append(target_id) + + # Merge with existing handoff configuration for this source + if source_id in self._handoff_config: + # Add new targets to existing list, avoiding duplicates + existing = self._handoff_config[source_id] + for target_id in target_ids: + if target_id not in existing: + existing.append(target_id) + else: + self._handoff_config[source_id] = target_ids - cloned_agent = _clone_chat_agent(agent) - tool_targets = self._apply_auto_tools(cloned_agent, target_agents) - if tool_targets: - middleware = _AutoHandoffMiddleware(tool_targets) - existing_middleware = list(cloned_agent.middleware or []) - existing_middleware.append(middleware) - cloned_agent.middleware = existing_middleware + return self - new_executor = AgentExecutor( - cloned_agent, - agent_thread=getattr(executor, "_agent_thread", None), - output_response=getattr(executor, "_output_response", False), - id=executor.id, + raise TypeError( + "Cannot mix factory names (str) and AgentProtocol/Executor instances " + "across source and targets in add_handoff()" ) - return new_executor, tool_targets def request_prompt(self, prompt: str | None) -> "HandoffBuilder": """Set a custom prompt message displayed when requesting user input. @@ -1548,75 +1690,46 @@ def build(self) -> Workflow: After calling build(), the builder instance should not be reused. Create a new builder if you need to construct another workflow with different configuration. """ - if not self._executors: - raise ValueError("No participants provided. Call participants([...]) first.") - if self._starting_agent_id is None: - raise ValueError("coordinator must be defined before build().") + if not self._executors and not self._participant_factories: + raise ValueError( + "No participants or participant_factories have been configured. " + "Call participants(...) or participant_factories(...) first." + ) - starting_executor = self._executors[self._starting_agent_id] - specialists = { - exec_id: executor for exec_id, executor in self._executors.items() if exec_id != self._starting_agent_id - } + if self._starting_agent_id is None: + raise ValueError("Must call set_coordinator(...) before building the workflow.") - # Build handoff tool registry for all agents that need them - handoff_tool_targets: dict[str, str] = {} - if self._auto_register_handoff_tools: - # Determine which agents should have handoff tools - if self._handoff_config: - # Use explicit handoff configuration from add_handoff() calls - for source_exec_id, target_exec_ids in self._handoff_config.items(): - executor = self._executors.get(source_exec_id) - if not executor: - raise ValueError(f"Handoff source agent '{source_exec_id}' not found in participants") - - if isinstance(executor, AgentExecutor): - # Build targets map for this source agent - targets_map: dict[str, Executor] = {} - for target_exec_id in target_exec_ids: - target_executor = self._executors.get(target_exec_id) - if not target_executor: - raise ValueError(f"Handoff target agent '{target_exec_id}' not found in participants") - targets_map[target_exec_id] = target_executor - - # Register handoff tools for this agent - updated_executor, tool_targets = self._prepare_agent_with_handoffs(executor, targets_map) - self._executors[source_exec_id] = updated_executor - handoff_tool_targets.update(tool_targets) - else: - # Default behavior: only coordinator gets handoff tools to all specialists - if isinstance(starting_executor, AgentExecutor) and specialists: - starting_executor, tool_targets = self._prepare_agent_with_handoffs(starting_executor, specialists) - self._executors[self._starting_agent_id] = starting_executor - handoff_tool_targets.update(tool_targets) # Update references after potential agent modifications - starting_executor = self._executors[self._starting_agent_id] - specialists = { - exec_id: executor for exec_id, executor in self._executors.items() if exec_id != self._starting_agent_id - } + # Resolve executors, aliases, and handoff tool targets + # This will instantiate participants if using factories, and validate handoff config + start_executor_id, executors, aliases, handoff_tool_targets = self._resolve_executors_and_handoffs() + specialists = {exec_id: executor for exec_id, executor in executors.items() if exec_id != start_executor_id} if not specialists: logger.warning("Handoff workflow has no specialist agents; the coordinator will loop with the user.") descriptions = { - exec_id: getattr(executor, "description", None) or exec_id for exec_id, executor in self._executors.items() + exec_id: getattr(executor, "description", None) or exec_id for exec_id, executor in executors.items() } participant_specs = { exec_id: GroupChatParticipantSpec(name=exec_id, participant=executor, description=descriptions[exec_id]) - for exec_id, executor in self._executors.items() + for exec_id, executor in executors.items() } input_node = _InputToConversation(id="input-conversation") user_gateway = _UserInputGateway( - starting_agent_id=starting_executor.id, + starting_agent_id=start_executor_id, prompt=self._request_prompt, id="handoff-user-input", ) builder = WorkflowBuilder(name=self._name, description=self._description).set_start_executor(input_node) - specialist_aliases = {alias: exec_id for alias, exec_id in self._aliases.items() if exec_id in specialists} + specialist_aliases = { + alias: specialists[exec_id].id for alias, exec_id in aliases.items() if exec_id in specialists + } def _handoff_orchestrator_factory(_: _GroupChatConfig) -> Executor: return _HandoffCoordinator( - starting_agent_id=starting_executor.id, + starting_agent_id=start_executor_id, specialist_ids=specialist_aliases, input_gateway_id=user_gateway.id, termination_condition=self._termination_condition, @@ -1633,8 +1746,8 @@ def _handoff_orchestrator_factory(_: _GroupChatConfig) -> Executor: manager_name=self._starting_agent_id, participants=participant_specs, max_rounds=None, - participant_aliases=self._aliases, - participant_executors=self._executors, + participant_aliases=aliases, + participant_executors=executors, ) # Determine participant factory - wrap with request info interceptor if enabled @@ -1683,14 +1796,159 @@ def _factory_with_request_info( builder = builder.add_edge(input_node, starting_entry_executor) else: # Fallback to direct connection if interceptor not found - builder = builder.add_edge(input_node, starting_executor) + builder = builder.add_edge(input_node, executors[start_executor_id]) else: - builder = builder.add_edge(input_node, starting_executor) + builder = builder.add_edge(input_node, executors[start_executor_id]) builder = builder.add_edge(coordinator, user_gateway) builder = builder.add_edge(user_gateway, coordinator) return builder.build() + # endregion Fluent Configuration Methods + + # region Internal Helper Methods + + def _resolve_executors(self) -> tuple[dict[str, Executor], dict[str, str]]: + """Resolve participant factories into executor instances. + + If executors were provided directly via participants(...), those are returned as-is. + If participant factories were provided via participant_factories(...), those + are invoked to create executor instances and aliases. + + Returns: + Tuple of (executors map, aliases map) + """ + if self._executors and self._participant_factories: + raise ValueError("Cannot have both executors and participant_factories configured") + + if self._executors: + if self._aliases: + # Return existing executors and aliases + return self._executors, self._aliases + raise ValueError("Aliases is empty despite executors being provided") + + if self._participant_factories: + # Invoke each factory to create participant instances + executor_ids_to_executors: dict[str, AgentProtocol | Executor] = {} + factory_names_to_ids: dict[str, str] = {} + for factory_name, factory in self._participant_factories.items(): + instance: Executor | AgentProtocol = factory() + if isinstance(instance, Executor): + identifier = instance.id + elif isinstance(instance, AgentProtocol): + identifier = instance.display_name + else: + raise TypeError( + f"Participants must be AgentProtocol or Executor instances. Got {type(instance).__name__}." + ) + + if identifier in executor_ids_to_executors: + raise ValueError(f"Duplicate participant name '{identifier}' detected") + executor_ids_to_executors[identifier] = instance + factory_names_to_ids[factory_name] = identifier + + # Prepare metadata and wrap instances as needed + metadata = prepare_participant_metadata( + executor_ids_to_executors, + description_factory=lambda name, participant: getattr(participant, "description", None) or name, + ) + + wrapped = metadata["executors"] + # Map executors by factory name (not executor.id) because handoff configs reference factory names + # This allows users to configure handoffs using the factory names they provided + executors = { + factory_name: wrapped[executor_id] for factory_name, executor_id in factory_names_to_ids.items() + } + aliases = metadata["aliases"] + + return executors, aliases + + raise ValueError("No executors or participant_factories have been configured") + + def _resolve_handoffs(self, executors: Mapping[str, Executor]) -> tuple[dict[str, Executor], dict[str, str]]: + """Handoffs may be specified using factory names or instances; resolve to executor IDs. + + Args: + executors: Map of executor IDs or factory names to Executor instances + + Returns: + Tuple of (updated executors map, handoff configuration map) + The updated executors map may have modified agents with handoff tools added + and maps executor IDs to Executor instances. + The handoff configuration map maps executor IDs to lists of target executor IDs. + """ + handoff_tool_targets: dict[str, str] = {} + updated_executors = {executor.id: executor for executor in executors.values()} + # Determine which agents should have handoff tools + if self._handoff_config: + # Use explicit handoff configuration from add_handoff() calls + for source_id, target_ids in self._handoff_config.items(): + executor = executors.get(source_id) + if not executor: + raise ValueError( + f"Handoff source agent '{source_id}' not found. " + "Please make sure source has been added as either a participant or participant_factory." + ) + + if isinstance(executor, AgentExecutor): + # Build targets map for this source agent + targets_map: dict[str, Executor] = {} + for target_id in target_ids: + target_executor = executors.get(target_id) + if not target_executor: + raise ValueError( + f"Handoff target agent '{target_id}' not found. " + "Please make sure target has been added as either a participant or participant_factory." + ) + targets_map[target_executor.id] = target_executor + + # Register handoff tools for this agent + updated_executor, tool_targets = self._prepare_agent_with_handoffs(executor, targets_map) + updated_executors[updated_executor.id] = updated_executor + handoff_tool_targets.update(tool_targets) + else: + if self._starting_agent_id is None or self._starting_agent_id not in executors: + raise RuntimeError("Failed to resolve default handoff configuration due to missing starting agent.") + + # Default behavior: only coordinator gets handoff tools to all specialists + starting_executor = executors[self._starting_agent_id] + specialists = { + executor.id: executor for executor in executors.values() if executor.id != starting_executor.id + } + + if isinstance(starting_executor, AgentExecutor) and specialists: + starting_executor, tool_targets = self._prepare_agent_with_handoffs(starting_executor, specialists) + updated_executors[starting_executor.id] = starting_executor + handoff_tool_targets.update(tool_targets) # Update references after potential agent modifications + + return updated_executors, handoff_tool_targets + + def _resolve_executors_and_handoffs(self) -> tuple[str, dict[str, Executor], dict[str, str], dict[str, str]]: + """Resolve participant factories into executor instances and handoff configurations. + + If executors were provided directly via participants(...), those are returned as-is. + If participant factories were provided via participant_factories(...), those + are invoked to create executor instances and aliases. + + Returns: + Tuple of (executors map, aliases map, handoff configuration map) + """ + # Resolve the participant factories now. This doesn't break the factory pattern + # since the Handoff builder still creates new instances per workflow build. + executors, aliases = self._resolve_executors() + # `self._starting_agent_id` is either a factory name or executor ID at this point, + # resolve to executor ID + if self._starting_agent_id in executors: + start_executor_id = executors[self._starting_agent_id].id + else: + raise RuntimeError("Failed to resolve starting agent ID during build.") + + # Resolve handoffs + # This will update the `executors` dict to a map of executor IDs to executors + updated_executors, handoff_tool_targets = self._resolve_handoffs(executors) + + return start_executor_id, updated_executors, aliases, handoff_tool_targets + def _resolve_to_id(self, candidate: str | AgentProtocol | Executor) -> str: """Resolve a participant reference into a concrete executor identifier.""" if isinstance(candidate, Executor): @@ -1705,3 +1963,77 @@ def _resolve_to_id(self, candidate: str | AgentProtocol | Executor) -> str: return self._aliases[candidate] return candidate raise TypeError(f"Invalid starting agent reference: {type(candidate).__name__}") + + def _apply_auto_tools(self, agent: ChatAgent, specialists: Mapping[str, Executor]) -> dict[str, str]: + """Attach synthetic handoff tools to a chat agent and return the target lookup table. + + Creates handoff tools for each specialist agent that this agent can route to. + The tool_targets dict maps various name formats (tool name, sanitized name, alias) + to executor IDs to enable flexible handoff target resolution. + + Args: + agent: The ChatAgent to add handoff tools to + specialists: Map of executor IDs or factory names to specialist executors this agent can hand off to + + Returns: + Dict mapping tool names (in various formats) to executor IDs for handoff resolution + """ + chat_options = agent.chat_options + existing_tools = list(chat_options.tools or []) + existing_names = {getattr(tool, "name", "") for tool in existing_tools if hasattr(tool, "name")} + + tool_targets: dict[str, str] = {} + new_tools: list[Any] = [] + for executor in specialists.values(): + alias = executor.id + sanitized = sanitize_identifier(alias) + tool = _create_handoff_tool(alias, executor.description if isinstance(executor, AgentExecutor) else None) + if tool.name not in existing_names: + new_tools.append(tool) + # Map multiple name variations to the same executor ID for robust resolution + tool_targets[tool.name.lower()] = executor.id + tool_targets[sanitized] = executor.id + tool_targets[alias.lower()] = executor.id + + if new_tools: + chat_options.tools = existing_tools + new_tools + else: + chat_options.tools = existing_tools + + return tool_targets + + def _prepare_agent_with_handoffs( + self, + executor: AgentExecutor, + target_agents: Mapping[str, Executor], + ) -> tuple[AgentExecutor, dict[str, str]]: + """Prepare an agent by adding handoff tools for the specified target agents. + + Args: + executor: The agent executor to prepare + target_agents: Map of executor IDs to target executors this agent can hand off to + + Returns: + Tuple of (updated executor, tool_targets map) + """ + agent = getattr(executor, "_agent", None) + if not isinstance(agent, ChatAgent): + return executor, {} + + cloned_agent = _clone_chat_agent(agent) + tool_targets = self._apply_auto_tools(cloned_agent, target_agents) + if tool_targets: + middleware = _AutoHandoffMiddleware(tool_targets) + existing_middleware = list(cloned_agent.middleware or []) + existing_middleware.append(middleware) + cloned_agent.middleware = existing_middleware + + new_executor = AgentExecutor( + cloned_agent, + agent_thread=getattr(executor, "_agent_thread", None), + output_response=getattr(executor, "_output_response", False), + id=executor.id, + ) + return new_executor, tool_targets + + # endregion Internal Helper Methods diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index a24fd77b16..cdbc79e0c0 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -25,7 +25,7 @@ from ._base_group_chat_orchestrator import BaseGroupChatOrchestrator from ._checkpoint import CheckpointStorage, WorkflowCheckpoint -from ._const import EXECUTOR_STATE_KEY +from ._const import EXECUTOR_STATE_KEY, WORKFLOW_RUN_KWARGS_KEY from ._events import AgentRunUpdateEvent, WorkflowEvent from ._executor import Executor, handler from ._group_chat import ( @@ -286,12 +286,14 @@ class _MagenticStartMessage(DictConvertible): """Internal: A message to start a magentic workflow.""" messages: list[ChatMessage] = field(default_factory=_new_chat_message_list) + run_kwargs: dict[str, Any] = field(default_factory=dict) def __init__( self, messages: str | ChatMessage | Sequence[str] | Sequence[ChatMessage] | None = None, *, task: ChatMessage | None = None, + run_kwargs: dict[str, Any] | None = None, ) -> None: normalized = normalize_messages_input(messages) if task is not None: @@ -299,6 +301,7 @@ def __init__( if not normalized: raise ValueError("MagenticStartMessage requires at least one message input.") self.messages: list[ChatMessage] = normalized + self.run_kwargs: dict[str, Any] = run_kwargs or {} @property def task(self) -> ChatMessage: @@ -1179,6 +1182,10 @@ async def handle_start_message( return logger.info("Magentic Orchestrator: Received start message") + # Store run_kwargs in SharedState so agent executors can access them + # Always store (even empty dict) so retrieval is deterministic + await context.set_shared_state(WORKFLOW_RUN_KWARGS_KEY, message.run_kwargs or {}) + self._context = MagenticContext( task=message.task, participant_descriptions=self._participants, @@ -2004,10 +2011,12 @@ async def _invoke_agent( """ logger.debug(f"Agent {self._agent_id}: Running with {len(self._chat_history)} messages") + run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + updates: list[AgentRunResponseUpdate] = [] # The wrapped participant is guaranteed to be an BaseAgent when this is called. agent = cast("AgentProtocol", self._agent) - async for update in agent.run_stream(messages=self._chat_history): # type: ignore[attr-defined] + async for update in agent.run_stream(messages=self._chat_history, **run_kwargs): # type: ignore[attr-defined] updates.append(update) await self._emit_agent_delta_event(ctx, update) @@ -2604,38 +2613,48 @@ def workflow(self) -> Workflow: """Access the underlying workflow.""" return self._workflow - async def run_streaming_with_string(self, task_text: str) -> AsyncIterable[WorkflowEvent]: + async def run_streaming_with_string(self, task_text: str, **kwargs: Any) -> AsyncIterable[WorkflowEvent]: """Run the workflow with a task string. Args: task_text: The task description as a string. + **kwargs: Additional keyword arguments to pass through to agent invocations. + These kwargs will be available in @ai_function tools via **kwargs. Yields: WorkflowEvent: The events generated during the workflow execution. """ start_message = _MagenticStartMessage.from_string(task_text) + start_message.run_kwargs = kwargs async for event in self._workflow.run_stream(start_message): yield event - async def run_streaming_with_message(self, task_message: ChatMessage) -> AsyncIterable[WorkflowEvent]: + async def run_streaming_with_message( + self, task_message: ChatMessage, **kwargs: Any + ) -> AsyncIterable[WorkflowEvent]: """Run the workflow with a ChatMessage. Args: task_message: The task as a ChatMessage. + **kwargs: Additional keyword arguments to pass through to agent invocations. + These kwargs will be available in @ai_function tools via **kwargs. Yields: WorkflowEvent: The events generated during the workflow execution. """ - start_message = _MagenticStartMessage(task_message) + start_message = _MagenticStartMessage(task_message, run_kwargs=kwargs) async for event in self._workflow.run_stream(start_message): yield event - async def run_stream(self, message: Any | None = None) -> AsyncIterable[WorkflowEvent]: + async def run_stream(self, message: Any | None = None, **kwargs: Any) -> AsyncIterable[WorkflowEvent]: """Run the workflow with either a message object or the preset task string. Args: message: The message to send. If None and task_text was provided during construction, uses the preset task string. + **kwargs: Additional keyword arguments to pass through to agent invocations. + These kwargs will be available in @ai_function tools via **kwargs. + Example: workflow.run_stream("task", user_id="123", custom_data={...}) Yields: WorkflowEvent: The events generated during the workflow execution. @@ -2643,13 +2662,19 @@ async def run_stream(self, message: Any | None = None) -> AsyncIterable[Workflow if message is None: if self._task_text is None: raise ValueError("No message provided and no preset task text available") - message = _MagenticStartMessage.from_string(self._task_text) + start_message = _MagenticStartMessage.from_string(self._task_text) elif isinstance(message, str): - message = _MagenticStartMessage.from_string(message) + start_message = _MagenticStartMessage.from_string(message) elif isinstance(message, (ChatMessage, list)): - message = _MagenticStartMessage(message) # type: ignore[arg-type] + start_message = _MagenticStartMessage(message) # type: ignore[arg-type] + else: + start_message = message - async for event in self._workflow.run_stream(message): + # Attach kwargs to the start message + if isinstance(start_message, _MagenticStartMessage): + start_message.run_kwargs = kwargs + + async for event in self._workflow.run_stream(start_message): yield event async def _validate_checkpoint_participants( @@ -2730,46 +2755,49 @@ async def _validate_checkpoint_participants( f"Missing names: {missing}; unexpected names: {unexpected}." ) - async def run_with_string(self, task_text: str) -> WorkflowRunResult: + async def run_with_string(self, task_text: str, **kwargs: Any) -> WorkflowRunResult: """Run the workflow with a task string and return all events. Args: task_text: The task description as a string. + **kwargs: Additional keyword arguments to pass through to agent invocations. Returns: WorkflowRunResult: All events generated during the workflow execution. """ events: list[WorkflowEvent] = [] - async for event in self.run_streaming_with_string(task_text): + async for event in self.run_streaming_with_string(task_text, **kwargs): events.append(event) return WorkflowRunResult(events) - async def run_with_message(self, task_message: ChatMessage) -> WorkflowRunResult: + async def run_with_message(self, task_message: ChatMessage, **kwargs: Any) -> WorkflowRunResult: """Run the workflow with a ChatMessage and return all events. Args: task_message: The task as a ChatMessage. + **kwargs: Additional keyword arguments to pass through to agent invocations. Returns: WorkflowRunResult: All events generated during the workflow execution. """ events: list[WorkflowEvent] = [] - async for event in self.run_streaming_with_message(task_message): + async for event in self.run_streaming_with_message(task_message, **kwargs): events.append(event) return WorkflowRunResult(events) - async def run(self, message: Any | None = None) -> WorkflowRunResult: + async def run(self, message: Any | None = None, **kwargs: Any) -> WorkflowRunResult: """Run the workflow and return all events. Args: message: The message to send. If None and task_text was provided during construction, uses the preset task string. + **kwargs: Additional keyword arguments to pass through to agent invocations. Returns: WorkflowRunResult: All events generated during the workflow execution. """ events: list[WorkflowEvent] = [] - async for event in self.run_stream(message): + async for event in self.run_stream(message, **kwargs): events.append(event) return WorkflowRunResult(events) diff --git a/python/packages/core/agent_framework/_workflows/_participant_utils.py b/python/packages/core/agent_framework/_workflows/_participant_utils.py index ac632a917d..a6f1cf2a84 100644 --- a/python/packages/core/agent_framework/_workflows/_participant_utils.py +++ b/python/packages/core/agent_framework/_workflows/_participant_utils.py @@ -47,15 +47,13 @@ def wrap_participant(participant: AgentProtocol | Executor, *, executor_id: str """Represent `participant` as an `Executor`.""" if isinstance(participant, Executor): return participant + if not isinstance(participant, AgentProtocol): raise TypeError( f"Participants must implement AgentProtocol or be Executor instances. Got {type(participant).__name__}." ) - name = getattr(participant, "name", None) - if executor_id is None: - if not name: - raise ValueError("Agent participants must expose a stable 'name' attribute.") - executor_id = str(name) + + executor_id = executor_id or participant.display_name return AgentExecutor(participant, id=executor_id) diff --git a/python/packages/core/agent_framework/_workflows/_sequential.py b/python/packages/core/agent_framework/_workflows/_sequential.py index 0f849926b6..24ae4cda29 100644 --- a/python/packages/core/agent_framework/_workflows/_sequential.py +++ b/python/packages/core/agent_framework/_workflows/_sequential.py @@ -154,6 +154,9 @@ def register_participants( "Cannot mix .participants([...]) and .register_participants() in the same builder instance." ) + if self._participant_factories: + raise ValueError("register_participants() has already been called on this builder instance.") + if not participant_factories: raise ValueError("participant_factories cannot be empty") @@ -171,6 +174,9 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Seq "Cannot mix .participants([...]) and .register_participants() in the same builder instance." ) + if self._participants: + raise ValueError("participants() has already been called on this builder instance.") + if not participants: raise ValueError("participants cannot be empty") diff --git a/python/packages/core/agent_framework/_workflows/_viz.py b/python/packages/core/agent_framework/_workflows/_viz.py index 14011cb5a5..0fcf8af32d 100644 --- a/python/packages/core/agent_framework/_workflows/_viz.py +++ b/python/packages/core/agent_framework/_workflows/_viz.py @@ -7,16 +7,16 @@ from pathlib import Path from typing import Literal -from ._edge import FanInEdgeGroup +from ._edge import FanInEdgeGroup, InternalEdgeGroup from ._workflow import Workflow # Import of WorkflowExecutor is performed lazily inside methods to avoid cycles -"""Workflow visualization module using graphviz.""" +"""Workflow visualization module using graphviz and Mermaid.""" class WorkflowViz: - """A class for visualizing workflows using graphviz.""" + """A class for visualizing workflows using graphviz and Mermaid.""" def __init__(self, workflow: Workflow): """Initialize the WorkflowViz with a workflow. @@ -26,9 +26,13 @@ def __init__(self, workflow: Workflow): """ self._workflow = workflow - def to_digraph(self) -> str: + def to_digraph(self, include_internal_executors: bool = False) -> str: """Export the workflow as a DOT format digraph string. + Args: + include_internal_executors (bool): Whether to include internal executors in the visualization. + Default is False. + Returns: A string representation of the workflow in DOT format. """ @@ -39,20 +43,37 @@ def to_digraph(self) -> str: lines.append("") # Emit the top-level workflow nodes/edges - self._emit_workflow_digraph(self._workflow, lines, indent=" ") + self._emit_workflow_digraph( + self._workflow, + lines, + indent=" ", + include_internal_executors=include_internal_executors, + ) # Emit sub-workflows hosted by WorkflowExecutor as nested clusters - self._emit_sub_workflows_digraph(self._workflow, lines, indent=" ") + self._emit_sub_workflows_digraph( + self._workflow, + lines, + indent=" ", + include_internal_executors=include_internal_executors, + ) lines.append("}") return "\n".join(lines) - def export(self, format: Literal["svg", "png", "pdf", "dot"] = "svg", filename: str | None = None) -> str: + def export( + self, + format: Literal["svg", "png", "pdf", "dot"] = "svg", + filename: str | None = None, + include_internal_executors: bool = False, + ) -> str: """Export the workflow visualization to a file or return the file path. Args: format: The output format. Supported formats: 'svg', 'png', 'pdf', 'dot'. filename: Optional filename to save the output. If None, creates a temporary file. + include_internal_executors (bool): Whether to include internal executors in the visualization. + Default is False. Returns: The path to the saved file. @@ -66,7 +87,7 @@ def export(self, format: Literal["svg", "png", "pdf", "dot"] = "svg", filename: raise ValueError(f"Unsupported format: {format}. Supported formats: svg, png, pdf, dot") if format == "dot": - content = self.to_digraph() + content = self.to_digraph(include_internal_executors=include_internal_executors) if filename: with open(filename, "w", encoding="utf-8") as f: f.write(content) @@ -87,7 +108,7 @@ def export(self, format: Literal["svg", "png", "pdf", "dot"] = "svg", filename: ) from e # Create a temporary graphviz Source object - dot_content = self.to_digraph() + dot_content = self.to_digraph(include_internal_executors=include_internal_executors) source = graphviz.Source(dot_content) try: @@ -99,7 +120,7 @@ def export(self, format: Literal["svg", "png", "pdf", "dot"] = "svg", filename: # Remove extension if present since graphviz.render() adds it base_name = str(output_path.with_suffix("")) - source.render(base_name, format=format, cleanup=True) + source.render(base_name, format=format, cleanup=True) # type: ignore # Return the actual filename with extension return f"{base_name}.{format}" @@ -108,7 +129,7 @@ def export(self, format: Literal["svg", "png", "pdf", "dot"] = "svg", filename: temp_path = Path(temp_file.name) base_name = str(temp_path.with_suffix("")) - source.render(base_name, format=format, cleanup=True) + source.render(base_name, format=format, cleanup=True) # type: ignore return f"{base_name}.{format}" except graphviz.backend.execute.ExecutableNotFound as e: raise ImportError( @@ -118,60 +139,72 @@ def export(self, format: Literal["svg", "png", "pdf", "dot"] = "svg", filename: "brew install graphviz on macOS, or download from https://graphviz.org/download/ for other platforms." ) from e - def save_svg(self, filename: str) -> str: + def save_svg(self, filename: str, include_internal_executors: bool = False) -> str: """Convenience method to save as SVG. Args: filename: The filename to save the SVG file. + include_internal_executors (bool): Whether to include internal executors in the visualization. + Default is False. Returns: The path to the saved SVG file. """ - return self.export(format="svg", filename=filename) + return self.export(format="svg", filename=filename, include_internal_executors=include_internal_executors) - def save_png(self, filename: str) -> str: + def save_png(self, filename: str, include_internal_executors: bool = False) -> str: """Convenience method to save as PNG. Args: filename: The filename to save the PNG file. + include_internal_executors (bool): Whether to include internal executors in the visualization. + Default is False. Returns: The path to the saved PNG file. """ - return self.export(format="png", filename=filename) + return self.export(format="png", filename=filename, include_internal_executors=include_internal_executors) - def save_pdf(self, filename: str) -> str: + def save_pdf(self, filename: str, include_internal_executors: bool = False) -> str: """Convenience method to save as PDF. Args: filename: The filename to save the PDF file. + include_internal_executors (bool): Whether to include internal executors in the visualization. + Default is False. Returns: The path to the saved PDF file. """ - return self.export(format="pdf", filename=filename) + return self.export(format="pdf", filename=filename, include_internal_executors=include_internal_executors) - def to_mermaid(self) -> str: + def to_mermaid(self, include_internal_executors: bool = False) -> str: """Export the workflow as a Mermaid flowchart string. + Args: + include_internal_executors (bool): Whether to include internal executors in the visualization. + Default is False. + Returns: A string representation of the workflow in Mermaid flowchart syntax. """ - - def _san(s: str) -> str: - """Sanitize an ID for Mermaid (alphanumeric and underscore, start with letter).""" - s2 = re.sub(r"[^0-9A-Za-z_]", "_", s) - if not s2 or not s2[0].isalpha(): - s2 = f"n_{s2}" - return s2 - lines: list[str] = ["flowchart TD"] # Emit top-level workflow - self._emit_workflow_mermaid(self._workflow, lines, indent=" ") + self._emit_workflow_mermaid( + self._workflow, + lines, + indent=" ", + include_internal_executors=include_internal_executors, + ) # Emit sub-workflows as Mermaid subgraphs - self._emit_sub_workflows_mermaid(self._workflow, lines, indent=" ") + self._emit_sub_workflows_mermaid( + self._workflow, + lines, + indent=" ", + include_internal_executors=include_internal_executors, + ) return "\n".join(lines) @@ -181,13 +214,13 @@ def _fan_in_digest(self, target: str, sources: list[str]) -> str: sources_sorted = sorted(sources) return hashlib.sha256((target + "|" + "|".join(sources_sorted)).encode("utf-8")).hexdigest()[:8] - def _compute_fan_in_descriptors(self, wf: Workflow | None = None) -> list[tuple[str, list[str], str]]: + def _compute_fan_in_descriptors(self, workflow: Workflow | None = None) -> list[tuple[str, list[str], str]]: """Return list of (node_id, sources, target) for fan-in groups. node_id is DOT-oriented: fan_in::target::digest """ result: list[tuple[str, list[str], str]] = [] - workflow = wf or self._workflow + workflow = workflow or self._workflow for group in workflow.edge_groups: if isinstance(group, FanInEdgeGroup): target = group.target_executor_ids[0] @@ -197,13 +230,19 @@ def _compute_fan_in_descriptors(self, wf: Workflow | None = None) -> list[tuple[ result.append((node_id, sorted(sources), target)) return result - def _compute_normal_edges(self, wf: Workflow | None = None) -> list[tuple[str, str, bool]]: + def _compute_normal_edges( + self, + workflow: Workflow | None = None, + include_internal_executors: bool = False, + ) -> list[tuple[str, str, bool]]: """Return list of (source_id, target_id, is_conditional) for non-fan-in groups.""" edges: list[tuple[str, str, bool]] = [] - workflow = wf or self._workflow + workflow = workflow or self._workflow for group in workflow.edge_groups: if isinstance(group, FanInEdgeGroup): continue + if isinstance(group, InternalEdgeGroup) and not include_internal_executors: + continue for edge in group.edges: is_cond = getattr(edge, "_condition", None) is not None edges.append((edge.source_id, edge.target_id, is_cond)) @@ -213,7 +252,14 @@ def _compute_normal_edges(self, wf: Workflow | None = None) -> list[tuple[str, s # region Internal emitters (DOT) - def _emit_workflow_digraph(self, wf: Workflow, lines: list[str], indent: str, ns: str | None = None) -> None: + def _emit_workflow_digraph( + self, + workflow: Workflow, + lines: list[str], + indent: str, + ns: str | None = None, + include_internal_executors: bool = False, + ) -> None: """Emit DOT nodes/edges for the given workflow. If ns (namespace) is provided, node ids are prefixed with f"{ns}/" for uniqueness, @@ -224,16 +270,16 @@ def map_id(x: str) -> str: return f"{ns}/{x}" if ns else x # Nodes - start_executor_id = wf.start_executor_id + start_executor_id = workflow.start_executor_id lines.append( f'{indent}"{map_id(start_executor_id)}" [fillcolor=lightgreen, label="{start_executor_id}\\n(Start)"];' ) - for executor_id in wf.executors: + for executor_id in workflow.executors: if executor_id != start_executor_id: lines.append(f'{indent}"{map_id(executor_id)}" [label="{executor_id}"];') # Fan-in nodes - fan_in_nodes = self._compute_fan_in_descriptors(wf) + fan_in_nodes = self._compute_fan_in_descriptors(workflow) if fan_in_nodes: lines.append("") for node_id, _, _ in fan_in_nodes: @@ -246,11 +292,19 @@ def map_id(x: str) -> str: lines.append(f'{indent}"{map_id(node_id)}" -> "{map_id(target)}";') # Normal edges - for src, tgt, is_cond in self._compute_normal_edges(wf): + for src, tgt, is_cond in self._compute_normal_edges( + workflow, include_internal_executors=include_internal_executors + ): edge_attr = ' [style=dashed, label="conditional"]' if is_cond else "" lines.append(f'{indent}"{map_id(src)}" -> "{map_id(tgt)}"{edge_attr};') - def _emit_sub_workflows_digraph(self, wf: Workflow, lines: list[str], indent: str) -> None: + def _emit_sub_workflows_digraph( + self, + workflow: Workflow, + lines: list[str], + indent: str, + include_internal_executors: bool = False, + ) -> None: """Emit DOT subgraphs for any WorkflowExecutor instances found in the workflow.""" # Lazy import to avoid any potential import cycles try: @@ -258,7 +312,7 @@ def _emit_sub_workflows_digraph(self, wf: Workflow, lines: list[str], indent: st except ImportError: # pragma: no cover - best-effort; if unavailable, skip subgraphs return - for exec_id, exec_obj in wf.executors.items(): + for exec_id, exec_obj in workflow.executors.items(): if isinstance(exec_obj, WorkflowExecutor) and hasattr(exec_obj, "workflow") and exec_obj.workflow: subgraph_id = f"cluster_{uuid.uuid5(uuid.NAMESPACE_OID, exec_id).hex[:8]}" lines.append(f"{indent}subgraph {subgraph_id} {{") @@ -267,10 +321,21 @@ def _emit_sub_workflows_digraph(self, wf: Workflow, lines: list[str], indent: st # Emit the nested workflow inside this cluster using a namespace ns = exec_id - self._emit_workflow_digraph(exec_obj.workflow, lines, indent=f"{indent} ", ns=ns) + self._emit_workflow_digraph( + exec_obj.workflow, + lines, + indent=f"{indent} ", + ns=ns, + include_internal_executors=include_internal_executors, + ) # Recurse into deeper nested sub-workflows - self._emit_sub_workflows_digraph(exec_obj.workflow, lines, indent=f"{indent} ") + self._emit_sub_workflows_digraph( + exec_obj.workflow, + lines, + indent=f"{indent} ", + include_internal_executors=include_internal_executors, + ) lines.append(f"{indent}}}") @@ -278,7 +343,14 @@ def _emit_sub_workflows_digraph(self, wf: Workflow, lines: list[str], indent: st # region Internal emitters (Mermaid) - def _emit_workflow_mermaid(self, wf: Workflow, lines: list[str], indent: str, ns: str | None = None) -> None: + def _emit_workflow_mermaid( + self, + workflow: Workflow, + lines: list[str], + indent: str, + ns: str | None = None, + include_internal_executors: bool = False, + ) -> None: def _san(s: str) -> str: s2 = re.sub(r"[^0-9A-Za-z_]", "_", s) if not s2 or not s2[0].isalpha(): @@ -291,15 +363,15 @@ def map_id(x: str) -> str: return _san(x) # Nodes - start_executor_id = wf.start_executor_id + start_executor_id = workflow.start_executor_id lines.append(f'{indent}{map_id(start_executor_id)}["{start_executor_id} (Start)"];') - for executor_id in wf.executors: + for executor_id in workflow.executors: if executor_id == start_executor_id: continue lines.append(f'{indent}{map_id(executor_id)}["{executor_id}"];') # Fan-in nodes - fan_in_nodes_dot = self._compute_fan_in_descriptors(wf) + fan_in_nodes_dot = self._compute_fan_in_descriptors(workflow) fan_in_nodes: list[tuple[str, list[str], str]] = [] for dot_node_id, sources, target in fan_in_nodes_dot: digest = dot_node_id.split("::")[-1] @@ -318,7 +390,9 @@ def map_id(x: str) -> str: lines.append(f"{indent}{fan_node_id} --> {map_id(target)};") # Normal edges - for src, tgt, is_cond in self._compute_normal_edges(wf): + for src, tgt, is_cond in self._compute_normal_edges( + workflow, include_internal_executors=include_internal_executors + ): s = map_id(src) t = map_id(tgt) if is_cond: @@ -326,7 +400,13 @@ def map_id(x: str) -> str: else: lines.append(f"{indent}{s} --> {t};") - def _emit_sub_workflows_mermaid(self, wf: Workflow, lines: list[str], indent: str) -> None: + def _emit_sub_workflows_mermaid( + self, + workflow: Workflow, + lines: list[str], + indent: str, + include_internal_executors: bool = False, + ) -> None: try: from ._workflow_executor import WorkflowExecutor # type: ignore except ImportError: # pragma: no cover @@ -338,14 +418,25 @@ def _san(s: str) -> str: s2 = f"n_{s2}" return s2 - for exec_id, exec_obj in wf.executors.items(): + for exec_id, exec_obj in workflow.executors.items(): if isinstance(exec_obj, WorkflowExecutor) and hasattr(exec_obj, "workflow") and exec_obj.workflow: sg_id = _san(exec_id) lines.append(f"{indent}subgraph {sg_id}") # Render nested workflow within this subgraph using namespacing - self._emit_workflow_mermaid(exec_obj.workflow, lines, indent=f"{indent} ", ns=exec_id) + self._emit_workflow_mermaid( + exec_obj.workflow, + lines, + indent=f"{indent} ", + ns=exec_id, + include_internal_executors=include_internal_executors, + ) # Recurse into deeper sub-workflows - self._emit_sub_workflows_mermaid(exec_obj.workflow, lines, indent=f"{indent} ") + self._emit_sub_workflows_mermaid( + exec_obj.workflow, + lines, + indent=f"{indent} ", + include_internal_executors=include_internal_executors, + ) lines.append(f"{indent}end") # endregion diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index caa60fbef6..7b446926fc 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -13,7 +13,7 @@ from ..observability import OtelAttr, capture_exception, create_workflow_span from ._agent import WorkflowAgent from ._checkpoint import CheckpointStorage -from ._const import DEFAULT_MAX_ITERATIONS +from ._const import DEFAULT_MAX_ITERATIONS, WORKFLOW_RUN_KWARGS_KEY from ._edge import ( EdgeGroup, FanOutEdgeGroup, @@ -291,6 +291,7 @@ async def _run_workflow_with_tracing( initial_executor_fn: Callable[[], Awaitable[None]] | None = None, reset_context: bool = True, streaming: bool = False, + run_kwargs: dict[str, Any] | None = None, ) -> AsyncIterable[WorkflowEvent]: """Private method to run workflow with proper tracing. @@ -301,6 +302,7 @@ async def _run_workflow_with_tracing( initial_executor_fn: Optional function to execute initial executor reset_context: Whether to reset the context for a new run streaming: Whether to enable streaming mode for agents + run_kwargs: Optional kwargs to store in SharedState for agent invocations Yields: WorkflowEvent: The events generated during the workflow execution. @@ -335,6 +337,10 @@ async def _run_workflow_with_tracing( self._runner.context.reset_for_new_run() await self._shared_state.clear() + # Store run kwargs in SharedState so executors can access them + # Always store (even empty dict) so retrieval is deterministic + await self._shared_state.set(WORKFLOW_RUN_KWARGS_KEY, run_kwargs or {}) + # Set streaming mode after reset self._runner_context.set_streaming(streaming) @@ -442,6 +448,7 @@ async def run_stream( *, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, + **kwargs: Any, ) -> AsyncIterable[WorkflowEvent]: """Run the workflow and stream events. @@ -457,6 +464,9 @@ async def run_stream( - With checkpoint_id: Used to load and restore the specified checkpoint - Without checkpoint_id: Enables checkpointing for this run, overriding build-time configuration + **kwargs: Additional keyword arguments to pass through to agent invocations. + These are stored in SharedState and accessible in @ai_function tools + via the **kwargs parameter. Yields: WorkflowEvent: Events generated during workflow execution. @@ -475,6 +485,17 @@ async def run_stream( async for event in workflow.run_stream("start message"): process(event) + With custom context for ai_functions: + + .. code-block:: python + + async for event in workflow.run_stream( + "analyze data", + custom_data={"endpoint": "https://api.example.com"}, + user_token={"user": "alice"}, + ): + process(event) + Enable checkpointing at runtime: .. code-block:: python @@ -524,6 +545,7 @@ async def run_stream( ), reset_context=reset_context, streaming=True, + run_kwargs=kwargs if kwargs else None, ): yield event finally: @@ -559,6 +581,7 @@ async def run( checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, include_status_events: bool = False, + **kwargs: Any, ) -> WorkflowRunResult: """Run the workflow to completion and return all events. @@ -575,6 +598,9 @@ async def run( - Without checkpoint_id: Enables checkpointing for this run, overriding build-time configuration include_status_events: Whether to include WorkflowStatusEvent instances in the result list. + **kwargs: Additional keyword arguments to pass through to agent invocations. + These are stored in SharedState and accessible in @ai_function tools + via the **kwargs parameter. Returns: A WorkflowRunResult instance containing events generated during workflow execution. @@ -593,6 +619,16 @@ async def run( result = await workflow.run("start message") outputs = result.get_outputs() + With custom context for ai_functions: + + .. code-block:: python + + result = await workflow.run( + "analyze data", + custom_data={"endpoint": "https://api.example.com"}, + user_token={"user": "alice"}, + ) + Enable checkpointing at runtime: .. code-block:: python @@ -637,6 +673,7 @@ async def run( self._execute_with_message_or_checkpoint, message, checkpoint_id, checkpoint_storage ), reset_context=reset_context, + run_kwargs=kwargs if kwargs else None, ) ] finally: diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index cc028f337c..dccd76403b 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -11,6 +11,7 @@ from ._workflow import Workflow from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value +from ._const import WORKFLOW_RUN_KWARGS_KEY from ._events import ( RequestInfoEvent, WorkflowErrorEvent, @@ -366,8 +367,11 @@ async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) logger.debug(f"WorkflowExecutor {self.id} starting sub-workflow {self.workflow.id} execution {execution_id}") try: - # Run the sub-workflow and collect all events - result = await self.workflow.run(input_data) + # Get kwargs from parent workflow's SharedState to propagate to subworkflow + parent_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) or {} + + # Run the sub-workflow and collect all events, passing parent kwargs + result = await self.workflow.run(input_data, **parent_kwargs) logger.debug( f"WorkflowExecutor {self.id} sub-workflow {self.workflow.id} " diff --git a/python/packages/core/agent_framework/azure/_chat_client.py b/python/packages/core/agent_framework/azure/_chat_client.py index 544c0fdf5b..59f74259a4 100644 --- a/python/packages/core/agent_framework/azure/_chat_client.py +++ b/python/packages/core/agent_framework/azure/_chat_client.py @@ -21,7 +21,7 @@ use_function_invocation, ) from agent_framework.exceptions import ServiceInitializationError -from agent_framework.observability import use_observability +from agent_framework.observability import use_instrumentation from agent_framework.openai._chat_client import OpenAIBaseChatClient from ._shared import ( @@ -41,7 +41,7 @@ @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class AzureOpenAIChatClient(AzureOpenAIConfigMixin, OpenAIBaseChatClient): """Azure OpenAI Chat completion class.""" @@ -154,7 +154,7 @@ def __init__( ) @override - def _parse_text_from_choice(self, choice: Choice | ChunkChoice) -> TextContent | None: + def _parse_text_from_openai(self, choice: Choice | ChunkChoice) -> TextContent | None: """Parse the choice into a TextContent object. Overwritten from OpenAIBaseChatClient to deal with Azure On Your Data function. diff --git a/python/packages/core/agent_framework/azure/_responses_client.py b/python/packages/core/agent_framework/azure/_responses_client.py index 1d88d51688..3f6140eeeb 100644 --- a/python/packages/core/agent_framework/azure/_responses_client.py +++ b/python/packages/core/agent_framework/azure/_responses_client.py @@ -10,7 +10,7 @@ from agent_framework import use_chat_middleware, use_function_invocation from agent_framework.exceptions import ServiceInitializationError -from agent_framework.observability import use_observability +from agent_framework.observability import use_instrumentation from agent_framework.openai._responses_client import OpenAIBaseResponsesClient from ._shared import ( @@ -22,7 +22,7 @@ @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class AzureOpenAIResponsesClient(AzureOpenAIConfigMixin, OpenAIBaseResponsesClient): """Azure Responses completion class.""" diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index f3e1d9bd68..38fca796c1 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -3,15 +3,19 @@ import contextlib import json import logging +import os from collections.abc import AsyncIterable, Awaitable, Callable, Generator, Mapping from enum import Enum from functools import wraps from time import perf_counter, time_ns from typing import TYPE_CHECKING, Any, ClassVar, Final, TypeVar +from dotenv import load_dotenv from opentelemetry import metrics, trace +from opentelemetry.sdk.resources import Resource +from opentelemetry.semconv.attributes import service_attributes from opentelemetry.semconv_ai import GenAISystem, Meters, SpanAttributes -from pydantic import BaseModel, PrivateAttr +from pydantic import PrivateAttr from . import __version__ as version_info from ._logging import get_logger @@ -19,10 +23,9 @@ from .exceptions import AgentInitializationError, ChatClientInitializationError if TYPE_CHECKING: # pragma: no cover - from azure.core.credentials import TokenCredential from opentelemetry.sdk._logs.export import LogRecordExporter from opentelemetry.sdk.metrics.export import MetricExporter - from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.metrics.view import View from opentelemetry.sdk.trace.export import SpanExporter from opentelemetry.trace import Tracer from opentelemetry.util._decorator import _AgnosticContextManager # type: ignore[reportPrivateUsage] @@ -44,11 +47,14 @@ __all__ = [ "OBSERVABILITY_SETTINGS", "OtelAttr", + "configure_otel_providers", + "create_metric_views", + "create_resource", + "enable_instrumentation", "get_meter", "get_tracer", - "setup_observability", - "use_agent_observability", - "use_observability", + "use_agent_instrumentation", + "use_instrumentation", ] @@ -259,89 +265,293 @@ def __str__(self) -> str: # region Telemetry utils -def _get_otlp_exporters(endpoints: list[str]) -> list["LogRecordExporter | SpanExporter | MetricExporter"]: - """Create standard OTLP Exporters for the supplied endpoints.""" - from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter - from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter - from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +# Parse headers helper +def _parse_headers(header_str: str) -> dict[str, str]: + """Parse header string like 'key1=value1,key2=value2' into dict.""" + headers: dict[str, str] = {} + if not header_str: + return headers + for pair in header_str.split(","): + if "=" in pair: + key, value = pair.split("=", 1) + headers[key.strip()] = value.strip() + return headers + + +def _create_otlp_exporters( + endpoint: str | None = None, + protocol: str = "grpc", + headers: dict[str, str] | None = None, + traces_endpoint: str | None = None, + traces_headers: dict[str, str] | None = None, + metrics_endpoint: str | None = None, + metrics_headers: dict[str, str] | None = None, + logs_endpoint: str | None = None, + logs_headers: dict[str, str] | None = None, +) -> list["LogRecordExporter | SpanExporter | MetricExporter"]: + """Create OTLP exporters for a given endpoint and protocol. + + Args: + endpoint: The OTLP endpoint URL (used for all exporters if individual endpoints not specified). + protocol: The protocol to use ("grpc" or "http"). Default is "grpc". + headers: Optional headers to include in requests (used for all exporters if individual headers not specified). + traces_endpoint: Optional specific endpoint for traces. Overrides endpoint parameter. + traces_headers: Optional specific headers for traces. Overrides headers parameter. + metrics_endpoint: Optional specific endpoint for metrics. Overrides endpoint parameter. + metrics_headers: Optional specific headers for metrics. Overrides headers parameter. + logs_endpoint: Optional specific endpoint for logs. Overrides endpoint parameter. + logs_headers: Optional specific headers for logs. Overrides headers parameter. + + Returns: + List containing OTLPLogExporter, OTLPSpanExporter, and OTLPMetricExporter. + + Raises: + ImportError: If the required OTLP exporter package is not installed. + """ + # Determine actual endpoints and headers to use + actual_traces_endpoint = traces_endpoint or endpoint + actual_metrics_endpoint = metrics_endpoint or endpoint + actual_logs_endpoint = logs_endpoint or endpoint + actual_traces_headers = traces_headers or headers + actual_metrics_headers = metrics_headers or headers + actual_logs_headers = logs_headers or headers exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] = [] - for endpoint in endpoints: - exporters.append(OTLPLogExporter(endpoint=endpoint)) - exporters.append(OTLPSpanExporter(endpoint=endpoint)) - exporters.append(OTLPMetricExporter(endpoint=endpoint)) - return exporters + if not actual_logs_endpoint and not actual_traces_endpoint and not actual_metrics_endpoint: + return exporters + if protocol in ("grpc", "http/protobuf"): + # Import all gRPC exporters + try: + from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter as GRPCLogExporter + from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( + OTLPMetricExporter as GRPCMetricExporter, + ) + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter + except ImportError as exc: + raise ImportError( + "opentelemetry-exporter-otlp-proto-grpc is required for OTLP gRPC exporters. " + "Install it with: pip install opentelemetry-exporter-otlp-proto-grpc" + ) from exc + + if actual_logs_endpoint: + exporters.append( + GRPCLogExporter( + endpoint=actual_logs_endpoint, + headers=actual_logs_headers if actual_logs_headers else None, + ) + ) + if actual_traces_endpoint: + exporters.append( + GRPCSpanExporter( + endpoint=actual_traces_endpoint, + headers=actual_traces_headers if actual_traces_headers else None, + ) + ) + if actual_metrics_endpoint: + exporters.append( + GRPCMetricExporter( + endpoint=actual_metrics_endpoint, + headers=actual_metrics_headers if actual_metrics_headers else None, + ) + ) -def _get_azure_monitor_exporters( - connection_strings: list[str], - credential: "TokenCredential | None" = None, -) -> list["LogRecordExporter | SpanExporter | MetricExporter"]: - """Create Azure Monitor Exporters, based on the connection strings and optionally the credential.""" - try: - from azure.monitor.opentelemetry.exporter import ( - AzureMonitorLogExporter, - AzureMonitorMetricExporter, - AzureMonitorTraceExporter, - ) - except ImportError as e: - raise ImportError( - "azure-monitor-opentelemetry-exporter is required for Azure Monitor exporters. " - "Install it with: pip install azure-monitor-opentelemetry-exporter>=1.0.0b41" - ) from e + elif protocol == "http": + # Import all HTTP exporters + try: + from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter as HTTPLogExporter + from opentelemetry.exporter.otlp.proto.http.metric_exporter import ( + OTLPMetricExporter as HTTPMetricExporter, + ) + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter + except ImportError as exc: + raise ImportError( + "opentelemetry-exporter-otlp-proto-http is required for OTLP HTTP exporters. " + "Install it with: pip install opentelemetry-exporter-otlp-proto-http" + ) from exc + + if actual_logs_endpoint: + exporters.append( + HTTPLogExporter( + endpoint=actual_logs_endpoint, + headers=actual_logs_headers if actual_logs_headers else None, + ) + ) + if actual_traces_endpoint: + exporters.append( + HTTPSpanExporter( + endpoint=actual_traces_endpoint, + headers=actual_traces_headers if actual_traces_headers else None, + ) + ) + if actual_metrics_endpoint: + exporters.append( + HTTPMetricExporter( + endpoint=actual_metrics_endpoint, + headers=actual_metrics_headers if actual_metrics_headers else None, + ) + ) - exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] = [] - for conn_string in connection_strings: - exporters.append(AzureMonitorLogExporter(connection_string=conn_string, credential=credential)) - exporters.append(AzureMonitorTraceExporter(connection_string=conn_string, credential=credential)) - exporters.append(AzureMonitorMetricExporter(connection_string=conn_string, credential=credential)) return exporters -def get_exporters( - otlp_endpoints: list[str] | None = None, - connection_strings: list[str] | None = None, - credential: "TokenCredential | None" = None, +def _get_exporters_from_env( + env_file_path: str | None = None, + env_file_encoding: str | None = None, ) -> list["LogRecordExporter | SpanExporter | MetricExporter"]: - """Add additional exporters to the existing configuration. + """Parse OpenTelemetry environment variables and create exporters. + + This function reads standard OpenTelemetry environment variables to configure + OTLP exporters for traces, logs, and metrics. + + The following environment variables are supported: + - OTEL_EXPORTER_OTLP_ENDPOINT: Base endpoint for all signals + - OTEL_EXPORTER_OTLP_TRACES_ENDPOINT: Endpoint specifically for traces + - OTEL_EXPORTER_OTLP_METRICS_ENDPOINT: Endpoint specifically for metrics + - OTEL_EXPORTER_OTLP_LOGS_ENDPOINT: Endpoint specifically for logs + - OTEL_EXPORTER_OTLP_PROTOCOL: Protocol to use (grpc, http/protobuf) + - OTEL_EXPORTER_OTLP_HEADERS: Headers for all signals + - OTEL_EXPORTER_OTLP_TRACES_HEADERS: Headers specifically for traces + - OTEL_EXPORTER_OTLP_METRICS_HEADERS: Headers specifically for metrics + - OTEL_EXPORTER_OTLP_LOGS_HEADERS: Headers specifically for logs - If you supply exporters, those will be added to the relevant providers directly. - If you supply endpoints or connection strings, new exporters will be created and added. - OTLP_endpoints will be used to create a `OTLPLogExporter`, `OTLPMetricExporter` and `OTLPSpanExporter` - Connection_strings will be used to create AzureMonitorExporters. + Args: + env_file_path: Path to a .env file to load environment variables from. + Default is None, which loads from '.env' if present. + env_file_encoding: Encoding to use when reading the .env file. + Default is None, which uses the system default encoding. - If a endpoint or connection string is already configured, through the environment variables, it will be skipped. - If you call this method twice with the same additional endpoint or connection string, it will be added twice. + Returns: + List of configured exporters (empty if no relevant env vars are set). - Args: - otlp_endpoints: A list of OpenTelemetry Protocol (OTLP) endpoints. Default is None. - connection_strings: A list of Azure Monitor connection strings. Default is None. - credential: The credential to use for Azure Monitor Entra ID authentication. Default is None. + References: + - https://opentelemetry.io/docs/languages/sdk-configuration/general/ + - https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter/ """ - new_exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] = [] - if otlp_endpoints: - new_exporters.extend(_get_otlp_exporters(endpoints=otlp_endpoints)) - - if connection_strings: - new_exporters.extend( - _get_azure_monitor_exporters( - connection_strings=connection_strings, - credential=credential, + # Load environment variables from .env file if present + load_dotenv(dotenv_path=env_file_path, encoding=env_file_encoding) + + # Get base endpoint + base_endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT") + + # Get signal-specific endpoints (these override base endpoint) + traces_endpoint = os.getenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT") or base_endpoint + metrics_endpoint = os.getenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT") or base_endpoint + logs_endpoint = os.getenv("OTEL_EXPORTER_OTLP_LOGS_ENDPOINT") or base_endpoint + + # Get protocol (default is grpc) + protocol = os.getenv("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc").lower() + + # Get base headers + base_headers_str = os.getenv("OTEL_EXPORTER_OTLP_HEADERS", "") + base_headers = _parse_headers(base_headers_str) + + # Get signal-specific headers (these merge with base headers) + traces_headers_str = os.getenv("OTEL_EXPORTER_OTLP_TRACES_HEADERS", "") + metrics_headers_str = os.getenv("OTEL_EXPORTER_OTLP_METRICS_HEADERS", "") + logs_headers_str = os.getenv("OTEL_EXPORTER_OTLP_LOGS_HEADERS", "") + + traces_headers = {**base_headers, **_parse_headers(traces_headers_str)} + metrics_headers = {**base_headers, **_parse_headers(metrics_headers_str)} + logs_headers = {**base_headers, **_parse_headers(logs_headers_str)} + + # Create exporters using helper function + return _create_otlp_exporters( + protocol=protocol, + traces_endpoint=traces_endpoint, + traces_headers=traces_headers if traces_headers else None, + metrics_endpoint=metrics_endpoint, + metrics_headers=metrics_headers if metrics_headers else None, + logs_endpoint=logs_endpoint, + logs_headers=logs_headers if logs_headers else None, + ) + + +def create_resource( + service_name: str | None = None, + service_version: str | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + **attributes: Any, +) -> "Resource": + """Create an OpenTelemetry Resource from environment variables and parameters. + + This function reads standard OpenTelemetry environment variables to configure + the resource, which identifies your service in telemetry backends. + + The following environment variables are read: + - OTEL_SERVICE_NAME: The name of the service (defaults to "agent_framework") + - OTEL_SERVICE_VERSION: The version of the service (defaults to package version) + - OTEL_RESOURCE_ATTRIBUTES: Additional resource attributes as key=value pairs + + Args: + service_name: Override the service name. If not provided, reads from + OTEL_SERVICE_NAME environment variable or defaults to "agent_framework". + service_version: Override the service version. If not provided, reads from + OTEL_SERVICE_VERSION environment variable or defaults to the package version. + env_file_path: Path to a .env file to load environment variables from. + Default is None, which loads from '.env' if present. + env_file_encoding: Encoding to use when reading the .env file. + Default is None, which uses the system default encoding. + **attributes: Additional resource attributes to include. These will be merged + with attributes from OTEL_RESOURCE_ATTRIBUTES environment variable. + + Returns: + A configured OpenTelemetry Resource instance. + + Examples: + .. code-block:: python + + from agent_framework.observability import create_resource + + # Use defaults from environment variables + resource = create_resource() + + # Override service name + resource = create_resource(service_name="my_service") + + # Add custom attributes + resource = create_resource( + service_name="my_service", service_version="1.0.0", deployment_environment="production" ) - ) - return new_exporters + # Load from custom .env file + resource = create_resource(env_file_path="config/.env") + """ + # Load environment variables from .env file if present + load_dotenv(dotenv_path=env_file_path, encoding=env_file_encoding) + + # Start with provided attributes + resource_attributes: dict[str, Any] = dict(attributes) + + # Set service name + if service_name is None: + service_name = os.getenv("OTEL_SERVICE_NAME", "agent_framework") + resource_attributes[service_attributes.SERVICE_NAME] = service_name -def _create_resource() -> "Resource": - import os + # Set service version + if service_version is None: + service_version = os.getenv("OTEL_SERVICE_VERSION", version_info) + resource_attributes[service_attributes.SERVICE_VERSION] = service_version - from opentelemetry.sdk.resources import Resource - from opentelemetry.semconv.attributes import service_attributes + # Parse OTEL_RESOURCE_ATTRIBUTES environment variable + # Format: key1=value1,key2=value2 + if resource_attrs_env := os.getenv("OTEL_RESOURCE_ATTRIBUTES"): + resource_attributes.update(_parse_headers(resource_attrs_env)) + return Resource.create(resource_attributes) - service_name = os.getenv("OTEL_SERVICE_NAME", "agent_framework") - return Resource.create({service_attributes.SERVICE_NAME: service_name}) +def create_metric_views() -> list["View"]: + """Create the default OpenTelemetry metric views for Agent Framework.""" + from opentelemetry.sdk.metrics.view import DropAggregation, View + + return [ + # Dropping all enable_instrumentation names except for those starting with "agent_framework" + View(instrument_name="agent_framework*"), + View(instrument_name="gen_ai*"), + View(instrument_name="*", aggregation=DropAggregation()), + ] class ObservabilitySettings(AFBaseSettings): @@ -357,14 +567,12 @@ class ObservabilitySettings(AFBaseSettings): Sensitive events should only be enabled on test and development environments. Keyword Args: - enable_otel: Enable OpenTelemetry diagnostics. Default is False. - Can be set via environment variable ENABLE_OTEL. + enable_instrumentation: Enable OpenTelemetry diagnostics. Default is False. + Can be set via environment variable ENABLE_INSTRUMENTATION. enable_sensitive_data: Enable OpenTelemetry sensitive events. Default is False. Can be set via environment variable ENABLE_SENSITIVE_DATA. - applicationinsights_connection_string: The Azure Monitor connection string. Default is None. - Can be set via environment variable APPLICATIONINSIGHTS_CONNECTION_STRING. - otlp_endpoint: The OpenTelemetry Protocol (OTLP) endpoint. Default is None. - Can be set via environment variable OTLP_ENDPOINT. + enable_console_exporters: Enable console exporters for traces, logs, and metrics. + Default is False. Can be set via environment variable ENABLE_CONSOLE_EXPORTERS. vs_code_extension_port: The port the AI Toolkit or Azure AI Foundry VS Code extensions are listening on. Default is None. Can be set via environment variable VS_CODE_EXTENSION_PORT. @@ -375,33 +583,39 @@ class ObservabilitySettings(AFBaseSettings): from agent_framework import ObservabilitySettings # Using environment variables - # Set ENABLE_OTEL=true - # Set APPLICATIONINSIGHTS_CONNECTION_STRING=InstrumentationKey=... + # Set ENABLE_INSTRUMENTATION=true + # Set ENABLE_CONSOLE_EXPORTERS=true settings = ObservabilitySettings() # Or passing parameters directly - settings = ObservabilitySettings( - enable_otel=True, applicationinsights_connection_string="InstrumentationKey=..." - ) + settings = ObservabilitySettings(enable_instrumentation=True, enable_console_exporters=True) """ env_prefix: ClassVar[str] = "" - enable_otel: bool = False + enable_instrumentation: bool = False enable_sensitive_data: bool = False - applicationinsights_connection_string: str | list[str] | None = None - otlp_endpoint: str | list[str] | None = None + enable_console_exporters: bool = False vs_code_extension_port: int | None = None - _resource: "Resource" = PrivateAttr(default_factory=_create_resource) + _resource: "Resource" = PrivateAttr() _executed_setup: bool = PrivateAttr(default=False) + def __init__(self, **kwargs: Any) -> None: + """Initialize the settings and create the resource.""" + super().__init__(**kwargs) + # Create resource with env file settings + self._resource = create_resource( + env_file_path=self.env_file_path, + env_file_encoding=self.env_file_encoding, + ) + @property def ENABLED(self) -> bool: """Check if model diagnostics are enabled. Model diagnostics are enabled if either diagnostic is enabled or diagnostic with sensitive events is enabled. """ - return self.enable_otel or self.enable_sensitive_data + return self.enable_instrumentation @property def SENSITIVE_DATA_ENABLED(self) -> bool: @@ -409,27 +623,18 @@ def SENSITIVE_DATA_ENABLED(self) -> bool: Sensitive events are enabled if the diagnostic with sensitive events is enabled. """ - return self.enable_sensitive_data + return self.enable_instrumentation and self.enable_sensitive_data @property def is_setup(self) -> bool: """Check if the setup has been executed.""" return self._executed_setup - @property - def resource(self) -> "Resource": - """Get the resource.""" - return self._resource - - @resource.setter - def resource(self, value: "Resource") -> None: - """Set the resource.""" - self._resource = value - def _configure( self, - credential: "TokenCredential | None" = None, + *, additional_exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] | None = None, + views: list["View"] | None = None, ) -> None: """Configure application-wide observability based on the settings. @@ -438,120 +643,102 @@ def _configure( will have no effect. Args: - credential: The credential to use for Azure Monitor Entra ID authentication. Default is None. additional_exporters: A list of additional exporters to add to the configuration. Default is None. + views: Optional list of OpenTelemetry views for metrics. Default is None. """ if not self.ENABLED or self._executed_setup: return - exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] = additional_exporters or [] - if self.otlp_endpoint: - exporters.extend( - _get_otlp_exporters( - self.otlp_endpoint if isinstance(self.otlp_endpoint, list) else [self.otlp_endpoint] - ) - ) - if self.applicationinsights_connection_string: - exporters.extend( - _get_azure_monitor_exporters( - connection_strings=( - self.applicationinsights_connection_string - if isinstance(self.applicationinsights_connection_string, list) - else [self.applicationinsights_connection_string] - ), - credential=credential, - ) + exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] = [] + + # 1. Add exporters from standard OTEL environment variables + exporters.extend( + _get_exporters_from_env( + env_file_path=self.env_file_path, + env_file_encoding=self.env_file_encoding, ) - self._configure_providers(exporters) - self._executed_setup = True + ) - def check_endpoint_already_configured(self, otlp_endpoint: str) -> bool: - """Check if the endpoint is already configured. + # 2. Add passed-in exporters + if additional_exporters: + exporters.extend(additional_exporters) - Returns: - True if the endpoint is already configured, False otherwise. - """ - if not self.otlp_endpoint: - return False - return otlp_endpoint in (self.otlp_endpoint if isinstance(self.otlp_endpoint, list) else [self.otlp_endpoint]) + # 3. Add console exporters if explicitly enabled + if self.enable_console_exporters: + from opentelemetry.sdk._logs.export import ConsoleLogRecordExporter + from opentelemetry.sdk.metrics.export import ConsoleMetricExporter + from opentelemetry.sdk.trace.export import ConsoleSpanExporter - def check_connection_string_already_configured(self, connection_string: str) -> bool: - """Check if the connection string is already configured. + exporters.extend([ConsoleSpanExporter(), ConsoleLogRecordExporter(), ConsoleMetricExporter()]) - Returns: - True if the connection string is already configured, False otherwise. - """ - if not self.applicationinsights_connection_string: - return False - return connection_string in ( - self.applicationinsights_connection_string - if isinstance(self.applicationinsights_connection_string, list) - else [self.applicationinsights_connection_string] - ) + # 4. Add VS Code extension exporters if port is specified + if self.vs_code_extension_port: + endpoint = f"http://localhost:{self.vs_code_extension_port}" + exporters.extend(_create_otlp_exporters(endpoint=endpoint, protocol="grpc")) + + # 5. Configure providers + self._configure_providers(exporters, views=views) + self._executed_setup = True + + def _configure_providers( + self, + exporters: list["LogRecordExporter | MetricExporter | SpanExporter"], + views: list["View"] | None = None, + ) -> None: + """Configure tracing, logging, events and metrics with the provided exporters. - def _configure_providers(self, exporters: list["LogRecordExporter | MetricExporter | SpanExporter"]) -> None: - """Configure tracing, logging, events and metrics with the provided exporters.""" + Args: + exporters: A list of exporters for logs, metrics and/or spans. + views: Optional list of OpenTelemetry views for metrics. Default is empty list. + """ from opentelemetry._logs import set_logger_provider from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler from opentelemetry.sdk._logs.export import BatchLogRecordProcessor, LogRecordExporter from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics.export import MetricExporter, PeriodicExportingMetricReader - from opentelemetry.sdk.metrics.view import DropAggregation, View from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter + span_exporters: list[SpanExporter] = [] + log_exporters: list[LogRecordExporter] = [] + metric_exporters: list[MetricExporter] = [] + for exp in exporters: + if isinstance(exp, SpanExporter): + span_exporters.append(exp) + if isinstance(exp, LogRecordExporter): + log_exporters.append(exp) + if isinstance(exp, MetricExporter): + metric_exporters.append(exp) + # Tracing - tracer_provider = TracerProvider(resource=self.resource) - trace.set_tracer_provider(tracer_provider) - should_add_console_exporter = True - for exporter in exporters: - if isinstance(exporter, SpanExporter): + if span_exporters: + tracer_provider = TracerProvider(resource=self._resource) + trace.set_tracer_provider(tracer_provider) + for exporter in span_exporters: tracer_provider.add_span_processor(BatchSpanProcessor(exporter)) - should_add_console_exporter = False - if should_add_console_exporter: - from opentelemetry.sdk.trace.export import ConsoleSpanExporter - - tracer_provider.add_span_processor(BatchSpanProcessor(ConsoleSpanExporter())) # Logging - logger_provider = LoggerProvider(resource=self.resource) - should_add_console_exporter = True - for exporter in exporters: - if isinstance(exporter, LogRecordExporter): - logger_provider.add_log_record_processor(BatchLogRecordProcessor(exporter)) - should_add_console_exporter = False - if should_add_console_exporter: - from opentelemetry.sdk._logs.export import ConsoleLogRecordExporter - - logger_provider.add_log_record_processor(BatchLogRecordProcessor(ConsoleLogRecordExporter())) - - # Attach a handler with the provider to the root logger - logger = logging.getLogger() - handler = LoggingHandler(logger_provider=logger_provider) - logger.addHandler(handler) - set_logger_provider(logger_provider) + if log_exporters: + logger_provider = LoggerProvider(resource=self._resource) + for log_exporter in log_exporters: + logger_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter)) + # Attach a handler with the provider to the root logger + logger = logging.getLogger() + handler = LoggingHandler(logger_provider=logger_provider) + logger.addHandler(handler) + set_logger_provider(logger_provider) # metrics - metric_readers = [ - PeriodicExportingMetricReader(exporter, export_interval_millis=5000) - for exporter in exporters - if isinstance(exporter, MetricExporter) - ] - if not metric_readers: - from opentelemetry.sdk.metrics.export import ConsoleMetricExporter - - metric_readers = [PeriodicExportingMetricReader(ConsoleMetricExporter(), export_interval_millis=5000)] - meter_provider = MeterProvider( - metric_readers=metric_readers, - resource=self.resource, - views=[ - # Dropping all instrument names except for those starting with "agent_framework" - View(instrument_name="*", aggregation=DropAggregation()), - View(instrument_name="agent_framework*"), - View(instrument_name="gen_ai*"), - ], - ) - metrics.set_meter_provider(meter_provider) + if metric_exporters: + meter_provider = MeterProvider( + metric_readers=[ + PeriodicExportingMetricReader(exporter, export_interval_millis=5000) + for exporter in metric_exporters + ], + resource=self._resource, + views=views or [], + ) + metrics.set_meter_provider(meter_provider) def get_tracer( @@ -661,125 +848,174 @@ def get_meter( OBSERVABILITY_SETTINGS: ObservabilitySettings = ObservabilitySettings() -def setup_observability( +def enable_instrumentation( + *, + enable_sensitive_data: bool | None = None, +) -> None: + """Enable instrumentation for your application. + + Calling this method implies you want to enable observability in your application. + + This method does not configure exporters or providers. + It only updates the global variables that trigger the instrumentation code. + If you have already set the environment variable ENABLE_INSTRUMENTATION=true, + calling this method has no effect, unless you want to enable or disable sensitive data events. + + Keyword Args: + enable_sensitive_data: Enable OpenTelemetry sensitive events. Overrides + the environment variable ENABLE_SENSITIVE_DATA if set. Default is None. + """ + global OBSERVABILITY_SETTINGS + OBSERVABILITY_SETTINGS.enable_instrumentation = True + if enable_sensitive_data is not None: + OBSERVABILITY_SETTINGS.enable_sensitive_data = enable_sensitive_data + + +def configure_otel_providers( + *, enable_sensitive_data: bool | None = None, - otlp_endpoint: str | list[str] | None = None, - applicationinsights_connection_string: str | list[str] | None = None, - credential: "TokenCredential | None" = None, exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] | None = None, + views: list["View"] | None = None, vs_code_extension_port: int | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, ) -> None: - """Setup observability for the application with OpenTelemetry. + """Configure otel providers and enable instrumentation for the application with OpenTelemetry. This method creates the exporters and providers for the application based on - the provided values and environment variables. + the provided values and environment variables and enables instrumentation. Call this method once during application startup, before any telemetry is captured. DO NOT call this method multiple times, as it may lead to unexpected behavior. + The function automatically reads standard OpenTelemetry environment variables: + - OTEL_EXPORTER_OTLP_ENDPOINT: Base OTLP endpoint for all signals + - OTEL_EXPORTER_OTLP_TRACES_ENDPOINT: OTLP endpoint for traces + - OTEL_EXPORTER_OTLP_METRICS_ENDPOINT: OTLP endpoint for metrics + - OTEL_EXPORTER_OTLP_LOGS_ENDPOINT: OTLP endpoint for logs + - OTEL_EXPORTER_OTLP_PROTOCOL: Protocol (grpc/http) + - OTEL_EXPORTER_OTLP_HEADERS: Headers for all signals + - ENABLE_CONSOLE_EXPORTERS: Enable console output for telemetry + Note: - If you have configured the providers manually, calling this method will not - have any effect. The reverse is also true - if you call this method first, - subsequent provider configurations will not take effect. + Since you can only setup one provider per signal type (logs, traces, metrics), + you can choose to use this method and take the exporter and provider that we created. + Alternatively, you can setup the providers yourself, or through another library + (e.g., Azure Monitor) and just call `enable_instrumentation()` to enable instrumentation. - Args: + Note: + By default, the Agent Framework emits metrics with the prefixes `agent_framework` + and `gen_ai` (OpenTelemetry GenAI semantic conventions). You can use the `views` + parameter to filter which metrics are collected and exported. You can also use + the `create_metric_views()` helper function to get default views. + + Keyword Args: enable_sensitive_data: Enable OpenTelemetry sensitive events. Overrides - the environment variable if set. Default is None. - otlp_endpoint: The OpenTelemetry Protocol (OTLP) endpoint. Will be used - to create OTLPLogExporter, OTLPMetricExporter and OTLPSpanExporter. - Default is None. - applicationinsights_connection_string: The Azure Monitor connection string. - Will be used to create AzureMonitorExporters. Default is None. - credential: The credential to use for Azure Monitor Entra ID authentication. + the environment variable ENABLE_SENSITIVE_DATA if set. Default is None. + exporters: A list of custom exporters for logs, metrics or spans, or any combination. + These will be added in addition to exporters configured via environment variables. Default is None. - exporters: A list of exporters for logs, metrics or spans, or any combination. - These will be added directly, allowing complete customization. Default is None. - vs_code_extension_port: The port the AI Toolkit or AzureAI Foundry VS Code + views: Optional list of OpenTelemetry views for metrics configuration. + Views allow filtering and customizing which metrics are collected. + Default is None (empty list). + vs_code_extension_port: The port the AI Toolkit or Azure AI Foundry VS Code extensions are listening on. When set, additional OTEL exporters will be - created with endpoint `http://localhost:{vs_code_extension_port}` unless - already configured. Overrides the environment variable if set. Default is None. + created with endpoint `http://localhost:{vs_code_extension_port}`. + Overrides the environment variable VS_CODE_EXTENSION_PORT if set. Default is None. + env_file_path: An optional path to a .env file to load environment variables from. + Default is None. + env_file_encoding: The encoding to use when loading the .env file. Default is None + which uses the system default encoding. Examples: .. code-block:: python - from agent_framework import setup_observability - - # With environment variables - # Set ENABLE_OTEL=true, OTLP_ENDPOINT=http://localhost:4317 - setup_observability() + from agent_framework.observability import configure_otel_providers - # With parameters (no environment variables) - setup_observability( - enable_sensitive_data=True, - otlp_endpoint="http://localhost:4317", - ) + # Using environment variables (recommended) + # Set ENABLE_INSTRUMENTATION=true + # Set OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4317 + configure_otel_providers() - # With Azure Monitor - setup_observability( - applicationinsights_connection_string="InstrumentationKey=...", - ) + # Enable console output for debugging + # Set ENABLE_CONSOLE_EXPORTERS=true + configure_otel_providers() # With custom exporters - from opentelemetry.sdk.trace.export import ConsoleSpanExporter + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter - setup_observability( - exporters=[ConsoleSpanExporter()], - ) - - # Mixed: combine environment variables and parameters - # Environment: OTLP_ENDPOINT=http://localhost:7431 - # Code adds additional endpoint - setup_observability( - enable_sensitive_data=True, - otlp_endpoint="http://localhost:4317", # Both endpoints will be used + configure_otel_providers( + exporters=[ + OTLPSpanExporter(endpoint="http://custom:4317"), + OTLPLogExporter(endpoint="http://custom:4317"), + ], ) # VS Code extension integration - setup_observability( + configure_otel_providers( vs_code_extension_port=4317, # Connects to AI Toolkit ) - """ - global OBSERVABILITY_SETTINGS - # Update the observability settings with the provided values - OBSERVABILITY_SETTINGS.enable_otel = True - if enable_sensitive_data is not None: - OBSERVABILITY_SETTINGS.enable_sensitive_data = enable_sensitive_data - if vs_code_extension_port is not None: - OBSERVABILITY_SETTINGS.vs_code_extension_port = vs_code_extension_port - - # Create exporters, after checking if they are already configured through the env. - new_exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] = exporters or [] - if otlp_endpoint: - if isinstance(otlp_endpoint, str): - otlp_endpoint = [otlp_endpoint] - new_exporters.extend( - _get_otlp_exporters( - endpoints=[ - endpoint - for endpoint in otlp_endpoint - if not OBSERVABILITY_SETTINGS.check_endpoint_already_configured(endpoint) - ] + + # Enable sensitive data logging (development only) + configure_otel_providers( + enable_sensitive_data=True, ) - ) - if applicationinsights_connection_string: - if isinstance(applicationinsights_connection_string, str): - applicationinsights_connection_string = [applicationinsights_connection_string] - new_exporters.extend( - _get_azure_monitor_exporters( - connection_strings=[ - conn_str - for conn_str in applicationinsights_connection_string - if not OBSERVABILITY_SETTINGS.check_connection_string_already_configured(conn_str) + + # With custom metrics views + from opentelemetry.sdk.metrics.view import View + + configure_otel_providers( + views=[ + View(instrument_name="agent_framework*"), + View(instrument_name="gen_ai*"), ], - credential=credential, ) - ) - if OBSERVABILITY_SETTINGS.vs_code_extension_port: - endpoint = f"http://localhost:{OBSERVABILITY_SETTINGS.vs_code_extension_port}" - if not OBSERVABILITY_SETTINGS.check_endpoint_already_configured(endpoint): - new_exporters.extend(_get_otlp_exporters(endpoints=[endpoint])) - OBSERVABILITY_SETTINGS._configure(credential=credential, additional_exporters=new_exporters) # pyright: ignore[reportPrivateUsage] + This example shows how to first setup your providers, + and then ensure Agent Framework emits traces, logs and metrics + + .. code-block:: python + + # when azure monitor is installed + from agent_framework.observability import enable_instrumentation + from azure.monitor.opentelemetry import configure_azure_monitor + + connection_string = "InstrumentationKey=your_instrumentation_key_here;..." + configure_azure_monitor(connection_string=connection_string) + enable_instrumentation() + + References: + - https://opentelemetry.io/docs/languages/sdk-configuration/general/ + - https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter/ + """ + global OBSERVABILITY_SETTINGS + if env_file_path: + # Build kwargs, excluding None values + settings_kwargs: dict[str, Any] = { + "enable_instrumentation": True, + "env_file_path": env_file_path, + } + if env_file_encoding is not None: + settings_kwargs["env_file_encoding"] = env_file_encoding + if enable_sensitive_data is not None: + settings_kwargs["enable_sensitive_data"] = enable_sensitive_data + if vs_code_extension_port is not None: + settings_kwargs["vs_code_extension_port"] = vs_code_extension_port + + OBSERVABILITY_SETTINGS = ObservabilitySettings(**settings_kwargs) + else: + # Update the observability settings with the provided values + OBSERVABILITY_SETTINGS.enable_instrumentation = True + if enable_sensitive_data is not None: + OBSERVABILITY_SETTINGS.enable_sensitive_data = enable_sensitive_data + if vs_code_extension_port is not None: + OBSERVABILITY_SETTINGS.vs_code_extension_port = vs_code_extension_port + + OBSERVABILITY_SETTINGS._configure( # type: ignore[reportPrivateUsage] + additional_exporters=exporters, + views=views, + ) # region Chat Client Telemetry @@ -993,7 +1229,7 @@ async def trace_get_streaming_response( return decorator(func) -def use_observability( +def use_instrumentation( chat_client: type[TChatClient], ) -> type[TChatClient]: """Class decorator that enables OpenTelemetry observability for a chat client. @@ -1019,12 +1255,12 @@ def use_observability( Examples: .. code-block:: python - from agent_framework import use_observability, setup_observability + from agent_framework import use_instrumentation, configure_otel_providers from agent_framework import ChatClientProtocol # Decorate a custom chat client class - @use_observability + @use_instrumentation class MyCustomChatClient: OTEL_PROVIDER_NAME = "my_provider" @@ -1038,7 +1274,7 @@ async def get_streaming_response(self, messages, **kwargs): # Setup observability - setup_observability(otlp_endpoint="http://localhost:4317") + configure_otel_providers(otlp_endpoint="http://localhost:4317") # Now all calls will be traced client = MyCustomChatClient() @@ -1082,12 +1318,14 @@ async def get_streaming_response(self, messages, **kwargs): def _trace_agent_run( run_func: Callable[..., Awaitable["AgentRunResponse"]], provider_name: str, + capture_usage: bool = True, ) -> Callable[..., Awaitable["AgentRunResponse"]]: """Decorator to trace chat completion activities. Args: run_func: The function to trace. provider_name: The system name used for Open Telemetry. + capture_usage: Whether to capture token usage as a span attribute. """ @wraps(run_func) @@ -1128,7 +1366,7 @@ async def trace_run( capture_exception(span=span, exception=exception, timestamp=time_ns()) raise else: - attributes = _get_response_attributes(attributes, response) + attributes = _get_response_attributes(attributes, response, capture_usage=capture_usage) _capture_response(span=span, attributes=attributes) if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and response.messages: _capture_messages( @@ -1145,12 +1383,14 @@ async def trace_run( def _trace_agent_run_stream( run_streaming_func: Callable[..., AsyncIterable["AgentRunResponseUpdate"]], provider_name: str, + capture_usage: bool, ) -> Callable[..., AsyncIterable["AgentRunResponseUpdate"]]: """Decorator to trace streaming agent run activities. Args: run_streaming_func: The function to trace. provider_name: The system name used for Open Telemetry. + capture_usage: Whether to capture token usage as a span attribute. """ @wraps(run_streaming_func) @@ -1201,7 +1441,7 @@ async def trace_run_streaming( raise else: response = AgentRunResponse.from_agent_run_response_updates(all_updates) - attributes = _get_response_attributes(attributes, response) + attributes = _get_response_attributes(attributes, response, capture_usage=capture_usage) _capture_response(span=span, attributes=attributes) if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and response.messages: _capture_messages( @@ -1214,9 +1454,11 @@ async def trace_run_streaming( return trace_run_streaming -def use_agent_observability( - agent: type[TAgent], -) -> type[TAgent]: +def use_agent_instrumentation( + agent: type[TAgent] | None = None, + *, + capture_usage: bool = True, +) -> type[TAgent] | Callable[[type[TAgent]], type[TAgent]]: """Class decorator that enables OpenTelemetry observability for an agent. This decorator automatically traces agent run requests, captures events, @@ -1224,12 +1466,17 @@ def use_agent_observability( Note: This decorator must be applied to the agent class itself, not an instance. - The agent class should have a class variable AGENT_SYSTEM_NAME to set the + The agent class should have a class variable AGENT_PROVIDER_NAME to set the proper system name for telemetry. Args: agent: The agent class to enable observability for. + Keyword Args: + capture_usage: Whether to capture token usage as a span attribute. + Defaults to True, set to False when the agent has underlying traces + that already capture token usage to avoid double counting. + Returns: The decorated agent class with observability enabled. @@ -1240,14 +1487,14 @@ def use_agent_observability( Examples: .. code-block:: python - from agent_framework import use_agent_observability, setup_observability + from agent_framework import use_agent_instrumentation, configure_otel_providers from agent_framework._agents import AgentProtocol # Decorate a custom agent class - @use_agent_observability + @use_agent_instrumentation class MyCustomAgent: - AGENT_SYSTEM_NAME = "my_agent_system" + AGENT_PROVIDER_NAME = "my_agent_system" async def run(self, messages=None, *, thread=None, **kwargs): # Your implementation @@ -1259,23 +1506,31 @@ async def run_stream(self, messages=None, *, thread=None, **kwargs): # Setup observability - setup_observability(otlp_endpoint="http://localhost:4317") + configure_otel_providers(otlp_endpoint="http://localhost:4317") # Now all agent runs will be traced agent = MyCustomAgent() response = await agent.run("Perform a task") """ - provider_name = str(getattr(agent, "AGENT_SYSTEM_NAME", "Unknown")) - try: - agent.run = _trace_agent_run(agent.run, provider_name) # type: ignore - except AttributeError as exc: - raise AgentInitializationError(f"The agent {agent.__name__} does not have a run method.", exc) from exc - try: - agent.run_stream = _trace_agent_run_stream(agent.run_stream, provider_name) # type: ignore - except AttributeError as exc: - raise AgentInitializationError(f"The agent {agent.__name__} does not have a run_stream method.", exc) from exc - setattr(agent, OPEN_TELEMETRY_AGENT_MARKER, True) - return agent + + def decorator(agent: type[TAgent]) -> type[TAgent]: + provider_name = str(getattr(agent, "AGENT_PROVIDER_NAME", "Unknown")) + try: + agent.run = _trace_agent_run(agent.run, provider_name, capture_usage=capture_usage) # type: ignore + except AttributeError as exc: + raise AgentInitializationError(f"The agent {agent.__name__} does not have a run method.", exc) from exc + try: + agent.run_stream = _trace_agent_run_stream(agent.run_stream, provider_name, capture_usage=capture_usage) # type: ignore + except AttributeError as exc: + raise AgentInitializationError( + f"The agent {agent.__name__} does not have a run_stream method.", exc + ) from exc + setattr(agent, OPEN_TELEMETRY_AGENT_MARKER, True) + return agent + + if agent is None: + return decorator + return decorator(agent) # region Otel Helpers @@ -1458,26 +1713,32 @@ def _to_otel_part(content: "Contents") -> dict[str, Any] | None: match content.type: case "text": return {"type": "text", "content": content.text} + case "text_reasoning": + return {"type": "reasoning", "content": content.text} + case "uri": + return { + "type": "uri", + "uri": content.uri, + "mime_type": content.media_type, + "modality": content.media_type.split("/")[0] if content.media_type else None, + } + case "data": + return { + "type": "blob", + "content": content.get_data_bytes_as_str(), + "mime_type": content.media_type, + "modality": content.media_type.split("/")[0] if content.media_type else None, + } case "function_call": return {"type": "tool_call", "id": content.call_id, "name": content.name, "arguments": content.arguments} case "function_result": - response: Any | None = None - if content.result: - if isinstance(content.result, list): - res: list[Any] = [] - for item in content.result: # type: ignore - from ._types import BaseContent - - if isinstance(item, BaseContent): - res.append(_to_otel_part(item)) # type: ignore - elif isinstance(item, BaseModel): - res.append(item.model_dump(exclude_none=True)) - else: - res.append(json.dumps(item, default=str)) - response = json.dumps(res, default=str) - else: - response = json.dumps(content.result, default=str) - return {"type": "tool_call_response", "id": content.call_id, "response": response} + from ._types import prepare_function_call_results + + return { + "type": "tool_call_response", + "id": content.call_id, + "response": prepare_function_call_results(content), + } case _: # GenericPart in otel output messages json spec. # just required type, and arbitrary other fields. @@ -1489,6 +1750,8 @@ def _get_response_attributes( attributes: dict[str, Any], response: "ChatResponse | AgentRunResponse", duration: float | None = None, + *, + capture_usage: bool = True, ) -> dict[str, Any]: """Get the response attributes from a response.""" if response.response_id: @@ -1502,7 +1765,7 @@ def _get_response_attributes( attributes[OtelAttr.FINISH_REASONS] = json.dumps([finish_reason.value]) if model_id := getattr(response, "model_id", None): attributes[SpanAttributes.LLM_RESPONSE_MODEL] = model_id - if usage := response.usage_details: + if capture_usage and (usage := response.usage_details): if usage.input_token_count: attributes[OtelAttr.INPUT_TOKENS] = usage.input_token_count if usage.output_token_count: diff --git a/python/packages/core/agent_framework/ollama/__init__.py b/python/packages/core/agent_framework/ollama/__init__.py new file mode 100644 index 0000000000..eae73853c2 --- /dev/null +++ b/python/packages/core/agent_framework/ollama/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft. All rights reserved. + +import importlib +from typing import Any + +IMPORT_PATH = "agent_framework_ollama" +PACKAGE_NAME = "agent-framework-ollama" +_IMPORTS = ["__version__", "OllamaChatClient", "OllamaSettings"] + + +def __getattr__(name: str) -> Any: + if name in _IMPORTS: + try: + return getattr(importlib.import_module(IMPORT_PATH), name) + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + f"The '{PACKAGE_NAME}' package is not installed, please do `pip install {PACKAGE_NAME}`" + ) from exc + raise AttributeError(f"Module {IMPORT_PATH} has no attribute {name}.") + + +def __dir__() -> list[str]: + return _IMPORTS diff --git a/python/packages/core/agent_framework/ollama/__init__.pyi b/python/packages/core/agent_framework/ollama/__init__.pyi new file mode 100644 index 0000000000..3a1e7824d6 --- /dev/null +++ b/python/packages/core/agent_framework/ollama/__init__.pyi @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft. All rights reserved. + +from agent_framework_ollama import ( + OllamaChatClient, + OllamaSettings, + __version__, +) + +__all__ = [ + "OllamaChatClient", + "OllamaSettings", + "__version__", +] diff --git a/python/packages/core/agent_framework/openai/_assistants_client.py b/python/packages/core/agent_framework/openai/_assistants_client.py index 0f3bb3de63..e790a44940 100644 --- a/python/packages/core/agent_framework/openai/_assistants_client.py +++ b/python/packages/core/agent_framework/openai/_assistants_client.py @@ -40,7 +40,7 @@ prepare_function_call_results, ) from ..exceptions import ServiceInitializationError -from ..observability import use_observability +from ..observability import use_instrumentation from ._shared import OpenAIConfigMixin, OpenAISettings if sys.version_info >= (3, 11): @@ -53,7 +53,7 @@ @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class OpenAIAssistantsClient(OpenAIConfigMixin, BaseChatClient): """OpenAI Assistants client.""" @@ -164,7 +164,7 @@ async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseExc async def close(self) -> None: """Clean up any assistants we created.""" if self._should_delete_assistant and self.assistant_id is not None: - client = await self.ensure_client() + client = await self._ensure_client() await client.beta.assistants.delete(self.assistant_id) object.__setattr__(self, "assistant_id", None) object.__setattr__(self, "_should_delete_assistant", False) @@ -188,7 +188,7 @@ async def _inner_get_streaming_response( chat_options: ChatOptions, **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: - # Extract necessary state from messages and options + # prepare run_options, tool_results = self._prepare_options(messages, chat_options, **kwargs) # Get the thread ID @@ -204,10 +204,10 @@ async def _inner_get_streaming_response( # Determine which assistant to use and create if needed assistant_id = await self._get_assistant_id_or_create() - # Create the streaming response + # execute stream, thread_id = await self._create_assistant_stream(thread_id, assistant_id, run_options, tool_results) - # Process and yield each update from the stream + # process async for update in self._process_stream_events(stream, thread_id): yield update @@ -222,7 +222,7 @@ async def _get_assistant_id_or_create(self) -> str: if not self.model_id: raise ServiceInitializationError("Parameter 'model_id' is required for assistant creation.") - client = await self.ensure_client() + client = await self._ensure_client() created_assistant = await client.beta.assistants.create( model=self.model_id, description=self.assistant_description, @@ -245,11 +245,11 @@ async def _create_assistant_stream( Returns: tuple: (stream, final_thread_id) """ - client = await self.ensure_client() + client = await self._ensure_client() # Get any active run for this thread thread_run = await self._get_active_thread_run(thread_id) - tool_run_id, tool_outputs = self._convert_function_results_to_tool_output(tool_results) + tool_run_id, tool_outputs = self._prepare_tool_outputs_for_assistants(tool_results) if thread_run is not None and tool_run_id is not None and tool_run_id == thread_run.id and tool_outputs: # There's an active run and we have tool results to submit, so submit the results. @@ -270,7 +270,7 @@ async def _create_assistant_stream( async def _get_active_thread_run(self, thread_id: str | None) -> Run | None: """Get any active run for the given thread.""" - client = await self.ensure_client() + client = await self._ensure_client() if thread_id is None: return None @@ -281,7 +281,7 @@ async def _get_active_thread_run(self, thread_id: str | None) -> Run | None: async def _prepare_thread(self, thread_id: str | None, thread_run: Run | None, run_options: dict[str, Any]) -> str: """Prepare the thread for a new run, creating or cleaning up as needed.""" - client = await self.ensure_client() + client = await self._ensure_client() if thread_id is None: # No thread ID was provided, so create a new thread. thread = await client.beta.threads.create( # type: ignore[reportDeprecated] @@ -330,7 +330,7 @@ async def _process_stream_events(self, stream: Any, thread_id: str) -> AsyncIter response_id=response_id, ) elif response.event == "thread.run.requires_action" and isinstance(response.data, Run): - contents = self._create_function_call_contents(response.data, response_id) + contents = self._parse_function_calls_from_assistants(response.data, response_id) if contents: yield ChatResponseUpdate( role=Role.ASSISTANT, @@ -371,8 +371,8 @@ async def _process_stream_events(self, stream: Any, thread_id: str) -> AsyncIter role=Role.ASSISTANT, ) - def _create_function_call_contents(self, event_data: Run, response_id: str | None) -> list[Contents]: - """Create function call contents from a tool action event.""" + def _parse_function_calls_from_assistants(self, event_data: Run, response_id: str | None) -> list[Contents]: + """Parse function call contents from an assistants tool action event.""" contents: list[Contents] = [] if event_data.required_action is not None: @@ -437,7 +437,10 @@ def _prepare_options( if chat_options.response_format is not None: run_options["response_format"] = { "type": "json_schema", - "json_schema": chat_options.response_format.model_json_schema(), + "json_schema": { + "name": chat_options.response_format.__name__, + "schema": chat_options.response_format.model_json_schema(), + }, } instructions: list[str] = [] @@ -487,10 +490,11 @@ def _prepare_options( return run_options, tool_results - def _convert_function_results_to_tool_output( + def _prepare_tool_outputs_for_assistants( self, tool_results: list[FunctionResultContent] | None, ) -> tuple[str | None, list[ToolOutput] | None]: + """Prepare function results for submission to the assistants API.""" run_id: str | None = None tool_outputs: list[ToolOutput] | None = None diff --git a/python/packages/core/agent_framework/openai/_chat_client.py b/python/packages/core/agent_framework/openai/_chat_client.py index 73605fadef..940858b670 100644 --- a/python/packages/core/agent_framework/openai/_chat_client.py +++ b/python/packages/core/agent_framework/openai/_chat_client.py @@ -14,7 +14,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice from openai.types.chat.chat_completion_message_custom_tool_call import ChatCompletionMessageCustomToolCall -from pydantic import BaseModel, ValidationError +from pydantic import ValidationError from .._clients import BaseChatClient from .._logging import get_logger @@ -44,7 +44,7 @@ ServiceInvalidRequestError, ServiceResponseException, ) -from ..observability import use_observability +from ..observability import use_instrumentation from ._exceptions import OpenAIContentFilterException from ._shared import OpenAIBase, OpenAIConfigMixin, OpenAISettings @@ -69,10 +69,12 @@ async def _inner_get_response( chat_options: ChatOptions, **kwargs: Any, ) -> ChatResponse: - client = await self.ensure_client() + client = await self._ensure_client() + # prepare options_dict = self._prepare_options(messages, chat_options) try: - return self._create_chat_response( + # execute and process + return self._parse_response_from_openai( await client.chat.completions.create(stream=False, **options_dict), chat_options ) except BadRequestError as ex: @@ -98,14 +100,16 @@ async def _inner_get_streaming_response( chat_options: ChatOptions, **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: - client = await self.ensure_client() + client = await self._ensure_client() + # prepare options_dict = self._prepare_options(messages, chat_options) options_dict["stream_options"] = {"include_usage": True} try: + # execute and process async for chunk in await client.chat.completions.create(stream=True, **options_dict): if len(chunk.choices) == 0 and chunk.usage is None: continue - yield self._create_chat_response_update(chunk) + yield self._parse_response_update_from_openai(chunk) except BadRequestError as ex: if ex.code == "content_filter": raise OpenAIContentFilterException( @@ -124,7 +128,9 @@ async def _inner_get_streaming_response( # region content creation - def _chat_to_tool_spec(self, tools: Sequence[ToolProtocol | MutableMapping[str, Any]]) -> list[dict[str, Any]]: + def _prepare_tools_for_openai( + self, tools: Sequence[ToolProtocol | MutableMapping[str, Any]] + ) -> list[dict[str, Any]]: chat_tools: list[dict[str, Any]] = [] for tool in tools: if isinstance(tool, ToolProtocol): @@ -157,48 +163,65 @@ def _process_web_search_tool( return None def _prepare_options(self, messages: MutableSequence[ChatMessage], chat_options: ChatOptions) -> dict[str, Any]: - # Preprocess web search tool if it exists - options_dict = chat_options.to_dict( + run_options = chat_options.to_dict( exclude={ "type", "instructions", # included as system message + "response_format", # handled separately + "additional_properties", # handled separately } ) - if messages and "messages" not in options_dict: - options_dict["messages"] = self._prepare_chat_history_for_request(messages) - if "messages" not in options_dict: + # messages + if messages and "messages" not in run_options: + run_options["messages"] = self._prepare_messages_for_openai(messages) + if "messages" not in run_options: raise ServiceInvalidRequestError("Messages are required for chat completions") + + # Translation between ChatOptions and Chat Completion API + translations = { + "model_id": "model", + "allow_multiple_tool_calls": "parallel_tool_calls", + "max_tokens": "max_output_tokens", + } + for old_key, new_key in translations.items(): + if old_key in run_options and old_key != new_key: + run_options[new_key] = run_options.pop(old_key) + + # model id + if not run_options.get("model"): + if not self.model_id: + raise ValueError("model_id must be a non-empty string") + run_options["model"] = self.model_id + + # tools if chat_options.tools is not None: - web_search_options = self._process_web_search_tool(chat_options.tools) - if web_search_options: - options_dict["web_search_options"] = web_search_options - options_dict["tools"] = self._chat_to_tool_spec(chat_options.tools) - if not options_dict.get("tools", None): - options_dict.pop("tools", None) - options_dict.pop("parallel_tool_calls", None) - options_dict.pop("tool_choice", None) - - if "model_id" not in options_dict: - options_dict["model"] = self.model_id - else: - options_dict["model"] = options_dict.pop("model_id") - if ( - chat_options.response_format - and isinstance(chat_options.response_format, type) - and issubclass(chat_options.response_format, BaseModel) - ): - options_dict["response_format"] = type_to_response_format_param(chat_options.response_format) - if additional_properties := options_dict.pop("additional_properties", None): - for key, value in additional_properties.items(): - if value is not None: - options_dict[key] = value - if (tool_choice := options_dict.get("tool_choice")) and len(tool_choice.keys()) == 1: - options_dict["tool_choice"] = tool_choice["mode"] - return options_dict - - def _create_chat_response(self, response: ChatCompletion, chat_options: ChatOptions) -> "ChatResponse": - """Create a chat message content object from a choice.""" + # Preprocess web search tool if it exists + if web_search_options := self._process_web_search_tool(chat_options.tools): + run_options["web_search_options"] = web_search_options + run_options["tools"] = self._prepare_tools_for_openai(chat_options.tools) + if not run_options.get("tools", None): + run_options.pop("tools", None) + run_options.pop("parallel_tool_calls", None) + run_options.pop("tool_choice", None) + # tool choice when `tool_choice` is a dict with single key `mode`, extract the mode value + if (tool_choice := run_options.get("tool_choice")) and len(tool_choice.keys()) == 1: + run_options["tool_choice"] = tool_choice["mode"] + + # response format + if chat_options.response_format: + run_options["response_format"] = type_to_response_format_param(chat_options.response_format) + + # additional properties + additional_options = { + key: value for key, value in chat_options.additional_properties.items() if value is not None + } + if additional_options: + run_options.update(additional_options) + return run_options + + def _parse_response_from_openai(self, response: ChatCompletion, chat_options: ChatOptions) -> "ChatResponse": + """Parse a response from OpenAI into a ChatResponse.""" response_metadata = self._get_metadata_from_chat_response(response) messages: list[ChatMessage] = [] finish_reason: FinishReason | None = None @@ -207,15 +230,15 @@ def _create_chat_response(self, response: ChatCompletion, chat_options: ChatOpti if choice.finish_reason: finish_reason = FinishReason(value=choice.finish_reason) contents: list[Contents] = [] - if text_content := self._parse_text_from_choice(choice): + if text_content := self._parse_text_from_openai(choice): contents.append(text_content) - if parsed_tool_calls := [tool for tool in self._get_tool_calls_from_chat_choice(choice)]: + if parsed_tool_calls := [tool for tool in self._parse_tool_calls_from_openai(choice)]: contents.extend(parsed_tool_calls) messages.append(ChatMessage(role="assistant", contents=contents)) return ChatResponse( response_id=response.id, created_at=datetime.fromtimestamp(response.created, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"), - usage_details=self._usage_details_from_openai(response.usage) if response.usage else None, + usage_details=self._parse_usage_from_openai(response.usage) if response.usage else None, messages=messages, model_id=response.model, additional_properties=response_metadata, @@ -223,16 +246,16 @@ def _create_chat_response(self, response: ChatCompletion, chat_options: ChatOpti response_format=chat_options.response_format, ) - def _create_chat_response_update( + def _parse_response_update_from_openai( self, chunk: ChatCompletionChunk, ) -> ChatResponseUpdate: - """Create a streaming chat message content object from a choice.""" + """Parse a streaming response update from OpenAI.""" chunk_metadata = self._get_metadata_from_streaming_chat_response(chunk) if chunk.usage: return ChatResponseUpdate( role=Role.ASSISTANT, - contents=[UsageContent(details=self._usage_details_from_openai(chunk.usage), raw_representation=chunk)], + contents=[UsageContent(details=self._parse_usage_from_openai(chunk.usage), raw_representation=chunk)], model_id=chunk.model, additional_properties=chunk_metadata, response_id=chunk.id, @@ -242,11 +265,11 @@ def _create_chat_response_update( finish_reason: FinishReason | None = None for choice in chunk.choices: chunk_metadata.update(self._get_metadata_from_chat_choice(choice)) - contents.extend(self._get_tool_calls_from_chat_choice(choice)) + contents.extend(self._parse_tool_calls_from_openai(choice)) if choice.finish_reason: finish_reason = FinishReason(value=choice.finish_reason) - if text_content := self._parse_text_from_choice(choice): + if text_content := self._parse_text_from_openai(choice): contents.append(text_content) return ChatResponseUpdate( created_at=datetime.fromtimestamp(chunk.created, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"), @@ -260,7 +283,7 @@ def _create_chat_response_update( message_id=chunk.id, ) - def _usage_details_from_openai(self, usage: CompletionUsage) -> UsageDetails: + def _parse_usage_from_openai(self, usage: CompletionUsage) -> UsageDetails: details = UsageDetails( input_token_count=usage.prompt_tokens, output_token_count=usage.completion_tokens, @@ -282,7 +305,7 @@ def _usage_details_from_openai(self, usage: CompletionUsage) -> UsageDetails: details["prompt/cached_tokens"] = tokens return details - def _parse_text_from_choice(self, choice: Choice | ChunkChoice) -> TextContent | None: + def _parse_text_from_openai(self, choice: Choice | ChunkChoice) -> TextContent | None: """Parse the choice into a TextContent object.""" message = choice.message if isinstance(choice, Choice) else choice.delta if message.content: @@ -309,8 +332,8 @@ def _get_metadata_from_chat_choice(self, choice: Choice | ChunkChoice) -> dict[s "logprobs": getattr(choice, "logprobs", None), } - def _get_tool_calls_from_chat_choice(self, choice: Choice | ChunkChoice) -> list[Contents]: - """Get tool calls from a chat choice.""" + def _parse_tool_calls_from_openai(self, choice: Choice | ChunkChoice) -> list[Contents]: + """Parse tool calls from an OpenAI response choice.""" resp: list[Contents] = [] content = choice.message if isinstance(choice, Choice) else choice.delta if content and content.tool_calls: @@ -328,13 +351,13 @@ def _get_tool_calls_from_chat_choice(self, choice: Choice | ChunkChoice) -> list # When you enable asynchronous content filtering in Azure OpenAI, you may receive empty deltas return resp - def _prepare_chat_history_for_request( + def _prepare_messages_for_openai( self, chat_messages: Sequence[ChatMessage], role_key: str = "role", content_key: str = "content", ) -> list[dict[str, Any]]: - """Prepare the chat history for a request. + """Prepare the chat history for an OpenAI request. Allowing customization of the key names for role/author, and optionally overriding the role. @@ -352,14 +375,14 @@ def _prepare_chat_history_for_request( Returns: prepared_chat_history (Any): The prepared chat history for a request. """ - list_of_list = [self._openai_chat_message_parser(message) for message in chat_messages] + list_of_list = [self._prepare_message_for_openai(message) for message in chat_messages] # Flatten the list of lists into a single list return list(chain.from_iterable(list_of_list)) # region Parsers - def _openai_chat_message_parser(self, message: ChatMessage) -> list[dict[str, Any]]: - """Parse a chat message into the openai format.""" + def _prepare_message_for_openai(self, message: ChatMessage) -> list[dict[str, Any]]: + """Prepare a chat message for OpenAI.""" all_messages: list[dict[str, Any]] = [] for content in message.contents: # Skip approval content - it's internal framework state, not for the LLM @@ -369,13 +392,15 @@ def _openai_chat_message_parser(self, message: ChatMessage) -> list[dict[str, An args: dict[str, Any] = { "role": message.role.value if isinstance(message.role, Role) else message.role, } + if message.author_name and message.role != Role.TOOL: + args["name"] = message.author_name match content: case FunctionCallContent(): if all_messages and "tool_calls" in all_messages[-1]: # If the last message already has tool calls, append to it - all_messages[-1]["tool_calls"].append(self._openai_content_parser(content)) + all_messages[-1]["tool_calls"].append(self._prepare_content_for_openai(content)) else: - args["tool_calls"] = [self._openai_content_parser(content)] # type: ignore + args["tool_calls"] = [self._prepare_content_for_openai(content)] # type: ignore case FunctionResultContent(): args["tool_call_id"] = content.call_id if content.result is not None: @@ -384,13 +409,13 @@ def _openai_chat_message_parser(self, message: ChatMessage) -> list[dict[str, An if "content" not in args: args["content"] = [] # this is a list to allow multi-modal content - args["content"].append(self._openai_content_parser(content)) # type: ignore + args["content"].append(self._prepare_content_for_openai(content)) # type: ignore if "content" in args or "tool_calls" in args: all_messages.append(args) return all_messages - def _openai_content_parser(self, content: Contents) -> dict[str, Any]: - """Parse contents into the openai format.""" + def _prepare_content_for_openai(self, content: Contents) -> dict[str, Any]: + """Prepare content for OpenAI.""" match content: case FunctionCallContent(): args = json.dumps(content.arguments) if isinstance(content.arguments, Mapping) else content.arguments @@ -467,7 +492,7 @@ def service_url(self) -> str: @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient): """OpenAI Chat completion class.""" diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index a537884ba4..746e50150e 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -64,7 +64,7 @@ ServiceInvalidRequestError, ServiceResponseException, ) -from ..observability import use_observability +from ..observability import use_instrumentation from ._exceptions import OpenAIContentFilterException from ._shared import OpenAIBase, OpenAIConfigMixin, OpenAISettings @@ -89,28 +89,16 @@ async def _inner_get_response( chat_options: ChatOptions, **kwargs: Any, ) -> ChatResponse: - client = await self.ensure_client() - run_options = await self.prepare_options(messages, chat_options, **kwargs) - response_format = run_options.pop("response_format", None) - text_config = run_options.pop("text", None) - text_format, text_config = self._prepare_text_config(response_format=response_format, text_config=text_config) - if text_config: - run_options["text"] = text_config + client = await self._ensure_client() + # prepare + run_options = await self._prepare_options(messages, chat_options, **kwargs) try: - if not text_format: - response = await client.responses.create( - stream=False, - **run_options, - ) - chat_options.conversation_id = self.get_conversation_id(response, chat_options.store) - return self._create_response_content(response, chat_options=chat_options) - parsed_response: ParsedResponse[BaseModel] = await client.responses.parse( - text_format=text_format, - stream=False, - **run_options, - ) - chat_options.conversation_id = self.get_conversation_id(parsed_response, chat_options.store) - return self._create_response_content(parsed_response, chat_options=chat_options) + # execute and process + if "text_format" in run_options: + response = await client.responses.parse(stream=False, **run_options) + else: + response = await client.responses.create(stream=False, **run_options) + return self._parse_response_from_openai(response, chat_options=chat_options) except BadRequestError as ex: if ex.code == "content_filter": raise OpenAIContentFilterException( @@ -134,35 +122,23 @@ async def _inner_get_streaming_response( chat_options: ChatOptions, **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: - client = await self.ensure_client() - run_options = await self.prepare_options(messages, chat_options, **kwargs) + client = await self._ensure_client() + # prepare + run_options = await self._prepare_options(messages, chat_options, **kwargs) function_call_ids: dict[int, tuple[str, str]] = {} # output_index: (call_id, name) - response_format = run_options.pop("response_format", None) - text_config = run_options.pop("text", None) - text_format, text_config = self._prepare_text_config(response_format=response_format, text_config=text_config) - if text_config: - run_options["text"] = text_config try: - if not text_format: - response = await client.responses.create( - stream=True, - **run_options, - ) - async for chunk in response: - update = self._create_streaming_response_content( + # execute and process + if "text_format" not in run_options: + async for chunk in await client.responses.create(stream=True, **run_options): + yield self._parse_chunk_from_openai( chunk, chat_options=chat_options, function_call_ids=function_call_ids ) - yield update return - async with client.responses.stream( - text_format=text_format, - **run_options, - ) as response: + async with client.responses.stream(**run_options) as response: async for chunk in response: - update = self._create_streaming_response_content( + yield self._parse_chunk_from_openai( chunk, chat_options=chat_options, function_call_ids=function_call_ids ) - yield update except BadRequestError as ex: if ex.code == "content_filter": raise OpenAIContentFilterException( @@ -179,33 +155,33 @@ async def _inner_get_streaming_response( inner_exception=ex, ) from ex - def _prepare_text_config( + def _prepare_response_and_text_format( self, *, response_format: Any, text_config: MutableMapping[str, Any] | None, ) -> tuple[type[BaseModel] | None, dict[str, Any] | None]: """Normalize response_format into Responses text configuration and parse target.""" - prepared_text = dict(text_config) if isinstance(text_config, MutableMapping) else None if text_config is not None and not isinstance(text_config, MutableMapping): raise ServiceInvalidRequestError("text must be a mapping when provided.") + text_config = cast(dict[str, Any], text_config) if isinstance(text_config, MutableMapping) else None if response_format is None: - return None, prepared_text + return None, text_config if isinstance(response_format, type) and issubclass(response_format, BaseModel): - if prepared_text and "format" in prepared_text: + if text_config and "format" in text_config: raise ServiceInvalidRequestError("response_format cannot be combined with explicit text.format.") - return response_format, prepared_text + return response_format, text_config if isinstance(response_format, Mapping): format_config = self._convert_response_format(cast("Mapping[str, Any]", response_format)) - if prepared_text is None: - prepared_text = {} - elif "format" in prepared_text and prepared_text["format"] != format_config: + if text_config is None: + text_config = {} + elif "format" in text_config and text_config["format"] != format_config: raise ServiceInvalidRequestError("Conflicting response_format definitions detected.") - prepared_text["format"] = format_config - return None, prepared_text + text_config["format"] = format_config + return None, text_config raise ServiceInvalidRequestError("response_format must be a Pydantic model or mapping.") @@ -245,23 +221,33 @@ def _convert_response_format(self, response_format: Mapping[str, Any]) -> dict[s raise ServiceInvalidRequestError("Unsupported response_format provided for Responses client.") - def get_conversation_id( + def _get_conversation_id( self, response: OpenAIResponse | ParsedResponse[BaseModel], store: bool | None ) -> str | None: """Get the conversation ID from the response if store is True.""" - return None if store is False else response.id + if store is False: + return None + # If conversation ID exists, it means that we operate with conversation + # so we use conversation ID as input and output. + if response.conversation and response.conversation.id: + return response.conversation.id + # If conversation ID doesn't exist, we operate with responses + # so we use response ID as input and output. + return response.id # region Prep methods - def _tools_to_response_tools( - self, tools: Sequence[ToolProtocol | MutableMapping[str, Any]] + def _prepare_tools_for_openai( + self, tools: Sequence[ToolProtocol | MutableMapping[str, Any]] | None ) -> list[ToolParam | dict[str, Any]]: response_tools: list[ToolParam | dict[str, Any]] = [] + if not tools: + return response_tools for tool in tools: if isinstance(tool, ToolProtocol): match tool: case HostedMCPTool(): - response_tools.append(self.get_mcp_tool(tool)) + response_tools.append(self._prepare_mcp_tool(tool)) case HostedCodeInterpreterTool(): tool_args: CodeInterpreterContainerCodeInterpreterToolAuto = {"type": "auto"} if tool.inputs: @@ -363,7 +349,8 @@ def _tools_to_response_tools( response_tools.append(tool_dict) return response_tools - def get_mcp_tool(self, tool: HostedMCPTool) -> Any: + @staticmethod + def _prepare_mcp_tool(tool: HostedMCPTool) -> Mcp: """Get MCP tool from HostedMCPTool.""" mcp: Mcp = { "type": "mcp", @@ -386,18 +373,13 @@ def get_mcp_tool(self, tool: HostedMCPTool) -> Any: return mcp - async def prepare_options( + async def _prepare_options( self, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any, ) -> dict[str, Any]: """Take ChatOptions and create the specific options for Responses API.""" - conversation_id = kwargs.pop("conversation_id", None) - - if conversation_id: - chat_options.conversation_id = conversation_id - run_options: dict[str, Any] = chat_options.to_dict( exclude={ "type", @@ -407,12 +389,24 @@ async def prepare_options( "seed", # not supported "stop", # not supported "instructions", # already added as system message + "response_format", # handled separately + "conversation_id", # handled separately + "additional_properties", # handled separately } ) + # messages + request_input = self._prepare_messages_for_openai(messages) + if not request_input: + raise ServiceInvalidRequestError("Messages are required for chat completions") + run_options["input"] = request_input - if chat_options.response_format: - run_options["response_format"] = chat_options.response_format + # model id + if not run_options.get("model"): + if not self.model_id: + raise ValueError("model_id must be a non-empty string") + run_options["model"] = self.model_id + # translations between ChatOptions and Responses API translations = { "model_id": "model", "allow_multiple_tool_calls": "parallel_tool_calls", @@ -423,34 +417,53 @@ async def prepare_options( if old_key in run_options and old_key != new_key: run_options[new_key] = run_options.pop(old_key) + # Handle different conversation ID formats + if conversation_id := self._get_current_conversation_id(chat_options, **kwargs): + if conversation_id.startswith("resp_"): + # For response IDs, set previous_response_id and remove conversation property + run_options["previous_response_id"] = conversation_id + elif conversation_id.startswith("conv_"): + # For conversation IDs, set conversation and remove previous_response_id property + run_options["conversation"] = conversation_id + else: + # If the format is unrecognized, default to previous_response_id + run_options["previous_response_id"] = conversation_id + # tools - if chat_options.tools is None: - run_options.pop("parallel_tool_calls", None) + if tools := self._prepare_tools_for_openai(chat_options.tools): + run_options["tools"] = tools else: - run_options["tools"] = self._tools_to_response_tools(chat_options.tools) + run_options.pop("parallel_tool_calls", None) + run_options.pop("tool_choice", None) + # tool choice when `tool_choice` is a dict with single key `mode`, extract the mode value + if (tool_choice := run_options.get("tool_choice")) and len(tool_choice.keys()) == 1: + run_options["tool_choice"] = tool_choice["mode"] - # model id - if not run_options.get("model"): - if not self.model_id: - raise ValueError("model_id must be a non-empty string") - run_options["model"] = self.model_id + # additional properties + additional_options = { + key: value for key, value in chat_options.additional_properties.items() if value is not None + } + if additional_options: + run_options.update(additional_options) - # messages - request_input = self._prepare_chat_messages_for_request(messages) - if not request_input: - raise ServiceInvalidRequestError("Messages are required for chat completions") - run_options["input"] = request_input + # response format and text config (after additional_properties so user can pass text via additional_properties) + response_format = chat_options.response_format + text_config = run_options.pop("text", None) + response_format, text_config = self._prepare_response_and_text_format( + response_format=response_format, text_config=text_config + ) + if text_config: + run_options["text"] = text_config + if response_format: + run_options["text_format"] = response_format - # additional provider specific settings - if additional_properties := run_options.pop("additional_properties", None): - for key, value in additional_properties.items(): - if value is not None: - run_options[key] = value - if (tool_choice := run_options.get("tool_choice")) and len(tool_choice.keys()) == 1: - run_options["tool_choice"] = tool_choice["mode"] return run_options - def _prepare_chat_messages_for_request(self, chat_messages: Sequence[ChatMessage]) -> list[dict[str, Any]]: + def _get_current_conversation_id(self, chat_options: ChatOptions, **kwargs: Any) -> str | None: + """Get the current conversation ID from chat options or kwargs.""" + return chat_options.conversation_id or kwargs.get("conversation_id") + + def _prepare_messages_for_openai(self, chat_messages: Sequence[ChatMessage]) -> list[dict[str, Any]]: """Prepare the chat messages for a request. Allowing customization of the key names for role/author, and optionally overriding the role. @@ -476,16 +489,16 @@ def _prepare_chat_messages_for_request(self, chat_messages: Sequence[ChatMessage and "fc_id" in content.additional_properties ): call_id_to_id[content.call_id] = content.additional_properties["fc_id"] - list_of_list = [self._openai_chat_message_parser(message, call_id_to_id) for message in chat_messages] + list_of_list = [self._prepare_message_for_openai(message, call_id_to_id) for message in chat_messages] # Flatten the list of lists into a single list return list(chain.from_iterable(list_of_list)) - def _openai_chat_message_parser( + def _prepare_message_for_openai( self, message: ChatMessage, call_id_to_id: dict[str, str], ) -> list[dict[str, Any]]: - """Parse a chat message into the openai format.""" + """Prepare a chat message for the OpenAI Responses API format.""" all_messages: list[dict[str, Any]] = [] args: dict[str, Any] = { "role": message.role.value if isinstance(message.role, Role) else message.role, @@ -497,28 +510,28 @@ def _openai_chat_message_parser( continue case FunctionResultContent(): new_args: dict[str, Any] = {} - new_args.update(self._openai_content_parser(message.role, content, call_id_to_id)) + new_args.update(self._prepare_content_for_openai(message.role, content, call_id_to_id)) all_messages.append(new_args) case FunctionCallContent(): - function_call = self._openai_content_parser(message.role, content, call_id_to_id) + function_call = self._prepare_content_for_openai(message.role, content, call_id_to_id) all_messages.append(function_call) # type: ignore case FunctionApprovalResponseContent() | FunctionApprovalRequestContent(): - all_messages.append(self._openai_content_parser(message.role, content, call_id_to_id)) # type: ignore + all_messages.append(self._prepare_content_for_openai(message.role, content, call_id_to_id)) # type: ignore case _: if "content" not in args: args["content"] = [] - args["content"].append(self._openai_content_parser(message.role, content, call_id_to_id)) # type: ignore + args["content"].append(self._prepare_content_for_openai(message.role, content, call_id_to_id)) # type: ignore if "content" in args or "tool_calls" in args: all_messages.append(args) return all_messages - def _openai_content_parser( + def _prepare_content_for_openai( self, role: Role, content: Contents, call_id_to_id: dict[str, str], ) -> dict[str, Any]: - """Parse contents into the openai format.""" + """Prepare content for the OpenAI Responses API format.""" match content: case TextContent(): return { @@ -625,14 +638,13 @@ def _openai_content_parser( logger.debug("Unsupported content type passed (type: %s)", type(content)) return {} - # region Response creation methods - - def _create_response_content( + # region Parse methods + def _parse_response_from_openai( self, response: OpenAIResponse | ParsedResponse[BaseModel], chat_options: ChatOptions, ) -> "ChatResponse": - """Create a chat message content object from a choice.""" + """Parse an OpenAI Responses API response into a ChatResponse.""" structured_response: BaseModel | None = response.output_parsed if isinstance(response, ParsedResponse) else None # type: ignore[reportUnknownMemberType] metadata: dict[str, Any] = response.metadata or {} @@ -826,11 +838,9 @@ def _create_response_content( "raw_representation": response, } - conversation_id = self.get_conversation_id(response, chat_options.store) # type: ignore[reportArgumentType] - - if conversation_id: + if conversation_id := self._get_conversation_id(response, chat_options.store): args["conversation_id"] = conversation_id - if response.usage and (usage_details := self._usage_details_from_openai(response.usage)): + if response.usage and (usage_details := self._parse_usage_from_openai(response.usage)): args["usage_details"] = usage_details if structured_response: args["value"] = structured_response @@ -838,13 +848,13 @@ def _create_response_content( args["response_format"] = chat_options.response_format return ChatResponse(**args) - def _create_streaming_response_content( + def _parse_chunk_from_openai( self, event: OpenAIResponseStreamEvent, chat_options: ChatOptions, function_call_ids: dict[int, tuple[str, str]], ) -> ChatResponseUpdate: - """Create a streaming chat message content object from a choice.""" + """Parse an OpenAI Responses API streaming event into a ChatResponseUpdate.""" metadata: dict[str, Any] = {} contents: list[Contents] = [] conversation_id: str | None = None @@ -931,10 +941,10 @@ def _create_streaming_response_content( contents.append(TextReasoningContent(text=event.text, raw_representation=event)) metadata.update(self._get_metadata_from_response(event)) case "response.completed": - conversation_id = self.get_conversation_id(event.response, chat_options.store) + conversation_id = self._get_conversation_id(event.response, chat_options.store) model = event.response.model if event.response.usage: - usage = self._usage_details_from_openai(event.response.usage) + usage = self._parse_usage_from_openai(event.response.usage) if usage: contents.append(UsageContent(details=usage, raw_representation=event)) case "response.output_item.added": @@ -1102,7 +1112,7 @@ def _get_ann_value(key: str) -> Any: raw_representation=event, ) - def _usage_details_from_openai(self, usage: ResponseUsage) -> UsageDetails | None: + def _parse_usage_from_openai(self, usage: ResponseUsage) -> UsageDetails | None: details = UsageDetails( input_token_count=usage.input_tokens, output_token_count=usage.output_tokens, @@ -1127,7 +1137,7 @@ def _get_metadata_from_response(self, output: Any) -> dict[str, Any]: @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class OpenAIResponsesClient(OpenAIConfigMixin, OpenAIBaseResponsesClient): """OpenAI Responses client class.""" diff --git a/python/packages/core/agent_framework/openai/_shared.py b/python/packages/core/agent_framework/openai/_shared.py index 511c1f3379..77189168f1 100644 --- a/python/packages/core/agent_framework/openai/_shared.py +++ b/python/packages/core/agent_framework/openai/_shared.py @@ -63,7 +63,7 @@ def _check_openai_version_for_callable_api_key() -> None: raise ServiceInitializationError( f"Callable API keys require OpenAI SDK >= 1.106.0, but you have {openai.__version__}. " f"Please upgrade with 'pip install openai>=1.106.0' or provide a string API key instead. " - f"Note: If you're using mem0ai, you may need to upgrade to mem0ai>=0.1.118 " + f"Note: If you're using mem0ai, you may need to upgrade to mem0ai>=1.0.0 " f"to allow newer OpenAI versions." ) except ServiceInitializationError: @@ -160,16 +160,16 @@ def __init__(self, *, model_id: str | None = None, client: AsyncOpenAI | None = for key, value in kwargs.items(): setattr(self, key, value) - async def initialize_client(self) -> None: + async def _initialize_client(self) -> None: """Initialize OpenAI client asynchronously. Override in subclasses to initialize the OpenAI client asynchronously. """ pass - async def ensure_client(self) -> AsyncOpenAI: + async def _ensure_client(self) -> AsyncOpenAI: """Ensure OpenAI client is initialized.""" - await self.initialize_client() + await self._initialize_client() if self.client is None: raise ServiceInitializationError("OpenAI client is not initialized") diff --git a/python/packages/core/pyproject.toml b/python/packages/core/pyproject.toml index ba55b9498d..17430d4e65 100644 --- a/python/packages/core/pyproject.toml +++ b/python/packages/core/pyproject.toml @@ -4,7 +4,7 @@ description = "Microsoft Agent Framework for building AI Agents with Python. Thi authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251211" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" @@ -30,7 +30,6 @@ dependencies = [ # telemetry "opentelemetry-api>=1.39.0", "opentelemetry-sdk>=1.39.0", - "opentelemetry-exporter-otlp-proto-grpc>=1.39.0", "opentelemetry-semantic-conventions-ai>=0.4.13", # connectors and functions "openai>=1.99.0", @@ -53,6 +52,7 @@ all = [ "agent-framework-devui", "agent-framework-lab", "agent-framework-mem0", + "agent-framework-ollama", "agent-framework-purview", "agent-framework-redis", ] diff --git a/python/packages/core/tests/azure/test_azure_chat_client.py b/python/packages/core/tests/azure/test_azure_chat_client.py index 1b7dbb904b..7da838529f 100644 --- a/python/packages/core/tests/azure/test_azure_chat_client.py +++ b/python/packages/core/tests/azure/test_azure_chat_client.py @@ -193,7 +193,7 @@ async def test_cmc( mock_create.assert_awaited_once_with( model=azure_openai_unit_test_env["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], stream=False, - messages=azure_chat_client._prepare_chat_history_for_request(chat_history), # type: ignore + messages=azure_chat_client._prepare_messages_for_openai(chat_history), # type: ignore ) @@ -216,7 +216,7 @@ async def test_cmc_with_logit_bias( mock_create.assert_awaited_once_with( model=azure_openai_unit_test_env["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], - messages=azure_chat_client._prepare_chat_history_for_request(chat_history), # type: ignore + messages=azure_chat_client._prepare_messages_for_openai(chat_history), # type: ignore stream=False, logit_bias=token_bias, ) @@ -241,7 +241,7 @@ async def test_cmc_with_stop( mock_create.assert_awaited_once_with( model=azure_openai_unit_test_env["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], - messages=azure_chat_client._prepare_chat_history_for_request(chat_history), # type: ignore + messages=azure_chat_client._prepare_messages_for_openai(chat_history), # type: ignore stream=False, stop=stop, ) @@ -311,7 +311,7 @@ async def test_azure_on_your_data( mock_create.assert_awaited_once_with( model=azure_openai_unit_test_env["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], - messages=azure_chat_client._prepare_chat_history_for_request(messages_out), # type: ignore + messages=azure_chat_client._prepare_messages_for_openai(messages_out), # type: ignore stream=False, extra_body=expected_data_settings, ) @@ -381,7 +381,7 @@ async def test_azure_on_your_data_string( mock_create.assert_awaited_once_with( model=azure_openai_unit_test_env["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], - messages=azure_chat_client._prepare_chat_history_for_request(messages_out), # type: ignore + messages=azure_chat_client._prepare_messages_for_openai(messages_out), # type: ignore stream=False, extra_body=expected_data_settings, ) @@ -438,7 +438,7 @@ async def test_azure_on_your_data_fail( mock_create.assert_awaited_once_with( model=azure_openai_unit_test_env["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], - messages=azure_chat_client._prepare_chat_history_for_request(messages_out), # type: ignore + messages=azure_chat_client._prepare_messages_for_openai(messages_out), # type: ignore stream=False, extra_body=expected_data_settings, ) @@ -584,7 +584,7 @@ async def test_get_streaming( mock_create.assert_awaited_once_with( model=azure_openai_unit_test_env["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], stream=True, - messages=azure_chat_client._prepare_chat_history_for_request(chat_history), # type: ignore + messages=azure_chat_client._prepare_messages_for_openai(chat_history), # type: ignore # NOTE: The `stream_options={"include_usage": True}` is explicitly enforced in # `OpenAIChatCompletionBase._inner_get_streaming_response`. # To ensure consistency, we align the arguments here accordingly. diff --git a/python/packages/core/tests/conftest.py b/python/packages/core/tests/conftest.py index d356e300bb..fd8b93ebc2 100644 --- a/python/packages/core/tests/conftest.py +++ b/python/packages/core/tests/conftest.py @@ -10,7 +10,7 @@ @fixture -def enable_otel(request: Any) -> bool: +def enable_instrumentation(request: Any) -> bool: """Fixture that returns a boolean indicating if Otel is enabled.""" return request.param if hasattr(request, "param") else True @@ -22,20 +22,31 @@ def enable_sensitive_data(request: Any) -> bool: @fixture -def span_exporter(monkeypatch, enable_otel: bool, enable_sensitive_data: bool) -> Generator[SpanExporter]: +def span_exporter(monkeypatch, enable_instrumentation: bool, enable_sensitive_data: bool) -> Generator[SpanExporter]: """Fixture to remove environment variables for ObservabilitySettings.""" env_vars = [ - "ENABLE_OTEL", + "ENABLE_INSTRUMENTATION", "ENABLE_SENSITIVE_DATA", - "OTLP_ENDPOINT", - "APPLICATIONINSIGHTS_CONNECTION_STRING", + "ENABLE_CONSOLE_EXPORTERS", + "OTEL_EXPORTER_OTLP_ENDPOINT", + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", + "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", + "OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", + "OTEL_EXPORTER_OTLP_PROTOCOL", + "OTEL_EXPORTER_OTLP_HEADERS", + "OTEL_EXPORTER_OTLP_TRACES_HEADERS", + "OTEL_EXPORTER_OTLP_METRICS_HEADERS", + "OTEL_EXPORTER_OTLP_LOGS_HEADERS", + "OTEL_SERVICE_NAME", + "OTEL_SERVICE_VERSION", + "OTEL_RESOURCE_ATTRIBUTES", ] for key in env_vars: monkeypatch.delenv(key, raising=False) # type: ignore - monkeypatch.setenv("ENABLE_OTEL", str(enable_otel)) # type: ignore - if not enable_otel: + monkeypatch.setenv("ENABLE_INSTRUMENTATION", str(enable_instrumentation)) # type: ignore + if not enable_instrumentation: # we overwrite sensitive data for tests enable_sensitive_data = False monkeypatch.setenv("ENABLE_SENSITIVE_DATA", str(enable_sensitive_data)) # type: ignore @@ -51,15 +62,22 @@ def span_exporter(monkeypatch, enable_otel: bool, enable_sensitive_data: bool) - # recreate observability settings with values from above and no file. observability_settings = observability.ObservabilitySettings(env_file_path="test.env") - observability_settings._configure() # pyright: ignore[reportPrivateUsage] + + # Configure providers manually without calling _configure() to avoid OTLP imports + if enable_instrumentation or enable_sensitive_data: + from opentelemetry.sdk.trace import TracerProvider + + tracer_provider = TracerProvider(resource=observability_settings._resource) + trace.set_tracer_provider(tracer_provider) + monkeypatch.setattr(observability, "OBSERVABILITY_SETTINGS", observability_settings, raising=False) # type: ignore with ( patch("agent_framework.observability.OBSERVABILITY_SETTINGS", observability_settings), - patch("agent_framework.observability.setup_observability"), + patch("agent_framework.observability.configure_otel_providers"), ): exporter = InMemorySpanExporter() - if enable_otel or enable_sensitive_data: + if enable_instrumentation or enable_sensitive_data: tracer_provider = trace.get_tracer_provider() if not hasattr(tracer_provider, "add_span_processor"): raise RuntimeError("Tracer provider does not support adding span processors.") diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 5a0ec5a773..bc96ddcc35 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. +from collections.abc import Awaitable, Callable + import pytest from agent_framework import ( @@ -16,6 +18,7 @@ TextContent, ai_function, ) +from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware async def test_base_client_with_function_calling(chat_client_base: ChatClientProtocol): @@ -2206,3 +2209,175 @@ def sometimes_fails(arg1: str) -> str: assert len(error_results) >= 1 assert len(success_results) >= 1 assert call_count == 2 # Both calls executed + + +class TerminateLoopMiddleware(FunctionMiddleware): + """Middleware that sets terminate=True to exit the function calling loop.""" + + async def process( + self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + # Set result to a simple value - the framework will wrap it in FunctionResultContent + context.result = "terminated by middleware" + context.terminate = True + + +async def test_terminate_loop_single_function_call(chat_client_base: ChatClientProtocol): + """Test that terminate_loop=True exits the function calling loop after single function call.""" + exec_counter = 0 + + @ai_function(name="test_function") + def ai_func(arg1: str) -> str: + nonlocal exec_counter + exec_counter += 1 + return f"Processed {arg1}" + + # Queue up two responses: function call, then final text + # If terminate_loop works, only the first response should be consumed + chat_client_base.run_responses = [ + ChatResponse( + messages=ChatMessage( + role="assistant", + contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')], + ) + ), + ChatResponse(messages=ChatMessage(role="assistant", text="done")), + ] + + response = await chat_client_base.get_response( + "hello", + tool_choice="auto", + tools=[ai_func], + middleware=[TerminateLoopMiddleware()], + ) + + # Function should NOT have been executed - middleware intercepted it + assert exec_counter == 0 + + # There should be 2 messages: assistant with function call, tool result from middleware + # The loop should NOT have continued to call the LLM again + assert len(response.messages) == 2 + assert response.messages[0].role == Role.ASSISTANT + assert isinstance(response.messages[0].contents[0], FunctionCallContent) + assert response.messages[1].role == Role.TOOL + assert isinstance(response.messages[1].contents[0], FunctionResultContent) + assert response.messages[1].contents[0].result == "terminated by middleware" + + # Verify the second response is still in the queue (wasn't consumed) + assert len(chat_client_base.run_responses) == 1 + + +class SelectiveTerminateMiddleware(FunctionMiddleware): + """Only terminates for terminating_function.""" + + async def process( + self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + if context.function.name == "terminating_function": + # Set result to a simple value - the framework will wrap it in FunctionResultContent + context.result = "terminated by middleware" + context.terminate = True + else: + await next_handler(context) + + +async def test_terminate_loop_multiple_function_calls_one_terminates(chat_client_base: ChatClientProtocol): + """Test that any(terminate_loop=True) exits loop even with multiple function calls.""" + normal_call_count = 0 + terminating_call_count = 0 + + @ai_function(name="normal_function") + def normal_func(arg1: str) -> str: + nonlocal normal_call_count + normal_call_count += 1 + return f"Normal {arg1}" + + @ai_function(name="terminating_function") + def terminating_func(arg1: str) -> str: + nonlocal terminating_call_count + terminating_call_count += 1 + return f"Terminating {arg1}" + + # Queue up two responses: parallel function calls, then final text + chat_client_base.run_responses = [ + ChatResponse( + messages=ChatMessage( + role="assistant", + contents=[ + FunctionCallContent(call_id="1", name="normal_function", arguments='{"arg1": "value1"}'), + FunctionCallContent(call_id="2", name="terminating_function", arguments='{"arg1": "value2"}'), + ], + ) + ), + ChatResponse(messages=ChatMessage(role="assistant", text="done")), + ] + + response = await chat_client_base.get_response( + "hello", + tool_choice="auto", + tools=[normal_func, terminating_func], + middleware=[SelectiveTerminateMiddleware()], + ) + + # normal_function should have executed (middleware calls next_handler) + # terminating_function should NOT have executed (middleware intercepts it) + assert normal_call_count == 1 + assert terminating_call_count == 0 + + # There should be 2 messages: assistant with function calls, tool results + # The loop should NOT have continued to call the LLM again + assert len(response.messages) == 2 + assert response.messages[0].role == Role.ASSISTANT + assert len(response.messages[0].contents) == 2 + assert response.messages[1].role == Role.TOOL + # Both function results should be present + assert len(response.messages[1].contents) == 2 + + # Verify the second response is still in the queue (wasn't consumed) + assert len(chat_client_base.run_responses) == 1 + + +async def test_terminate_loop_streaming_single_function_call(chat_client_base: ChatClientProtocol): + """Test that terminate_loop=True exits the streaming function calling loop.""" + exec_counter = 0 + + @ai_function(name="test_function") + def ai_func(arg1: str) -> str: + nonlocal exec_counter + exec_counter += 1 + return f"Processed {arg1}" + + # Queue up two streaming responses + chat_client_base.streaming_responses = [ + [ + ChatResponseUpdate( + contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')], + role="assistant", + ), + ], + [ + ChatResponseUpdate( + contents=[TextContent(text="done")], + role="assistant", + ) + ], + ] + + updates = [] + async for update in chat_client_base.get_streaming_response( + "hello", + tool_choice="auto", + tools=[ai_func], + middleware=[TerminateLoopMiddleware()], + ): + updates.append(update) + + # Function should NOT have been executed - middleware intercepted it + assert exec_counter == 0 + + # Should have function call update and function result update + # The loop should NOT have continued to call the LLM again + assert len(updates) == 2 + + # Verify the second streaming response is still in the queue (wasn't consumed) + assert len(chat_client_base.streaming_responses) == 1 diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 813667bb7a..18c90d64b3 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -24,14 +24,14 @@ ) from agent_framework._mcp import ( MCPTool, - _ai_content_to_mcp_types, - _chat_message_to_mcp_types, _get_input_model_from_mcp_prompt, _get_input_model_from_mcp_tool, - _mcp_call_tool_result_to_ai_contents, - _mcp_prompt_message_to_chat_message, - _mcp_type_to_ai_content, _normalize_mcp_name, + _parse_content_from_mcp, + _parse_contents_from_mcp_tool_result, + _parse_message_from_mcp, + _prepare_content_for_mcp, + _prepare_message_for_mcp, ) from agent_framework.exceptions import ToolException, ToolExecutionException @@ -60,7 +60,7 @@ def test_normalize_mcp_name(): def test_mcp_prompt_message_to_ai_content(): """Test conversion from MCP prompt message to AI content.""" mcp_message = types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hello, world!")) - ai_content = _mcp_prompt_message_to_chat_message(mcp_message) + ai_content = _parse_message_from_mcp(mcp_message) assert isinstance(ai_content, ChatMessage) assert ai_content.role.value == "user" @@ -70,22 +70,26 @@ def test_mcp_prompt_message_to_ai_content(): assert ai_content.raw_representation == mcp_message -def test_mcp_call_tool_result_to_ai_contents(): +def test_parse_contents_from_mcp_tool_result(): """Test conversion from MCP tool result to AI contents.""" mcp_result = types.CallToolResult( content=[ types.TextContent(type="text", text="Result text"), - types.ImageContent(type="image", data="", mimeType="image/png"), + types.ImageContent(type="image", data="xyz", mimeType="image/png"), + types.ImageContent(type="image", data=b"abc", mimeType="image/webp"), ] ) - ai_contents = _mcp_call_tool_result_to_ai_contents(mcp_result) + ai_contents = _parse_contents_from_mcp_tool_result(mcp_result) - assert len(ai_contents) == 2 + assert len(ai_contents) == 3 assert isinstance(ai_contents[0], TextContent) assert ai_contents[0].text == "Result text" assert isinstance(ai_contents[1], DataContent) assert ai_contents[1].uri == "" assert ai_contents[1].media_type == "image/png" + assert isinstance(ai_contents[2], DataContent) + assert ai_contents[2].uri == "" + assert ai_contents[2].media_type == "image/webp" def test_mcp_call_tool_result_with_meta_error(): @@ -96,7 +100,7 @@ def test_mcp_call_tool_result_with_meta_error(): _meta={"isError": True, "errorCode": "TOOL_ERROR", "errorMessage": "Tool execution failed"}, ) - ai_contents = _mcp_call_tool_result_to_ai_contents(mcp_result) + ai_contents = _parse_contents_from_mcp_tool_result(mcp_result) assert len(ai_contents) == 1 assert isinstance(ai_contents[0], TextContent) @@ -127,7 +131,7 @@ def test_mcp_call_tool_result_with_meta_arbitrary_data(): }, ) - ai_contents = _mcp_call_tool_result_to_ai_contents(mcp_result) + ai_contents = _parse_contents_from_mcp_tool_result(mcp_result) assert len(ai_contents) == 1 assert isinstance(ai_contents[0], TextContent) @@ -149,7 +153,7 @@ def test_mcp_call_tool_result_with_meta_merging_existing_properties(): text_content = types.TextContent(type="text", text="Test content") mcp_result = types.CallToolResult(content=[text_content], _meta={"newField": "newValue", "isError": False}) - ai_contents = _mcp_call_tool_result_to_ai_contents(mcp_result) + ai_contents = _parse_contents_from_mcp_tool_result(mcp_result) assert len(ai_contents) == 1 content = ai_contents[0] @@ -165,7 +169,7 @@ def test_mcp_call_tool_result_with_meta_none(): mcp_result = types.CallToolResult(content=[types.TextContent(type="text", text="No meta test")]) # No _meta field set - ai_contents = _mcp_call_tool_result_to_ai_contents(mcp_result) + ai_contents = _parse_contents_from_mcp_tool_result(mcp_result) assert len(ai_contents) == 1 assert isinstance(ai_contents[0], TextContent) @@ -183,11 +187,11 @@ def test_mcp_call_tool_result_regression_successful_workflow(): mcp_result = types.CallToolResult( content=[ types.TextContent(type="text", text="Success message"), - types.ImageContent(type="image", data="", mimeType="image/jpeg"), + types.ImageContent(type="image", data="abc123", mimeType="image/jpeg"), ] ) - ai_contents = _mcp_call_tool_result_to_ai_contents(mcp_result) + ai_contents = _parse_contents_from_mcp_tool_result(mcp_result) # Verify basic conversion still works correctly assert len(ai_contents) == 2 @@ -209,7 +213,7 @@ def test_mcp_call_tool_result_regression_successful_workflow(): def test_mcp_content_types_to_ai_content_text(): """Test conversion of MCP text content to AI content.""" mcp_content = types.TextContent(type="text", text="Sample text") - ai_content = _mcp_type_to_ai_content(mcp_content)[0] + ai_content = _parse_content_from_mcp(mcp_content)[0] assert isinstance(ai_content, TextContent) assert ai_content.text == "Sample text" @@ -218,8 +222,9 @@ def test_mcp_content_types_to_ai_content_text(): def test_mcp_content_types_to_ai_content_image(): """Test conversion of MCP image content to AI content.""" - mcp_content = types.ImageContent(type="image", data="", mimeType="image/jpeg") - ai_content = _mcp_type_to_ai_content(mcp_content)[0] + mcp_content = types.ImageContent(type="image", data="abc", mimeType="image/jpeg") + mcp_content = types.ImageContent(type="image", data=b"abc", mimeType="image/jpeg") + ai_content = _parse_content_from_mcp(mcp_content)[0] assert isinstance(ai_content, DataContent) assert ai_content.uri == "" @@ -229,8 +234,8 @@ def test_mcp_content_types_to_ai_content_image(): def test_mcp_content_types_to_ai_content_audio(): """Test conversion of MCP audio content to AI content.""" - mcp_content = types.AudioContent(type="audio", data="data:audio/wav;base64,def", mimeType="audio/wav") - ai_content = _mcp_type_to_ai_content(mcp_content)[0] + mcp_content = types.AudioContent(type="audio", data="def", mimeType="audio/wav") + ai_content = _parse_content_from_mcp(mcp_content)[0] assert isinstance(ai_content, DataContent) assert ai_content.uri == "data:audio/wav;base64,def" @@ -246,7 +251,7 @@ def test_mcp_content_types_to_ai_content_resource_link(): name="test_resource", mimeType="application/json", ) - ai_content = _mcp_type_to_ai_content(mcp_content)[0] + ai_content = _parse_content_from_mcp(mcp_content)[0] assert isinstance(ai_content, UriContent) assert ai_content.uri == "https://example.com/resource" @@ -262,7 +267,7 @@ def test_mcp_content_types_to_ai_content_embedded_resource_text(): text="Embedded text content", ) mcp_content = types.EmbeddedResource(type="resource", resource=text_resource) - ai_content = _mcp_type_to_ai_content(mcp_content)[0] + ai_content = _parse_content_from_mcp(mcp_content)[0] assert isinstance(ai_content, TextContent) assert ai_content.text == "Embedded text content" @@ -278,7 +283,7 @@ def test_mcp_content_types_to_ai_content_embedded_resource_blob(): blob="data:application/octet-stream;base64,dGVzdCBkYXRh", ) mcp_content = types.EmbeddedResource(type="resource", resource=blob_resource) - ai_content = _mcp_type_to_ai_content(mcp_content)[0] + ai_content = _parse_content_from_mcp(mcp_content)[0] assert isinstance(ai_content, DataContent) assert ai_content.uri == "data:application/octet-stream;base64,dGVzdCBkYXRh" @@ -289,7 +294,7 @@ def test_mcp_content_types_to_ai_content_embedded_resource_blob(): def test_ai_content_to_mcp_content_types_text(): """Test conversion of AI text content to MCP content.""" ai_content = TextContent(text="Sample text") - mcp_content = _ai_content_to_mcp_types(ai_content) + mcp_content = _prepare_content_for_mcp(ai_content) assert isinstance(mcp_content, types.TextContent) assert mcp_content.type == "text" @@ -299,7 +304,7 @@ def test_ai_content_to_mcp_content_types_text(): def test_ai_content_to_mcp_content_types_data_image(): """Test conversion of AI data content to MCP content.""" ai_content = DataContent(uri="", media_type="image/png") - mcp_content = _ai_content_to_mcp_types(ai_content) + mcp_content = _prepare_content_for_mcp(ai_content) assert isinstance(mcp_content, types.ImageContent) assert mcp_content.type == "image" @@ -310,7 +315,7 @@ def test_ai_content_to_mcp_content_types_data_image(): def test_ai_content_to_mcp_content_types_data_audio(): """Test conversion of AI data content to MCP content.""" ai_content = DataContent(uri="data:audio/mpeg;base64,xyz", media_type="audio/mpeg") - mcp_content = _ai_content_to_mcp_types(ai_content) + mcp_content = _prepare_content_for_mcp(ai_content) assert isinstance(mcp_content, types.AudioContent) assert mcp_content.type == "audio" @@ -324,7 +329,7 @@ def test_ai_content_to_mcp_content_types_data_binary(): uri="data:application/octet-stream;base64,xyz", media_type="application/octet-stream", ) - mcp_content = _ai_content_to_mcp_types(ai_content) + mcp_content = _prepare_content_for_mcp(ai_content) assert isinstance(mcp_content, types.EmbeddedResource) assert mcp_content.type == "resource" @@ -335,7 +340,7 @@ def test_ai_content_to_mcp_content_types_data_binary(): def test_ai_content_to_mcp_content_types_uri(): """Test conversion of AI URI content to MCP content.""" ai_content = UriContent(uri="https://example.com/resource", media_type="application/json") - mcp_content = _ai_content_to_mcp_types(ai_content) + mcp_content = _prepare_content_for_mcp(ai_content) assert isinstance(mcp_content, types.ResourceLink) assert mcp_content.type == "resource_link" @@ -343,7 +348,7 @@ def test_ai_content_to_mcp_content_types_uri(): assert mcp_content.mimeType == "application/json" -def test_chat_message_to_mcp_types(): +def test_prepare_message_for_mcp(): message = ChatMessage( role="user", contents=[ @@ -351,7 +356,7 @@ def test_chat_message_to_mcp_types(): DataContent(uri="", media_type="image/png"), ], ) - mcp_contents = _chat_message_to_mcp_types(message) + mcp_contents = _prepare_message_for_mcp(message) assert len(mcp_contents) == 2 assert isinstance(mcp_contents[0], types.TextContent) assert isinstance(mcp_contents[1], types.ImageContent) diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index 7b280da123..6cb41f674b 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -193,7 +193,8 @@ async def process( # Create a message to start the conversation messages = [ChatMessage(role=Role.USER, text="test message")] - # Set up chat client to return a function call + # Set up chat client to return a function call, then a final response + # If terminate works correctly, only the first response should be consumed chat_client.responses = [ ChatResponse( messages=[ @@ -204,7 +205,8 @@ async def process( ], ) ] - ) + ), + ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="this should not be consumed")]), ] # Create the test function with the expected signature @@ -222,7 +224,11 @@ def test_function(text: str) -> str: # Verify that function was not called and only middleware executed assert execution_order == ["middleware_before", "middleware_after"] assert "function_called" not in execution_order - assert execution_order == ["middleware_before", "middleware_after"] + + # Verify the chat client was only called once (no extra LLM call after termination) + assert chat_client.call_count == 1 + # Verify the second response is still in the queue (wasn't consumed) + assert len(chat_client.responses) == 1 async def test_function_middleware_with_post_termination(self, chat_client: "MockChatClient") -> None: """Test that function middleware can terminate execution after calling next().""" @@ -242,7 +248,8 @@ async def process( # Create a message to start the conversation messages = [ChatMessage(role=Role.USER, text="test message")] - # Set up chat client to return a function call + # Set up chat client to return a function call, then a final response + # If terminate works correctly, only the first response should be consumed chat_client.responses = [ ChatResponse( messages=[ @@ -253,7 +260,8 @@ async def process( ], ) ] - ) + ), + ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="this should not be consumed")]), ] # Create the test function with the expected signature @@ -273,6 +281,11 @@ def test_function(text: str) -> str: assert "function_called" in execution_order assert execution_order == ["middleware_before", "function_called", "middleware_after"] + # Verify the chat client was only called once (no extra LLM call after termination) + assert chat_client.call_count == 1 + # Verify the second response is still in the queue (wasn't consumed) + assert len(chat_client.responses) == 1 + async def test_function_based_agent_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None: """Test function-based agent middleware with ChatAgent.""" execution_order: list[str] = [] diff --git a/python/packages/core/tests/core/test_observability.py b/python/packages/core/tests/core/test_observability.py index abdc5184be..8528295406 100644 --- a/python/packages/core/tests/core/test_observability.py +++ b/python/packages/core/tests/core/test_observability.py @@ -33,8 +33,8 @@ ChatMessageListTimestampFilter, OtelAttr, get_function_span, - use_agent_observability, - use_observability, + use_agent_instrumentation, + use_instrumentation, ) # region Test constants @@ -157,7 +157,7 @@ def test_start_span_with_tool_call_id(span_exporter: InMemorySpanExporter): assert span.attributes[OtelAttr.TOOL_TYPE] == "function" -# region Test use_observability decorator +# region Test use_instrumentation decorator def test_decorator_with_valid_class(): @@ -175,7 +175,7 @@ async def gen(): return gen() # Apply the decorator - decorated_class = use_observability(MockChatClient) + decorated_class = use_instrumentation(MockChatClient) assert hasattr(decorated_class, OPEN_TELEMETRY_CHAT_CLIENT_MARKER) @@ -187,7 +187,7 @@ class MockChatClient: # Apply the decorator - should not raise an error with pytest.raises(ChatClientInitializationError): - use_observability(MockChatClient) + use_instrumentation(MockChatClient) def test_decorator_with_partial_methods(): @@ -200,7 +200,7 @@ async def get_response(self, messages, **kwargs): return Mock() with pytest.raises(ChatClientInitializationError): - use_observability(MockChatClient) + use_instrumentation(MockChatClient) # region Test telemetry decorator with mock client @@ -235,7 +235,7 @@ async def _inner_get_streaming_response( @pytest.mark.parametrize("enable_sensitive_data", [True, False], indirect=True) async def test_chat_client_observability(mock_chat_client, span_exporter: InMemorySpanExporter, enable_sensitive_data): """Test that when diagnostics are enabled, telemetry is applied.""" - client = use_observability(mock_chat_client)() + client = use_instrumentation(mock_chat_client)() messages = [ChatMessage(role=Role.USER, text="Test message")] span_exporter.clear() @@ -258,8 +258,8 @@ async def test_chat_client_observability(mock_chat_client, span_exporter: InMemo async def test_chat_client_streaming_observability( mock_chat_client, span_exporter: InMemorySpanExporter, enable_sensitive_data ): - """Test streaming telemetry through the use_observability decorator.""" - client = use_observability(mock_chat_client)() + """Test streaming telemetry through the use_instrumentation decorator.""" + client = use_instrumentation(mock_chat_client)() messages = [ChatMessage(role=Role.USER, text="Test")] span_exporter.clear() # Collect all yielded updates @@ -282,7 +282,7 @@ async def test_chat_client_streaming_observability( async def test_chat_client_without_model_id_observability(mock_chat_client, span_exporter: InMemorySpanExporter): """Test telemetry shouldn't fail when the model_id is not provided for unknown reason.""" - client = use_observability(mock_chat_client)() + client = use_instrumentation(mock_chat_client)() messages = [ChatMessage(role=Role.USER, text="Test")] span_exporter.clear() response = await client.get_response(messages=messages) @@ -301,7 +301,7 @@ async def test_chat_client_streaming_without_model_id_observability( mock_chat_client, span_exporter: InMemorySpanExporter ): """Test streaming telemetry shouldn't fail when the model_id is not provided for unknown reason.""" - client = use_observability(mock_chat_client)() + client = use_instrumentation(mock_chat_client)() messages = [ChatMessage(role=Role.USER, text="Test")] span_exporter.clear() # Collect all yielded updates @@ -329,7 +329,7 @@ def test_prepend_user_agent_with_none_value(): assert AGENT_FRAMEWORK_USER_AGENT in str(result["User-Agent"]) -# region Test use_agent_observability decorator +# region Test use_agent_instrumentation decorator def test_agent_decorator_with_valid_class(): @@ -337,7 +337,7 @@ def test_agent_decorator_with_valid_class(): # Create a mock class with the required methods class MockChatClientAgent: - AGENT_SYSTEM_NAME = "test_agent_system" + AGENT_PROVIDER_NAME = "test_agent_system" def __init__(self): self.id = "test_agent_id" @@ -358,7 +358,7 @@ def get_new_thread(self) -> AgentThread: return AgentThread() # Apply the decorator - decorated_class = use_agent_observability(MockChatClientAgent) + decorated_class = use_agent_instrumentation(MockChatClientAgent) assert hasattr(decorated_class, OPEN_TELEMETRY_AGENT_MARKER) @@ -367,19 +367,19 @@ def test_agent_decorator_with_missing_methods(): """Test that agent decorator handles classes missing required methods gracefully.""" class MockAgent: - AGENT_SYSTEM_NAME = "test_agent_system" + AGENT_PROVIDER_NAME = "test_agent_system" # Apply the decorator - should not raise an error with pytest.raises(AgentInitializationError): - use_agent_observability(MockAgent) + use_agent_instrumentation(MockAgent) def test_agent_decorator_with_partial_methods(): """Test agent decorator when only one method is present.""" - from agent_framework.observability import use_agent_observability + from agent_framework.observability import use_agent_instrumentation class MockAgent: - AGENT_SYSTEM_NAME = "test_agent_system" + AGENT_PROVIDER_NAME = "test_agent_system" def __init__(self): self.id = "test_agent_id" @@ -390,7 +390,7 @@ async def run(self, messages=None, *, thread=None, **kwargs): return Mock() with pytest.raises(AgentInitializationError): - use_agent_observability(MockAgent) + use_agent_instrumentation(MockAgent) # region Test agent telemetry decorator with mock agent @@ -401,7 +401,7 @@ def mock_chat_agent(): """Create a mock chat client agent for testing.""" class MockChatClientAgent: - AGENT_SYSTEM_NAME = "test_agent_system" + AGENT_PROVIDER_NAME = "test_agent_system" def __init__(self): self.id = "test_agent_id" @@ -433,7 +433,7 @@ async def test_agent_instrumentation_enabled( ): """Test that when agent diagnostics are enabled, telemetry is applied.""" - agent = use_agent_observability(mock_chat_agent)() + agent = use_agent_instrumentation(mock_chat_agent)() span_exporter.clear() response = await agent.run("Test message") @@ -457,8 +457,8 @@ async def test_agent_instrumentation_enabled( async def test_agent_streaming_response_with_diagnostics_enabled_via_decorator( mock_chat_agent: AgentProtocol, span_exporter: InMemorySpanExporter, enable_sensitive_data ): - """Test agent streaming telemetry through the use_agent_observability decorator.""" - agent = use_agent_observability(mock_chat_agent)() + """Test agent streaming telemetry through the use_agent_instrumentation decorator.""" + agent = use_agent_instrumentation(mock_chat_agent)() span_exporter.clear() updates = [] async for update in agent.run_stream("Test message"): @@ -522,3 +522,393 @@ async def failing_function(param: str) -> str: exception_message = exception_event.attributes["exception.message"] assert isinstance(exception_message, str) assert "Function execution failed" in exception_message + + +# region Test OTEL environment variable parsing + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_get_exporters_from_env_with_grpc_endpoint(monkeypatch): + """Test _get_exporters_from_env with OTEL_EXPORTER_OTLP_ENDPOINT (gRPC).""" + from agent_framework.observability import _get_exporters_from_env + + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc") + + exporters = _get_exporters_from_env() + + # Should return 3 exporters (trace, metrics, logs) + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_get_exporters_from_env_with_http_endpoint(monkeypatch): + """Test _get_exporters_from_env with OTEL_EXPORTER_OTLP_ENDPOINT (HTTP).""" + from agent_framework.observability import _get_exporters_from_env + + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4318") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http") + + exporters = _get_exporters_from_env() + + # Should return 3 exporters (trace, metrics, logs) + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_get_exporters_from_env_with_individual_endpoints(monkeypatch): + """Test _get_exporters_from_env with individual signal endpoints.""" + from agent_framework.observability import _get_exporters_from_env + + monkeypatch.setenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "http://localhost:4317") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", "http://localhost:4318") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", "http://localhost:4319") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc") + + exporters = _get_exporters_from_env() + + # Should return 3 exporters (trace, metrics, logs) + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_get_exporters_from_env_with_headers(monkeypatch): + """Test _get_exporters_from_env with OTEL_EXPORTER_OTLP_HEADERS.""" + from agent_framework.observability import _get_exporters_from_env + + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_HEADERS", "key1=value1,key2=value2") + + exporters = _get_exporters_from_env() + + # Should return 3 exporters with headers + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_get_exporters_from_env_with_signal_specific_headers(monkeypatch): + """Test _get_exporters_from_env with signal-specific headers.""" + from agent_framework.observability import _get_exporters_from_env + + monkeypatch.setenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "http://localhost:4317") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_TRACES_HEADERS", "trace-key=trace-value") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc") + + exporters = _get_exporters_from_env() + + # Should have at least the traces exporter + assert len(exporters) >= 1 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_get_exporters_from_env_without_env_vars(monkeypatch): + """Test _get_exporters_from_env returns empty list when no env vars set.""" + from agent_framework.observability import _get_exporters_from_env + + # Clear all OTEL env vars + for key in [ + "OTEL_EXPORTER_OTLP_ENDPOINT", + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", + "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", + "OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", + ]: + monkeypatch.delenv(key, raising=False) + + exporters = _get_exporters_from_env() + + # Should return empty list + assert len(exporters) == 0 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_get_exporters_from_env_missing_grpc_dependency(monkeypatch): + """Test _get_exporters_from_env raises ImportError when gRPC exporters not installed.""" + + from agent_framework.observability import _get_exporters_from_env + + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc") + + # Mock the import to raise ImportError + original_import = __builtins__.__import__ + + def mock_import(name, *args, **kwargs): + if "opentelemetry.exporter.otlp.proto.grpc" in name: + raise ImportError("No module named 'opentelemetry.exporter.otlp.proto.grpc'") + return original_import(name, *args, **kwargs) + + monkeypatch.setattr(__builtins__, "__import__", mock_import) + + with pytest.raises(ImportError, match="opentelemetry-exporter-otlp-proto-grpc"): + _get_exporters_from_env() + + +# region Test create_resource + + +def test_create_resource_from_env(monkeypatch): + """Test create_resource reads OTEL environment variables.""" + from agent_framework.observability import create_resource + + monkeypatch.setenv("OTEL_SERVICE_NAME", "test-service") + monkeypatch.setenv("OTEL_SERVICE_VERSION", "1.0.0") + monkeypatch.setenv("OTEL_RESOURCE_ATTRIBUTES", "deployment.environment=production,host.name=server1") + + resource = create_resource() + + assert resource.attributes["service.name"] == "test-service" + assert resource.attributes["service.version"] == "1.0.0" + assert resource.attributes["deployment.environment"] == "production" + assert resource.attributes["host.name"] == "server1" + + +def test_create_resource_with_parameters_override_env(monkeypatch): + """Test create_resource parameters override environment variables.""" + from agent_framework.observability import create_resource + + monkeypatch.setenv("OTEL_SERVICE_NAME", "env-service") + monkeypatch.setenv("OTEL_SERVICE_VERSION", "0.1.0") + + resource = create_resource(service_name="param-service", service_version="2.0.0") + + # Parameters should override env vars + assert resource.attributes["service.name"] == "param-service" + assert resource.attributes["service.version"] == "2.0.0" + + +def test_create_resource_with_custom_attributes(monkeypatch): + """Test create_resource accepts custom attributes.""" + from agent_framework.observability import create_resource + + resource = create_resource(custom_attr="custom_value", another_attr=123) + + assert resource.attributes["custom_attr"] == "custom_value" + assert resource.attributes["another_attr"] == 123 + + +# region Test _create_otlp_exporters + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_create_otlp_exporters_grpc_with_single_endpoint(): + """Test _create_otlp_exporters creates gRPC exporters with single endpoint.""" + from agent_framework.observability import _create_otlp_exporters + + exporters = _create_otlp_exporters(endpoint="http://localhost:4317", protocol="grpc") + + # Should return 3 exporters (trace, metrics, logs) + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_create_otlp_exporters_http_with_single_endpoint(): + """Test _create_otlp_exporters creates HTTP exporters with single endpoint.""" + from agent_framework.observability import _create_otlp_exporters + + exporters = _create_otlp_exporters(endpoint="http://localhost:4318", protocol="http") + + # Should return 3 exporters (trace, metrics, logs) + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_create_otlp_exporters_with_individual_endpoints(): + """Test _create_otlp_exporters with individual signal endpoints.""" + from agent_framework.observability import _create_otlp_exporters + + exporters = _create_otlp_exporters( + protocol="grpc", + traces_endpoint="http://localhost:4317", + metrics_endpoint="http://localhost:4318", + logs_endpoint="http://localhost:4319", + ) + + # Should return 3 exporters + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_create_otlp_exporters_with_headers(): + """Test _create_otlp_exporters with headers.""" + from agent_framework.observability import _create_otlp_exporters + + exporters = _create_otlp_exporters( + endpoint="http://localhost:4317", protocol="grpc", headers={"Authorization": "Bearer token"} + ) + + # Should return 3 exporters with headers + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_create_otlp_exporters_grpc_missing_dependency(): + """Test _create_otlp_exporters raises ImportError when gRPC exporters not installed.""" + import sys + from unittest.mock import patch + + from agent_framework.observability import _create_otlp_exporters + + # Mock the import to raise ImportError + with ( + patch.dict(sys.modules, {"opentelemetry.exporter.otlp.proto.grpc.trace_exporter": None}), + pytest.raises(ImportError, match="opentelemetry-exporter-otlp-proto-grpc"), + ): + _create_otlp_exporters(endpoint="http://localhost:4317", protocol="grpc") + + +# region Test configure_otel_providers with views + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_configure_otel_providers_with_views(monkeypatch): + """Test configure_otel_providers accepts views parameter.""" + from opentelemetry.sdk.metrics import View + from opentelemetry.sdk.metrics.view import DropAggregation + + from agent_framework.observability import configure_otel_providers + + # Clear all OTEL env vars + for key in [ + "OTEL_EXPORTER_OTLP_ENDPOINT", + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", + "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", + "OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", + ]: + monkeypatch.delenv(key, raising=False) + + # Create a view that drops all metrics + views = [View(instrument_name="*", aggregation=DropAggregation())] + + # Should not raise an error + configure_otel_providers(views=views) + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_configure_otel_providers_without_views(monkeypatch): + """Test configure_otel_providers works without views parameter.""" + from agent_framework.observability import configure_otel_providers + + # Clear all OTEL env vars + for key in [ + "OTEL_EXPORTER_OTLP_ENDPOINT", + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", + "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", + "OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", + ]: + monkeypatch.delenv(key, raising=False) + + # Should not raise an error with default empty views + configure_otel_providers() + + +# region Test console exporters opt-in + + +def test_console_exporters_opt_in_false(monkeypatch): + """Test console exporters are not added when ENABLE_CONSOLE_EXPORTERS is false.""" + from agent_framework.observability import ObservabilitySettings + + monkeypatch.setenv("ENABLE_CONSOLE_EXPORTERS", "false") + monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + + settings = ObservabilitySettings(env_file_path="test.env") + assert settings.enable_console_exporters is False + + +def test_console_exporters_opt_in_true(monkeypatch): + """Test console exporters are added when ENABLE_CONSOLE_EXPORTERS is true.""" + from agent_framework.observability import ObservabilitySettings + + monkeypatch.setenv("ENABLE_CONSOLE_EXPORTERS", "true") + + settings = ObservabilitySettings(env_file_path="test.env") + assert settings.enable_console_exporters is True + + +def test_console_exporters_default_false(monkeypatch): + """Test console exporters default to False when not set.""" + from agent_framework.observability import ObservabilitySettings + + monkeypatch.delenv("ENABLE_CONSOLE_EXPORTERS", raising=False) + + settings = ObservabilitySettings(env_file_path="test.env") + assert settings.enable_console_exporters is False + + +# region Test _parse_headers helper + + +def test_parse_headers_valid(): + """Test _parse_headers with valid header string.""" + from agent_framework.observability import _parse_headers + + headers = _parse_headers("key1=value1,key2=value2") + assert headers == {"key1": "value1", "key2": "value2"} + + +def test_parse_headers_with_spaces(): + """Test _parse_headers handles spaces around keys and values.""" + from agent_framework.observability import _parse_headers + + headers = _parse_headers("key1 = value1 , key2 = value2 ") + assert headers == {"key1": "value1", "key2": "value2"} + + +def test_parse_headers_empty_string(): + """Test _parse_headers with empty string.""" + from agent_framework.observability import _parse_headers + + headers = _parse_headers("") + assert headers == {} + + +def test_parse_headers_invalid_format(): + """Test _parse_headers ignores invalid pairs.""" + from agent_framework.observability import _parse_headers + + headers = _parse_headers("key1=value1,invalid,key2=value2") + # Should only include valid pairs + assert headers == {"key1": "value1", "key2": "value2"} diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index 4beee1fb7d..88c34dc3e8 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -1,5 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import Any +from typing import Annotated, Any, Literal from unittest.mock import Mock import pytest @@ -14,7 +14,7 @@ ToolProtocol, ai_function, ) -from agent_framework._tools import _parse_inputs +from agent_framework._tools import _parse_annotation, _parse_inputs from agent_framework.exceptions import ToolException from agent_framework.observability import OtelAttr @@ -128,6 +128,95 @@ def test_tool(self, x: int, y: int) -> int: assert test_tool(1, 2) == 3 +def test_ai_function_with_literal_type_parameter(): + """Test ai_function decorator with Literal type parameter (issue #2891).""" + + @ai_function + def search_flows(category: Literal["Data", "Security", "Network"], issue: str) -> str: + """Search flows by category.""" + return f"{category}: {issue}" + + assert isinstance(search_flows, AIFunction) + schema = search_flows.parameters() + assert schema == { + "properties": { + "category": {"enum": ["Data", "Security", "Network"], "title": "Category", "type": "string"}, + "issue": {"title": "Issue", "type": "string"}, + }, + "required": ["category", "issue"], + "title": "search_flows_input", + "type": "object", + } + # Verify invocation works + assert search_flows("Data", "test issue") == "Data: test issue" + + +def test_ai_function_with_literal_type_in_class_method(): + """Test ai_function decorator with Literal type parameter in a class method (issue #2891).""" + + class MyTools: + @ai_function + def search_flows(self, category: Literal["Data", "Security", "Network"], issue: str) -> str: + """Search flows by category.""" + return f"{category}: {issue}" + + tools = MyTools() + search_tool = tools.search_flows + assert isinstance(search_tool, AIFunction) + schema = search_tool.parameters() + assert schema == { + "properties": { + "category": {"enum": ["Data", "Security", "Network"], "title": "Category", "type": "string"}, + "issue": {"title": "Issue", "type": "string"}, + }, + "required": ["category", "issue"], + "title": "search_flows_input", + "type": "object", + } + # Verify invocation works + assert search_tool("Security", "test issue") == "Security: test issue" + + +def test_ai_function_with_literal_int_type(): + """Test ai_function decorator with Literal int type parameter.""" + + @ai_function + def set_priority(priority: Literal[1, 2, 3], task: str) -> str: + """Set priority for a task.""" + return f"Priority {priority}: {task}" + + assert isinstance(set_priority, AIFunction) + schema = set_priority.parameters() + assert schema == { + "properties": { + "priority": {"enum": [1, 2, 3], "title": "Priority", "type": "integer"}, + "task": {"title": "Task", "type": "string"}, + }, + "required": ["priority", "task"], + "title": "set_priority_input", + "type": "object", + } + assert set_priority(1, "important task") == "Priority 1: important task" + + +def test_ai_function_with_literal_and_annotated(): + """Test ai_function decorator with Literal type combined with Annotated for description.""" + + @ai_function + def categorize( + category: Annotated[Literal["A", "B", "C"], "The category to assign"], + name: str, + ) -> str: + """Categorize an item.""" + return f"{category}: {name}" + + assert isinstance(categorize, AIFunction) + schema = categorize.parameters() + # Literal type inside Annotated should preserve enum values + assert schema["properties"]["category"]["enum"] == ["A", "B", "C"] + assert categorize("A", "test") == "A: test" + + async def test_ai_function_decorator_shared_state(): """Test that decorated methods maintain shared state across multiple calls and tool usage.""" @@ -1368,3 +1457,70 @@ def tool_with_kwargs(x: int, **kwargs: Any) -> str: arguments=tool_with_kwargs.input_model(x=10), ) assert result_default == "x=10, user=unknown" + + +# region _parse_annotation tests + + +def test_parse_annotation_with_literal_type(): + """Test that _parse_annotation returns Literal types unchanged (issue #2891).""" + from typing import get_args, get_origin + + # Literal with string values + literal_annotation = Literal["Data", "Security", "Network"] + result = _parse_annotation(literal_annotation) + assert result is literal_annotation + assert get_origin(result) is Literal + assert get_args(result) == ("Data", "Security", "Network") + + +def test_parse_annotation_with_literal_int_type(): + """Test that _parse_annotation returns Literal int types unchanged.""" + from typing import get_args, get_origin + + literal_annotation = Literal[1, 2, 3] + result = _parse_annotation(literal_annotation) + assert result is literal_annotation + assert get_origin(result) is Literal + assert get_args(result) == (1, 2, 3) + + +def test_parse_annotation_with_literal_bool_type(): + """Test that _parse_annotation returns Literal bool types unchanged.""" + from typing import get_args, get_origin + + literal_annotation = Literal[True, False] + result = _parse_annotation(literal_annotation) + assert result is literal_annotation + assert get_origin(result) is Literal + assert get_args(result) == (True, False) + + +def test_parse_annotation_with_simple_types(): + """Test that _parse_annotation returns simple types unchanged.""" + assert _parse_annotation(str) is str + assert _parse_annotation(int) is int + assert _parse_annotation(float) is float + assert _parse_annotation(bool) is bool + + +def test_parse_annotation_with_annotated_and_literal(): + """Test that Annotated[Literal[...], description] works correctly.""" + from typing import get_args, get_origin + + # When Literal is inside Annotated, it should still be preserved + annotated_literal = Annotated[Literal["A", "B", "C"], "The category"] + result = _parse_annotation(annotated_literal) + + # The Annotated type should be preserved + origin = get_origin(result) + assert origin is Annotated + + args = get_args(result) + # First arg is the Literal type + literal_type = args[0] + assert get_origin(literal_type) is Literal + assert get_args(literal_type) == ("A", "B", "C") + + +# endregion diff --git a/python/packages/core/tests/openai/test_openai_assistants_client.py b/python/packages/core/tests/openai/test_openai_assistants_client.py index b9c32b14b5..861ccc73d1 100644 --- a/python/packages/core/tests/openai/test_openai_assistants_client.py +++ b/python/packages/core/tests/openai/test_openai_assistants_client.py @@ -463,9 +463,9 @@ async def test_openai_assistants_client_process_stream_events_requires_action(mo """Test _process_stream_events with thread.run.requires_action event.""" chat_client = create_test_openai_assistants_client(mock_async_openai) - # Mock the _create_function_call_contents method to return test content + # Mock the _parse_function_calls_from_assistants method to return test content test_function_content = FunctionCallContent(call_id="call-123", name="test_func", arguments={"arg": "value"}) - chat_client._create_function_call_contents = MagicMock(return_value=[test_function_content]) # type: ignore + chat_client._parse_function_calls_from_assistants = MagicMock(return_value=[test_function_content]) # type: ignore # Create a mock Run object mock_run = MagicMock(spec=Run) @@ -498,8 +498,8 @@ async def async_iterator() -> Any: assert update.contents[0] == test_function_content assert update.raw_representation == mock_run - # Verify _create_function_call_contents was called correctly - chat_client._create_function_call_contents.assert_called_once_with(mock_run, None) # type: ignore + # Verify _parse_function_calls_from_assistants was called correctly + chat_client._parse_function_calls_from_assistants.assert_called_once_with(mock_run, None) # type: ignore async def test_openai_assistants_client_process_stream_events_run_step_created(mock_async_openai: MagicMock) -> None: @@ -585,8 +585,8 @@ async def async_iterator() -> Any: assert update.raw_representation == mock_run -def test_openai_assistants_client_create_function_call_contents_basic(mock_async_openai: MagicMock) -> None: - """Test _create_function_call_contents with a simple function call.""" +def test_openai_assistants_client_parse_function_calls_from_assistants_basic(mock_async_openai: MagicMock) -> None: + """Test _parse_function_calls_from_assistants with a simple function call.""" chat_client = create_test_openai_assistants_client(mock_async_openai) @@ -605,7 +605,7 @@ def test_openai_assistants_client_create_function_call_contents_basic(mock_async # Call the method response_id = "response_456" - contents = chat_client._create_function_call_contents(mock_run, response_id) # type: ignore + contents = chat_client._parse_function_calls_from_assistants(mock_run, response_id) # type: ignore # Test that one function call content was created assert len(contents) == 1 @@ -825,24 +825,24 @@ def test_openai_assistants_client_prepare_options_with_image_content(mock_async_ assert message["content"][0]["image_url"]["url"] == "https://example.com/image.jpg" -def test_openai_assistants_client_convert_function_results_to_tool_output_empty(mock_async_openai: MagicMock) -> None: - """Test _convert_function_results_to_tool_output with empty list.""" +def test_openai_assistants_client_prepare_tool_outputs_for_assistants_empty(mock_async_openai: MagicMock) -> None: + """Test _prepare_tool_outputs_for_assistants with empty list.""" chat_client = create_test_openai_assistants_client(mock_async_openai) - run_id, tool_outputs = chat_client._convert_function_results_to_tool_output([]) # type: ignore + run_id, tool_outputs = chat_client._prepare_tool_outputs_for_assistants([]) # type: ignore assert run_id is None assert tool_outputs is None -def test_openai_assistants_client_convert_function_results_to_tool_output_valid(mock_async_openai: MagicMock) -> None: - """Test _convert_function_results_to_tool_output with valid function results.""" +def test_openai_assistants_client_prepare_tool_outputs_for_assistants_valid(mock_async_openai: MagicMock) -> None: + """Test _prepare_tool_outputs_for_assistants with valid function results.""" chat_client = create_test_openai_assistants_client(mock_async_openai) call_id = json.dumps(["run-123", "call-456"]) function_result = FunctionResultContent(call_id=call_id, result="Function executed successfully") - run_id, tool_outputs = chat_client._convert_function_results_to_tool_output([function_result]) # type: ignore + run_id, tool_outputs = chat_client._prepare_tool_outputs_for_assistants([function_result]) # type: ignore assert run_id == "run-123" assert tool_outputs is not None @@ -851,10 +851,10 @@ def test_openai_assistants_client_convert_function_results_to_tool_output_valid( assert tool_outputs[0].get("output") == "Function executed successfully" -def test_openai_assistants_client_convert_function_results_to_tool_output_mismatched_run_ids( +def test_openai_assistants_client_prepare_tool_outputs_for_assistants_mismatched_run_ids( mock_async_openai: MagicMock, ) -> None: - """Test _convert_function_results_to_tool_output with mismatched run IDs.""" + """Test _prepare_tool_outputs_for_assistants with mismatched run IDs.""" chat_client = create_test_openai_assistants_client(mock_async_openai) # Create function results with different run IDs @@ -863,7 +863,7 @@ def test_openai_assistants_client_convert_function_results_to_tool_output_mismat function_result1 = FunctionResultContent(call_id=call_id1, result="Result 1") function_result2 = FunctionResultContent(call_id=call_id2, result="Result 2") - run_id, tool_outputs = chat_client._convert_function_results_to_tool_output([function_result1, function_result2]) # type: ignore + run_id, tool_outputs = chat_client._prepare_tool_outputs_for_assistants([function_result1, function_result2]) # type: ignore # Should only process the first one since run IDs don't match assert run_id == "run-123" diff --git a/python/packages/core/tests/openai/test_openai_chat_client.py b/python/packages/core/tests/openai/test_openai_chat_client.py index 8af3ed61aa..d2ddc1fb02 100644 --- a/python/packages/core/tests/openai/test_openai_chat_client.py +++ b/python/packages/core/tests/openai/test_openai_chat_client.py @@ -182,12 +182,12 @@ def test_unsupported_tool_handling(openai_unit_test_env: dict[str, str]) -> None unsupported_tool.__class__.__name__ = "UnsupportedAITool" # This should ignore the unsupported ToolProtocol and return empty list - result = client._chat_to_tool_spec([unsupported_tool]) # type: ignore + result = client._prepare_tools_for_openai([unsupported_tool]) # type: ignore assert result == [] # Also test with a non-ToolProtocol that should be converted to dict dict_tool = {"type": "function", "name": "test"} - result = client._chat_to_tool_spec([dict_tool]) # type: ignore + result = client._prepare_tools_for_openai([dict_tool]) # type: ignore assert result == [dict_tool] @@ -637,7 +637,7 @@ def test_chat_response_content_order_text_before_tool_calls(openai_unit_test_env ) client = OpenAIChatClient() - response = client._create_chat_response(mock_response, ChatOptions()) + response = client._parse_response_from_openai(mock_response, ChatOptions()) # Verify we have both text and tool call content assert len(response.messages) == 1 @@ -658,7 +658,7 @@ def test_function_result_falsy_values_handling(openai_unit_test_env: dict[str, s # Test with empty list (falsy but not None) message_with_empty_list = ChatMessage(role="tool", contents=[FunctionResultContent(call_id="call-123", result=[])]) - openai_messages = client._openai_chat_message_parser(message_with_empty_list) + openai_messages = client._prepare_message_for_openai(message_with_empty_list) assert len(openai_messages) == 1 assert openai_messages[0]["content"] == "[]" # Empty list should be JSON serialized @@ -667,14 +667,14 @@ def test_function_result_falsy_values_handling(openai_unit_test_env: dict[str, s role="tool", contents=[FunctionResultContent(call_id="call-456", result="")] ) - openai_messages = client._openai_chat_message_parser(message_with_empty_string) + openai_messages = client._prepare_message_for_openai(message_with_empty_string) assert len(openai_messages) == 1 assert openai_messages[0]["content"] == "" # Empty string should be preserved # Test with False (falsy but not None) message_with_false = ChatMessage(role="tool", contents=[FunctionResultContent(call_id="call-789", result=False)]) - openai_messages = client._openai_chat_message_parser(message_with_false) + openai_messages = client._prepare_message_for_openai(message_with_false) assert len(openai_messages) == 1 assert openai_messages[0]["content"] == "false" # False should be JSON serialized @@ -695,7 +695,7 @@ def test_function_result_exception_handling(openai_unit_test_env: dict[str, str] ], ) - openai_messages = client._openai_chat_message_parser(message_with_exception) + openai_messages = client._prepare_message_for_openai(message_with_exception) assert len(openai_messages) == 1 assert openai_messages[0]["content"] == "Error: Function failed." assert openai_messages[0]["tool_call_id"] == "call-123" @@ -708,8 +708,8 @@ def test_prepare_function_call_results_string_passthrough(): assert isinstance(result, str) -def test_openai_content_parser_data_content_image(openai_unit_test_env: dict[str, str]) -> None: - """Test _openai_content_parser converts DataContent with image media type to OpenAI format.""" +def test_prepare_content_for_openai_data_content_image(openai_unit_test_env: dict[str, str]) -> None: + """Test _prepare_content_for_openai converts DataContent with image media type to OpenAI format.""" client = OpenAIChatClient() # Test DataContent with image media type @@ -718,7 +718,7 @@ def test_openai_content_parser_data_content_image(openai_unit_test_env: dict[str media_type="image/png", ) - result = client._openai_content_parser(image_data_content) # type: ignore + result = client._prepare_content_for_openai(image_data_content) # type: ignore # Should convert to OpenAI image_url format assert result["type"] == "image_url" @@ -727,7 +727,7 @@ def test_openai_content_parser_data_content_image(openai_unit_test_env: dict[str # Test DataContent with non-image media type should use default model_dump text_data_content = DataContent(uri="data:text/plain;base64,SGVsbG8gV29ybGQ=", media_type="text/plain") - result = client._openai_content_parser(text_data_content) # type: ignore + result = client._prepare_content_for_openai(text_data_content) # type: ignore # Should use default model_dump format assert result["type"] == "data" @@ -740,7 +740,7 @@ def test_openai_content_parser_data_content_image(openai_unit_test_env: dict[str media_type="audio/wav", ) - result = client._openai_content_parser(audio_data_content) # type: ignore + result = client._prepare_content_for_openai(audio_data_content) # type: ignore # Should convert to OpenAI input_audio format assert result["type"] == "input_audio" @@ -751,7 +751,7 @@ def test_openai_content_parser_data_content_image(openai_unit_test_env: dict[str # Test DataContent with MP3 audio mp3_data_content = DataContent(uri="data:audio/mp3;base64,//uQAAAAWGluZwAAAA8AAAACAAACcQ==", media_type="audio/mp3") - result = client._openai_content_parser(mp3_data_content) # type: ignore + result = client._prepare_content_for_openai(mp3_data_content) # type: ignore # Should convert to OpenAI input_audio format with mp3 assert result["type"] == "input_audio" @@ -760,8 +760,8 @@ def test_openai_content_parser_data_content_image(openai_unit_test_env: dict[str assert result["input_audio"]["format"] == "mp3" -def test_openai_content_parser_document_file_mapping(openai_unit_test_env: dict[str, str]) -> None: - """Test _openai_content_parser converts document files (PDF, DOCX, etc.) to OpenAI file format.""" +def test_prepare_content_for_openai_document_file_mapping(openai_unit_test_env: dict[str, str]) -> None: + """Test _prepare_content_for_openai converts document files (PDF, DOCX, etc.) to OpenAI file format.""" client = OpenAIChatClient() # Test PDF without filename - should omit filename in OpenAI payload @@ -770,7 +770,7 @@ def test_openai_content_parser_document_file_mapping(openai_unit_test_env: dict[ media_type="application/pdf", ) - result = client._openai_content_parser(pdf_data_content) # type: ignore + result = client._prepare_content_for_openai(pdf_data_content) # type: ignore # Should convert to OpenAI file format without filename assert result["type"] == "file" @@ -787,7 +787,7 @@ def test_openai_content_parser_document_file_mapping(openai_unit_test_env: dict[ additional_properties={"filename": "report.pdf"}, ) - result = client._openai_content_parser(pdf_with_filename) # type: ignore + result = client._prepare_content_for_openai(pdf_with_filename) # type: ignore # Should use custom filename assert result["type"] == "file" @@ -820,7 +820,7 @@ def test_openai_content_parser_document_file_mapping(openai_unit_test_env: dict[ media_type=case["media_type"], ) - result = client._openai_content_parser(doc_content) # type: ignore + result = client._prepare_content_for_openai(doc_content) # type: ignore # All application/* types should now be mapped to file format assert result["type"] == "file" @@ -834,7 +834,7 @@ def test_openai_content_parser_document_file_mapping(openai_unit_test_env: dict[ additional_properties={"filename": case["filename"]}, ) - result = client._openai_content_parser(doc_with_filename) # type: ignore + result = client._prepare_content_for_openai(doc_with_filename) # type: ignore # Should now use file format with filename assert result["type"] == "file" @@ -848,7 +848,7 @@ def test_openai_content_parser_document_file_mapping(openai_unit_test_env: dict[ additional_properties={}, ) - result = client._openai_content_parser(pdf_empty_props) # type: ignore + result = client._prepare_content_for_openai(pdf_empty_props) # type: ignore assert result["type"] == "file" assert "filename" not in result["file"] @@ -860,7 +860,7 @@ def test_openai_content_parser_document_file_mapping(openai_unit_test_env: dict[ additional_properties={"filename": None}, ) - result = client._openai_content_parser(pdf_none_filename) # type: ignore + result = client._prepare_content_for_openai(pdf_none_filename) # type: ignore assert result["type"] == "file" assert "filename" not in result["file"] # None filename should be omitted diff --git a/python/packages/core/tests/openai/test_openai_chat_client_base.py b/python/packages/core/tests/openai/test_openai_chat_client_base.py index b146bad613..3e48899509 100644 --- a/python/packages/core/tests/openai/test_openai_chat_client_base.py +++ b/python/packages/core/tests/openai/test_openai_chat_client_base.py @@ -76,7 +76,7 @@ async def test_cmc( mock_create.assert_awaited_once_with( model=openai_unit_test_env["OPENAI_CHAT_MODEL_ID"], stream=False, - messages=openai_chat_completion._prepare_chat_history_for_request(chat_history), # type: ignore + messages=openai_chat_completion._prepare_messages_for_openai(chat_history), # type: ignore ) @@ -97,7 +97,7 @@ async def test_cmc_chat_options( mock_create.assert_awaited_once_with( model=openai_unit_test_env["OPENAI_CHAT_MODEL_ID"], stream=False, - messages=openai_chat_completion._prepare_chat_history_for_request(chat_history), # type: ignore + messages=openai_chat_completion._prepare_messages_for_openai(chat_history), # type: ignore ) @@ -120,7 +120,7 @@ async def test_cmc_no_fcc_in_response( mock_create.assert_awaited_once_with( model=openai_unit_test_env["OPENAI_CHAT_MODEL_ID"], stream=False, - messages=openai_chat_completion._prepare_chat_history_for_request(orig_chat_history), # type: ignore + messages=openai_chat_completion._prepare_messages_for_openai(orig_chat_history), # type: ignore ) @@ -167,7 +167,7 @@ async def test_scmc_chat_options( model=openai_unit_test_env["OPENAI_CHAT_MODEL_ID"], stream=True, stream_options={"include_usage": True}, - messages=openai_chat_completion._prepare_chat_history_for_request(chat_history), # type: ignore + messages=openai_chat_completion._prepare_messages_for_openai(chat_history), # type: ignore ) @@ -203,7 +203,7 @@ async def test_cmc_additional_properties( mock_create.assert_awaited_once_with( model=openai_unit_test_env["OPENAI_CHAT_MODEL_ID"], stream=False, - messages=openai_chat_completion._prepare_chat_history_for_request(chat_history), # type: ignore + messages=openai_chat_completion._prepare_messages_for_openai(chat_history), # type: ignore reasoning_effort="low", ) @@ -246,7 +246,7 @@ async def test_get_streaming( model=openai_unit_test_env["OPENAI_CHAT_MODEL_ID"], stream=True, stream_options={"include_usage": True}, - messages=openai_chat_completion._prepare_chat_history_for_request(orig_chat_history), # type: ignore + messages=openai_chat_completion._prepare_messages_for_openai(orig_chat_history), # type: ignore ) @@ -285,7 +285,7 @@ async def test_get_streaming_singular( model=openai_unit_test_env["OPENAI_CHAT_MODEL_ID"], stream=True, stream_options={"include_usage": True}, - messages=openai_chat_completion._prepare_chat_history_for_request(orig_chat_history), # type: ignore + messages=openai_chat_completion._prepare_messages_for_openai(orig_chat_history), # type: ignore ) @@ -349,7 +349,7 @@ async def test_get_streaming_no_fcc_in_response( model=openai_unit_test_env["OPENAI_CHAT_MODEL_ID"], stream=True, stream_options={"include_usage": True}, - messages=openai_chat_completion._prepare_chat_history_for_request(orig_chat_history), # type: ignore + messages=openai_chat_completion._prepare_messages_for_openai(orig_chat_history), # type: ignore ) @@ -399,7 +399,7 @@ def test_chat_response_created_at_uses_utc(openai_unit_test_env: dict[str, str]) ) client = OpenAIChatClient() - response = client._create_chat_response(mock_response, ChatOptions()) + response = client._parse_response_from_openai(mock_response, ChatOptions()) # Verify that created_at is correctly formatted as UTC assert response.created_at is not None @@ -431,7 +431,7 @@ def test_chat_response_update_created_at_uses_utc(openai_unit_test_env: dict[str ) client = OpenAIChatClient() - response_update = client._create_chat_response_update(mock_chunk) + response_update = client._parse_response_update_from_openai(mock_chunk) # Verify that created_at is correctly formatted as UTC assert response_update.created_at is not None diff --git a/python/packages/core/tests/openai/test_openai_responses_client.py b/python/packages/core/tests/openai/test_openai_responses_client.py index cc187e01f2..451dfd9b06 100644 --- a/python/packages/core/tests/openai/test_openai_responses_client.py +++ b/python/packages/core/tests/openai/test_openai_responses_client.py @@ -368,6 +368,7 @@ async def test_response_format_parse_path() -> None: mock_parsed_response.output_parsed = None mock_parsed_response.usage = None mock_parsed_response.finish_reason = None + mock_parsed_response.conversation = None # No conversation object with patch.object(client.client.responses, "parse", return_value=mock_parsed_response): response = await client.get_response( @@ -454,7 +455,7 @@ async def test_get_streaming_response_with_all_parameters() -> None: def test_response_content_creation_with_annotations() -> None: - """Test _create_response_content with different annotation types.""" + """Test _parse_response_from_openai with different annotation types.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Create a mock response with annotated text content @@ -485,7 +486,7 @@ def test_response_content_creation_with_annotations() -> None: mock_response.output = [mock_message_item] with patch.object(client, "_get_metadata_from_response", return_value={}): - response = client._create_response_content(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore assert len(response.messages[0].contents) >= 1 assert isinstance(response.messages[0].contents[0], TextContent) @@ -494,7 +495,7 @@ def test_response_content_creation_with_annotations() -> None: def test_response_content_creation_with_refusal() -> None: - """Test _create_response_content with refusal content.""" + """Test _parse_response_from_openai with refusal content.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Create a mock response with refusal content @@ -516,7 +517,7 @@ def test_response_content_creation_with_refusal() -> None: mock_response.output = [mock_message_item] - response = client._create_response_content(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore assert len(response.messages[0].contents) == 1 assert isinstance(response.messages[0].contents[0], TextContent) @@ -524,7 +525,7 @@ def test_response_content_creation_with_refusal() -> None: def test_response_content_creation_with_reasoning() -> None: - """Test _create_response_content with reasoning content.""" + """Test _parse_response_from_openai with reasoning content.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Create a mock response with reasoning content @@ -546,7 +547,7 @@ def test_response_content_creation_with_reasoning() -> None: mock_response.output = [mock_reasoning_item] - response = client._create_response_content(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore assert len(response.messages[0].contents) == 2 assert isinstance(response.messages[0].contents[0], TextReasoningContent) @@ -554,7 +555,7 @@ def test_response_content_creation_with_reasoning() -> None: def test_response_content_creation_with_code_interpreter() -> None: - """Test _create_response_content with code interpreter outputs.""" + """Test _parse_response_from_openai with code interpreter outputs.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") @@ -582,7 +583,7 @@ def test_response_content_creation_with_code_interpreter() -> None: mock_response.output = [mock_code_interpreter_item] - response = client._create_response_content(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore assert len(response.messages[0].contents) == 2 assert isinstance(response.messages[0].contents[0], TextContent) @@ -593,7 +594,7 @@ def test_response_content_creation_with_code_interpreter() -> None: def test_response_content_creation_with_function_call() -> None: - """Test _create_response_content with function call content.""" + """Test _parse_response_from_openai with function call content.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Create a mock response with function call @@ -614,7 +615,7 @@ def test_response_content_creation_with_function_call() -> None: mock_response.output = [mock_function_call_item] - response = client._create_response_content(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore assert len(response.messages[0].contents) == 1 assert isinstance(response.messages[0].contents[0], FunctionCallContent) @@ -624,7 +625,7 @@ def test_response_content_creation_with_function_call() -> None: assert function_call.arguments == '{"location": "Seattle"}' -def test_tools_to_response_tools_with_hosted_mcp() -> None: +def test_prepare_tools_for_openai_with_hosted_mcp() -> None: """Test that HostedMCPTool is converted to the correct response tool dict.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") @@ -638,7 +639,7 @@ def test_tools_to_response_tools_with_hosted_mcp() -> None: additional_properties={"custom": "value"}, ) - resp_tools = client._tools_to_response_tools([tool]) + resp_tools = client._prepare_tools_for_openai([tool]) assert isinstance(resp_tools, list) assert len(resp_tools) == 1 mcp = resp_tools[0] @@ -654,7 +655,7 @@ def test_tools_to_response_tools_with_hosted_mcp() -> None: assert "require_approval" in mcp -def test_create_response_content_with_mcp_approval_request() -> None: +def test_parse_response_from_openai_with_mcp_approval_request() -> None: """Test that a non-streaming mcp_approval_request is parsed into FunctionApprovalRequestContent.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") @@ -675,7 +676,7 @@ def test_create_response_content_with_mcp_approval_request() -> None: mock_response.output = [mock_item] - response = client._create_response_content(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore assert isinstance(response.messages[0].contents[0], FunctionApprovalRequestContent) req = response.messages[0].contents[0] @@ -716,7 +717,7 @@ def test_responses_client_created_at_uses_utc(openai_unit_test_env: dict[str, st mock_response.output = [mock_message_item] with patch.object(client, "_get_metadata_from_response", return_value={}): - response = client._create_response_content(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore # Verify that created_at is correctly formatted as UTC assert response.created_at is not None @@ -730,7 +731,7 @@ def test_responses_client_created_at_uses_utc(openai_unit_test_env: dict[str, st ) -def test_tools_to_response_tools_with_raw_image_generation() -> None: +def test_prepare_tools_for_openai_with_raw_image_generation() -> None: """Test that raw image_generation tool dict is handled correctly with parameter mapping.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") @@ -744,7 +745,7 @@ def test_tools_to_response_tools_with_raw_image_generation() -> None: "background": "transparent", } - resp_tools = client._tools_to_response_tools([tool]) + resp_tools = client._prepare_tools_for_openai([tool]) assert isinstance(resp_tools, list) assert len(resp_tools) == 1 @@ -759,7 +760,7 @@ def test_tools_to_response_tools_with_raw_image_generation() -> None: assert image_tool["output_compression"] == 75 -def test_tools_to_response_tools_with_raw_image_generation_openai_responses_params() -> None: +def test_prepare_tools_for_openai_with_raw_image_generation_openai_responses_params() -> None: """Test raw image_generation tool with OpenAI-specific parameters.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") @@ -773,7 +774,7 @@ def test_tools_to_response_tools_with_raw_image_generation_openai_responses_para "partial_images": 2, # Should be integer 0-3 } - resp_tools = client._tools_to_response_tools([tool]) + resp_tools = client._prepare_tools_for_openai([tool]) assert isinstance(resp_tools, list) assert len(resp_tools) == 1 @@ -791,14 +792,14 @@ def test_tools_to_response_tools_with_raw_image_generation_openai_responses_para assert tool_dict["partial_images"] == 2 -def test_tools_to_response_tools_with_raw_image_generation_minimal() -> None: +def test_prepare_tools_for_openai_with_raw_image_generation_minimal() -> None: """Test raw image_generation tool with minimal configuration.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Test with minimal parameters (just type) tool = {"type": "image_generation"} - resp_tools = client._tools_to_response_tools([tool]) + resp_tools = client._prepare_tools_for_openai([tool]) assert isinstance(resp_tools, list) assert len(resp_tools) == 1 @@ -809,7 +810,7 @@ def test_tools_to_response_tools_with_raw_image_generation_minimal() -> None: assert len(image_tool) == 1 -def test_create_streaming_response_content_with_mcp_approval_request() -> None: +def test_parse_chunk_from_openai_with_mcp_approval_request() -> None: """Test that a streaming mcp_approval_request event is parsed into FunctionApprovalRequestContent.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") chat_options = ChatOptions() @@ -825,14 +826,14 @@ def test_create_streaming_response_content_with_mcp_approval_request() -> None: mock_item.server_label = "My_MCP" mock_event.item = mock_item - update = client._create_streaming_response_content(mock_event, chat_options, function_call_ids) + update = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids) assert any(isinstance(c, FunctionApprovalRequestContent) for c in update.contents) fa = next(c for c in update.contents if isinstance(c, FunctionApprovalRequestContent)) assert fa.id == "approval-stream-1" assert fa.function_call.name == "do_stream_action" -@pytest.mark.parametrize("enable_otel", [False], indirect=True) +@pytest.mark.parametrize("enable_instrumentation", [False], indirect=True) @pytest.mark.parametrize("enable_sensitive_data", [False], indirect=True) async def test_end_to_end_mcp_approval_flow(span_exporter) -> None: """End-to-end mocked test: @@ -901,7 +902,7 @@ async def test_end_to_end_mcp_approval_flow(span_exporter) -> None: def test_usage_details_basic() -> None: - """Test _usage_details_from_openai without cached or reasoning tokens.""" + """Test _parse_usage_from_openai without cached or reasoning tokens.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") mock_usage = MagicMock() @@ -911,7 +912,7 @@ def test_usage_details_basic() -> None: mock_usage.input_tokens_details = None mock_usage.output_tokens_details = None - details = client._usage_details_from_openai(mock_usage) # type: ignore + details = client._parse_usage_from_openai(mock_usage) # type: ignore assert details is not None assert details.input_token_count == 100 assert details.output_token_count == 50 @@ -919,7 +920,7 @@ def test_usage_details_basic() -> None: def test_usage_details_with_cached_tokens() -> None: - """Test _usage_details_from_openai with cached input tokens.""" + """Test _parse_usage_from_openai with cached input tokens.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") mock_usage = MagicMock() @@ -930,14 +931,14 @@ def test_usage_details_with_cached_tokens() -> None: mock_usage.input_tokens_details.cached_tokens = 25 mock_usage.output_tokens_details = None - details = client._usage_details_from_openai(mock_usage) # type: ignore + details = client._parse_usage_from_openai(mock_usage) # type: ignore assert details is not None assert details.input_token_count == 200 assert details.additional_counts["openai.cached_input_tokens"] == 25 def test_usage_details_with_reasoning_tokens() -> None: - """Test _usage_details_from_openai with reasoning tokens.""" + """Test _parse_usage_from_openai with reasoning tokens.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") mock_usage = MagicMock() @@ -948,7 +949,7 @@ def test_usage_details_with_reasoning_tokens() -> None: mock_usage.output_tokens_details = MagicMock() mock_usage.output_tokens_details.reasoning_tokens = 30 - details = client._usage_details_from_openai(mock_usage) # type: ignore + details = client._parse_usage_from_openai(mock_usage) # type: ignore assert details is not None assert details.output_token_count == 80 assert details.additional_counts["openai.reasoning_tokens"] == 30 @@ -975,7 +976,7 @@ def test_get_metadata_from_response() -> None: def test_streaming_response_basic_structure() -> None: - """Test that _create_streaming_response_content returns proper structure.""" + """Test that _parse_chunk_from_openai returns proper structure.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") chat_options = ChatOptions(store=True) function_call_ids: dict[int, tuple[str, str]] = {} @@ -983,7 +984,7 @@ def test_streaming_response_basic_structure() -> None: # Test with a basic mock event to ensure the method returns proper structure mock_event = MagicMock() - response = client._create_streaming_response_content(mock_event, chat_options, function_call_ids) # type: ignore + response = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids) # type: ignore # Should get a valid ChatResponseUpdate structure assert isinstance(response, ChatResponseUpdate) @@ -1008,7 +1009,7 @@ def test_streaming_annotation_added_with_file_path() -> None: "index": 42, } - response = client._create_streaming_response_content(mock_event, chat_options, function_call_ids) + response = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids) assert len(response.contents) == 1 content = response.contents[0] @@ -1035,7 +1036,7 @@ def test_streaming_annotation_added_with_file_citation() -> None: "index": 15, } - response = client._create_streaming_response_content(mock_event, chat_options, function_call_ids) + response = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids) assert len(response.contents) == 1 content = response.contents[0] @@ -1064,7 +1065,7 @@ def test_streaming_annotation_added_with_container_file_citation() -> None: "end_index": 50, } - response = client._create_streaming_response_content(mock_event, chat_options, function_call_ids) + response = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids) assert len(response.contents) == 1 content = response.contents[0] @@ -1091,7 +1092,7 @@ def test_streaming_annotation_added_with_unknown_type() -> None: "url": "https://example.com", } - response = client._create_streaming_response_content(mock_event, chat_options, function_call_ids) + response = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids) # url_citation should not produce HostedFileContent assert len(response.contents) == 0 @@ -1137,8 +1138,8 @@ async def run_streaming(): asyncio.run(run_streaming()) -def test_openai_content_parser_image_content() -> None: - """Test _openai_content_parser with image content variations.""" +def test_prepare_content_for_openai_image_content() -> None: + """Test _prepare_content_for_openai with image content variations.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Test image content with detail parameter and file_id @@ -1147,7 +1148,7 @@ def test_openai_content_parser_image_content() -> None: media_type="image/jpeg", additional_properties={"detail": "high", "file_id": "file_123"}, ) - result = client._openai_content_parser(Role.USER, image_content_with_detail, {}) # type: ignore + result = client._prepare_content_for_openai(Role.USER, image_content_with_detail, {}) # type: ignore assert result["type"] == "input_image" assert result["image_url"] == "https://example.com/image.jpg" assert result["detail"] == "high" @@ -1155,47 +1156,47 @@ def test_openai_content_parser_image_content() -> None: # Test image content without additional properties (defaults) image_content_basic = UriContent(uri="https://example.com/basic.png", media_type="image/png") - result = client._openai_content_parser(Role.USER, image_content_basic, {}) # type: ignore + result = client._prepare_content_for_openai(Role.USER, image_content_basic, {}) # type: ignore assert result["type"] == "input_image" assert result["detail"] == "auto" assert result["file_id"] is None -def test_openai_content_parser_audio_content() -> None: - """Test _openai_content_parser with audio content variations.""" +def test_prepare_content_for_openai_audio_content() -> None: + """Test _prepare_content_for_openai with audio content variations.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Test WAV audio content wav_content = UriContent(uri="data:audio/wav;base64,abc123", media_type="audio/wav") - result = client._openai_content_parser(Role.USER, wav_content, {}) # type: ignore + result = client._prepare_content_for_openai(Role.USER, wav_content, {}) # type: ignore assert result["type"] == "input_audio" assert result["input_audio"]["data"] == "data:audio/wav;base64,abc123" assert result["input_audio"]["format"] == "wav" # Test MP3 audio content mp3_content = UriContent(uri="data:audio/mp3;base64,def456", media_type="audio/mp3") - result = client._openai_content_parser(Role.USER, mp3_content, {}) # type: ignore + result = client._prepare_content_for_openai(Role.USER, mp3_content, {}) # type: ignore assert result["type"] == "input_audio" assert result["input_audio"]["format"] == "mp3" -def test_openai_content_parser_unsupported_content() -> None: - """Test _openai_content_parser with unsupported content types.""" +def test_prepare_content_for_openai_unsupported_content() -> None: + """Test _prepare_content_for_openai with unsupported content types.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Test unsupported audio format unsupported_audio = UriContent(uri="data:audio/ogg;base64,ghi789", media_type="audio/ogg") - result = client._openai_content_parser(Role.USER, unsupported_audio, {}) # type: ignore + result = client._prepare_content_for_openai(Role.USER, unsupported_audio, {}) # type: ignore assert result == {} # Test non-media content text_uri_content = UriContent(uri="https://example.com/document.txt", media_type="text/plain") - result = client._openai_content_parser(Role.USER, text_uri_content, {}) # type: ignore + result = client._prepare_content_for_openai(Role.USER, text_uri_content, {}) # type: ignore assert result == {} -def test_create_streaming_response_content_code_interpreter() -> None: - """Test _create_streaming_response_content with code_interpreter_call.""" +def test_parse_chunk_from_openai_code_interpreter() -> None: + """Test _parse_chunk_from_openai with code_interpreter_call.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") chat_options = ChatOptions() function_call_ids: dict[int, tuple[str, str]] = {} @@ -1211,15 +1212,15 @@ def test_create_streaming_response_content_code_interpreter() -> None: mock_item_image.code = None mock_event_image.item = mock_item_image - result = client._create_streaming_response_content(mock_event_image, chat_options, function_call_ids) # type: ignore + result = client._parse_chunk_from_openai(mock_event_image, chat_options, function_call_ids) # type: ignore assert len(result.contents) == 1 assert isinstance(result.contents[0], UriContent) assert result.contents[0].uri == "https://example.com/plot.png" assert result.contents[0].media_type == "image" -def test_create_streaming_response_content_reasoning() -> None: - """Test _create_streaming_response_content with reasoning content.""" +def test_parse_chunk_from_openai_reasoning() -> None: + """Test _parse_chunk_from_openai with reasoning content.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") chat_options = ChatOptions() function_call_ids: dict[int, tuple[str, str]] = {} @@ -1234,7 +1235,7 @@ def test_create_streaming_response_content_reasoning() -> None: mock_item_reasoning.summary = ["Problem analysis summary"] mock_event_reasoning.item = mock_item_reasoning - result = client._create_streaming_response_content(mock_event_reasoning, chat_options, function_call_ids) # type: ignore + result = client._parse_chunk_from_openai(mock_event_reasoning, chat_options, function_call_ids) # type: ignore assert len(result.contents) == 1 assert isinstance(result.contents[0], TextReasoningContent) assert result.contents[0].text == "Analyzing the problem step by step..." @@ -1242,8 +1243,8 @@ def test_create_streaming_response_content_reasoning() -> None: assert result.contents[0].additional_properties["summary"] == "Problem analysis summary" -def test_openai_content_parser_text_reasoning_comprehensive() -> None: - """Test _openai_content_parser with TextReasoningContent all additional properties.""" +def test_prepare_content_for_openai_text_reasoning_comprehensive() -> None: + """Test _prepare_content_for_openai with TextReasoningContent all additional properties.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Test TextReasoningContent with all additional properties @@ -1255,7 +1256,7 @@ def test_openai_content_parser_text_reasoning_comprehensive() -> None: "encrypted_content": "secure_data_456", }, ) - result = client._openai_content_parser(Role.ASSISTANT, comprehensive_reasoning, {}) # type: ignore + result = client._prepare_content_for_openai(Role.ASSISTANT, comprehensive_reasoning, {}) # type: ignore assert result["type"] == "reasoning" assert result["summary"]["text"] == "Comprehensive reasoning summary" assert result["status"] == "in_progress" @@ -1280,7 +1281,7 @@ def test_streaming_reasoning_text_delta_event() -> None: ) with patch.object(client, "_get_metadata_from_response", return_value={}) as mock_metadata: - response = client._create_streaming_response_content(event, chat_options, function_call_ids) # type: ignore + response = client._parse_chunk_from_openai(event, chat_options, function_call_ids) # type: ignore assert len(response.contents) == 1 assert isinstance(response.contents[0], TextReasoningContent) @@ -1305,7 +1306,7 @@ def test_streaming_reasoning_text_done_event() -> None: ) with patch.object(client, "_get_metadata_from_response", return_value={"test": "data"}) as mock_metadata: - response = client._create_streaming_response_content(event, chat_options, function_call_ids) # type: ignore + response = client._parse_chunk_from_openai(event, chat_options, function_call_ids) # type: ignore assert len(response.contents) == 1 assert isinstance(response.contents[0], TextReasoningContent) @@ -1331,7 +1332,7 @@ def test_streaming_reasoning_summary_text_delta_event() -> None: ) with patch.object(client, "_get_metadata_from_response", return_value={}) as mock_metadata: - response = client._create_streaming_response_content(event, chat_options, function_call_ids) # type: ignore + response = client._parse_chunk_from_openai(event, chat_options, function_call_ids) # type: ignore assert len(response.contents) == 1 assert isinstance(response.contents[0], TextReasoningContent) @@ -1356,7 +1357,7 @@ def test_streaming_reasoning_summary_text_done_event() -> None: ) with patch.object(client, "_get_metadata_from_response", return_value={"custom": "meta"}) as mock_metadata: - response = client._create_streaming_response_content(event, chat_options, function_call_ids) # type: ignore + response = client._parse_chunk_from_openai(event, chat_options, function_call_ids) # type: ignore assert len(response.contents) == 1 assert isinstance(response.contents[0], TextReasoningContent) @@ -1392,8 +1393,8 @@ def test_streaming_reasoning_events_preserve_metadata() -> None: ) with patch.object(client, "_get_metadata_from_response", return_value={"test": "metadata"}): - text_response = client._create_streaming_response_content(text_event, chat_options, function_call_ids) # type: ignore - reasoning_response = client._create_streaming_response_content(reasoning_event, chat_options, function_call_ids) # type: ignore + text_response = client._parse_chunk_from_openai(text_event, chat_options, function_call_ids) # type: ignore + reasoning_response = client._parse_chunk_from_openai(reasoning_event, chat_options, function_call_ids) # type: ignore # Both should preserve metadata assert text_response.additional_properties == {"test": "metadata"} @@ -1404,7 +1405,7 @@ def test_streaming_reasoning_events_preserve_metadata() -> None: assert isinstance(reasoning_response.contents[0], TextReasoningContent) -def test_create_response_content_image_generation_raw_base64(): +def test_parse_response_from_openai_image_generation_raw_base64(): """Test image generation response parsing with raw base64 string.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") @@ -1428,7 +1429,7 @@ def test_create_response_content_image_generation_raw_base64(): mock_response.output = [mock_item] with patch.object(client, "_get_metadata_from_response", return_value={}): - response = client._create_response_content(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore # Verify the response contains DataContent with proper URI and media_type assert len(response.messages[0].contents) == 1 @@ -1438,7 +1439,7 @@ def test_create_response_content_image_generation_raw_base64(): assert content.media_type == "image/png" -def test_create_response_content_image_generation_existing_data_uri(): +def test_parse_response_from_openai_image_generation_existing_data_uri(): """Test image generation response parsing with existing data URI.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") @@ -1461,7 +1462,7 @@ def test_create_response_content_image_generation_existing_data_uri(): mock_response.output = [mock_item] with patch.object(client, "_get_metadata_from_response", return_value={}): - response = client._create_response_content(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore # Verify the response contains DataContent with proper media_type parsed from URI assert len(response.messages[0].contents) == 1 @@ -1471,7 +1472,7 @@ def test_create_response_content_image_generation_existing_data_uri(): assert content.media_type == "image/webp" -def test_create_response_content_image_generation_format_detection(): +def test_parse_response_from_openai_image_generation_format_detection(): """Test different image format detection from base64 data.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") @@ -1493,7 +1494,7 @@ def test_create_response_content_image_generation_format_detection(): mock_response_jpeg.output = [mock_item_jpeg] with patch.object(client, "_get_metadata_from_response", return_value={}): - response_jpeg = client._create_response_content(mock_response_jpeg, chat_options=ChatOptions()) # type: ignore + response_jpeg = client._parse_response_from_openai(mock_response_jpeg, chat_options=ChatOptions()) # type: ignore content_jpeg = response_jpeg.messages[0].contents[0] assert isinstance(content_jpeg, DataContent) assert content_jpeg.media_type == "image/jpeg" @@ -1517,14 +1518,14 @@ def test_create_response_content_image_generation_format_detection(): mock_response_webp.output = [mock_item_webp] with patch.object(client, "_get_metadata_from_response", return_value={}): - response_webp = client._create_response_content(mock_response_webp, chat_options=ChatOptions()) # type: ignore + response_webp = client._parse_response_from_openai(mock_response_webp, chat_options=ChatOptions()) # type: ignore content_webp = response_webp.messages[0].contents[0] assert isinstance(content_webp, DataContent) assert content_webp.media_type == "image/webp" assert "data:image/webp;base64," in content_webp.uri -def test_create_response_content_image_generation_fallback(): +def test_parse_response_from_openai_image_generation_fallback(): """Test image generation with invalid base64 falls back to PNG.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") @@ -1547,7 +1548,7 @@ def test_create_response_content_image_generation_fallback(): mock_response.output = [mock_item] with patch.object(client, "_get_metadata_from_response", return_value={}): - response = client._create_response_content(mock_response, chat_options=ChatOptions()) # type: ignore + response = client._parse_response_from_openai(mock_response, chat_options=ChatOptions()) # type: ignore # Verify it falls back to PNG format for unrecognized binary data assert len(response.messages[0].contents) == 1 @@ -1563,21 +1564,21 @@ async def test_prepare_options_store_parameter_handling() -> None: test_conversation_id = "test-conversation-123" chat_options = ChatOptions(store=True, conversation_id=test_conversation_id) - options = await client.prepare_options(messages, chat_options) + options = await client._prepare_options(messages, chat_options) # type: ignore assert options["store"] is True assert options["previous_response_id"] == test_conversation_id chat_options = ChatOptions(store=False, conversation_id="") - options = await client.prepare_options(messages, chat_options) + options = await client._prepare_options(messages, chat_options) # type: ignore assert options["store"] is False chat_options = ChatOptions(store=None, conversation_id=None) - options = await client.prepare_options(messages, chat_options) + options = await client._prepare_options(messages, chat_options) # type: ignore assert "store" not in options assert "previous_response_id" not in options chat_options = ChatOptions() - options = await client.prepare_options(messages, chat_options) + options = await client._prepare_options(messages, chat_options) # type: ignore assert "store" not in options assert "previous_response_id" not in options diff --git a/python/packages/core/tests/test_observability_datetime.py b/python/packages/core/tests/test_observability_datetime.py index 05efdc1a5e..6ad3d77e1a 100644 --- a/python/packages/core/tests/test_observability_datetime.py +++ b/python/packages/core/tests/test_observability_datetime.py @@ -22,5 +22,5 @@ def test_datetime_in_tool_results() -> None: result = _to_otel_part(content) parsed = json.loads(result["response"]) - # Datetime should be converted to string - assert isinstance(parsed["timestamp"], str) + # Datetime should be converted to string in the result field + assert isinstance(parsed["result"]["timestamp"], str) diff --git a/python/packages/core/tests/workflow/test_handoff.py b/python/packages/core/tests/workflow/test_handoff.py index 0ceccfaf15..d0d5092323 100644 --- a/python/packages/core/tests/workflow/test_handoff.py +++ b/python/packages/core/tests/workflow/test_handoff.py @@ -25,7 +25,12 @@ from agent_framework._mcp import MCPTool from agent_framework._workflows import AgentRunEvent from agent_framework._workflows import _handoff as handoff_module # type: ignore -from agent_framework._workflows._handoff import _clone_chat_agent # type: ignore[reportPrivateUsage] +from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value +from agent_framework._workflows._handoff import ( + _clone_chat_agent, # type: ignore[reportPrivateUsage] + _ConversationWithUserInput, + _UserInputGateway, +) from agent_framework._workflows._workflow_builder import WorkflowBuilder @@ -218,7 +223,7 @@ async def test_handoff_preserves_complex_additional_properties(complex_metadata: workflow = ( HandoffBuilder(participants=[triage, specialist]) - .set_coordinator("triage") + .set_coordinator(triage) .with_termination_condition(lambda conv: sum(1 for msg in conv if msg.role == Role.USER) >= 2) .build() ) @@ -281,7 +286,7 @@ async def test_tool_call_handoff_detection_with_text_hint(): triage = _RecordingAgent(name="triage", handoff_to="specialist", text_handoff=True) specialist = _RecordingAgent(name="specialist") - workflow = HandoffBuilder(participants=[triage, specialist]).set_coordinator("triage").build() + workflow = HandoffBuilder(participants=[triage, specialist]).set_coordinator(triage).build() await _drain(workflow.run_stream("Package arrived broken")) @@ -296,7 +301,7 @@ async def test_autonomous_interaction_mode_yields_output_without_user_request(): workflow = ( HandoffBuilder(participants=[triage, specialist]) - .set_coordinator("triage") + .set_coordinator(triage) .with_interaction_mode("autonomous", autonomous_turn_limit=1) .build() ) @@ -428,13 +433,13 @@ def test_build_fails_without_coordinator(): triage = _RecordingAgent(name="triage") specialist = _RecordingAgent(name="specialist") - with pytest.raises(ValueError, match="coordinator must be defined before build"): + with pytest.raises(ValueError, match=r"Must call set_coordinator\(...\) before building the workflow."): HandoffBuilder(participants=[triage, specialist]).build() def test_build_fails_without_participants(): """Verify that build() raises ValueError when no participants are provided.""" - with pytest.raises(ValueError, match="No participants provided"): + with pytest.raises(ValueError, match="No participants or participant_factories have been configured."): HandoffBuilder().build() @@ -605,7 +610,7 @@ async def test_return_to_previous_enabled(): workflow = ( HandoffBuilder(participants=[triage, specialist_a, specialist_b]) - .set_coordinator("triage") + .set_coordinator(triage) .enable_return_to_previous(True) .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3) .build() @@ -638,7 +643,7 @@ def test_handoff_builder_sets_start_executor_once(monkeypatch: pytest.MonkeyPatc workflow = ( HandoffBuilder(participants=[coordinator, specialist]) - .set_coordinator("coordinator") + .set_coordinator(coordinator) .with_termination_condition(lambda conv: len(conv) > 0) .build() ) @@ -698,7 +703,7 @@ async def test_handoff_builder_with_request_info(): # Build workflow with request info enabled workflow = ( HandoffBuilder(participants=[coordinator, specialist]) - .set_coordinator("coordinator") + .set_coordinator(coordinator) .with_termination_condition(lambda conv: len([m for m in conv if m.role == Role.USER]) >= 1) .with_request_info() .build() @@ -775,3 +780,893 @@ async def test_return_to_previous_state_serialization(): # Verify current_agent_id was restored assert coordinator2._current_agent_id == "specialist_a", "Current agent should be restored from checkpoint" # type: ignore[reportPrivateUsage] + + +# region Participant Factory Tests + + +def test_handoff_builder_rejects_empty_participant_factories(): + """Test that HandoffBuilder rejects empty participant_factories dictionary.""" + # Empty factories are rejected immediately when calling participant_factories() + with pytest.raises(ValueError, match=r"participant_factories cannot be empty"): + HandoffBuilder().participant_factories({}) + + with pytest.raises(ValueError, match=r"No participants or participant_factories have been configured"): + HandoffBuilder(participant_factories={}).build() + + +def test_handoff_builder_rejects_mixing_participants_and_factories(): + """Test that mixing participants and participant_factories in __init__ raises an error.""" + triage = _RecordingAgent(name="triage") + with pytest.raises(ValueError, match="Cannot mix .participants"): + HandoffBuilder(participants=[triage], participant_factories={"triage": lambda: triage}) + + +def test_handoff_builder_rejects_mixing_participants_and_participant_factories_methods(): + """Test that mixing .participants() and .participant_factories() raises an error.""" + triage = _RecordingAgent(name="triage") + + # Case 1: participants first, then participant_factories + with pytest.raises(ValueError, match="Cannot mix .participants"): + HandoffBuilder(participants=[triage]).participant_factories({ + "specialist": lambda: _RecordingAgent(name="specialist") + }) + + # Case 2: participant_factories first, then participants + with pytest.raises(ValueError, match="Cannot mix .participants"): + HandoffBuilder(participant_factories={"triage": lambda: triage}).participants([ + _RecordingAgent(name="specialist") + ]) + + # Case 3: participants(), then participant_factories() + with pytest.raises(ValueError, match="Cannot mix .participants"): + HandoffBuilder().participants([triage]).participant_factories({ + "specialist": lambda: _RecordingAgent(name="specialist") + }) + + # Case 4: participant_factories(), then participants() + with pytest.raises(ValueError, match="Cannot mix .participants"): + HandoffBuilder().participant_factories({"triage": lambda: triage}).participants([ + _RecordingAgent(name="specialist") + ]) + + # Case 5: mix during initialization + with pytest.raises(ValueError, match="Cannot mix .participants"): + HandoffBuilder( + participants=[triage], participant_factories={"specialist": lambda: _RecordingAgent(name="specialist")} + ) + + +def test_handoff_builder_rejects_multiple_calls_to_participant_factories(): + """Test that multiple calls to .participant_factories() raises an error.""" + with pytest.raises(ValueError, match=r"participant_factories\(\) has already been called"): + ( + HandoffBuilder() + .participant_factories({"agent1": lambda: _RecordingAgent(name="agent1")}) + .participant_factories({"agent2": lambda: _RecordingAgent(name="agent2")}) + ) + + +def test_handoff_builder_rejects_multiple_calls_to_participants(): + """Test that multiple calls to .participants() raises an error.""" + with pytest.raises(ValueError, match="participants have already been assigned"): + (HandoffBuilder().participants([_RecordingAgent(name="agent1")]).participants([_RecordingAgent(name="agent2")])) + + +def test_handoff_builder_rejects_duplicate_factories(): + """Test that multiple calls to participant_factories are rejected.""" + factories = { + "triage": lambda: _RecordingAgent(name="triage"), + "specialist": lambda: _RecordingAgent(name="specialist"), + } + + # Multiple calls to participant_factories should fail + builder = HandoffBuilder(participant_factories=factories) + with pytest.raises(ValueError, match=r"participant_factories\(\) has already been called"): + builder.participant_factories({"triage": lambda: _RecordingAgent(name="triage2")}) + + +def test_handoff_builder_rejects_instance_coordinator_with_factories(): + """Test that using an agent instance for set_coordinator when using factories raises an error.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage") + + def create_specialist() -> _RecordingAgent: + return _RecordingAgent(name="specialist") + + # Create an agent instance + coordinator_instance = _RecordingAgent(name="coordinator") + + with pytest.raises(ValueError, match=r"Call participants\(\.\.\.\) before coordinator\(\.\.\.\)"): + ( + HandoffBuilder( + participant_factories={"triage": create_triage, "specialist": create_specialist} + ).set_coordinator(coordinator_instance) # Instance, not factory name + ) + + +def test_handoff_builder_rejects_factory_name_coordinator_with_instances(): + """Test that using a factory name for set_coordinator when using instances raises an error.""" + triage = _RecordingAgent(name="triage") + specialist = _RecordingAgent(name="specialist") + + with pytest.raises( + ValueError, match="coordinator factory name 'triage' is not part of the participant_factories list" + ): + ( + HandoffBuilder(participants=[triage, specialist]).set_coordinator( + "triage" + ) # String factory name, not instance + ) + + +def test_handoff_builder_rejects_mixed_types_in_add_handoff_source(): + """Test that add_handoff rejects factory name source with instance-based participants.""" + triage = _RecordingAgent(name="triage") + specialist = _RecordingAgent(name="specialist") + + with pytest.raises(TypeError, match="Cannot mix factory names \\(str\\) and AgentProtocol/Executor instances"): + ( + HandoffBuilder(participants=[triage, specialist]) + .set_coordinator(triage) + .add_handoff("triage", specialist) # String source with instance participants + ) + + +def test_handoff_builder_accepts_all_factory_names_in_add_handoff(): + """Test that add_handoff accepts all factory names when using participant_factories.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage") + + def create_specialist_a() -> _RecordingAgent: + return _RecordingAgent(name="specialist_a") + + def create_specialist_b() -> _RecordingAgent: + return _RecordingAgent(name="specialist_b") + + # This should work - all strings with participant_factories + builder = ( + HandoffBuilder( + participant_factories={ + "triage": create_triage, + "specialist_a": create_specialist_a, + "specialist_b": create_specialist_b, + } + ) + .set_coordinator("triage") + .add_handoff("triage", ["specialist_a", "specialist_b"]) + ) + + workflow = builder.build() + assert "triage" in workflow.executors + assert "specialist_a" in workflow.executors + assert "specialist_b" in workflow.executors + + +def test_handoff_builder_accepts_all_instances_in_add_handoff(): + """Test that add_handoff accepts all instances when using participants.""" + triage = _RecordingAgent(name="triage", handoff_to="specialist_a") + specialist_a = _RecordingAgent(name="specialist_a") + specialist_b = _RecordingAgent(name="specialist_b") + + # This should work - all instances with participants + builder = ( + HandoffBuilder(participants=[triage, specialist_a, specialist_b]) + .set_coordinator(triage) + .add_handoff(triage, [specialist_a, specialist_b]) + ) + + workflow = builder.build() + assert "triage" in workflow.executors + assert "specialist_a" in workflow.executors + assert "specialist_b" in workflow.executors + + +async def test_handoff_with_participant_factories(): + """Test workflow creation using participant_factories.""" + call_count = 0 + + def create_triage() -> _RecordingAgent: + nonlocal call_count + call_count += 1 + return _RecordingAgent(name="triage", handoff_to="specialist") + + def create_specialist() -> _RecordingAgent: + nonlocal call_count + call_count += 1 + return _RecordingAgent(name="specialist") + + workflow = ( + HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) + .set_coordinator("triage") + .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 2) + .build() + ) + + # Factories should be called during build + assert call_count == 2 + + events = await _drain(workflow.run_stream("Need help")) + requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] + assert requests + + # Follow-up message + events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "More details"})) + outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] + assert outputs + + +async def test_handoff_participant_factories_reusable_builder(): + """Test that the builder can be reused to build multiple workflows with factories.""" + call_count = 0 + + def create_triage() -> _RecordingAgent: + nonlocal call_count + call_count += 1 + return _RecordingAgent(name="triage", handoff_to="specialist") + + def create_specialist() -> _RecordingAgent: + nonlocal call_count + call_count += 1 + return _RecordingAgent(name="specialist") + + builder = HandoffBuilder( + participant_factories={"triage": create_triage, "specialist": create_specialist} + ).set_coordinator("triage") + + # Build first workflow + wf1 = builder.build() + assert call_count == 2 + + # Build second workflow + wf2 = builder.build() + assert call_count == 4 + + # Verify that the two workflows have different agent instances + assert wf1.executors["triage"] is not wf2.executors["triage"] + assert wf1.executors["specialist"] is not wf2.executors["specialist"] + + +async def test_handoff_with_participant_factories_and_add_handoff(): + """Test that .add_handoff() works correctly with participant_factories.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage", handoff_to="specialist_a") + + def create_specialist_a() -> _RecordingAgent: + return _RecordingAgent(name="specialist_a", handoff_to="specialist_b") + + def create_specialist_b() -> _RecordingAgent: + return _RecordingAgent(name="specialist_b") + + workflow = ( + HandoffBuilder( + participant_factories={ + "triage": create_triage, + "specialist_a": create_specialist_a, + "specialist_b": create_specialist_b, + } + ) + .set_coordinator("triage") + .add_handoff("triage", ["specialist_a", "specialist_b"]) + .add_handoff("specialist_a", "specialist_b") + .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3) + .build() + ) + + # Start conversation - triage hands off to specialist_a + events = await _drain(workflow.run_stream("Initial request")) + requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] + assert requests + + # Verify specialist_a executor exists and was called + assert "specialist_a" in workflow.executors + + # Second user message - specialist_a hands off to specialist_b + events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need escalation"})) + requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] + assert requests + + # Verify specialist_b executor exists + assert "specialist_b" in workflow.executors + + +async def test_handoff_participant_factories_with_checkpointing(): + """Test checkpointing with participant_factories.""" + from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage + + storage = InMemoryCheckpointStorage() + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage", handoff_to="specialist") + + def create_specialist() -> _RecordingAgent: + return _RecordingAgent(name="specialist") + + workflow = ( + HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) + .set_coordinator("triage") + .with_checkpointing(storage) + .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 2) + .build() + ) + + # Run workflow and capture output + events = await _drain(workflow.run_stream("checkpoint test")) + requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] + assert requests + + events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "follow up"})) + outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] + assert outputs, "Should have workflow output after termination condition is met" + + # List checkpoints - just verify they were created + checkpoints = await storage.list_checkpoints() + assert checkpoints, "Checkpoints should be created during workflow execution" + + +def test_handoff_set_coordinator_with_factory_name(): + """Test that set_coordinator accepts factory name as string.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage") + + def create_specialist() -> _RecordingAgent: + return _RecordingAgent(name="specialist") + + builder = HandoffBuilder( + participant_factories={"triage": create_triage, "specialist": create_specialist} + ).set_coordinator("triage") + + workflow = builder.build() + assert "triage" in workflow.executors + + +def test_handoff_add_handoff_with_factory_names(): + """Test that add_handoff accepts factory names as strings.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage", handoff_to="specialist_a") + + def create_specialist_a() -> _RecordingAgent: + return _RecordingAgent(name="specialist_a") + + def create_specialist_b() -> _RecordingAgent: + return _RecordingAgent(name="specialist_b") + + builder = ( + HandoffBuilder( + participant_factories={ + "triage": create_triage, + "specialist_a": create_specialist_a, + "specialist_b": create_specialist_b, + } + ) + .set_coordinator("triage") + .add_handoff("triage", ["specialist_a", "specialist_b"]) + ) + + workflow = builder.build() + assert "triage" in workflow.executors + assert "specialist_a" in workflow.executors + assert "specialist_b" in workflow.executors + + +async def test_handoff_participant_factories_autonomous_mode(): + """Test autonomous mode with participant_factories.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage", handoff_to="specialist") + + def create_specialist() -> _RecordingAgent: + return _RecordingAgent(name="specialist") + + workflow = ( + HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) + .set_coordinator("triage") + .with_interaction_mode("autonomous", autonomous_turn_limit=2) + .build() + ) + + events = await _drain(workflow.run_stream("Issue")) + outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] + assert outputs, "Autonomous mode should yield output" + requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] + assert not requests, "Autonomous mode should not request user input" + + +async def test_handoff_participant_factories_with_request_info(): + """Test that .with_request_info() works with participant_factories.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage") + + def create_specialist() -> _RecordingAgent: + return _RecordingAgent(name="specialist") + + builder = ( + HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) + .set_coordinator("triage") + .with_request_info(agents=["triage"]) + ) + + workflow = builder.build() + assert "triage" in workflow.executors + + +def test_handoff_participant_factories_invalid_coordinator_name(): + """Test that set_coordinator raises error for non-existent factory name.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage") + + with pytest.raises( + ValueError, match="coordinator factory name 'nonexistent' is not part of the participant_factories list" + ): + (HandoffBuilder(participant_factories={"triage": create_triage}).set_coordinator("nonexistent").build()) + + +def test_handoff_participant_factories_invalid_handoff_target(): + """Test that add_handoff raises error for non-existent target factory name.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage") + + def create_specialist() -> _RecordingAgent: + return _RecordingAgent(name="specialist") + + with pytest.raises(ValueError, match="Target factory name 'nonexistent' is not in the participant_factories list"): + ( + HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) + .set_coordinator("triage") + .add_handoff("triage", "nonexistent") + .build() + ) + + +async def test_handoff_participant_factories_enable_return_to_previous(): + """Test return_to_previous works with participant_factories.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage", handoff_to="specialist_a") + + def create_specialist_a() -> _RecordingAgent: + return _RecordingAgent(name="specialist_a", handoff_to="specialist_b") + + def create_specialist_b() -> _RecordingAgent: + return _RecordingAgent(name="specialist_b") + + workflow = ( + HandoffBuilder( + participant_factories={ + "triage": create_triage, + "specialist_a": create_specialist_a, + "specialist_b": create_specialist_b, + } + ) + .set_coordinator("triage") + .add_handoff("triage", ["specialist_a", "specialist_b"]) + .add_handoff("specialist_a", "specialist_b") + .enable_return_to_previous(True) + .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3) + .build() + ) + + # Start conversation - triage hands off to specialist_a + events = await _drain(workflow.run_stream("Initial request")) + requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] + assert requests + + # Second user message - specialist_a hands off to specialist_b + events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need escalation"})) + requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] + assert requests + + # Third user message - should route back to specialist_b (return to previous) + events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up"})) + outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] + assert outputs or [ev for ev in events if isinstance(ev, RequestInfoEvent)] + + +# endregion Participant Factory Tests + + +async def test_handoff_user_input_request_checkpoint_excludes_conversation(): + """Test that HandoffUserInputRequest serialization excludes conversation to prevent duplication. + + Issue #2667: When checkpointing a workflow with a pending HandoffUserInputRequest, + the conversation field gets serialized twice: once in the RequestInfoEvent's data + and once in the coordinator's conversation state. On restore, this causes duplicate + messages. + + The fix is to exclude the conversation field during checkpoint serialization since + the conversation is already preserved in the coordinator's state. + """ + # Create a conversation history + conversation = [ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi there!"), + ChatMessage(role=Role.USER, text="Help me"), + ] + + # Create a HandoffUserInputRequest with the conversation + request = HandoffUserInputRequest( + conversation=conversation, + awaiting_agent_id="specialist_agent", + prompt="Please provide your input", + source_executor_id="gateway", + ) + + # Encode the request (simulating checkpoint save) + encoded = encode_checkpoint_value(request) + + # Verify conversation is NOT in the encoded output + # The fix should exclude conversation from serialization + assert isinstance(encoded, dict) + + # If using MODEL_MARKER strategy (to_dict/from_dict) + if "__af_model__" in encoded or "__af_dataclass__" in encoded: + value = encoded.get("value", {}) + assert "conversation" not in value, "conversation should be excluded from checkpoint serialization" + + # Decode the request (simulating checkpoint restore) + decoded = decode_checkpoint_value(encoded) + + # Verify the decoded request is a HandoffUserInputRequest + assert isinstance(decoded, HandoffUserInputRequest) + + # Verify other fields are preserved + assert decoded.awaiting_agent_id == "specialist_agent" + assert decoded.prompt == "Please provide your input" + assert decoded.source_executor_id == "gateway" + + # Conversation should be an empty list after deserialization + # (will be reconstructed from coordinator state on restore) + assert decoded.conversation == [] + + +async def test_handoff_user_input_request_roundtrip_preserves_metadata(): + """Test that non-conversation fields survive checkpoint roundtrip.""" + request = HandoffUserInputRequest( + conversation=[ChatMessage(role=Role.USER, text="test")], + awaiting_agent_id="test_agent", + prompt="Enter your response", + source_executor_id="test_gateway", + ) + + # Roundtrip through checkpoint encoding + encoded = encode_checkpoint_value(request) + decoded = decode_checkpoint_value(encoded) + + assert isinstance(decoded, HandoffUserInputRequest) + assert decoded.awaiting_agent_id == request.awaiting_agent_id + assert decoded.prompt == request.prompt + assert decoded.source_executor_id == request.source_executor_id + + +async def test_request_info_event_with_handoff_user_input_request(): + """Test RequestInfoEvent serialization with HandoffUserInputRequest data.""" + conversation = [ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="How can I help?"), + ] + + request = HandoffUserInputRequest( + conversation=conversation, + awaiting_agent_id="specialist", + prompt="Provide input", + source_executor_id="gateway", + ) + + # Create a RequestInfoEvent wrapping the request + event = RequestInfoEvent( + request_id="test-request-123", + source_executor_id="gateway", + request_data=request, + response_type=object, + ) + + # Serialize the event + event_dict = event.to_dict() + + # Verify the data field doesn't contain conversation + data_encoded = event_dict["data"] + if isinstance(data_encoded, dict) and ("__af_model__" in data_encoded or "__af_dataclass__" in data_encoded): + value = data_encoded.get("value", {}) + assert "conversation" not in value + + # Deserialize and verify + restored_event = RequestInfoEvent.from_dict(event_dict) + assert isinstance(restored_event.data, HandoffUserInputRequest) + assert restored_event.data.awaiting_agent_id == "specialist" + assert restored_event.data.conversation == [] + + +async def test_handoff_user_input_request_to_dict_excludes_conversation(): + """Test that to_dict() method excludes conversation field.""" + conversation = [ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi!"), + ] + + request = HandoffUserInputRequest( + conversation=conversation, + awaiting_agent_id="agent1", + prompt="Enter input", + source_executor_id="gateway", + ) + + # Call to_dict directly + data = request.to_dict() + + # Verify conversation is excluded + assert "conversation" not in data + assert data["awaiting_agent_id"] == "agent1" + assert data["prompt"] == "Enter input" + assert data["source_executor_id"] == "gateway" + + +async def test_handoff_user_input_request_from_dict_creates_empty_conversation(): + """Test that from_dict() creates an instance with empty conversation.""" + data = { + "awaiting_agent_id": "agent1", + "prompt": "Enter input", + "source_executor_id": "gateway", + } + + request = HandoffUserInputRequest.from_dict(data) + + assert request.conversation == [] + assert request.awaiting_agent_id == "agent1" + assert request.prompt == "Enter input" + assert request.source_executor_id == "gateway" + + +async def test_user_input_gateway_resume_handles_empty_conversation(): + """Test that _UserInputGateway.resume_from_user handles post-restore scenario. + + After checkpoint restore, the HandoffUserInputRequest will have an empty + conversation. The gateway should handle this by sending only the new user + messages to the coordinator. + """ + from unittest.mock import AsyncMock + + # Create a gateway + gateway = _UserInputGateway( + starting_agent_id="coordinator", + prompt="Enter input", + id="test-gateway", + ) + + # Simulate post-restore: request with empty conversation + restored_request = HandoffUserInputRequest( + conversation=[], # Empty after restore + awaiting_agent_id="specialist", + prompt="Enter input", + source_executor_id="test-gateway", + ) + + # Create mock context + mock_ctx = MagicMock() + mock_ctx.send_message = AsyncMock() + + # Call resume_from_user with a user response + await gateway.resume_from_user(restored_request, "New user message", mock_ctx) + + # Verify send_message was called + mock_ctx.send_message.assert_called_once() + + # Get the message that was sent + call_args = mock_ctx.send_message.call_args + sent_message = call_args[0][0] + + # Verify it's a _ConversationWithUserInput + assert isinstance(sent_message, _ConversationWithUserInput) + + # Verify it contains only the new user message (not any history) + assert len(sent_message.full_conversation) == 1 + assert sent_message.full_conversation[0].role == Role.USER + assert sent_message.full_conversation[0].text == "New user message" + + +async def test_user_input_gateway_resume_with_full_conversation(): + """Test that _UserInputGateway.resume_from_user handles normal flow correctly. + + In normal flow (no checkpoint restore), the HandoffUserInputRequest has + the full conversation. The gateway should send the full conversation + plus the new user messages. + """ + from unittest.mock import AsyncMock + + # Create a gateway + gateway = _UserInputGateway( + starting_agent_id="coordinator", + prompt="Enter input", + id="test-gateway", + ) + + # Normal flow: request with full conversation + normal_request = HandoffUserInputRequest( + conversation=[ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi!"), + ], + awaiting_agent_id="specialist", + prompt="Enter input", + source_executor_id="test-gateway", + ) + + # Create mock context + mock_ctx = MagicMock() + mock_ctx.send_message = AsyncMock() + + # Call resume_from_user with a user response + await gateway.resume_from_user(normal_request, "Follow up message", mock_ctx) + + # Verify send_message was called + mock_ctx.send_message.assert_called_once() + + # Get the message that was sent + call_args = mock_ctx.send_message.call_args + sent_message = call_args[0][0] + + # Verify it's a _ConversationWithUserInput + assert isinstance(sent_message, _ConversationWithUserInput) + + # Verify it contains the full conversation plus new user message + assert len(sent_message.full_conversation) == 3 + assert sent_message.full_conversation[0].text == "Hello" + assert sent_message.full_conversation[1].text == "Hi!" + assert sent_message.full_conversation[2].text == "Follow up message" + + +async def test_coordinator_handle_user_input_post_restore(): + """Test that _HandoffCoordinator.handle_user_input handles post-restore correctly. + + After checkpoint restore, the coordinator has its conversation restored, + and the gateway sends only the new user messages. The coordinator should + append these to its existing conversation rather than replacing. + """ + from unittest.mock import AsyncMock + + from agent_framework._workflows._handoff import _HandoffCoordinator + + # Create a coordinator with pre-existing conversation (simulating restored state) + coordinator = _HandoffCoordinator( + starting_agent_id="triage", + specialist_ids={"specialist_a": "specialist_a"}, + input_gateway_id="gateway", + termination_condition=lambda conv: False, + id="test-coordinator", + ) + + # Simulate restored conversation + coordinator._conversation = [ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi there!"), + ChatMessage(role=Role.USER, text="Help me"), + ChatMessage(role=Role.ASSISTANT, text="Sure, what do you need?"), + ] + + # Create mock context + mock_ctx = MagicMock() + mock_ctx.send_message = AsyncMock() + + # Simulate post-restore: only new user message with explicit flag + incoming = _ConversationWithUserInput( + full_conversation=[ChatMessage(role=Role.USER, text="I need shipping help")], + is_post_restore=True, + ) + + # Handle the user input + await coordinator.handle_user_input(incoming, mock_ctx) + + # Verify conversation was appended, not replaced + assert len(coordinator._conversation) == 5 + assert coordinator._conversation[0].text == "Hello" + assert coordinator._conversation[1].text == "Hi there!" + assert coordinator._conversation[2].text == "Help me" + assert coordinator._conversation[3].text == "Sure, what do you need?" + assert coordinator._conversation[4].text == "I need shipping help" + + +async def test_coordinator_handle_user_input_normal_flow(): + """Test that _HandoffCoordinator.handle_user_input handles normal flow correctly. + + In normal flow (no restore), the gateway sends the full conversation. + The coordinator should replace its conversation with the incoming one. + """ + from unittest.mock import AsyncMock + + from agent_framework._workflows._handoff import _HandoffCoordinator + + # Create a coordinator + coordinator = _HandoffCoordinator( + starting_agent_id="triage", + specialist_ids={"specialist_a": "specialist_a"}, + input_gateway_id="gateway", + termination_condition=lambda conv: False, + id="test-coordinator", + ) + + # Set some initial conversation + coordinator._conversation = [ + ChatMessage(role=Role.USER, text="Old message"), + ] + + # Create mock context + mock_ctx = MagicMock() + mock_ctx.send_message = AsyncMock() + + # Normal flow: full conversation including new user message (is_post_restore=False by default) + incoming = _ConversationWithUserInput( + full_conversation=[ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi!"), + ChatMessage(role=Role.USER, text="New message"), + ], + is_post_restore=False, + ) + + # Handle the user input + await coordinator.handle_user_input(incoming, mock_ctx) + + # Verify conversation was replaced (normal flow with full history) + assert len(coordinator._conversation) == 3 + assert coordinator._conversation[0].text == "Hello" + assert coordinator._conversation[1].text == "Hi!" + assert coordinator._conversation[2].text == "New message" + + +async def test_coordinator_handle_user_input_multiple_consecutive_user_messages(): + """Test that multiple consecutive USER messages in normal flow are handled correctly. + + This is a regression test for the edge case where a user submits multiple consecutive + USER messages. The explicit is_post_restore flag ensures this doesn't get incorrectly + detected as a post-restore scenario. + """ + from unittest.mock import AsyncMock + + from agent_framework._workflows._handoff import _HandoffCoordinator + + # Create a coordinator with existing conversation + coordinator = _HandoffCoordinator( + starting_agent_id="triage", + specialist_ids={"specialist_a": "specialist_a"}, + input_gateway_id="gateway", + termination_condition=lambda conv: False, + id="test-coordinator", + ) + + # Set existing conversation with 4 messages + coordinator._conversation = [ + ChatMessage(role=Role.USER, text="Original message 1"), + ChatMessage(role=Role.ASSISTANT, text="Response 1"), + ChatMessage(role=Role.USER, text="Original message 2"), + ChatMessage(role=Role.ASSISTANT, text="Response 2"), + ] + + # Create mock context + mock_ctx = MagicMock() + mock_ctx.send_message = AsyncMock() + + # Normal flow: User sends multiple consecutive USER messages + # This should REPLACE the conversation, not append to it + incoming = _ConversationWithUserInput( + full_conversation=[ + ChatMessage(role=Role.USER, text="New user message 1"), + ChatMessage(role=Role.USER, text="New user message 2"), + ], + is_post_restore=False, # Explicit flag - this is normal flow + ) + + # Handle the user input + await coordinator.handle_user_input(incoming, mock_ctx) + + # Verify conversation was REPLACED (not appended) + # Without the explicit flag, the old heuristic might incorrectly append + assert len(coordinator._conversation) == 2 + assert coordinator._conversation[0].text == "New user message 1" + assert coordinator._conversation[1].text == "New user message 2" diff --git a/python/packages/core/tests/workflow/test_magentic.py b/python/packages/core/tests/workflow/test_magentic.py index 71cfc6752a..4ee16ddb5f 100644 --- a/python/packages/core/tests/workflow/test_magentic.py +++ b/python/packages/core/tests/workflow/test_magentic.py @@ -876,3 +876,204 @@ def test_magentic_builder_does_not_have_human_input_hook(): "MagenticBuilder should not have with_human_input_hook - " "use with_plan_review() or with_human_input_on_stall() instead" ) + + +# region Message Deduplication Tests + + +async def test_magentic_no_duplicate_messages_with_conversation_history(): + """Test that passing list[ChatMessage] does not create duplicate messages in chat_history. + + When a frontend passes conversation history as list[ChatMessage], the last message + (task) should not be duplicated in the orchestrator's chat_history. + """ + manager = FakeManager(max_round_count=10) + manager.satisfied_after_signoff = True # Complete immediately after first agent response + + wf = MagenticBuilder().participants(agentA=_DummyExec("agentA")).with_standard_manager(manager).build() + + # Simulate frontend passing conversation history + conversation: list[ChatMessage] = [ + ChatMessage(role=Role.USER, text="previous question"), + ChatMessage(role=Role.ASSISTANT, text="previous answer"), + ChatMessage(role=Role.USER, text="current task"), + ] + + # Get orchestrator to inspect chat_history after run + orchestrator = None + for executor in wf.executors.values(): + if isinstance(executor, MagenticOrchestratorExecutor): + orchestrator = executor + break + + events: list[WorkflowEvent] = [] + async for event in wf.run_stream(conversation): + events.append(event) + if isinstance(event, WorkflowStatusEvent) and event.state in ( + WorkflowRunState.IDLE, + WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, + ): + break + + assert orchestrator is not None + assert orchestrator._context is not None # type: ignore[reportPrivateUsage] + + # Count occurrences of each message text in chat_history + history = orchestrator._context.chat_history # type: ignore[reportPrivateUsage] + user_task_count = sum(1 for msg in history if msg.text == "current task") + prev_question_count = sum(1 for msg in history if msg.text == "previous question") + prev_answer_count = sum(1 for msg in history if msg.text == "previous answer") + + # Each input message should appear exactly once (no duplicates) + assert prev_question_count == 1, f"Expected 1 'previous question', got {prev_question_count}" + assert prev_answer_count == 1, f"Expected 1 'previous answer', got {prev_answer_count}" + assert user_task_count == 1, f"Expected 1 'current task', got {user_task_count}" + + +async def test_magentic_agent_executor_no_duplicate_messages_on_broadcast(): + """Test that MagenticAgentExecutor does not duplicate messages from broadcasts. + + When the orchestrator broadcasts the task ledger to all agents, each agent + should receive it exactly once, not multiple times. + """ + backing_executor = _DummyExec("backing") + agent_exec = MagenticAgentExecutor(backing_executor, "agentA") + + # Simulate orchestrator sending a broadcast message + broadcast_msg = ChatMessage( + role=Role.ASSISTANT, + text="Task ledger content", + author_name="magentic_manager", + ) + + # Simulate the same message being received multiple times (e.g., from checkpoint restore + live) + from agent_framework._workflows._magentic import _MagenticResponseMessage + + response1 = _MagenticResponseMessage(body=broadcast_msg, broadcast=True) + response2 = _MagenticResponseMessage(body=broadcast_msg, broadcast=True) + + # Create a mock context + from unittest.mock import AsyncMock, MagicMock + + mock_context = MagicMock() + mock_context.send_message = AsyncMock() + + # Call the handler twice with the same message + await agent_exec.handle_response_message(response1, mock_context) # type: ignore[arg-type] + await agent_exec.handle_response_message(response2, mock_context) # type: ignore[arg-type] + + # Count how many times the broadcast message appears + history = agent_exec._chat_history # type: ignore[reportPrivateUsage] + broadcast_count = sum(1 for msg in history if msg.text == "Task ledger content") + + # Each broadcast should be recorded (this is expected behavior - broadcasts are additive) + # The test documents current behavior. If dedup is needed, this assertion would change. + assert broadcast_count == 2, ( + f"Expected 2 broadcasts (current behavior is additive), got {broadcast_count}. " + "If deduplication is required, update the handler logic." + ) + + +async def test_magentic_context_no_duplicate_on_reset(): + """Test that MagenticContext.reset() clears chat_history without leaving duplicates.""" + ctx = MagenticContext( + task=ChatMessage(role=Role.USER, text="task"), + participant_descriptions={"Alice": "Researcher"}, + ) + + # Add some history + ctx.chat_history.append(ChatMessage(role=Role.ASSISTANT, text="response1")) + ctx.chat_history.append(ChatMessage(role=Role.ASSISTANT, text="response2")) + assert len(ctx.chat_history) == 2 + + # Reset + ctx.reset() + + # Verify clean slate + assert len(ctx.chat_history) == 0, "chat_history should be empty after reset" + + # Add new history + ctx.chat_history.append(ChatMessage(role=Role.ASSISTANT, text="new_response")) + assert len(ctx.chat_history) == 1, "Should have exactly 1 message after adding to reset context" + + +async def test_magentic_start_message_messages_list_integrity(): + """Test that _MagenticStartMessage preserves message list without internal duplication.""" + conversation: list[ChatMessage] = [ + ChatMessage(role=Role.USER, text="msg1"), + ChatMessage(role=Role.ASSISTANT, text="msg2"), + ChatMessage(role=Role.USER, text="msg3"), + ] + + start_msg = _MagenticStartMessage(conversation) + + # Verify messages list is preserved + assert len(start_msg.messages) == 3, f"Expected 3 messages, got {len(start_msg.messages)}" + + # Verify task is the last message (not a copy) + assert start_msg.task is start_msg.messages[-1], "task should be the same object as messages[-1]" + assert start_msg.task.text == "msg3" + + +async def test_magentic_checkpoint_restore_no_duplicate_history(): + """Test that checkpoint restore does not create duplicate messages in chat_history.""" + manager = FakeManager(max_round_count=10) + storage = InMemoryCheckpointStorage() + + wf = ( + MagenticBuilder() + .participants(agentA=_DummyExec("agentA")) + .with_standard_manager(manager) + .with_checkpointing(storage) + .build() + ) + + # Run with conversation history to create initial checkpoint + conversation: list[ChatMessage] = [ + ChatMessage(role=Role.USER, text="history_msg"), + ChatMessage(role=Role.USER, text="task_msg"), + ] + + async for event in wf.run_stream(conversation): + if isinstance(event, WorkflowStatusEvent) and event.state in ( + WorkflowRunState.IDLE, + WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, + ): + break + + # Get checkpoint + checkpoints = await storage.list_checkpoints() + assert len(checkpoints) > 0, "Should have created checkpoints" + + latest_checkpoint = checkpoints[-1] + + # Load checkpoint and verify no duplicates in shared state + checkpoint_data = await storage.load_checkpoint(latest_checkpoint.checkpoint_id) + assert checkpoint_data is not None + + # Check the magentic_context in the checkpoint + for _, executor_state in checkpoint_data.metadata.items(): + if isinstance(executor_state, dict) and "magentic_context" in executor_state: + ctx_data = executor_state["magentic_context"] + chat_history = ctx_data.get("chat_history", []) + + # Count unique messages by text + texts = [ + msg.get("text") or (msg.get("contents", [{}])[0].get("text") if msg.get("contents") else None) + for msg in chat_history + ] + text_counts: dict[str, int] = {} + for text in texts: + if text: + text_counts[text] = text_counts.get(text, 0) + 1 + + # Input messages should not be duplicated + assert text_counts.get("history_msg", 0) <= 1, ( + f"'history_msg' appears {text_counts.get('history_msg', 0)} times in checkpoint - expected <= 1" + ) + assert text_counts.get("task_msg", 0) <= 1, ( + f"'task_msg' appears {text_counts.get('task_msg', 0)} times in checkpoint - expected <= 1" + ) + + +# endregion diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index c005a0f9f9..51b3544b22 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -12,17 +12,20 @@ AgentThread, ChatMessage, ChatMessageStore, + DataContent, Executor, FunctionApprovalRequestContent, FunctionApprovalResponseContent, FunctionCallContent, Role, TextContent, + UriContent, UsageContent, UsageDetails, WorkflowAgent, WorkflowBuilder, WorkflowContext, + executor, handler, response_handler, ) @@ -284,6 +287,141 @@ async def handle_bool(self, message: bool, context: WorkflowContext[Any]) -> Non with pytest.raises(ValueError, match="Workflow's start executor cannot handle list\\[ChatMessage\\]"): workflow.as_agent() + async def test_workflow_as_agent_yield_output_surfaces_as_agent_response(self) -> None: + """Test that ctx.yield_output() in a workflow executor surfaces as agent output when using .as_agent(). + + This validates the fix for issue #2813: WorkflowOutputEvent should be converted to + AgentRunResponseUpdate when the workflow is wrapped via .as_agent(). + """ + + @executor + async def yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None: + # Extract text from input for demonstration + input_text = messages[0].text if messages else "no input" + await ctx.yield_output(f"processed: {input_text}") + + workflow = WorkflowBuilder().set_start_executor(yielding_executor).build() + + # Run directly - should return WorkflowOutputEvent in result + direct_result = await workflow.run([ChatMessage(role=Role.USER, contents=[TextContent(text="hello")])]) + direct_outputs = direct_result.get_outputs() + assert len(direct_outputs) == 1 + assert direct_outputs[0] == "processed: hello" + + # Run as agent - yield_output should surface as agent response message + agent = workflow.as_agent("test-agent") + agent_result = await agent.run("hello") + + assert isinstance(agent_result, AgentRunResponse) + assert len(agent_result.messages) == 1 + assert agent_result.messages[0].text == "processed: hello" + + async def test_workflow_as_agent_yield_output_surfaces_in_run_stream(self) -> None: + """Test that ctx.yield_output() surfaces as AgentRunResponseUpdate when streaming.""" + + @executor + async def yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None: + await ctx.yield_output("first output") + await ctx.yield_output("second output") + + workflow = WorkflowBuilder().set_start_executor(yielding_executor).build() + agent = workflow.as_agent("test-agent") + + updates: list[AgentRunResponseUpdate] = [] + async for update in agent.run_stream("hello"): + updates.append(update) + + # Should have received updates for both yield_output calls + texts = [u.text for u in updates if u.text] + assert "first output" in texts + assert "second output" in texts + + async def test_workflow_as_agent_yield_output_with_content_types(self) -> None: + """Test that yield_output preserves different content types (TextContent, DataContent, etc.).""" + + @executor + async def content_yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None: + # Yield different content types + await ctx.yield_output(TextContent(text="text content")) + await ctx.yield_output(DataContent(data=b"binary data", media_type="application/octet-stream")) + await ctx.yield_output(UriContent(uri="https://example.com/image.png", media_type="image/png")) + + workflow = WorkflowBuilder().set_start_executor(content_yielding_executor).build() + agent = workflow.as_agent("content-test-agent") + + result = await agent.run("test") + + assert isinstance(result, AgentRunResponse) + assert len(result.messages) == 3 + + # Verify each content type is preserved + assert isinstance(result.messages[0].contents[0], TextContent) + assert result.messages[0].contents[0].text == "text content" + + assert isinstance(result.messages[1].contents[0], DataContent) + assert result.messages[1].contents[0].media_type == "application/octet-stream" + + assert isinstance(result.messages[2].contents[0], UriContent) + assert result.messages[2].contents[0].uri == "https://example.com/image.png" + + async def test_workflow_as_agent_yield_output_with_chat_message(self) -> None: + """Test that yield_output with ChatMessage preserves the message structure.""" + + @executor + async def chat_message_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None: + msg = ChatMessage( + role=Role.ASSISTANT, + contents=[TextContent(text="response text")], + author_name="custom-author", + ) + await ctx.yield_output(msg) + + workflow = WorkflowBuilder().set_start_executor(chat_message_executor).build() + agent = workflow.as_agent("chat-msg-agent") + + result = await agent.run("test") + + assert len(result.messages) == 1 + assert result.messages[0].role == Role.ASSISTANT + assert result.messages[0].text == "response text" + assert result.messages[0].author_name == "custom-author" + + async def test_workflow_as_agent_yield_output_sets_raw_representation(self) -> None: + """Test that yield_output sets raw_representation with the original data.""" + + # A custom object to verify raw_representation preserves the original data + class CustomData: + def __init__(self, value: int): + self.value = value + + def __str__(self) -> str: + return f"CustomData({self.value})" + + @executor + async def raw_yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None: + # Yield different types of data + await ctx.yield_output("simple string") + await ctx.yield_output(TextContent(text="text content")) + custom = CustomData(42) + await ctx.yield_output(custom) + + workflow = WorkflowBuilder().set_start_executor(raw_yielding_executor).build() + agent = workflow.as_agent("raw-test-agent") + + updates: list[AgentRunResponseUpdate] = [] + async for update in agent.run_stream("test"): + updates.append(update) + + # Should have 3 updates + assert len(updates) == 3 + + # Verify raw_representation is set for each update + assert updates[0].raw_representation == "simple string" + assert isinstance(updates[1].raw_representation, TextContent) + assert updates[1].raw_representation.text == "text content" + assert isinstance(updates[2].raw_representation, CustomData) + assert updates[2].raw_representation.value == 42 + async def test_thread_conversation_history_included_in_workflow_run(self) -> None: """Test that conversation history from thread is included when running WorkflowAgent. diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py new file mode 100644 index 0000000000..0e5a5a8d4f --- /dev/null +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -0,0 +1,644 @@ +# Copyright (c) Microsoft. All rights reserved. + +from collections.abc import AsyncIterable +from typing import Annotated, Any + +from agent_framework import ( + AgentRunResponse, + AgentRunResponseUpdate, + AgentThread, + BaseAgent, + ChatMessage, + ConcurrentBuilder, + GroupChatBuilder, + GroupChatStateSnapshot, + HandoffBuilder, + Role, + SequentialBuilder, + TextContent, + WorkflowRunState, + WorkflowStatusEvent, + ai_function, +) +from agent_framework._workflows._const import WORKFLOW_RUN_KWARGS_KEY + +# Track kwargs received by tools during test execution +_received_kwargs: list[dict[str, Any]] = [] + + +def _reset_received_kwargs() -> None: + """Reset the kwargs tracker before each test.""" + _received_kwargs.clear() + + +@ai_function +def tool_with_kwargs( + action: Annotated[str, "The action to perform"], + **kwargs: Any, +) -> str: + """A test tool that captures kwargs for verification.""" + _received_kwargs.append(dict(kwargs)) + custom_data = kwargs.get("custom_data", {}) + user_token = kwargs.get("user_token", {}) + return f"Executed {action} with custom_data={custom_data}, user={user_token.get('user_name', 'unknown')}" + + +class _KwargsCapturingAgent(BaseAgent): + """Test agent that captures kwargs passed to run/run_stream.""" + + captured_kwargs: list[dict[str, Any]] + + def __init__(self, name: str = "test_agent") -> None: + super().__init__(name=name, description="Test agent for kwargs capture") + self.captured_kwargs = [] + + async def run( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AgentRunResponse: + self.captured_kwargs.append(dict(kwargs)) + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.display_name} response")]) + + async def run_stream( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AsyncIterable[AgentRunResponseUpdate]: + self.captured_kwargs.append(dict(kwargs)) + yield AgentRunResponseUpdate(contents=[TextContent(text=f"{self.display_name} response")]) + + +class _EchoAgent(BaseAgent): + """Simple agent that echoes back for workflow completion.""" + + async def run( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AgentRunResponse: + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.display_name} reply")]) + + async def run_stream( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AsyncIterable[AgentRunResponseUpdate]: + yield AgentRunResponseUpdate(contents=[TextContent(text=f"{self.display_name} reply")]) + + +# region Sequential Builder Tests + + +async def test_sequential_kwargs_flow_to_agent() -> None: + """Test that kwargs passed to SequentialBuilder workflow flow through to agent.""" + agent = _KwargsCapturingAgent(name="seq_agent") + workflow = SequentialBuilder().participants([agent]).build() + + custom_data = {"endpoint": "https://api.example.com", "version": "v1"} + user_token = {"user_name": "alice", "access_level": "admin"} + + async for event in workflow.run_stream( + "test message", + custom_data=custom_data, + user_token=user_token, + ): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # Verify agent received kwargs + assert len(agent.captured_kwargs) >= 1, "Agent should have been invoked at least once" + received = agent.captured_kwargs[0] + assert "custom_data" in received, "Agent should receive custom_data kwarg" + assert "user_token" in received, "Agent should receive user_token kwarg" + assert received["custom_data"] == custom_data + assert received["user_token"] == user_token + + +async def test_sequential_kwargs_flow_to_multiple_agents() -> None: + """Test that kwargs flow to all agents in a sequential workflow.""" + agent1 = _KwargsCapturingAgent(name="agent1") + agent2 = _KwargsCapturingAgent(name="agent2") + workflow = SequentialBuilder().participants([agent1, agent2]).build() + + custom_data = {"key": "value"} + + async for event in workflow.run_stream("test", custom_data=custom_data): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # Both agents should have received kwargs + assert len(agent1.captured_kwargs) >= 1, "First agent should be invoked" + assert len(agent2.captured_kwargs) >= 1, "Second agent should be invoked" + assert agent1.captured_kwargs[0].get("custom_data") == custom_data + assert agent2.captured_kwargs[0].get("custom_data") == custom_data + + +async def test_sequential_run_kwargs_flow() -> None: + """Test that kwargs flow through workflow.run() (non-streaming).""" + agent = _KwargsCapturingAgent(name="run_agent") + workflow = SequentialBuilder().participants([agent]).build() + + _ = await workflow.run("test message", custom_data={"test": True}) + + assert len(agent.captured_kwargs) >= 1 + assert agent.captured_kwargs[0].get("custom_data") == {"test": True} + + +# endregion + + +# region Concurrent Builder Tests + + +async def test_concurrent_kwargs_flow_to_agents() -> None: + """Test that kwargs flow to all agents in a concurrent workflow.""" + agent1 = _KwargsCapturingAgent(name="concurrent1") + agent2 = _KwargsCapturingAgent(name="concurrent2") + workflow = ConcurrentBuilder().participants([agent1, agent2]).build() + + custom_data = {"batch_id": "123"} + user_token = {"user_name": "bob"} + + async for event in workflow.run_stream( + "concurrent test", + custom_data=custom_data, + user_token=user_token, + ): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # Both agents should have received kwargs + assert len(agent1.captured_kwargs) >= 1, "First concurrent agent should be invoked" + assert len(agent2.captured_kwargs) >= 1, "Second concurrent agent should be invoked" + + for agent in [agent1, agent2]: + received = agent.captured_kwargs[0] + assert received.get("custom_data") == custom_data + assert received.get("user_token") == user_token + + +# endregion + + +# region GroupChat Builder Tests + + +async def test_groupchat_kwargs_flow_to_agents() -> None: + """Test that kwargs flow to agents in a group chat workflow.""" + agent1 = _KwargsCapturingAgent(name="chat1") + agent2 = _KwargsCapturingAgent(name="chat2") + + # Simple selector that takes GroupChatStateSnapshot + turn_count = 0 + + def simple_selector(state: GroupChatStateSnapshot) -> str | None: + nonlocal turn_count + turn_count += 1 + if turn_count > 2: # Stop after 2 turns + return None + # state is a Mapping - access via dict syntax + names = list(state["participants"].keys()) + return names[(turn_count - 1) % len(names)] + + workflow = ( + GroupChatBuilder().participants(chat1=agent1, chat2=agent2).set_select_speakers_func(simple_selector).build() + ) + + custom_data = {"session_id": "group123"} + + async for event in workflow.run_stream("group chat test", custom_data=custom_data): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # At least one agent should have received kwargs + all_kwargs = agent1.captured_kwargs + agent2.captured_kwargs + assert len(all_kwargs) >= 1, "At least one agent should be invoked in group chat" + + for received in all_kwargs: + assert received.get("custom_data") == custom_data + + +# endregion + + +# region SharedState Verification Tests + + +async def test_kwargs_stored_in_shared_state() -> None: + """Test that kwargs are stored in SharedState with the correct key.""" + from agent_framework import Executor, WorkflowContext, handler + + stored_kwargs: dict[str, Any] | None = None + + class _SharedStateInspector(Executor): + @handler + async def inspect(self, msgs: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: + nonlocal stored_kwargs + stored_kwargs = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + await ctx.send_message(msgs) + + inspector = _SharedStateInspector(id="inspector") + workflow = SequentialBuilder().participants([inspector]).build() + + async for event in workflow.run_stream("test", my_kwarg="my_value", another=123): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + assert stored_kwargs is not None, "kwargs should be stored in SharedState" + assert stored_kwargs.get("my_kwarg") == "my_value" + assert stored_kwargs.get("another") == 123 + + +async def test_empty_kwargs_stored_as_empty_dict() -> None: + """Test that empty kwargs are stored as empty dict in SharedState.""" + from agent_framework import Executor, WorkflowContext, handler + + stored_kwargs: Any = "NOT_CHECKED" + + class _SharedStateChecker(Executor): + @handler + async def check(self, msgs: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: + nonlocal stored_kwargs + stored_kwargs = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + await ctx.send_message(msgs) + + checker = _SharedStateChecker(id="checker") + workflow = SequentialBuilder().participants([checker]).build() + + # Run without any kwargs + async for event in workflow.run_stream("test"): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # SharedState should have empty dict when no kwargs provided + assert stored_kwargs == {}, f"Expected empty dict, got: {stored_kwargs}" + + +# endregion + + +# region Edge Cases + + +async def test_kwargs_with_none_values() -> None: + """Test that kwargs with None values are passed through correctly.""" + agent = _KwargsCapturingAgent(name="none_test") + workflow = SequentialBuilder().participants([agent]).build() + + async for event in workflow.run_stream("test", optional_param=None, other_param="value"): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + assert len(agent.captured_kwargs) >= 1 + received = agent.captured_kwargs[0] + assert "optional_param" in received + assert received["optional_param"] is None + assert received["other_param"] == "value" + + +async def test_kwargs_with_complex_nested_data() -> None: + """Test that complex nested data structures flow through correctly.""" + agent = _KwargsCapturingAgent(name="nested_test") + workflow = SequentialBuilder().participants([agent]).build() + + complex_data = { + "level1": { + "level2": { + "level3": ["a", "b", "c"], + "number": 42, + }, + "list": [1, 2, {"nested": True}], + }, + "tuple_like": [1, 2, 3], + } + + async for event in workflow.run_stream("test", complex_data=complex_data): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + assert len(agent.captured_kwargs) >= 1 + received = agent.captured_kwargs[0] + assert received.get("complex_data") == complex_data + + +async def test_kwargs_preserved_across_workflow_reruns() -> None: + """Test that kwargs are correctly isolated between workflow runs.""" + agent = _KwargsCapturingAgent(name="rerun_test") + + # Build separate workflows for each run to avoid "already running" error + workflow1 = SequentialBuilder().participants([agent]).build() + workflow2 = SequentialBuilder().participants([agent]).build() + + # First run + async for event in workflow1.run_stream("run1", run_id="first"): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # Second run with different kwargs (using fresh workflow) + async for event in workflow2.run_stream("run2", run_id="second"): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + assert len(agent.captured_kwargs) >= 2 + assert agent.captured_kwargs[0].get("run_id") == "first" + assert agent.captured_kwargs[1].get("run_id") == "second" + + +# endregion + + +# region Handoff Builder Tests + + +async def test_handoff_kwargs_flow_to_agents() -> None: + """Test that kwargs flow to agents in a handoff workflow.""" + agent1 = _KwargsCapturingAgent(name="coordinator") + agent2 = _KwargsCapturingAgent(name="specialist") + + workflow = ( + HandoffBuilder() + .participants([agent1, agent2]) + .set_coordinator(agent1) + .with_interaction_mode("autonomous") + .build() + ) + + custom_data = {"session_id": "handoff123"} + + async for event in workflow.run_stream("handoff test", custom_data=custom_data): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # Coordinator agent should have received kwargs + assert len(agent1.captured_kwargs) >= 1, "Coordinator should be invoked in handoff" + assert agent1.captured_kwargs[0].get("custom_data") == custom_data + + +# endregion + + +# region Magentic Builder Tests + + +async def test_magentic_kwargs_flow_to_agents() -> None: + """Test that kwargs flow to agents in a magentic workflow via MagenticAgentExecutor.""" + from agent_framework import MagenticBuilder + from agent_framework._workflows._magentic import ( + MagenticContext, + MagenticManagerBase, + _MagenticProgressLedger, + _MagenticProgressLedgerItem, + ) + + # Create a mock manager that completes after one round + class _MockManager(MagenticManagerBase): + def __init__(self) -> None: + super().__init__(max_stall_count=3, max_reset_count=None, max_round_count=2) + self.task_ledger = None + + async def plan(self, context: MagenticContext) -> ChatMessage: + return ChatMessage(role=Role.ASSISTANT, text="Plan: Test task", author_name="manager") + + async def replan(self, context: MagenticContext) -> ChatMessage: + return ChatMessage(role=Role.ASSISTANT, text="Replan: Test task", author_name="manager") + + async def create_progress_ledger(self, context: MagenticContext) -> _MagenticProgressLedger: + # Return completed on first call + return _MagenticProgressLedger( + is_request_satisfied=_MagenticProgressLedgerItem(answer=True, reason="Done"), + is_progress_being_made=_MagenticProgressLedgerItem(answer=True, reason="Progress"), + is_in_loop=_MagenticProgressLedgerItem(answer=False, reason="Not looping"), + instruction_or_question=_MagenticProgressLedgerItem(answer="Complete", reason="Done"), + next_speaker=_MagenticProgressLedgerItem(answer="agent1", reason="First"), + ) + + async def prepare_final_answer(self, context: MagenticContext) -> ChatMessage: + return ChatMessage(role=Role.ASSISTANT, text="Final answer", author_name="manager") + + agent = _KwargsCapturingAgent(name="agent1") + manager = _MockManager() + + workflow = MagenticBuilder().participants(agent1=agent).with_standard_manager(manager=manager).build() + + custom_data = {"session_id": "magentic123"} + + async for event in workflow.run_stream("magentic test", custom_data=custom_data): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # The workflow completes immediately via prepare_final_answer without invoking agents + # because is_request_satisfied=True. This test verifies the kwargs storage path works. + # A more comprehensive integration test would require the manager to select an agent. + + +async def test_magentic_kwargs_stored_in_shared_state() -> None: + """Test that kwargs are stored in SharedState when using MagenticWorkflow.run_stream().""" + from agent_framework import MagenticBuilder + from agent_framework._workflows._magentic import ( + MagenticContext, + MagenticManagerBase, + _MagenticProgressLedger, + _MagenticProgressLedgerItem, + ) + + class _MockManager(MagenticManagerBase): + def __init__(self) -> None: + super().__init__(max_stall_count=3, max_reset_count=None, max_round_count=1) + self.task_ledger = None + + async def plan(self, context: MagenticContext) -> ChatMessage: + return ChatMessage(role=Role.ASSISTANT, text="Plan", author_name="manager") + + async def replan(self, context: MagenticContext) -> ChatMessage: + return ChatMessage(role=Role.ASSISTANT, text="Replan", author_name="manager") + + async def create_progress_ledger(self, context: MagenticContext) -> _MagenticProgressLedger: + return _MagenticProgressLedger( + is_request_satisfied=_MagenticProgressLedgerItem(answer=True, reason="Done"), + is_progress_being_made=_MagenticProgressLedgerItem(answer=True, reason="Progress"), + is_in_loop=_MagenticProgressLedgerItem(answer=False, reason="Not looping"), + instruction_or_question=_MagenticProgressLedgerItem(answer="Done", reason="Done"), + next_speaker=_MagenticProgressLedgerItem(answer="agent1", reason="First"), + ) + + async def prepare_final_answer(self, context: MagenticContext) -> ChatMessage: + return ChatMessage(role=Role.ASSISTANT, text="Final", author_name="manager") + + agent = _KwargsCapturingAgent(name="agent1") + manager = _MockManager() + + magentic_workflow = MagenticBuilder().participants(agent1=agent).with_standard_manager(manager=manager).build() + + # Use MagenticWorkflow.run_stream() which goes through the kwargs attachment path + custom_data = {"magentic_key": "magentic_value"} + + async for event in magentic_workflow.run_stream("test task", custom_data=custom_data): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # Verify the workflow completed (kwargs were stored, even if agent wasn't invoked) + # The test validates the code path through MagenticWorkflow.run_stream -> _MagenticStartMessage + + +# endregion + + +# region SubWorkflow (WorkflowExecutor) Tests + + +async def test_subworkflow_kwargs_propagation() -> None: + """Test that kwargs are propagated to subworkflows. + + Verifies kwargs passed to parent workflow.run_stream() flow through to agents + in subworkflows wrapped by WorkflowExecutor. + """ + from agent_framework._workflows._workflow_executor import WorkflowExecutor + + # Create an agent inside the subworkflow that captures kwargs + inner_agent = _KwargsCapturingAgent(name="inner_agent") + + # Build the inner (sub) workflow with the agent + inner_workflow = SequentialBuilder().participants([inner_agent]).build() + + # Wrap the inner workflow in a WorkflowExecutor so it can be used as a subworkflow + subworkflow_executor = WorkflowExecutor(workflow=inner_workflow, id="subworkflow_executor") + + # Build the outer (parent) workflow containing the subworkflow + outer_workflow = SequentialBuilder().participants([subworkflow_executor]).build() + + # Define kwargs that should propagate to subworkflow + custom_data = {"api_key": "secret123", "endpoint": "https://api.example.com"} + user_token = {"user_name": "alice", "access_level": "admin"} + + # Run the outer workflow with kwargs + async for event in outer_workflow.run_stream( + "test message for subworkflow", + custom_data=custom_data, + user_token=user_token, + ): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # Verify that the inner agent was called + assert len(inner_agent.captured_kwargs) >= 1, "Inner agent in subworkflow should have been invoked" + + received_kwargs = inner_agent.captured_kwargs[0] + + # Verify kwargs were propagated from parent workflow to subworkflow agent + assert "custom_data" in received_kwargs, ( + f"Subworkflow agent should receive 'custom_data' kwarg. Received keys: {list(received_kwargs.keys())}" + ) + assert "user_token" in received_kwargs, ( + f"Subworkflow agent should receive 'user_token' kwarg. Received keys: {list(received_kwargs.keys())}" + ) + assert received_kwargs.get("custom_data") == custom_data, ( + f"Expected custom_data={custom_data}, got {received_kwargs.get('custom_data')}" + ) + assert received_kwargs.get("user_token") == user_token, ( + f"Expected user_token={user_token}, got {received_kwargs.get('user_token')}" + ) + + +async def test_subworkflow_kwargs_accessible_via_shared_state() -> None: + """Test that kwargs are accessible via SharedState within subworkflow. + + Verifies that WORKFLOW_RUN_KWARGS_KEY is populated in the subworkflow's SharedState + with kwargs from the parent workflow. + """ + from agent_framework import Executor, WorkflowContext, handler + from agent_framework._workflows._workflow_executor import WorkflowExecutor + + captured_kwargs_from_state: list[dict[str, Any]] = [] + + class _SharedStateReader(Executor): + """Executor that reads kwargs from SharedState for verification.""" + + @handler + async def read_kwargs(self, msgs: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: + kwargs_from_state = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + captured_kwargs_from_state.append(kwargs_from_state or {}) + await ctx.send_message(msgs) + + # Build inner workflow with SharedState reader + state_reader = _SharedStateReader(id="state_reader") + inner_workflow = SequentialBuilder().participants([state_reader]).build() + + # Wrap as subworkflow + subworkflow_executor = WorkflowExecutor(workflow=inner_workflow, id="subworkflow") + + # Build outer workflow + outer_workflow = SequentialBuilder().participants([subworkflow_executor]).build() + + # Run with kwargs + async for event in outer_workflow.run_stream( + "test", + my_custom_kwarg="should_be_propagated", + another_kwarg=42, + ): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # Verify the state reader was invoked + assert len(captured_kwargs_from_state) >= 1, "SharedState reader should have been invoked" + + kwargs_in_subworkflow = captured_kwargs_from_state[0] + + assert kwargs_in_subworkflow.get("my_custom_kwarg") == "should_be_propagated", ( + f"Expected 'my_custom_kwarg' in subworkflow SharedState, got: {kwargs_in_subworkflow}" + ) + assert kwargs_in_subworkflow.get("another_kwarg") == 42, ( + f"Expected 'another_kwarg'=42 in subworkflow SharedState, got: {kwargs_in_subworkflow}" + ) + + +async def test_nested_subworkflow_kwargs_propagation() -> None: + """Test kwargs propagation through multiple levels of nested subworkflows. + + Verifies kwargs flow through 3 levels: + - Outer workflow + - Middle subworkflow (WorkflowExecutor) + - Inner subworkflow (WorkflowExecutor) with agent + """ + from agent_framework._workflows._workflow_executor import WorkflowExecutor + + # Innermost agent + inner_agent = _KwargsCapturingAgent(name="deeply_nested_agent") + + # Build inner workflow + inner_workflow = SequentialBuilder().participants([inner_agent]).build() + inner_executor = WorkflowExecutor(workflow=inner_workflow, id="inner_executor") + + # Build middle workflow containing inner + middle_workflow = SequentialBuilder().participants([inner_executor]).build() + middle_executor = WorkflowExecutor(workflow=middle_workflow, id="middle_executor") + + # Build outer workflow containing middle + outer_workflow = SequentialBuilder().participants([middle_executor]).build() + + # Run with kwargs + async for event in outer_workflow.run_stream( + "deeply nested test", + deep_kwarg="should_reach_inner", + ): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # Verify inner agent was called + assert len(inner_agent.captured_kwargs) >= 1, "Deeply nested agent should be invoked" + + received = inner_agent.captured_kwargs[0] + assert received.get("deep_kwarg") == "should_reach_inner", ( + f"Deeply nested agent should receive 'deep_kwarg'. Got: {received}" + ) + + +# endregion diff --git a/python/packages/core/tests/workflow/test_workflow_observability.py b/python/packages/core/tests/workflow/test_workflow_observability.py index 1760361f1a..4c97b850b8 100644 --- a/python/packages/core/tests/workflow/test_workflow_observability.py +++ b/python/packages/core/tests/workflow/test_workflow_observability.py @@ -229,8 +229,10 @@ async def test_trace_context_handling(span_exporter: InMemorySpanExporter) -> No assert processing_span.attributes.get("message.payload_type") == "str" -@pytest.mark.parametrize("enable_otel", [False], indirect=True) -async def test_trace_context_disabled_when_tracing_disabled(enable_otel, span_exporter: InMemorySpanExporter) -> None: +@pytest.mark.parametrize("enable_instrumentation", [False], indirect=True) +async def test_trace_context_disabled_when_tracing_disabled( + enable_instrumentation, span_exporter: InMemorySpanExporter +) -> None: """Test that no trace context is added when tracing is disabled.""" # Tracing should be disabled by default executor = MockExecutor("test-executor") @@ -433,7 +435,7 @@ async def handle_message(self, message: str, ctx: WorkflowContext) -> None: assert workflow_span.status.status_code.name == "ERROR" -@pytest.mark.parametrize("enable_otel", [False], indirect=True) +@pytest.mark.parametrize("enable_instrumentation", [False], indirect=True) async def test_message_trace_context_serialization(span_exporter: InMemorySpanExporter) -> None: """Test that message trace context is properly serialized/deserialized.""" ctx = InProcRunnerContext(InMemoryCheckpointStorage()) diff --git a/python/packages/declarative/pyproject.toml b/python/packages/declarative/pyproject.toml index d36c2b76b5..5a82f3a1e5 100644 --- a/python/packages/declarative/pyproject.toml +++ b/python/packages/declarative/pyproject.toml @@ -4,7 +4,7 @@ description = "Declarative specification support for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251211" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/devui/agent_framework_devui/__init__.py b/python/packages/devui/agent_framework_devui/__init__.py index 45d1ea8c2d..9a480d170e 100644 --- a/python/packages/devui/agent_framework_devui/__init__.py +++ b/python/packages/devui/agent_framework_devui/__init__.py @@ -177,9 +177,9 @@ def serve( import os # Only set if not already configured by user - if not os.environ.get("ENABLE_OTEL"): - os.environ["ENABLE_OTEL"] = "true" - logger.info("Set ENABLE_OTEL=true for tracing") + if not os.environ.get("ENABLE_INSTRUMENTATION"): + os.environ["ENABLE_INSTRUMENTATION"] = "true" + logger.info("Set ENABLE_INSTRUMENTATION=true for tracing") if not os.environ.get("ENABLE_SENSITIVE_DATA"): os.environ["ENABLE_SENSITIVE_DATA"] = "true" diff --git a/python/packages/devui/agent_framework_devui/_executor.py b/python/packages/devui/agent_framework_devui/_executor.py index 813bc4d4cc..1f28c8772c 100644 --- a/python/packages/devui/agent_framework_devui/_executor.py +++ b/python/packages/devui/agent_framework_devui/_executor.py @@ -82,27 +82,23 @@ def _setup_tracing_provider(self) -> None: def _setup_agent_framework_tracing(self) -> None: """Set up Agent Framework's built-in tracing.""" - # Configure Agent Framework tracing only if ENABLE_OTEL is set - if os.environ.get("ENABLE_OTEL"): + # Configure Agent Framework tracing only if ENABLE_INSTRUMENTATION is set + if os.environ.get("ENABLE_INSTRUMENTATION"): try: - from agent_framework.observability import OBSERVABILITY_SETTINGS, setup_observability + from agent_framework.observability import OBSERVABILITY_SETTINGS, configure_otel_providers # Only configure if not already executed if not OBSERVABILITY_SETTINGS._executed_setup: - # Get OTLP endpoint from either custom or standard env var - # This handles the case where env vars are set after ObservabilitySettings was imported - otlp_endpoint = os.environ.get("OTLP_ENDPOINT") or os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") - - # Pass the endpoint explicitly to setup_observability + # Run the configure_otel_providers # This ensures OTLP exporters are created even if env vars were set late - setup_observability(enable_sensitive_data=True, otlp_endpoint=otlp_endpoint) + configure_otel_providers(enable_sensitive_data=True) logger.info("Enabled Agent Framework observability") else: logger.debug("Agent Framework observability already configured") except Exception as e: logger.warning(f"Failed to enable Agent Framework observability: {e}") else: - logger.debug("ENABLE_OTEL not set, skipping observability setup") + logger.debug("ENABLE_INSTRUMENTATION not set, skipping observability setup") async def discover_entities(self) -> list[EntityInfo]: """Discover all available entities. @@ -252,7 +248,7 @@ async def _execute_agent( # Get thread from conversation parameter (OpenAI standard!) thread = None - conversation_id = request.get_conversation_id() + conversation_id = request._get_conversation_id() if conversation_id: thread = self.conversation_store.get_thread(conversation_id) if thread: @@ -328,7 +324,7 @@ async def _execute_workflow( entity_id = request.get_entity_id() or "unknown" # Get or create session conversation for checkpoint storage - conversation_id = request.get_conversation_id() + conversation_id = request._get_conversation_id() if not conversation_id: # Create default session if not provided import time diff --git a/python/packages/devui/agent_framework_devui/_server.py b/python/packages/devui/agent_framework_devui/_server.py index 26630945cb..b3a7c751b6 100644 --- a/python/packages/devui/agent_framework_devui/_server.py +++ b/python/packages/devui/agent_framework_devui/_server.py @@ -407,7 +407,7 @@ async def get_meta() -> MetaResponse: framework="agent_framework", runtime="python", # Python DevUI backend capabilities={ - "tracing": os.getenv("ENABLE_OTEL") == "true", + "tracing": os.getenv("ENABLE_INSTRUMENTATION") == "true", "openai_proxy": openai_executor.is_configured, "deployment": True, # Deployment feature is available }, diff --git a/python/packages/devui/agent_framework_devui/models/_openai_custom.py b/python/packages/devui/agent_framework_devui/models/_openai_custom.py index f82ef90b72..ac0e74034a 100644 --- a/python/packages/devui/agent_framework_devui/models/_openai_custom.py +++ b/python/packages/devui/agent_framework_devui/models/_openai_custom.py @@ -324,7 +324,7 @@ def get_entity_id(self) -> str | None: return self.metadata.get("entity_id") return None - def get_conversation_id(self) -> str | None: + def _get_conversation_id(self) -> str | None: """Extract conversation_id from conversation parameter. Supports both string and object forms: diff --git a/python/packages/devui/agent_framework_devui/ui/assets/index.js b/python/packages/devui/agent_framework_devui/ui/assets/index.js index d12b71a838..1b05f27842 100644 --- a/python/packages/devui/agent_framework_devui/ui/assets/index.js +++ b/python/packages/devui/agent_framework_devui/ui/assets/index.js @@ -1,4 +1,4 @@ -function yE(e,n){for(var r=0;ra[l]})}}}return Object.freeze(Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}))}(function(){const n=document.createElement("link").relList;if(n&&n.supports&&n.supports("modulepreload"))return;for(const l of document.querySelectorAll('link[rel="modulepreload"]'))a(l);new MutationObserver(l=>{for(const c of l)if(c.type==="childList")for(const d of c.addedNodes)d.tagName==="LINK"&&d.rel==="modulepreload"&&a(d)}).observe(document,{childList:!0,subtree:!0});function r(l){const c={};return l.integrity&&(c.integrity=l.integrity),l.referrerPolicy&&(c.referrerPolicy=l.referrerPolicy),l.crossOrigin==="use-credentials"?c.credentials="include":l.crossOrigin==="anonymous"?c.credentials="omit":c.credentials="same-origin",c}function a(l){if(l.ep)return;l.ep=!0;const c=r(l);fetch(l.href,c)}})();function yp(e){return e&&e.__esModule&&Object.prototype.hasOwnProperty.call(e,"default")?e.default:e}var Gm={exports:{}},Bi={};/** +function yE(e, n) { for (var r = 0; r < n.length; r++) { const a = n[r]; if (typeof a != "string" && !Array.isArray(a)) { for (const l in a) if (l !== "default" && !(l in e)) { const c = Object.getOwnPropertyDescriptor(a, l); c && Object.defineProperty(e, l, c.get ? c : { enumerable: !0, get: () => a[l] }) } } } return Object.freeze(Object.defineProperty(e, Symbol.toStringTag, { value: "Module" })) } (function () { const n = document.createElement("link").relList; if (n && n.supports && n.supports("modulepreload")) return; for (const l of document.querySelectorAll('link[rel="modulepreload"]')) a(l); new MutationObserver(l => { for (const c of l) if (c.type === "childList") for (const d of c.addedNodes) d.tagName === "LINK" && d.rel === "modulepreload" && a(d) }).observe(document, { childList: !0, subtree: !0 }); function r(l) { const c = {}; return l.integrity && (c.integrity = l.integrity), l.referrerPolicy && (c.referrerPolicy = l.referrerPolicy), l.crossOrigin === "use-credentials" ? c.credentials = "include" : l.crossOrigin === "anonymous" ? c.credentials = "omit" : c.credentials = "same-origin", c } function a(l) { if (l.ep) return; l.ep = !0; const c = r(l); fetch(l.href, c) } })(); function yp(e) { return e && e.__esModule && Object.prototype.hasOwnProperty.call(e, "default") ? e.default : e } var Gm = { exports: {} }, Bi = {};/** * @license React * react-jsx-runtime.production.js * @@ -6,7 +6,7 @@ function yE(e,n){for(var r=0;r>>1,C=k[H];if(0>>1;H<$;){var Y=2*(H+1)-1,V=k[Y],W=Y+1,fe=k[W];if(0>l(V,I))Wl(fe,V)?(k[H]=fe,k[W]=I,H=W):(k[H]=V,k[Y]=I,H=Y);else if(Wl(fe,I))k[H]=fe,k[W]=I,H=W;else break e}}return L}function l(k,L){var I=k.sortIndex-L.sortIndex;return I!==0?I:k.id-L.id}if(e.unstable_now=void 0,typeof performance=="object"&&typeof performance.now=="function"){var c=performance;e.unstable_now=function(){return c.now()}}else{var d=Date,f=d.now();e.unstable_now=function(){return d.now()-f}}var m=[],h=[],g=1,y=null,x=3,b=!1,S=!1,N=!1,j=!1,_=typeof setTimeout=="function"?setTimeout:null,M=typeof clearTimeout=="function"?clearTimeout:null,E=typeof setImmediate<"u"?setImmediate:null;function T(k){for(var L=r(h);L!==null;){if(L.callback===null)a(h);else if(L.startTime<=k)a(h),L.sortIndex=L.expirationTime,n(m,L);else break;L=r(h)}}function R(k){if(N=!1,T(k),!S)if(r(m)!==null)S=!0,D||(D=!0,G());else{var L=r(h);L!==null&&U(R,L.startTime-k)}}var D=!1,O=-1,B=5,q=-1;function K(){return j?!0:!(e.unstable_now()-qk&&K());){var H=y.callback;if(typeof H=="function"){y.callback=null,x=y.priorityLevel;var C=H(y.expirationTime<=k);if(k=e.unstable_now(),typeof C=="function"){y.callback=C,T(k),L=!0;break t}y===r(m)&&a(m),T(k)}else a(m);y=r(m)}if(y!==null)L=!0;else{var $=r(h);$!==null&&U(R,$.startTime-k),L=!1}}break e}finally{y=null,x=I,b=!1}L=void 0}}finally{L?G():D=!1}}}var G;if(typeof E=="function")G=function(){E(J)};else if(typeof MessageChannel<"u"){var Z=new MessageChannel,P=Z.port2;Z.port1.onmessage=J,G=function(){P.postMessage(null)}}else G=function(){_(J,0)};function U(k,L){O=_(function(){k(e.unstable_now())},L)}e.unstable_IdlePriority=5,e.unstable_ImmediatePriority=1,e.unstable_LowPriority=4,e.unstable_NormalPriority=3,e.unstable_Profiling=null,e.unstable_UserBlockingPriority=2,e.unstable_cancelCallback=function(k){k.callback=null},e.unstable_forceFrameRate=function(k){0>k||125H?(k.sortIndex=I,n(h,k),r(m)===null&&k===r(h)&&(N?(M(O),O=-1):N=!0,U(R,I-H))):(k.sortIndex=C,n(m,k),S||b||(S=!0,D||(D=!0,G()))),k},e.unstable_shouldYield=K,e.unstable_wrapCallback=function(k){var L=x;return function(){var I=x;x=L;try{return k.apply(this,arguments)}finally{x=I}}}})(Km)),Km}var tv;function SE(){return tv||(tv=1,Wm.exports=NE()),Wm.exports}var Qm={exports:{}},Wt={};/** + */var ev; function NE() { return ev || (ev = 1, (function (e) { function n(k, L) { var I = k.length; k.push(L); e: for (; 0 < I;) { var H = I - 1 >>> 1, C = k[H]; if (0 < l(C, L)) k[H] = L, k[I] = C, I = H; else break e } } function r(k) { return k.length === 0 ? null : k[0] } function a(k) { if (k.length === 0) return null; var L = k[0], I = k.pop(); if (I !== L) { k[0] = I; e: for (var H = 0, C = k.length, $ = C >>> 1; H < $;) { var Y = 2 * (H + 1) - 1, V = k[Y], W = Y + 1, fe = k[W]; if (0 > l(V, I)) W < C && 0 > l(fe, V) ? (k[H] = fe, k[W] = I, H = W) : (k[H] = V, k[Y] = I, H = Y); else if (W < C && 0 > l(fe, I)) k[H] = fe, k[W] = I, H = W; else break e } } return L } function l(k, L) { var I = k.sortIndex - L.sortIndex; return I !== 0 ? I : k.id - L.id } if (e.unstable_now = void 0, typeof performance == "object" && typeof performance.now == "function") { var c = performance; e.unstable_now = function () { return c.now() } } else { var d = Date, f = d.now(); e.unstable_now = function () { return d.now() - f } } var m = [], h = [], g = 1, y = null, x = 3, b = !1, S = !1, N = !1, j = !1, _ = typeof setTimeout == "function" ? setTimeout : null, M = typeof clearTimeout == "function" ? clearTimeout : null, E = typeof setImmediate < "u" ? setImmediate : null; function T(k) { for (var L = r(h); L !== null;) { if (L.callback === null) a(h); else if (L.startTime <= k) a(h), L.sortIndex = L.expirationTime, n(m, L); else break; L = r(h) } } function R(k) { if (N = !1, T(k), !S) if (r(m) !== null) S = !0, D || (D = !0, G()); else { var L = r(h); L !== null && U(R, L.startTime - k) } } var D = !1, O = -1, B = 5, q = -1; function K() { return j ? !0 : !(e.unstable_now() - q < B) } function J() { if (j = !1, D) { var k = e.unstable_now(); q = k; var L = !0; try { e: { S = !1, N && (N = !1, M(O), O = -1), b = !0; var I = x; try { t: { for (T(k), y = r(m); y !== null && !(y.expirationTime > k && K());) { var H = y.callback; if (typeof H == "function") { y.callback = null, x = y.priorityLevel; var C = H(y.expirationTime <= k); if (k = e.unstable_now(), typeof C == "function") { y.callback = C, T(k), L = !0; break t } y === r(m) && a(m), T(k) } else a(m); y = r(m) } if (y !== null) L = !0; else { var $ = r(h); $ !== null && U(R, $.startTime - k), L = !1 } } break e } finally { y = null, x = I, b = !1 } L = void 0 } } finally { L ? G() : D = !1 } } } var G; if (typeof E == "function") G = function () { E(J) }; else if (typeof MessageChannel < "u") { var Z = new MessageChannel, P = Z.port2; Z.port1.onmessage = J, G = function () { P.postMessage(null) } } else G = function () { _(J, 0) }; function U(k, L) { O = _(function () { k(e.unstable_now()) }, L) } e.unstable_IdlePriority = 5, e.unstable_ImmediatePriority = 1, e.unstable_LowPriority = 4, e.unstable_NormalPriority = 3, e.unstable_Profiling = null, e.unstable_UserBlockingPriority = 2, e.unstable_cancelCallback = function (k) { k.callback = null }, e.unstable_forceFrameRate = function (k) { 0 > k || 125 < k ? console.error("forceFrameRate takes a positive int between 0 and 125, forcing frame rates higher than 125 fps is not supported") : B = 0 < k ? Math.floor(1e3 / k) : 5 }, e.unstable_getCurrentPriorityLevel = function () { return x }, e.unstable_next = function (k) { switch (x) { case 1: case 2: case 3: var L = 3; break; default: L = x }var I = x; x = L; try { return k() } finally { x = I } }, e.unstable_requestPaint = function () { j = !0 }, e.unstable_runWithPriority = function (k, L) { switch (k) { case 1: case 2: case 3: case 4: case 5: break; default: k = 3 }var I = x; x = k; try { return L() } finally { x = I } }, e.unstable_scheduleCallback = function (k, L, I) { var H = e.unstable_now(); switch (typeof I == "object" && I !== null ? (I = I.delay, I = typeof I == "number" && 0 < I ? H + I : H) : I = H, k) { case 1: var C = -1; break; case 2: C = 250; break; case 5: C = 1073741823; break; case 4: C = 1e4; break; default: C = 5e3 }return C = I + C, k = { id: g++, callback: L, priorityLevel: k, startTime: I, expirationTime: C, sortIndex: -1 }, I > H ? (k.sortIndex = I, n(h, k), r(m) === null && k === r(h) && (N ? (M(O), O = -1) : N = !0, U(R, I - H))) : (k.sortIndex = C, n(m, k), S || b || (S = !0, D || (D = !0, G()))), k }, e.unstable_shouldYield = K, e.unstable_wrapCallback = function (k) { var L = x; return function () { var I = x; x = L; try { return k.apply(this, arguments) } finally { x = I } } } })(Km)), Km } var tv; function SE() { return tv || (tv = 1, Wm.exports = NE()), Wm.exports } var Qm = { exports: {} }, Wt = {};/** * @license React * react-dom.production.js * @@ -30,7 +30,7 @@ function yE(e,n){for(var r=0;r"u"||typeof __REACT_DEVTOOLS_GLOBAL_HOOK__.checkDCE!="function"))try{__REACT_DEVTOOLS_GLOBAL_HOOK__.checkDCE(e)}catch(n){console.error(n)}}return e(),Qm.exports=jE(),Qm.exports}/** + */var nv; function jE() { if (nv) return Wt; nv = 1; var e = pl(); function n(m) { var h = "https://react.dev/errors/" + m; if (1 < arguments.length) { h += "?args[]=" + encodeURIComponent(arguments[1]); for (var g = 2; g < arguments.length; g++)h += "&args[]=" + encodeURIComponent(arguments[g]) } return "Minified React error #" + m + "; visit " + h + " for the full message or use the non-minified dev environment for full errors and additional helpful warnings." } function r() { } var a = { d: { f: r, r: function () { throw Error(n(522)) }, D: r, C: r, L: r, m: r, X: r, S: r, M: r }, p: 0, findDOMNode: null }, l = Symbol.for("react.portal"); function c(m, h, g) { var y = 3 < arguments.length && arguments[3] !== void 0 ? arguments[3] : null; return { $$typeof: l, key: y == null ? null : "" + y, children: m, containerInfo: h, implementation: g } } var d = e.__CLIENT_INTERNALS_DO_NOT_USE_OR_WARN_USERS_THEY_CANNOT_UPGRADE; function f(m, h) { if (m === "font") return ""; if (typeof h == "string") return h === "use-credentials" ? h : "" } return Wt.__DOM_INTERNALS_DO_NOT_USE_OR_WARN_USERS_THEY_CANNOT_UPGRADE = a, Wt.createPortal = function (m, h) { var g = 2 < arguments.length && arguments[2] !== void 0 ? arguments[2] : null; if (!h || h.nodeType !== 1 && h.nodeType !== 9 && h.nodeType !== 11) throw Error(n(299)); return c(m, h, null, g) }, Wt.flushSync = function (m) { var h = d.T, g = a.p; try { if (d.T = null, a.p = 2, m) return m() } finally { d.T = h, a.p = g, a.d.f() } }, Wt.preconnect = function (m, h) { typeof m == "string" && (h ? (h = h.crossOrigin, h = typeof h == "string" ? h === "use-credentials" ? h : "" : void 0) : h = null, a.d.C(m, h)) }, Wt.prefetchDNS = function (m) { typeof m == "string" && a.d.D(m) }, Wt.preinit = function (m, h) { if (typeof m == "string" && h && typeof h.as == "string") { var g = h.as, y = f(g, h.crossOrigin), x = typeof h.integrity == "string" ? h.integrity : void 0, b = typeof h.fetchPriority == "string" ? h.fetchPriority : void 0; g === "style" ? a.d.S(m, typeof h.precedence == "string" ? h.precedence : void 0, { crossOrigin: y, integrity: x, fetchPriority: b }) : g === "script" && a.d.X(m, { crossOrigin: y, integrity: x, fetchPriority: b, nonce: typeof h.nonce == "string" ? h.nonce : void 0 }) } }, Wt.preinitModule = function (m, h) { if (typeof m == "string") if (typeof h == "object" && h !== null) { if (h.as == null || h.as === "script") { var g = f(h.as, h.crossOrigin); a.d.M(m, { crossOrigin: g, integrity: typeof h.integrity == "string" ? h.integrity : void 0, nonce: typeof h.nonce == "string" ? h.nonce : void 0 }) } } else h == null && a.d.M(m) }, Wt.preload = function (m, h) { if (typeof m == "string" && typeof h == "object" && h !== null && typeof h.as == "string") { var g = h.as, y = f(g, h.crossOrigin); a.d.L(m, g, { crossOrigin: y, integrity: typeof h.integrity == "string" ? h.integrity : void 0, nonce: typeof h.nonce == "string" ? h.nonce : void 0, type: typeof h.type == "string" ? h.type : void 0, fetchPriority: typeof h.fetchPriority == "string" ? h.fetchPriority : void 0, referrerPolicy: typeof h.referrerPolicy == "string" ? h.referrerPolicy : void 0, imageSrcSet: typeof h.imageSrcSet == "string" ? h.imageSrcSet : void 0, imageSizes: typeof h.imageSizes == "string" ? h.imageSizes : void 0, media: typeof h.media == "string" ? h.media : void 0 }) } }, Wt.preloadModule = function (m, h) { if (typeof m == "string") if (h) { var g = f(h.as, h.crossOrigin); a.d.m(m, { as: typeof h.as == "string" && h.as !== "script" ? h.as : void 0, crossOrigin: g, integrity: typeof h.integrity == "string" ? h.integrity : void 0 }) } else a.d.m(m) }, Wt.requestFormReset = function (m) { a.d.r(m) }, Wt.unstable_batchedUpdates = function (m, h) { return m(h) }, Wt.useFormState = function (m, h, g) { return d.H.useFormState(m, h, g) }, Wt.useFormStatus = function () { return d.H.useHostTransitionStatus() }, Wt.version = "19.1.1", Wt } var sv; function ew() { if (sv) return Qm.exports; sv = 1; function e() { if (!(typeof __REACT_DEVTOOLS_GLOBAL_HOOK__ > "u" || typeof __REACT_DEVTOOLS_GLOBAL_HOOK__.checkDCE != "function")) try { __REACT_DEVTOOLS_GLOBAL_HOOK__.checkDCE(e) } catch (n) { console.error(n) } } return e(), Qm.exports = jE(), Qm.exports }/** * @license React * react-dom-client.production.js * @@ -38,414 +38,475 @@ function yE(e,n){for(var r=0;rC||(t.current=H[C],H[C]=null,C--)}function V(t,s){C++,H[C]=t.current,t.current=s}var W=$(null),fe=$(null),ue=$(null),te=$(null);function ie(t,s){switch(V(ue,s),V(fe,t),V(W,null),s.nodeType){case 9:case 11:t=(t=s.documentElement)&&(t=t.namespaceURI)?jy(t):0;break;default:if(t=s.tagName,s=s.namespaceURI)s=jy(s),t=_y(s,t);else switch(t){case"svg":t=1;break;case"math":t=2;break;default:t=0}}Y(W),V(W,t)}function ge(){Y(W),Y(fe),Y(ue)}function be(t){t.memoizedState!==null&&V(te,t);var s=W.current,i=_y(s,t.type);s!==i&&(V(fe,t),V(W,i))}function we(t){fe.current===t&&(Y(W),Y(fe)),te.current===t&&(Y(te),zi._currentValue=I)}var ne=Object.prototype.hasOwnProperty,pe=e.unstable_scheduleCallback,he=e.unstable_cancelCallback,ee=e.unstable_shouldYield,ve=e.unstable_requestPaint,ye=e.unstable_now,Te=e.unstable_getCurrentPriorityLevel,je=e.unstable_ImmediatePriority,$e=e.unstable_UserBlockingPriority,it=e.unstable_NormalPriority,ze=e.unstable_LowPriority,Se=e.unstable_IdlePriority,Pe=e.log,Ee=e.unstable_setDisableYieldValue,He=null,Fe=null;function Nt(t){if(typeof Pe=="function"&&Ee(t),Fe&&typeof Fe.setStrictMode=="function")try{Fe.setStrictMode(He,t)}catch{}}var yt=Math.clz32?Math.clz32:xe,hs=Math.log,wo=Math.LN2;function xe(t){return t>>>=0,t===0?32:31-(hs(t)/wo|0)|0}var Re=256,Ue=4194304;function Et(t){var s=t&42;if(s!==0)return s;switch(t&-t){case 1:return 1;case 2:return 2;case 4:return 4;case 8:return 8;case 16:return 16;case 32:return 32;case 64:return 64;case 128:return 128;case 256:case 512:case 1024:case 2048:case 4096:case 8192:case 16384:case 32768:case 65536:case 131072:case 262144:case 524288:case 1048576:case 2097152:return t&4194048;case 4194304:case 8388608:case 16777216:case 33554432:return t&62914560;case 67108864:return 67108864;case 134217728:return 134217728;case 268435456:return 268435456;case 536870912:return 536870912;case 1073741824:return 0;default:return t}}function Dn(t,s,i){var u=t.pendingLanes;if(u===0)return 0;var p=0,v=t.suspendedLanes,A=t.pingedLanes;t=t.warmLanes;var z=u&134217727;return z!==0?(u=z&~v,u!==0?p=Et(u):(A&=z,A!==0?p=Et(A):i||(i=z&~t,i!==0&&(p=Et(i))))):(z=u&~v,z!==0?p=Et(z):A!==0?p=Et(A):i||(i=u&~t,i!==0&&(p=Et(i)))),p===0?0:s!==0&&s!==p&&(s&v)===0&&(v=p&-p,i=s&-s,v>=i||v===32&&(i&4194048)!==0)?s:p}function Le(t,s){return(t.pendingLanes&~(t.suspendedLanes&~t.pingedLanes)&s)===0}function Ne(t,s){switch(t){case 1:case 2:case 4:case 8:case 64:return s+250;case 16:case 32:case 128:case 256:case 512:case 1024:case 2048:case 4096:case 8192:case 16384:case 32768:case 65536:case 131072:case 262144:case 524288:case 1048576:case 2097152:return s+5e3;case 4194304:case 8388608:case 16777216:case 33554432:return-1;case 67108864:case 134217728:case 268435456:case 536870912:case 1073741824:return-1;default:return-1}}function lt(){var t=Re;return Re<<=1,(Re&4194048)===0&&(Re=256),t}function ot(){var t=Ue;return Ue<<=1,(Ue&62914560)===0&&(Ue=4194304),t}function At(t){for(var s=[],i=0;31>i;i++)s.push(t);return s}function en(t,s){t.pendingLanes|=s,s!==268435456&&(t.suspendedLanes=0,t.pingedLanes=0,t.warmLanes=0)}function On(t,s,i,u,p,v){var A=t.pendingLanes;t.pendingLanes=i,t.suspendedLanes=0,t.pingedLanes=0,t.warmLanes=0,t.expiredLanes&=i,t.entangledLanes&=i,t.errorRecoveryDisabledLanes&=i,t.shellSuspendCounter=0;var z=t.entanglements,F=t.expirationTimes,re=t.hiddenUpdates;for(i=A&~i;0)":-1p||F[u]!==re[p]){var le=` -`+F[u].replace(" at new "," at ");return t.displayName&&le.includes("")&&(le=le.replace("",t.displayName)),le}while(1<=u&&0<=p);break}}}finally{qa=!1,Error.prepareStackTrace=i}return(i=t?t.displayName||t.name:"")?ys(i):""}function Ud(t){switch(t.tag){case 26:case 27:case 5:return ys(t.type);case 16:return ys("Lazy");case 13:return ys("Suspense");case 19:return ys("SuspenseList");case 0:case 15:return Fa(t.type,!1);case 11:return Fa(t.type.render,!1);case 1:return Fa(t.type,!0);case 31:return ys("Activity");default:return""}}function Il(t){try{var s="";do s+=Ud(t),t=t.return;while(t);return s}catch(i){return` -Error generating stack: `+i.message+` -`+i.stack}}function tn(t){switch(typeof t){case"bigint":case"boolean":case"number":case"string":case"undefined":return t;case"object":return t;default:return""}}function Ll(t){var s=t.type;return(t=t.nodeName)&&t.toLowerCase()==="input"&&(s==="checkbox"||s==="radio")}function Vd(t){var s=Ll(t)?"checked":"value",i=Object.getOwnPropertyDescriptor(t.constructor.prototype,s),u=""+t[s];if(!t.hasOwnProperty(s)&&typeof i<"u"&&typeof i.get=="function"&&typeof i.set=="function"){var p=i.get,v=i.set;return Object.defineProperty(t,s,{configurable:!0,get:function(){return p.call(this)},set:function(A){u=""+A,v.call(this,A)}}),Object.defineProperty(t,s,{enumerable:i.enumerable}),{getValue:function(){return u},setValue:function(A){u=""+A},stopTracking:function(){t._valueTracker=null,delete t[s]}}}}function jo(t){t._valueTracker||(t._valueTracker=Vd(t))}function Ya(t){if(!t)return!1;var s=t._valueTracker;if(!s)return!0;var i=s.getValue(),u="";return t&&(u=Ll(t)?t.checked?"true":"false":t.value),t=u,t!==i?(s.setValue(t),!0):!1}function _o(t){if(t=t||(typeof document<"u"?document:void 0),typeof t>"u")return null;try{return t.activeElement||t.body}catch{return t.body}}var qd=/[\n"\\]/g;function nn(t){return t.replace(qd,function(s){return"\\"+s.charCodeAt(0).toString(16)+" "})}function zr(t,s,i,u,p,v,A,z){t.name="",A!=null&&typeof A!="function"&&typeof A!="symbol"&&typeof A!="boolean"?t.type=A:t.removeAttribute("type"),s!=null?A==="number"?(s===0&&t.value===""||t.value!=s)&&(t.value=""+tn(s)):t.value!==""+tn(s)&&(t.value=""+tn(s)):A!=="submit"&&A!=="reset"||t.removeAttribute("value"),s!=null?Ga(t,A,tn(s)):i!=null?Ga(t,A,tn(i)):u!=null&&t.removeAttribute("value"),p==null&&v!=null&&(t.defaultChecked=!!v),p!=null&&(t.checked=p&&typeof p!="function"&&typeof p!="symbol"),z!=null&&typeof z!="function"&&typeof z!="symbol"&&typeof z!="boolean"?t.name=""+tn(z):t.removeAttribute("name")}function Hl(t,s,i,u,p,v,A,z){if(v!=null&&typeof v!="function"&&typeof v!="symbol"&&typeof v!="boolean"&&(t.type=v),s!=null||i!=null){if(!(v!=="submit"&&v!=="reset"||s!=null))return;i=i!=null?""+tn(i):"",s=s!=null?""+tn(s):i,z||s===t.value||(t.value=s),t.defaultValue=s}u=u??p,u=typeof u!="function"&&typeof u!="symbol"&&!!u,t.checked=z?t.checked:!!u,t.defaultChecked=!!u,A!=null&&typeof A!="function"&&typeof A!="symbol"&&typeof A!="boolean"&&(t.name=A)}function Ga(t,s,i){s==="number"&&_o(t.ownerDocument)===t||t.defaultValue===""+i||(t.defaultValue=""+i)}function vs(t,s,i,u){if(t=t.options,s){s={};for(var p=0;p"u"||typeof window.document>"u"||typeof window.document.createElement>"u"),Zd=!1;if(bs)try{var Za={};Object.defineProperty(Za,"passive",{get:function(){Zd=!0}}),window.addEventListener("test",Za,Za),window.removeEventListener("test",Za,Za)}catch{Zd=!1}var Zs=null,Wd=null,Bl=null;function kg(){if(Bl)return Bl;var t,s=Wd,i=s.length,u,p="value"in Zs?Zs.value:Zs.textContent,v=p.length;for(t=0;t=Qa),Og=" ",zg=!1;function Ig(t,s){switch(t){case"keyup":return Uj.indexOf(s.keyCode)!==-1;case"keydown":return s.keyCode!==229;case"keypress":case"mousedown":case"focusout":return!0;default:return!1}}function Lg(t){return t=t.detail,typeof t=="object"&&"data"in t?t.data:null}var Ao=!1;function qj(t,s){switch(t){case"compositionend":return Lg(s);case"keypress":return s.which!==32?null:(zg=!0,Og);case"textInput":return t=s.data,t===Og&&zg?null:t;default:return null}}function Fj(t,s){if(Ao)return t==="compositionend"||!tf&&Ig(t,s)?(t=kg(),Bl=Wd=Zs=null,Ao=!1,t):null;switch(t){case"paste":return null;case"keypress":if(!(s.ctrlKey||s.altKey||s.metaKey)||s.ctrlKey&&s.altKey){if(s.char&&1=s)return{node:i,offset:s-t};t=u}e:{for(;i;){if(i.nextSibling){i=i.nextSibling;break e}i=i.parentNode}i=void 0}i=Fg(i)}}function Gg(t,s){return t&&s?t===s?!0:t&&t.nodeType===3?!1:s&&s.nodeType===3?Gg(t,s.parentNode):"contains"in t?t.contains(s):t.compareDocumentPosition?!!(t.compareDocumentPosition(s)&16):!1:!1}function Xg(t){t=t!=null&&t.ownerDocument!=null&&t.ownerDocument.defaultView!=null?t.ownerDocument.defaultView:window;for(var s=_o(t.document);s instanceof t.HTMLIFrameElement;){try{var i=typeof s.contentWindow.location.href=="string"}catch{i=!1}if(i)t=s.contentWindow;else break;s=_o(t.document)}return s}function rf(t){var s=t&&t.nodeName&&t.nodeName.toLowerCase();return s&&(s==="input"&&(t.type==="text"||t.type==="search"||t.type==="tel"||t.type==="url"||t.type==="password")||s==="textarea"||t.contentEditable==="true")}var Jj=bs&&"documentMode"in document&&11>=document.documentMode,Mo=null,of=null,ni=null,af=!1;function Zg(t,s,i){var u=i.window===i?i.document:i.nodeType===9?i:i.ownerDocument;af||Mo==null||Mo!==_o(u)||(u=Mo,"selectionStart"in u&&rf(u)?u={start:u.selectionStart,end:u.selectionEnd}:(u=(u.ownerDocument&&u.ownerDocument.defaultView||window).getSelection(),u={anchorNode:u.anchorNode,anchorOffset:u.anchorOffset,focusNode:u.focusNode,focusOffset:u.focusOffset}),ni&&ti(ni,u)||(ni=u,u=Mc(of,"onSelect"),0>=A,p-=A,Ns=1<<32-yt(s)+p|i<v?v:8;var A=k.T,z={};k.T=z,Yf(t,!1,s,i);try{var F=p(),re=k.S;if(re!==null&&re(z,F),F!==null&&typeof F=="object"&&typeof F.then=="function"){var le=l_(F,u);xi(t,s,le,hn(t))}else xi(t,s,u,hn(t))}catch(me){xi(t,s,{then:function(){},status:"rejected",reason:me},hn())}finally{L.p=v,k.T=A}}function m_(){}function qf(t,s,i,u){if(t.tag!==5)throw Error(a(476));var p=Wx(t).queue;Zx(t,p,s,I,i===null?m_:function(){return Kx(t),i(u)})}function Wx(t){var s=t.memoizedState;if(s!==null)return s;s={memoizedState:I,baseState:I,baseQueue:null,queue:{pending:null,lanes:0,dispatch:null,lastRenderedReducer:Es,lastRenderedState:I},next:null};var i={};return s.next={memoizedState:i,baseState:i,baseQueue:null,queue:{pending:null,lanes:0,dispatch:null,lastRenderedReducer:Es,lastRenderedState:i},next:null},t.memoizedState=s,t=t.alternate,t!==null&&(t.memoizedState=s),s}function Kx(t){var s=Wx(t).next.queue;xi(t,s,{},hn())}function Ff(){return Zt(zi)}function Qx(){return Rt().memoizedState}function Jx(){return Rt().memoizedState}function h_(t){for(var s=t.return;s!==null;){switch(s.tag){case 24:case 3:var i=hn();t=Qs(i);var u=Js(s,t,i);u!==null&&(pn(u,s,i),di(u,s,i)),s={cache:wf()},t.payload=s;return}s=s.return}}function p_(t,s,i){var u=hn();i={lane:u,revertLane:0,action:i,hasEagerState:!1,eagerState:null,next:null},uc(t)?t0(s,i):(i=df(t,s,i,u),i!==null&&(pn(i,t,u),n0(i,s,u)))}function e0(t,s,i){var u=hn();xi(t,s,i,u)}function xi(t,s,i,u){var p={lane:u,revertLane:0,action:i,hasEagerState:!1,eagerState:null,next:null};if(uc(t))t0(s,p);else{var v=t.alternate;if(t.lanes===0&&(v===null||v.lanes===0)&&(v=s.lastRenderedReducer,v!==null))try{var A=s.lastRenderedState,z=v(A,i);if(p.hasEagerState=!0,p.eagerState=z,cn(z,A))return Gl(t,s,p,0),gt===null&&Yl(),!1}catch{}finally{}if(i=df(t,s,p,u),i!==null)return pn(i,t,u),n0(i,s,u),!0}return!1}function Yf(t,s,i,u){if(u={lane:2,revertLane:jm(),action:u,hasEagerState:!1,eagerState:null,next:null},uc(t)){if(s)throw Error(a(479))}else s=df(t,i,u,2),s!==null&&pn(s,t,2)}function uc(t){var s=t.alternate;return t===Ze||s!==null&&s===Ze}function t0(t,s){Bo=rc=!0;var i=t.pending;i===null?s.next=s:(s.next=i.next,i.next=s),t.pending=s}function n0(t,s,i){if((i&4194048)!==0){var u=s.lanes;u&=t.pendingLanes,i|=u,s.lanes=i,La(t,i)}}var dc={readContext:Zt,use:ac,useCallback:Ct,useContext:Ct,useEffect:Ct,useImperativeHandle:Ct,useLayoutEffect:Ct,useInsertionEffect:Ct,useMemo:Ct,useReducer:Ct,useRef:Ct,useState:Ct,useDebugValue:Ct,useDeferredValue:Ct,useTransition:Ct,useSyncExternalStore:Ct,useId:Ct,useHostTransitionStatus:Ct,useFormState:Ct,useActionState:Ct,useOptimistic:Ct,useMemoCache:Ct,useCacheRefresh:Ct},s0={readContext:Zt,use:ac,useCallback:function(t,s){return rn().memoizedState=[t,s===void 0?null:s],t},useContext:Zt,useEffect:Bx,useImperativeHandle:function(t,s,i){i=i!=null?i.concat([t]):null,cc(4194308,4,qx.bind(null,s,t),i)},useLayoutEffect:function(t,s){return cc(4194308,4,t,s)},useInsertionEffect:function(t,s){cc(4,2,t,s)},useMemo:function(t,s){var i=rn();s=s===void 0?null:s;var u=t();if(Gr){Nt(!0);try{t()}finally{Nt(!1)}}return i.memoizedState=[u,s],u},useReducer:function(t,s,i){var u=rn();if(i!==void 0){var p=i(s);if(Gr){Nt(!0);try{i(s)}finally{Nt(!1)}}}else p=s;return u.memoizedState=u.baseState=p,t={pending:null,lanes:0,dispatch:null,lastRenderedReducer:t,lastRenderedState:p},u.queue=t,t=t.dispatch=p_.bind(null,Ze,t),[u.memoizedState,t]},useRef:function(t){var s=rn();return t={current:t},s.memoizedState=t},useState:function(t){t=Bf(t);var s=t.queue,i=e0.bind(null,Ze,s);return s.dispatch=i,[t.memoizedState,i]},useDebugValue:Uf,useDeferredValue:function(t,s){var i=rn();return Vf(i,t,s)},useTransition:function(){var t=Bf(!1);return t=Zx.bind(null,Ze,t.queue,!0,!1),rn().memoizedState=t,[!1,t]},useSyncExternalStore:function(t,s,i){var u=Ze,p=rn();if(at){if(i===void 0)throw Error(a(407));i=i()}else{if(i=s(),gt===null)throw Error(a(349));(nt&124)!==0||jx(u,s,i)}p.memoizedState=i;var v={value:i,getSnapshot:s};return p.queue=v,Bx(Ex.bind(null,u,v,t),[t]),u.flags|=2048,Uo(9,lc(),_x.bind(null,u,v,i,s),null),i},useId:function(){var t=rn(),s=gt.identifierPrefix;if(at){var i=Ss,u=Ns;i=(u&~(1<<32-yt(u)-1)).toString(32)+i,s="«"+s+"R"+i,i=oc++,0Ve?(Pt=Oe,Oe=null):Pt=Oe.sibling;var rt=oe(Q,Oe,se[Ve],ce);if(rt===null){Oe===null&&(Oe=Pt);break}t&&Oe&&rt.alternate===null&&s(Q,Oe),X=v(rt,X,Ve),Ke===null?Ce=rt:Ke.sibling=rt,Ke=rt,Oe=Pt}if(Ve===se.length)return i(Q,Oe),at&&Pr(Q,Ve),Ce;if(Oe===null){for(;VeVe?(Pt=Oe,Oe=null):Pt=Oe.sibling;var gr=oe(Q,Oe,rt.value,ce);if(gr===null){Oe===null&&(Oe=Pt);break}t&&Oe&&gr.alternate===null&&s(Q,Oe),X=v(gr,X,Ve),Ke===null?Ce=gr:Ke.sibling=gr,Ke=gr,Oe=Pt}if(rt.done)return i(Q,Oe),at&&Pr(Q,Ve),Ce;if(Oe===null){for(;!rt.done;Ve++,rt=se.next())rt=me(Q,rt.value,ce),rt!==null&&(X=v(rt,X,Ve),Ke===null?Ce=rt:Ke.sibling=rt,Ke=rt);return at&&Pr(Q,Ve),Ce}for(Oe=u(Oe);!rt.done;Ve++,rt=se.next())rt=ae(Oe,Q,Ve,rt.value,ce),rt!==null&&(t&&rt.alternate!==null&&Oe.delete(rt.key===null?Ve:rt.key),X=v(rt,X,Ve),Ke===null?Ce=rt:Ke.sibling=rt,Ke=rt);return t&&Oe.forEach(function(xE){return s(Q,xE)}),at&&Pr(Q,Ve),Ce}function mt(Q,X,se,ce){if(typeof se=="object"&&se!==null&&se.type===S&&se.key===null&&(se=se.props.children),typeof se=="object"&&se!==null){switch(se.$$typeof){case x:e:{for(var Ce=se.key;X!==null;){if(X.key===Ce){if(Ce=se.type,Ce===S){if(X.tag===7){i(Q,X.sibling),ce=p(X,se.props.children),ce.return=Q,Q=ce;break e}}else if(X.elementType===Ce||typeof Ce=="object"&&Ce!==null&&Ce.$$typeof===B&&o0(Ce)===X.type){i(Q,X.sibling),ce=p(X,se.props),vi(ce,se),ce.return=Q,Q=ce;break e}i(Q,X);break}else s(Q,X);X=X.sibling}se.type===S?(ce=$r(se.props.children,Q.mode,ce,se.key),ce.return=Q,Q=ce):(ce=Zl(se.type,se.key,se.props,null,Q.mode,ce),vi(ce,se),ce.return=Q,Q=ce)}return A(Q);case b:e:{for(Ce=se.key;X!==null;){if(X.key===Ce)if(X.tag===4&&X.stateNode.containerInfo===se.containerInfo&&X.stateNode.implementation===se.implementation){i(Q,X.sibling),ce=p(X,se.children||[]),ce.return=Q,Q=ce;break e}else{i(Q,X);break}else s(Q,X);X=X.sibling}ce=hf(se,Q.mode,ce),ce.return=Q,Q=ce}return A(Q);case B:return Ce=se._init,se=Ce(se._payload),mt(Q,X,se,ce)}if(U(se))return qe(Q,X,se,ce);if(G(se)){if(Ce=G(se),typeof Ce!="function")throw Error(a(150));return se=Ce.call(se),Be(Q,X,se,ce)}if(typeof se.then=="function")return mt(Q,X,fc(se),ce);if(se.$$typeof===E)return mt(Q,X,Jl(Q,se),ce);mc(Q,se)}return typeof se=="string"&&se!==""||typeof se=="number"||typeof se=="bigint"?(se=""+se,X!==null&&X.tag===6?(i(Q,X.sibling),ce=p(X,se),ce.return=Q,Q=ce):(i(Q,X),ce=mf(se,Q.mode,ce),ce.return=Q,Q=ce),A(Q)):i(Q,X)}return function(Q,X,se,ce){try{yi=0;var Ce=mt(Q,X,se,ce);return Vo=null,Ce}catch(Oe){if(Oe===ci||Oe===tc)throw Oe;var Ke=un(29,Oe,null,Q.mode);return Ke.lanes=ce,Ke.return=Q,Ke}finally{}}}var qo=a0(!0),i0=a0(!1),En=$(null),Wn=null;function tr(t){var s=t.alternate;V(zt,zt.current&1),V(En,t),Wn===null&&(s===null||$o.current!==null||s.memoizedState!==null)&&(Wn=t)}function l0(t){if(t.tag===22){if(V(zt,zt.current),V(En,t),Wn===null){var s=t.alternate;s!==null&&s.memoizedState!==null&&(Wn=t)}}else nr()}function nr(){V(zt,zt.current),V(En,En.current)}function Cs(t){Y(En),Wn===t&&(Wn=null),Y(zt)}var zt=$(0);function hc(t){for(var s=t;s!==null;){if(s.tag===13){var i=s.memoizedState;if(i!==null&&(i=i.dehydrated,i===null||i.data==="$?"||Im(i)))return s}else if(s.tag===19&&s.memoizedProps.revealOrder!==void 0){if((s.flags&128)!==0)return s}else if(s.child!==null){s.child.return=s,s=s.child;continue}if(s===t)break;for(;s.sibling===null;){if(s.return===null||s.return===t)return null;s=s.return}s.sibling.return=s.return,s=s.sibling}return null}function Gf(t,s,i,u){s=t.memoizedState,i=i(u,s),i=i==null?s:g({},s,i),t.memoizedState=i,t.lanes===0&&(t.updateQueue.baseState=i)}var Xf={enqueueSetState:function(t,s,i){t=t._reactInternals;var u=hn(),p=Qs(u);p.payload=s,i!=null&&(p.callback=i),s=Js(t,p,u),s!==null&&(pn(s,t,u),di(s,t,u))},enqueueReplaceState:function(t,s,i){t=t._reactInternals;var u=hn(),p=Qs(u);p.tag=1,p.payload=s,i!=null&&(p.callback=i),s=Js(t,p,u),s!==null&&(pn(s,t,u),di(s,t,u))},enqueueForceUpdate:function(t,s){t=t._reactInternals;var i=hn(),u=Qs(i);u.tag=2,s!=null&&(u.callback=s),s=Js(t,u,i),s!==null&&(pn(s,t,i),di(s,t,i))}};function c0(t,s,i,u,p,v,A){return t=t.stateNode,typeof t.shouldComponentUpdate=="function"?t.shouldComponentUpdate(u,v,A):s.prototype&&s.prototype.isPureReactComponent?!ti(i,u)||!ti(p,v):!0}function u0(t,s,i,u){t=s.state,typeof s.componentWillReceiveProps=="function"&&s.componentWillReceiveProps(i,u),typeof s.UNSAFE_componentWillReceiveProps=="function"&&s.UNSAFE_componentWillReceiveProps(i,u),s.state!==t&&Xf.enqueueReplaceState(s,s.state,null)}function Xr(t,s){var i=s;if("ref"in s){i={};for(var u in s)u!=="ref"&&(i[u]=s[u])}if(t=t.defaultProps){i===s&&(i=g({},i));for(var p in t)i[p]===void 0&&(i[p]=t[p])}return i}var pc=typeof reportError=="function"?reportError:function(t){if(typeof window=="object"&&typeof window.ErrorEvent=="function"){var s=new window.ErrorEvent("error",{bubbles:!0,cancelable:!0,message:typeof t=="object"&&t!==null&&typeof t.message=="string"?String(t.message):String(t),error:t});if(!window.dispatchEvent(s))return}else if(typeof process=="object"&&typeof process.emit=="function"){process.emit("uncaughtException",t);return}console.error(t)};function d0(t){pc(t)}function f0(t){console.error(t)}function m0(t){pc(t)}function gc(t,s){try{var i=t.onUncaughtError;i(s.value,{componentStack:s.stack})}catch(u){setTimeout(function(){throw u})}}function h0(t,s,i){try{var u=t.onCaughtError;u(i.value,{componentStack:i.stack,errorBoundary:s.tag===1?s.stateNode:null})}catch(p){setTimeout(function(){throw p})}}function Zf(t,s,i){return i=Qs(i),i.tag=3,i.payload={element:null},i.callback=function(){gc(t,s)},i}function p0(t){return t=Qs(t),t.tag=3,t}function g0(t,s,i,u){var p=i.type.getDerivedStateFromError;if(typeof p=="function"){var v=u.value;t.payload=function(){return p(v)},t.callback=function(){h0(s,i,u)}}var A=i.stateNode;A!==null&&typeof A.componentDidCatch=="function"&&(t.callback=function(){h0(s,i,u),typeof p!="function"&&(lr===null?lr=new Set([this]):lr.add(this));var z=u.stack;this.componentDidCatch(u.value,{componentStack:z!==null?z:""})})}function x_(t,s,i,u,p){if(i.flags|=32768,u!==null&&typeof u=="object"&&typeof u.then=="function"){if(s=i.alternate,s!==null&&ai(s,i,p,!0),i=En.current,i!==null){switch(i.tag){case 13:return Wn===null?vm():i.alternate===null&&_t===0&&(_t=3),i.flags&=-257,i.flags|=65536,i.lanes=p,u===jf?i.flags|=16384:(s=i.updateQueue,s===null?i.updateQueue=new Set([u]):s.add(u),wm(t,u,p)),!1;case 22:return i.flags|=65536,u===jf?i.flags|=16384:(s=i.updateQueue,s===null?(s={transitions:null,markerInstances:null,retryQueue:new Set([u])},i.updateQueue=s):(i=s.retryQueue,i===null?s.retryQueue=new Set([u]):i.add(u)),wm(t,u,p)),!1}throw Error(a(435,i.tag))}return wm(t,u,p),vm(),!1}if(at)return s=En.current,s!==null?((s.flags&65536)===0&&(s.flags|=256),s.flags|=65536,s.lanes=p,u!==xf&&(t=Error(a(422),{cause:u}),oi(Nn(t,i)))):(u!==xf&&(s=Error(a(423),{cause:u}),oi(Nn(s,i))),t=t.current.alternate,t.flags|=65536,p&=-p,t.lanes|=p,u=Nn(u,i),p=Zf(t.stateNode,u,p),Cf(t,p),_t!==4&&(_t=2)),!1;var v=Error(a(520),{cause:u});if(v=Nn(v,i),Ei===null?Ei=[v]:Ei.push(v),_t!==4&&(_t=2),s===null)return!0;u=Nn(u,i),i=s;do{switch(i.tag){case 3:return i.flags|=65536,t=p&-p,i.lanes|=t,t=Zf(i.stateNode,u,t),Cf(i,t),!1;case 1:if(s=i.type,v=i.stateNode,(i.flags&128)===0&&(typeof s.getDerivedStateFromError=="function"||v!==null&&typeof v.componentDidCatch=="function"&&(lr===null||!lr.has(v))))return i.flags|=65536,p&=-p,i.lanes|=p,p=p0(p),g0(p,t,i,u),Cf(i,p),!1}i=i.return}while(i!==null);return!1}var x0=Error(a(461)),$t=!1;function Vt(t,s,i,u){s.child=t===null?i0(s,null,i,u):qo(s,t.child,i,u)}function y0(t,s,i,u,p){i=i.render;var v=s.ref;if("ref"in u){var A={};for(var z in u)z!=="ref"&&(A[z]=u[z])}else A=u;return Fr(s),u=Rf(t,s,i,A,v,p),z=Df(),t!==null&&!$t?(Of(t,s,p),ks(t,s,p)):(at&&z&&pf(s),s.flags|=1,Vt(t,s,u,p),s.child)}function v0(t,s,i,u,p){if(t===null){var v=i.type;return typeof v=="function"&&!ff(v)&&v.defaultProps===void 0&&i.compare===null?(s.tag=15,s.type=v,b0(t,s,v,u,p)):(t=Zl(i.type,null,u,s,s.mode,p),t.ref=s.ref,t.return=s,s.child=t)}if(v=t.child,!sm(t,p)){var A=v.memoizedProps;if(i=i.compare,i=i!==null?i:ti,i(A,u)&&t.ref===s.ref)return ks(t,s,p)}return s.flags|=1,t=ws(v,u),t.ref=s.ref,t.return=s,s.child=t}function b0(t,s,i,u,p){if(t!==null){var v=t.memoizedProps;if(ti(v,u)&&t.ref===s.ref)if($t=!1,s.pendingProps=u=v,sm(t,p))(t.flags&131072)!==0&&($t=!0);else return s.lanes=t.lanes,ks(t,s,p)}return Wf(t,s,i,u,p)}function w0(t,s,i){var u=s.pendingProps,p=u.children,v=t!==null?t.memoizedState:null;if(u.mode==="hidden"){if((s.flags&128)!==0){if(u=v!==null?v.baseLanes|i:i,t!==null){for(p=s.child=t.child,v=0;p!==null;)v=v|p.lanes|p.childLanes,p=p.sibling;s.childLanes=v&~u}else s.childLanes=0,s.child=null;return N0(t,s,u,i)}if((i&536870912)!==0)s.memoizedState={baseLanes:0,cachePool:null},t!==null&&ec(s,v!==null?v.cachePool:null),v!==null?bx(s,v):Af(),l0(s);else return s.lanes=s.childLanes=536870912,N0(t,s,v!==null?v.baseLanes|i:i,i)}else v!==null?(ec(s,v.cachePool),bx(s,v),nr(),s.memoizedState=null):(t!==null&&ec(s,null),Af(),nr());return Vt(t,s,p,i),s.child}function N0(t,s,i,u){var p=Sf();return p=p===null?null:{parent:Ot._currentValue,pool:p},s.memoizedState={baseLanes:i,cachePool:p},t!==null&&ec(s,null),Af(),l0(s),t!==null&&ai(t,s,u,!0),null}function xc(t,s){var i=s.ref;if(i===null)t!==null&&t.ref!==null&&(s.flags|=4194816);else{if(typeof i!="function"&&typeof i!="object")throw Error(a(284));(t===null||t.ref!==i)&&(s.flags|=4194816)}}function Wf(t,s,i,u,p){return Fr(s),i=Rf(t,s,i,u,void 0,p),u=Df(),t!==null&&!$t?(Of(t,s,p),ks(t,s,p)):(at&&u&&pf(s),s.flags|=1,Vt(t,s,i,p),s.child)}function S0(t,s,i,u,p,v){return Fr(s),s.updateQueue=null,i=Nx(s,u,i,p),wx(t),u=Df(),t!==null&&!$t?(Of(t,s,v),ks(t,s,v)):(at&&u&&pf(s),s.flags|=1,Vt(t,s,i,v),s.child)}function j0(t,s,i,u,p){if(Fr(s),s.stateNode===null){var v=Oo,A=i.contextType;typeof A=="object"&&A!==null&&(v=Zt(A)),v=new i(u,v),s.memoizedState=v.state!==null&&v.state!==void 0?v.state:null,v.updater=Xf,s.stateNode=v,v._reactInternals=s,v=s.stateNode,v.props=u,v.state=s.memoizedState,v.refs={},_f(s),A=i.contextType,v.context=typeof A=="object"&&A!==null?Zt(A):Oo,v.state=s.memoizedState,A=i.getDerivedStateFromProps,typeof A=="function"&&(Gf(s,i,A,u),v.state=s.memoizedState),typeof i.getDerivedStateFromProps=="function"||typeof v.getSnapshotBeforeUpdate=="function"||typeof v.UNSAFE_componentWillMount!="function"&&typeof v.componentWillMount!="function"||(A=v.state,typeof v.componentWillMount=="function"&&v.componentWillMount(),typeof v.UNSAFE_componentWillMount=="function"&&v.UNSAFE_componentWillMount(),A!==v.state&&Xf.enqueueReplaceState(v,v.state,null),mi(s,u,v,p),fi(),v.state=s.memoizedState),typeof v.componentDidMount=="function"&&(s.flags|=4194308),u=!0}else if(t===null){v=s.stateNode;var z=s.memoizedProps,F=Xr(i,z);v.props=F;var re=v.context,le=i.contextType;A=Oo,typeof le=="object"&&le!==null&&(A=Zt(le));var me=i.getDerivedStateFromProps;le=typeof me=="function"||typeof v.getSnapshotBeforeUpdate=="function",z=s.pendingProps!==z,le||typeof v.UNSAFE_componentWillReceiveProps!="function"&&typeof v.componentWillReceiveProps!="function"||(z||re!==A)&&u0(s,v,u,A),Ks=!1;var oe=s.memoizedState;v.state=oe,mi(s,u,v,p),fi(),re=s.memoizedState,z||oe!==re||Ks?(typeof me=="function"&&(Gf(s,i,me,u),re=s.memoizedState),(F=Ks||c0(s,i,F,u,oe,re,A))?(le||typeof v.UNSAFE_componentWillMount!="function"&&typeof v.componentWillMount!="function"||(typeof v.componentWillMount=="function"&&v.componentWillMount(),typeof v.UNSAFE_componentWillMount=="function"&&v.UNSAFE_componentWillMount()),typeof v.componentDidMount=="function"&&(s.flags|=4194308)):(typeof v.componentDidMount=="function"&&(s.flags|=4194308),s.memoizedProps=u,s.memoizedState=re),v.props=u,v.state=re,v.context=A,u=F):(typeof v.componentDidMount=="function"&&(s.flags|=4194308),u=!1)}else{v=s.stateNode,Ef(t,s),A=s.memoizedProps,le=Xr(i,A),v.props=le,me=s.pendingProps,oe=v.context,re=i.contextType,F=Oo,typeof re=="object"&&re!==null&&(F=Zt(re)),z=i.getDerivedStateFromProps,(re=typeof z=="function"||typeof v.getSnapshotBeforeUpdate=="function")||typeof v.UNSAFE_componentWillReceiveProps!="function"&&typeof v.componentWillReceiveProps!="function"||(A!==me||oe!==F)&&u0(s,v,u,F),Ks=!1,oe=s.memoizedState,v.state=oe,mi(s,u,v,p),fi();var ae=s.memoizedState;A!==me||oe!==ae||Ks||t!==null&&t.dependencies!==null&&Ql(t.dependencies)?(typeof z=="function"&&(Gf(s,i,z,u),ae=s.memoizedState),(le=Ks||c0(s,i,le,u,oe,ae,F)||t!==null&&t.dependencies!==null&&Ql(t.dependencies))?(re||typeof v.UNSAFE_componentWillUpdate!="function"&&typeof v.componentWillUpdate!="function"||(typeof v.componentWillUpdate=="function"&&v.componentWillUpdate(u,ae,F),typeof v.UNSAFE_componentWillUpdate=="function"&&v.UNSAFE_componentWillUpdate(u,ae,F)),typeof v.componentDidUpdate=="function"&&(s.flags|=4),typeof v.getSnapshotBeforeUpdate=="function"&&(s.flags|=1024)):(typeof v.componentDidUpdate!="function"||A===t.memoizedProps&&oe===t.memoizedState||(s.flags|=4),typeof v.getSnapshotBeforeUpdate!="function"||A===t.memoizedProps&&oe===t.memoizedState||(s.flags|=1024),s.memoizedProps=u,s.memoizedState=ae),v.props=u,v.state=ae,v.context=F,u=le):(typeof v.componentDidUpdate!="function"||A===t.memoizedProps&&oe===t.memoizedState||(s.flags|=4),typeof v.getSnapshotBeforeUpdate!="function"||A===t.memoizedProps&&oe===t.memoizedState||(s.flags|=1024),u=!1)}return v=u,xc(t,s),u=(s.flags&128)!==0,v||u?(v=s.stateNode,i=u&&typeof i.getDerivedStateFromError!="function"?null:v.render(),s.flags|=1,t!==null&&u?(s.child=qo(s,t.child,null,p),s.child=qo(s,null,i,p)):Vt(t,s,i,p),s.memoizedState=v.state,t=s.child):t=ks(t,s,p),t}function _0(t,s,i,u){return ri(),s.flags|=256,Vt(t,s,i,u),s.child}var Kf={dehydrated:null,treeContext:null,retryLane:0,hydrationErrors:null};function Qf(t){return{baseLanes:t,cachePool:fx()}}function Jf(t,s,i){return t=t!==null?t.childLanes&~i:0,s&&(t|=Cn),t}function E0(t,s,i){var u=s.pendingProps,p=!1,v=(s.flags&128)!==0,A;if((A=v)||(A=t!==null&&t.memoizedState===null?!1:(zt.current&2)!==0),A&&(p=!0,s.flags&=-129),A=(s.flags&32)!==0,s.flags&=-33,t===null){if(at){if(p?tr(s):nr(),at){var z=jt,F;if(F=z){e:{for(F=z,z=Zn;F.nodeType!==8;){if(!z){z=null;break e}if(F=Hn(F.nextSibling),F===null){z=null;break e}}z=F}z!==null?(s.memoizedState={dehydrated:z,treeContext:Br!==null?{id:Ns,overflow:Ss}:null,retryLane:536870912,hydrationErrors:null},F=un(18,null,null,0),F.stateNode=z,F.return=s,s.child=F,Kt=s,jt=null,F=!0):F=!1}F||Vr(s)}if(z=s.memoizedState,z!==null&&(z=z.dehydrated,z!==null))return Im(z)?s.lanes=32:s.lanes=536870912,null;Cs(s)}return z=u.children,u=u.fallback,p?(nr(),p=s.mode,z=yc({mode:"hidden",children:z},p),u=$r(u,p,i,null),z.return=s,u.return=s,z.sibling=u,s.child=z,p=s.child,p.memoizedState=Qf(i),p.childLanes=Jf(t,A,i),s.memoizedState=Kf,u):(tr(s),em(s,z))}if(F=t.memoizedState,F!==null&&(z=F.dehydrated,z!==null)){if(v)s.flags&256?(tr(s),s.flags&=-257,s=tm(t,s,i)):s.memoizedState!==null?(nr(),s.child=t.child,s.flags|=128,s=null):(nr(),p=u.fallback,z=s.mode,u=yc({mode:"visible",children:u.children},z),p=$r(p,z,i,null),p.flags|=2,u.return=s,p.return=s,u.sibling=p,s.child=u,qo(s,t.child,null,i),u=s.child,u.memoizedState=Qf(i),u.childLanes=Jf(t,A,i),s.memoizedState=Kf,s=p);else if(tr(s),Im(z)){if(A=z.nextSibling&&z.nextSibling.dataset,A)var re=A.dgst;A=re,u=Error(a(419)),u.stack="",u.digest=A,oi({value:u,source:null,stack:null}),s=tm(t,s,i)}else if($t||ai(t,s,i,!1),A=(i&t.childLanes)!==0,$t||A){if(A=gt,A!==null&&(u=i&-i,u=(u&42)!==0?1:Ha(u),u=(u&(A.suspendedLanes|i))!==0?0:u,u!==0&&u!==F.retryLane))throw F.retryLane=u,Do(t,u),pn(A,t,u),x0;z.data==="$?"||vm(),s=tm(t,s,i)}else z.data==="$?"?(s.flags|=192,s.child=t.child,s=null):(t=F.treeContext,jt=Hn(z.nextSibling),Kt=s,at=!0,Ur=null,Zn=!1,t!==null&&(jn[_n++]=Ns,jn[_n++]=Ss,jn[_n++]=Br,Ns=t.id,Ss=t.overflow,Br=s),s=em(s,u.children),s.flags|=4096);return s}return p?(nr(),p=u.fallback,z=s.mode,F=t.child,re=F.sibling,u=ws(F,{mode:"hidden",children:u.children}),u.subtreeFlags=F.subtreeFlags&65011712,re!==null?p=ws(re,p):(p=$r(p,z,i,null),p.flags|=2),p.return=s,u.return=s,u.sibling=p,s.child=u,u=p,p=s.child,z=t.child.memoizedState,z===null?z=Qf(i):(F=z.cachePool,F!==null?(re=Ot._currentValue,F=F.parent!==re?{parent:re,pool:re}:F):F=fx(),z={baseLanes:z.baseLanes|i,cachePool:F}),p.memoizedState=z,p.childLanes=Jf(t,A,i),s.memoizedState=Kf,u):(tr(s),i=t.child,t=i.sibling,i=ws(i,{mode:"visible",children:u.children}),i.return=s,i.sibling=null,t!==null&&(A=s.deletions,A===null?(s.deletions=[t],s.flags|=16):A.push(t)),s.child=i,s.memoizedState=null,i)}function em(t,s){return s=yc({mode:"visible",children:s},t.mode),s.return=t,t.child=s}function yc(t,s){return t=un(22,t,null,s),t.lanes=0,t.stateNode={_visibility:1,_pendingMarkers:null,_retryCache:null,_transitions:null},t}function tm(t,s,i){return qo(s,t.child,null,i),t=em(s,s.pendingProps.children),t.flags|=2,s.memoizedState=null,t}function C0(t,s,i){t.lanes|=s;var u=t.alternate;u!==null&&(u.lanes|=s),vf(t.return,s,i)}function nm(t,s,i,u,p){var v=t.memoizedState;v===null?t.memoizedState={isBackwards:s,rendering:null,renderingStartTime:0,last:u,tail:i,tailMode:p}:(v.isBackwards=s,v.rendering=null,v.renderingStartTime=0,v.last=u,v.tail=i,v.tailMode=p)}function k0(t,s,i){var u=s.pendingProps,p=u.revealOrder,v=u.tail;if(Vt(t,s,u.children,i),u=zt.current,(u&2)!==0)u=u&1|2,s.flags|=128;else{if(t!==null&&(t.flags&128)!==0)e:for(t=s.child;t!==null;){if(t.tag===13)t.memoizedState!==null&&C0(t,i,s);else if(t.tag===19)C0(t,i,s);else if(t.child!==null){t.child.return=t,t=t.child;continue}if(t===s)break e;for(;t.sibling===null;){if(t.return===null||t.return===s)break e;t=t.return}t.sibling.return=t.return,t=t.sibling}u&=1}switch(V(zt,u),p){case"forwards":for(i=s.child,p=null;i!==null;)t=i.alternate,t!==null&&hc(t)===null&&(p=i),i=i.sibling;i=p,i===null?(p=s.child,s.child=null):(p=i.sibling,i.sibling=null),nm(s,!1,p,i,v);break;case"backwards":for(i=null,p=s.child,s.child=null;p!==null;){if(t=p.alternate,t!==null&&hc(t)===null){s.child=p;break}t=p.sibling,p.sibling=i,i=p,p=t}nm(s,!0,i,null,v);break;case"together":nm(s,!1,null,null,void 0);break;default:s.memoizedState=null}return s.child}function ks(t,s,i){if(t!==null&&(s.dependencies=t.dependencies),ir|=s.lanes,(i&s.childLanes)===0)if(t!==null){if(ai(t,s,i,!1),(i&s.childLanes)===0)return null}else return null;if(t!==null&&s.child!==t.child)throw Error(a(153));if(s.child!==null){for(t=s.child,i=ws(t,t.pendingProps),s.child=i,i.return=s;t.sibling!==null;)t=t.sibling,i=i.sibling=ws(t,t.pendingProps),i.return=s;i.sibling=null}return s.child}function sm(t,s){return(t.lanes&s)!==0?!0:(t=t.dependencies,!!(t!==null&&Ql(t)))}function y_(t,s,i){switch(s.tag){case 3:ie(s,s.stateNode.containerInfo),Ws(s,Ot,t.memoizedState.cache),ri();break;case 27:case 5:be(s);break;case 4:ie(s,s.stateNode.containerInfo);break;case 10:Ws(s,s.type,s.memoizedProps.value);break;case 13:var u=s.memoizedState;if(u!==null)return u.dehydrated!==null?(tr(s),s.flags|=128,null):(i&s.child.childLanes)!==0?E0(t,s,i):(tr(s),t=ks(t,s,i),t!==null?t.sibling:null);tr(s);break;case 19:var p=(t.flags&128)!==0;if(u=(i&s.childLanes)!==0,u||(ai(t,s,i,!1),u=(i&s.childLanes)!==0),p){if(u)return k0(t,s,i);s.flags|=128}if(p=s.memoizedState,p!==null&&(p.rendering=null,p.tail=null,p.lastEffect=null),V(zt,zt.current),u)break;return null;case 22:case 23:return s.lanes=0,w0(t,s,i);case 24:Ws(s,Ot,t.memoizedState.cache)}return ks(t,s,i)}function A0(t,s,i){if(t!==null)if(t.memoizedProps!==s.pendingProps)$t=!0;else{if(!sm(t,i)&&(s.flags&128)===0)return $t=!1,y_(t,s,i);$t=(t.flags&131072)!==0}else $t=!1,at&&(s.flags&1048576)!==0&&ox(s,Kl,s.index);switch(s.lanes=0,s.tag){case 16:e:{t=s.pendingProps;var u=s.elementType,p=u._init;if(u=p(u._payload),s.type=u,typeof u=="function")ff(u)?(t=Xr(u,t),s.tag=1,s=j0(null,s,u,t,i)):(s.tag=0,s=Wf(null,s,u,t,i));else{if(u!=null){if(p=u.$$typeof,p===T){s.tag=11,s=y0(null,s,u,t,i);break e}else if(p===O){s.tag=14,s=v0(null,s,u,t,i);break e}}throw s=P(u)||u,Error(a(306,s,""))}}return s;case 0:return Wf(t,s,s.type,s.pendingProps,i);case 1:return u=s.type,p=Xr(u,s.pendingProps),j0(t,s,u,p,i);case 3:e:{if(ie(s,s.stateNode.containerInfo),t===null)throw Error(a(387));u=s.pendingProps;var v=s.memoizedState;p=v.element,Ef(t,s),mi(s,u,null,i);var A=s.memoizedState;if(u=A.cache,Ws(s,Ot,u),u!==v.cache&&bf(s,[Ot],i,!0),fi(),u=A.element,v.isDehydrated)if(v={element:u,isDehydrated:!1,cache:A.cache},s.updateQueue.baseState=v,s.memoizedState=v,s.flags&256){s=_0(t,s,u,i);break e}else if(u!==p){p=Nn(Error(a(424)),s),oi(p),s=_0(t,s,u,i);break e}else{switch(t=s.stateNode.containerInfo,t.nodeType){case 9:t=t.body;break;default:t=t.nodeName==="HTML"?t.ownerDocument.body:t}for(jt=Hn(t.firstChild),Kt=s,at=!0,Ur=null,Zn=!0,i=i0(s,null,u,i),s.child=i;i;)i.flags=i.flags&-3|4096,i=i.sibling}else{if(ri(),u===p){s=ks(t,s,i);break e}Vt(t,s,u,i)}s=s.child}return s;case 26:return xc(t,s),t===null?(i=Dy(s.type,null,s.pendingProps,null))?s.memoizedState=i:at||(i=s.type,t=s.pendingProps,u=Rc(ue.current).createElement(i),u[Ht]=s,u[Xt]=t,Ft(u,i,t),Mt(u),s.stateNode=u):s.memoizedState=Dy(s.type,t.memoizedProps,s.pendingProps,t.memoizedState),null;case 27:return be(s),t===null&&at&&(u=s.stateNode=My(s.type,s.pendingProps,ue.current),Kt=s,Zn=!0,p=jt,dr(s.type)?(Lm=p,jt=Hn(u.firstChild)):jt=p),Vt(t,s,s.pendingProps.children,i),xc(t,s),t===null&&(s.flags|=4194304),s.child;case 5:return t===null&&at&&((p=u=jt)&&(u=Y_(u,s.type,s.pendingProps,Zn),u!==null?(s.stateNode=u,Kt=s,jt=Hn(u.firstChild),Zn=!1,p=!0):p=!1),p||Vr(s)),be(s),p=s.type,v=s.pendingProps,A=t!==null?t.memoizedProps:null,u=v.children,Dm(p,v)?u=null:A!==null&&Dm(p,A)&&(s.flags|=32),s.memoizedState!==null&&(p=Rf(t,s,u_,null,null,i),zi._currentValue=p),xc(t,s),Vt(t,s,u,i),s.child;case 6:return t===null&&at&&((t=i=jt)&&(i=G_(i,s.pendingProps,Zn),i!==null?(s.stateNode=i,Kt=s,jt=null,t=!0):t=!1),t||Vr(s)),null;case 13:return E0(t,s,i);case 4:return ie(s,s.stateNode.containerInfo),u=s.pendingProps,t===null?s.child=qo(s,null,u,i):Vt(t,s,u,i),s.child;case 11:return y0(t,s,s.type,s.pendingProps,i);case 7:return Vt(t,s,s.pendingProps,i),s.child;case 8:return Vt(t,s,s.pendingProps.children,i),s.child;case 12:return Vt(t,s,s.pendingProps.children,i),s.child;case 10:return u=s.pendingProps,Ws(s,s.type,u.value),Vt(t,s,u.children,i),s.child;case 9:return p=s.type._context,u=s.pendingProps.children,Fr(s),p=Zt(p),u=u(p),s.flags|=1,Vt(t,s,u,i),s.child;case 14:return v0(t,s,s.type,s.pendingProps,i);case 15:return b0(t,s,s.type,s.pendingProps,i);case 19:return k0(t,s,i);case 31:return u=s.pendingProps,i=s.mode,u={mode:u.mode,children:u.children},t===null?(i=yc(u,i),i.ref=s.ref,s.child=i,i.return=s,s=i):(i=ws(t.child,u),i.ref=s.ref,s.child=i,i.return=s,s=i),s;case 22:return w0(t,s,i);case 24:return Fr(s),u=Zt(Ot),t===null?(p=Sf(),p===null&&(p=gt,v=wf(),p.pooledCache=v,v.refCount++,v!==null&&(p.pooledCacheLanes|=i),p=v),s.memoizedState={parent:u,cache:p},_f(s),Ws(s,Ot,p)):((t.lanes&i)!==0&&(Ef(t,s),mi(s,null,null,i),fi()),p=t.memoizedState,v=s.memoizedState,p.parent!==u?(p={parent:u,cache:u},s.memoizedState=p,s.lanes===0&&(s.memoizedState=s.updateQueue.baseState=p),Ws(s,Ot,u)):(u=v.cache,Ws(s,Ot,u),u!==p.cache&&bf(s,[Ot],i,!0))),Vt(t,s,s.pendingProps.children,i),s.child;case 29:throw s.pendingProps}throw Error(a(156,s.tag))}function As(t){t.flags|=4}function M0(t,s){if(s.type!=="stylesheet"||(s.state.loading&4)!==0)t.flags&=-16777217;else if(t.flags|=16777216,!Hy(s)){if(s=En.current,s!==null&&((nt&4194048)===nt?Wn!==null:(nt&62914560)!==nt&&(nt&536870912)===0||s!==Wn))throw ui=jf,mx;t.flags|=8192}}function vc(t,s){s!==null&&(t.flags|=4),t.flags&16384&&(s=t.tag!==22?ot():536870912,t.lanes|=s,Xo|=s)}function bi(t,s){if(!at)switch(t.tailMode){case"hidden":s=t.tail;for(var i=null;s!==null;)s.alternate!==null&&(i=s),s=s.sibling;i===null?t.tail=null:i.sibling=null;break;case"collapsed":i=t.tail;for(var u=null;i!==null;)i.alternate!==null&&(u=i),i=i.sibling;u===null?s||t.tail===null?t.tail=null:t.tail.sibling=null:u.sibling=null}}function St(t){var s=t.alternate!==null&&t.alternate.child===t.child,i=0,u=0;if(s)for(var p=t.child;p!==null;)i|=p.lanes|p.childLanes,u|=p.subtreeFlags&65011712,u|=p.flags&65011712,p.return=t,p=p.sibling;else for(p=t.child;p!==null;)i|=p.lanes|p.childLanes,u|=p.subtreeFlags,u|=p.flags,p.return=t,p=p.sibling;return t.subtreeFlags|=u,t.childLanes=i,s}function v_(t,s,i){var u=s.pendingProps;switch(gf(s),s.tag){case 31:case 16:case 15:case 0:case 11:case 7:case 8:case 12:case 9:case 14:return St(s),null;case 1:return St(s),null;case 3:return i=s.stateNode,u=null,t!==null&&(u=t.memoizedState.cache),s.memoizedState.cache!==u&&(s.flags|=2048),_s(Ot),ge(),i.pendingContext&&(i.context=i.pendingContext,i.pendingContext=null),(t===null||t.child===null)&&(si(s)?As(s):t===null||t.memoizedState.isDehydrated&&(s.flags&256)===0||(s.flags|=1024,lx())),St(s),null;case 26:return i=s.memoizedState,t===null?(As(s),i!==null?(St(s),M0(s,i)):(St(s),s.flags&=-16777217)):i?i!==t.memoizedState?(As(s),St(s),M0(s,i)):(St(s),s.flags&=-16777217):(t.memoizedProps!==u&&As(s),St(s),s.flags&=-16777217),null;case 27:we(s),i=ue.current;var p=s.type;if(t!==null&&s.stateNode!=null)t.memoizedProps!==u&&As(s);else{if(!u){if(s.stateNode===null)throw Error(a(166));return St(s),null}t=W.current,si(s)?ax(s):(t=My(p,u,i),s.stateNode=t,As(s))}return St(s),null;case 5:if(we(s),i=s.type,t!==null&&s.stateNode!=null)t.memoizedProps!==u&&As(s);else{if(!u){if(s.stateNode===null)throw Error(a(166));return St(s),null}if(t=W.current,si(s))ax(s);else{switch(p=Rc(ue.current),t){case 1:t=p.createElementNS("http://www.w3.org/2000/svg",i);break;case 2:t=p.createElementNS("http://www.w3.org/1998/Math/MathML",i);break;default:switch(i){case"svg":t=p.createElementNS("http://www.w3.org/2000/svg",i);break;case"math":t=p.createElementNS("http://www.w3.org/1998/Math/MathML",i);break;case"script":t=p.createElement("div"),t.innerHTML="