Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"Name": "Microsoft.Extensions.AI.OpenAI, Version=10.5.0.0, Culture=neutral, PublicKeyToken=31bf3856ad364e35",
"Name": "Microsoft.Extensions.AI.OpenAI, Version=10.6.0.0, Culture=neutral, PublicKeyToken=31bf3856ad364e35",
"Types": [
{
"Type": "static class OpenAI.Assistants.MicrosoftExtensionsAIAssistantsExtensions",
Expand Down Expand Up @@ -208,6 +208,20 @@
"Stage": "Experimental"
}
]
},
{
"Type": "sealed class Microsoft.Extensions.AI.OpenAIRequestPolicies",
"Stage": "Experimental",
"Methods": [
{
"Member": "Microsoft.Extensions.AI.OpenAIRequestPolicies.OpenAIRequestPolicies();",
"Stage": "Experimental"
},
{
"Member": "void Microsoft.Extensions.AI.OpenAIRequestPolicies.AddPolicy(System.ClientModel.Primitives.PipelinePolicy policy, System.ClientModel.Primitives.PipelinePosition position = System.ClientModel.Primitives.PipelinePosition.PerCall);",
"Stage": "Experimental"
}
]
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#pragma warning disable CA1308 // Normalize strings to uppercase
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
#pragma warning disable SA1204 // Static elements should appear before instance elements
#pragma warning disable MEAI001 // OpenAIRequestPolicies is experimental

namespace Microsoft.Extensions.AI;

Expand Down Expand Up @@ -55,6 +56,9 @@ internal sealed partial class OpenAIChatClient : IChatClient
/// <summary>The underlying <see cref="ChatClient" />.</summary>
private readonly ChatClient _chatClient;

/// <summary>Caller-registered policies applied to every <see cref="RequestOptions"/>.</summary>
private readonly OpenAIRequestPolicies _requestPolicies = new();

/// <summary>Initializes a new instance of the <see cref="OpenAIChatClient"/> class for the specified <see cref="ChatClient"/>.</summary>
/// <param name="chatClient">The underlying client.</param>
/// <exception cref="ArgumentNullException"><paramref name="chatClient"/> is <see langword="null"/>.</exception>
Expand All @@ -76,6 +80,7 @@ public OpenAIChatClient(ChatClient chatClient)
serviceKey is not null ? null :
serviceType == typeof(ChatClientMetadata) ? _metadata :
serviceType == typeof(ChatClient) ? _chatClient :
serviceType == typeof(OpenAIRequestPolicies) ? _requestPolicies :
serviceType.IsInstanceOfType(this) ? this :
null;
}
Expand All @@ -94,7 +99,7 @@ public async Task<ChatResponse> GetResponseAsync(

// Make the call to OpenAI.
var task = _completeChatAsync is not null ?
_completeChatAsync(_chatClient, openAIChatMessages, openAIOptions, cancellationToken.ToRequestOptions(streaming: false)) :
_completeChatAsync(_chatClient, openAIChatMessages, openAIOptions, cancellationToken.ToRequestOptions(streaming: false, _requestPolicies)) :
_chatClient.CompleteChatAsync(openAIChatMessages, openAIOptions, cancellationToken);
var response = await task.ConfigureAwait(false);

Expand All @@ -115,7 +120,7 @@ public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(

// Make the call to OpenAI.
var chatCompletionUpdates = _completeChatStreamingAsync is not null ?
_completeChatStreamingAsync(_chatClient, openAIChatMessages, openAIOptions, cancellationToken.ToRequestOptions(streaming: true)) :
_completeChatStreamingAsync(_chatClient, openAIChatMessages, openAIOptions, cancellationToken.ToRequestOptions(streaming: true, _requestPolicies)) :
_chatClient.CompleteChatStreamingAsync(openAIChatMessages, openAIOptions, cancellationToken);

return FromOpenAIStreamingChatCompletionAsync(chatCompletionUpdates, openAIOptions, cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using OpenAI.Embeddings;

#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
#pragma warning disable MEAI001 // OpenAIRequestPolicies is experimental

namespace Microsoft.Extensions.AI;

Expand Down Expand Up @@ -40,6 +41,9 @@ internal sealed class OpenAIEmbeddingGenerator : IEmbeddingGenerator<string, Emb
/// <summary>The number of dimensions produced by the generator.</summary>
private readonly int? _dimensions;

/// <summary>Caller-registered policies applied to every <see cref="RequestOptions"/>.</summary>
private readonly OpenAIRequestPolicies _requestPolicies = new();

/// <summary>Initializes a new instance of the <see cref="OpenAIEmbeddingGenerator"/> class.</summary>
/// <param name="embeddingClient">The underlying client.</param>
/// <param name="defaultModelDimensions">The number of dimensions to generate in each embedding.</param>
Expand All @@ -66,7 +70,7 @@ public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(IEnumerab
OpenAI.Embeddings.EmbeddingGenerationOptions? openAIOptions = ToOpenAIOptions(options);

var t = _generateEmbeddingsAsync is not null ?
_generateEmbeddingsAsync(_embeddingClient, values, openAIOptions, cancellationToken.ToRequestOptions(streaming: false)) :
_generateEmbeddingsAsync(_embeddingClient, values, openAIOptions, cancellationToken.ToRequestOptions(streaming: false, _requestPolicies)) :
_embeddingClient.GenerateEmbeddingsAsync(values, openAIOptions, cancellationToken);
var embeddings = (await t.ConfigureAwait(false)).Value;

Expand Down Expand Up @@ -104,6 +108,7 @@ void IDisposable.Dispose()
serviceKey is not null ? null :
serviceType == typeof(EmbeddingGeneratorMetadata) ? _metadata :
serviceType == typeof(EmbeddingClient) ? _embeddingClient :
serviceType == typeof(OpenAIRequestPolicies) ? _requestPolicies :
serviceType.IsInstanceOfType(this) ? this :
null;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.ClientModel.Primitives;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using Microsoft.Shared.DiagnosticIds;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Extensions.AI;

/// <summary>
/// Provides an extension hook for adding <see cref="PipelinePolicy"/> instances to the
/// <see cref="RequestOptions"/> built by Microsoft.Extensions.AI for every outbound OpenAI request
/// made through the owning <c>IChatClient</c> or <c>IEmbeddingGenerator</c>.
/// </summary>
/// <remarks>
/// <para>
/// Retrieve the instance via <see cref="IChatClient.GetService(System.Type, object?)"/>
/// (or the equivalent on other Microsoft.Extensions.AI client interfaces) using
/// <see cref="OpenAIRequestPolicies"/> as the service type. The instance is per-client and
/// reachable through any <c>ChatClientBuilder</c> decorator chain.
/// </para>
/// <para>
/// Customer-registered policies are appended <em>after</em> Microsoft.Extensions.AI's own internal
/// policies, so a policy that calls <c>message.Request.Headers.Set("User-Agent", ...)</c>
/// replaces the existing value, while one that calls <c>Headers.Add(...)</c> stacks an
/// additional value.
/// </para>
/// <para>
/// Registration is intended for one-time configuration at startup, but is safe to call
/// concurrently with in-flight requests.
/// </para>
/// </remarks>
[Experimental(DiagnosticIds.Experiments.AIOpenAIRequestPolicies, UrlFormat = DiagnosticIds.UrlFormat)]
public sealed class OpenAIRequestPolicies
{
private static readonly Entry[] _empty = Array.Empty<Entry>();

private Entry[] _entries = _empty;

/// <summary>Initializes a new instance of the <see cref="OpenAIRequestPolicies"/> class.</summary>
public OpenAIRequestPolicies()
{
}

/// <summary>
/// Adds a <see cref="PipelinePolicy"/> to be applied to every <see cref="RequestOptions"/>
/// produced for outbound OpenAI requests by the owning Microsoft.Extensions.AI client.
/// </summary>
/// <param name="policy">The pipeline policy to register. Must not be <see langword="null"/>.</param>
/// <param name="position">
/// The position in the pipeline at which to place the policy. Defaults to
/// <see cref="PipelinePosition.PerCall"/>, which runs the policy once per logical request
/// (for example, to stamp a User-Agent or correlation header).
/// </param>
/// <exception cref="ArgumentNullException"><paramref name="policy"/> is <see langword="null"/>.</exception>
public void AddPolicy(PipelinePolicy policy, PipelinePosition position = PipelinePosition.PerCall)
{
_ = Throw.IfNull(policy);

var newEntry = new Entry(policy, position);

// Lock-free append: copy-on-write with CAS retry.
while (true)
{
var current = Volatile.Read(ref _entries);
var updated = new Entry[current.Length + 1];
Array.Copy(current, updated, current.Length);
updated[current.Length] = newEntry;

if (Interlocked.CompareExchange(ref _entries, updated, current) == current)
{
return;
}
}
}

/// <summary>
/// Applies all registered policies to the supplied <see cref="RequestOptions"/>.
/// Called by the Microsoft.Extensions.AI OpenAI clients after their own internal policies
/// have been registered.
/// </summary>
internal void ApplyTo(RequestOptions requestOptions)
{
var snapshot = Volatile.Read(ref _entries);
for (int i = 0; i < snapshot.Length; i++)
{
var entry = snapshot[i];
requestOptions.AddPolicy(entry.Policy, entry.Position);
}
}

private readonly struct Entry
{
public Entry(PipelinePolicy policy, PipelinePosition position)
{
Policy = policy;
Position = position;
}

public PipelinePolicy Policy { get; }
public PipelinePosition Position { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
#pragma warning disable S3254 // Default parameter values should not be passed as arguments
#pragma warning disable SA1204 // Static elements should appear before instance elements
#pragma warning disable MEAI001 // OpenAIRequestPolicies is experimental

namespace Microsoft.Extensions.AI;

Expand Down Expand Up @@ -59,6 +60,9 @@ private static readonly Func<ResponsesClient, GetResponseOptions, RequestOptions
/// <summary>The default model ID to use for the chat client.</summary>
private readonly string? _defaultModelId;

/// <summary>Caller-registered policies applied to every <see cref="RequestOptions"/>.</summary>
private readonly OpenAIRequestPolicies _requestPolicies = new();

/// <summary>Initializes a new instance of the <see cref="OpenAIResponsesChatClient"/> class for the specified <see cref="ResponsesClient"/>.</summary>
/// <param name="responseClient">The underlying client.</param>
/// <param name="defaultModelId">The default model ID to use for the chat client.</param>
Expand All @@ -82,6 +86,7 @@ public OpenAIResponsesChatClient(ResponsesClient responseClient, string? default
serviceKey is not null ? null :
serviceType == typeof(ChatClientMetadata) ? _metadata :
serviceType == typeof(ResponsesClient) ? _responseClient :
serviceType == typeof(OpenAIRequestPolicies) ? _requestPolicies :
serviceType.IsInstanceOfType(this) ? this :
null;
}
Expand All @@ -100,7 +105,7 @@ public async Task<ChatResponse> GetResponseAsync(
// Provided continuation token signals that an existing background response should be fetched.
if (GetContinuationToken(messages, options) is { } token)
{
var getTask = _responseClient.GetResponseAsync(token.ResponseId, include: null, stream: null, startingAfter: null, includeObfuscation: null, cancellationToken.ToRequestOptions(streaming: false));
var getTask = _responseClient.GetResponseAsync(token.ResponseId, include: null, stream: null, startingAfter: null, includeObfuscation: null, cancellationToken.ToRequestOptions(streaming: false, _requestPolicies));
var response = (ResponseResult)await getTask.ConfigureAwait(false);
return FromOpenAIResponse(response, openAIOptions, openAIConversationId);
}
Expand All @@ -111,7 +116,7 @@ public async Task<ChatResponse> GetResponseAsync(
}

// Make the call to the ResponsesClient.
var createTask = _responseClient.CreateResponseAsync((BinaryContent)openAIOptions, cancellationToken.ToRequestOptions(streaming: false));
var createTask = _responseClient.CreateResponseAsync((BinaryContent)openAIOptions, cancellationToken.ToRequestOptions(streaming: false, _requestPolicies));
var openAIResponsesResult = (ResponseResult)await createTask.ConfigureAwait(false);

// Convert the response to a ChatResponse.
Expand Down Expand Up @@ -330,7 +335,7 @@ public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(

Debug.Assert(_getResponseStreamingAsync is not null, $"Unable to find {nameof(_getResponseStreamingAsync)} method");
IAsyncEnumerable<StreamingResponseUpdate> getUpdates = _getResponseStreamingAsync is not null ?
_getResponseStreamingAsync(_responseClient, getOptions, cancellationToken.ToRequestOptions(streaming: true)) :
_getResponseStreamingAsync(_responseClient, getOptions, cancellationToken.ToRequestOptions(streaming: true, _requestPolicies)) :
_responseClient.GetResponseStreamingAsync(getOptions, cancellationToken);

return FromOpenAIStreamingResponseUpdatesAsync(getUpdates, openAIOptions, openAIConversationId, token.ResponseId, cancellationToken);
Expand All @@ -343,7 +348,7 @@ public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(

Debug.Assert(_createResponseStreamingAsync is not null, $"Unable to find {nameof(_createResponseStreamingAsync)} method");
AsyncCollectionResult<StreamingResponseUpdate> createUpdates = _createResponseStreamingAsync is not null ?
_createResponseStreamingAsync(_responseClient, openAIOptions, cancellationToken.ToRequestOptions(streaming: true)) :
_createResponseStreamingAsync(_responseClient, openAIOptions, cancellationToken.ToRequestOptions(streaming: true, _requestPolicies)) :
_responseClient.CreateResponseStreamingAsync(openAIOptions, cancellationToken);

return FromOpenAIStreamingResponseUpdatesAsync(createUpdates, openAIOptions, openAIConversationId, cancellationToken: cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,23 @@
using System.Threading.Tasks;

#pragma warning disable CA1307 // Specify StringComparison
#pragma warning disable MEAI001 // OpenAIRequestPolicies is experimental

namespace Microsoft.Extensions.AI;

/// <summary>Provides utility methods for creating <see cref="RequestOptions"/>.</summary>
internal static class RequestOptionsExtensions
{
/// <summary>Creates a <see cref="RequestOptions"/> configured for use with OpenAI.</summary>
public static RequestOptions ToRequestOptions(this CancellationToken cancellationToken, bool streaming)
public static RequestOptions ToRequestOptions(this CancellationToken cancellationToken, bool streaming) =>
ToRequestOptions(cancellationToken, streaming, policies: null);

/// <summary>
/// Creates a <see cref="RequestOptions"/> configured for use with OpenAI, applying any
/// caller-registered <see cref="OpenAIRequestPolicies"/> after Microsoft.Extensions.AI's own
/// internal policies.
/// </summary>
public static RequestOptions ToRequestOptions(this CancellationToken cancellationToken, bool streaming, OpenAIRequestPolicies? policies)
{
RequestOptions requestOptions = new()
{
Expand All @@ -25,6 +34,8 @@ public static RequestOptions ToRequestOptions(this CancellationToken cancellatio

requestOptions.AddPolicy(MeaiUserAgentPolicy.Instance, PipelinePosition.PerCall);

policies?.ApplyTo(requestOptions);

return requestOptions;
}

Expand Down
1 change: 1 addition & 0 deletions src/Shared/DiagnosticIds/DiagnosticIds.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ internal static class Experiments
internal const string AIToolSearch = AIExperiments;
internal const string AIRealTime = AIExperiments;
internal const string AIFiles = AIExperiments;
internal const string AIOpenAIRequestPolicies = AIExperiments;

// These diagnostic IDs are defined by the OpenAI package for its experimental APIs.
// We use the same IDs so consumers do not need to suppress additional diagnostics
Expand Down
Loading
Loading