diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallContent.cs index 836d5a4110b..7c506a7845b 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallContent.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallContent.cs @@ -14,7 +14,7 @@ namespace Microsoft.Extensions.AI; /// Represents a function call request. /// [DebuggerDisplay("{DebuggerDisplay,nq}")] -public sealed class FunctionCallContent : AIContent +public class FunctionCallContent : AIContent { /// /// Initializes a new instance of the class. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionResultContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionResultContent.cs index 46401347b40..d5eb4884709 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionResultContent.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionResultContent.cs @@ -13,7 +13,7 @@ namespace Microsoft.Extensions.AI; /// Represents the result of a function call. /// [DebuggerDisplay("{DebuggerDisplay,nq}")] -public sealed class FunctionResultContent : AIContent +public class FunctionResultContent : AIContent { /// /// Initializes a new instance of the class. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json index 6362fcaa00e..ef1f2ebb496 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json @@ -1792,7 +1792,7 @@ ] }, { - "Type": "sealed class Microsoft.Extensions.AI.FunctionCallContent : Microsoft.Extensions.AI.AIContent", + "Type": "class Microsoft.Extensions.AI.FunctionCallContent : Microsoft.Extensions.AI.AIContent", "Stage": "Stable", "Methods": [ { @@ -1824,7 +1824,7 @@ ] }, { - "Type": "sealed class Microsoft.Extensions.AI.FunctionResultContent : Microsoft.Extensions.AI.AIContent", + "Type": "class Microsoft.Extensions.AI.FunctionResultContent : Microsoft.Extensions.AI.AIContent", "Stage": "Stable", "Methods": [ { diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index fdc7d57b1ea..c88fe91fc17 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -1208,6 +1208,13 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul object? functionResult; if (result.Status == FunctionInvocationStatus.RanToCompletion) { + // If the result is already a FunctionResultContent with a matching CallId, use it directly. + if (result.Result is FunctionResultContent frc && + frc.CallId == result.CallContent.CallId) + { + return frc; + } + functionResult = result.Result ?? "Success: Function completed."; } else diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index e9672275871..066fb5e3427 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -354,6 +354,264 @@ public async Task FunctionInvokerDelegateOverridesHandlingAsync() await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure); } + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task FunctionReturningFunctionResultContentWithMatchingCallId_UsesItDirectly(bool streaming) + { + FunctionResultContent? returnedFrc = null; + + var options = new ChatOptions + { + Tools = + [ + AIFunctionFactory.Create(() => "Result 1", "Func1"), + ] + }; + + using var innerClient = new TestChatClient + { + GetResponseAsyncCallback = (msgs, opts, ct) => + { + var toolMessage = msgs.FirstOrDefault(m => m.Role == ChatRole.Tool); + if (toolMessage is null) + { + return Task.FromResult(new ChatResponse( + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]))); + } + else + { + return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "done"))); + } + }, + GetStreamingResponseAsyncCallback = (msgs, opts, ct) => + { + var toolMessage = msgs.FirstOrDefault(m => m.Role == ChatRole.Tool); + if (toolMessage is null) + { + return YieldAsync(new ChatResponse( + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")])).ToChatResponseUpdates()); + } + else + { + return YieldAsync(new ChatResponse(new ChatMessage(ChatRole.Assistant, "done")).ToChatResponseUpdates()); + } + } + }; + + using var client = new FunctionInvokingChatClient(innerClient) + { + FunctionInvoker = (ctx, cancellationToken) => + { + returnedFrc = new FunctionResultContent(ctx.CallContent.CallId, "Custom result from function") + { + RawRepresentation = "CustomRaw" + }; + return new ValueTask(returnedFrc); + } + }; + + var messages = new List + { + new ChatMessage(ChatRole.User, "hello"), + }; + + ChatResponse response; + if (streaming) + { + response = await client.GetStreamingResponseAsync(messages, options).ToChatResponseAsync(); + } + else + { + response = await client.GetResponseAsync(messages, options); + } + + // Verify that the FunctionResultContent was used directly (same reference) + var toolMessage = response.Messages.First(m => m.Role == ChatRole.Tool); + var capturedFrc = Assert.Single(toolMessage.Contents.OfType()); + Assert.Same(returnedFrc, capturedFrc); + Assert.Equal("Custom result from function", capturedFrc.Result); + Assert.Equal("CustomRaw", capturedFrc.RawRepresentation); + Assert.Equal("callId1", capturedFrc.CallId); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task FunctionReturningFunctionResultContentWithMismatchedCallId_WrapsIt(bool streaming) + { + FunctionResultContent? returnedFrc = null; + + var options = new ChatOptions + { + Tools = + [ + AIFunctionFactory.Create(() => "Result 1", "Func1"), + ] + }; + + using var innerClient = new TestChatClient + { + GetResponseAsyncCallback = (msgs, opts, ct) => + { + var toolMessage = msgs.FirstOrDefault(m => m.Role == ChatRole.Tool); + if (toolMessage is null) + { + return Task.FromResult(new ChatResponse( + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]))); + } + else + { + return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "done"))); + } + }, + GetStreamingResponseAsyncCallback = (msgs, opts, ct) => + { + var toolMessage = msgs.FirstOrDefault(m => m.Role == ChatRole.Tool); + if (toolMessage is null) + { + return YieldAsync(new ChatResponse( + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")])).ToChatResponseUpdates()); + } + else + { + return YieldAsync(new ChatResponse(new ChatMessage(ChatRole.Assistant, "done")).ToChatResponseUpdates()); + } + } + }; + + using var client = new FunctionInvokingChatClient(innerClient) + { + FunctionInvoker = (ctx, cancellationToken) => + { + // Return a FunctionResultContent with a different CallId + returnedFrc = new FunctionResultContent("differentCallId", "Result from function"); + return new ValueTask(returnedFrc); + } + }; + + var messages = new List + { + new ChatMessage(ChatRole.User, "hello"), + }; + + ChatResponse response; + if (streaming) + { + response = await client.GetStreamingResponseAsync(messages, options).ToChatResponseAsync(); + } + else + { + response = await client.GetResponseAsync(messages, options); + } + + // Verify the result is wrapped - the outer FunctionResultContent has the correct CallId + // and the inner one is reference-equal to what was returned + var toolMessage = response.Messages.First(m => m.Role == ChatRole.Tool); + var frc = Assert.Single(toolMessage.Contents.OfType()); + Assert.Equal("callId1", frc.CallId); + Assert.Same(returnedFrc, frc.Result); + var innerFrc = (FunctionResultContent)frc.Result!; + Assert.Equal("differentCallId", innerFrc.CallId); + Assert.Equal("Result from function", innerFrc.Result); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task FunctionReturningDerivedFunctionResultContent_PropagatesInstanceToInnerClient(bool streaming) + { + DerivedFunctionResultContent? returnedFrc = null; + + var options = new ChatOptions + { + Tools = + [ + AIFunctionFactory.Create(() => "Result 1", "Func1"), + ] + }; + + using var innerClient = new TestChatClient + { + GetResponseAsyncCallback = (msgs, opts, ct) => + { + var toolMessage = msgs.FirstOrDefault(m => m.Role == ChatRole.Tool); + if (toolMessage is null) + { + return Task.FromResult(new ChatResponse( + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]))); + } + else + { + return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "done"))); + } + }, + GetStreamingResponseAsyncCallback = (msgs, opts, ct) => + { + var toolMessage = msgs.FirstOrDefault(m => m.Role == ChatRole.Tool); + if (toolMessage is null) + { + return YieldAsync(new ChatResponse( + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")])).ToChatResponseUpdates()); + } + else + { + return YieldAsync(new ChatResponse(new ChatMessage(ChatRole.Assistant, "done")).ToChatResponseUpdates()); + } + } + }; + + using var client = new FunctionInvokingChatClient(innerClient) + { + FunctionInvoker = (ctx, cancellationToken) => + { + // Return a derived FunctionResultContent + returnedFrc = new DerivedFunctionResultContent(ctx.CallContent.CallId, "Derived result") + { + CustomProperty = "CustomValue" + }; + return new ValueTask(returnedFrc); + } + }; + + var messages = new List + { + new ChatMessage(ChatRole.User, "hello"), + }; + + ChatResponse response; + if (streaming) + { + response = await client.GetStreamingResponseAsync(messages, options).ToChatResponseAsync(); + } + else + { + response = await client.GetResponseAsync(messages, options); + } + + // Verify that the derived FunctionResultContent instance was propagated to the inner client + // and is reference-equal to what was returned + var toolMessage = response.Messages.First(m => m.Role == ChatRole.Tool); + var capturedFrc = Assert.Single(toolMessage.Contents.OfType()); + Assert.Same(returnedFrc, capturedFrc); + Assert.IsType(capturedFrc); + var derivedFrc = (DerivedFunctionResultContent)capturedFrc; + Assert.Equal("callId1", derivedFrc.CallId); + Assert.Equal("Derived result", derivedFrc.Result); + Assert.Equal("CustomValue", derivedFrc.CustomProperty); + } + + /// A derived FunctionResultContent for testing purposes. + private sealed class DerivedFunctionResultContent : FunctionResultContent + { + public DerivedFunctionResultContent(string callId, object? result) + : base(callId, result) + { + } + + public string? CustomProperty { get; set; } + } + [Fact] public async Task ContinuesWithSuccessfulCallsUntilMaximumIterations() {