Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public static class OpenAIResponseClientExtensions
/// <param name="tools">Optional collection of AI tools that the agent can use during conversations.</param>
/// <param name="clientFactory">Provides a way to customize the creation of the underlying <see cref="IChatClient"/> used by the agent.</param>
/// <param name="loggerFactory">Optional logger factory for enabling logging within the agent.</param>
/// <param name="services">An optional <see cref="IServiceProvider"/> to use for resolving services required by the <see cref="AIFunction"/> instances being invoked.</param>
/// <returns>An <see cref="ChatClientAgent"/> instance backed by the OpenAI Response service.</returns>
/// <exception cref="ArgumentNullException">Thrown when <paramref name="client"/> is <see langword="null"/>.</exception>
public static ChatClientAgent CreateAIAgent(
Expand All @@ -39,7 +40,8 @@ public static ChatClientAgent CreateAIAgent(
string? description = null,
IList<AITool>? tools = null,
Func<IChatClient, IChatClient>? clientFactory = null,
ILoggerFactory? loggerFactory = null)
ILoggerFactory? loggerFactory = null,
IServiceProvider? services = null)
{
Throw.IfNull(client);

Expand All @@ -55,7 +57,8 @@ public static ChatClientAgent CreateAIAgent(
}
},
clientFactory,
loggerFactory);
loggerFactory,
services);
}

/// <summary>
Expand All @@ -65,13 +68,15 @@ public static ChatClientAgent CreateAIAgent(
/// <param name="options">Full set of options to configure the agent.</param>
/// <param name="clientFactory">Provides a way to customize the creation of the underlying <see cref="IChatClient"/> used by the agent.</param>
/// <param name="loggerFactory">Optional logger factory for enabling logging within the agent.</param>
/// <param name="services">An optional <see cref="IServiceProvider"/> to use for resolving services required by the <see cref="AIFunction"/> instances being invoked.</param>
/// <returns>An <see cref="ChatClientAgent"/> instance backed by the OpenAI Response service.</returns>
/// <exception cref="ArgumentNullException">Thrown when <paramref name="client"/> or <paramref name="options"/> is <see langword="null"/>.</exception>
public static ChatClientAgent CreateAIAgent(
this OpenAIResponseClient client,
ChatClientAgentOptions options,
Func<IChatClient, IChatClient>? clientFactory = null,
ILoggerFactory? loggerFactory = null)
ILoggerFactory? loggerFactory = null,
IServiceProvider? services = null)
{
Throw.IfNull(client);
Throw.IfNull(options);
Expand All @@ -83,6 +88,6 @@ public static ChatClientAgent CreateAIAgent(
chatClient = clientFactory(chatClient);
}

return new ChatClientAgent(chatClient, options, loggerFactory);
return new ChatClientAgent(chatClient, options, loggerFactory, services);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Azure;
Expand Down Expand Up @@ -726,6 +727,159 @@ public async Task CreateAIAgentAsync_WithEmptyModel_ThrowsArgumentExceptionAsync
Assert.Equal("model", exception.ParamName);
}

/// <summary>
/// Verify that CreateAIAgent with services parameter correctly passes it through to the ChatClientAgent.
/// </summary>
[Fact]
public void CreateAIAgent_WithServices_PassesServicesToAgent()
{
// Arrange
var client = CreateFakePersistentAgentsClient();
var serviceProvider = new TestServiceProvider();
const string Model = "test-model";

// Act
var agent = client.CreateAIAgent(
Model,
instructions: "Test instructions",
name: "Test Agent",
services: serviceProvider);

// Assert
Assert.NotNull(agent);

// Verify the IServiceProvider was passed through to the FunctionInvokingChatClient
var chatClient = agent.GetService<IChatClient>();
Assert.NotNull(chatClient);
var functionInvokingClient = chatClient.GetService<FunctionInvokingChatClient>();
Assert.NotNull(functionInvokingClient);
Assert.Same(serviceProvider, GetFunctionInvocationServices(functionInvokingClient));
}

/// <summary>
/// Verify that CreateAIAgentAsync with services parameter correctly passes it through to the ChatClientAgent.
/// </summary>
[Fact]
public async Task CreateAIAgentAsync_WithServices_PassesServicesToAgentAsync()
{
// Arrange
var client = CreateFakePersistentAgentsClient();
var serviceProvider = new TestServiceProvider();
const string Model = "test-model";

// Act
var agent = await client.CreateAIAgentAsync(
Model,
instructions: "Test instructions",
name: "Test Agent",
services: serviceProvider);

// Assert
Assert.NotNull(agent);

// Verify the IServiceProvider was passed through to the FunctionInvokingChatClient
var chatClient = agent.GetService<IChatClient>();
Assert.NotNull(chatClient);
var functionInvokingClient = chatClient.GetService<FunctionInvokingChatClient>();
Assert.NotNull(functionInvokingClient);
Assert.Same(serviceProvider, GetFunctionInvocationServices(functionInvokingClient));
}

/// <summary>
/// Verify that GetAIAgent with services parameter correctly passes it through to the ChatClientAgent.
/// </summary>
[Fact]
public void GetAIAgent_WithServices_PassesServicesToAgent()
{
// Arrange
var client = CreateFakePersistentAgentsClient();
var serviceProvider = new TestServiceProvider();

// Act
var agent = client.GetAIAgent("agent_abc123", services: serviceProvider);

// Assert
Assert.NotNull(agent);

// Verify the IServiceProvider was passed through to the FunctionInvokingChatClient
var chatClient = agent.GetService<IChatClient>();
Assert.NotNull(chatClient);
var functionInvokingClient = chatClient.GetService<FunctionInvokingChatClient>();
Assert.NotNull(functionInvokingClient);
Assert.Same(serviceProvider, GetFunctionInvocationServices(functionInvokingClient));
}

/// <summary>
/// Verify that GetAIAgentAsync with services parameter correctly passes it through to the ChatClientAgent.
/// </summary>
[Fact]
public async Task GetAIAgentAsync_WithServices_PassesServicesToAgentAsync()
{
// Arrange
var client = CreateFakePersistentAgentsClient();
var serviceProvider = new TestServiceProvider();

// Act
var agent = await client.GetAIAgentAsync("agent_abc123", services: serviceProvider);

// Assert
Assert.NotNull(agent);

// Verify the IServiceProvider was passed through to the FunctionInvokingChatClient
var chatClient = agent.GetService<IChatClient>();
Assert.NotNull(chatClient);
var functionInvokingClient = chatClient.GetService<FunctionInvokingChatClient>();
Assert.NotNull(functionInvokingClient);
Assert.Same(serviceProvider, GetFunctionInvocationServices(functionInvokingClient));
}

/// <summary>
/// Verify that CreateAIAgent with both clientFactory and services works correctly.
/// </summary>
[Fact]
public void CreateAIAgent_WithClientFactoryAndServices_AppliesBothCorrectly()
{
// Arrange
var client = CreateFakePersistentAgentsClient();
var serviceProvider = new TestServiceProvider();
TestChatClient? testChatClient = null;
const string Model = "test-model";

// Act
var agent = client.CreateAIAgent(
Model,
instructions: "Test instructions",
name: "Test Agent",
clientFactory: (innerClient) => testChatClient = new TestChatClient(innerClient),
services: serviceProvider);

// Assert
Assert.NotNull(agent);

// Verify the custom chat client was applied
var retrievedTestClient = agent.GetService<TestChatClient>();
Assert.NotNull(retrievedTestClient);
Assert.Same(testChatClient, retrievedTestClient);

// Verify the IServiceProvider was passed through
var chatClient = agent.GetService<IChatClient>();
Assert.NotNull(chatClient);
var functionInvokingClient = chatClient.GetService<FunctionInvokingChatClient>();
Assert.NotNull(functionInvokingClient);
Assert.Same(serviceProvider, GetFunctionInvocationServices(functionInvokingClient));
}

/// <summary>
/// Uses reflection to access the FunctionInvocationServices property which is not public.
/// </summary>
private static IServiceProvider? GetFunctionInvocationServices(FunctionInvokingChatClient client)
{
var property = typeof(FunctionInvokingChatClient).GetProperty(
"FunctionInvocationServices",
BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
return property?.GetValue(client) as IServiceProvider;
}

/// <summary>
/// Test custom chat client that can be used to verify clientFactory functionality.
/// </summary>
Expand All @@ -736,6 +890,14 @@ public TestChatClient(IChatClient innerClient) : base(innerClient)
}
}

/// <summary>
/// A simple test IServiceProvider implementation for testing.
/// </summary>
private sealed class TestServiceProvider : IServiceProvider
{
public object? GetService(Type serviceType) => null;
}

public sealed class FakePersistentAgentsAdministrationClient : PersistentAgentsAdministrationClient
{
public FakePersistentAgentsAdministrationClient()
Expand All @@ -761,7 +923,7 @@ private static PersistentAgentsClient CreateFakePersistentAgentsClient()
{
var client = new PersistentAgentsClient("https://any.com", DelegatedTokenCredential.Create((_, _) => new AccessToken()));

((System.Reflection.TypeInfo)typeof(PersistentAgentsClient)).DeclaredFields.First(f => f.Name == "_client")
((TypeInfo)typeof(PersistentAgentsClient)).DeclaredFields.First(f => f.Name == "_client")
.SetValue(client, new FakePersistentAgentsAdministrationClient());
return client;
}
Expand Down
Loading
Loading