Skip to content

Commit fab7fe0

Browse files
stephentoubjeffhandley
authored andcommitted
Fix handling of tool calls with some OpenAI endpoints (#6405)
* Fix handling of tool calls with some endpoints Most assistant messages containing tool calls don't contain text as well (though some can). In such a case, we were still creating the assistant with empty text. While OpenAI's service permits that, some other endpoints are more finicky about it. This avoids doing so. * Reduce to single iteration through assistant content
1 parent 6100a72 commit fab7fe0

File tree

1 file changed

+75
-44
lines changed

1 file changed

+75
-44
lines changed

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs

Lines changed: 75 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -157,30 +157,54 @@ void IDisposable.Dispose()
157157
}
158158
else if (input.Role == ChatRole.Assistant)
159159
{
160-
AssistantChatMessage message = new(ToOpenAIChatContent(input.Contents))
161-
{
162-
ParticipantName = input.AuthorName
163-
};
164-
160+
List<ChatMessageContentPart>? contentParts = null;
161+
List<ChatToolCall>? toolCalls = null;
162+
string? refusal = null;
165163
foreach (var content in input.Contents)
166164
{
167165
switch (content)
168166
{
169-
case ErrorContent errorContent when errorContent.ErrorCode is nameof(message.Refusal):
170-
message.Refusal = errorContent.Message;
167+
case ErrorContent ec when ec.ErrorCode == nameof(AssistantChatMessage.Refusal):
168+
refusal = ec.Message;
171169
break;
172170

173-
case FunctionCallContent callRequest:
174-
message.ToolCalls.Add(
175-
ChatToolCall.CreateFunctionToolCall(
176-
callRequest.CallId,
177-
callRequest.Name,
178-
new(JsonSerializer.SerializeToUtf8Bytes(
179-
callRequest.Arguments,
180-
options.GetTypeInfo(typeof(IDictionary<string, object?>))))));
171+
case FunctionCallContent fc:
172+
(toolCalls ??= []).Add(
173+
ChatToolCall.CreateFunctionToolCall(fc.CallId, fc.Name, new(JsonSerializer.SerializeToUtf8Bytes(
174+
fc.Arguments, options.GetTypeInfo(typeof(IDictionary<string, object?>))))));
181175
break;
176+
177+
default:
178+
if (ToChatMessageContentPart(content) is { } part)
179+
{
180+
(contentParts ??= []).Add(part);
181+
}
182+
183+
break;
184+
}
185+
}
186+
187+
AssistantChatMessage message;
188+
if (contentParts is not null)
189+
{
190+
message = new(contentParts);
191+
if (toolCalls is not null)
192+
{
193+
foreach (var toolCall in toolCalls)
194+
{
195+
message.ToolCalls.Add(toolCall);
196+
}
182197
}
183198
}
199+
else
200+
{
201+
message = toolCalls is not null ?
202+
new(toolCalls) :
203+
new(ChatMessageContentPart.CreateTextPart(string.Empty));
204+
}
205+
206+
message.ParticipantName = input.AuthorName;
207+
message.Refusal = refusal;
184208

185209
yield return message;
186210
}
@@ -191,38 +215,12 @@ void IDisposable.Dispose()
191215
private static List<ChatMessageContentPart> ToOpenAIChatContent(IList<AIContent> contents)
192216
{
193217
List<ChatMessageContentPart> parts = [];
218+
194219
foreach (var content in contents)
195220
{
196-
switch (content)
221+
if (ToChatMessageContentPart(content) is { } part)
197222
{
198-
case TextContent textContent:
199-
parts.Add(ChatMessageContentPart.CreateTextPart(textContent.Text));
200-
break;
201-
202-
case UriContent uriContent when uriContent.HasTopLevelMediaType("image"):
203-
parts.Add(ChatMessageContentPart.CreateImagePart(uriContent.Uri, GetImageDetail(content)));
204-
break;
205-
206-
case DataContent dataContent when dataContent.HasTopLevelMediaType("image"):
207-
parts.Add(ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(dataContent.Data), dataContent.MediaType, GetImageDetail(content)));
208-
break;
209-
210-
case DataContent dataContent when dataContent.HasTopLevelMediaType("audio"):
211-
var audioData = BinaryData.FromBytes(dataContent.Data);
212-
if (dataContent.MediaType.Equals("audio/mpeg", StringComparison.OrdinalIgnoreCase))
213-
{
214-
parts.Add(ChatMessageContentPart.CreateInputAudioPart(audioData, ChatInputAudioFormat.Mp3));
215-
}
216-
else if (dataContent.MediaType.Equals("audio/wav", StringComparison.OrdinalIgnoreCase))
217-
{
218-
parts.Add(ChatMessageContentPart.CreateInputAudioPart(audioData, ChatInputAudioFormat.Wav));
219-
}
220-
221-
break;
222-
223-
case DataContent dataContent when dataContent.MediaType.StartsWith("application/pdf", StringComparison.OrdinalIgnoreCase):
224-
parts.Add(ChatMessageContentPart.CreateFilePart(BinaryData.FromBytes(dataContent.Data), dataContent.MediaType, $"{Guid.NewGuid():N}.pdf"));
225-
break;
223+
parts.Add(part);
226224
}
227225
}
228226

@@ -234,6 +232,39 @@ private static List<ChatMessageContentPart> ToOpenAIChatContent(IList<AIContent>
234232
return parts;
235233
}
236234

235+
private static ChatMessageContentPart? ToChatMessageContentPart(AIContent content)
236+
{
237+
switch (content)
238+
{
239+
case TextContent textContent:
240+
return ChatMessageContentPart.CreateTextPart(textContent.Text);
241+
242+
case UriContent uriContent when uriContent.HasTopLevelMediaType("image"):
243+
return ChatMessageContentPart.CreateImagePart(uriContent.Uri, GetImageDetail(content));
244+
245+
case DataContent dataContent when dataContent.HasTopLevelMediaType("image"):
246+
return ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(dataContent.Data), dataContent.MediaType, GetImageDetail(content));
247+
248+
case DataContent dataContent when dataContent.HasTopLevelMediaType("audio"):
249+
var audioData = BinaryData.FromBytes(dataContent.Data);
250+
if (dataContent.MediaType.Equals("audio/mpeg", StringComparison.OrdinalIgnoreCase))
251+
{
252+
return ChatMessageContentPart.CreateInputAudioPart(audioData, ChatInputAudioFormat.Mp3);
253+
}
254+
else if (dataContent.MediaType.Equals("audio/wav", StringComparison.OrdinalIgnoreCase))
255+
{
256+
return ChatMessageContentPart.CreateInputAudioPart(audioData, ChatInputAudioFormat.Wav);
257+
}
258+
259+
break;
260+
261+
case DataContent dataContent when dataContent.MediaType.StartsWith("application/pdf", StringComparison.OrdinalIgnoreCase):
262+
return ChatMessageContentPart.CreateFilePart(BinaryData.FromBytes(dataContent.Data), dataContent.MediaType, $"{Guid.NewGuid():N}.pdf");
263+
}
264+
265+
return null;
266+
}
267+
237268
private static ChatImageDetailLevel? GetImageDetail(AIContent content)
238269
{
239270
if (content.AdditionalProperties?.TryGetValue("detail", out object? value) is true)

0 commit comments

Comments
 (0)