diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/CHANGELOG.md b/src/Libraries/Microsoft.Extensions.AI.Abstractions/CHANGELOG.md index 34c83f701c1..37ffe79f892 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/CHANGELOG.md +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/CHANGELOG.md @@ -4,6 +4,7 @@ - Added non-invocable `AIFunctionDeclaration` (base class for `AIFunction`), `AIFunctionFactory.CreateDeclaration`, and `AIFunction.AsDeclarationOnly`. - Added `[Experimental]` support for user approval of function invocations via `ApprovalRequiredAIFunction`, `FunctionApprovalRequestContent`, and friends. +- Added `[Experimental]` support for MCP server-hosted tools via `HostedMcpServerTool`, `HostedMcpServerToolApprovalMode`, and friends. ## 9.8.0 diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/HostedMcpServerToolRequireSpecificApprovalMode.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/HostedMcpServerToolRequireSpecificApprovalMode.cs index a5e870af7e7..42a39f35c50 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/HostedMcpServerToolRequireSpecificApprovalMode.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/HostedMcpServerToolRequireSpecificApprovalMode.cs @@ -6,6 +6,9 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; +#pragma warning disable S109 // Magic numbers should not be used +#pragma warning disable EA0011 // Consider removing unnecessary conditional access operator (?) + namespace Microsoft.Extensions.AI; /// @@ -42,7 +45,7 @@ public override bool Equals(object? obj) => obj is HostedMcpServerToolRequireSpe /// public override int GetHashCode() => - HashCode.Combine(GetListHashCode(AlwaysRequireApprovalToolNames), GetListHashCode(NeverRequireApprovalToolNames)); + Combine(GetListHashCode(AlwaysRequireApprovalToolNames), GetListHashCode(NeverRequireApprovalToolNames)); private static bool ListEquals(IList? list1, IList? list2) => ReferenceEquals(list1, list2) || @@ -55,12 +58,32 @@ private static int GetListHashCode(IList? list) return 0; } +#if NET HashCode hc = default; - foreach (string item in list) + for (int i = 0; i < list.Count; i++) { - hc.Add(item); + hc.Add(list[i]); } return hc.ToHashCode(); +#else + int hash = 0; + for (int i = 0; i < list.Count; i++) + { + hash = Combine(hash, list[i]?.GetHashCode() ?? 0); + } + + return hash; +#endif + } + + private static int Combine(int h1, int h2) + { +#if NET + return HashCode.Combine(h1, h2); +#else + uint rol5 = ((uint)h1 << 5) | ((uint)h1 >> 27); + return ((int)rol5 + h1) ^ h2; +#endif } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj index 25aa2f4c49d..f5472854def 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj @@ -32,7 +32,6 @@ - diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/CHANGELOG.md b/src/Libraries/Microsoft.Extensions.AI.OpenAI/CHANGELOG.md index 15859a5744d..ded1e85533f 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/CHANGELOG.md +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/CHANGELOG.md @@ -2,9 +2,9 @@ ## NOT YET RELEASED +- Updated to depend on OpenAI 2.4.0 - Updated tool mappings to recognize any `AIFunctionDeclaration`. - Updated to accommodate the additions in `Microsoft.Extensions.AI.Abstractions`. -- Updated to depend on OpenAI 2.4.0 ## 9.8.0-preview.1.25412.6 diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/HostedMcpServerToolApprovalModeTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/HostedMcpServerToolApprovalModeTests.cs index 3ad4690130a..9f72a194f06 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/HostedMcpServerToolApprovalModeTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/HostedMcpServerToolApprovalModeTests.cs @@ -45,4 +45,51 @@ public void Serialization_RequireSpecific_Roundtrips() HostedMcpServerToolApprovalMode? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.HostedMcpServerToolApprovalMode); Assert.Equal(requireSpecific, result); } + + [Fact] + public void Equality_RequireSpecific_WorksAsExpected() + { + var mode1 = HostedMcpServerToolApprovalMode.RequireSpecific(["ToolA", "ToolB"], ["ToolC"]); + var mode2 = HostedMcpServerToolApprovalMode.RequireSpecific(["ToolA", "ToolB"], ["ToolC"]); + Assert.Equal(mode1, mode2); + Assert.Equal(mode1.GetHashCode(), mode2.GetHashCode()); + + Assert.NotNull(mode1.AlwaysRequireApprovalToolNames); + mode1.AlwaysRequireApprovalToolNames.Add("ToolD"); + Assert.NotEqual(mode1, mode2); + Assert.NotEqual(mode1.GetHashCode(), mode2.GetHashCode()); + + Assert.NotNull(mode2.AlwaysRequireApprovalToolNames); + mode2.AlwaysRequireApprovalToolNames.Add("ToolD"); + Assert.Equal(mode1, mode2); + Assert.Equal(mode1.GetHashCode(), mode2.GetHashCode()); + + Assert.NotNull(mode2.NeverRequireApprovalToolNames); + mode2.NeverRequireApprovalToolNames.Add("ToolE"); + Assert.NotEqual(mode1, mode2); + Assert.NotEqual(mode1.GetHashCode(), mode2.GetHashCode()); + + Assert.NotNull(mode1.NeverRequireApprovalToolNames); + mode1.NeverRequireApprovalToolNames.Add("ToolE"); + Assert.Equal(mode1, mode2); + Assert.Equal(mode1.GetHashCode(), mode2.GetHashCode()); + + var mode3 = HostedMcpServerToolApprovalMode.RequireSpecific(null, null); + Assert.Equal(mode3.GetHashCode(), mode3.GetHashCode()); + var mode4 = HostedMcpServerToolApprovalMode.RequireSpecific(["a"], null); + Assert.Equal(mode4.GetHashCode(), mode4.GetHashCode()); + Assert.NotEqual(mode3, mode4); + Assert.NotEqual(mode3.GetHashCode(), mode4.GetHashCode()); + + var mode5 = HostedMcpServerToolApprovalMode.RequireSpecific(null, ["b"]); + Assert.Equal(mode5.GetHashCode(), mode5.GetHashCode()); + Assert.NotEqual(mode3, mode5); + Assert.NotEqual(mode3.GetHashCode(), mode5.GetHashCode()); + Assert.NotEqual(mode4, mode5); + Assert.NotEqual(mode4.GetHashCode(), mode5.GetHashCode()); + + var mode6 = HostedMcpServerToolApprovalMode.RequireSpecific([], []); + Assert.Equal(mode6.GetHashCode(), mode6.GetHashCode()); + Assert.NotEqual(mode3, mode6); + } }