Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using Microsoft.Agents.AI.Hosting.Local;
using System.Linq;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Shared.Diagnostics;
Expand All @@ -29,7 +28,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s
return services.AddAIAgent(name, (sp, key) =>
{
var chatClient = sp.GetRequiredService<IChatClient>();
var tools = GetRegisteredToolsForAgent(sp, name);
var tools = sp.GetKeyedServices<AITool>(name).ToList();
return new ChatClientAgent(chatClient, instructions, key, tools: tools);
});
}
Expand All @@ -49,7 +48,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s
Throw.IfNullOrEmpty(name);
return services.AddAIAgent(name, (sp, key) =>
{
var tools = GetRegisteredToolsForAgent(sp, name);
var tools = sp.GetKeyedServices<AITool>(name).ToList();
return new ChatClientAgent(chatClient, instructions, key, tools: tools);
});
}
Expand All @@ -70,7 +69,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s
return services.AddAIAgent(name, (sp, key) =>
{
var chatClient = chatClientServiceKey is null ? sp.GetRequiredService<IChatClient>() : sp.GetRequiredKeyedService<IChatClient>(chatClientServiceKey);
var tools = GetRegisteredToolsForAgent(sp, name);
var tools = sp.GetKeyedServices<AITool>(name).ToList();
return new ChatClientAgent(chatClient, instructions, key, tools: tools);
});
}
Expand All @@ -92,7 +91,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s
return services.AddAIAgent(name, (sp, key) =>
{
var chatClient = chatClientServiceKey is null ? sp.GetRequiredService<IChatClient>() : sp.GetRequiredKeyedService<IChatClient>(chatClientServiceKey);
var tools = GetRegisteredToolsForAgent(sp, name);
var tools = sp.GetKeyedServices<AITool>(name).ToList();
return new ChatClientAgent(chatClient, instructions: instructions, name: key, description: description, tools: tools);
});
}
Expand Down Expand Up @@ -127,10 +126,4 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s

return new HostedAgentBuilder(name, services);
}

private static IList<AITool> GetRegisteredToolsForAgent(IServiceProvider serviceProvider, string agentName)
{
var registry = serviceProvider.GetService<LocalAgentToolRegistry>();
return registry?.GetTools(agentName) ?? [];
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Linq;
using Microsoft.Agents.AI.Hosting.Local;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Shared.Diagnostics;
Expand Down Expand Up @@ -70,18 +68,7 @@ public static IHostedAgentBuilder WithAITool(this IHostedAgentBuilder builder, A
Throw.IfNull(builder);
Throw.IfNull(tool);

var agentName = builder.Name;
var services = builder.ServiceCollection;

// Get or create the agent tool registry
var descriptor = services.FirstOrDefault(sd => !sd.IsKeyedService && sd.ServiceType.Equals(typeof(LocalAgentToolRegistry)));
if (descriptor?.ImplementationInstance is not LocalAgentToolRegistry toolRegistry)
{
toolRegistry = new();
services.Add(ServiceDescriptor.Singleton(toolRegistry));
}

toolRegistry.AddTool(agentName, tool);
builder.ServiceCollection.AddKeyedSingleton(builder.Name, tool);

return builder;
}
Expand All @@ -105,4 +92,19 @@ public static IHostedAgentBuilder WithAITools(this IHostedAgentBuilder builder,

return builder;
}

/// <summary>
/// Adds AI tool to an agent being configured with the service collection.
/// </summary>
/// <param name="builder">The hosted agent builder.</param>
/// <param name="factory">A factory function that creates a AI tool using the provided service provider.</param>
public static IHostedAgentBuilder WithAITool(this IHostedAgentBuilder builder, Func<IServiceProvider, AITool> factory)
{
Throw.IfNull(builder);
Throw.IfNull(factory);

builder.ServiceCollection.AddKeyedSingleton(builder.Name, (sp, name) => factory(sp));

return builder;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
Expand All @@ -17,49 +18,40 @@ public sealed class HostedAgentBuilderToolsExtensionsTests
[Fact]
public void WithAITool_ThrowsWhenBuilderIsNull()
{
// Arrange
var tool = new DummyAITool();

// Act & Assert
Assert.Throws<ArgumentNullException>(() => HostedAgentBuilderExtensions.WithAITool(null!, tool));
}

[Fact]
public void WithAITool_ThrowsWhenToolIsNull()
{
// Arrange
var services = new ServiceCollection();
var builder = services.AddAIAgent("test-agent", "Test instructions");

// Act & Assert
Assert.Throws<ArgumentNullException>(() => builder.WithAITool(null!));
Assert.Throws<ArgumentNullException>(() => builder.WithAITool(tool: null!));
}

[Fact]
public void WithAITools_ThrowsWhenBuilderIsNull()
{
// Arrange
var tools = new[] { new DummyAITool() };

// Act & Assert
Assert.Throws<ArgumentNullException>(() => HostedAgentBuilderExtensions.WithAITools(null!, tools));
}

[Fact]
public void WithAITools_ThrowsWhenToolsArrayIsNull()
{
// Arrange
var services = new ServiceCollection();
var builder = services.AddAIAgent("test-agent", "Test instructions");

// Act & Assert
Assert.Throws<ArgumentNullException>(() => builder.WithAITools(null!));
}

[Fact]
public void RegisteredTools_ResolvesAllToolsForAgent()
{
// Arrange
var services = new ServiceCollection();
services.AddSingleton<IChatClient>(new MockChatClient());

Expand All @@ -73,9 +65,13 @@ public void RegisteredTools_ResolvesAllToolsForAgent()

var serviceProvider = services.BuildServiceProvider();

var agent1Tools = ResolveAgentTools(serviceProvider, "test-agent");
var agent1Tools = ResolveToolsFromAgent(serviceProvider, "test-agent");
Assert.Contains(tool1, agent1Tools);
Assert.Contains(tool2, agent1Tools);

var agent1ToolsDI = ResolveToolsFromDI(serviceProvider, "test-agent");
Assert.Contains(tool1, agent1ToolsDI);
Assert.Contains(tool2, agent1ToolsDI);
}

[Fact]
Expand All @@ -100,21 +96,160 @@ public void RegisteredTools_IsolatedPerAgent()

var serviceProvider = services.BuildServiceProvider();

var agent1Tools = ResolveAgentTools(serviceProvider, "agent1");
var agent2Tools = ResolveAgentTools(serviceProvider, "agent2");
var agent1Tools = ResolveToolsFromAgent(serviceProvider, "agent1");
var agent2Tools = ResolveToolsFromAgent(serviceProvider, "agent2");

var agent1ToolsDI = ResolveToolsFromDI(serviceProvider, "agent1");
var agent2ToolsDI = ResolveToolsFromDI(serviceProvider, "agent2");

Assert.Contains(tool1, agent1Tools);
Assert.Contains(tool2, agent1Tools);
Assert.Contains(tool1, agent1ToolsDI);
Assert.Contains(tool2, agent1ToolsDI);

Assert.Contains(tool3, agent2Tools);
Assert.Contains(tool3, agent2ToolsDI);
}

private static IList<AITool> ResolveAgentTools(IServiceProvider serviceProvider, string name)
private static IList<AITool> ResolveToolsFromAgent(IServiceProvider serviceProvider, string name)
{
var agent = serviceProvider.GetRequiredKeyedService<AIAgent>(name) as ChatClientAgent;
Assert.NotNull(agent?.ChatOptions?.Tools);
return agent.ChatOptions.Tools;
}

private static List<AITool> ResolveToolsFromDI(IServiceProvider serviceProvider, string name)
{
var tools = serviceProvider.GetKeyedServices<AITool>(name);
Assert.NotNull(tools);
return tools.ToList();
}

[Fact]
public void WithAIToolFactory_ThrowsWhenBuilderIsNull()
{
Assert.Throws<ArgumentNullException>(() => HostedAgentBuilderExtensions.WithAITool(null!, CreateTool));

static AITool CreateTool(IServiceProvider _) => new DummyAITool();
}

[Fact]
public void WithAIToolFactory_ThrowsWhenFactoryIsNull()
{
var services = new ServiceCollection();
var builder = services.AddAIAgent("test-agent", "Test instructions");

Assert.Throws<ArgumentNullException>(() => builder.WithAITool(factory: null!));
}

[Fact]
public void WithAIToolFactory_RegistersToolFromFactory()
{
var services = new ServiceCollection();
services.AddSingleton<IChatClient>(new MockChatClient());

DummyAITool? createdTool = null;
var builder = services.AddAIAgent("test-agent", "Test instructions");
builder.WithAITool(sp =>
{
createdTool = new DummyAITool();
return createdTool;
});

var serviceProvider = services.BuildServiceProvider();
var tools = ResolveToolsFromDI(serviceProvider, "test-agent");

Assert.Single(tools);
Assert.Same(createdTool, tools[0]);
}

[Fact]
public void WithAIToolFactory_CanAccessServicesFromFactory()
{
var services = new ServiceCollection();
var mockChatClient = new MockChatClient();
services.AddSingleton<IChatClient>(mockChatClient);

IChatClient? resolvedChatClient = null;
var builder = services.AddAIAgent("test-agent", "Test instructions");
builder.WithAITool(sp =>
{
resolvedChatClient = sp.GetService<IChatClient>();
return new DummyAITool();
});

var serviceProvider = services.BuildServiceProvider();
_ = ResolveToolsFromDI(serviceProvider, "test-agent");

Assert.Same(mockChatClient, resolvedChatClient);
}

[Fact]
public void WithAIToolFactory_ToolsAreIsolatedPerAgent()
{
var services = new ServiceCollection();
services.AddSingleton<IChatClient>(new MockChatClient());

var tool1 = new DummyAITool();
var tool2 = new DummyAITool();

var builder1 = services.AddAIAgent("agent1", "Agent 1 instructions");
var builder2 = services.AddAIAgent("agent2", "Agent 2 instructions");

builder1.WithAITool(_ => tool1);
builder2.WithAITool(_ => tool2);

var serviceProvider = services.BuildServiceProvider();
var agent1Tools = ResolveToolsFromDI(serviceProvider, "agent1");
var agent2Tools = ResolveToolsFromDI(serviceProvider, "agent2");

Assert.Single(agent1Tools);
Assert.Contains(tool1, agent1Tools);
Assert.DoesNotContain(tool2, agent1Tools);

Assert.Single(agent2Tools);
Assert.Contains(tool2, agent2Tools);
Assert.DoesNotContain(tool1, agent2Tools);
}

[Fact]
public void WithAIToolFactory_CanCombineWithDirectToolRegistration()
{
var services = new ServiceCollection();
services.AddSingleton<IChatClient>(new MockChatClient());

var directTool = new DummyAITool();
var factoryTool = new DummyAITool();

var builder = services.AddAIAgent("test-agent", "Test instructions");
builder
.WithAITool(directTool)
.WithAITool(_ => factoryTool);

var serviceProvider = services.BuildServiceProvider();
var tools = ResolveToolsFromDI(serviceProvider, "test-agent");

Assert.Equal(2, tools.Count);
Assert.Contains(directTool, tools);
Assert.Contains(factoryTool, tools);
}

[Fact]
public void WithAIToolFactory_ToolsAvailableOnAgent()
{
var services = new ServiceCollection();
services.AddSingleton<IChatClient>(new MockChatClient());

var factoryTool = new DummyAITool();
var builder = services.AddAIAgent("test-agent", "Test instructions");
builder.WithAITool(_ => factoryTool);

var serviceProvider = services.BuildServiceProvider();
var agentTools = ResolveToolsFromAgent(serviceProvider, "test-agent");

Assert.Contains(factoryTool, agentTools);
}

/// <summary>
/// Dummy AITool implementation for testing.
/// </summary>
Expand Down
Loading