Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Text.Json.Serialization;
using Microsoft.Shared.Diagnostics;
Expand Down Expand Up @@ -56,6 +57,24 @@ public FunctionCallContent(string callId, string name, IDictionary<string, objec
[JsonIgnore]
public Exception? Exception { get; set; }

/// <summary>
/// Gets or sets a value indicating whether this function call requires invocation.
/// </summary>
/// <remarks>
/// <para>
/// This property defaults to <see langword="true"/>, indicating that the function call should be processed.
/// When set to <see langword="false"/>, it indicates that the function has already been processed and
/// should be ignored by components that process function calls, such as FunctionInvokingChatClient.
/// </para>
/// <para>
/// This property is not serialized when it has the value <see langword="false"/> (the JSON default for bool).
/// When deserialized, if the property is not present in the JSON, it will default to <see langword="true"/>.
/// </para>
/// </remarks>
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
[Experimental("MEAI001")]
public bool InvocationRequired { get; set; } = true;

/// <summary>
/// Creates a new instance of <see cref="FunctionCallContent"/> parsing arguments using a specified encoding and parser.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ public override async Task<ChatResponse> GetResponseAsync(
responseMessages.AddRange(modeAndMessages.MessagesAdded);
consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;

// Mark the function calls that were just processed as no longer requiring invocation
MarkFunctionCallsAsProcessed(responseMessages, modeAndMessages.MessagesAdded);

if (modeAndMessages.ShouldTerminate)
{
break;
Expand Down Expand Up @@ -608,6 +611,9 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
responseMessages.AddRange(modeAndMessages.MessagesAdded);
consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;

// Mark the function calls that were just processed as no longer requiring invocation
MarkFunctionCallsAsProcessed(responseMessages, modeAndMessages.MessagesAdded);

// Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages
// includes all activities, including generated function results.
foreach (var message in modeAndMessages.MessagesAdded)
Expand Down Expand Up @@ -775,7 +781,7 @@ private static bool CopyFunctionCalls(
int count = content.Count;
for (int i = 0; i < count; i++)
{
if (content[i] is FunctionCallContent functionCall)
if (content[i] is FunctionCallContent functionCall && functionCall.InvocationRequired)
{
(functionCalls ??= []).Add(functionCall);
any = true;
Expand Down Expand Up @@ -1619,6 +1625,43 @@ private static IList<ChatMessage> ReplaceFunctionCallsWithApprovalRequests(
return outputMessages;
}

/// <summary>
/// Marks FunctionCallContent objects in allMessages as processed (InvocationRequired = false)
/// if they have corresponding FunctionResultContent in the newly added messages.
/// </summary>
/// <param name="allMessages">All messages accumulated so far.</param>
/// <param name="newlyAddedMessages">The messages that were just added containing FunctionResultContent.</param>
private static void MarkFunctionCallsAsProcessed(List<ChatMessage> allMessages, IList<ChatMessage> newlyAddedMessages)
{
// Build a set of call IDs from the newly added FunctionResultContent
HashSet<string>? processedCallIds = null;
foreach (var message in newlyAddedMessages)
{
foreach (var content in message.Contents)
{
if (content is FunctionResultContent frc)
{
_ = (processedCallIds ??= []).Add(frc.CallId);
}
}
}

// Mark FunctionCallContent with matching call IDs as processed
if (processedCallIds is not null)
{
foreach (var message in allMessages)
{
foreach (var content in message.Contents)
{
if (content is FunctionCallContent fcc && processedCallIds.Contains(fcc.CallId))
{
fcc.InvocationRequired = false;
}
}
}
}
}

private static TimeSpan GetElapsedTime(long startingTimestamp) =>
#if NET
Stopwatch.GetElapsedTime(startingTimestamp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,96 @@ public void Constructor_PropsRoundtrip()
Exception e = new();
c.Exception = e;
Assert.Same(e, c.Exception);

Assert.True(c.InvocationRequired);
c.InvocationRequired = false;
Assert.False(c.InvocationRequired);
}

[Fact]
public void InvocationRequired_DefaultsToTrue()
{
FunctionCallContent c = new("callId1", "name");
Assert.True(c.InvocationRequired);
}

[Fact]
public void InvocationRequired_CanBeSetToFalse()
{
FunctionCallContent c = new("callId1", "name") { InvocationRequired = false };
Assert.False(c.InvocationRequired);
}

[Fact]
public void InvocationRequired_NotSerializedWhenFalse()
{
// Arrange - Set InvocationRequired to false (the JSON default value for bool)
var sut = new FunctionCallContent("callId1", "functionName", new Dictionary<string, object?> { ["key"] = "value" })
{
InvocationRequired = false
};

// Act
var json = JsonSerializer.SerializeToNode(sut, TestJsonSerializerContext.Default.Options);

// Assert - InvocationRequired should not be in the JSON when it's false (default for bool)
Assert.NotNull(json);
Assert.False(json!.AsObject().ContainsKey("invocationRequired"));
Assert.False(json!.AsObject().ContainsKey("InvocationRequired"));
}

[Fact]
public void InvocationRequired_SerializedWhenTrue()
{
// Arrange - InvocationRequired defaults to true
var sut = new FunctionCallContent("callId1", "functionName", new Dictionary<string, object?> { ["key"] = "value" });

// Act
var json = JsonSerializer.SerializeToNode(sut, TestJsonSerializerContext.Default.Options);

// Assert - InvocationRequired should be in the JSON when it's true
Assert.NotNull(json);
var jsonObj = json!.AsObject();
Assert.True(jsonObj.ContainsKey("invocationRequired") || jsonObj.ContainsKey("InvocationRequired"));

JsonNode? invocationRequiredValue = null;
if (jsonObj.TryGetPropertyValue("invocationRequired", out var value1))
{
invocationRequiredValue = value1;
}
else if (jsonObj.TryGetPropertyValue("InvocationRequired", out var value2))
{
invocationRequiredValue = value2;
}

Assert.NotNull(invocationRequiredValue);
Assert.True(invocationRequiredValue!.GetValue<bool>());
}

[Fact]
public void InvocationRequired_DeserializedCorrectly()
{
// Test deserialization when InvocationRequired is true
var json = """{"callId":"callId1","name":"functionName","invocationRequired":true}""";
var deserialized = JsonSerializer.Deserialize<FunctionCallContent>(json, TestJsonSerializerContext.Default.Options);

Assert.NotNull(deserialized);
Assert.Equal("callId1", deserialized.CallId);
Assert.Equal("functionName", deserialized.Name);
Assert.True(deserialized.InvocationRequired);
}

[Fact]
public void InvocationRequired_DeserializedToTrueWhenMissing()
{
// Test deserialization when InvocationRequired is not in JSON (should default to true from field initializer)
var json = """{"callId":"callId1","name":"functionName"}""";
var deserialized = JsonSerializer.Deserialize<FunctionCallContent>(json, TestJsonSerializerContext.Default.Options);

Assert.NotNull(deserialized);
Assert.Equal("callId1", deserialized.CallId);
Assert.Equal("functionName", deserialized.Name);
Assert.True(deserialized.InvocationRequired);
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,159 @@ public async Task CreatesOrchestrateToolsSpanWhenNoInvokeAgentParent(bool stream
}
}

[Fact]
public async Task InvocationRequired_SetToFalseAfterProcessing()
{
var options = new ChatOptions
{
Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")]
};

List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]),
new ChatMessage(ChatRole.Assistant, "world"),
];

var chat = await InvokeAndAssertAsync(options, plan);

// Find the FunctionCallContent in the chat history
var functionCallMessage = chat.First(m => m.Contents.Any(c => c is FunctionCallContent));
var functionCallContent = functionCallMessage.Contents.OfType<FunctionCallContent>().First();

// Verify InvocationRequired was set to false after processing
Assert.False(functionCallContent.InvocationRequired);
}

[Fact]
public async Task InvocationRequired_IgnoresFunctionCallsWithInvocationRequiredFalse()
{
var functionInvokedCount = 0;
var options = new ChatOptions
{
Tools = [AIFunctionFactory.Create(() => { functionInvokedCount++; return "Result 1"; }, "Func1")]
};

// Create a function call that has already been processed
var alreadyProcessedFunctionCall = new FunctionCallContent("callId1", "Func1") { InvocationRequired = false };

using var innerClient = new TestChatClient
{
GetResponseAsyncCallback = async (contents, actualOptions, actualCancellationToken) =>
{
await Task.Yield();

// Return a response with a FunctionCallContent that has InvocationRequired = false
var message = new ChatMessage(ChatRole.Assistant, [alreadyProcessedFunctionCall]);
return new ChatResponse(message);
}
};

using var client = new FunctionInvokingChatClient(innerClient);

var response = await client.GetResponseAsync([new ChatMessage(ChatRole.User, "hello")], options);

// The function should not have been invoked since InvocationRequired was false
Assert.Equal(0, functionInvokedCount);

// The response should contain the FunctionCallContent but no FunctionResultContent
Assert.Contains(response.Messages, m => m.Contents.Any(c => c is FunctionCallContent fcc && !fcc.InvocationRequired));
Assert.DoesNotContain(response.Messages, m => m.Contents.Any(c => c is FunctionResultContent));
}

[Fact]
public async Task InvocationRequired_IgnoresFunctionCallsWithInvocationRequiredFalse_Streaming()
{
var functionInvokedCount = 0;
var options = new ChatOptions
{
Tools = [AIFunctionFactory.Create(() => { functionInvokedCount++; return "Result 1"; }, "Func1")]
};

// Create a function call that has already been processed
var alreadyProcessedFunctionCall = new FunctionCallContent("callId1", "Func1") { InvocationRequired = false };

using var innerClient = new TestChatClient
{
GetStreamingResponseAsyncCallback = (contents, actualOptions, actualCancellationToken) =>
{
// Return a response with a FunctionCallContent that has InvocationRequired = false
var message = new ChatMessage(ChatRole.Assistant, [alreadyProcessedFunctionCall]);
return YieldAsync(new ChatResponse(message).ToChatResponseUpdates());
}
};

using var client = new FunctionInvokingChatClient(innerClient);

var updates = new List<ChatResponseUpdate>();
await foreach (var update in client.GetStreamingResponseAsync([new ChatMessage(ChatRole.User, "hello")], options))
{
updates.Add(update);
}

// The function should not have been invoked since InvocationRequired was false
Assert.Equal(0, functionInvokedCount);

// The updates should contain the FunctionCallContent but no FunctionResultContent
Assert.Contains(updates, u => u.Contents.Any(c => c is FunctionCallContent fcc && !fcc.InvocationRequired));
Assert.DoesNotContain(updates, u => u.Contents.Any(c => c is FunctionResultContent));
}

[Fact]
public async Task InvocationRequired_ProcessesMixedFunctionCalls()
{
var func1InvokedCount = 0;
var func2InvokedCount = 0;

var options = new ChatOptions
{
Tools =
[
AIFunctionFactory.Create(() => { func1InvokedCount++; return "Result 1"; }, "Func1"),
AIFunctionFactory.Create(() => { func2InvokedCount++; return "Result 2"; }, "Func2"),
]
};

// Create one function call that needs processing and one that doesn't
var needsProcessing = new FunctionCallContent("callId1", "Func1") { InvocationRequired = true };
var alreadyProcessed = new FunctionCallContent("callId2", "Func2") { InvocationRequired = false };

using var innerClient = new TestChatClient
{
GetResponseAsyncCallback = async (contents, actualOptions, actualCancellationToken) =>
{
await Task.Yield();

if (contents.Count() == 1)
{
// First call - return both function calls
var message = new ChatMessage(ChatRole.Assistant, [needsProcessing, alreadyProcessed]);
return new ChatResponse(message);
}
else
{
// Second call - return final response after processing
var message = new ChatMessage(ChatRole.Assistant, "done");
return new ChatResponse(message);
}
}
};

using var client = new FunctionInvokingChatClient(innerClient);

var response = await client.GetResponseAsync([new ChatMessage(ChatRole.User, "hello")], options);

// Only Func1 should have been invoked (the one with InvocationRequired = true)
Assert.Equal(1, func1InvokedCount);
Assert.Equal(0, func2InvokedCount);

// The response should contain FunctionResultContent for Func1 but not Func2
Assert.Contains(response.Messages, m => m.Contents.Any(c => c is FunctionResultContent frc && frc.CallId == "callId1"));
Assert.DoesNotContain(response.Messages, m => m.Contents.Any(c => c is FunctionResultContent frc && frc.CallId == "callId2"));
}

private sealed class CustomSynchronizationContext : SynchronizationContext
{
public override void Post(SendOrPostCallback d, object? state)
Expand Down
Loading