Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
Expand All @@ -25,6 +28,27 @@ namespace Microsoft.Extensions.AI;
/// <summary>Represents an <see cref="IChatClient"/> for an OpenAI <see cref="OpenAIClient"/> or <see cref="ChatClient"/>.</summary>
internal sealed class OpenAIChatClient : IChatClient
{
// These delegate instances are used to call the internal overloads of CompleteChatAsync and CompleteChatStreamingAsync that accept
// a RequestOptions. These should be replaced once a better way to pass RequestOptions is available.
private static readonly Func<ChatClient, IEnumerable<OpenAI.Chat.ChatMessage>, ChatCompletionOptions, RequestOptions, Task<ClientResult<ChatCompletion>>>?
_completeChatAsync =
(Func<ChatClient, IEnumerable<OpenAI.Chat.ChatMessage>, ChatCompletionOptions, RequestOptions, Task<ClientResult<ChatCompletion>>>?)
typeof(ChatClient)
.GetMethod(
nameof(ChatClient.CompleteChatAsync), BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance,
null, [typeof(IEnumerable<OpenAI.Chat.ChatMessage>), typeof(ChatCompletionOptions), typeof(RequestOptions)], null)
?.CreateDelegate(
typeof(Func<ChatClient, IEnumerable<OpenAI.Chat.ChatMessage>, ChatCompletionOptions, RequestOptions, Task<ClientResult<ChatCompletion>>>));
private static readonly Func<ChatClient, IEnumerable<OpenAI.Chat.ChatMessage>, ChatCompletionOptions, RequestOptions, AsyncCollectionResult<StreamingChatCompletionUpdate>>?
_completeChatStreamingAsync =
(Func<ChatClient, IEnumerable<OpenAI.Chat.ChatMessage>, ChatCompletionOptions, RequestOptions, AsyncCollectionResult<StreamingChatCompletionUpdate>>?)
typeof(ChatClient)
.GetMethod(
nameof(ChatClient.CompleteChatStreamingAsync), BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance,
null, [typeof(IEnumerable<OpenAI.Chat.ChatMessage>), typeof(ChatCompletionOptions), typeof(RequestOptions)], null)
?.CreateDelegate(
typeof(Func<ChatClient, IEnumerable<OpenAI.Chat.ChatMessage>, ChatCompletionOptions, RequestOptions, AsyncCollectionResult<StreamingChatCompletionUpdate>>));

/// <summary>Metadata about the client.</summary>
private readonly ChatClientMetadata _metadata;

Expand Down Expand Up @@ -64,7 +88,10 @@ public async Task<ChatResponse> GetResponseAsync(
var openAIOptions = ToOpenAIOptions(options);

// Make the call to OpenAI.
var response = await _chatClient.CompleteChatAsync(openAIChatMessages, openAIOptions, cancellationToken).ConfigureAwait(false);
var task = _completeChatAsync is not null ?
_completeChatAsync(_chatClient, openAIChatMessages, openAIOptions, cancellationToken.ToRequestOptions(streaming: false)) :
_chatClient.CompleteChatAsync(openAIChatMessages, openAIOptions, cancellationToken);
var response = await task.ConfigureAwait(false);

return FromOpenAIChatCompletion(response.Value, openAIOptions);
}
Expand All @@ -79,7 +106,9 @@ public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
var openAIOptions = ToOpenAIOptions(options);

// Make the call to OpenAI.
var chatCompletionUpdates = _chatClient.CompleteChatStreamingAsync(openAIChatMessages, openAIOptions, cancellationToken);
var chatCompletionUpdates = _completeChatStreamingAsync is not null ?
_completeChatStreamingAsync(_chatClient, openAIChatMessages, openAIOptions, cancellationToken.ToRequestOptions(streaming: true)) :
_chatClient.CompleteChatStreamingAsync(openAIChatMessages, openAIOptions, cancellationToken);

return FromOpenAIStreamingChatCompletionAsync(chatCompletionUpdates, openAIOptions, cancellationToken);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
Expand Down Expand Up @@ -32,6 +33,23 @@ internal sealed class OpenAIResponsesChatClient : IChatClient
private static readonly Type? _internalResponseReasoningSummaryTextDeltaEventType = Type.GetType("OpenAI.Responses.InternalResponseReasoningSummaryTextDeltaEvent, OpenAI");
private static readonly PropertyInfo? _summaryTextDeltaProperty = _internalResponseReasoningSummaryTextDeltaEventType?.GetProperty("Delta");

// These delegate instances are used to call the internal overloads of CreateResponseAsync and CreateResponseStreamingAsync that accept
// a RequestOptions. These should be replaced once a better way to pass RequestOptions is available.
private static readonly Func<OpenAIResponseClient, IEnumerable<ResponseItem>, ResponseCreationOptions, RequestOptions, Task<ClientResult<OpenAIResponse>>>?
_createResponseAsync =
(Func<OpenAIResponseClient, IEnumerable<ResponseItem>, ResponseCreationOptions, RequestOptions, Task<ClientResult<OpenAIResponse>>>?)
typeof(OpenAIResponseClient).GetMethod(
nameof(OpenAIResponseClient.CreateResponseAsync), BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance,
null, [typeof(IEnumerable<ResponseItem>), typeof(ResponseCreationOptions), typeof(RequestOptions)], null)
?.CreateDelegate(typeof(Func<OpenAIResponseClient, IEnumerable<ResponseItem>, ResponseCreationOptions, RequestOptions, Task<ClientResult<OpenAIResponse>>>));
private static readonly Func<OpenAIResponseClient, IEnumerable<ResponseItem>, ResponseCreationOptions, RequestOptions, AsyncCollectionResult<StreamingResponseUpdate>>?
_createResponseStreamingAsync =
(Func<OpenAIResponseClient, IEnumerable<ResponseItem>, ResponseCreationOptions, RequestOptions, AsyncCollectionResult<StreamingResponseUpdate>>?)
typeof(OpenAIResponseClient).GetMethod(
nameof(OpenAIResponseClient.CreateResponseStreamingAsync), BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance,
null, [typeof(IEnumerable<ResponseItem>), typeof(ResponseCreationOptions), typeof(RequestOptions)], null)
?.CreateDelegate(typeof(Func<OpenAIResponseClient, IEnumerable<ResponseItem>, ResponseCreationOptions, RequestOptions, AsyncCollectionResult<StreamingResponseUpdate>>));

/// <summary>Metadata about the client.</summary>
private readonly ChatClientMetadata _metadata;

Expand Down Expand Up @@ -79,7 +97,10 @@ public async Task<ChatResponse> GetResponseAsync(
var openAIOptions = ToOpenAIResponseCreationOptions(options);

// Make the call to the OpenAIResponseClient.
var openAIResponse = (await _responseClient.CreateResponseAsync(openAIResponseItems, openAIOptions, cancellationToken).ConfigureAwait(false)).Value;
var task = _createResponseAsync is not null ?
_createResponseAsync(_responseClient, openAIResponseItems, openAIOptions, cancellationToken.ToRequestOptions(streaming: false)) :
_responseClient.CreateResponseAsync(openAIResponseItems, openAIOptions, cancellationToken);
var openAIResponse = (await task.ConfigureAwait(false)).Value;

// Convert the response to a ChatResponse.
return FromOpenAIResponse(openAIResponse, openAIOptions);
Expand Down Expand Up @@ -208,7 +229,9 @@ public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
var openAIResponseItems = ToOpenAIResponseItems(messages, options);
var openAIOptions = ToOpenAIResponseCreationOptions(options);

var streamingUpdates = _responseClient.CreateResponseStreamingAsync(openAIResponseItems, openAIOptions, cancellationToken);
var streamingUpdates = _createResponseStreamingAsync is not null ?
_createResponseStreamingAsync(_responseClient, openAIResponseItems, openAIOptions, cancellationToken.ToRequestOptions(streaming: true)) :
_responseClient.CreateResponseStreamingAsync(openAIResponseItems, openAIOptions, cancellationToken);

return FromOpenAIStreamingResponseUpdatesAsync(streamingUpdates, openAIOptions, cancellationToken);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;

#pragma warning disable CA1307 // Specify StringComparison

namespace Microsoft.Extensions.AI;

/// <summary>Provides utility methods for creating <see cref="RequestOptions"/>.</summary>
internal static class RequestOptionsExtensions
{
/// <summary>Creates a <see cref="RequestOptions"/> configured for use with OpenAI.</summary>
public static RequestOptions ToRequestOptions(this CancellationToken cancellationToken, bool streaming)
{
RequestOptions requestOptions = new()
{
CancellationToken = cancellationToken,
BufferResponse = !streaming
};

requestOptions.AddPolicy(MeaiUserAgentPolicy.Instance, PipelinePosition.PerCall);

return requestOptions;
}

/// <summary>Provides a pipeline policy that adds a "MEAI/x.y.z" user-agent header.</summary>
private sealed class MeaiUserAgentPolicy : PipelinePolicy
{
public static MeaiUserAgentPolicy Instance { get; } = new MeaiUserAgentPolicy();

private static readonly string _userAgentValue = CreateUserAgentValue();

public override void Process(PipelineMessage message, IReadOnlyList<PipelinePolicy> pipeline, int currentIndex)
{
AddUserAgentHeader(message);
ProcessNext(message, pipeline, currentIndex);
}

public override ValueTask ProcessAsync(PipelineMessage message, IReadOnlyList<PipelinePolicy> pipeline, int currentIndex)
{
AddUserAgentHeader(message);
return ProcessNextAsync(message, pipeline, currentIndex);
}

private static void AddUserAgentHeader(PipelineMessage message) =>
message.Request.Headers.Add("User-Agent", _userAgentValue);

private static string CreateUserAgentValue()
{
const string Name = "MEAI";

if (typeof(MeaiUserAgentPolicy).Assembly.GetCustomAttribute<AssemblyInformationalVersionAttribute>()?.InformationalVersion is string version)
{
int pos = version.IndexOf('+');
if (pos >= 0)
{
version = version.Substring(0, pos);
}

if (version.Length > 0)
{
return $"{Name}/{version}";
}
}

return Name;
}
}
}
Loading