Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 75 additions & 44 deletions src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -157,30 +157,54 @@ void IDisposable.Dispose()
}
else if (input.Role == ChatRole.Assistant)
{
AssistantChatMessage message = new(ToOpenAIChatContent(input.Contents))
{
ParticipantName = input.AuthorName
};

List<ChatMessageContentPart>? contentParts = null;
List<ChatToolCall>? toolCalls = null;
string? refusal = null;
foreach (var content in input.Contents)
{
switch (content)
{
case ErrorContent errorContent when errorContent.ErrorCode is nameof(message.Refusal):
message.Refusal = errorContent.Message;
case ErrorContent ec when ec.ErrorCode == nameof(AssistantChatMessage.Refusal):
refusal = ec.Message;
break;

case FunctionCallContent callRequest:
message.ToolCalls.Add(
ChatToolCall.CreateFunctionToolCall(
callRequest.CallId,
callRequest.Name,
new(JsonSerializer.SerializeToUtf8Bytes(
callRequest.Arguments,
options.GetTypeInfo(typeof(IDictionary<string, object?>))))));
case FunctionCallContent fc:
(toolCalls ??= []).Add(
ChatToolCall.CreateFunctionToolCall(fc.CallId, fc.Name, new(JsonSerializer.SerializeToUtf8Bytes(
fc.Arguments, options.GetTypeInfo(typeof(IDictionary<string, object?>))))));
break;

default:
if (ToChatMessageContentPart(content) is { } part)
{
(contentParts ??= []).Add(part);
}

break;
}
}

AssistantChatMessage message;
if (contentParts is not null)
{
message = new(contentParts);
if (toolCalls is not null)
{
foreach (var toolCall in toolCalls)
{
message.ToolCalls.Add(toolCall);
}
}
}
else
{
message = toolCalls is not null ?
new(toolCalls) :
new(ChatMessageContentPart.CreateTextPart(string.Empty));
}

message.ParticipantName = input.AuthorName;
message.Refusal = refusal;

yield return message;
}
Expand All @@ -191,38 +215,12 @@ void IDisposable.Dispose()
private static List<ChatMessageContentPart> ToOpenAIChatContent(IList<AIContent> contents)
{
List<ChatMessageContentPart> parts = [];

foreach (var content in contents)
{
switch (content)
if (ToChatMessageContentPart(content) is { } part)
{
case TextContent textContent:
parts.Add(ChatMessageContentPart.CreateTextPart(textContent.Text));
break;

case UriContent uriContent when uriContent.HasTopLevelMediaType("image"):
parts.Add(ChatMessageContentPart.CreateImagePart(uriContent.Uri, GetImageDetail(content)));
break;

case DataContent dataContent when dataContent.HasTopLevelMediaType("image"):
parts.Add(ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(dataContent.Data), dataContent.MediaType, GetImageDetail(content)));
break;

case DataContent dataContent when dataContent.HasTopLevelMediaType("audio"):
var audioData = BinaryData.FromBytes(dataContent.Data);
if (dataContent.MediaType.Equals("audio/mpeg", StringComparison.OrdinalIgnoreCase))
{
parts.Add(ChatMessageContentPart.CreateInputAudioPart(audioData, ChatInputAudioFormat.Mp3));
}
else if (dataContent.MediaType.Equals("audio/wav", StringComparison.OrdinalIgnoreCase))
{
parts.Add(ChatMessageContentPart.CreateInputAudioPart(audioData, ChatInputAudioFormat.Wav));
}

break;

case DataContent dataContent when dataContent.MediaType.StartsWith("application/pdf", StringComparison.OrdinalIgnoreCase):
parts.Add(ChatMessageContentPart.CreateFilePart(BinaryData.FromBytes(dataContent.Data), dataContent.MediaType, $"{Guid.NewGuid():N}.pdf"));
break;
parts.Add(part);
}
}

Expand All @@ -234,6 +232,39 @@ private static List<ChatMessageContentPart> ToOpenAIChatContent(IList<AIContent>
return parts;
}

private static ChatMessageContentPart? ToChatMessageContentPart(AIContent content)
{
switch (content)
{
case TextContent textContent:
return ChatMessageContentPart.CreateTextPart(textContent.Text);

case UriContent uriContent when uriContent.HasTopLevelMediaType("image"):
return ChatMessageContentPart.CreateImagePart(uriContent.Uri, GetImageDetail(content));

case DataContent dataContent when dataContent.HasTopLevelMediaType("image"):
return ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(dataContent.Data), dataContent.MediaType, GetImageDetail(content));

case DataContent dataContent when dataContent.HasTopLevelMediaType("audio"):
var audioData = BinaryData.FromBytes(dataContent.Data);
if (dataContent.MediaType.Equals("audio/mpeg", StringComparison.OrdinalIgnoreCase))
{
return ChatMessageContentPart.CreateInputAudioPart(audioData, ChatInputAudioFormat.Mp3);
}
else if (dataContent.MediaType.Equals("audio/wav", StringComparison.OrdinalIgnoreCase))
{
return ChatMessageContentPart.CreateInputAudioPart(audioData, ChatInputAudioFormat.Wav);
}

break;

case DataContent dataContent when dataContent.MediaType.StartsWith("application/pdf", StringComparison.OrdinalIgnoreCase):
return ChatMessageContentPart.CreateFilePart(BinaryData.FromBytes(dataContent.Data), dataContent.MediaType, $"{Guid.NewGuid():N}.pdf");
}

return null;
}

private static ChatImageDetailLevel? GetImageDetail(AIContent content)
{
if (content.AdditionalProperties?.TryGetValue("detail", out object? value) is true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ public virtual async Task GetResponseAsync_WithEmptyMessage()

var response = await _chatClient.GetResponseAsync(
[
new(ChatRole.System, []),
new(ChatRole.User, []),
new(ChatRole.Assistant, []),
new(ChatRole.User, "What is 1 + 2? Reply with a single number."),
]);

Expand Down Expand Up @@ -618,9 +620,11 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputUnchange
var secondResponse = await chatClient.GetResponseAsync([message]);
Assert.Equal(response.Text, secondResponse.Text);
Assert.Equal(2, functionCallCount);
Assert.Equal(2, llmCallCount!.CallCount);
Assert.Equal(FunctionInvokingChatClientSetsConversationId ? 3 : 2, llmCallCount!.CallCount);
}

public virtual bool FunctionInvokingChatClientSetsConversationId => false;

[ConditionalFact]
public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputChangedAsync()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ public class OpenAIResponseClientIntegrationTests : ChatClientIntegrationTests
IntegrationTestHelpers.GetOpenAIClient()
?.GetOpenAIResponseClient(TestRunnerConfiguration.Instance["OpenAI:ChatModel"] ?? "gpt-4o-mini")
.AsIChatClient();

public override bool FunctionInvokingChatClientSetsConversationId => true;
}
Loading