diff --git a/.dotnet.azure/.tests.staging/AoaiTestBase.cs b/.dotnet.azure/.tests.staging/AoaiTestBase.cs deleted file mode 100644 index d6a229f20..000000000 --- a/.dotnet.azure/.tests.staging/AoaiTestBase.cs +++ /dev/null @@ -1,706 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.ClientModel; -using System.ClientModel.Primitives; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Text.Json; -using System.Threading; -using System.Threading.Tasks; -using Azure.Identity; - -//using Azure.AI.OpenAI.Tests.Models; -//using Azure.AI.OpenAI.Tests.Utils; -//using Azure.AI.OpenAI.Tests.Utils.Config; -//using Azure.Core.TestFramework; -//using Azure.Core.TestFramework.Models; -using NUnit.Framework; -using NUnit.Framework.Interfaces; -using OpenAI.Assistants; -using OpenAI.Audio; -using OpenAI.Batch; -using OpenAI.Chat; -using OpenAI.Embeddings; -using OpenAI.Files; -using OpenAI.FineTuning; -using OpenAI.Images; -//using OpenAI.Tests; -using OpenAI.VectorStores; -using RetryMode = Azure.Core.RetryMode; -using RetryOptions = Azure.Core.RetryOptions; -using TokenCredential = Azure.Core.TokenCredential; - -namespace Azure.AI.OpenAI.Tests; - -#pragma warning disable OPENAI001 - -public class AoaiTestBase // : RecordedTestBase -{ - private const string AZURE_URI_SANITIZER_PATTERN = @"(?<=/(subscriptions|resourceGroups|accounts)/)([^/]+?)(?=(/|$))"; - private const string SMALL_1x1_PNG = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAAJcEhZcwAAFiQAABYkAZsVxhQAAAAMSURBVBhXY2BgYAAAAAQAAVzN/2kAAAAASUVORK5CYII="; - - public static readonly DateTimeOffset START_2024 = new DateTimeOffset(2024, 01, 01, 00, 00, 00, TimeSpan.Zero); - public static readonly DateTimeOffset UNIX_EPOCH = -#if NETFRAMEWORK - DateTimeOffset.Parse("1970-01-01T00:00:00.0000000+00:00"); -#else - DateTimeOffset.UnixEpoch; -#endif - - public AoaiTestBase(bool isAsync) { } - - public TClient GetTestClient() - { - AzureOpenAIClient topLevelClient = GetTestTopLevelClient(null); - string getDeployment() => "gpt-4"; - - object clientObject = null; - - switch (typeof(TClient).Name) - { - case nameof(AssistantClient): - clientObject = topLevelClient.GetAssistantClient(); - break; - case nameof(AudioClient): - clientObject = topLevelClient.GetAudioClient(getDeployment()); - break; - case nameof(BatchClient): - clientObject = topLevelClient.GetBatchClient(getDeployment()); - break; - case nameof(ChatClient): - clientObject = topLevelClient.GetChatClient(getDeployment()); - break; - case nameof(EmbeddingClient): - clientObject = topLevelClient.GetEmbeddingClient(getDeployment()); - break; - case nameof(FileClient): - clientObject = topLevelClient.GetFileClient(); - break; - case nameof(FineTuningClient): - clientObject = topLevelClient.GetFineTuningClient(); - break; - case nameof(ImageClient): - clientObject = topLevelClient.GetImageClient(getDeployment()); - break; - case nameof(VectorStoreClient): - clientObject = topLevelClient.GetVectorStoreClient(); - break; - } - - return (TClient)clientObject; - } - - // internal TestConfig TestConfig { get; } - // internal Assets Assets { get; } - // internal DisableRecordingInterceptor RecordingDisabler { get; } - - // protected AoaiTestBase(bool isAsync, RecordedTestMode? mode = null) - // : base(isAsync, mode) - // { - // TestConfig = new TestConfig(Mode); - // Assets = new Assets(TestEnvironment); - - // // Disable additional fluff that is causing issues - // TestDiagnostics = false; - - // // Add sanitizers to prevent resource names from leaking into recordings - // UriRegexSanitizers.Add(new UriRegexSanitizer(SanitizedJsonConfig.HOST_SUBDOMAIN_PATTERN) - // { - // Value = SanitizedJsonConfig.MASK_STRING - // }); - // UriRegexSanitizers.Add(new UriRegexSanitizer(AZURE_URI_SANITIZER_PATTERN) - // { - // Value = SanitizedJsonConfig.MASK_STRING - // }); - // HeaderRegexSanitizers.Add(new HeaderRegexSanitizer("Azure-AsyncOperation") - // { - // Regex = AZURE_URI_SANITIZER_PATTERN, - // Value = SanitizedJsonConfig.MASK_STRING - // }); - // HeaderRegexSanitizers.Add(new HeaderRegexSanitizer("Location") - // { - // Regex = AZURE_URI_SANITIZER_PATTERN, - // Value = SanitizedJsonConfig.MASK_STRING - // }); - // BodyKeySanitizers.Add(new BodyKeySanitizer("$..endpoint") - // { - // Regex = SanitizedJsonConfig.HOST_SUBDOMAIN_PATTERN, - // Value = SanitizedJsonConfig.MASK_STRING - // }); - // BodyKeySanitizers.Add(new BodyKeySanitizer("$..id") - // { - // Regex = AZURE_URI_SANITIZER_PATTERN, - // Value = SanitizedJsonConfig.MASK_STRING - // }); - - // // Add sanitizers to prevent our keys from leaking into the recordings - // JsonPathSanitizers.Add("*..key"); - // JsonPathSanitizers.Add("*..api_key"); - - // // Multi-part form data gives the test-proxy that is used for recording and playback indigestion (it always thinks it needs - // // to re-record the test on playback). So let's add an interceptor that will automatically disable body recording for specific - // // client methods calls, and then re-enable it afterwards. - // RecordingDisabler = new(() => Recording); - // RecordingDisabler.DisableBodyRecordingFor(nameof(FileClient.UploadFileAsync)); - - // IgnoredHeaders.Add("x-ms-client-request-id"); - - // // Data URIs trimmed to prevent the recording from being too large - // BodyKeySanitizers.Add(new BodyKeySanitizer("$..url") - // { - // Regex = @"(?<=data:image/png;base64,)(.+)", - // Value = SMALL_1x1_PNG - // }); - // // Base64 encoded images in the response are replaced with a 1x1 black pixel PNG image to ensure valid data - // BodyKeySanitizers.Add(new BodyKeySanitizer($"..b64_json") - // { - // Value = SMALL_1x1_PNG - // }); - // } - - /// - /// Gets the top level test client to use for testing. - /// - /// The test configuration to use - /// (Optional) The client options to use. - /// (Optional) The token credential to use. If this is null, an API key will be read from the - /// test configuration. - /// (Optional) The key credential to use instead of the one from the configuration. - public virtual AzureOpenAIClient GetTestTopLevelClient( - object config, // IConfiguration? config, - object options = null, // TestClientOptions? options = null, - TokenCredential tokenCredential = null, - ApiKeyCredential keyCredential = null) - { - string rawEndpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); - Uri endpoint = new(rawEndpoint); - TokenCredential credential = new DefaultAzureCredential(); - - return new AzureOpenAIClient(endpoint, credential); - // // First validate that the config has the parameters we need - // if (config == null) - // { - // throw CreateKeyNotFoundEx("any configuration"); - // } - // else if (config.Endpoint is null) - // { - // throw CreateKeyNotFoundEx("endpoint"); - // } - // else if (tokenCredential == null && keyCredential == null && string.IsNullOrEmpty(config.Key)) - // { - // throw CreateKeyNotFoundEx("API key"); - // } - - // // Configure the test options as needed - // options ??= new(); - // Action? requestAction = options.ShouldOutputRequests ? DumpRequest : null; - // Action? responseAction = options.ShouldOutputResponses ? DumpResponse : null; - // options.AddPolicy(new TestPipelinePolicy(requestAction, responseAction), PipelinePosition.PerCall); - - // // If we are in playback, or record mode we should set the transport to the test proxy transport, except - // // in the case where we've explicitly specified the transport ourselves. There are cases where we use a - // // mock pipeline and we don't want those to go to the test proxy. - // if (options.Transport == null) - // { - // // TODO FIXME update once test framework code is updated - // /* NOTE: - // * Normally we would call the base class RecordedTestBase.InstrumentClientOptions. Unfortunately this doesn't - // * currently work since the test framework still relies on a version of Azure.Core that has not been updated - // * to use the new System.ClientModel types. Thus InstrumentClientOptions expects a type that inherits from - // * Azure.Core.ClientOptions, whereas we inherit from System.ClientModel.Primitives.ClientPipelineOptions. For - // * now we duplicate the code from InstrumentClientOptions here - // */ - - // if (Mode == RecordedTestMode.Playback) - // { - // // You guessed it: the constructor for RetryOptions is internal only. So plan B: - // RetryOptions retryOpt = (RetryOptions)Activator.CreateInstance(typeof(RetryOptions), true)!; - - // // Not making the timeout zero so retry code still goes async - // retryOpt.Delay = TimeSpan.FromMilliseconds(10); - // retryOpt.Mode = RetryMode.Fixed; - - // options.RetryPolicy = new Utils.Pipeline.ClientRetryPolicyAdapter(retryOpt); - // } - - // // No need to set the transport if we are in Live mode - // if (Mode != RecordedTestMode.Live) - // { - // // Wait what's this? More private or internal only things I need access to? - // var proxyAccess = NonPublic.FromField("_proxy"); - // var disableRecordingAccess = NonPublic.FromField>("_disableRecording"); - - // options.Transport = new Utils.Pipeline.ProxyTransport( - // proxyAccess.Get(this), - // Recording, - // () => disableRecordingAccess.Get(Recording).Value); - // } - // } - - // AzureOpenAIClient topLevelClient; - // if (tokenCredential != null) - // { - // topLevelClient = new AzureOpenAIClient(config.Endpoint, tokenCredential, options); - // } - // else - // { - // topLevelClient = new AzureOpenAIClient(config.Endpoint, keyCredential ?? new ApiKeyCredential(config.Key!), options); - // } - - // return topLevelClient; - } - - // /// - // /// Gets the properly instrumented client to use for testing. This have proper support for automatic sync/async method testing, - // /// as well as recording, and playback support. - // /// - // /// (Optional) The client options to use. - // /// (Optional) The token credential to use. If this is null, an API key will be read from the - // /// test configuration. - // /// (Optional) The key credential to use instead of the one from the configuration. - // /// The test client instance. - // public virtual TClient GetTestClient(TestClientOptions? options = null, TokenCredential? tokenCredential = null, ApiKeyCredential? keyCredential = null) - // => GetTestClient(TestConfig.GetConfig(), options, tokenCredential, keyCredential); - - // /// - // /// Gets the properly instrumented client to use for testing. This have proper support for automatic sync/async method testing, - // /// as well as recording, and playback support. - // /// - // /// - // /// (Optional) The client options to use. - // /// (Optional) The token credential to use. If this is null, an API key will be read from the - // /// test configuration. - // /// (Optional) The key credential to use instead of the one from the configuration. - // /// The test client instance. - // public virtual TClient GetTestClient(string configName, TestClientOptions? options = null, TokenCredential? tokenCredential = null, ApiKeyCredential? keyCredential = null) - // => GetTestClient(TestConfig.GetConfig(configName), options, tokenCredential, keyCredential); - - // /// - // /// Gets a different type of client using the same configuration as the specified client. - // /// - // /// The type of other client to create. - // /// The client instance whose configuration we want to use. - // /// (Optional) The specific deployment to use instead of the one from the config. - // /// - // /// The client instance passed was not instrumented - // public virtual TExplicitClient GetTestClientFrom(TClient client, string? deploymentName = null) - // { - // AzureOpenAiInstrumented? instrumented = _clientToTopLevel.FirstOrDefault(e => ReferenceEquals(client, e.Client)); - // if (instrumented?.TopLevelClient != null - // && instrumented?.Config != null) - // { - // return GetTestClient(instrumented.TopLevelClient, instrumented.Config, deploymentName); - // } - - // throw new NotSupportedException("The client provided was not properly instrumented. Please make sure to get your test client " + - // "instances using the GetTestClient() methods"); - // } - - // /// - // /// Disables the recording of request bodies for the specified method in the current client. - // /// - // /// The method name. - // public virtual void DisableRequestBodyRecording(string methodName) - // => RecordingDisabler.DisableBodyRecordingFor(methodName); - - // /// - // /// Polls until a condition has been met with a maximum wait time. The function will always return the last value even - // /// if the condition was not met. - // /// - // /// The value in the . - // /// The initial value. - // /// The asynchronous function to get the latest state of the value. - // /// When we should stop waiting. - // /// (Optional) The amount of time to wait between retries. This will be ignored in playback - // /// mode. Default is 2 seconds. - // /// (Optional) The maximum amount of time to wait until the condition becomes true. This will be ignored in - // /// playback mode. The default is 2 minutes. - // /// The final state. This will return when the conditions have been met or we timed out. - // protected virtual Task WaitUntilReturnLast(T initialValue, Func>> getAsync, Predicate stopCondition, TimeSpan? waitTimeBetweenRequests = null, TimeSpan? maxWait = null) - // => WaitUntilReturnLast(initialValue, new Func>(async () => await getAsync().ConfigureAwait(false)), stopCondition, waitTimeBetweenRequests, maxWait); - - // /// - // /// Polls until a condition has been met with a maximum wait time. The function will always return the last value even - // /// if the condition was not met. - // /// - // /// The return value. - // /// The initial value. - // /// The asynchronous function to get the latest state of the value. - // /// When we should stop waiting. - // /// (Optional) The amount of time to wait between retries. This will be ignored in playback - // /// mode. Default is 2 seconds. - // /// (Optional) The maximum amount of time to wait until the condition becomes true. This will be ignored in - // /// playback mode. The default is 2 minutes. - // /// The final state. This will return when the conditions have been met or we timed out. - // protected virtual async Task WaitUntilReturnLast(T initialValue, Func> getAsync, Predicate stopCondition, TimeSpan? waitTimeBetweenRequests = null, TimeSpan? maxWait = null) - // { - // TimeSpan delay, max; - // if (Mode == RecordedTestMode.Playback) - // { - // delay = TimeSpan.FromMilliseconds(10); - // max = TimeSpan.FromSeconds(30); - // } - // else - // { - // delay = waitTimeBetweenRequests ?? TimeSpan.FromSeconds(2); - // max = maxWait ?? TimeSpan.FromMinutes(2); - // } - - // DateTimeOffset stopTime = DateTimeOffset.Now + max; - // T result = initialValue; - - // while (!stopCondition(result) && DateTimeOffset.Now < stopTime) - // { - // await Task.Delay(delay).ConfigureAwait(false); - // result = await getAsync().ConfigureAwait(false); - // } - - // return result; - // } - - // /// - // /// Gets the properly instrumented client to use for testing. This have proper support for automatic sync/async method testing, - // /// as well as recording, and playback support. - // /// - // /// The test configuration to use - // /// (Optional) The client options to use. - // /// (Optional) The token credential to use. If this is null, an API key will be read from the - // /// test configuration. - // /// (Optional) The key credential to use instead of the one from the configuration. - // /// The test client instance. - // protected virtual TClient GetTestClient(IConfiguration? config, TestClientOptions? options = null, TokenCredential? tokenCredential = null, ApiKeyCredential? keyCredential = null) - // { - // AzureOpenAIClient topLevelClient = GetTestTopLevelClient(config, options, tokenCredential, keyCredential); - // return GetTestClient(topLevelClient, config!); - // } - - // /// - // /// Gets the properly instrumented client to use for testing. This have proper support for automatic sync/async method testing, - // /// as well as recording, and playback support. - // /// - // /// The type of test client to get. - // /// The top level client to use. - // /// The configuration to use to get the deployment information (if needed). - // /// The instrumented client instance to use. - // /// Support for the type of client being requested has not been implemented yet. - // protected virtual TExplicitClient GetTestClient(AzureOpenAIClient topLevelClient, IConfiguration config, string? deploymentName = null) - // { - // Func getDeployment = () => deploymentName ?? config?.Deployment ?? throw CreateKeyNotFoundEx("deployment"); - // object clientObject; - - // switch (typeof(TExplicitClient).Name) - // { - // case nameof(AssistantClient): - // clientObject = topLevelClient.GetAssistantClient(); - // break; - // case nameof(AudioClient): - // clientObject = topLevelClient.GetAudioClient(getDeployment()); - // break; - // case nameof(BatchClient): - // clientObject = topLevelClient.GetBatchClient(getDeployment()); - // break; - // case nameof(ChatClient): - // clientObject = topLevelClient.GetChatClient(getDeployment()); - // break; - // case nameof(EmbeddingClient): - // clientObject = topLevelClient.GetEmbeddingClient(getDeployment()); - // break; - // case nameof(FileClient): - // clientObject = topLevelClient.GetFileClient(); - // break; - // case nameof(FineTuningClient): - // clientObject = topLevelClient.GetFineTuningClient(); - // break; - // case nameof(ImageClient): - // clientObject = topLevelClient.GetImageClient(getDeployment()); - // break; - // case nameof(VectorStoreClient): - // clientObject = topLevelClient.GetVectorStoreClient(); - // break; - // case nameof(AzureDeploymentClient): - // var accessor = NonPublic.FromField("_transport"); - // clientObject = new AzureDeploymentClient( - // config, - // TestEnvironment.Credential, - // transport: accessor.Get(topLevelClient.Pipeline)); - // break; - // default: - // throw new NotImplementedException($"Test client helpers not yet implemented for {typeof(TExplicitClient)}"); - // }; - - // object instrumented = InstrumentClient(typeof(TExplicitClient), clientObject, [RecordingDisabler]); - - // // Keep track of the corresponding top level client and config - // _clientToTopLevel.Add(new AzureOpenAiInstrumented - // { - // Client = instrumented, - // TopLevelClient = topLevelClient, - // Config = config, - // }); - - // return (TExplicitClient)instrumented; - // } - - // private Exception CreateKeyNotFoundEx(string whatIsMissing) - // { - // return new KeyNotFoundException($"Could not find any {whatIsMissing} to use. Please make sure you have the necessary" + - // $" {TestConfig.AssetsJson} config file, or have the needed environment variables set"); - // } - - // private static void DumpRequest(PipelineRequest request) - // { - // Console.WriteLine($"--- New request ---"); - // string headers = request.Headers - // .Select(header => $"{header.Key}={(header.Key.ToLower().Contains("auth") ? "***" : header.Value)}") - // .Aggregate(string.Empty, (current, next) => string.Format("{0},{1}", current, next)); - // Console.WriteLine($"Headers: {headers}"); - // Console.WriteLine($"{request.Method} URI: {request?.Uri}"); - // if (request!.Content is not null) - // { - // using MemoryStream stream = new(); - // request.Content.WriteTo(stream, default); - // stream.Position = 0; - // using StreamReader reader = new(stream); - // Console.WriteLine(reader.ReadToEnd()); - // } - // } - - // private static void DumpResponse(PipelineResponse response) - // { - // Console.WriteLine($"--- Response --- "); - // } - - protected void ValidateById(string id) - { - Assert.That(id, Is.Not.Null.Or.Empty); - switch (typeof(T).Name) - { - case nameof(Assistant): _assistantIdsToDelete.Add(id); break; - case nameof(AssistantThread): _threadIdsToDelete.Add(id); break; - case nameof(OpenAIFileInfo): _fileIdsToDelete.Add(id); break; - case nameof(ThreadRun): break; - case nameof(VectorStore): _vectorStoreIdsToDelete.Add(id); break; - default: throw new NotImplementedException(); - } - } - - protected void ValidateById(string id, string parentId) - { - Assert.That(id, Is.Not.Null.Or.Empty); - Assert.That(parentId, Is.Not.Null.Or.Empty); - switch (typeof(T).Name) - { - case nameof(ThreadMessage): - _threadIdsWithMessageIdsToDelete.Add((parentId, id)); - break; - case nameof(VectorStoreFileAssociation): - _vectorStoreFileAssociationsToRemove.Add((parentId, id)); - break; - default: - throw new NotImplementedException(); - } - } - - /// - /// Performs basic, invariant validation of a target that was just instantiated from its corresponding origination - /// mechanism. If applicable, the instance is recorded into the test run for cleanup of persistent resources. - /// - /// Instance type being validated. - /// The instance to validate. - /// The provided instance type isn't supported. - protected void Validate(T target) - { - if (target is ThreadMessage message) - { - ValidateById(message.Id, message.ThreadId); - } - else if (target is VectorStoreFileAssociation fileAssociation) - { - ValidateById(fileAssociation.VectorStoreId, fileAssociation.FileId); - } - else - { - ValidateById(target switch - { - Assistant assistant => assistant.Id, - AssistantThread thread => thread.Id, - OpenAIFileInfo file => file.Id, - ThreadRun run => run.Id, - VectorStore store => store.Id, - _ => throw new NotImplementedException(), - }); - } - } - - [TearDown] - protected void Cleanup() - { - AzureOpenAIClient topLevelCleanupClient = GetTestTopLevelClient(null); - //AzureOpenAIClient topLevelCleanupClient = GetTestTopLevelClient(TestConfig.GetConfig(), new() - //{ - // ShouldOutputRequests = false, - // ShouldOutputResponses = false, - //}); - AssistantClient client = topLevelCleanupClient.GetAssistantClient(); - VectorStoreClient vectorStoreClient = topLevelCleanupClient.GetVectorStoreClient(); - FileClient fileClient = topLevelCleanupClient.GetFileClient(); - RequestOptions requestOptions = new() { ErrorOptions = ClientErrorBehaviors.NoThrow, }; - foreach ((string threadId, string messageId) in _threadIdsWithMessageIdsToDelete) - { - Console.WriteLine($"Cleanup: {messageId} -> {client.DeleteMessage(threadId, messageId, requestOptions)?.GetRawResponse().Status}"); - } - foreach (string assistantId in _assistantIdsToDelete) - { - Console.WriteLine($"Cleanup: {assistantId} -> {client.DeleteAssistant(assistantId, requestOptions)?.GetRawResponse().Status}"); - } - foreach (string threadId in _threadIdsToDelete) - { - Console.WriteLine($"Cleanup: {threadId} -> {client.DeleteThread(threadId, requestOptions)?.GetRawResponse().Status}"); - } - foreach ((string vectorStoreId, string fileId) in _vectorStoreFileAssociationsToRemove) - { - Console.WriteLine($"Cleanup: {vectorStoreId}<->{fileId} => {vectorStoreClient.RemoveFileFromStore(vectorStoreId, fileId, requestOptions)?.GetRawResponse().Status}"); - } - foreach (string vectorStoreId in _vectorStoreIdsToDelete) - { - Console.WriteLine($"Cleanup: {vectorStoreId} => {vectorStoreClient.DeleteVectorStore(vectorStoreId, requestOptions)?.GetRawResponse().Status}"); - } - foreach (string fileId in _fileIdsToDelete) - { - Console.WriteLine($"Cleanup: {fileId} -> {fileClient.DeleteFile(fileId, requestOptions)?.GetRawResponse().Status}"); - } - _threadIdsWithMessageIdsToDelete.Clear(); - _assistantIdsToDelete.Clear(); - _threadIdsToDelete.Clear(); - _vectorStoreFileAssociationsToRemove.Clear(); - _vectorStoreIdsToDelete.Clear(); - _fileIdsToDelete.Clear(); - - // If we are in recording mode, update the recorded playback configuration as well - //if (Mode == RecordedTestMode.Record - // && TestContext.CurrentContext.Result.Outcome == ResultState.Success) - //{ - // TestConfig.SavePlaybackConfig(); - //} - } - - protected static void ValidateClientResult(ClientResult result) - { - Assert.That(result, Is.Not.Null); - Assert.That(result.GetRawResponse(), Is.Not.Null); - } - - protected static PipelineResponse ValidateClientResultResponse(ClientResult result) - { - ValidateClientResult(result); - - PipelineResponse response = result.GetRawResponse(); - Assert.That(response.Status, Is.GreaterThanOrEqualTo(200).And.LessThan(300)); - Assert.That(response.Headers, Is.Not.Null); - // Assert.That(response.Headers.GetFirstValueOrDefault("Content-Type"), Does.StartWith("application/json")); - Assert.That(response.Content, Is.Not.Null); - - return response; - } - - protected virtual TModel ValidateAndParse(ClientResult result) where TModel : IJsonModel - { - var response = ValidateClientResultResponse(result); - - TModel model = ModelReaderWriter.Read(response.Content, ModelReaderWriterOptions.Json); - Assert.That(model, Is.Not.Null); - return model!; - } - - //protected virtual TModel ValidateAndParse(ClientResult result, JsonSerializerOptions? options = null) - //{ - // var response = ValidateClientResultResponse(result); - - // using Stream stream = response.Content.ToStream(); - // Assert.That(stream, Is.Not.Null); - - // TModel? model = JsonHelpers.Deserialize(stream, options ?? JsonHelpers.OpenAIJsonOptions); - // Assert.That(model, Is.Not.Null); - // return model!; - //} - - // protected AsyncCollectionResult SyncOrAsync(TClient client, Func> sync, Func> async) - // { - // // TODO FIXME HACK Since the test framework doesn't currently support async result collection, this methods provides - // // a simplified way to make explicit calls to the right methods in tests - // TClient rawClient = GetOriginal(client); - - // if (IsAsync) - // { - // return async(rawClient); - // } - // else - // { - // CollectionResult syncCollection = sync(rawClient); - // return new SyncToAsyncCollectionResult(syncCollection); - // } - // } - - // protected AsyncPageCollection SyncOrAsync(TClient client, Func> sync, Func> async) - // { - // // TODO FIXME HACK Since the test framework doesn't currently support async result collection, this methods provides - // // a simplified way to make explicit calls to the right methods in tests - // TClient rawClient = GetOriginal(client); - - // if (IsAsync) - // { - // return async(rawClient); - // } - // else - // { - // PageCollection syncCollection = sync(rawClient); - // return new SyncToAsyncPageCollection(syncCollection); - // } - // } - - // protected Task> SyncOrAsyncList(TClient client, Func> sync, Func> async) - // { - // // TODO FIXME HACK Since the test framework doesn't currently support async result collection, this methods provides - // // a simplified way to make explicit calls to the right methods in tests - // TClient rawClient = GetOriginal(client); - - // if (IsAsync) - // { - // return async(rawClient).ToEnumerableAsync(); - // } - // else - // { - // return Task.FromResult(sync(rawClient).ToList()); - // } - // } - - // internal class AzureOpenAiInstrumented - // { - // required public object Client { get; init; } - // required public AzureOpenAIClient TopLevelClient { get; init; } - // required public IConfiguration Config { get; init; } - // } - - private readonly List _assistantIdsToDelete = []; - private readonly List _threadIdsToDelete = []; - private readonly List<(string, string)> _threadIdsWithMessageIdsToDelete = []; - private readonly List _fileIdsToDelete = []; - private readonly List<(string, string)> _vectorStoreFileAssociationsToRemove = []; - private readonly List _vectorStoreIdsToDelete = []; - // internal readonly List _clientToTopLevel = new(); - //} - - //public class TestClientOptions : AzureOpenAIClientOptions - //{ - // public TestClientOptions() : base() - // { } - - // public TestClientOptions(ServiceVersion version) : base(version) - // { } - - // public bool ShouldOutputRequests { get; set; } = true; - // public bool ShouldOutputResponses { get; set; } = true; -} diff --git a/.dotnet.azure/.tests.staging/AssistantTests.cs b/.dotnet.azure/.tests.staging/AssistantTests.cs deleted file mode 100644 index 115428724..000000000 --- a/.dotnet.azure/.tests.staging/AssistantTests.cs +++ /dev/null @@ -1,674 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#nullable disable - -using System; -using System.ClientModel; -using System.ClientModel.Primitives; -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using Azure.AI.OpenAI.Assistants; -using Azure.Core; -using Azure.Identity; -using NUnit.Framework; -using OpenAI; -using OpenAI.Assistants; -using OpenAI.Audio; -using OpenAI.Batch; -using OpenAI.Chat; -using OpenAI.Embeddings; -using OpenAI.Files; -using OpenAI.FineTuning; -using OpenAI.Images; -using OpenAI.VectorStores; - -namespace Azure.AI.OpenAI.Tests; - -#pragma warning disable OPENAI001 -#pragma warning disable AOAI001 - -public class AssistantTests : AoaiTestBase -{ - public AssistantTests() : base(isAsync: true) { } - public AssistantTests(bool isAsync) : base(isAsync) { } - - [Test] - [Category("Smoke")] - public void CanCreateClient() => Assert.That(GetTestClient(), Is.InstanceOf()); - - [Test] - public async Task BasicAssistantOperationsWork() - { - AssistantClient client = GetTestClient(); - string modelName = "gpt-4"; // client.DeploymentOrThrow(); - Assistant assistant = await client.CreateAssistantAsync(modelName); - Validate(assistant); - Assert.That(assistant.Name, Is.Null.Or.Empty); - assistant = await client.ModifyAssistantAsync(assistant.Id, new AssistantModificationOptions() - { - Name = "test assistant name", - }); - Assert.That(assistant.Name, Is.EqualTo("test assistant name")); - bool deleted = await client.DeleteAssistantAsync(assistant.Id); - Assert.That(deleted, Is.True); - assistant = await client.CreateAssistantAsync(modelName, new AssistantCreationOptions() - { - Metadata = - { - ["testkey"] = "hello!" - }, - }); - Validate(assistant); - Assistant retrievedAssistant = await client.GetAssistantAsync(assistant.Id); - Assert.That(retrievedAssistant.Id, Is.EqualTo(assistant.Id)); - Assert.That(retrievedAssistant.Metadata.TryGetValue("testkey", out string metadataValue) && metadataValue == "hello!"); - Assistant modifiedAssistant = await client.ModifyAssistantAsync(assistant.Id, new AssistantModificationOptions() - { - Metadata = - { - ["testkey"] = "goodbye!", - }, - }); - Assert.That(modifiedAssistant.Id, Is.EqualTo(assistant.Id)); - IAsyncEnumerable recentAssistants = client.GetAssistantsAsync().GetAllValuesAsync(); - //SyncOrAsync( - // client, c => c.GetAssistants(), c => c.GetAssistantsAsync()); - Assistant recentAssistant = null; - await foreach (Assistant asyncAssistant in recentAssistants) - { - recentAssistant = asyncAssistant; - break; - } - Assert.That(recentAssistant, Is.Not.Null); - Assert.That(recentAssistant.Metadata.TryGetValue("testkey", out string newMetadataValue) && newMetadataValue == "goodbye!"); - } - - // [RecordedTest] - // public async Task BasicThreadOperationsWork() - // { - // AssistantClient client = GetTestClient(); - // AssistantThread thread = await client.CreateThreadAsync(); - // Validate(thread); - // Assert.That(thread.CreatedAt, Is.GreaterThan(s_2024)); - // bool deleted = await client.DeleteThreadAsync(thread.Id); - // Assert.That(deleted, Is.True); - - // ThreadCreationOptions options = new() - // { - // Metadata = - // { - // ["threadMetadata"] = "threadMetadataValue", - // } - // }; - // thread = await client.CreateThreadAsync(options); - // Validate(thread); - // Assert.That(thread.Metadata.TryGetValue("threadMetadata", out string threadMetadataValue) && threadMetadataValue == "threadMetadataValue"); - // AssistantThread retrievedThread = await client.GetThreadAsync(thread.Id); - // Assert.That(retrievedThread.Id, Is.EqualTo(thread.Id)); - // thread = await client.ModifyThreadAsync(thread, new ThreadModificationOptions() - // { - // Metadata = - // { - // ["threadMetadata"] = "newThreadMetadataValue", - // }, - // }); - // Assert.That(thread.Metadata.TryGetValue("threadMetadata", out threadMetadataValue) && threadMetadataValue == "newThreadMetadataValue"); - // } - - // [RecordedTest] - // public async Task SettingResponseFormatWorks() - // { - // AssistantClient client = GetTestClient(); - // string modelName = client.DeploymentOrThrow(); - - // Assistant assistant = await client.CreateAssistantAsync(modelName, new() - // { - // ResponseFormat = AssistantResponseFormat.JsonObject, - // }); - // Validate(assistant); - // Assert.That(assistant.ResponseFormat, Is.EqualTo(AssistantResponseFormat.JsonObject)); - // assistant = await client.ModifyAssistantAsync(assistant, new() - // { - // ResponseFormat = AssistantResponseFormat.Text, - // }); - // Assert.That(assistant.ResponseFormat, Is.EqualTo(AssistantResponseFormat.Text)); - // AssistantThread thread = await client.CreateThreadAsync(); - // Validate(thread); - // ThreadMessage message = await client.CreateMessageAsync(thread.Id, MessageRole.User, ["Write some JSON for me!"]); - // Validate(message); - // ThreadRun run = await client.CreateRunAsync(thread, assistant, new() - // { - // ResponseFormat = AssistantResponseFormat.JsonObject, - // }); - // Validate(run); - // Assert.That(run.ResponseFormat, Is.EqualTo(AssistantResponseFormat.JsonObject)); - // } - - // [RecordedTest] - // public async Task StreamingToolCall() - // { - // AssistantClient client = GetTestClient(); - // string modelName = client.DeploymentOrThrow(); - // FunctionToolDefinition getWeatherTool = new("get_current_weather", "Gets the user's current weather"); - // Assistant assistant = await client.CreateAssistantAsync(modelName, new() - // { - // Tools = { getWeatherTool } - // }); - // Validate(assistant); - - // Stopwatch stopwatch = Stopwatch.StartNew(); - // void Print(string message) => Console.WriteLine($"[{stopwatch.ElapsedMilliseconds,6}] {message}"); - - // Print(" >>> Beginning call ... "); - - // ThreadCreationOptions thrdOpt = new() - // { - // InitialMessages = { new(MessageRole.User, ["What should I wear outside right now?"]), }, - // }; - // AsyncResultCollection asyncResults = SyncOrAsync(client, - // c => c.CreateThreadAndRunStreaming(assistant, thrdOpt), - // c => c.CreateThreadAndRunStreamingAsync(assistant, thrdOpt)); - - // Print(" >>> Starting enumeration ..."); - - // ThreadRun run = null; - - // do - // { - // run = null; - // List toolOutputs = new(); - // await foreach (StreamingUpdate update in asyncResults) - // { - // string message = update.UpdateKind.ToString(); - - // if (update is RunUpdate runUpdate) - // { - // message += $" run_id:{runUpdate.Value.Id}"; - // run = runUpdate.Value; - // } - // if (update is RequiredActionUpdate requiredActionUpdate) - // { - // Assert.That(requiredActionUpdate.FunctionName, Is.EqualTo(getWeatherTool.FunctionName)); - // Assert.That(requiredActionUpdate.GetThreadRun().Status, Is.EqualTo(RunStatus.RequiresAction)); - // message += $" {requiredActionUpdate.FunctionName}"; - // toolOutputs.Add(new(requiredActionUpdate.ToolCallId, "warm and sunny")); - // } - // if (update is MessageContentUpdate contentUpdate) - // { - // message += $" {contentUpdate.Text}"; - // } - // Print(message); - // } - // if (toolOutputs.Count > 0) - // { - // asyncResults = SyncOrAsync(client, - // c => c.SubmitToolOutputsToRunStreaming(run, toolOutputs), - // c => c.SubmitToolOutputsToRunStreamingAsync(run, toolOutputs)); - // } - // } while (run?.Status.IsTerminal == false); - // } - - // [RecordedTest] - // public async Task BasicMessageOperationsWork() - // { - // // TODO FIXME Can't currently delete messages on AOAI - // bool aoaiDeleteBugFixed = false; - - // AssistantClient client = GetTestClient(); - // AssistantThread thread = await client.CreateThreadAsync(); - // Validate(thread); - // ThreadMessage message = await client.CreateMessageAsync(thread.Id, MessageRole.User, ["Hello, world!"]); - // Validate(message); - // Assert.That(message.CreatedAt, Is.GreaterThan(s_2024)); - // Assert.That(message.Content?.Count, Is.EqualTo(1)); - // Assert.That(message.Content[0], Is.Not.Null); - // Assert.That(message.Content[0].Text, Is.EqualTo("Hello, world!")); - - // if (aoaiDeleteBugFixed) - // { - // bool deleted = await client.DeleteMessageAsync(message); - // Assert.That(deleted, Is.True); - // } - - // message = await client.CreateMessageAsync(thread.Id, MessageRole.User, ["Goodbye, world!"], new MessageCreationOptions() - // { - // Metadata = - // { - // ["messageMetadata"] = "messageMetadataValue", - // }, - // }); - // Validate(message); - // Assert.That(message.Metadata.TryGetValue("messageMetadata", out string metadataValue) && metadataValue == "messageMetadataValue"); - - // ThreadMessage retrievedMessage = await client.GetMessageAsync(thread.Id, message.Id); - // Assert.That(retrievedMessage.Id, Is.EqualTo(message.Id)); - - // message = await client.ModifyMessageAsync(message, new MessageModificationOptions() - // { - // Metadata = - // { - // ["messageMetadata"] = "newValue", - // } - // }); - // Assert.That(message.Metadata.TryGetValue("messageMetadata", out metadataValue) && metadataValue == "newValue"); - - // var messagePage = await SyncOrAsyncList(client, - // c => c.GetMessages(thread), - // c => c.GetMessagesAsync(thread)); - // if (aoaiDeleteBugFixed) - // { - // Assert.That(messagePage.Count, Is.EqualTo(1)); - // } - // else - // { - // Assert.That(messagePage.Count, Is.EqualTo(2)); - // } - - // Assert.That(messagePage.ElementAt(0).Id, Is.EqualTo(message.Id)); - // Assert.That(messagePage.ElementAt(0).Metadata.TryGetValue("messageMetadata", out metadataValue) && metadataValue == "newValue"); - // } - - // [RecordedTest] - // public async Task ThreadWithInitialMessagesWorks() - // { - // AssistantClient client = GetTestClient(); - // ThreadCreationOptions options = new() - // { - // InitialMessages = - // { - // new ThreadInitializationMessage(MessageRole.User, ["Hello, world!"]), - // new ThreadInitializationMessage(MessageRole.User, - // [ - // "Can you describe this image for me?", - // MessageContent.FromImageUrl(new Uri("https://test.openai.com/image.png")) - // ]) - // { - // Metadata = - // { - // ["messageMetadata"] = "messageMetadataValue", - // }, - // }, - // }, - // }; - // AssistantThread thread = await client.CreateThreadAsync (options); - // Validate(thread); - // List messageList = await SyncOrAsyncList(client, - // c => c.GetMessages(thread, resultOrder: ListOrder.OldestFirst), - // c => c.GetMessagesAsync(thread, resultOrder: ListOrder.OldestFirst)); - // Assert.That(messageList.Count, Is.EqualTo(2)); - // Assert.That(messageList[0].Role, Is.EqualTo(MessageRole.User)); - // Assert.That(messageList[0].Content?.Count, Is.EqualTo(1)); - // Assert.That(messageList[0].Content[0].Text, Is.EqualTo("Hello, world!")); - // Assert.That(messageList[1].Content?.Count, Is.EqualTo(2)); - // Assert.That(messageList[1].Content[0], Is.Not.Null); - // Assert.That(messageList[1].Content[0].Text, Is.EqualTo("Can you describe this image for me?")); - // Assert.That(messageList[1].Content[1], Is.Not.Null); - // Assert.That(messageList[1].Content[1].ImageUrl.AbsoluteUri, Is.EqualTo("https://test.openai.com/image.png")); - // } - - // [RecordedTest] - // public async Task BasicRunOperationsWork() - // { - // AssistantClient client = GetTestClient(); - // string modelName = client.DeploymentOrThrow(); - // Assistant assistant = await client.CreateAssistantAsync(modelName); - // Validate(assistant); - // AssistantThread thread = await client.CreateThreadAsync(); - // Validate(thread); - // List runPage = await SyncOrAsyncList(client, - // c => c.GetRuns(thread.Id), - // c => c.GetRunsAsync(thread.Id)); - // Assert.That(runPage.Count, Is.EqualTo(0)); - // ThreadMessage message = await client.CreateMessageAsync(thread.Id, MessageRole.User, ["Hello, assistant!"]); - // Validate(message); - // ThreadRun run = await client.CreateRunAsync(thread.Id, assistant.Id); - // Validate(run); - // Assert.That(run.Status, Is.EqualTo(RunStatus.Queued)); - // Assert.That(run.CreatedAt, Is.GreaterThan(s_2024)); - // ThreadRun retrievedRun = await client.GetRunAsync(thread.Id, run.Id); - // Assert.That(retrievedRun.Id, Is.EqualTo(run.Id)); - // runPage = await SyncOrAsyncList(client, - // c => c.GetRuns(thread.Id), - // c => c.GetRunsAsync(thread.Id)); - // Assert.That(runPage.Count, Is.EqualTo(1)); - // Assert.That(runPage.ElementAt(0).Id, Is.EqualTo(run.Id)); - - // List messages = await SyncOrAsyncList(client, - // c => c.GetMessages(thread), - // c => c.GetMessagesAsync(thread)); - // Assert.That(messages.Count, Is.GreaterThanOrEqualTo(1)); - - // run = await WaitUntilReturnLast( - // run, - // () => client.GetRunAsync(run), - // r => r.Status.IsTerminal); - // Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); - - // Assert.Multiple(() => - // { - // Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); - // Assert.That(run.CompletedAt, Is.GreaterThan(s_2024)); - // Assert.That(run.RequiredActions, Is.Empty); - // Assert.That(run.AssistantId, Is.EqualTo(assistant.Id)); - // Assert.That(run.FailedAt, Is.Null); - // Assert.That(run.IncompleteDetails, Is.Null); - // }); - // messages = await SyncOrAsyncList(client, - // c => c.GetMessages(thread), - // c => c.GetMessagesAsync(thread)); - // Assert.That(messages.Count, Is.EqualTo(2)); - - // Assert.That(messages.ElementAt(0).Role, Is.EqualTo(MessageRole.Assistant)); - // Assert.That(messages.ElementAt(1).Role, Is.EqualTo(MessageRole.User)); - // Assert.That(messages.ElementAt(1).Id, Is.EqualTo(message.Id)); - // } - - // [RecordedTest] - // public async Task BasicRunStepFunctionalityWorks() - // { - // AssistantClient client = GetTestClient(); - // string modelName = client.DeploymentOrThrow(); - // Assistant assistant = await client.CreateAssistantAsync(modelName, new AssistantCreationOptions() - // { - // Tools = { new CodeInterpreterToolDefinition() }, - // Instructions = "Call the code interpreter tool when asked to visualize mathematical concepts.", - // }); - // Validate(assistant); - - // AssistantThread thread = await client.CreateThreadAsync(new ThreadCreationOptions() - // { - // InitialMessages = { new(MessageRole.User, ["Please graph the equation y = 3x + 4"]), }, - // }); - // Validate(thread); - - // ThreadRun run = await client.CreateRunAsync(thread, assistant); - // Validate(run); - - // run = await WaitUntilReturnLast( - // run, - // () => client.GetRunAsync(run), - // r => r.Status.IsTerminal); - // Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); - // Assert.That(run.Usage?.TotalTokens, Is.GreaterThan(0)); - - // List runSteps = await SyncOrAsyncList(client, - // c => c.GetRunSteps(run), - // c => c.GetRunStepsAsync(run)); - // Assert.That(runSteps.Count(), Is.GreaterThan(1)); - // Assert.Multiple(() => - // { - // Assert.That(runSteps.ElementAt(0).AssistantId, Is.EqualTo(assistant.Id)); - // Assert.That(runSteps.ElementAt(0).ThreadId, Is.EqualTo(thread.Id)); - // Assert.That(runSteps.ElementAt(0).RunId, Is.EqualTo(run.Id)); - // Assert.That(runSteps.ElementAt(0).CreatedAt, Is.GreaterThan(s_2024)); - // Assert.That(runSteps.ElementAt(0).CompletedAt, Is.GreaterThan(s_2024)); - // }); - // RunStepDetails details = runSteps.ElementAt(0).Details; - // Assert.That(details?.CreatedMessageId, Is.Not.Null.Or.Empty); - - // details = runSteps.ElementAt(1).Details; - // Assert.Multiple(() => - // { - // Assert.That(details?.ToolCalls.Count, Is.GreaterThan(0)); - // Assert.That(details.ToolCalls[0].ToolKind, Is.EqualTo(RunStepToolCallKind.CodeInterpreter)); - // Assert.That(details.ToolCalls[0].ToolCallId, Is.Not.Null.Or.Empty); - // Assert.That(details.ToolCalls[0].CodeInterpreterInput, Is.Not.Null.Or.Empty); - // Assert.That(details.ToolCalls[0].CodeInterpreterOutputs?.Count, Is.GreaterThan(0)); - // Assert.That(details.ToolCalls[0].CodeInterpreterOutputs[0].ImageFileId, Is.Not.Null.Or.Empty); - // }); - // } - - // [RecordedTest] - // public async Task FunctionToolsWork() - // { - // AssistantClient client = GetTestClient(); - // string modelName = client.DeploymentOrThrow(); - // Assistant assistant = await client.CreateAssistantAsync(modelName, new AssistantCreationOptions() - // { - // Tools = - // { - // new FunctionToolDefinition() - // { - // FunctionName = "get_favorite_food_for_day_of_week", - // Description = "gets the user's favorite food for a given day of the week, like Tuesday", - // Parameters = BinaryData.FromObjectAsJson(new - // { - // type = "object", - // properties = new - // { - // day_of_week = new - // { - // type = "string", - // description = "a day of the week, like Tuesday or Saturday", - // } - // } - // }), - // }, - // }, - // }); - // Validate(assistant); - // Assert.That(assistant.Tools?.Count, Is.EqualTo(1)); - - // FunctionToolDefinition responseToolDefinition = assistant.Tools[0] as FunctionToolDefinition; - // Assert.That(responseToolDefinition?.FunctionName, Is.EqualTo("get_favorite_food_for_day_of_week")); - // Assert.That(responseToolDefinition?.Parameters, Is.Not.Null); - - // ThreadRun run = await client.CreateThreadAndRunAsync( - // assistant, - // new ThreadCreationOptions() - // { - // InitialMessages = { new(MessageRole.User, ["What should I eat on Thursday?"]) }, - // }, - // new RunCreationOptions() - // { - // AdditionalInstructions = "Call provided tools when appropriate.", - // }); - // Validate(run); - // Console.WriteLine($" Run status right after creation: {run.Status}"); - - // // TODO FIXME: The underlying OpenAI code doesn't consider the "requires_action" status to be terminal even though it is. - // // Work around this here - // run = await WaitUntilReturnLast( - // run, - // () => client.GetRunAsync(run), - // r => r.Status.IsTerminal || r.Status.Equals(RunStatus.RequiresAction)); - - // Assert.That(run.Status, Is.EqualTo(RunStatus.RequiresAction)); - // Assert.That(run.RequiredActions?.Count, Is.EqualTo(1)); - // Assert.That(run.RequiredActions[0].ToolCallId, Is.Not.Null.Or.Empty); - // Assert.That(run.RequiredActions[0].FunctionName, Is.EqualTo("get_favorite_food_for_day_of_week")); - // Assert.That(run.RequiredActions[0].FunctionArguments, Is.Not.Null.Or.Empty); - - // run = await client.SubmitToolOutputsToRunAsync(run, [new(run.RequiredActions[0].ToolCallId, "tacos")]); - // Assert.That(run.Status.IsTerminal, Is.False); - - // run = await WaitUntilReturnLast( - // run, - // () => client.GetRunAsync(run), - // r => r.Status.IsTerminal); - // Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); - - // List messages = await SyncOrAsyncList(client, - // c => c.GetMessages(run.ThreadId, resultOrder: ListOrder.NewestFirst), - // c => c.GetMessagesAsync(run.ThreadId, resultOrder: ListOrder.NewestFirst)); - // Assert.That(messages.Count, Is.GreaterThan(1)); - // Assert.That(messages.ElementAt(0).Role, Is.EqualTo(MessageRole.Assistant)); - // Assert.That(messages.ElementAt(0).Content?[0], Is.Not.Null); - // Assert.That(messages.ElementAt(0).Content?[0].Text, Does.Contain("tacos")); - // } - - // [RecordedTest] - // public async Task BasicFileSearchWorks() - // { - // // First, we need to upload a simple test file. - // AssistantClient client = GetTestClient(); - // string modelName = client.DeploymentOrThrow(); - // FileClient fileClient = GetTestClientFrom(client); - - // OpenAIFileInfo testFile = await fileClient.UploadFileAsync( - // BinaryData.FromString(""" - // This file describes the favorite foods of several people. - - // Summanus Ferdinand: tacos - // Tekakwitha Effie: pizza - // Filip Carola: cake - // """), - // "favorite_foods.txt", - // FileUploadPurpose.Assistants); - // Validate(testFile); - - // // Create an assistant, using the creation helper to make a new vector store - // Assistant assistant = await client.CreateAssistantAsync(modelName, new() - // { - // Tools = { new FileSearchToolDefinition() }, - // ToolResources = new() - // { - // FileSearch = new() - // { - // NewVectorStores = - // { - // new VectorStoreCreationHelper([testFile]), - // } - // } - // } - // }); - // Validate(assistant); - // Assert.That(assistant.ToolResources?.FileSearch?.VectorStoreIds, Has.Count.EqualTo(1)); - // string createdVectorStoreId = assistant.ToolResources.FileSearch.VectorStoreIds[0]; - // ValidateById(createdVectorStoreId); - - // // Modify an assistant to use the existing vector store - // assistant = await client.ModifyAssistantAsync(assistant, new AssistantModificationOptions() - // { - // ToolResources = new() - // { - // FileSearch = new() - // { - // VectorStoreIds = { assistant.ToolResources.FileSearch.VectorStoreIds[0] }, - // }, - // }, - // }); - // Assert.That(assistant.ToolResources?.FileSearch?.VectorStoreIds, Has.Count.EqualTo(1)); - // Assert.That(assistant.ToolResources.FileSearch.VectorStoreIds[0], Is.EqualTo(createdVectorStoreId)); - - // // Create a thread with an override vector store - // AssistantThread thread = await client.CreateThreadAsync(new ThreadCreationOptions() - // { - // InitialMessages = { new(MessageRole.User, ["Using the files you have available, what's Filip's favorite food?"]) }, - // ToolResources = new() - // { - // FileSearch = new() - // { - // NewVectorStores = - // { - // new VectorStoreCreationHelper([testFile.Id]) - // } - // } - // } - // }); - // Validate(thread); - // Assert.That(thread.ToolResources?.FileSearch?.VectorStoreIds, Has.Count.EqualTo(1)); - // createdVectorStoreId = thread.ToolResources.FileSearch.VectorStoreIds[0]; - // ValidateById(createdVectorStoreId); - - // // Ensure that modifying the thread with an existing vector store works - // thread = await client.ModifyThreadAsync(thread, new ThreadModificationOptions() - // { - // ToolResources = new() - // { - // FileSearch = new() - // { - // VectorStoreIds = { createdVectorStoreId }, - // } - // } - // }); - // Assert.That(thread.ToolResources?.FileSearch?.VectorStoreIds, Has.Count.EqualTo(1)); - // Assert.That(thread.ToolResources.FileSearch.VectorStoreIds[0], Is.EqualTo(createdVectorStoreId)); - - // ThreadRun run = await client.CreateRunAsync(thread, assistant); - // Validate(run); - // run = await WaitUntilReturnLast( - // run, - // () => client.GetRunAsync(run), - // r => r.Status.IsTerminal); - // Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); - - // AsyncPageableCollection messages = SyncOrAsync(client, - // c => c.GetMessages(thread, resultOrder: ListOrder.NewestFirst), - // c => c.GetMessagesAsync(thread, resultOrder: ListOrder.NewestFirst)); - // bool hasAtLeastOne = false; - // bool hasCake = false; - // await foreach (ThreadMessage message in messages) - // { - // hasAtLeastOne = true; - // foreach (MessageContent content in message.Content) - // { - // Console.WriteLine(content.Text); - // hasCake |= content.Text?.ToLowerInvariant().Contains("cake") == true; - // foreach (TextAnnotation annotation in content.TextAnnotations) - // { - // Console.WriteLine($" --> From file: {annotation.InputFileId}, replacement: {annotation.TextToReplace}"); - // } - // } - // } - // Assert.That(hasAtLeastOne, Is.True); - // Assert.That(hasCake, Is.True); - // } - - // [RecordedTest] - // public async Task StreamingRunWorks() - // { - // AssistantClient client = GetTestClient(); - // string modelName = client.DeploymentOrThrow(); - // Assistant assistant = await client.CreateAssistantAsync(modelName); - // Validate(assistant); - - // AssistantThread thread = await client.CreateThreadAsync(new ThreadCreationOptions() - // { - // InitialMessages = { new(MessageRole.User, ["Hello there, assistant! How are you today?"]), }, - // }); - // Validate(thread); - - // AsyncResultCollection streamingResult = SyncOrAsync(client, - // c => c.CreateRunStreaming(thread.Id, assistant.Id), - // c => c.CreateRunStreamingAsync(thread.Id, assistant.Id)); - - // StringBuilder content = new(); - // DateTimeOffset? lastUpdate = null; - // StreamingUpdateReason? lastUpdateReason = null; - - // await foreach (StreamingUpdate update in streamingResult) - // { - // if (update is RunUpdate runUpdate) - // { - // lastUpdateReason = runUpdate.UpdateKind; - // lastUpdate = update.UpdateKind switch - // { - // StreamingUpdateReason.RunCreated => runUpdate.Value.CreatedAt, - // StreamingUpdateReason.RunQueued => runUpdate.Value.StartedAt, - // StreamingUpdateReason.RunInProgress => runUpdate.Value.StartedAt, - // StreamingUpdateReason.RunCompleted => runUpdate.Value.CompletedAt, - // _ => null, - // }; - // } - // if (update is MessageContentUpdate contentUpdate) - // { - // // TODO FIXME: The OpenAI library code is currently incorrectly returning a MessageRole.User value here. - // // It should instead be null or at least Assistant - // //Assert.That(contentUpdate.Role, Is.Null.Or.EqualTo(MessageRole.Assistant)); - // Assert.That(contentUpdate.Text, Is.Not.Null); // can be empty string - // content.Append(contentUpdate.Text); - // } - // } - - // Assert.That(lastUpdateReason, Is.EqualTo(StreamingUpdateReason.RunCompleted)); - // Assert.That(lastUpdate, Is.Not.Null.And.GreaterThan(s_2024)); - // Assert.That(content, Has.Length.GreaterThan(0)); - // } - - private static readonly DateTimeOffset s_2024 = new(2024, 1, 1, 0, 0, 0, TimeSpan.Zero); -} diff --git a/.dotnet.azure/.tests.staging/Azure.AI.OpenAI.Tests.csproj b/.dotnet.azure/.tests.staging/Azure.AI.OpenAI.Tests.csproj deleted file mode 100644 index 3cb4fedeb..000000000 --- a/.dotnet.azure/.tests.staging/Azure.AI.OpenAI.Tests.csproj +++ /dev/null @@ -1,21 +0,0 @@ - - - net7.0 - - $(NoWarn);CS1591 - Unsigned - latest - - - - - - - - - - - - - - diff --git a/.dotnet.azure/CHANGELOG.md b/.dotnet.azure/CHANGELOG.md index 972422632..fc64493a3 100644 --- a/.dotnet.azure/CHANGELOG.md +++ b/.dotnet.azure/CHANGELOG.md @@ -1,5 +1,35 @@ # Release History +## 2.0.0-beta.5 (2024-09-03) + +This update increments library compatibility to `OpenAI 2.0.0-beta.11`, including several breaking changes. + +### Features Added + +- Added the `OpenAIChatModelFactory` in the `OpenAI.Chat` namespace (a static class that can be used to instantiate OpenAI models for mocking in non-live test scenarios). ([79014ab](https://github.com/openai/openai-dotnet/commit/79014abc01a00e13d5a334d3f6529ed590b8ee98)) + +### Breaking Changes + +- Updated fine-tuning pagination methods `GetJobs`, `GetEvents`, and `GetJobCheckpoints` to return `IEnumerable` instead of `ClientResult`. ([5773292](https://github.com/openai/openai-dotnet/commit/57732927575c6c48f30bded0afb9f5b16d4f30da)) +- Updated the batching pagination method `GetBatches` to return `IEnumerable` instead of `ClientResult`. ([5773292](https://github.com/openai/openai-dotnet/commit/57732927575c6c48f30bded0afb9f5b16d4f30da)) +- Changed `GeneratedSpeechVoice` from an enum to an "extensible enum". ([79014ab](https://github.com/openai/openai-dotnet/commit/79014abc01a00e13d5a334d3f6529ed590b8ee98)) +- Changed `GeneratedSpeechFormat` from an enum to an "extensible enum". ([cc9169a](https://github.com/openai/openai-dotnet/commit/cc9169ad2ff92bb7312eed3b7e64e45da5da1d18)) +- Renamed `SpeechGenerationOptions`'s `Speed` property to `SpeedRatio`. ([cc9169a](https://github.com/openai/openai-dotnet/commit/cc9169ad2ff92bb7312eed3b7e64e45da5da1d18)) + +### Bugs Fixed + +- Corrected an internal deserialization issue that caused recent updates to Assistants `file_search` to fail when streaming a run. Strongly typed support for `ranking_options` is not included but will arrive soon. ([cc9169a](https://github.com/openai/openai-dotnet/commit/cc9169ad2ff92bb7312eed3b7e64e45da5da1d18)) +- Mitigated a .NET runtime issue that prevented `ChatResponseFormat` from serializing correct on targets including Unity. ([cc9169a](https://github.com/openai/openai-dotnet/commit/cc9169ad2ff92bb7312eed3b7e64e45da5da1d18)) + +## 2.0.0-beta.4 (2024-08-30) + +This small release increments library compatibility to the latest `OpenAI 2.0.0-beta.10`. Prior to this update, interactions with the two breaking changes described below prevented full interoperability. + +### Breaking Changes + +- `AudioClient`'s `GenerateSpeechFromText()` method is renamed to `GenerateSpeech()` +- `OpenAIFileInfo`'s `SizeInBytes` is now of type `int?` (previously `long?`) + ## 2.0.0-beta.3 (2024-08-23) This change updates the library for compatibility with the latest `2.0.0-beta.9` of the `OpenAI` package and the `2024-07-01-preview` Azure OpenAI service API version label, as published on 8/5. diff --git a/.dotnet.azure/Directory.Build.props b/.dotnet.azure/Directory.Build.props new file mode 100644 index 000000000..0d1b0a607 --- /dev/null +++ b/.dotnet.azure/Directory.Build.props @@ -0,0 +1,43 @@ + + + $(MSBuildThisFileDirectory) + $(MSBuildThisFileDirectory)eng + $(RepoRoot)src + $(RepoRoot)src/SDKs + true + true + true + + + + + Debug + AnyCPU + $(Platform) + + + + + $(RepoRoot)artifacts\ + $(ArtifactsDir)obj\ + $(ArtifactsDir)bin\ + $(ArtifactsDir)packages\$(Configuration)\ + + $(MSBuildProjectName) + + $([System.IO.Path]::GetFullPath('$(ArtifactsBinDir)$(OutDirName)\')) + $(BaseOutputPath)$(Configuration)\ + $(BaseOutputPath)$(PlatformName)\$(Configuration)\ + + $([System.IO.Path]::GetFullPath('$(ArtifactsObjDir)$(OutDirName)\')) + $(BaseIntermediateOutputPath)$(Configuration)\ + $(BaseIntermediateOutputPath)$(PlatformName)\$(Configuration)\ + + $(ArtifactsPackagesDir)/$(MSBuildProjectName) + + + import-required-properties + + + + diff --git a/.dotnet.azure/Directory.Build.targets b/.dotnet.azure/Directory.Build.targets new file mode 100644 index 000000000..3d4330c2e --- /dev/null +++ b/.dotnet.azure/Directory.Build.targets @@ -0,0 +1,10 @@ + + + + <_Parameter1>SourcePath + <_Parameter2>$(MSBuildProjectDirectory) + + + + + diff --git a/.dotnet.azure/assets.json b/.dotnet.azure/assets.json deleted file mode 100644 index 39d86235b..000000000 --- a/.dotnet.azure/assets.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "AssetsRepo": "Azure/azure-sdk-assets", - "AssetsRepoPrefixPath": "net", - "TagPrefix": "net/openai/Azure.AI.OpenAI", - "Tag": "net/openai/Azure.AI.OpenAI_23ae923738" -} diff --git a/.dotnet.azure/eng/CodeAnalysis.ruleset b/.dotnet.azure/eng/CodeAnalysis.ruleset new file mode 100644 index 000000000..d6ade187a --- /dev/null +++ b/.dotnet.azure/eng/CodeAnalysis.ruleset @@ -0,0 +1,404 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.dotnet.azure/eng/Directory.Build.Common.props b/.dotnet.azure/eng/Directory.Build.Common.props new file mode 100644 index 000000000..59a538ae8 --- /dev/null +++ b/.dotnet.azure/eng/Directory.Build.Common.props @@ -0,0 +1,56 @@ + + + + true + true + true + 11.0 + true + + $(NoWarn); + NU5105; + + CA1812; + CA1716; + CA1308; + CA1819; + CA1710; + CA1028; + CA1032; + CA1063; + CA1066; + CA1815; + CA2007; + CA2231; + CA2225; + CA1714; + CA1062; + CA1031; + CA2000; + CA2012; + + MSB3245; + AZPROVISION001; + + true + + + + + + netstandard2.0 + $(WarningsNotAsErrors);NU1901;NU1902;NU1903;NU1904 + $(RepoEngPath)\CodeAnalysis.ruleset + + + + + + net8.0;net6.0 + $(RequiredTargetFrameworks);net462 + + + + + + diff --git a/.dotnet.azure/eng/Directory.Build.Common.targets b/.dotnet.azure/eng/Directory.Build.Common.targets new file mode 100644 index 000000000..e3e5db31c --- /dev/null +++ b/.dotnet.azure/eng/Directory.Build.Common.targets @@ -0,0 +1,3 @@ + + + diff --git a/.dotnet.azure/eng/Packages.Data.props b/.dotnet.azure/eng/Packages.Data.props new file mode 100644 index 000000000..f868bc705 --- /dev/null +++ b/.dotnet.azure/eng/Packages.Data.props @@ -0,0 +1,395 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1.0.0-dev.20240806.1 + + + diff --git a/.dotnet.azure/eng/common/testproxy/dotnet-devcert.pfx b/.dotnet.azure/eng/common/testproxy/dotnet-devcert.pfx new file mode 100644 index 000000000..1d59ff89a Binary files /dev/null and b/.dotnet.azure/eng/common/testproxy/dotnet-devcert.pfx differ diff --git a/.dotnet.azure/nuget.config b/.dotnet.azure/nuget.config new file mode 100644 index 000000000..1f889a235 --- /dev/null +++ b/.dotnet.azure/nuget.config @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/.dotnet.azure/sdk/openai/.gitignore b/.dotnet.azure/sdk/openai/.gitignore new file mode 100644 index 000000000..5e5364f16 --- /dev/null +++ b/.dotnet.azure/sdk/openai/.gitignore @@ -0,0 +1 @@ +#Azure.AI.OpenAI/Directory.Build.props \ No newline at end of file diff --git a/.dotnet.azure/Azure.AI.OpenAI.sln b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI.sln similarity index 64% rename from .dotnet.azure/Azure.AI.OpenAI.sln rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI.sln index 7962a56eb..20dcf6c40 100644 --- a/.dotnet.azure/Azure.AI.OpenAI.sln +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI.sln @@ -3,23 +3,27 @@ Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio Version 17 VisualStudioVersion = 17.10.35004.147 MinimumVisualStudioVersion = 10.0.40219.1 -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Azure.AI.OpenAI", "src\Azure.AI.OpenAI.csproj", "{A80B9566-84A5-4AE4-AA0A-72B18646F1EC}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Azure.AI.OpenAI", "Azure.AI.OpenAI\src\Azure.AI.OpenAI.csproj", "{A80B9566-84A5-4AE4-AA0A-72B18646F1EC}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenAI", "..\.dotnet\src\OpenAI.csproj", "{8BEE571B-DB25-4BE5-B9EB-2CA81D12EBC6}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenAI", "..\..\..\.dotnet\src\OpenAI.csproj", "{8BEE571B-DB25-4BE5-B9EB-2CA81D12EBC6}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Azure.AI.OpenAI.Tests", ".tests.staging\Azure.AI.OpenAI.Tests.csproj", "{23DAB09E-3986-4248-AC80-2273C20FCD90}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Azure.AI.OpenAI.Tests", "Azure.AI.OpenAI\tests\Azure.AI.OpenAI.Tests.csproj", "{23DAB09E-3986-4248-AC80-2273C20FCD90}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenAI.TestFramework", "tools\TestFramework\src\OpenAI.TestFramework.csproj", "{D1E3E196-BAA8-47C2-905A-B1C20733AEA8}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Unsigned|Any CPU = Unsigned|Any CPU EndGlobalSection GlobalSection(ProjectConfigurationPlatforms) = postSolution - {A80B9566-84A5-4AE4-AA0A-72B18646F1EC}.Unsigned|Any CPU.ActiveCfg = Unsigned|Any CPU - {A80B9566-84A5-4AE4-AA0A-72B18646F1EC}.Unsigned|Any CPU.Build.0 = Unsigned|Any CPU + {A80B9566-84A5-4AE4-AA0A-72B18646F1EC}.Unsigned|Any CPU.ActiveCfg = Debug|Any CPU + {A80B9566-84A5-4AE4-AA0A-72B18646F1EC}.Unsigned|Any CPU.Build.0 = Debug|Any CPU {8BEE571B-DB25-4BE5-B9EB-2CA81D12EBC6}.Unsigned|Any CPU.ActiveCfg = Unsigned|Any CPU {8BEE571B-DB25-4BE5-B9EB-2CA81D12EBC6}.Unsigned|Any CPU.Build.0 = Unsigned|Any CPU - {23DAB09E-3986-4248-AC80-2273C20FCD90}.Unsigned|Any CPU.ActiveCfg = Unsigned|Any CPU - {23DAB09E-3986-4248-AC80-2273C20FCD90}.Unsigned|Any CPU.Build.0 = Unsigned|Any CPU + {23DAB09E-3986-4248-AC80-2273C20FCD90}.Unsigned|Any CPU.ActiveCfg = Debug|Any CPU + {23DAB09E-3986-4248-AC80-2273C20FCD90}.Unsigned|Any CPU.Build.0 = Debug|Any CPU + {D1E3E196-BAA8-47C2-905A-B1C20733AEA8}.Unsigned|Any CPU.ActiveCfg = Debug|Any CPU + {D1E3E196-BAA8-47C2-905A-B1C20733AEA8}.Unsigned|Any CPU.Build.0 = Debug|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/Directory.Build.props b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/Directory.Build.props new file mode 100644 index 000000000..924ecfa8f --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/Directory.Build.props @@ -0,0 +1,20 @@ + + + + + + $(RepoRoot)/../.dotnet/src/OpenAI.csproj + 1.1.0-beta.5 + + + + + + + + + diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/assets.json b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/assets.json new file mode 100644 index 000000000..78d850850 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/assets.json @@ -0,0 +1,6 @@ +{ + "AssetsRepo": "Azure/azure-sdk-assets", + "AssetsRepoPrefixPath": "net", + "TagPrefix": "dotnet.azure/openai/Azure.AI.OpenAI", + "Tag": "dotnet.azure/openai/Azure.AI.OpenAI_9a2f5cd1c9" +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Azure.AI.OpenAI.csproj b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Azure.AI.OpenAI.csproj new file mode 100644 index 000000000..b1440d99b --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Azure.AI.OpenAI.csproj @@ -0,0 +1,65 @@ + + + + + Azure OpenAI's official extension package for using OpenAI's .NET library with the Azure OpenAI Service. + + Azure.AI.OpenAI Client Library + 2.0.0 + beta.5 + Microsoft Azure OpenAI + true + $(RequiredTargetFrameworks) + true + $(NoWarn);CS1591;AZC0012;AZC0102;CS8002;CS0436;AZC0112;OPENAI001 + enable + preview + disable + + + + + 0024000004800000940000000602000000240000525341310004000001000100d15ddcb29688295338af4b7686603fe614abd555e09efba8fb88ee09e1f7b1ccaeed2e8f823fa9eef3fdd60217fc012ea67d2479751a0b8c087a4185541b851bd8b16f8d91b840e51b1cb0ba6fe647997e57429265e85ef62d565db50a69ae1647d54d7bd855e4db3d8a91510e5bcbd0edfbbecaa20a7bd9ae74593daa7b11b4 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/.dotnet.azure/src/Custom/Assistants/AzureAssistantClient.Protocol.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Assistants/AzureAssistantClient.Protocol.cs similarity index 100% rename from .dotnet.azure/src/Custom/Assistants/AzureAssistantClient.Protocol.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Assistants/AzureAssistantClient.Protocol.cs diff --git a/.dotnet.azure/src/Custom/Assistants/AzureAssistantClient.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Assistants/AzureAssistantClient.cs similarity index 100% rename from .dotnet.azure/src/Custom/Assistants/AzureAssistantClient.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Assistants/AzureAssistantClient.cs diff --git a/.dotnet.azure/src/Custom/Assistants/Internal/Pagination/AzureAssistantsPageEnumerator.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Assistants/Internal/Pagination/AzureAssistantsPageEnumerator.cs similarity index 100% rename from .dotnet.azure/src/Custom/Assistants/Internal/Pagination/AzureAssistantsPageEnumerator.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Assistants/Internal/Pagination/AzureAssistantsPageEnumerator.cs diff --git a/.dotnet.azure/src/Custom/Assistants/Internal/Pagination/AzureMessagesPageEnumerator.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Assistants/Internal/Pagination/AzureMessagesPageEnumerator.cs similarity index 100% rename from .dotnet.azure/src/Custom/Assistants/Internal/Pagination/AzureMessagesPageEnumerator.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Assistants/Internal/Pagination/AzureMessagesPageEnumerator.cs diff --git a/.dotnet.azure/src/Custom/Assistants/Internal/Pagination/AzureRunStepsPageEnumerator.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Assistants/Internal/Pagination/AzureRunStepsPageEnumerator.cs similarity index 100% rename from .dotnet.azure/src/Custom/Assistants/Internal/Pagination/AzureRunStepsPageEnumerator.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Assistants/Internal/Pagination/AzureRunStepsPageEnumerator.cs diff --git a/.dotnet.azure/src/Custom/Assistants/Internal/Pagination/AzureRunsPageEnumerator.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Assistants/Internal/Pagination/AzureRunsPageEnumerator.cs similarity index 100% rename from .dotnet.azure/src/Custom/Assistants/Internal/Pagination/AzureRunsPageEnumerator.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Assistants/Internal/Pagination/AzureRunsPageEnumerator.cs diff --git a/.dotnet.azure/src/Custom/Audio/AzureAudioClient.Protocol.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Audio/AzureAudioClient.Protocol.cs similarity index 100% rename from .dotnet.azure/src/Custom/Audio/AzureAudioClient.Protocol.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Audio/AzureAudioClient.Protocol.cs diff --git a/.dotnet.azure/src/Custom/Audio/AzureAudioClient.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Audio/AzureAudioClient.cs similarity index 100% rename from .dotnet.azure/src/Custom/Audio/AzureAudioClient.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Audio/AzureAudioClient.cs diff --git a/.dotnet.azure/src/Custom/AzureOpenAIAudience.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIAudience.cs similarity index 100% rename from .dotnet.azure/src/Custom/AzureOpenAIAudience.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIAudience.cs diff --git a/.dotnet.azure/src/Custom/AzureOpenAIClient.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIClient.cs similarity index 100% rename from .dotnet.azure/src/Custom/AzureOpenAIClient.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIClient.cs diff --git a/.dotnet.azure/src/Custom/AzureOpenAIClientOptions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIClientOptions.cs similarity index 100% rename from .dotnet.azure/src/Custom/AzureOpenAIClientOptions.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIClientOptions.cs diff --git a/.dotnet.azure/src/Custom/AzureTokenAuthenticationPolicy.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureTokenAuthenticationPolicy.cs similarity index 100% rename from .dotnet.azure/src/Custom/AzureTokenAuthenticationPolicy.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureTokenAuthenticationPolicy.cs diff --git a/.dotnet.azure/src/Custom/Batch/AzureBatchClient.Protocol.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Batch/AzureBatchClient.Protocol.cs similarity index 100% rename from .dotnet.azure/src/Custom/Batch/AzureBatchClient.Protocol.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Batch/AzureBatchClient.Protocol.cs diff --git a/.dotnet.azure/src/Custom/Batch/AzureBatchClient.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Batch/AzureBatchClient.cs similarity index 100% rename from .dotnet.azure/src/Custom/Batch/AzureBatchClient.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Batch/AzureBatchClient.cs diff --git a/.dotnet.azure/src/Custom/Chat/AzureChatClient.Protocol.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatClient.Protocol.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/AzureChatClient.Protocol.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatClient.Protocol.cs diff --git a/.dotnet.azure/src/Custom/Chat/AzureChatClient.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatClient.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/AzureChatClient.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatClient.cs index 73ee82a03..86ca7d229 100644 --- a/.dotnet.azure/src/Custom/Chat/AzureChatClient.cs +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatClient.cs @@ -30,8 +30,8 @@ internal AzureChatClient(ClientPipeline pipeline, string deploymentName, Uri end options ??= new(); _deploymentName = deploymentName; - _apiVersion = options.Version; _endpoint = endpoint; + _apiVersion = options.Version; } protected AzureChatClient() diff --git a/.dotnet.azure/src/Custom/Chat/AzureChatCompletion.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatCompletion.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/AzureChatCompletion.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatCompletion.cs diff --git a/.dotnet.azure/src/Custom/Chat/AzureChatCompletionOptions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatCompletionOptions.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/AzureChatCompletionOptions.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatCompletionOptions.cs diff --git a/.dotnet.azure/src/Custom/Chat/AzureStreamingChatCompletionUpdate.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureStreamingChatCompletionUpdate.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/AzureStreamingChatCompletionUpdate.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureStreamingChatCompletionUpdate.cs diff --git a/.dotnet.azure/src/Custom/Chat/GeneratorStubs.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/GeneratorStubs.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/GeneratorStubs.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/GeneratorStubs.cs diff --git a/.dotnet.azure/src/Custom/Chat/Internal/GeneratorStubs.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/Internal/GeneratorStubs.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/Internal/GeneratorStubs.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/Internal/GeneratorStubs.cs diff --git a/.dotnet.azure/src/Custom/Chat/Internal/InternalAzureChatDataSourceEndpointVectorizationSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/Internal/InternalAzureChatDataSourceEndpointVectorizationSource.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/Internal/InternalAzureChatDataSourceEndpointVectorizationSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/Internal/InternalAzureChatDataSourceEndpointVectorizationSource.cs diff --git a/.dotnet.azure/src/Custom/Chat/Internal/InternalAzureCosmosDBChatDataSourceParameters.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/Internal/InternalAzureCosmosDBChatDataSourceParameters.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/Internal/InternalAzureCosmosDBChatDataSourceParameters.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/Internal/InternalAzureCosmosDBChatDataSourceParameters.cs diff --git a/.dotnet.azure/src/Custom/Chat/Internal/InternalAzureMachineLearningIndexDataSourceParameters.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/Internal/InternalAzureMachineLearningIndexDataSourceParameters.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/Internal/InternalAzureMachineLearningIndexDataSourceParameters.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/Internal/InternalAzureMachineLearningIndexDataSourceParameters.cs diff --git a/.dotnet.azure/src/Custom/Chat/Internal/InternalAzureSearchChatDataSourceParameters.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/Internal/InternalAzureSearchChatDataSourceParameters.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/Internal/InternalAzureSearchChatDataSourceParameters.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/Internal/InternalAzureSearchChatDataSourceParameters.cs diff --git a/.dotnet.azure/src/Custom/Chat/Internal/InternalElasticsearchChatDataSourceParameters.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/Internal/InternalElasticsearchChatDataSourceParameters.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/Internal/InternalElasticsearchChatDataSourceParameters.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/Internal/InternalElasticsearchChatDataSourceParameters.cs diff --git a/.dotnet.azure/src/Custom/Chat/Internal/InternalPineconeChatDataSourceParameters.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/Internal/InternalPineconeChatDataSourceParameters.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/Internal/InternalPineconeChatDataSourceParameters.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/Internal/InternalPineconeChatDataSourceParameters.cs diff --git a/.dotnet.azure/src/Custom/Chat/OnYourData/AzureCosmosDBChatDataSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/AzureCosmosDBChatDataSource.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/OnYourData/AzureCosmosDBChatDataSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/AzureCosmosDBChatDataSource.cs diff --git a/.dotnet.azure/src/Custom/Chat/OnYourData/AzureMachineLearningIndexChatDataSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/AzureMachineLearningIndexChatDataSource.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/OnYourData/AzureMachineLearningIndexChatDataSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/AzureMachineLearningIndexChatDataSource.cs diff --git a/.dotnet.azure/src/Custom/Chat/OnYourData/AzureSearchChatDataSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/AzureSearchChatDataSource.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/OnYourData/AzureSearchChatDataSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/AzureSearchChatDataSource.cs diff --git a/.dotnet.azure/src/Custom/Chat/OnYourData/DataSourceAuthentication.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/DataSourceAuthentication.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/OnYourData/DataSourceAuthentication.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/DataSourceAuthentication.cs diff --git a/.dotnet.azure/src/Custom/Chat/OnYourData/DataSourceFieldMappings.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/DataSourceFieldMappings.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/OnYourData/DataSourceFieldMappings.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/DataSourceFieldMappings.cs diff --git a/.dotnet.azure/src/Custom/Chat/OnYourData/DataSourceOutputContextFlags.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/DataSourceOutputContextFlags.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/OnYourData/DataSourceOutputContextFlags.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/DataSourceOutputContextFlags.Serialization.cs diff --git a/.dotnet.azure/src/Custom/Chat/OnYourData/DataSourceOutputContextFlags.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/DataSourceOutputContextFlags.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/OnYourData/DataSourceOutputContextFlags.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/DataSourceOutputContextFlags.cs diff --git a/.dotnet.azure/src/Custom/Chat/OnYourData/DataSourceVectorizer.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/DataSourceVectorizer.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/OnYourData/DataSourceVectorizer.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/DataSourceVectorizer.cs diff --git a/.dotnet.azure/src/Custom/Chat/OnYourData/ElasticsearchChatDataSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ElasticsearchChatDataSource.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/OnYourData/ElasticsearchChatDataSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ElasticsearchChatDataSource.cs diff --git a/.dotnet.azure/src/Custom/Chat/OnYourData/PineconeChatDataSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/PineconeChatDataSource.cs similarity index 100% rename from .dotnet.azure/src/Custom/Chat/OnYourData/PineconeChatDataSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/PineconeChatDataSource.cs diff --git a/.dotnet.azure/src/Custom/Common/AdditionalPropertyHelpers.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/AdditionalPropertyHelpers.cs similarity index 100% rename from .dotnet.azure/src/Custom/Common/AdditionalPropertyHelpers.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/AdditionalPropertyHelpers.cs diff --git a/.dotnet.azure/src/Custom/Common/ContentFilterBlocklistResult.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/ContentFilterBlocklistResult.cs similarity index 100% rename from .dotnet.azure/src/Custom/Common/ContentFilterBlocklistResult.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/ContentFilterBlocklistResult.cs diff --git a/.dotnet.azure/src/Custom/Common/ContentFilterResultForPrompt.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/ContentFilterResultForPrompt.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Custom/Common/ContentFilterResultForPrompt.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/ContentFilterResultForPrompt.Serialization.cs diff --git a/.dotnet.azure/src/Custom/Common/ContentFilterResultForPrompt.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/ContentFilterResultForPrompt.cs similarity index 100% rename from .dotnet.azure/src/Custom/Common/ContentFilterResultForPrompt.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/ContentFilterResultForPrompt.cs diff --git a/.dotnet.azure/src/Custom/Common/ContentFilterResultForResponse.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/ContentFilterResultForResponse.cs similarity index 100% rename from .dotnet.azure/src/Custom/Common/ContentFilterResultForResponse.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/ContentFilterResultForResponse.cs diff --git a/.dotnet.azure/src/Custom/Common/GeneratorStubs.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/GeneratorStubs.cs similarity index 100% rename from .dotnet.azure/src/Custom/Common/GeneratorStubs.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/GeneratorStubs.cs diff --git a/.dotnet.azure/src/Custom/Embeddings/AzureEmbeddingClient.Protocol.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Embeddings/AzureEmbeddingClient.Protocol.cs similarity index 100% rename from .dotnet.azure/src/Custom/Embeddings/AzureEmbeddingClient.Protocol.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Embeddings/AzureEmbeddingClient.Protocol.cs diff --git a/.dotnet.azure/src/Custom/Embeddings/AzureEmbeddingClient.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Embeddings/AzureEmbeddingClient.cs similarity index 100% rename from .dotnet.azure/src/Custom/Embeddings/AzureEmbeddingClient.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Embeddings/AzureEmbeddingClient.cs diff --git a/.dotnet.azure/src/Custom/Files/AzureFileClient.Protocol.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Files/AzureFileClient.Protocol.cs similarity index 100% rename from .dotnet.azure/src/Custom/Files/AzureFileClient.Protocol.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Files/AzureFileClient.Protocol.cs diff --git a/.dotnet.azure/src/Custom/Files/AzureFileClient.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Files/AzureFileClient.cs similarity index 100% rename from .dotnet.azure/src/Custom/Files/AzureFileClient.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Files/AzureFileClient.cs diff --git a/.dotnet.azure/src/Custom/FineTuning/AzureFineTuningClient.Extensions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/AzureFineTuningClient.Extensions.cs similarity index 100% rename from .dotnet.azure/src/Custom/FineTuning/AzureFineTuningClient.Extensions.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/AzureFineTuningClient.Extensions.cs diff --git a/.dotnet.azure/src/Custom/FineTuning/AzureFineTuningClient.Protocol.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/AzureFineTuningClient.Protocol.cs similarity index 72% rename from .dotnet.azure/src/Custom/FineTuning/AzureFineTuningClient.Protocol.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/AzureFineTuningClient.Protocol.cs index 4332e1042..479f12cde 100644 --- a/.dotnet.azure/src/Custom/FineTuning/AzureFineTuningClient.Protocol.cs +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/AzureFineTuningClient.Protocol.cs @@ -39,37 +39,37 @@ public override async Task GetJobAsync(string fineTuningJobId, Req public override IEnumerable GetJobs(string after, int? limit, RequestOptions options) { - FineTuningJobsPageEnumerator enumerator = new(Pipeline, _endpoint, after, limit, options); + AzureFineTuningJobsPageEnumerator enumerator = new(Pipeline, _endpoint, after, limit, _apiVersion, options); return PageCollectionHelpers.Create(enumerator); } public override IAsyncEnumerable GetJobsAsync(string after, int? limit, RequestOptions options) { - FineTuningJobsPageEnumerator enumerator = new(Pipeline, _endpoint, after, limit, options); + AzureFineTuningJobsPageEnumerator enumerator = new(Pipeline, _endpoint, after, limit, _apiVersion, options); return PageCollectionHelpers.CreateAsync(enumerator); } public override IEnumerable GetJobEvents(string fineTuningJobId, string after, int? limit, RequestOptions options) { - FineTuningJobEventsPageEnumerator enumerator = new(Pipeline, _endpoint, fineTuningJobId, after, limit, options); + AzureFineTuningJobEventsPageEnumerator enumerator = new(Pipeline, _endpoint, fineTuningJobId, after, limit, _apiVersion, options); return PageCollectionHelpers.Create(enumerator); } public override IAsyncEnumerable GetJobEventsAsync(string fineTuningJobId, string after, int? limit, RequestOptions options) { - FineTuningJobEventsPageEnumerator enumerator = new(Pipeline, _endpoint, fineTuningJobId, after, limit, options); + AzureFineTuningJobEventsPageEnumerator enumerator = new(Pipeline, _endpoint, fineTuningJobId, after, limit, _apiVersion, options); return PageCollectionHelpers.CreateAsync(enumerator); } public override IEnumerable GetJobCheckpoints(string fineTuningJobId, string after, int? limit, RequestOptions options) { - FineTuningJobCheckpointsPageEnumerator enumerator = new(Pipeline, _endpoint, fineTuningJobId, after, limit, options); + AzureFineTuningJobCheckpointsPageEnumerator enumerator = new(Pipeline, _endpoint, fineTuningJobId, after, limit, _apiVersion, options); return PageCollectionHelpers.Create(enumerator); } public override IAsyncEnumerable GetJobCheckpointsAsync(string fineTuningJobId, string after, int? limit, RequestOptions options) { - FineTuningJobCheckpointsPageEnumerator enumerator = new(Pipeline, _endpoint, fineTuningJobId, after, limit, options); + AzureFineTuningJobCheckpointsPageEnumerator enumerator = new(Pipeline, _endpoint, fineTuningJobId, after, limit, _apiVersion, options); return PageCollectionHelpers.CreateAsync(enumerator); } @@ -110,16 +110,6 @@ private PipelineMessage CreateCreateJobRequestMessage(BinaryContent content, Req .WithOptions(options) .Build(); - private PipelineMessage CreateGetJobsRequestMessage(string after, int? limit, RequestOptions options) - => new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion) - .WithMethod("GET") - .WithPath("fine_tuning", "jobs") - .WithOptionalQueryParameter("after", after) - .WithOptionalQueryParameter("limit", limit) - .WithAccept("application/json") - .WithOptions(options) - .Build(); - private PipelineMessage CreateGetJobRequestMessage(string jobId, RequestOptions options) => new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion) .WithMethod("GET") @@ -128,26 +118,6 @@ private PipelineMessage CreateGetJobRequestMessage(string jobId, RequestOptions .WithOptions(options) .Build(); - private PipelineMessage CreateGetJobEventsRequestMessage(string jobId, string after, int? limit, RequestOptions options) - => new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion) - .WithMethod("GET") - .WithPath("fine_tuning", "jobs", jobId, "events") - .WithOptionalQueryParameter("after", after) - .WithOptionalQueryParameter("limit", limit) - .WithAccept("application/json") - .WithOptions(options) - .Build(); - - private PipelineMessage CreateGetJobCheckpointsRequestMessage(string jobId, string after, int? limit, RequestOptions options) - => new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion) - .WithMethod("GET") - .WithPath("fine_tuning", "jobs", jobId, "checkpoints") - .WithOptionalQueryParameter("after", after) - .WithOptionalQueryParameter("limit", limit) - .WithAccept("application/json") - .WithOptions(options) - .Build(); - private PipelineMessage CreateCancelJobRequestMessage(string jobId, RequestOptions options) => new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion) .WithMethod("POST") diff --git a/.dotnet.azure/src/Custom/FineTuning/AzureFineTuningClient.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/AzureFineTuningClient.cs similarity index 100% rename from .dotnet.azure/src/Custom/FineTuning/AzureFineTuningClient.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/AzureFineTuningClient.cs diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/Pagination/AzureFineTuningJobCheckpointsPageEnumerator.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/Pagination/AzureFineTuningJobCheckpointsPageEnumerator.cs new file mode 100644 index 000000000..f002a2bf5 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/Pagination/AzureFineTuningJobCheckpointsPageEnumerator.cs @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Text.Json; + +namespace Azure.AI.OpenAI.FineTuning; + +internal class AzureFineTuningJobCheckpointsPageEnumerator : FineTuningJobCheckpointsPageEnumerator +{ + private readonly ClientPipeline _pipeline; + private readonly Uri _endpoint; + private readonly string _apiVersion; + + private readonly int? _limit; + private readonly string _jobId; + private readonly RequestOptions _options; + private string? _after; + + public AzureFineTuningJobCheckpointsPageEnumerator( + ClientPipeline pipeline, + Uri endpoint, + string jobId, string after, int? limit, + string apiVersion, + RequestOptions options) + : base(pipeline, endpoint, jobId, after!, limit, options) + { + _pipeline = pipeline; + _endpoint = endpoint; + _apiVersion = apiVersion; + + _jobId = jobId; + _after = after; + _limit = limit; + _options = options; + } + + public override async Task GetNextAsync(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response?.Content); + + if (doc?.RootElement.TryGetProperty("data", out JsonElement dataElement) == true + && dataElement.EnumerateArray().LastOrDefault().TryGetProperty("id", out JsonElement idElement) == true) + { + _after = idElement.GetString(); + } + + return await GetJobCheckpointsAsync(_jobId, _after!, _limit, _options).ConfigureAwait(false); + } + + public override ClientResult GetNext(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response?.Content); + + if (doc?.RootElement.TryGetProperty("data", out JsonElement dataElement) == true + && dataElement.EnumerateArray().LastOrDefault().TryGetProperty("id", out JsonElement idElement) == true) + { + _after = idElement.GetString(); + } + + return GetJobCheckpoints(_jobId, _after!, _limit, _options); + } + + internal override async Task GetJobCheckpointsAsync(string jobId, string after, int? limit, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateGetFineTuningJobCheckpointsRequest(jobId, after, limit, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + internal override ClientResult GetJobCheckpoints(string jobId, string after, int? limit, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateGetFineTuningJobCheckpointsRequest(jobId, after, limit, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + internal new PipelineMessage CreateGetFineTuningJobCheckpointsRequest(string fineTuningJobId, string after, int? limit, RequestOptions options) + => new AzureOpenAIPipelineMessageBuilder(_pipeline, _endpoint, _apiVersion) + .WithMethod("GET") + .WithPath("fine_tuning", "jobs", fineTuningJobId, "checkpoints") + .WithOptionalQueryParameter("after", after) + .WithOptionalQueryParameter("limit", limit) + .WithAccept("application/json") + .WithOptions(options) + .Build(); +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/Pagination/AzureFineTuningJobEventsPageEnumerator.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/Pagination/AzureFineTuningJobEventsPageEnumerator.cs new file mode 100644 index 000000000..a18cbe777 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/Pagination/AzureFineTuningJobEventsPageEnumerator.cs @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Text.Json; + +namespace Azure.AI.OpenAI.FineTuning; + +internal class AzureFineTuningJobEventsPageEnumerator : FineTuningJobEventsPageEnumerator +{ + private readonly ClientPipeline _pipeline; + private readonly Uri _endpoint; + private readonly string _apiVersion; + + private readonly int? _limit; + private readonly string _jobId; + private readonly RequestOptions _options; + private string? _after; + + public AzureFineTuningJobEventsPageEnumerator( + ClientPipeline pipeline, + Uri endpoint, + string jobId, string? after, int? limit, + string apiVersion, + RequestOptions options) + : base(pipeline, endpoint, jobId, after!, limit, options) + { + _pipeline = pipeline; + _endpoint = endpoint; + _apiVersion = apiVersion; + + _jobId = jobId; + _after = after; + _limit = limit; + _options = options; + } + + public override async Task GetNextAsync(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response?.Content); + + if (doc?.RootElement.TryGetProperty("data", out JsonElement dataElement) == true + && dataElement.EnumerateArray().LastOrDefault().TryGetProperty("id", out JsonElement idElement) == true) + { + _after = idElement.GetString(); + } + + return await GetJobEventsAsync(_jobId, _after!, _limit, _options).ConfigureAwait(false); + } + + public override ClientResult GetNext(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response?.Content); + + if (doc?.RootElement.TryGetProperty("data", out JsonElement dataElement) == true + && dataElement.EnumerateArray().LastOrDefault().TryGetProperty("id", out JsonElement idElement) == true) + { + _after = idElement.GetString(); + } + + return GetJobEvents(_jobId, _after!, _limit, _options); + } + + internal override async Task GetJobEventsAsync(string jobId, string after, int? limit, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateGetFineTuningEventsRequest(jobId, after, limit, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + internal override ClientResult GetJobEvents(string jobId, string after, int? limit, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateGetFineTuningEventsRequest(jobId, after, limit, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + internal new PipelineMessage CreateGetFineTuningEventsRequest(string fineTuningJobId, string after, int? limit, RequestOptions options) + => new AzureOpenAIPipelineMessageBuilder(_pipeline, _endpoint, _apiVersion) + .WithMethod("GET") + .WithPath("fine_tuning", "jobs", fineTuningJobId, "events") + .WithOptionalQueryParameter("after", after) + .WithOptionalQueryParameter("limit", limit) + .WithAccept("application/json") + .WithOptions(options) + .Build(); +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/Pagination/AzureFineTuningJobsPageEnumerator.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/Pagination/AzureFineTuningJobsPageEnumerator.cs new file mode 100644 index 000000000..ef8a57075 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/FineTuning/Pagination/AzureFineTuningJobsPageEnumerator.cs @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Text.Json; + +namespace Azure.AI.OpenAI.FineTuning; + +internal class AzureFineTuningJobsPageEnumerator : FineTuningJobsPageEnumerator +{ + private readonly ClientPipeline _pipeline; + private readonly Uri _endpoint; + private readonly string _apiVersion; + + private readonly int? _limit; + private readonly RequestOptions _options; + private string? _after; + + public AzureFineTuningJobsPageEnumerator( + ClientPipeline pipeline, + Uri endpoint, + string? after, + int? limit, + string apiVersion, + RequestOptions options) + : base(pipeline, endpoint, after!, limit, options) + { + _pipeline = pipeline; + _endpoint = endpoint; + _apiVersion = apiVersion; + + _after = after; + _limit = limit; + _options = options; + } + + public override async Task GetNextAsync(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response?.Content); + + if (doc?.RootElement.TryGetProperty("data", out JsonElement dataElement) == true + && dataElement.EnumerateArray().LastOrDefault().TryGetProperty("id", out JsonElement idElement) == true) + { + _after = idElement.GetString(); + } + + return await GetJobsAsync(_after!, _limit, _options).ConfigureAwait(false); + } + + public override ClientResult GetNext(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response?.Content); + + if (doc?.RootElement.TryGetProperty("data", out JsonElement dataElement) == true + && dataElement.EnumerateArray().LastOrDefault().TryGetProperty("id", out JsonElement idElement) == true) + { + _after = idElement.GetString(); + } + + return GetJobs(_after!, _limit, _options); + } + + internal override async Task GetJobsAsync(string after, int? limit, RequestOptions options) + { + using PipelineMessage message = CreateGetFineTuningJobsRequest(after, limit, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + internal override ClientResult GetJobs(string after, int? limit, RequestOptions options) + { + using PipelineMessage message = CreateGetFineTuningJobsRequest(after, limit, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + internal new PipelineMessage CreateGetFineTuningJobsRequest(string? after, int? limit, RequestOptions options) + => new AzureOpenAIPipelineMessageBuilder(_pipeline, _endpoint, _apiVersion) + .WithMethod("GET") + .WithPath("fine_tuning", "jobs") + .WithOptionalQueryParameter("after", after) + .WithOptionalQueryParameter("limit", limit) + .WithAccept("application/json") + .WithOptions(options) + .Build(); +} diff --git a/.dotnet.azure/src/Custom/Images/AzureGeneratedImage.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Images/AzureGeneratedImage.cs similarity index 100% rename from .dotnet.azure/src/Custom/Images/AzureGeneratedImage.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Images/AzureGeneratedImage.cs diff --git a/.dotnet.azure/src/Custom/Images/AzureImageClient.Protocol.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Images/AzureImageClient.Protocol.cs similarity index 100% rename from .dotnet.azure/src/Custom/Images/AzureImageClient.Protocol.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Images/AzureImageClient.Protocol.cs diff --git a/.dotnet.azure/src/Custom/Images/AzureImageClient.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Images/AzureImageClient.cs similarity index 100% rename from .dotnet.azure/src/Custom/Images/AzureImageClient.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Images/AzureImageClient.cs diff --git a/.dotnet.azure/src/Custom/Images/GeneratorStubs.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Images/GeneratorStubs.cs similarity index 100% rename from .dotnet.azure/src/Custom/Images/GeneratorStubs.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Images/GeneratorStubs.cs diff --git a/.dotnet.azure/src/Custom/Internal/AzureOpenAIChatError.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Internal/AzureOpenAIChatError.cs similarity index 100% rename from .dotnet.azure/src/Custom/Internal/AzureOpenAIChatError.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Internal/AzureOpenAIChatError.cs diff --git a/.dotnet.azure/src/Custom/Internal/AzureOpenAIDalleError.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Internal/AzureOpenAIDalleError.cs similarity index 100% rename from .dotnet.azure/src/Custom/Internal/AzureOpenAIDalleError.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Internal/AzureOpenAIDalleError.cs diff --git a/.dotnet.azure/src/Custom/Internal/ClientPipelineExtensions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Internal/ClientPipelineExtensions.cs similarity index 100% rename from .dotnet.azure/src/Custom/Internal/ClientPipelineExtensions.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Internal/ClientPipelineExtensions.cs diff --git a/.dotnet.azure/src/Custom/Internal/ClientUriBuilder.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Internal/ClientUriBuilder.cs similarity index 100% rename from .dotnet.azure/src/Custom/Internal/ClientUriBuilder.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Internal/ClientUriBuilder.cs diff --git a/.dotnet.azure/src/Custom/Internal/GeneratorStubs.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Internal/GeneratorStubs.cs similarity index 100% rename from .dotnet.azure/src/Custom/Internal/GeneratorStubs.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/Internal/GeneratorStubs.cs diff --git a/.dotnet.azure/src/Custom/VectorStores/AzureVectorStoreClient.Protocol.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/VectorStores/AzureVectorStoreClient.Protocol.cs similarity index 100% rename from .dotnet.azure/src/Custom/VectorStores/AzureVectorStoreClient.Protocol.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/VectorStores/AzureVectorStoreClient.Protocol.cs diff --git a/.dotnet.azure/src/Custom/VectorStores/AzureVectorStoreClient.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/VectorStores/AzureVectorStoreClient.cs similarity index 100% rename from .dotnet.azure/src/Custom/VectorStores/AzureVectorStoreClient.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/VectorStores/AzureVectorStoreClient.cs diff --git a/.dotnet.azure/src/Custom/VectorStores/Internal/Pagination/AzureVectorStoreFileBatchesPageEnumerator.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/VectorStores/Internal/Pagination/AzureVectorStoreFileBatchesPageEnumerator.cs similarity index 100% rename from .dotnet.azure/src/Custom/VectorStores/Internal/Pagination/AzureVectorStoreFileBatchesPageEnumerator.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/VectorStores/Internal/Pagination/AzureVectorStoreFileBatchesPageEnumerator.cs diff --git a/.dotnet.azure/src/Custom/VectorStores/Internal/Pagination/AzureVectorStoreFilesPageEnumerator.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/VectorStores/Internal/Pagination/AzureVectorStoreFilesPageEnumerator.cs similarity index 100% rename from .dotnet.azure/src/Custom/VectorStores/Internal/Pagination/AzureVectorStoreFilesPageEnumerator.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/VectorStores/Internal/Pagination/AzureVectorStoreFilesPageEnumerator.cs diff --git a/.dotnet.azure/src/Custom/VectorStores/Internal/Pagination/AzureVectorStoresPageEnumerator.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/VectorStores/Internal/Pagination/AzureVectorStoresPageEnumerator.cs similarity index 100% rename from .dotnet.azure/src/Custom/VectorStores/Internal/Pagination/AzureVectorStoresPageEnumerator.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Custom/VectorStores/Internal/Pagination/AzureVectorStoresPageEnumerator.cs diff --git a/.dotnet.azure/src/Generated/AzureChatCitation.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatCitation.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureChatCitation.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatCitation.Serialization.cs diff --git a/.dotnet.azure/src/Generated/AzureChatCitation.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatCitation.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureChatCitation.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatCitation.cs diff --git a/.dotnet.azure/src/Generated/AzureChatDataSource.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatDataSource.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureChatDataSource.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatDataSource.Serialization.cs diff --git a/.dotnet.azure/src/Generated/AzureChatDataSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatDataSource.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureChatDataSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatDataSource.cs diff --git a/.dotnet.azure/src/Generated/AzureChatMessageContext.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatMessageContext.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureChatMessageContext.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatMessageContext.Serialization.cs diff --git a/.dotnet.azure/src/Generated/AzureChatMessageContext.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatMessageContext.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureChatMessageContext.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatMessageContext.cs diff --git a/.dotnet.azure/src/Generated/AzureChatRetrievedDocument.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatRetrievedDocument.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureChatRetrievedDocument.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatRetrievedDocument.Serialization.cs diff --git a/.dotnet.azure/src/Generated/AzureChatRetrievedDocument.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatRetrievedDocument.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureChatRetrievedDocument.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatRetrievedDocument.cs diff --git a/.dotnet.azure/src/Generated/AzureChatRetrievedDocumentFilterReason.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatRetrievedDocumentFilterReason.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureChatRetrievedDocumentFilterReason.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureChatRetrievedDocumentFilterReason.cs diff --git a/.dotnet.azure/src/Generated/AzureCosmosDBChatDataSource.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureCosmosDBChatDataSource.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureCosmosDBChatDataSource.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureCosmosDBChatDataSource.Serialization.cs diff --git a/.dotnet.azure/src/Generated/AzureCosmosDBChatDataSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureCosmosDBChatDataSource.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureCosmosDBChatDataSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureCosmosDBChatDataSource.cs diff --git a/.dotnet.azure/src/Generated/AzureMachineLearningIndexChatDataSource.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureMachineLearningIndexChatDataSource.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureMachineLearningIndexChatDataSource.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureMachineLearningIndexChatDataSource.Serialization.cs diff --git a/.dotnet.azure/src/Generated/AzureMachineLearningIndexChatDataSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureMachineLearningIndexChatDataSource.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureMachineLearningIndexChatDataSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureMachineLearningIndexChatDataSource.cs diff --git a/.dotnet.azure/src/Generated/AzureOpenAIChatError.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIChatError.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureOpenAIChatError.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIChatError.Serialization.cs diff --git a/.dotnet.azure/src/Generated/AzureOpenAIChatError.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIChatError.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureOpenAIChatError.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIChatError.cs diff --git a/.dotnet.azure/src/Generated/AzureOpenAIChatErrorResponse.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIChatErrorResponse.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureOpenAIChatErrorResponse.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIChatErrorResponse.Serialization.cs diff --git a/.dotnet.azure/src/Generated/AzureOpenAIChatErrorResponse.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIChatErrorResponse.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureOpenAIChatErrorResponse.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIChatErrorResponse.cs diff --git a/.dotnet.azure/src/Generated/AzureOpenAIDalleError.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIDalleError.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureOpenAIDalleError.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIDalleError.Serialization.cs diff --git a/.dotnet.azure/src/Generated/AzureOpenAIDalleError.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIDalleError.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureOpenAIDalleError.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIDalleError.cs diff --git a/.dotnet.azure/src/Generated/AzureOpenAIDalleErrorResponse.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIDalleErrorResponse.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureOpenAIDalleErrorResponse.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIDalleErrorResponse.Serialization.cs diff --git a/.dotnet.azure/src/Generated/AzureOpenAIDalleErrorResponse.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIDalleErrorResponse.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureOpenAIDalleErrorResponse.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIDalleErrorResponse.cs diff --git a/.dotnet.azure/src/Generated/AzureSearchChatDataSource.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureSearchChatDataSource.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureSearchChatDataSource.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureSearchChatDataSource.Serialization.cs diff --git a/.dotnet.azure/src/Generated/AzureSearchChatDataSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureSearchChatDataSource.cs similarity index 100% rename from .dotnet.azure/src/Generated/AzureSearchChatDataSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureSearchChatDataSource.cs diff --git a/.dotnet.azure/src/Generated/ContentFilterBlocklistResult.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterBlocklistResult.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/ContentFilterBlocklistResult.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterBlocklistResult.Serialization.cs diff --git a/.dotnet.azure/src/Generated/ContentFilterBlocklistResult.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterBlocklistResult.cs similarity index 100% rename from .dotnet.azure/src/Generated/ContentFilterBlocklistResult.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterBlocklistResult.cs diff --git a/.dotnet.azure/src/Generated/ContentFilterDetectionResult.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterDetectionResult.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/ContentFilterDetectionResult.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterDetectionResult.Serialization.cs diff --git a/.dotnet.azure/src/Generated/ContentFilterDetectionResult.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterDetectionResult.cs similarity index 100% rename from .dotnet.azure/src/Generated/ContentFilterDetectionResult.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterDetectionResult.cs diff --git a/.dotnet.azure/src/Generated/ContentFilterProtectedMaterialCitedResult.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterProtectedMaterialCitedResult.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/ContentFilterProtectedMaterialCitedResult.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterProtectedMaterialCitedResult.Serialization.cs diff --git a/.dotnet.azure/src/Generated/ContentFilterProtectedMaterialCitedResult.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterProtectedMaterialCitedResult.cs similarity index 100% rename from .dotnet.azure/src/Generated/ContentFilterProtectedMaterialCitedResult.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterProtectedMaterialCitedResult.cs diff --git a/.dotnet.azure/src/Generated/ContentFilterProtectedMaterialResult.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterProtectedMaterialResult.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/ContentFilterProtectedMaterialResult.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterProtectedMaterialResult.Serialization.cs diff --git a/.dotnet.azure/src/Generated/ContentFilterProtectedMaterialResult.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterProtectedMaterialResult.cs similarity index 100% rename from .dotnet.azure/src/Generated/ContentFilterProtectedMaterialResult.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterProtectedMaterialResult.cs diff --git a/.dotnet.azure/src/Generated/ContentFilterResultForPrompt.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterResultForPrompt.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/ContentFilterResultForPrompt.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterResultForPrompt.Serialization.cs diff --git a/.dotnet.azure/src/Generated/ContentFilterResultForPrompt.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterResultForPrompt.cs similarity index 100% rename from .dotnet.azure/src/Generated/ContentFilterResultForPrompt.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterResultForPrompt.cs diff --git a/.dotnet.azure/src/Generated/ContentFilterResultForResponse.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterResultForResponse.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/ContentFilterResultForResponse.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterResultForResponse.Serialization.cs diff --git a/.dotnet.azure/src/Generated/ContentFilterResultForResponse.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterResultForResponse.cs similarity index 100% rename from .dotnet.azure/src/Generated/ContentFilterResultForResponse.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterResultForResponse.cs diff --git a/.dotnet.azure/src/Generated/ContentFilterSeverity.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterSeverity.cs similarity index 100% rename from .dotnet.azure/src/Generated/ContentFilterSeverity.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterSeverity.cs diff --git a/.dotnet.azure/src/Generated/ContentFilterSeverityResult.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterSeverityResult.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/ContentFilterSeverityResult.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterSeverityResult.Serialization.cs diff --git a/.dotnet.azure/src/Generated/ContentFilterSeverityResult.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterSeverityResult.cs similarity index 100% rename from .dotnet.azure/src/Generated/ContentFilterSeverityResult.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterSeverityResult.cs diff --git a/.dotnet.azure/src/Generated/DataSourceAuthentication.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/DataSourceAuthentication.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/DataSourceAuthentication.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/DataSourceAuthentication.Serialization.cs diff --git a/.dotnet.azure/src/Generated/DataSourceAuthentication.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/DataSourceAuthentication.cs similarity index 100% rename from .dotnet.azure/src/Generated/DataSourceAuthentication.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/DataSourceAuthentication.cs diff --git a/.dotnet.azure/src/Generated/DataSourceFieldMappings.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/DataSourceFieldMappings.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/DataSourceFieldMappings.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/DataSourceFieldMappings.Serialization.cs diff --git a/.dotnet.azure/src/Generated/DataSourceFieldMappings.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/DataSourceFieldMappings.cs similarity index 100% rename from .dotnet.azure/src/Generated/DataSourceFieldMappings.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/DataSourceFieldMappings.cs diff --git a/.dotnet.azure/src/Generated/DataSourceQueryType.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/DataSourceQueryType.cs similarity index 100% rename from .dotnet.azure/src/Generated/DataSourceQueryType.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/DataSourceQueryType.cs diff --git a/.dotnet.azure/src/Generated/DataSourceVectorizer.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/DataSourceVectorizer.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/DataSourceVectorizer.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/DataSourceVectorizer.Serialization.cs diff --git a/.dotnet.azure/src/Generated/DataSourceVectorizer.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/DataSourceVectorizer.cs similarity index 100% rename from .dotnet.azure/src/Generated/DataSourceVectorizer.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/DataSourceVectorizer.cs diff --git a/.dotnet.azure/src/Generated/ElasticsearchChatDataSource.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ElasticsearchChatDataSource.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/ElasticsearchChatDataSource.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ElasticsearchChatDataSource.Serialization.cs diff --git a/.dotnet.azure/src/Generated/ElasticsearchChatDataSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ElasticsearchChatDataSource.cs similarity index 100% rename from .dotnet.azure/src/Generated/ElasticsearchChatDataSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ElasticsearchChatDataSource.cs diff --git a/.dotnet.azure/src/Generated/ImageContentFilterResultForPrompt.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ImageContentFilterResultForPrompt.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/ImageContentFilterResultForPrompt.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ImageContentFilterResultForPrompt.Serialization.cs diff --git a/.dotnet.azure/src/Generated/ImageContentFilterResultForPrompt.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ImageContentFilterResultForPrompt.cs similarity index 100% rename from .dotnet.azure/src/Generated/ImageContentFilterResultForPrompt.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ImageContentFilterResultForPrompt.cs diff --git a/.dotnet.azure/src/Generated/ImageContentFilterResultForResponse.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ImageContentFilterResultForResponse.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/ImageContentFilterResultForResponse.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ImageContentFilterResultForResponse.Serialization.cs diff --git a/.dotnet.azure/src/Generated/ImageContentFilterResultForResponse.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ImageContentFilterResultForResponse.cs similarity index 100% rename from .dotnet.azure/src/Generated/ImageContentFilterResultForResponse.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/ImageContentFilterResultForResponse.cs diff --git a/.dotnet.azure/src/Generated/Internal/Argument.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/Argument.cs similarity index 100% rename from .dotnet.azure/src/Generated/Internal/Argument.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/Argument.cs diff --git a/.dotnet.azure/src/Generated/Internal/BinaryContentHelper.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/BinaryContentHelper.cs similarity index 100% rename from .dotnet.azure/src/Generated/Internal/BinaryContentHelper.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/BinaryContentHelper.cs diff --git a/.dotnet.azure/src/Generated/Internal/ChangeTrackingDictionary.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/ChangeTrackingDictionary.cs similarity index 100% rename from .dotnet.azure/src/Generated/Internal/ChangeTrackingDictionary.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/ChangeTrackingDictionary.cs diff --git a/.dotnet.azure/src/Generated/Internal/ChangeTrackingList.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/ChangeTrackingList.cs similarity index 100% rename from .dotnet.azure/src/Generated/Internal/ChangeTrackingList.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/ChangeTrackingList.cs diff --git a/.dotnet.azure/src/Generated/Internal/ClientPipelineExtensions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/ClientPipelineExtensions.cs similarity index 100% rename from .dotnet.azure/src/Generated/Internal/ClientPipelineExtensions.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/ClientPipelineExtensions.cs diff --git a/.dotnet.azure/src/Generated/Internal/ClientUriBuilder.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/ClientUriBuilder.cs similarity index 100% rename from .dotnet.azure/src/Generated/Internal/ClientUriBuilder.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/ClientUriBuilder.cs diff --git a/.dotnet.azure/src/Generated/Internal/ErrorResult.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/ErrorResult.cs similarity index 100% rename from .dotnet.azure/src/Generated/Internal/ErrorResult.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/ErrorResult.cs diff --git a/.dotnet.azure/src/Generated/Internal/ModelSerializationExtensions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/ModelSerializationExtensions.cs similarity index 100% rename from .dotnet.azure/src/Generated/Internal/ModelSerializationExtensions.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/ModelSerializationExtensions.cs diff --git a/.dotnet.azure/src/Generated/Internal/MultipartFormDataBinaryContent.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/MultipartFormDataBinaryContent.cs similarity index 100% rename from .dotnet.azure/src/Generated/Internal/MultipartFormDataBinaryContent.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/MultipartFormDataBinaryContent.cs diff --git a/.dotnet.azure/src/Generated/Internal/Optional.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/Optional.cs similarity index 100% rename from .dotnet.azure/src/Generated/Internal/Optional.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/Optional.cs diff --git a/.dotnet.azure/src/Generated/Internal/Utf8JsonBinaryContent.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/Utf8JsonBinaryContent.cs similarity index 100% rename from .dotnet.azure/src/Generated/Internal/Utf8JsonBinaryContent.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/Internal/Utf8JsonBinaryContent.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceAccessTokenAuthenticationOptions.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceAccessTokenAuthenticationOptions.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceAccessTokenAuthenticationOptions.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceAccessTokenAuthenticationOptions.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceAccessTokenAuthenticationOptions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceAccessTokenAuthenticationOptions.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceAccessTokenAuthenticationOptions.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceAccessTokenAuthenticationOptions.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceApiKeyAuthenticationOptions.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceApiKeyAuthenticationOptions.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceApiKeyAuthenticationOptions.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceApiKeyAuthenticationOptions.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceApiKeyAuthenticationOptions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceApiKeyAuthenticationOptions.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceApiKeyAuthenticationOptions.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceApiKeyAuthenticationOptions.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceConnectionStringAuthenticationOptions.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceConnectionStringAuthenticationOptions.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceConnectionStringAuthenticationOptions.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceConnectionStringAuthenticationOptions.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceConnectionStringAuthenticationOptions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceConnectionStringAuthenticationOptions.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceConnectionStringAuthenticationOptions.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceConnectionStringAuthenticationOptions.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceDeploymentNameVectorizationSource.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceDeploymentNameVectorizationSource.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceDeploymentNameVectorizationSource.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceDeploymentNameVectorizationSource.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceDeploymentNameVectorizationSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceDeploymentNameVectorizationSource.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceDeploymentNameVectorizationSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceDeploymentNameVectorizationSource.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceEncodedApiKeyAuthenticationOptions.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceEncodedApiKeyAuthenticationOptions.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceEncodedApiKeyAuthenticationOptions.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceEncodedApiKeyAuthenticationOptions.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceEncodedApiKeyAuthenticationOptions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceEncodedApiKeyAuthenticationOptions.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceEncodedApiKeyAuthenticationOptions.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceEncodedApiKeyAuthenticationOptions.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceEndpointVectorizationSource.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceEndpointVectorizationSource.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceEndpointVectorizationSource.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceEndpointVectorizationSource.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceEndpointVectorizationSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceEndpointVectorizationSource.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceEndpointVectorizationSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceEndpointVectorizationSource.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceKeyAndKeyIdAuthenticationOptions.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceKeyAndKeyIdAuthenticationOptions.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceKeyAndKeyIdAuthenticationOptions.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceKeyAndKeyIdAuthenticationOptions.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceKeyAndKeyIdAuthenticationOptions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceKeyAndKeyIdAuthenticationOptions.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceKeyAndKeyIdAuthenticationOptions.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceKeyAndKeyIdAuthenticationOptions.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceModelIdVectorizationSource.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceModelIdVectorizationSource.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceModelIdVectorizationSource.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceModelIdVectorizationSource.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceModelIdVectorizationSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceModelIdVectorizationSource.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceModelIdVectorizationSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceModelIdVectorizationSource.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceSystemAssignedManagedIdentityAuthenticationOptions.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceSystemAssignedManagedIdentityAuthenticationOptions.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceSystemAssignedManagedIdentityAuthenticationOptions.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceSystemAssignedManagedIdentityAuthenticationOptions.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceSystemAssignedManagedIdentityAuthenticationOptions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceSystemAssignedManagedIdentityAuthenticationOptions.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceSystemAssignedManagedIdentityAuthenticationOptions.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceSystemAssignedManagedIdentityAuthenticationOptions.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceUserAssignedManagedIdentityAuthenticationOptions.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceUserAssignedManagedIdentityAuthenticationOptions.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceUserAssignedManagedIdentityAuthenticationOptions.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceUserAssignedManagedIdentityAuthenticationOptions.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureChatDataSourceUserAssignedManagedIdentityAuthenticationOptions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceUserAssignedManagedIdentityAuthenticationOptions.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureChatDataSourceUserAssignedManagedIdentityAuthenticationOptions.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureChatDataSourceUserAssignedManagedIdentityAuthenticationOptions.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureContentFilterBlocklistIdResult.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterBlocklistIdResult.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureContentFilterBlocklistIdResult.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterBlocklistIdResult.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureContentFilterBlocklistIdResult.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterBlocklistIdResult.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureContentFilterBlocklistIdResult.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterBlocklistIdResult.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureContentFilterBlocklistResultDetail.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterBlocklistResultDetail.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureContentFilterBlocklistResultDetail.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterBlocklistResultDetail.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureContentFilterBlocklistResultDetail.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterBlocklistResultDetail.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureContentFilterBlocklistResultDetail.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterBlocklistResultDetail.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResults.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResults.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResults.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResults.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResults.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResults.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResults.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResults.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResultsError.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResultsError.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResultsError.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResultsError.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResultsError.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResultsError.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResultsError.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureContentFilterResultForPromptContentFilterResultsError.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureCosmosDBChatDataSourceParameters.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureCosmosDBChatDataSourceParameters.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureCosmosDBChatDataSourceParameters.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureCosmosDBChatDataSourceParameters.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureCosmosDBChatDataSourceParameters.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureCosmosDBChatDataSourceParameters.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureCosmosDBChatDataSourceParameters.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureCosmosDBChatDataSourceParameters.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureMachineLearningIndexChatDataSourceParameters.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureMachineLearningIndexChatDataSourceParameters.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureMachineLearningIndexChatDataSourceParameters.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureMachineLearningIndexChatDataSourceParameters.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureMachineLearningIndexChatDataSourceParameters.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureMachineLearningIndexChatDataSourceParameters.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureMachineLearningIndexChatDataSourceParameters.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureMachineLearningIndexChatDataSourceParameters.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureOpenAIChatErrorInnerError.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureOpenAIChatErrorInnerError.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureOpenAIChatErrorInnerError.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureOpenAIChatErrorInnerError.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureOpenAIChatErrorInnerError.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureOpenAIChatErrorInnerError.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureOpenAIChatErrorInnerError.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureOpenAIChatErrorInnerError.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureOpenAIChatErrorInnerErrorCode.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureOpenAIChatErrorInnerErrorCode.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureOpenAIChatErrorInnerErrorCode.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureOpenAIChatErrorInnerErrorCode.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureOpenAIDalleErrorInnerError.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureOpenAIDalleErrorInnerError.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureOpenAIDalleErrorInnerError.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureOpenAIDalleErrorInnerError.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureOpenAIDalleErrorInnerError.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureOpenAIDalleErrorInnerError.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureOpenAIDalleErrorInnerError.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureOpenAIDalleErrorInnerError.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureOpenAIDalleErrorInnerErrorCode.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureOpenAIDalleErrorInnerErrorCode.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureOpenAIDalleErrorInnerErrorCode.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureOpenAIDalleErrorInnerErrorCode.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureSearchChatDataSourceParameters.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureSearchChatDataSourceParameters.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureSearchChatDataSourceParameters.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureSearchChatDataSourceParameters.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureSearchChatDataSourceParameters.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureSearchChatDataSourceParameters.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureSearchChatDataSourceParameters.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureSearchChatDataSourceParameters.cs diff --git a/.dotnet.azure/src/Generated/InternalAzureSearchChatDataSourceParametersIncludeContext.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureSearchChatDataSourceParametersIncludeContext.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalAzureSearchChatDataSourceParametersIncludeContext.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalAzureSearchChatDataSourceParametersIncludeContext.cs diff --git a/.dotnet.azure/src/Generated/InternalElasticsearchChatDataSourceParameters.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalElasticsearchChatDataSourceParameters.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalElasticsearchChatDataSourceParameters.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalElasticsearchChatDataSourceParameters.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalElasticsearchChatDataSourceParameters.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalElasticsearchChatDataSourceParameters.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalElasticsearchChatDataSourceParameters.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalElasticsearchChatDataSourceParameters.cs diff --git a/.dotnet.azure/src/Generated/InternalPineconeChatDataSourceParameters.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalPineconeChatDataSourceParameters.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalPineconeChatDataSourceParameters.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalPineconeChatDataSourceParameters.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalPineconeChatDataSourceParameters.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalPineconeChatDataSourceParameters.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalPineconeChatDataSourceParameters.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalPineconeChatDataSourceParameters.cs diff --git a/.dotnet.azure/src/Generated/InternalUnknownAzureChatDataSource.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalUnknownAzureChatDataSource.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalUnknownAzureChatDataSource.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalUnknownAzureChatDataSource.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalUnknownAzureChatDataSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalUnknownAzureChatDataSource.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalUnknownAzureChatDataSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalUnknownAzureChatDataSource.cs diff --git a/.dotnet.azure/src/Generated/InternalUnknownAzureChatDataSourceAuthenticationOptions.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalUnknownAzureChatDataSourceAuthenticationOptions.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalUnknownAzureChatDataSourceAuthenticationOptions.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalUnknownAzureChatDataSourceAuthenticationOptions.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalUnknownAzureChatDataSourceAuthenticationOptions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalUnknownAzureChatDataSourceAuthenticationOptions.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalUnknownAzureChatDataSourceAuthenticationOptions.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalUnknownAzureChatDataSourceAuthenticationOptions.cs diff --git a/.dotnet.azure/src/Generated/InternalUnknownAzureChatDataSourceVectorizationSource.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalUnknownAzureChatDataSourceVectorizationSource.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalUnknownAzureChatDataSourceVectorizationSource.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalUnknownAzureChatDataSourceVectorizationSource.Serialization.cs diff --git a/.dotnet.azure/src/Generated/InternalUnknownAzureChatDataSourceVectorizationSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalUnknownAzureChatDataSourceVectorizationSource.cs similarity index 100% rename from .dotnet.azure/src/Generated/InternalUnknownAzureChatDataSourceVectorizationSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/InternalUnknownAzureChatDataSourceVectorizationSource.cs diff --git a/.dotnet.azure/src/Generated/PineconeChatDataSource.Serialization.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/PineconeChatDataSource.Serialization.cs similarity index 100% rename from .dotnet.azure/src/Generated/PineconeChatDataSource.Serialization.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/PineconeChatDataSource.Serialization.cs diff --git a/.dotnet.azure/src/Generated/PineconeChatDataSource.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/PineconeChatDataSource.cs similarity index 100% rename from .dotnet.azure/src/Generated/PineconeChatDataSource.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Generated/PineconeChatDataSource.cs diff --git a/.dotnet.azure/src/Utility/AzureOpenAIPipelineMessageBuilder.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/AzureOpenAIPipelineMessageBuilder.cs similarity index 100% rename from .dotnet.azure/src/Utility/AzureOpenAIPipelineMessageBuilder.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/AzureOpenAIPipelineMessageBuilder.cs diff --git a/.dotnet.azure/src/Utility/CustomSerializationHelpers.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/CustomSerializationHelpers.cs similarity index 100% rename from .dotnet.azure/src/Utility/CustomSerializationHelpers.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/CustomSerializationHelpers.cs diff --git a/.dotnet.azure/src/Utility/Generator/CodeGenClientAttribute.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Generator/CodeGenClientAttribute.cs similarity index 100% rename from .dotnet.azure/src/Utility/Generator/CodeGenClientAttribute.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Generator/CodeGenClientAttribute.cs diff --git a/.dotnet.azure/src/Utility/Generator/CodeGenMemberAttribute.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Generator/CodeGenMemberAttribute.cs similarity index 100% rename from .dotnet.azure/src/Utility/Generator/CodeGenMemberAttribute.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Generator/CodeGenMemberAttribute.cs diff --git a/.dotnet.azure/src/Utility/Generator/CodeGenModelAttribute.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Generator/CodeGenModelAttribute.cs similarity index 100% rename from .dotnet.azure/src/Utility/Generator/CodeGenModelAttribute.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Generator/CodeGenModelAttribute.cs diff --git a/.dotnet.azure/src/Utility/Generator/CodeGenSerializationAttribute.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Generator/CodeGenSerializationAttribute.cs similarity index 100% rename from .dotnet.azure/src/Utility/Generator/CodeGenSerializationAttribute.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Generator/CodeGenSerializationAttribute.cs diff --git a/.dotnet.azure/src/Utility/Generator/CodeGenSuppressAttribute.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Generator/CodeGenSuppressAttribute.cs similarity index 100% rename from .dotnet.azure/src/Utility/Generator/CodeGenSuppressAttribute.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Generator/CodeGenSuppressAttribute.cs diff --git a/.dotnet.azure/src/Utility/Generator/CodeGenTypeAttribute.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Generator/CodeGenTypeAttribute.cs similarity index 100% rename from .dotnet.azure/src/Utility/Generator/CodeGenTypeAttribute.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Generator/CodeGenTypeAttribute.cs diff --git a/.dotnet.azure/src/Utility/GenericActionPipelinePolicy.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/GenericActionPipelinePolicy.cs similarity index 100% rename from .dotnet.azure/src/Utility/GenericActionPipelinePolicy.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/GenericActionPipelinePolicy.cs diff --git a/.dotnet.azure/src/Utility/Polyfill/System.Diagnostics.CodeAnalysis.ExperimentalAttribute.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Polyfill/System.Diagnostics.CodeAnalysis.ExperimentalAttribute.cs similarity index 100% rename from .dotnet.azure/src/Utility/Polyfill/System.Diagnostics.CodeAnalysis.ExperimentalAttribute.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Polyfill/System.Diagnostics.CodeAnalysis.ExperimentalAttribute.cs diff --git a/.dotnet.azure/src/Utility/Polyfill/System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Polyfill/System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute.cs similarity index 100% rename from .dotnet.azure/src/Utility/Polyfill/System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Polyfill/System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute.cs diff --git a/.dotnet.azure/src/Utility/Polyfill/System.Runtime.CompilerServices.CompilerFeatureRequiredAttribute.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Polyfill/System.Runtime.CompilerServices.CompilerFeatureRequiredAttribute.cs similarity index 100% rename from .dotnet.azure/src/Utility/Polyfill/System.Runtime.CompilerServices.CompilerFeatureRequiredAttribute.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Polyfill/System.Runtime.CompilerServices.CompilerFeatureRequiredAttribute.cs diff --git a/.dotnet.azure/src/Utility/Polyfill/System.Runtime.CompilerServices.IsExternalInit.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Polyfill/System.Runtime.CompilerServices.IsExternalInit.cs similarity index 100% rename from .dotnet.azure/src/Utility/Polyfill/System.Runtime.CompilerServices.IsExternalInit.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Polyfill/System.Runtime.CompilerServices.IsExternalInit.cs diff --git a/.dotnet.azure/src/Utility/Polyfill/System.Runtime.CompilerServices.RequiredMemberAttribute.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Polyfill/System.Runtime.CompilerServices.RequiredMemberAttribute.cs similarity index 100% rename from .dotnet.azure/src/Utility/Polyfill/System.Runtime.CompilerServices.RequiredMemberAttribute.cs rename to .dotnet.azure/sdk/openai/Azure.AI.OpenAI/src/Utility/Polyfill/System.Runtime.CompilerServices.RequiredMemberAttribute.cs diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/edit_sample_image.png b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/edit_sample_image.png new file mode 100644 index 000000000..869bb1e04 Binary files /dev/null and b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/edit_sample_image.png differ diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/edit_sample_mask.png b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/edit_sample_mask.png new file mode 100644 index 000000000..98b9c237c Binary files /dev/null and b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/edit_sample_mask.png differ diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/fine_tuning.jsonl b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/fine_tuning.jsonl new file mode 100644 index 000000000..2c345d259 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/fine_tuning.jsonl @@ -0,0 +1,11 @@ +{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Sources: Colts grant RB Taylor OK to seek trade"},{"role":"assistant","content":"{\n \"player\": \"Jonathan Taylor\",\n \"team\": \"Indianapolis Colts\",\n \"sport\": \"football\",\n \"gender\": \"male\"\n}"}]} +{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"OSU 'split down middle' on starting QB battle"},{"role":"assistant","content":"{\n \"team\": \"Ohio State University\",\n \"sport\": \"football\"\n}"}]} +{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Christian McCaffrey given 2 year extension by 49ers"},{"role":"assistant","content":"{\n \"player\": \"Christian McCaffrey\",\n \"team\": \"San Francisco 49ers\",\n \"sport\": \"football\",\n \"gender\": \"male\"\n}"}]} +{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Tucupita Marcano banned for life by MLB for betting on baseball"},{"role":"assistant","content":"{\n \"player\": \"Tucupita Marcano\",\n \"sport\": \"baseball\",\n \"gender\": \"male\"\n}"}]} +{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Who will win the 2024 NBA finals? Predictions about the Celtics vs. Mavericks matchup"},{"role":"assistant","content":"[\n {\n \"team\": \"Celtics\",\n \"sport\": \"basketball\"\n },\n {\n \"team\": \"Mavericks\",\n \"sport\": \"basketball\"\n }\n]"}]} +{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Pavleski will not play in 2024-2025 season"},{"role":"assistant","content":"{\n \"player\": \"Pavelski\",\n \"sport\": \"hockey\",\n \"gender\": \"male\"\n}"}]} +{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Charges against Scottie Scheffler have been dropped after arrest"},{"role":"assistant","content":"{\n \"player\": \"Scottie Scheffler\",\n \"sport\": \"golf\",\n \"gender\": \"male\"\n}"}]} +{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Perez picked by Red Bull for F1 2025 deal"},{"role":"assistant","content":"{\n \"player\": \"Perez\",\n \"team\": \"Red Bull\",\n \"sport\": \"F1\",\n \"gender\": \"male\"\n}"}]} +{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"DL Johnson III waived by 49ers"},{"role":"assistant","content":"{\n \"player\": \"DL Johnson III\",\n \"team\": \"49ers\",\n \"sport\": \"football\",\n \"gender\": \"male\"\n}"}]} +{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Trevor Williams placed on injured list"},{"role":"assistant","content":"{\n \"player\": \"Trevor Williams\",\n \"sport\": \"baseball\",\n \"gender\": \"male\"\n}"}]} +{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Coco Gauff, and Iga Swiatek will meet in French Open semis"},{"role":"assistant","content":"[\n {\n \"player\": \"Coco Gauff\",\n \"sport\": \"tennis\",\n \"gender\": \"female\"\n },\n {\n \"player\": \"Iga Swiatek\",\n \"sport\": \"tennis\",\n \"gender\": \"female\"\n }\n]"}]} \ No newline at end of file diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/french.wav b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/french.wav new file mode 100644 index 000000000..847f3463a Binary files /dev/null and b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/french.wav differ diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/hello_world.m4a b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/hello_world.m4a new file mode 100644 index 000000000..ed8e09c8f Binary files /dev/null and b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/hello_world.m4a differ diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/playback_test_config.json b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/playback_test_config.json new file mode 100644 index 000000000..d55e5d96e --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/playback_test_config.json @@ -0,0 +1,38 @@ +{ + "default": { + "endpoint": "https://Sanitized.openai.azure.com/", + "key": "Sanitized", + "deployment": "gpt-4-turbo", + "resource_group": "Sanitized", + "subscription_id": "Sanitized" + }, + "audio": { + "deployment": "whisper" + }, + "embedding": { + "deployment": "text-embedding-ada-002" + }, + "fine_tuning": { + "deployment": "gpt-35-turbo-0613", + "fine_tuned_model": "gpt-35-turbo-0613.ft-53f9c10199f84dfea3ec772341862ff5-azure-ai-openai-integration-test" + }, + "image": { + "deployment": "dall-e-3" + }, + "rate_limited_chat": { + "endpoint": "https://Sanitized.openai.azure.com/", + "key": "Sanitized", + "deployment": "gpt-35-turbo-low-quota" + }, + "search": { + "endpoint": "https://Sanitized.search.windows.net/", + "key": "Sanitized", + "index": "openaiwikisearchindex" + }, + "tts": { + "deployment": "tts" + }, + "vision": { + "deployment": "gpt-4-vision-preview" + } +} \ No newline at end of file diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/speed-talking.wav b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/speed-talking.wav new file mode 100644 index 000000000..2a09e2737 Binary files /dev/null and b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/speed-talking.wav differ diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/stop_sign.png b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/stop_sign.png new file mode 100644 index 000000000..002b3ae1a Binary files /dev/null and b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/stop_sign.png differ diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/variation_sample_image.png b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/variation_sample_image.png new file mode 100644 index 000000000..119a13e8f Binary files /dev/null and b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Assets/variation_sample_image.png differ diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/AssistantTests.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/AssistantTests.cs new file mode 100644 index 000000000..67b4c354f --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/AssistantTests.cs @@ -0,0 +1,653 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable disable + +using System; +using System.ClientModel; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Azure.AI.OpenAI.Tests.Utils.Config; +using OpenAI; +using OpenAI.Assistants; +using OpenAI.Files; +using OpenAI.TestFramework; +using OpenAI.TestFramework.Utils; +using OpenAI.VectorStores; + +namespace Azure.AI.OpenAI.Tests; + +public class AssistantTests(bool isAsync) : AoaiTestBase(isAsync) +{ + [Test] + [Category("Smoke")] + public void CanCreateClient() => Assert.That(GetTestClient(), Is.InstanceOf()); + + [Test] + [Category("Smoke")] + public void VerifyClientOptionMutability() + { + AzureOpenAIClientOptions options = null; + Assert.DoesNotThrow(() => + options = new AzureOpenAIClientOptions() + { + ApplicationId = "init does not throw", + }); + Assert.DoesNotThrow(() => + options.ApplicationId = "set before freeze OK"); + AzureOpenAIClient azureClient = new( + new Uri("https://www.microsoft.com/placeholder"), + new ApiKeyCredential("placeholder"), + options); + Assert.Throws(() => + options.ApplicationId = "set after freeze throws"); + } + + [RecordedTest] + public async Task BasicAssistantOperationsWork() + { + AssistantClient client = GetTestClient(); + string modelName = client.DeploymentOrThrow(); + Assistant assistant = await client.CreateAssistantAsync(modelName); + Validate(assistant); + Assert.That(assistant.Name, Is.Null.Or.Empty); + assistant = await client.ModifyAssistantAsync(assistant.Id, new AssistantModificationOptions() + { + Name = "test assistant name", + }); + Assert.That(assistant.Name, Is.EqualTo("test assistant name")); + bool deleted = await client.DeleteAssistantAsync(assistant.Id); + Assert.That(deleted, Is.True); + assistant = await client.CreateAssistantAsync(modelName, new AssistantCreationOptions() + { + Metadata = + { + ["testkey"] = "hello!" + }, + }); + Validate(assistant); + Assistant retrievedAssistant = await client.GetAssistantAsync(assistant.Id); + Assert.That(retrievedAssistant.Id, Is.EqualTo(assistant.Id)); + Assert.That(retrievedAssistant.Metadata.TryGetValue("testkey", out string metadataValue) && metadataValue == "hello!"); + Assistant modifiedAssistant = await client.ModifyAssistantAsync(assistant.Id, new AssistantModificationOptions() + { + Metadata = + { + ["testkey"] = "goodbye!", + }, + }); + Assert.That(modifiedAssistant.Id, Is.EqualTo(assistant.Id)); + AsyncPageCollection recentAssistants = client.GetAssistantsAsync(); + Assistant firstAssistant = await recentAssistants.GetAllValuesAsync().FirstOrDefaultAsync(); + Assert.That(firstAssistant, Is.Not.Null); + Assert.That(firstAssistant.Metadata.TryGetValue("testkey", out string newMetadataValue) && newMetadataValue == "goodbye!"); + } + + [RecordedTest] + public async Task BasicThreadOperationsWork() + { + AssistantClient client = GetTestClient(); + AssistantThread thread = await client.CreateThreadAsync(); + Validate(thread); + Assert.That(thread.CreatedAt, Is.GreaterThan(s_2024)); + bool deleted = await client.DeleteThreadAsync(thread.Id); + Assert.That(deleted, Is.True); + + ThreadCreationOptions options = new() + { + Metadata = + { + ["threadMetadata"] = "threadMetadataValue", + } + }; + thread = await client.CreateThreadAsync(options); + Validate(thread); + Assert.That(thread.Metadata.TryGetValue("threadMetadata", out string threadMetadataValue) && threadMetadataValue == "threadMetadataValue"); + AssistantThread retrievedThread = await client.GetThreadAsync(thread.Id); + Assert.That(retrievedThread.Id, Is.EqualTo(thread.Id)); + thread = await client.ModifyThreadAsync(thread, new ThreadModificationOptions() + { + Metadata = + { + ["threadMetadata"] = "newThreadMetadataValue", + }, + }); + Assert.That(thread.Metadata.TryGetValue("threadMetadata", out threadMetadataValue) && threadMetadataValue == "newThreadMetadataValue"); + } + + [RecordedTest] + public async Task SettingResponseFormatWorks() + { + AssistantClient client = GetTestClient(); + string modelName = client.DeploymentOrThrow(); + + Assistant assistant = await client.CreateAssistantAsync(modelName, new() + { + ResponseFormat = AssistantResponseFormat.JsonObject, + }); + Validate(assistant); + Assert.That(assistant.ResponseFormat, Is.EqualTo(AssistantResponseFormat.JsonObject)); + assistant = await client.ModifyAssistantAsync(assistant, new() + { + ResponseFormat = AssistantResponseFormat.Text, + }); + Assert.That(assistant.ResponseFormat, Is.EqualTo(AssistantResponseFormat.Text)); + AssistantThread thread = await client.CreateThreadAsync(); + Validate(thread); + ThreadMessage message = await client.CreateMessageAsync(thread.Id, MessageRole.User, ["Write some JSON for me!"]); + Validate(message); + ThreadRun run = await client.CreateRunAsync(thread, assistant, new() + { + ResponseFormat = AssistantResponseFormat.JsonObject, + }); + Validate(run); + Assert.That(run.ResponseFormat, Is.EqualTo(AssistantResponseFormat.JsonObject)); + } + + [RecordedTest] + public async Task StreamingToolCall() + { + AssistantClient client = GetTestClient(); + string modelName = client.DeploymentOrThrow(); + FunctionToolDefinition getWeatherTool = new("get_current_weather") { Description = "Gets the user's current weather" }; + Assistant assistant = await client.CreateAssistantAsync(modelName, new() + { + Tools = { getWeatherTool } + }); + Validate(assistant); + + Stopwatch stopwatch = Stopwatch.StartNew(); + void Print(string message) => Console.WriteLine($"[{stopwatch.ElapsedMilliseconds,6}] {message}"); + + Print(" >>> Beginning call ... "); + + ThreadCreationOptions thrdOpt = new() + { + InitialMessages = { new(MessageRole.User, ["What should I wear outside right now?"]), }, + }; + AsyncCollectionResult asyncResults = client.CreateThreadAndRunStreamingAsync(assistant, thrdOpt); + + Print(" >>> Starting enumeration ..."); + + ThreadRun run = null; + + do + { + run = null; + List toolOutputs = new(); + await foreach (StreamingUpdate update in asyncResults) + { + string message = update.UpdateKind.ToString(); + + if (update is RunUpdate runUpdate) + { + message += $" run_id:{runUpdate.Value.Id}"; + run = runUpdate.Value; + } + if (update is RequiredActionUpdate requiredActionUpdate) + { + Assert.That(requiredActionUpdate.FunctionName, Is.EqualTo(getWeatherTool.FunctionName)); + Assert.That(requiredActionUpdate.GetThreadRun().Status, Is.EqualTo(RunStatus.RequiresAction)); + message += $" {requiredActionUpdate.FunctionName}"; + toolOutputs.Add(new(requiredActionUpdate.ToolCallId, "warm and sunny")); + } + if (update is MessageContentUpdate contentUpdate) + { + message += $" {contentUpdate.Text}"; + } + Print(message); + } + if (toolOutputs.Count > 0) + { + asyncResults = client.SubmitToolOutputsToRunStreamingAsync(run, toolOutputs); + } + } while (run?.Status.IsTerminal == false); + } + + [RecordedTest] + public async Task BasicMessageOperationsWork() + { + // TODO FIXME Can't currently delete messages on AOAI + bool aoaiDeleteBugFixed = false; + + AssistantClient client = GetTestClient(); + AssistantThread thread = await client.CreateThreadAsync(); + Validate(thread); + ThreadMessage message = await client.CreateMessageAsync(thread.Id, MessageRole.User, ["Hello, world!"]); + Validate(message); + Assert.That(message.CreatedAt, Is.GreaterThan(s_2024)); + Assert.That(message.Content?.Count, Is.EqualTo(1)); + Assert.That(message.Content[0], Is.Not.Null); + Assert.That(message.Content[0].Text, Is.EqualTo("Hello, world!")); + + if (aoaiDeleteBugFixed) + { + bool deleted = await client.DeleteMessageAsync(message); + Assert.That(deleted, Is.True); + } + + message = await client.CreateMessageAsync(thread.Id, MessageRole.User, ["Goodbye, world!"], new MessageCreationOptions() + { + Metadata = + { + ["messageMetadata"] = "messageMetadataValue", + }, + }); + Validate(message); + Assert.That(message.Metadata.TryGetValue("messageMetadata", out string metadataValue) && metadataValue == "messageMetadataValue"); + + ThreadMessage retrievedMessage = await client.GetMessageAsync(thread.Id, message.Id); + Assert.That(retrievedMessage.Id, Is.EqualTo(message.Id)); + + message = await client.ModifyMessageAsync(message, new MessageModificationOptions() + { + Metadata = + { + ["messageMetadata"] = "newValue", + } + }); + Assert.That(message.Metadata.TryGetValue("messageMetadata", out metadataValue) && metadataValue == "newValue"); + + var messagePage = await client.GetMessagesAsync(thread).ToListAsync(); + if (aoaiDeleteBugFixed) + { + Assert.That(messagePage.Count, Is.EqualTo(1)); + } + else + { + Assert.That(messagePage.Count, Is.EqualTo(2)); + } + + Assert.That(messagePage.ElementAt(0).Id, Is.EqualTo(message.Id)); + Assert.That(messagePage.ElementAt(0).Metadata.TryGetValue("messageMetadata", out metadataValue) && metadataValue == "newValue"); + } + + [RecordedTest] + public async Task ThreadWithInitialMessagesWorks() + { + const string userGreeting = "Hello, world!"; + const string userQuestion = "Can you describe why stop signs are the shape and color that they are?"; + + AssistantClient client = GetTestClient(); + ThreadCreationOptions options = new() + { + InitialMessages = + { + new ThreadInitializationMessage(MessageRole.User, [userGreeting]), + new ThreadInitializationMessage(MessageRole.User, [ userQuestion ]) + { + Metadata = + { + ["messageMetadata"] = "messageMetadataValue", + }, + }, + }, + }; + AssistantThread thread = await client.CreateThreadAsync(options); + Validate(thread); + List messageList = await client.GetMessagesAsync(thread, new() { Order = ListOrder.OldestFirst }).ToListAsync(); + Assert.That(messageList.Count, Is.EqualTo(2)); + Assert.That(messageList[0].Role, Is.EqualTo(MessageRole.User)); + Assert.That(messageList[0].Content?.Count, Is.EqualTo(1)); + Assert.That(messageList[0].Content[0].Text, Is.EqualTo(userGreeting)); + Assert.That(messageList[1].Content[0], Is.Not.Null); + Assert.That(messageList[1].Content[0].Text, Is.EqualTo(userQuestion)); + } + + [RecordedTest] + public async Task BasicRunOperationsWork() + { + AssistantClient client = GetTestClient(); + string modelName = client.DeploymentOrThrow(); + Assistant assistant = await client.CreateAssistantAsync(modelName); + Validate(assistant); + AssistantThread thread = await client.CreateThreadAsync(); + Validate(thread); + List runPage = await client.GetRunsAsync(thread.Id).ToListAsync(); + Assert.That(runPage.Count, Is.EqualTo(0)); + ThreadMessage message = await client.CreateMessageAsync(thread.Id, MessageRole.User, ["Hello, assistant!"]); + Validate(message); + ThreadRun run = await client.CreateRunAsync(thread.Id, assistant.Id); + Validate(run); + Assert.That(run.Status, Is.EqualTo(RunStatus.Queued)); + Assert.That(run.CreatedAt, Is.GreaterThan(s_2024)); + ThreadRun retrievedRun = await client.GetRunAsync(thread.Id, run.Id); + Assert.That(retrievedRun.Id, Is.EqualTo(run.Id)); + runPage = await client.GetRunsAsync(thread.Id).ToListAsync(); + Assert.That(runPage.Count, Is.EqualTo(1)); + Assert.That(runPage.ElementAt(0).Id, Is.EqualTo(run.Id)); + + List messages = await client.GetMessagesAsync(thread).ToListAsync(); + Assert.That(messages.Count, Is.GreaterThanOrEqualTo(1)); + + run = await WaitUntilReturnLast( + run, + () => client.GetRunAsync(run), + r => r.Status.IsTerminal); + Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); + + Assert.Multiple(() => + { + Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); + Assert.That(run.CompletedAt, Is.GreaterThan(s_2024)); + Assert.That(run.RequiredActions, Is.Empty); + Assert.That(run.AssistantId, Is.EqualTo(assistant.Id)); + Assert.That(run.FailedAt, Is.Null); + Assert.That(run.IncompleteDetails, Is.Null); + }); + messages = await client.GetMessagesAsync(thread).ToListAsync(); + Assert.That(messages.Count, Is.EqualTo(2)); + + Assert.That(messages.ElementAt(0).Role, Is.EqualTo(MessageRole.Assistant)); + Assert.That(messages.ElementAt(1).Role, Is.EqualTo(MessageRole.User)); + Assert.That(messages.ElementAt(1).Id, Is.EqualTo(message.Id)); + } + + [RecordedTest] + public async Task BasicRunStepFunctionalityWorks() + { + AssistantClient client = GetTestClient(); + string modelName = client.DeploymentOrThrow(); + Assistant assistant = await client.CreateAssistantAsync(modelName, new AssistantCreationOptions() + { + Tools = { new CodeInterpreterToolDefinition() }, + Instructions = "Call the code interpreter tool when asked to visualize mathematical concepts.", + }); + Validate(assistant); + + AssistantThread thread = await client.CreateThreadAsync(new ThreadCreationOptions() + { + InitialMessages = { new(MessageRole.User, ["Please graph the equation y = 3x + 4"]), }, + }); + Validate(thread); + + ThreadRun run = await client.CreateRunAsync(thread, assistant); + Validate(run); + + run = await WaitUntilReturnLast( + run, + () => client.GetRunAsync(run), + r => r.Status.IsTerminal); + Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); + Assert.That(run.Usage?.TotalTokens, Is.GreaterThan(0)); + + List runSteps = await client.GetRunStepsAsync(run).ToListAsync(); + Assert.That(runSteps.Count(), Is.GreaterThan(1)); + Assert.Multiple(() => + { + Assert.That(runSteps.ElementAt(0).AssistantId, Is.EqualTo(assistant.Id)); + Assert.That(runSteps.ElementAt(0).ThreadId, Is.EqualTo(thread.Id)); + Assert.That(runSteps.ElementAt(0).RunId, Is.EqualTo(run.Id)); + Assert.That(runSteps.ElementAt(0).CreatedAt, Is.GreaterThan(s_2024)); + Assert.That(runSteps.ElementAt(0).CompletedAt, Is.GreaterThan(s_2024)); + }); + RunStepDetails details = runSteps.ElementAt(0).Details; + Assert.That(details?.CreatedMessageId, Is.Not.Null.Or.Empty); + + details = runSteps.ElementAt(1).Details; + Assert.Multiple(() => + { + Assert.That(details?.ToolCalls.Count, Is.GreaterThan(0)); + Assert.That(details.ToolCalls[0].ToolKind, Is.EqualTo(RunStepToolCallKind.CodeInterpreter)); + Assert.That(details.ToolCalls[0].ToolCallId, Is.Not.Null.Or.Empty); + Assert.That(details.ToolCalls[0].CodeInterpreterInput, Is.Not.Null.Or.Empty); + Assert.That(details.ToolCalls[0].CodeInterpreterOutputs?.Count, Is.GreaterThan(0)); + Assert.That(details.ToolCalls[0].CodeInterpreterOutputs[0].ImageFileId, Is.Not.Null.Or.Empty); + }); + } + + [RecordedTest] + public async Task FunctionToolsWork() + { + AssistantClient client = GetTestClient(); + string modelName = client.DeploymentOrThrow(); + Assistant assistant = await client.CreateAssistantAsync(modelName, new AssistantCreationOptions() + { + Tools = + { + new FunctionToolDefinition() + { + FunctionName = "get_favorite_food_for_day_of_week", + Description = "gets the user's favorite food for a given day of the week, like Tuesday", + Parameters = BinaryData.FromObjectAsJson(new + { + type = "object", + properties = new + { + day_of_week = new + { + type = "string", + description = "a day of the week, like Tuesday or Saturday", + } + } + }), + }, + }, + }); + Validate(assistant); + Assert.That(assistant.Tools?.Count, Is.EqualTo(1)); + + FunctionToolDefinition responseToolDefinition = assistant.Tools[0] as FunctionToolDefinition; + Assert.That(responseToolDefinition?.FunctionName, Is.EqualTo("get_favorite_food_for_day_of_week")); + Assert.That(responseToolDefinition?.Parameters, Is.Not.Null); + + ThreadRun run = await client.CreateThreadAndRunAsync( + assistant, + new ThreadCreationOptions() + { + InitialMessages = { new(MessageRole.User, ["What should I eat on Thursday?"]) }, + }, + new RunCreationOptions() + { + AdditionalInstructions = "Call provided tools when appropriate.", + }); + Validate(run); + Console.WriteLine($" Run status right after creation: {run.Status}"); + + // TODO FIXME: The underlying OpenAI code doesn't consider the "requires_action" status to be terminal even though it is. + // Work around this here + run = await WaitUntilReturnLast( + run, + () => client.GetRunAsync(run), + r => r.Status.IsTerminal || r.Status.Equals(RunStatus.RequiresAction)); + + Assert.That(run.Status, Is.EqualTo(RunStatus.RequiresAction)); + Assert.That(run.RequiredActions?.Count, Is.EqualTo(1)); + Assert.That(run.RequiredActions[0].ToolCallId, Is.Not.Null.Or.Empty); + Assert.That(run.RequiredActions[0].FunctionName, Is.EqualTo("get_favorite_food_for_day_of_week")); + Assert.That(run.RequiredActions[0].FunctionArguments, Is.Not.Null.Or.Empty); + + run = await client.SubmitToolOutputsToRunAsync(run, [new(run.RequiredActions[0].ToolCallId, "tacos")]); + Assert.That(run.Status.IsTerminal, Is.False); + + run = await WaitUntilReturnLast( + run, + () => client.GetRunAsync(run), + r => r.Status.IsTerminal); + Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); + + List messages = await client.GetMessagesAsync(run.ThreadId, new() { Order = ListOrder.NewestFirst }) + .ToListAsync(); + Assert.That(messages.Count, Is.GreaterThan(1)); + Assert.That(messages.ElementAt(0).Role, Is.EqualTo(MessageRole.Assistant)); + Assert.That(messages.ElementAt(0).Content?[0], Is.Not.Null); + Assert.That(messages.ElementAt(0).Content?[0].Text, Does.Contain("tacos")); + } + + [RecordedTest] + public async Task BasicFileSearchWorks() + { + // First, we need to upload a simple test file. + AssistantClient client = GetTestClient(); + string modelName = client.DeploymentOrThrow(); + FileClient fileClient = GetTestClientFrom(client); + + OpenAIFileInfo testFile = await fileClient.UploadFileAsync( + BinaryData.FromString(""" + This file describes the favorite foods of several people. + + Summanus Ferdinand: tacos + Tekakwitha Effie: pizza + Filip Carola: cake + """), + "favorite_foods.txt", + FileUploadPurpose.Assistants); + Validate(testFile); + + // Create an assistant, using the creation helper to make a new vector store + Assistant assistant = await client.CreateAssistantAsync(modelName, new() + { + Tools = { new FileSearchToolDefinition() }, + ToolResources = new() + { + FileSearch = new() + { + NewVectorStores = + { + new VectorStoreCreationHelper([testFile]), + } + } + } + }); + Validate(assistant); + Assert.That(assistant.ToolResources?.FileSearch?.VectorStoreIds, Has.Count.EqualTo(1)); + string createdVectorStoreId = assistant.ToolResources.FileSearch.VectorStoreIds[0]; + ValidateById(createdVectorStoreId); + + // Modify an assistant to use the existing vector store + assistant = await client.ModifyAssistantAsync(assistant, new AssistantModificationOptions() + { + ToolResources = new() + { + FileSearch = new() + { + VectorStoreIds = { assistant.ToolResources.FileSearch.VectorStoreIds[0] }, + }, + }, + }); + Assert.That(assistant.ToolResources?.FileSearch?.VectorStoreIds, Has.Count.EqualTo(1)); + Assert.That(assistant.ToolResources.FileSearch.VectorStoreIds[0], Is.EqualTo(createdVectorStoreId)); + + // Create a thread with an override vector store + AssistantThread thread = await client.CreateThreadAsync(new ThreadCreationOptions() + { + InitialMessages = { new(MessageRole.User, ["Using the files you have available, what's Filip's favorite food?"]) }, + ToolResources = new() + { + FileSearch = new() + { + NewVectorStores = + { + new VectorStoreCreationHelper([testFile.Id]) + } + } + } + }); + Validate(thread); + Assert.That(thread.ToolResources?.FileSearch?.VectorStoreIds, Has.Count.EqualTo(1)); + createdVectorStoreId = thread.ToolResources.FileSearch.VectorStoreIds[0]; + ValidateById(createdVectorStoreId); + + // Ensure that modifying the thread with an existing vector store works + thread = await client.ModifyThreadAsync(thread, new ThreadModificationOptions() + { + ToolResources = new() + { + FileSearch = new() + { + VectorStoreIds = { createdVectorStoreId }, + } + } + }); + Assert.That(thread.ToolResources?.FileSearch?.VectorStoreIds, Has.Count.EqualTo(1)); + Assert.That(thread.ToolResources.FileSearch.VectorStoreIds[0], Is.EqualTo(createdVectorStoreId)); + + ThreadRun run = await client.CreateRunAsync(thread, assistant); + Validate(run); + run = await WaitUntilReturnLast( + run, + () => client.GetRunAsync(run), + r => r.Status.IsTerminal); + Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); + + AsyncPageCollection messages = client.GetMessagesAsync(thread, new() { Order = ListOrder.NewestFirst }); + int numPages = 0; + int numThreads = 0; + bool hasCake = false; + await foreach (PageResult page in messages) + { + numPages++; + foreach (ThreadMessage message in page.Values) + { + numThreads++; + foreach (MessageContent content in message.Content) + { + Console.WriteLine(content.Text); + hasCake |= content.Text?.ToLowerInvariant().Contains("cake") == true; + foreach (TextAnnotation annotation in content.TextAnnotations) + { + Console.WriteLine($" --> From file: {annotation.InputFileId}, replacement: {annotation.TextToReplace}"); + } + } + } + } + + Assert.That(numPages, Is.GreaterThan(0)); + Assert.That(numThreads, Is.GreaterThan(0)); + Assert.That(hasCake, Is.True); + } + + [RecordedTest] + public async Task StreamingRunWorks() + { + AssistantClient client = GetTestClient(); + string modelName = client.DeploymentOrThrow(); + Assistant assistant = await client.CreateAssistantAsync(modelName); + Validate(assistant); + + AssistantThread thread = await client.CreateThreadAsync(new ThreadCreationOptions() + { + InitialMessages = { new(MessageRole.User, ["Hello there, assistant! How are you today?"]), }, + }); + Validate(thread); + + AsyncCollectionResult streamingResult = client.CreateRunStreamingAsync(thread.Id, assistant.Id); + + StringBuilder content = new(); + DateTimeOffset? lastUpdate = null; + StreamingUpdateReason? lastUpdateReason = null; + + await foreach (StreamingUpdate update in streamingResult) + { + if (update is RunUpdate runUpdate) + { + lastUpdateReason = runUpdate.UpdateKind; + lastUpdate = update.UpdateKind switch + { + StreamingUpdateReason.RunCreated => runUpdate.Value.CreatedAt, + StreamingUpdateReason.RunQueued => runUpdate.Value.StartedAt, + StreamingUpdateReason.RunInProgress => runUpdate.Value.StartedAt, + StreamingUpdateReason.RunCompleted => runUpdate.Value.CompletedAt, + _ => null, + }; + } + if (update is MessageContentUpdate contentUpdate) + { + // TODO FIXME: The OpenAI library code is currently incorrectly returning a MessageRole.User value here. + // It should instead be null or at least Assistant + //Assert.That(contentUpdate.Role, Is.Null.Or.EqualTo(MessageRole.Assistant)); + Assert.That(contentUpdate.Text, Is.Not.Null); // can be empty string + content.Append(contentUpdate.Text); + } + } + + Assert.That(lastUpdateReason, Is.EqualTo(StreamingUpdateReason.RunCompleted)); + Assert.That(lastUpdate, Is.Not.Null.And.GreaterThan(s_2024)); + Assert.That(content, Has.Length.GreaterThan(0)); + } + + private static readonly DateTimeOffset s_2024 = new(2024, 1, 1, 0, 0, 0, TimeSpan.Zero); +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/AudioTests.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/AudioTests.cs new file mode 100644 index 000000000..5312b8083 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/AudioTests.cs @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.IO; +using System.Threading.Tasks; +using OpenAI.Audio; +using OpenAI.TestFramework; + +namespace Azure.AI.OpenAI.Tests; + +public class AudioTests(bool isAsync) : AoaiTestBase(isAsync) +{ + [Test] + [Category("Smoke")] + public void CanCreateClient() => Assert.That(GetTestClient(), Is.InstanceOf()); + + [RecordedTest] + public async Task TranscriptionWorks() + { + AudioClient audioClient = GetTestClient(); + AudioTranscription transcription = await audioClient.TranscribeAudioAsync(Assets.HelloWorld.RelativePath); + Assert.That(transcription?.Text, Is.Not.Null.Or.Empty); + } + + [RecordedTest] + public async Task TranslationWorks() + { + AudioClient audioClient = GetTestClient(); + AudioTranslation translation = await audioClient.TranslateAudioAsync(Assets.WhisperFrenchDescription.RelativePath); + Assert.That(translation?.Text, Is.Not.Null.Or.Empty); + } + + [RecordedTest] + public async Task TextToSpeechWorks() + { + AudioClient audioClient = GetTestClient("tts"); + BinaryData ttsData = await audioClient.GenerateSpeechAsync( + "hello, world!", + GeneratedSpeechVoice.Alloy); + Assert.That(ttsData, Is.Not.Null); + } + + [RecordedTest] + [TestCase(AudioTranscriptionFormat.Simple)] + [TestCase(AudioTranscriptionFormat.Verbose)] + [TestCase(AudioTranscriptionFormat.Srt)] + [TestCase(AudioTranscriptionFormat.Vtt)] + [TestCase(null)] + public async Task TranscriptionWorksWithFormat(AudioTranscriptionFormat? format) + { + AudioClient client = GetTestClient(); + + var audioInfo = Assets.HelloWorld; + using Stream audioFileStream = File.OpenRead(audioInfo.RelativePath); + AudioTranscriptionOptions options = new() + { + Temperature = 0.25f, + ResponseFormat = format, + }; + + AudioTranscription transcription = await client.TranscribeAudioAsync( + audioFileStream, audioInfo.Name, options); + + Assert.That(transcription, Is.Not.Null); + Assert.That(transcription.Text, Is.Not.Null.Or.Empty); + + if (format == AudioTranscriptionFormat.Simple) + { + Assert.That(transcription.Duration, Is.Null); + Assert.That(transcription.Language, Is.Null); + Assert.That(transcription.Segments, Is.Null.Or.Empty); + } + else if (format == AudioTranscriptionFormat.Verbose) + { + Assert.That(transcription.Duration, Is.GreaterThan(TimeSpan.FromSeconds(0))); + Assert.That(transcription.Language, Is.Not.Null.Or.Empty); + Assert.That(transcription.Segments, Is.Not.Null.Or.Empty); + + TranscribedSegment firstSegment = transcription.Segments[0]; + Assert.That(firstSegment, Is.Not.Null); + Assert.That(firstSegment.Id, Is.EqualTo(0)); + Assert.That(firstSegment.Start, Is.GreaterThanOrEqualTo(TimeSpan.FromSeconds(0))); + Assert.That(firstSegment.End, Is.GreaterThan(firstSegment.Start)); + Assert.That(firstSegment.Text, Is.Not.Null.Or.Empty); + } + } + + [RecordedTest] + [TestCase(AudioTimestampGranularities.Default)] + [TestCase(AudioTimestampGranularities.Word)] + [TestCase(AudioTimestampGranularities.Segment)] + [TestCase(AudioTimestampGranularities.Word | AudioTimestampGranularities.Segment)] + public async Task TranscriptionTimestampGranularitiesWork(AudioTimestampGranularities granularityFlags) + { + AudioClient client = GetTestClient(); + var audioInfo = Assets.HelloWorld; + using Stream audioFileStream = File.OpenRead(audioInfo.RelativePath); + AudioTranscriptionOptions options = new() + { + Granularities = granularityFlags, + ResponseFormat = AudioTranscriptionFormat.Verbose, + }; + ClientResult transcriptionResult = await client.TranscribeAudioAsync( + audioFileStream, + audioInfo.Name, + options); + PipelineResponse response = transcriptionResult.GetRawResponse(); + Assert.That(response, Is.Not.Null); + AudioTranscription transcription = transcriptionResult.Value; + Assert.That(transcription.Text, Is.Not.Null.Or.Empty); + Assert.That( + transcription.Words?.Count > 0, + Is.EqualTo(granularityFlags.HasFlag(AudioTimestampGranularities.Word)), + "Word-level information should appear (and only appear) when requested"); + Assert.That( + transcription.Segments?.Count > 0, + Is.EqualTo(granularityFlags.HasFlag(AudioTimestampGranularities.Segment) || granularityFlags == AudioTimestampGranularities.Default), + "Segment-level information should appear (and only appear) when requested or when no flags were provided"); + } + + [RecordedTest] + [TestCase(AudioTranslationFormat.Simple)] + [TestCase(AudioTranslationFormat.Verbose)] + [TestCase(AudioTranslationFormat.Srt)] + [TestCase(AudioTranslationFormat.Vtt)] + [TestCase(null)] + public async Task TranslationWorksWithFormat(AudioTranslationFormat? format) + { + AudioClient client = GetTestClient(); + + var audioInfo = Assets.WhisperFrenchDescription; + using Stream audioFileStream = File.OpenRead(audioInfo.RelativePath); + AudioTranslationOptions options = new() + { + ResponseFormat = format, + }; + + AudioTranslation translation = await client.TranslateAudioAsync( + audioFileStream, audioInfo.Name, options); + + Assert.That(translation, Is.Not.Null); + Assert.That(translation.Text, Is.Not.Null.Or.Empty); + + if (format == AudioTranslationFormat.Simple) + { + Assert.That(translation.Duration, Is.Null); + Assert.That(translation.Language, Is.Null); + Assert.That(translation.Segments, Is.Null.Or.Empty); + } + else if (format == AudioTranslationFormat.Verbose) + { + Assert.That(translation.Duration, Is.GreaterThan(TimeSpan.FromSeconds(0))); + Assert.That(translation.Language, Is.Not.Null.Or.Empty); + Assert.That(translation.Segments, Is.Not.Null.Or.Empty); + + TranscribedSegment firstSegment = translation.Segments[0]; + Assert.That(firstSegment, Is.Not.Null); + Assert.That(firstSegment.Id, Is.EqualTo(0)); + Assert.That(firstSegment.Start, Is.GreaterThanOrEqualTo(TimeSpan.FromSeconds(0))); + Assert.That(firstSegment.End, Is.GreaterThan(firstSegment.Start)); + Assert.That(firstSegment.Text, Is.Not.Null.Or.Empty); + } + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Azure.AI.OpenAI.Tests.csproj b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Azure.AI.OpenAI.Tests.csproj new file mode 100644 index 000000000..5cc3cb985 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Azure.AI.OpenAI.Tests.csproj @@ -0,0 +1,61 @@ + + + + $(RequiredTargetFrameworks) + + + $(NoWarn);CS1591;CS8002;SA1402;SA1507;SA1508;SA1633;SA1028;SA1505;OPENAI001;AOAI001 + preview + enable + + + + + + + + + + + + + + + + + + + + + + + + + + + PreserveNewest + + + Never + + + + + + Utils\Polyfill\%(RecursiveDir)\%(Filename).cs + + + + + + + <_Parameter1>TestProjectSourceBasePath + <_Parameter2>$(MSBuildThisFileDirectory) + + + + diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/BatchTests.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/BatchTests.cs new file mode 100644 index 000000000..357c022d1 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/BatchTests.cs @@ -0,0 +1,222 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Text.Json; +using System.Threading.Tasks; +using Azure.AI.OpenAI.Tests.Models; +using Azure.AI.OpenAI.Tests.Utils; +using Azure.AI.OpenAI.Tests.Utils.Config; +using OpenAI.Batch; +using OpenAI.Chat; +using OpenAI.Embeddings; +using OpenAI.Files; +using OpenAI.TestFramework; +using OpenAI.TestFramework.Mocks; +using OpenAI.TestFramework.Utils; + +namespace Azure.AI.OpenAI.Tests; + +public class BatchTests : AoaiTestBase +{ + public BatchTests(bool isAsync) : base(isAsync) + { } + + [Test] + [Category("Smoke")] + public void CanCreateClient() => Assert.That(GetTestClient(), Is.InstanceOf()); + + [RecordedTest] + [Ignore("Azure OpenAI does not yet support batch file uploads")] + public async Task SimpleBatchCompletionsTest() + { + BatchClient batchClient = GetTestClient(new TestClientOptions(AzureOpenAIClientOptions.ServiceVersion.V2024_06_01)); + await using BatchOperations ops = new(this, batchClient); + + // Create the batch operations to send and upload them + ops.ChatClient.CompleteChat([new SystemChatMessage("You are a saccharine AI"), new UserChatMessage("Tell me about yourself")]); + ops.ChatClient.CompleteChat([new UserChatMessage("Give me a large random number")]); + Assert.That(ops.Operations, Has.Count.EqualTo(2)); + string inputFileId = await ops.UploadBatchFileAsync(); + + // Create the batch operation + using var requestContent = new BatchOptions() + { + InputFileId = inputFileId, + Endpoint = ops.Operations.Select(o => o.Url).Distinct().First(), + Metadata = + { + [ "description" ] = "Azure OpenAI .Net SDK integration test framework " + nameof(SimpleBatchCompletionsTest), + } + }.ToBinaryContent(); + + ClientResult response = await batchClient.CreateBatchAsync(requestContent); + BatchObject batchObj = ExtractAndValidateBatchObj(response); + + // Poll until we've completed, failed, or were canceled + while ("completed" != batchObj.Status) + { + response = await batchClient.GetBatchAsync(batchObj.Id, new()); + batchObj = ExtractAndValidateBatchObj(response); + } + + Assert.That(batchObj.OutputFileID, Is.Not.Null.Or.Empty); + BinaryData outputData = await ops.DownloadAndValidateResultAsync(batchObj.OutputFileID!); + var parsedOutput = BatchResult.From(outputData); + Assert.That(parsedOutput, Is.Not.Null); + Assert.That(parsedOutput, Has.Count.EqualTo(ops.Operations.Count)); + for (int i = 0; i < parsedOutput.Count; i++) + { + Assert.That(parsedOutput[i].CustomId, Is.EqualTo(ops.Operations[i].CustomId), "Wrong custom ID at index {0}", i); + var completion = parsedOutput[i].Response!; + Assert.That(completion, Is.Not.Null); + Assert.That(completion.Role, Is.EqualTo(ChatMessageRole.Assistant)); + Assert.That(completion.Content, Has.Count.EqualTo(1)); + Assert.That(completion.Content[0].Kind, Is.EqualTo(ChatMessageContentPartKind.Text)); + Assert.That(completion.Content[0].Text, Is.Not.Null.Or.Empty); + } + + } + + #region helper methods + + private BinaryData ValidateHasRawJsonResponse(ClientResult result) + { + Assert.That(result, Is.Not.Null); + PipelineResponse response = result.GetRawResponse(); + Assert.That(response, Is.Not.Null); + Assert.That(response.Status, Is.GreaterThanOrEqualTo(200).And.LessThan(300)); + Assert.That(response.Headers.GetFirstOrDefault("Content-Type"), Does.StartWith("application/json")); + + return response.Content; + } + + private void ValidateBatchResult(BatchObject batchObj) + { + Assert.That(batchObj, Is.Not.Null); + Assert.That(batchObj.Id, Is.Not.Null.Or.Empty); + Assert.That(batchObj.Status, Is.Not.Null); + Assert.That(batchObj.Status, Is.AnyOf("validating", "in_progress", "finalizing", "completed")); + } + + private BatchObject ExtractAndValidateBatchObj(ClientResult result) + { + var binaryData = ValidateHasRawJsonResponse(result); + var batchObj = BatchObject.From(binaryData); + ValidateBatchResult(batchObj); + return batchObj; + } + + #endregion + + #region helper classes + + private class BatchOperations : IAsyncDisposable + { + private MockHttpMessageHandler _handler; + private List _operations; + private string? _uploadId; + private FileClient _fileClient; + + public BatchOperations(AoaiTestBase testBase, BatchClient batchClient) + { + _handler = new(MockHttpMessageHandler.ReturnEmptyJson); + _handler.OnRequest += HandleRequest; + _operations = new(); + + BatchFileName = "batch-" + Guid.NewGuid().ToString("D") + ".json"; + + _fileClient = testBase.GetTestClientFrom(batchClient); + + // Generate the fake pipeline to capture requests and save them to a file later + AzureOpenAIClient fakeTopLevel = new AzureOpenAIClient( + new Uri("https://not.a.real.endpoint.fake"), + new ApiKeyCredential("not.a.real.key"), + new() { Transport = _handler.Transport }); + + ChatClient = fakeTopLevel.GetChatClient(testBase.TestConfig.GetConfig().DeploymentOrThrow("chat client")); + EmbeddingClient = fakeTopLevel.GetEmbeddingClient(testBase.TestConfig.GetConfig().DeploymentOrThrow("embedding client")); + } + + public string BatchFileName { get; } + public IReadOnlyList Operations => _operations; + public ChatClient ChatClient { get; } + public EmbeddingClient EmbeddingClient { get; } + + public async Task UploadBatchFileAsync() + { + if (Operations.Count == 0) + { + throw new InvalidOperationException(); + } + + using MemoryStream stream = new MemoryStream(); + JsonHelpers.Serialize(stream, _operations, JsonOptions.OpenAIJsonOptions); + stream.Seek(0, SeekOrigin.Begin); + var data = BinaryData.FromStream(stream); + + using var content = BinaryContent.Create(data); + + OpenAIFileInfo file = await _fileClient.UploadFileAsync(data, BatchFileName, FileUploadPurpose.Batch); + _uploadId = file.Id; + Assert.That(_uploadId, Is.Not.Null.Or.Empty); + return _uploadId; + } + + public async Task DownloadAndValidateResultAsync(string outputId) + { + ClientResult response = await _fileClient.DownloadFileAsync(outputId); + Assert.That(response, Is.Not.Null); + Assert.That(response.Value, Is.Not.Null); + return response.Value; + } + + public async ValueTask DisposeAsync() + { + // clean up any files + if (_uploadId != null) + { + await _fileClient.DeleteFileAsync(_uploadId); + } + + _handler.OnRequest -= HandleRequest; + _handler.Dispose(); + _operations.Clear(); + } + + private void HandleRequest(object? sender, CapturedRequest request) + { + JsonElement? element = null; + if (request.Content != null) + { + using var json = JsonDocument.Parse(request.Content.ToMemory()); + element = json.RootElement.Clone(); + } + + BatchOperation operation = new() + { + Method = request.Method, + Url = request.Uri?.AbsolutePath ?? string.Empty, + Body = element + }; + + _operations.Add(operation); + } + + public class BatchOperation + { + public string CustomId { get; } = Guid.NewGuid().ToString(); + public HttpMethod Method { get; init; } = HttpMethod.Get; + public string Url { get; init; } = string.Empty; + public JsonElement? Body { get; init; } + } + } + + #endregion +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Functions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Functions.cs new file mode 100644 index 000000000..ebe881c6a --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Functions.cs @@ -0,0 +1,280 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using OpenAI.Chat; +using OpenAI.TestFramework; + +namespace Azure.AI.OpenAI.Tests; + +public partial class ChatTests +{ + [Obsolete] + private static readonly ChatFunction FUNCTION_TEMPERATURE = new( + "get_future_temperature", + "requests the anticipated future temperature at a provided location to help inform advice about topics like choice of attire", + BinaryData.FromString( + """ + { + "type": "object", + "properties": { + "locationName": { + "type": "string", + "description": "the name or brief description of a location for weather information" + }, + "date": { + "type": "string", + "description": "the day, month, and year for which to retrieve weather information" + } + } + } + """)); + + public enum FunctionCallTestType + { + Auto, + None, + Function, + } + + [RecordedTest] + [TestCase(FunctionCallTestType.None)] + [TestCase(FunctionCallTestType.Auto)] + [TestCase(FunctionCallTestType.Function)] + [Obsolete] + public async Task SimpleFunctionCallWorks(FunctionCallTestType functionCallType) + { + ChatClient client = GetTestClient(); + + List messages = new() + { + new SystemChatMessage("You are a helpful assistant."), + new UserChatMessage("What should I wear in Honolulu next Thursday?") + }; + var requestOptions = new ChatCompletionOptions() + { + FunctionChoice = functionCallType switch + { + FunctionCallTestType.Auto => ChatFunctionChoice.Auto, + FunctionCallTestType.None => ChatFunctionChoice.None, + FunctionCallTestType.Function => new ChatFunctionChoice(FUNCTION_TEMPERATURE), + _ => throw new NotImplementedException(), + }, + Functions = { FUNCTION_TEMPERATURE }, + MaxTokens = 512, + }; + + ClientResult response = await client.CompleteChatAsync(messages, requestOptions); + Assert.That(response, Is.Not.Null); + + ChatCompletion completion = response.Value; + Assert.IsNotNull(completion); + Assert.That(completion.Id, Is.Not.Null.Or.Empty); + + ContentFilterResultForPrompt filter = completion.GetContentFilterResultForPrompt(); + Assert.IsNotNull(filter); + Assert.That(filter.SelfHarm, Is.Not.Null); + Assert.That(filter.SelfHarm.Filtered, Is.False); + Assert.That(filter.SelfHarm.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + + if (functionCallType == FunctionCallTestType.None) + { + Assert.That(completion.FinishReason, Is.EqualTo(ChatFinishReason.Stop)); + Assert.That(completion.FunctionCall, Is.Null); + + Assert.That(completion.Content, Has.Count.GreaterThan(0)); + Assert.That(completion.Content, Has.All.Not.Null); + + ChatMessageContentPart content = completion.Content[0]; + Assert.That(content.Kind, Is.EqualTo(ChatMessageContentPartKind.Text)); + Assert.That(content.Text, Is.Not.Null.Or.Empty); + + // test complete, as we were merely validating that we didn't get what we shouldn't + return; + } + + // TODO old tests look for stop reason of function_call for both auto and function, but the service currently returns "stop" + // for function + if (functionCallType == FunctionCallTestType.Auto) + { + Assert.That(completion.FinishReason, Is.EqualTo(ChatFinishReason.FunctionCall)); + } + else + { + Assert.That(completion.FinishReason, Is.EqualTo(ChatFinishReason.Stop)); + } + + Assert.That(completion.Content, Has.Count.EqualTo(0)); + + Assert.That(completion.FunctionCall, Is.Not.Null); + Assert.That(completion.FunctionCall.FunctionName, Is.EqualTo(FUNCTION_TEMPERATURE.FunctionName)); + Assert.That(completion.FunctionCall.FunctionArguments, Is.Not.Null); + var parsedArgs = JsonSerializer.Deserialize(completion.FunctionCall.FunctionArguments, SERIALIZER_OPTIONS)!; + Assert.That(parsedArgs, Is.Not.Null); + Assert.That(parsedArgs.LocationName, Is.Not.Null.Or.Empty); + Assert.That(parsedArgs.Date, Is.Not.Null.Or.Empty); + + // Complete the function call + messages.Add(new AssistantChatMessage(completion.FunctionCall)); + messages.Add(new FunctionChatMessage(FUNCTION_TEMPERATURE.FunctionName, JsonSerializer.Serialize(new + { + temperature = 31, + unit = "celsius" + }))); + + requestOptions = new() + { + Functions = { FUNCTION_TEMPERATURE }, + MaxTokens = requestOptions.MaxTokens, + }; + + completion = await client.CompleteChatAsync(messages, requestOptions); + Assert.That(completion, Is.Not.Null); + Assert.That(completion.FinishReason, Is.EqualTo(ChatFinishReason.Stop)); + + ContentFilterResultForResponse responseFilter = completion.GetContentFilterResultForResponse(); + Assert.That(responseFilter, Is.Not.Null); + Assert.That(responseFilter.Hate, Is.Not.Null); + Assert.That(responseFilter.Hate.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + Assert.That(responseFilter.Hate.Filtered, Is.False); + + Assert.That(completion.Content, Has.Count.GreaterThan(0)); + Assert.That(completion.Content[0], Is.Not.Null); + Assert.That(completion.Content[0].Text, Is.Not.Null.Or.Empty); + Assert.That(completion.Content[0].Kind, Is.EqualTo(ChatMessageContentPartKind.Text)); + } + + [RecordedTest] + [TestCase(FunctionCallTestType.None)] + [TestCase(FunctionCallTestType.Auto)] + [TestCase(FunctionCallTestType.Function)] + [Obsolete] + public async Task SimpleFunctionCallWorksStreaming(FunctionCallTestType functionCallType) + { + StringBuilder content = new(); + bool foundPromptFilter = false; + bool foundResponseFilter = false; + string? functionName = null; + StringBuilder functionArgs = new(); + + ChatClient client = GetTestClient(); + + List messages = new() + { + new SystemChatMessage("You are a helpful assistant."), + new UserChatMessage("What should I wear in Honolulu next Thursday?") + }; + var requestOptions = new ChatCompletionOptions() + { + FunctionChoice = functionCallType switch + { + FunctionCallTestType.Auto => ChatFunctionChoice.Auto, + FunctionCallTestType.None => ChatFunctionChoice.None, + FunctionCallTestType.Function => new ChatFunctionChoice(FUNCTION_TEMPERATURE), + _ => throw new NotImplementedException(), + }, + Functions = { FUNCTION_TEMPERATURE }, + MaxTokens = 512, + }; + + Action validateUpdate = (update) => + { + Assert.That(update.ContentUpdate, Is.Not.Null); + Assert.That(update.ContentUpdate, Has.All.Not.Null); + + if (update.FunctionCallUpdate != null) + { + Assert.That(update.FunctionCallUpdate.FunctionName, Is.Null.Or.EqualTo(FUNCTION_TEMPERATURE.FunctionName)); + functionName ??= update.FunctionCallUpdate.FunctionName; + + Assert.That(update.FunctionCallUpdate.FunctionArgumentsUpdate, Is.Not.Null); + functionArgs.Append(update.FunctionCallUpdate.FunctionArgumentsUpdate); + } + + foreach (var part in update.ContentUpdate) + { + Assert.That(part.Kind, Is.EqualTo(ChatMessageContentPartKind.Text)); + Assert.That(part.Text, Is.Not.Null); // Could be empty string + + content.Append(part.Text); + } + + var promptFilter = update.GetContentFilterResultForPrompt(); + if (!foundPromptFilter && promptFilter?.Hate != null) + { + Assert.That(promptFilter.Hate.Filtered, Is.False); + Assert.That(promptFilter.Hate.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + foundPromptFilter = true; + } + + var responseFilter = update.GetContentFilterResultForResponse(); + if (!foundResponseFilter && responseFilter?.Hate != null) + { + Assert.That(responseFilter.Hate.Filtered, Is.False); + Assert.That(responseFilter.Hate.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + foundResponseFilter = true; + } + }; + + AsyncCollectionResult response = client.CompleteChatStreamingAsync(messages, requestOptions); + Assert.That(response, Is.Not.Null); + + await foreach (StreamingChatCompletionUpdate update in response) + { + validateUpdate(update); + } + + Assert.That(foundPromptFilter, Is.True); + + if (functionCallType != FunctionCallTestType.None) + { + Assert.That(functionName, Is.Not.Null); + var parsedArgs = JsonSerializer.Deserialize(functionArgs.ToString(), SERIALIZER_OPTIONS)!; + Assert.That(parsedArgs, Is.Not.Null); + Assert.That(parsedArgs.LocationName, Is.Not.Null.Or.Empty); + Assert.That(parsedArgs.Date, Is.Not.Null.Or.Empty); + + // TODO FIXME: There isn't a clear or obvious way to pass the assitant function message back to the service, and the constructors that allow + // us manual control are internal. So let's use JSON. + var converted = ModelReaderWriter.Read(BinaryData.FromString(JsonSerializer.Serialize(new { name = functionName, arguments = functionArgs.ToString() }))); + messages.Add(new AssistantChatMessage(converted)); + messages.Add(new FunctionChatMessage(FUNCTION_TEMPERATURE.FunctionName, JsonSerializer.Serialize(new + { + temperature = 31, + unit = "celsius" + }))); + + requestOptions = new() + { + Functions = { FUNCTION_TEMPERATURE }, + MaxTokens = requestOptions.MaxTokens, + }; + + content.Clear(); + foundPromptFilter = false; + foundResponseFilter = false; + functionName = null; + functionArgs.Clear(); + + response = client.CompleteChatStreamingAsync(messages, requestOptions); + Assert.That(response, Is.Not.Null); + + await foreach (StreamingChatCompletionUpdate update in response) + { + validateUpdate(update); + } + } + + Assert.That(foundPromptFilter, Is.True); + Assert.That(foundResponseFilter, Is.True); + Assert.That(functionName, Is.Null); + Assert.That(functionArgs, Has.Length.EqualTo(0)); + Assert.That(content.ToString(), Is.Not.Null.Or.Empty); + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Tools.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Tools.cs new file mode 100644 index 000000000..77160c89a --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Tools.cs @@ -0,0 +1,326 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel; +using System.Collections.Generic; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using OpenAI.Chat; +using OpenAI.TestFramework; + +namespace Azure.AI.OpenAI.Tests +{ + public partial class ChatTests + { + private static readonly JsonSerializerOptions SERIALIZER_OPTIONS = new() + { + PropertyNameCaseInsensitive = true, + PropertyNamingPolicy = JsonNamingPolicy.CamelCase + }; + + private static readonly ChatTool TOOL_TEMPERATURE = ChatTool.CreateFunctionTool( + "get_future_temperature", + "requests the anticipated future temperature at a provided location to help inform advice about topics like choice of attire", + BinaryData.FromString( + """ + { + "type": "object", + "properties": { + "locationName": { + "type": "string", + "description": "the name or brief description of a location for weather information" + }, + "date": { + "type": "string", + "description": "the day, month, and year for which to retrieve weather information" + } + } + } + """)); + + private class TemperatureFunctionRequestArguments + { + public string? LocationName { get; set; } + public string? Date { get; set; } + } + + public enum ToolChoiceTestType + { + None, + Auto, + Tool, + Required + } + + [RecordedTest] + [TestCase(ToolChoiceTestType.None)] + [TestCase(ToolChoiceTestType.Auto)] + [TestCase(ToolChoiceTestType.Tool)] + [TestCase(ToolChoiceTestType.Required, Ignore = "This seems to be considered invalid")] + public async Task SimpleToolWorks(ToolChoiceTestType toolChoice) + { + ChatClient client = GetTestClient(); + + List messages = new() + { + new SystemChatMessage("You are a helpful assistant."), + new UserChatMessage("What should I wear in Honolulu next Thursday?") + }; + var requestOptions = new ChatCompletionOptions() + { + ToolChoice = toolChoice switch + { + ToolChoiceTestType.None => ChatToolChoice.None, + ToolChoiceTestType.Auto => ChatToolChoice.Auto, + ToolChoiceTestType.Tool => new ChatToolChoice(TOOL_TEMPERATURE), + ToolChoiceTestType.Required => ChatToolChoice.Required, + _ => throw new NotImplementedException(), + }, + Tools = { TOOL_TEMPERATURE }, + MaxTokens = 512, + }; + + ClientResult response = await client.CompleteChatAsync(messages, requestOptions); + Assert.That(response, Is.Not.Null); + + ChatCompletion completion = response.Value; + Assert.IsNotNull(completion); + Assert.That(completion.Id, Is.Not.Null.Or.Empty); + + ContentFilterResultForPrompt filter = completion.GetContentFilterResultForPrompt(); + Assert.IsNotNull(filter); + Assert.That(filter.SelfHarm, Is.Not.Null); + Assert.That(filter.SelfHarm.Filtered, Is.False); + Assert.That(filter.SelfHarm.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + + if (toolChoice == ToolChoiceTestType.None) + { + Assert.That(completion.FinishReason, Is.EqualTo(ChatFinishReason.Stop)); + Assert.That(completion.ToolCalls, Has.Count.EqualTo(0)); + + Assert.That(completion.Content, Has.Count.GreaterThan(0)); + Assert.That(completion.Content, Has.All.Not.Null); + + ChatMessageContentPart content = completion.Content[0]; + Assert.That(content.Kind, Is.EqualTo(ChatMessageContentPartKind.Text)); + Assert.That(content.Text, Is.Not.Null.Or.Empty); + + // test complete, as we were merely validating that we didn't get what we shouldn't + return; + } + + // TODO old tests look for stop reason of function_call for both auto and function, but the service currently returns "stop" + // for function + if (toolChoice == ToolChoiceTestType.Auto) + { + Assert.That(completion.FinishReason, Is.EqualTo(ChatFinishReason.ToolCalls)); + } + else + { + Assert.That(completion.FinishReason, Is.EqualTo(ChatFinishReason.Stop)); + } + + Assert.That(completion.Content, Has.Count.EqualTo(0)); + Assert.That(completion.ToolCalls, Has.Count.EqualTo(1)); + Assert.That(completion.ToolCalls, Has.All.Not.Null); + + ChatToolCall toolCall = completion.ToolCalls[0]; + Assert.That(toolCall.Id, Is.Not.Null.Or.Empty); + Assert.That(toolCall.Kind, Is.EqualTo(ChatToolCallKind.Function)); + Assert.That(toolCall.FunctionName, Is.EqualTo(TOOL_TEMPERATURE.FunctionName)); + Assert.That(toolCall.FunctionArguments, Is.Not.Null); + var parsedArgs = JsonSerializer.Deserialize(toolCall.FunctionArguments, SERIALIZER_OPTIONS)!; + Assert.That(parsedArgs, Is.Not.Null); + Assert.That(parsedArgs.LocationName, Is.Not.Null.Or.Empty); + Assert.That(parsedArgs.Date, Is.Not.Null.Or.Empty); + + // Complete the tool call + messages.Add(new AssistantChatMessage([toolCall])); + messages.Add(new ToolChatMessage(toolCall.Id, JsonSerializer.Serialize(new + { + temperature = 31, + unit = "celsius" + }))); + + requestOptions = new() + { + Tools = { TOOL_TEMPERATURE }, + MaxTokens = requestOptions.MaxTokens + }; + + completion = await client.CompleteChatAsync(messages, requestOptions); + Assert.That(completion, Is.Not.Null); + Assert.That(completion.FinishReason, Is.EqualTo(ChatFinishReason.Stop)); + + ContentFilterResultForPrompt promptFilter = completion.GetContentFilterResultForPrompt(); + Assert.That(promptFilter, Is.Not.Null); + Assert.That(promptFilter.Hate, Is.Not.Null); + Assert.That(promptFilter.Hate.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + Assert.That(promptFilter.Hate.Filtered, Is.False); + + ContentFilterResultForResponse responseFilter = completion.GetContentFilterResultForResponse(); + Assert.That(responseFilter, Is.Not.Null); + Assert.That(responseFilter.Hate, Is.Not.Null); + Assert.That(responseFilter.Hate.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + Assert.That(responseFilter.Hate.Filtered, Is.False); + + Assert.That(completion.Content, Has.Count.GreaterThan(0)); + Assert.That(completion.Content, Has.All.Not.Null); + Assert.That(completion.Content[0].Text, Is.Not.Null.Or.Empty); + Assert.That(completion.Content[0].Kind, Is.EqualTo(ChatMessageContentPartKind.Text)); + } + + [RecordedTest] + [TestCase(ToolChoiceTestType.None)] + [TestCase(ToolChoiceTestType.Auto)] + [TestCase(ToolChoiceTestType.Tool)] + [TestCase(ToolChoiceTestType.Required, Ignore = "This seems to be considered invalid")] + public async Task SimpleToolWorksStreaming(ToolChoiceTestType toolChoice) + { + StringBuilder content = new(); + bool foundPromptFilter = false; + bool foundResponseFilter = false; + string? toolId = null; + string? toolName = null; + StringBuilder toolArgs = new(); + + ChatClient client = GetTestClient(); + + List messages = new() + { + new SystemChatMessage("You are a helpful assistant."), + new UserChatMessage("What should I wear in Honolulu next Thursday?") + }; + var requestOptions = new ChatCompletionOptions() + { + ToolChoice = toolChoice switch + { + ToolChoiceTestType.None => ChatToolChoice.None, + ToolChoiceTestType.Auto => ChatToolChoice.Auto, + ToolChoiceTestType.Tool => new ChatToolChoice(TOOL_TEMPERATURE), + ToolChoiceTestType.Required => ChatToolChoice.Required, + _ => throw new NotImplementedException(), + }, + Tools = { TOOL_TEMPERATURE }, + MaxTokens = 512, + }; + + Action validateUpdate = (update) => + { + Assert.That(update.ContentUpdate, Is.Not.Null); + Assert.That(update.ContentUpdate, Has.All.Not.Null); + Assert.That(update.ToolCallUpdates, Is.Not.Null); + Assert.That(update.ToolCallUpdates, Has.All.Not.Null); + + if (update.ToolCallUpdates.Count > 0) + { + Assert.That(update.ToolCallUpdates, Has.Count.EqualTo(1)); + + StreamingChatToolCallUpdate toolUpdate = update.ToolCallUpdates[0]; + Assert.That(toolUpdate.Index, Is.EqualTo(0)); + Assert.That(toolUpdate.Id, Is.Null.Or.Not.Empty); + toolId ??= toolUpdate.Id; + Assert.That(toolUpdate.FunctionName, Is.Null.Or.EqualTo(TOOL_TEMPERATURE.FunctionName)); + toolName ??= toolUpdate.FunctionName; + + Assert.That(toolUpdate.FunctionArgumentsUpdate, Is.Not.Null); + toolArgs.Append(toolUpdate.FunctionArgumentsUpdate); + } + + foreach (var part in update.ContentUpdate) + { + Assert.That(part.Kind, Is.EqualTo(ChatMessageContentPartKind.Text)); + Assert.That(part.Text, Is.Not.Null); // Could be empty string + + content.Append(part.Text); + } + + var promptFilter = update.GetContentFilterResultForPrompt(); + if (!foundPromptFilter && promptFilter?.Hate != null) + { + Assert.That(promptFilter.Hate.Filtered, Is.False); + Assert.That(promptFilter.Hate.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + foundPromptFilter = true; + } + + var responseFilter = update.GetContentFilterResultForResponse(); + if (!foundResponseFilter && responseFilter?.Hate != null) + { + Assert.That(responseFilter.Hate.Filtered, Is.False); + Assert.That(responseFilter.Hate.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + foundResponseFilter = true; + } + }; + + AsyncCollectionResult response = client.CompleteChatStreamingAsync(messages, requestOptions); + Assert.That(response, Is.Not.Null); + + await foreach (StreamingChatCompletionUpdate update in response) + { + validateUpdate(update); + } + + Assert.That(foundPromptFilter, Is.True); + + if (toolChoice != ToolChoiceTestType.None) + { + Assert.That(content, Has.Length.EqualTo(0)); + Assert.That(toolId, Is.Not.Null); + Assert.That(toolName, Is.Not.Null); + Assert.That(toolArgs, Has.Length.GreaterThan(0)); + var parsedArgs = JsonSerializer.Deserialize(toolArgs.ToString(), SERIALIZER_OPTIONS)!; + Assert.That(parsedArgs, Is.Not.Null); + Assert.That(parsedArgs.LocationName, Is.Not.Null.Or.Empty); + Assert.That(parsedArgs.Date, Is.Not.Null.Or.Empty); + + // Complete the tool call + messages.Add( + new AssistantChatMessage( + [ + ChatToolCall.CreateFunctionToolCall( + toolId, + toolName, + toolArgs.ToString() + ) + ] + ) + ); + messages.Add(new ToolChatMessage(toolId, JsonSerializer.Serialize(new + { + temperature = 31, + unit = "celsius" + }))); + + requestOptions = new() + { + Tools = { TOOL_TEMPERATURE }, + MaxTokens = requestOptions.MaxTokens + }; + + content.Clear(); + foundPromptFilter = false; + foundResponseFilter = false; + toolId = null; + toolName = null; + toolArgs.Clear(); + + response = client.CompleteChatStreamingAsync(messages, requestOptions); + Assert.That(response, Is.Not.Null); + + await foreach (StreamingChatCompletionUpdate update in response) + { + validateUpdate(update); + } + } + + Assert.That(foundPromptFilter, Is.True); + Assert.That(foundResponseFilter, Is.True); + Assert.That(content.ToString(), Is.Not.Null.Or.Empty); + Assert.That(toolId, Is.Null); + Assert.That(toolName, Is.Null); + Assert.That(toolArgs, Has.Length.EqualTo(0)); + } + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Vision.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Vision.cs new file mode 100644 index 000000000..85399aac4 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Vision.cs @@ -0,0 +1,137 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel; +using System.IO; +using System.Text; +using System.Threading.Tasks; +using OpenAI.Chat; +using OpenAI.TestFramework; + +namespace Azure.AI.OpenAI.Tests +{ + public partial class ChatTests + { + [RecordedTest] + [TestCase(true)] + [TestCase(false)] + public async Task ChatWithImages(bool useUri) + { + var imageAsset = Assets.DogAndCat; + ChatClient client = GetTestClient("vision"); + + ChatMessageContentPart imagePart; + if (useUri) + { + imagePart = ChatMessageContentPart.CreateImageMessageContentPart( + imageAsset.Url, ImageChatMessageContentPartDetail.Low); + } + else + { + using var stream = File.OpenRead(imageAsset.RelativePath); + var imageData = BinaryData.FromStream(stream); + + imagePart = ChatMessageContentPart.CreateImageMessageContentPart( + imageData, imageAsset.MimeType, ImageChatMessageContentPartDetail.Low); + } + + ChatMessage[] messages = + [ + new SystemChatMessage("You are a helpful assistant that helps describe images."), + new UserChatMessage(imagePart, ChatMessageContentPart.CreateTextMessageContentPart("describe this image")) + ]; + + ChatCompletionOptions options = new() + { + MaxTokens = 2048, + }; + + var response = await client.CompleteChatAsync(messages, options); + Assert.That(response, Is.Not.Null); + + Assert.That(response.Value.Id, Is.Not.Null.Or.Empty); + Assert.That(response.Value.CreatedAt, Is.GreaterThan(START_2024)); + Assert.That(response.Value.FinishReason, Is.EqualTo(ChatFinishReason.Stop)); + Assert.That(response.Value.Role, Is.EqualTo(ChatMessageRole.Assistant)); + Assert.That(response.Value.Usage, Is.Not.Null); + Assert.That(response.Value.Usage.InputTokens, Is.GreaterThan(10)); + Assert.That(response.Value.Usage.OutputTokens, Is.GreaterThan(10)); + Assert.That(response.Value.Usage.TotalTokens, Is.GreaterThan(20)); + + Assert.That(response.Value.Content, Has.Count.EqualTo(1)); + ChatMessageContentPart choice = response.Value.Content[0]; + Assert.That(choice.Kind, Is.EqualTo(ChatMessageContentPartKind.Text)); + Assert.That(choice.Text, Is.Not.Null.Or.Empty); + Assert.That(choice.Text.ToLowerInvariant(), Does.Contain("dog").Or.Contain("cat")); + + // TODO FIXME: Some models (e.g. gpt-4o either randomly return prompt filters with some missing entries) + var promptFilter = response.Value.GetContentFilterResultForPrompt(); + Assert.That(promptFilter, Is.Not.Null); + //Assert.That(promptFilter.Hate, Is.Not.Null); + //Assert.That(promptFilter.Hate.Filtered, Is.False); + //Assert.That(promptFilter.Hate.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + + var responseFilter = response.Value.GetContentFilterResultForResponse(); + Assert.That(responseFilter, Is.Not.Null); + Assert.That(responseFilter.Hate, Is.Not.Null); + Assert.That(responseFilter.Hate.Filtered, Is.False); + Assert.That(responseFilter.Hate.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + } + + [RecordedTest] + [TestCase(true)] + [TestCase(false)] + public async Task ChatWithImagesStreaming(bool useUri) + { + bool foundPromptFilter = false; + bool foundResponseFilter = false; + StringBuilder content = new(); + + ChatClient client = GetTestClient("vision"); + + ChatMessageContentPart imagePart; + var imageAsset = Assets.DogAndCat; + if (useUri) + { + imagePart = ChatMessageContentPart.CreateImageMessageContentPart( + imageAsset.Url, ImageChatMessageContentPartDetail.Low); + } + else + { + using var stream = File.OpenRead(imageAsset.RelativePath); + var imageData = BinaryData.FromStream(stream); + + imagePart = ChatMessageContentPart.CreateImageMessageContentPart( + imageData, imageAsset.MimeType, ImageChatMessageContentPartDetail.Low); + } + + ChatMessage[] messages = + [ + new SystemChatMessage("You are a helpful assistant that helps describe images."), + new UserChatMessage(imagePart, ChatMessageContentPart.CreateTextMessageContentPart("describe this image")) + ]; + + ChatCompletionOptions options = new() + { + MaxTokens = 2048, + }; + + AsyncCollectionResult response = client.CompleteChatStreamingAsync(messages, options); + Assert.That(response, Is.Not.Null); + + await foreach (StreamingChatCompletionUpdate update in response) + { + ValidateUpdate(update, content, ref foundPromptFilter, ref foundResponseFilter); + } + + // TOOD FIXME: gpt-4o models seem to return inconsistent prompt filters to skip this for now + //Assert.That(foundPromptFilter, Is.True); + Assert.That(foundResponseFilter, Is.True); + Assert.That(content, Has.Length.GreaterThan(0)); + + string c = content.ToString().ToLowerInvariant(); + Assert.That(c, Does.Contain("dog").Or.Contain("cat")); + } + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.cs new file mode 100644 index 000000000..8e376ecb4 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.cs @@ -0,0 +1,575 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Net.Http; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using Azure.AI.OpenAI.Chat; +using Azure.AI.OpenAI.Tests.Utils.Config; +using OpenAI.Chat; +using OpenAI.TestFramework; +using OpenAI.TestFramework.Mocks; +using OpenAI.TestFramework.Utils; + +namespace Azure.AI.OpenAI.Tests; + +public partial class ChatTests : AoaiTestBase +{ + public ChatTests(bool isAsync) : base(isAsync) + { } + + #region General tests + + [Test] + [Category("Smoke")] + public async Task DefaultUserAgentStringWorks() + { + using MockHttpMessageHandler pipeline = new(MockHttpMessageHandler.ReturnEmptyJson); + + Uri endpoint = new Uri("https://www.bing.com/"); + string apiKey = "not-a-real-one"; + string model = "ignore"; + + AzureOpenAIClient topLevel = new( + endpoint, + new ApiKeyCredential(apiKey), + new AzureOpenAIClientOptions() + { + Transport = pipeline.Transport + }); + + ChatClient client = WrapClient(topLevel.GetChatClient(model)); + + await client.CompleteChatAsync([new UserChatMessage("Hello")]); + + Assert.That(pipeline.Requests, Is.Not.Empty); + + var request = pipeline.Requests[0]; + Assert.That(request.Method, Is.EqualTo(HttpMethod.Post)); + Assert.That(request.Uri?.GetLeftPart(UriPartial.Authority), Is.EqualTo(endpoint.GetLeftPart(UriPartial.Authority))); + Assert.That(request.Headers.GetValueOrDefault("api-key")?.FirstOrDefault(), Is.EqualTo(apiKey)); + Assert.That(request.Headers.GetValueOrDefault("User-Agent")?.FirstOrDefault(), Does.Contain("azsdk-net-AI.OpenAI/")); + Assert.That(request.Content, Is.Not.Null); + var jsonString = request.Content.ToString(); + Assert.That(jsonString, Is.Not.Null.Or.Empty); + Assert.That(jsonString, Does.Contain("\"messages\"").And.Contain("\"model\"").And.Contain(model)); + } + + [Test] + [Category("Smoke")] + public void DataSourceSerializationWorks() + { + AzureSearchChatDataSource source = new() + { + Endpoint = new Uri("https://some-search-resource.azure.com"), + Authentication = DataSourceAuthentication.FromApiKey("test-api-key"), + IndexName = "index-name-here", + FieldMappings = new() + { + ContentFieldNames = { "hello" }, + TitleFieldName = "hi", + }, + AllowPartialResult = true, + QueryType = DataSourceQueryType.Simple, + OutputContextFlags = DataSourceOutputContextFlags.AllRetrievedDocuments | DataSourceOutputContextFlags.Citations, + VectorizationSource = DataSourceVectorizer.FromEndpoint( + new Uri("https://my-embedding.com"), + DataSourceAuthentication.FromApiKey("embedding-api-key")), + }; + dynamic serialized = ModelReaderWriter.Write(source).ToDynamicFromJson(); + Assert.That(serialized?.type?.ToString(), Is.EqualTo("azure_search")); + Assert.That(serialized?.parameters?.authentication?.type?.ToString(), Is.EqualTo("api_key")); + Assert.That(serialized?.parameters?.authentication?.key?.ToString(), Does.Contain("test")); + Assert.That(serialized?.parameters?.index_name?.ToString(), Is.EqualTo("index-name-here")); + Assert.That(serialized?.parameters?.fields_mapping?.content_fields?[0]?.ToString(), Is.EqualTo("hello")); + Assert.That(serialized?.parameters?.fields_mapping?.title_field?.ToString(), Is.EqualTo("hi")); + Assert.That(bool.TryParse(serialized?.parameters?.allow_partial_result?.ToString(), out bool parsed) && parsed == true); + Assert.That(serialized?.parameters?.query_type?.ToString(), Is.EqualTo("simple")); + Assert.That(serialized?.parameters?.include_contexts?[0]?.ToString(), Is.EqualTo("citations")); + Assert.That(serialized?.parameters?.include_contexts?[1]?.ToString(), Is.EqualTo("all_retrieved_documents")); + Assert.That(serialized?.parameters?.embedding_dependency?.type?.ToString(), Is.EqualTo("endpoint")); + + ChatCompletionOptions options = new(); + options.AddDataSource(new ElasticsearchChatDataSource() + { + Authentication = DataSourceAuthentication.FromAccessToken("foo-token"), + Endpoint = new Uri("https://my-elasticsearch.com"), + IndexName = "my-index-name", + InScope = true, + }); + + IReadOnlyList sourcesFromOptions = options.GetDataSources(); + Assert.That(sourcesFromOptions, Has.Count.EqualTo(1)); + Assert.That(sourcesFromOptions[0], Is.InstanceOf()); + Assert.That(((ElasticsearchChatDataSource)sourcesFromOptions[0]).IndexName, Is.EqualTo("my-index-name")); + + options.AddDataSource(new AzureCosmosDBChatDataSource() + { + Authentication = DataSourceAuthentication.FromApiKey("api-key"), + ContainerName = "my-container-name", + DatabaseName = "my_database_name", + FieldMappings = new() + { + ContentFieldNames = { "hello", "world" }, + }, + IndexName = "my-index-name", + VectorizationSource = DataSourceVectorizer.FromDeploymentName("my-deployment"), + }); + sourcesFromOptions = options.GetDataSources(); + Assert.That(sourcesFromOptions, Has.Count.EqualTo(2)); + Assert.That(sourcesFromOptions[1], Is.InstanceOf()); + } + + [RecordedTest] + public async Task ChatCompletionBadKeyGivesHelpfulError() + { + string mockKey = "not-a-valid-key-and-should-still-be-sanitized"; + + try + { + ChatClient chatClient = GetTestClient(keyCredential: new ApiKeyCredential(mockKey)); + _ = await chatClient.CompleteChatAsync([new UserChatMessage("oops, this won't work with that key!")]); + Assert.Fail("No exception was thrown"); + } + catch (Exception thrownException) + { + Assert.That(thrownException, Is.InstanceOf()); + Assert.That(thrownException.Message, Does.Contain("invalid subscription key")); + Assert.That(thrownException.Message, Does.Not.Contain(mockKey)); + } + } + + [RecordedTest] + [Category("Smoke")] + public async Task DefaultAzureCredentialWorks() + { + ChatClient chatClient = GetTestClient(tokenCredential: this.TestEnvironment.Credential); + ChatCompletion chatCompletion = await chatClient.CompleteChatAsync([ChatMessage.CreateUserMessage("Hello, world!")]); + Assert.That(chatCompletion, Is.Not.Null); + Assert.That(chatCompletion.Content, Is.Not.Null.Or.Empty); + Assert.That(chatCompletion.Content[0].Text, Is.Not.Null.Or.Empty); + } + + [RecordedTest] + [Ignore("Delay behavior not emulated by recordings, and needs to be run manually with some time in between iterations due to service throttling behaviour")] + [TestCase("x-ms-retry-after-ms", "1000", 1000)] + [TestCase("retry-after-ms", "1400", 1400)] + [TestCase("Retry-After", "1", 1000)] + [TestCase("Retry-After", "1.5", 1500)] + [TestCase("retry-after-ms", "200", 200)] + [TestCase("x-fake-test-retry-header", "1400", 800)] + public async Task RateLimitedRetryWorks(string headerName, string headerValue, double expectedDelayMilliseconds) + { + const string responseClass = "HttpClientTransportResponse"; + const string responseField = "_httpResponse"; + IConfiguration testConfig = TestConfig.GetConfig("rate_limited_chat")!; + Assert.That(testConfig, Is.Not.Null); + + int failureCount = 0; + string? clientRequestId = null; + + TestPipelinePolicy replaceHeadersPolicy = new( + requestAction: (request) => + { + clientRequestId ??= request.Headers.GetFirstOrDefault("x-ms-client-request-id"); + }, + responseAction: (response) => + { + if (response.Status != 200) + { + failureCount++; + + Type httpPipelineResponseType = typeof(HttpClientPipelineTransport).GetNestedType(responseClass, BindingFlags.NonPublic) + ?? throw new InvalidOperationException($"Could not the expected {responseClass} inner non public class"); + FieldInfo httpResponseField = httpPipelineResponseType.GetField(responseField, BindingFlags.Instance | BindingFlags.NonPublic) + ?? throw new InvalidOperationException($"Could not find the expected {responseClass}.{responseField} field)"); + HttpResponseMessage httpResponse = httpResponseField.GetValue(response) as HttpResponseMessage + ?? throw new InvalidOperationException($"Could note determine the HttpResponseMessage to modify"); + + httpResponse.Headers.Remove("x-ms-retry-after-ms"); + httpResponse.Headers.Remove("retry-after-ms"); + httpResponse.Headers.Remove("Retry-After"); + httpResponse.Headers.TryAddWithoutValidation(headerName, headerValue); + } + }); + + TestClientOptions options = new(); + options.AddPolicy(replaceHeadersPolicy, PipelinePosition.PerTry); + + ChatClient client = GetTestClient(testConfig, options); + + BinaryContent requestContent = BinaryContent.Create(BinaryData.FromString($$""" + { + "model": "{{testConfig.Deployment}}", + "messages": [ + { "role": "user", "content": "Write three haikus about tropical fruit." } + ] + } + """)); + RequestOptions noThrowOptions = new() { ErrorOptions = ClientErrorBehaviors.NoThrow }; + + TimeSpan? observed200Delay = null; + TimeSpan? observed429Delay = null; + + for (int i = 0; i < 4 && !observed429Delay.HasValue; i++) + { + Stopwatch requestWatch = Stopwatch.StartNew(); + ClientResult protocolResult = await client.CompleteChatAsync(requestContent, noThrowOptions); + PipelineResponse response = protocolResult.GetRawResponse(); + bool responseHasRequestId = response.Headers.TryGetValue("x-ms-client-request-id", out string? requestIdFromResponse); + Assert.That(responseHasRequestId, Is.True); + Assert.That(requestIdFromResponse, Is.EqualTo(clientRequestId)); + switch (response.Status) + { + case 200: + observed200Delay = requestWatch.Elapsed; + break; + case 429: + observed429Delay = requestWatch.Elapsed; + break; + default: + Assert.Fail(); + break; + } + clientRequestId = null; + } + + Assert.That(observed200Delay.HasValue, Is.True); + Assert.That(observed429Delay.HasValue, Is.True); + Assert.That(failureCount, Is.EqualTo(4)); + Assert.That(observed429Delay!.Value.TotalMilliseconds, Is.GreaterThan(expectedDelayMilliseconds)); + Assert.That(observed429Delay!.Value.TotalMilliseconds, Is.LessThan(3 * expectedDelayMilliseconds + 2 * observed200Delay!.Value.TotalMilliseconds)); + } + + #endregion + + #region Regular chat completions tests + + [RecordedTest] + public async Task ChatCompletion() + { + ChatClient chatClient = GetTestClient(); + ClientResult chatCompletion = await chatClient.CompleteChatAsync([new UserChatMessage("hello, world!")]); + Assert.That(chatCompletion, Is.Not.Null); + Assert.That(chatCompletion.Value, Is.Not.Null); + Assert.That(chatCompletion.Value, Is.InstanceOf()); + Assert.That(chatCompletion.Value.Content, Is.Not.Null.Or.Empty); + } + + [RecordedTest] + public async Task ChatCompletionWithHistoryAndLogProbabilities() + { + ChatClient client = GetTestClient(); + + ChatCompletion response = await client.CompleteChatAsync( + [ + new SystemChatMessage("You are a helpful assistant."), + new UserChatMessage("I am baking a pizza, can you help me?"), + new AssistantChatMessage("Of course, I'd be happy to help! What do you need assistance with? Do you need a recipe, cooking time and temperature suggestions, topping ideas, or something else?"), + new UserChatMessage("What temperature should I bake at?") + ], + new ChatCompletionOptions() + { + IncludeLogProbabilities = true, + TopLogProbabilityCount = 3 + }); + + Assert.That(response, Is.Not.Null); + Assert.That(response.Id, Is.Not.Null.Or.Empty); + Assert.That(response.CreatedAt, Is.GreaterThan(new DateTimeOffset(2024, 01, 01, 00, 00, 00, TimeSpan.Zero))); + Assert.That(response.FinishReason, Is.Not.Null.Or.Empty); + Assert.That(response.Content, Is.Not.Null.Or.Empty); + Assert.That(response.Content.Count, Is.EqualTo(1)); + Assert.That(response.Usage, Is.Not.Null); + Assert.That(response.Usage.InputTokens, Is.GreaterThan(10)); + Assert.That(response.Usage.OutputTokens, Is.GreaterThan(10)); + Assert.That(response.Usage.TotalTokens, Is.GreaterThan(20)); + Assert.That(response.ContentTokenLogProbabilities, Is.Not.Null.Or.Empty); + foreach (var logProb in response.ContentTokenLogProbabilities) + { + Assert.That(logProb, Is.Not.Null); + Assert.That(logProb.TopLogProbabilities, Is.Not.Null.Or.Empty); + Assert.That(logProb.TopLogProbabilities.Count, Is.EqualTo(3)); + } + + ChatMessageContentPart content = response.Content[0]; + Assert.That(content.Kind, Is.EqualTo(ChatMessageContentPartKind.Text)); + Assert.That(content.Text, Is.Not.Null.Or.Empty); + Assert.That(content.Text, Does + .Contain("Fahrenheit") + .Or.Contain("Celsius") + .Or.Contain("°F") + .Or.Contain("°C") + .Or.Contain("oven")); + } + + [RecordedTest] + public async Task ChatCompletionWithTextFormat() + { + ChatClient client = GetTestClient(); + ChatCompletionOptions options = new() + { + ResponseFormat = ChatResponseFormat.Text + }; + + ChatCompletion response = await client.CompleteChatAsync([new UserChatMessage("Give me a random number")], options); + Assert.That(response, Is.Not.Null); + Assert.That(response.Content, Is.Not.Null.Or.Empty); + Assert.That(response.Content[0].Text, Is.Not.Null.Or.Empty); + } + + [RecordedTest] + public async Task ChatCompletionContentFilter() + { + ChatClient client = GetTestClient(); + ClientResult chatCompletionResult = await client.CompleteChatAsync([ChatMessage.CreateUserMessage("Hello, world!")]); + Console.WriteLine($"--- RESPONSE ---"); + ChatCompletion chatCompletion = chatCompletionResult; + ContentFilterResultForPrompt promptFilterResult = chatCompletion.GetContentFilterResultForPrompt(); + Assert.That(promptFilterResult, Is.Not.Null); + Assert.That(promptFilterResult.Sexual?.Filtered, Is.False); + Assert.That(promptFilterResult.Sexual?.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + ContentFilterResultForResponse responseFilterResult = chatCompletion.GetContentFilterResultForResponse(); + Assert.That(responseFilterResult, Is.Not.Null); + Assert.That(responseFilterResult.Hate?.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + Assert.That(responseFilterResult.ProtectedMaterialCode, Is.Null); + } + + [RecordedTest] + public async Task SearchExtensionWorks() + { + var searchConfig = TestConfig.GetConfig("search")!; + Assert.That(searchConfig, Is.Not.Null); + string searchIndex = searchConfig.GetValueOrThrow("index"); + + AzureSearchChatDataSource source = new() + { + Endpoint = searchConfig.Endpoint, + Authentication = DataSourceAuthentication.FromApiKey(searchConfig.Key), + IndexName = searchIndex, + AllowPartialResult = true, + QueryType = DataSourceQueryType.Simple, + }; + ChatCompletionOptions options = new(); + options.AddDataSource(source); + + ChatClient client = GetTestClient(); + + ClientResult chatCompletionResult = await client.CompleteChatAsync( + [new UserChatMessage("What does the term 'PR complete' mean?")], + options); + Assert.That(chatCompletionResult, Is.Not.Null); + + ChatCompletion chatCompletion = chatCompletionResult.Value; + Assert.That(chatCompletion, Is.Not.Null); + Assert.That(chatCompletion.FinishReason, Is.EqualTo(ChatFinishReason.Stop)); + Assert.That(chatCompletion.Content, Is.Not.Null.Or.Empty); + + var content = chatCompletion.Content[0]; + Assert.That(content.Kind, Is.EqualTo(ChatMessageContentPartKind.Text)); + Assert.That(content.Text, Is.Not.Null.Or.Empty); + + AzureChatMessageContext context = chatCompletion.GetAzureMessageContext(); + Assert.IsNotNull(context); + Assert.That(context.Intent, Is.Not.Null.Or.Empty); + Assert.That(context.Citations, Has.Count.GreaterThan(0)); + Assert.That(context.Citations[0].Filepath, Is.Not.Null.Or.Empty); + Assert.That(context.Citations[0].Content, Is.Not.Null.Or.Empty); + Assert.That(context.Citations[0].ChunkId, Is.Not.Null.Or.Empty); + Assert.That(context.Citations[0].Title, Is.Not.Null.Or.Empty); + } + + #endregion + + #region Streaming chat completion tests + + [RecordedTest] + public async Task ChatCompletionBadKeyGivesHelpfulErrorStreaming() + { + string mockKey = "not-a-valid-key-and-should-still-be-sanitized"; + + try + { + ChatClient chatClient = GetTestClient(keyCredential: new ApiKeyCredential(mockKey)); + var messages = new[] { new UserChatMessage("oops, this won't work with that key!") }; + + AsyncCollectionResult result = chatClient.CompleteChatStreamingAsync(messages); + await foreach (StreamingChatCompletionUpdate update in result) + { + Assert.Fail("No exception was thrown"); + } + + Assert.Fail("No exception was thrown"); + } + catch (Exception thrownException) + { + Assert.That(thrownException, Is.InstanceOf()); + Assert.That(thrownException.Message, Does.Contain("invalid subscription key")); + Assert.That(thrownException.Message, Does.Not.Contain(mockKey)); + } + } + + [RecordedTest] + public async Task ChatCompletionStreaming() + { + StringBuilder builder = new(); + bool foundPromptFilter = false; + bool foundResponseFilter = false; + + ChatClient chatClient = GetTestClient(); + + ChatMessage[] messages = + [ + new SystemChatMessage("You are a curmudgeon"), + new UserChatMessage("Hello, assitant!") + ]; + ChatCompletionOptions options = new() + { + MaxTokens = 512, + IncludeLogProbabilities = true, + TopLogProbabilityCount = 1, + }; + + AsyncCollectionResult streamingResults = chatClient.CompleteChatStreamingAsync(messages, options); + Assert.That(streamingResults, Is.Not.Null); + + await foreach (StreamingChatCompletionUpdate update in streamingResults) + { + ValidateUpdate(update, builder, ref foundPromptFilter, ref foundResponseFilter); + } + + string allText = builder.ToString(); + Assert.That(allText, Is.Not.Null.Or.Empty); + + Assert.That(foundPromptFilter, Is.True); + Assert.That(foundResponseFilter, Is.True); + } + + [RecordedTest] + public async Task SearchExtensionWorksStreaming() + { + StringBuilder builder = new(); + bool foundPromptFilter = false; + bool foundResponseFilter = false; + List contexts = new(); + + var searchConfig = TestConfig.GetConfig("search")!; + Assert.That(searchConfig, Is.Not.Null); + string searchIndex = searchConfig.GetValueOrThrow("index"); + + AzureSearchChatDataSource source = new() + { + Endpoint = searchConfig.Endpoint, + Authentication = DataSourceAuthentication.FromApiKey(searchConfig.Key), + IndexName = searchIndex, + AllowPartialResult = true, + QueryType = DataSourceQueryType.Simple, + }; + + ChatCompletionOptions options = new(); + options.AddDataSource(source); + + ChatMessage[] messages = [new UserChatMessage("What does the term 'PR complete' mean?")]; + + ChatClient client = GetTestClient(); + + AsyncCollectionResult chatUpdates = client.CompleteChatStreamingAsync(messages, options); + Assert.IsNotNull(chatUpdates); + + await foreach (StreamingChatCompletionUpdate update in chatUpdates) + { + ValidateUpdate(update, builder, ref foundPromptFilter, ref foundResponseFilter); + + AzureChatMessageContext context = update.GetAzureMessageContext(); + if (context != null) + { + contexts.Add(context); + } + } + + string allText = builder.ToString(); + Assert.That(allText, Is.Not.Null.Or.Empty); + + // TODO FIXME: When using data sources, the service does not appear to return request nor response filtering information + //Assert.That(foundPromptFilter, Is.True); + //Assert.That(foundResponseFilter, Is.True); + + Assert.That(contexts, Has.Count.EqualTo(1)); + Assert.That(contexts[0].Intent, Is.Not.Null.Or.Empty); + Assert.That(contexts[0].Citations, Has.Count.GreaterThan(0)); + Assert.That(contexts[0].Citations[0].Content, Is.Not.Null.Or.Empty); + Assert.That(contexts[0].Citations[0].Filepath, Is.Not.Null.Or.Empty); + Assert.That(contexts[0].Citations[0].ChunkId, Is.Not.Null.Or.Empty); + Assert.That(contexts[0].Citations[0].Title, Is.Not.Null.Or.Empty); + } + + #endregion + + #region Helper methods + + private void ValidateUpdate(StreamingChatCompletionUpdate update, StringBuilder builder, ref bool foundPromptFilter, ref bool foundResponseFilter) + { + if (update.CreatedAt == UNIX_EPOCH) + { + // This is the first message that usually contains the service's request content filtering + ContentFilterResultForPrompt promptFilter = update.GetContentFilterResultForPrompt(); + if (promptFilter?.SelfHarm != null) + { + Assert.That(promptFilter.SelfHarm.Filtered, Is.False); + Assert.That(promptFilter.SelfHarm.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + foundPromptFilter = true; + } + } + else + { + Assert.That(update.Id, Is.Not.Null.Or.Empty); + Assert.That(update.CreatedAt, Is.GreaterThan(new DateTimeOffset(2024, 01, 01, 00, 00, 00, TimeSpan.Zero))); + Assert.That(update.FinishReason, Is.Null.Or.EqualTo(ChatFinishReason.Stop)); + if (update.Usage != null) + { + Assert.That(update.Usage.InputTokens, Is.GreaterThanOrEqualTo(0)); + Assert.That(update.Usage.OutputTokens, Is.GreaterThanOrEqualTo(0)); + Assert.That(update.Usage.TotalTokens, Is.GreaterThanOrEqualTo(0)); + } + + Assert.That(update.Model, Is.Not.Null); + Assert.That(update.Role, Is.Null.Or.EqualTo(ChatMessageRole.Assistant)); + Assert.That(update.ContentUpdate, Is.Not.Null); + + Assert.That(update.ContentTokenLogProbabilities, Is.Not.Null); + foreach (var logProb in update.ContentTokenLogProbabilities) + { + Assert.That(logProb.TopLogProbabilities, Is.Not.Null); + Assert.That(logProb.TopLogProbabilities.Count, Is.EqualTo(1)); + } + + foreach (ChatMessageContentPart part in update.ContentUpdate) + { + Assert.That(part.Kind, Is.EqualTo(ChatMessageContentPartKind.Text)); + Assert.That(part.Text, Is.Not.Null); + + builder.Append(part.Text); + } + + if (!foundResponseFilter) + { + ContentFilterResultForResponse responseFilter = update.GetContentFilterResultForResponse(); + if (responseFilter?.Violence != null) + { + Assert.That(responseFilter.Violence.Filtered, Is.False); + Assert.That(responseFilter.Violence.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + foundResponseFilter = true; + } + } + } + + #endregion + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/EmbeddingTests.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/EmbeddingTests.cs new file mode 100644 index 000000000..771d59fbb --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/EmbeddingTests.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.Threading.Tasks; +using OpenAI.Embeddings; +using OpenAI.TestFramework; + +namespace Azure.AI.OpenAI.Tests; + +public class EmbeddingTests : AoaiTestBase +{ + public EmbeddingTests(bool isAsync) : base(isAsync) + { } + + [Test] + [Category("Smoke")] + public void CanCreateClient() => Assert.That(GetTestClient(), Is.InstanceOf()); + + [RecordedTest] + public async Task SimpleEmbeddingWithTopLevelClient() + { + EmbeddingClient embeddingClient = GetTestClient(); + ClientResult embeddingResult = await embeddingClient.GenerateEmbeddingAsync("sample text to embed"); + Assert.That(embeddingResult?.Value?.Vector.Length, Is.GreaterThan(0)); + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/FileTests.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/FileTests.cs new file mode 100644 index 000000000..48b533798 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/FileTests.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Threading.Tasks; +using OpenAI.Files; +using OpenAI.TestFramework; + +namespace Azure.AI.OpenAI.Tests; + +public class FileTests : AoaiTestBase +{ + public FileTests(bool isAsync) : base(isAsync) + { } + + [Test] + [Category("Smoke")] + public void CanCreateClient() => Assert.That(GetTestClient(), Is.InstanceOf()); + + [RecordedTest] + public async Task CanUploadAndDeleteFiles() + { + FileClient client = GetTestClient(); + OpenAIFileInfo file = await client.UploadFileAsync( + BinaryData.FromString("hello, world!"), + "test_file_delete_me.txt", + FileUploadPurpose.Assistants); + Validate(file); + bool deleted = await client.DeleteFileAsync(file.Id); + Assert.IsTrue(deleted); + } + + [RecordedTest] + public async Task CanListFiles() + { + FileClient client = GetTestClient(); + OpenAIFileInfoCollection files = await client.GetFilesAsync(); + Assert.That(files, Has.Count.GreaterThan(0)); + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/FineTuningTests.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/FineTuningTests.cs new file mode 100644 index 000000000..b77689edf --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/FineTuningTests.cs @@ -0,0 +1,416 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Threading.Tasks; +using Azure.AI.OpenAI.FineTuning; +using Azure.AI.OpenAI.Tests.Models; +using Azure.AI.OpenAI.Tests.Utils; +using Azure.AI.OpenAI.Tests.Utils.Config; +using OpenAI.Chat; +using OpenAI.Files; +using OpenAI.FineTuning; +using OpenAI.TestFramework; +using OpenAI.TestFramework.Utils; + +namespace Azure.AI.OpenAI.Tests; + +public class FineTuningTests : AoaiTestBase +{ + public FineTuningTests(bool isAsync) : base(isAsync) + { } + + [Test] + [Category("Smoke")] + public void CanCreateClient() => Assert.That(GetTestClient(), Is.InstanceOf()); + + [RecordedTest] + public async Task JobsFineTuning() + { + FineTuningClient client = GetTestClient(); + + int count = 25; + + await foreach (FineTuningJob job in EnumerateJobsAsync(client)) + { + if (count-- <= 0) + { + break; + } + + Assert.That(job, Is.Not.Null); + Assert.That(job.ID, !(Is.Null.Or.Empty)); + Assert.That(job.FineTunedModel, Is.Null.Or.Not.Empty); // this either null or set to some non-empty value + Assert.That(job.Status, !(Is.Null.Or.Empty)); + Assert.That(job.Object, Is.EqualTo("fine_tuning.job")); + } + } + + [RecordedTest] + public async Task CheckpointsFineTuning() + { + string fineTunedModel = GetFineTunedModel(); + FineTuningClient client = GetTestClient(); + + // Check if the model exists by searching all jobs + FineTuningJob job = await EnumerateJobsAsync(client) + .FirstOrDefaultAsync(j => j.FineTunedModel == fineTunedModel)!; + Assert.That(job, Is.Not.Null); + Assert.That(job!.Status, Is.EqualTo("succeeded")); + + int count = 25; + await foreach (FineTuningCheckpoint checkpoint in EnumerateCheckpoints(client, job.ID)) + { + if (count-- <= 0) + { + break; + } + + Assert.That(checkpoint, Is.Not.Null); + Assert.That(checkpoint.ID, !(Is.Null.Or.Empty)); + Assert.That(checkpoint.CreatedAt, Is.GreaterThan(START_2024)); + Assert.That(checkpoint.FineTunedModelCheckpoint, !(Is.Null.Or.Empty)); + Assert.That(checkpoint.Metrics, Is.Not.Null); + Assert.That(checkpoint.Metrics.Step, Is.GreaterThan(0)); + Assert.That(checkpoint.Metrics.TrainLoss, Is.GreaterThan(0)); + Assert.That(checkpoint.Metrics.TrainMeanTokenAccuracy, Is.GreaterThan(0)); + //Assert.That(checkpoint.Metrics.ValidLoss, Is.GreaterThan(0)); + //Assert.That(checkpoint.Metrics.ValidMeanTokenAccuracy, Is.GreaterThan(0)); + //Assert.That(checkpoint.Metrics.FullValidLoss, Is.GreaterThan(0)); + //Assert.That(checkpoint.Metrics.FullValidMeanTokenAccuracy, Is.GreaterThan(0)); + } + } + + [RecordedTest] + public async Task EventsFineTuning() + { + string fineTunedModel = GetFineTunedModel(); + FineTuningClient client = GetTestClient(); + + // Check if the model exists by searching all jobs + FineTuningJob job = await EnumerateJobsAsync(client) + .FirstOrDefaultAsync(j => j.FineTunedModel == fineTunedModel)!; + Assert.That(job, Is.Not.Null); + Assert.That(job!.Status, Is.EqualTo("succeeded")); + + HashSet ids = new(); + + int count = 25; + var asyncEnum = EnumerateAsync((after, limit, opt) => client.GetJobEventsAsync(job.ID, after, limit, opt)); + await foreach (FineTuningJobEvent evt in asyncEnum) + { + if (count-- <= 0) + { + break; + } + + Assert.That(evt, Is.Not.Null); + Assert.That(evt.ID, !(Is.Null.Or.Empty)); + Assert.That(evt.Object, Is.EqualTo("fine_tuning.job.event")); + Assert.That(evt.CreatedAt, Is.GreaterThan(START_2024)); + Assert.That(evt.Level, !(Is.Null.Or.Empty)); + Assert.That(evt.Message, !(Is.Null.Or.Empty)); + + bool added = ids.Add(evt.ID); + Assert.That(added, Is.True, "Duplicate event ID detected {0}", evt.ID); + } + } + + [RecordedTest] + public async Task DeleteFineTuningModel() + { + FineTuningClient client = GetTestClient(); + Assert.That(client, Is.Not.Null); + Assert.That(client, Is.InstanceOf()); + + // The service always happily returns HTTP 204 regardless of whether or not the model exists + bool deleted = await DeleteJobAndVerifyAsync(client, "does-not-exist"); + Assert.That(deleted, Is.True); + } + + [RecordedTest] + public async Task CreateAndCancelFineTuning() + { + var fineTuningFile = Assets.FineTuning; + + FineTuningClient client = GetTestClient(); + FileClient fileClient = GetTestClientFrom(client); + + // upload training data + OpenAIFileInfo uploadedFile = await UploadAndWaitForCompleteOrFail(fileClient, fineTuningFile.RelativePath); + + // Create the fine tuning job + using var requestContent = new FineTuningOptions() + { + Model = client.DeploymentOrThrow(), + TrainingFile = uploadedFile.Id + }.ToBinaryContent(); + + ClientResult result = await client.CreateJobAsync(requestContent); + FineTuningJob job = ValidateAndParse(result); + Assert.That(job.ID, !(Is.Null.Or.Empty)); + + await using RunOnScopeExit _ = new(async () => + { + bool deleted = await DeleteJobAndVerifyAsync(client, job.ID); + Assert.True(deleted, "Failed to delete fine tuning job: {0}", job.ID); + }); + + // Wait for some events to become available + ListResponse events; + int maxLoops = 10; + do + { + result = await client.GetJobEventsAsync(job.ID, null, 10, new()).FirstOrDefaultAsync(); + events = ValidateAndParse>(result); + + if (events.Data?.Count > 0) + { + Assert.That(events.Data[0], Is.Not.Null); + Assert.That(events.Data[0].ID, !(Is.Null.Or.Empty)); + Assert.That(events.Data[0].Level, !(Is.Null.Or.Empty)); + Assert.That(events.Data[0].Message, !(Is.Null.Or.Empty)); + Assert.That(events.Data[0].CreatedAt, Is.GreaterThan(START_2024)); + + break; + } + + await Task.Delay(TimeSpan.FromSeconds(2)); + + } while (maxLoops-- > 0); + + // Cancel the fine tuning job + result = await client.CancelJobAsync(job.ID, new()); + job = ValidateAndParse(result); + + // Make sure the job status shows as cancelled + job = await WaitForJobToEnd(client, job); + Assert.That(job.Status, Is.EqualTo("cancelled")); + } + + [RecordedTest] + [Category("LongRunning")] // CAUTION: This test can take up 30 *minutes* to run in live mode + public async Task CreateAndDeleteFineTuning() + { + var fineTuningFile = Assets.FineTuning; + + FineTuningClient client = GetTestClient(); + FileClient fileClient = GetTestClientFrom(client); + + // upload training data + OpenAIFileInfo uploadedFile = await UploadAndWaitForCompleteOrFail(fileClient, fineTuningFile.RelativePath); + Assert.That(uploadedFile.Status, Is.EqualTo(OpenAIFileStatus.Processed)); + + // Create the fine tuning job + using var requestContent = new FineTuningOptions() + { + Model = client.DeploymentOrThrow(), + TrainingFile = uploadedFile.Id + }.ToBinaryContent(); + + ClientResult result = await client.CreateJobAsync(requestContent); + FineTuningJob job = ValidateAndParse(result); + Assert.That(job.ID, Is.Not.Null.Or.Empty); + Assert.That(job.Error, Is.Null); + Assert.That(job.Status, !(Is.Null.Or.EqualTo("failed").Or.EqualTo("cancelled"))); + + // Wait for the fine tuning to complete + job = await WaitForJobToEnd(client, job); + Assert.That(job.Status, Is.EqualTo("succeeded"), "Fine tuning did not succeed"); + Assert.That(job.FineTunedModel, Is.Not.Null.Or.Empty); + + // Delete the fine tuned model + bool deleted = await DeleteJobAndVerifyAsync(client, job.ID); + Assert.True(deleted, "Failed to delete fine tuning model: {0}", job.FineTunedModel); + } + + [RecordedTest] + [Category("LongRunning")] // CAUTION: This test can take around 10 to 15 *minutes* in live mode to run + public async Task DeployAndChatWithModel() + { + string fineTunedModel = GetFineTunedModel(); + FineTuningClient client = GetTestClient(); + + AzureDeploymentClient deploymentClient = GetTestClientFrom(client); + string? deploymentName = null; + await using RunOnScopeExit _ = new(async () => + { + if (deploymentName != null) + { + await deploymentClient.DeleteDeploymentAsync(deploymentName); + } + }); + + // Check if the model exists by searching all jobs + FineTuningJob? job = await EnumerateJobsAsync(client) + .FirstOrDefaultAsync(j => j.FineTunedModel == fineTunedModel); + Assert.That(job, Is.Not.Null); + Assert.That(job!.Status, Is.EqualTo("succeeded")); + + // Deploy the model and wait for the deployment to finish + deploymentName = "azure-ai-openai-test-" + Recording?.Random.NewGuid().ToString(); + AzureDeployedModel deployment = await deploymentClient.CreateDeploymentAsync(deploymentName, fineTunedModel); + Assert.That(deployment, Is.Not.Null); + Assert.That(deployment.ID, !(Is.Null.Or.Empty)); + Assert.That(deployment.Properties, Is.Not.Null); + + deployment = await WaitUntilReturnLast( + deployment, + () => deploymentClient.GetDeploymentAsync(deploymentName), + (d) => + { + Assert.That(deployment?.Properties?.ProvisioningState, !(Is.Null.Or.Empty)); + + return d.Properties.ProvisioningState == "Succeeded" + || d.Properties.ProvisioningState == "Failed" + || d.Properties.ProvisioningState == "Canceled"; + }, + TimeSpan.FromMinutes(1), + TimeSpan.FromMinutes(30)); + + Assert.That(deployment.Properties.ProvisioningState, Is.EqualTo("Succeeded")); + + // Run a chat completion test + ChatClient chatClient = GetTestClientFrom(client, deploymentName); + + ChatCompletion completion = await chatClient.CompleteChatAsync( + [ + new SystemChatMessage("Convert sports headline to JSON: \"player\" (full name), \"team\", \"sport\", and \"gender\". If more than one return an array. No markdown"), + new UserChatMessage("Pavleski will not play in 2024-2025 season") + ]); + Assert.That(completion, Is.Not.Null); + Assert.That(completion.FinishReason, Is.EqualTo(ChatFinishReason.Stop)); + Assert.That(completion.Content, Has.Count.GreaterThan(0)); + Assert.That(completion.Content[0].Kind, Is.EqualTo(ChatMessageContentPartKind.Text)); + Assert.That(completion.Content[0].Text, !Is.Null.Or.Empty); + + // we expect a JSON payload as the response so let's try to deserialize it + using var jsonDoc = JsonDocument.Parse(completion.Content[0].Text, new() + { + AllowTrailingCommas = true, + CommentHandling = JsonCommentHandling.Skip, + MaxDepth = 2 + }); + JsonElement json = jsonDoc.RootElement; + if (json.ValueKind == JsonValueKind.Array) + { + json = json.EnumerateArray().FirstOrDefault(); + } + + Assert.That(json.ValueKind, Is.EqualTo(JsonValueKind.Object)); + Assert.That(json.EnumerateObject().Select(p => p.Name), Has.Some.Match("(player)|(team)|(sport)|(gender)")); + } + + #region helper methods + + private string GetFineTunedModel() + { + string? model = TestConfig.GetConfig() + ?.GetValue("fine_tuned_model"); + Assert.That(model, !(Is.Null.Or.Empty), "Failed to find the already fine tuned model to use"); + return model!; + } + + private async Task UploadAndWaitForCompleteOrFail(FileClient fileClient, string path) + { + OpenAIFileInfo uploadedFile = await fileClient.UploadFileAsync(path, FileUploadPurpose.FineTune); + Validate(uploadedFile); + + uploadedFile = await WaitUntilReturnLast( + uploadedFile, + () => fileClient.GetFileAsync(uploadedFile.Id), + f => f.Status == OpenAIFileStatus.Processed || f.Status == OpenAIFileStatus.Error, + TimeSpan.FromSeconds(5), + TimeSpan.FromMinutes(5)) + .ConfigureAwait(false); + + return uploadedFile; + } + + private Task WaitForJobToEnd(FineTuningClient client, FineTuningJob job) + { + RequestOptions options = new(); + string jobId = job.ID; + + // NOTE: Fine tuning jobs can take up 30 minutes to complete so the timeouts here are longer to account for that + return WaitUntilReturnLast( + job, + async () => + { + ClientResult result = await client.GetJobAsync(jobId, options).ConfigureAwait(false); + return ValidateAndParse(result); + }, + j => j.Status == "cancelled" || j.Status == "failed" || j.Status == "succeeded", + TimeSpan.FromMinutes(1), + TimeSpan.FromMinutes(40)); + } + + private IAsyncEnumerable EnumerateJobsAsync(FineTuningClient client) + => EnumerateAsync((after, limit, opt) => client.GetJobsAsync(after, limit, opt)); + + private IAsyncEnumerable EnumerateCheckpoints(FineTuningClient client, string jobId) + => EnumerateAsync((after, limit, opt) => client.GetJobCheckpointsAsync(jobId, after, limit, opt)); + + private async IAsyncEnumerable EnumerateAsync(Func> getAsyncEnumerable) + where T : FineTuningModelBase + { + int numPerFetch = 10; + RequestOptions reqOptions = new(); + + await foreach (ClientResult pageResult in getAsyncEnumerable(null, numPerFetch, reqOptions)) + { + ListResponse items = ValidateAndParse>(pageResult); + if (items.Data?.Count > 0) + { + foreach (T item in items.Data) + { + yield return item; + } + } + } + } + + private async Task DeleteJobAndVerifyAsync(FineTuningClient client, string jobId, TimeSpan? timeBetween = null, TimeSpan? maxWaitTime = null) + { + var stopTime = DateTimeOffset.Now + (maxWaitTime ?? TimeSpan.FromMinutes(1)); + var sleepTime = timeBetween ?? TimeSpan.FromSeconds(2); + + RequestOptions noThrow = new() + { + ErrorOptions = ClientErrorBehaviors.NoThrow + }; + + // Since the DeleteJob and DeleteJobAsync are extensions methods, we need to call them on the unwrapped type, + // instead of the dynamically wrapped type. + var rawClient = UnWrap(client); + + bool success = false; + while (DateTimeOffset.Now < stopTime) + { + ClientResult result = IsAsync + ? await rawClient.DeleteJobAsync(jobId, noThrow).ConfigureAwait(false) + : rawClient.DeleteJob(jobId, noThrow); + Assert.That(result, Is.Not.Null); + + // verify the deletion actually succeeded + result = await client.GetJobAsync(jobId, noThrow).ConfigureAwait(false); + var rawResponse = result.GetRawResponse(); + success = rawResponse.Status == 404; + if (success) + { + break; + } + + await Task.Delay(sleepTime).ConfigureAwait(false); + } + + return success; + } + + #endregion +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ImageTests.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ImageTests.cs new file mode 100644 index 000000000..f82f5762f --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/ImageTests.cs @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel; +using System.Threading.Tasks; +using OpenAI.Images; +using OpenAI.TestFramework; + +namespace Azure.AI.OpenAI.Tests; + +public class ImageTests(bool isAsync) : AoaiTestBase(isAsync) +{ + [RecordedTest] + [Category("Smoke")] + public void CanCreateClient() + { + ImageClient client = GetTestClient(tokenCredential: TestEnvironment.Credential); + Assert.That(client, Is.InstanceOf()); + } + + [RecordedTest] + public async Task BadKeyGivesHelpfulError() + { + string mockKey = "not-a-valid-key-and-should-still-be-sanitized"; + + try + { + ImageClient client = GetTestClient(keyCredential: new ApiKeyCredential(mockKey)); + _ = await client.GenerateImageAsync("a delightful exception message, in contemporary watercolor"); + Assert.Fail("No exception was thrown"); + } + catch (Exception thrownException) + { + Assert.That(thrownException, Is.InstanceOf()); + Assert.That(thrownException.Message, Does.Contain("invalid subscription key")); + Assert.That(thrownException.Message, Does.Not.Contain(mockKey)); + } + } + + [RecordedTest] + public async Task CanCreateSimpleImage() + { + ImageClient client = GetTestClient(); + GeneratedImage image = await client.GenerateImageAsync("a tabby cat", new() + { + Quality = GeneratedImageQuality.Standard, + Size = GeneratedImageSize.W1024xH1024, + EndUserId = "test_user", + ResponseFormat = GeneratedImageFormat.Bytes, + }); + Assert.That(image, Is.Not.Null); + Assert.That(image.ImageBytes, Is.Not.Null); + } + + [RecordedTest] + public async Task CanGetContentFilterResults() + { + ImageClient client = GetTestClient(); + ClientResult imageResult = await client.GenerateImageAsync("a tabby cat", new() + { + Quality = GeneratedImageQuality.Standard, + Size = GeneratedImageSize.W1024xH1024, + EndUserId = "test_user", + ResponseFormat = GeneratedImageFormat.Uri, + }); + GeneratedImage image = imageResult.Value; + Assert.That(image, Is.Not.Null); + Assert.That(image.ImageUri, Is.Not.Null); + Console.WriteLine($"RESPONSE--\n{imageResult.GetRawResponse().Content}"); + ImageContentFilterResultForPrompt promptResults = image.GetContentFilterResultForPrompt(); + ImageContentFilterResultForResponse responseResults = image.GetContentFilterResultForResponse(); + Assert.That(promptResults?.Sexual?.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + Assert.That(responseResults?.Sexual?.Severity, Is.EqualTo(ContentFilterSeverity.Safe)); + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/AutoOrLongValue.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/AutoOrLongValue.cs new file mode 100644 index 000000000..e794736a3 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/AutoOrLongValue.cs @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Globalization; +using System.Text.Json; + +namespace Azure.AI.OpenAI.Tests.Models; + +public readonly struct AutoOrLongValue +{ + public const string NULL = "<>"; + public const string AUTO = "auto"; + + private readonly long? _longValue; + private readonly string _stringValue; + + public AutoOrLongValue() + { + _longValue = null; + _stringValue = NULL; + } + + public AutoOrLongValue(string value) + { + if (value == null) + { + throw new ArgumentNullException("value"); + } + else if (string.Equals(value, AUTO, StringComparison.OrdinalIgnoreCase)) + { + _longValue = null; + _stringValue = AUTO; + } + else if (string.Equals(value, NULL, StringComparison.OrdinalIgnoreCase)) + { + _longValue = null; + _stringValue = NULL; + } + else + { + throw new NotSupportedException(); + } + } + + public AutoOrLongValue(long value) + { + _longValue = value; + _stringValue = value.ToString(CultureInfo.InvariantCulture); + } + + public JsonElement? ToJsonElement() + { + if (_stringValue == NULL) + { + return null; + } + + using var json = JsonDocument.Parse( + _longValue?.ToString(CultureInfo.InvariantCulture) + ?? $"\"{_stringValue}\""); + + return json.RootElement.Clone(); + } + + public static AutoOrLongValue FromJsonElement(JsonElement element) + { + if (element.ValueKind == JsonValueKind.String) + { + return new(element.GetString() ?? NULL); + } + else if (element.ValueKind == JsonValueKind.Null) + { + return new(); + } + else if (element.ValueKind == JsonValueKind.Number) + { + return new(element.GetInt64()); + } + else + { + throw new JsonException("Unsupported element kind: " + element.ValueKind); + } + } + + public bool HasValue => _stringValue != NULL && HasLongValue; + public string StringValue => _stringValue; + public bool HasLongValue => _longValue.HasValue; + public long LongValue => _longValue ?? throw new InvalidOperationException("No corresponding long value"); + + public static implicit operator AutoOrLongValue(long val) => new AutoOrLongValue(val); + public static implicit operator AutoOrLongValue(string? val) => new AutoOrLongValue(val ?? NULL); +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/AzureDeployedModel.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/AzureDeployedModel.cs new file mode 100644 index 000000000..bcb3cfaa8 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/AzureDeployedModel.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +namespace Azure.AI.OpenAI.Tests.Models; + +public class AzureDeployedModel +{ + required public string ID { get; init; } + required public string Name { get; init; } + required public Props Properties { get; init; } + + public class Props + { + required public ModelInfo Model { get; init; } + required public string ProvisioningState { get; init; } + } + + public class ModelInfo + { + public string? Model { get; init; } + required public string Name { get; init; } + required public string Version { get; init; } + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/AzureDeploymentClient.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/AzureDeploymentClient.cs new file mode 100644 index 000000000..3b51c441d --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/AzureDeploymentClient.cs @@ -0,0 +1,252 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.AI.OpenAI.Tests.Utils; +using Azure.AI.OpenAI.Tests.Utils.Config; +using OpenAI.TestFramework.Utils; + +namespace Azure.AI.OpenAI.Tests.Models; + +internal class AzureDeploymentClient : IDisposable +{ + private const string DEFAULT_API_VERSION = "2023-10-01-preview"; + private const string DEFAULT_SKU_NAME = "standard"; + private const int DEFAULT_CAPACITY = 1; + + private CancellationTokenSource _cts; + private ClientPipeline _pipeline; + private Core.AccessToken? _cachedAuthToken; + private readonly Core.TokenCredential _credential; + private readonly string _subscriptionId; + private readonly string _resourceGroup; + private readonly string _resourceName; + private readonly string _endpointUrl; + private readonly string _apiVersion; + + internal AzureDeploymentClient() + { + // for mocking + _cts = new(); + _pipeline = ClientPipeline.Create(); + _subscriptionId = _resourceGroup = _resourceName = _endpointUrl = string.Empty; + _apiVersion = DEFAULT_API_VERSION; + _credential = null!; + } + + public AzureDeploymentClient(IConfiguration config, Core.TokenCredential credential, string? apiVersion = null, PipelineTransport? transport = null) + { + if (config == null) + { + throw new ArgumentNullException(nameof(config)); + } + + _cts = new(); + _pipeline = ClientPipeline.Create(new ClientPipelineOptions() + { + Transport = transport ?? new HttpClientPipelineTransport() + }); + _credential = credential ?? throw new ArgumentNullException(nameof(credential)); + + _subscriptionId = config.GetValueOrThrow("subscription_id"); + _resourceGroup = config.GetValueOrThrow("resource_group"); + _resourceName = config.Endpoint?.IdnHost.Split('.').FirstOrDefault() + ?? throw new KeyNotFoundException("Could extract the resource name from the endpoint URL in the config"); + + _endpointUrl = $"https://management.azure.com/subscriptions/{_subscriptionId}/resourceGroups/{_resourceGroup}/providers/Microsoft.CognitiveServices/accounts/{_resourceName}/deployments/"; + + _apiVersion = DEFAULT_API_VERSION; + if (!string.IsNullOrWhiteSpace(apiVersion)) + { + _apiVersion = Uri.EscapeDataString(apiVersion); + } + } + + public virtual AzureDeployedModel CreateDeployment(string deploymentName, string modelName, string? skuName = DEFAULT_SKU_NAME, int capacity = DEFAULT_CAPACITY, CancellationToken token = default) + => CreateDeploymentAsync(false, deploymentName, modelName, skuName, capacity, token).GetAwaiter().GetResult(); + + public virtual Task CreateDeploymentAsync(string deploymentName, string modelName, string? skuName = DEFAULT_SKU_NAME, int capacity = DEFAULT_CAPACITY, CancellationToken token = default) + => CreateDeploymentAsync(true, deploymentName, modelName, skuName, capacity, token).AsTask(); + + public virtual AzureDeployedModel GetDeployment(string deploymentName, CancellationToken token = default) + => GetDeploymentAsync(false, deploymentName, token).GetAwaiter().GetResult(); + + public virtual Task GetDeploymentAsync(string deploymentName, CancellationToken token = default) + => GetDeploymentAsync(true, deploymentName, token).AsTask(); + + public virtual bool DeleteDeployment(string deploymentName, CancellationToken token = default) + => DeleteDeploymentAsync(false, deploymentName, token).GetAwaiter().GetResult(); + + public virtual Task DeleteDeploymentAsync(string deploymentName, CancellationToken token = default) + => DeleteDeploymentAsync(true, deploymentName, token).AsTask(); + + public void Dispose() + { + _cts.Cancel(); + _cts.Dispose(); + } + + private async ValueTask CreateDeploymentAsync(bool isAsync, string deploymentName, string modelName, string? skuName, int capacity, CancellationToken token) + { + BinaryContent content = ToJsonContent(new + { + sku = new + { + name = skuName, + capacity = capacity.ToString(CultureInfo.InvariantCulture), + }, + properties = new + { + model = new + { + format = "OpenAI", + name = modelName, + version = "1" + } + } + }); + + PipelineResponse response = await SendRequestAsync(isAsync, HttpMethod.Put, deploymentName, content, token) + .ConfigureAwait(false); + return FromJsonContent(response, token); + } + + private async ValueTask GetDeploymentAsync(bool isAsync, string deploymentName, CancellationToken token) + { + PipelineResponse response = await SendRequestAsync(isAsync, HttpMethod.Get, deploymentName, null, token) + .ConfigureAwait(false); + return FromJsonContent(response, token); + } + + private async ValueTask DeleteDeploymentAsync(bool isAsync, string deploymentName, CancellationToken token) + { + PipelineResponse response = await SendRequestAsync(isAsync, HttpMethod.Delete, deploymentName, null, token) + .ConfigureAwait(false); + ThrowOnFailed(response); + return true; + } + + private static BinaryContent ToJsonContent(T value) + { + Utf8JsonBinaryContent content = new(); + JsonSerializer.Serialize(content.JsonWriter, value, typeof(T), JsonOptions.AzureJsonOptions); + return content; + } + + private class ErrorDetail + { + public string? Code { get; init; } + public string? Message { get; init; } + } + + private class ErrorInfo + { + public ErrorDetail? Error { get; init; } + } + + private static void ThrowOnFailed(PipelineResponse response) + { + if (response.IsError) + { + if (response.Content != null + && response.Headers.GetFirstOrDefault("Content-Type")?.StartsWith("application/json") == true) + { + using Stream errorStream = response.Content.ToStream(); + ErrorInfo? error = JsonHelpers.Deserialize(errorStream, JsonOptions.AzureJsonOptions); + if (error?.Error != null) + { + throw new ClientResultException($"[{response.Status} - {error.Error.Code}] {error.Error.Message}", response); + } + } + + throw new ClientResultException(response); + } + } + + private static T FromJsonContent(PipelineResponse response, CancellationToken token) + { + ThrowOnFailed(response); + + using Stream stream = response.Content.ToStream(); + return JsonHelpers.Deserialize(stream, JsonOptions.AzureJsonOptions) + ?? throw new InvalidDataException("Service returned a null JSON response body"); + } + + private async ValueTask SendRequestAsync(bool isAsync, HttpMethod method, string pathPart, BinaryContent? body, CancellationToken token) + { + var linked = CancellationTokenSource.CreateLinkedTokenSource(_cts.Token, token); + + PipelineMessage message = _pipeline.CreateMessage(); + message.Apply(new() + { + CancellationToken = linked.Token, + ErrorOptions = ClientErrorBehaviors.NoThrow + }); + + string requestId = Guid.NewGuid().ToString(); + string bearerToken = await GetOrRenewAuthTokenAsync(isAsync, requestId, token).ConfigureAwait(false); + + string fullEndpoint = _endpointUrl + pathPart + "?api-version=" + _apiVersion; + + PipelineRequest request = message.Request; + request.Method = method.Method; + request.Uri = new Uri(fullEndpoint); + request.Headers.Add("x-ms-client-request-id", requestId); + request.Headers.Add("Authorization", "Bearer " + bearerToken); + if (body != null) + { + request.Headers.Add("Content-Type", "application/json"); + request.Content = body; + } + + if (isAsync) + { + await _pipeline.SendAsync(message).ConfigureAwait(false); + } + else + { + _pipeline.Send(message); + } + + return message.Response ?? throw new InvalidOperationException("No response was set after sending"); + } + + private async ValueTask GetOrRenewAuthTokenAsync(bool isAsync, string requestId, CancellationToken token) + { + // TODO FIXME: Use more streamlined way to get bearer auth token + if (_cachedAuthToken?.ExpiresOn > DateTimeOffset.Now.AddSeconds(-5)) + { + return _cachedAuthToken.Value.Token; + } + + var context = new Core.TokenRequestContext( + [ + "https://management.azure.com/.default" + ], + requestId); + + Core.AccessToken authToken; + if (isAsync) + { + authToken = await _credential.GetTokenAsync(context, token).ConfigureAwait(false); + } + else + { + authToken = _credential.GetToken(context, token); + } + + string bearerToken = authToken.Token; + _cachedAuthToken = authToken; + return bearerToken; + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/BatchObject.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/BatchObject.cs new file mode 100644 index 000000000..0c0e7d517 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/BatchObject.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Text.Json; +using Azure.AI.OpenAI.Tests.Utils; + +namespace Azure.AI.OpenAI.Tests.Models; + +public class BatchObject +{ + public static BatchObject From(BinaryData data) + { + return JsonSerializer.Deserialize(data, JsonOptions.OpenAIJsonOptions) + ?? throw new InvalidOperationException("Response was null JSON"); + } + + public string? Status { get; set; } + public string? Id { get; set; } + public string? OutputFileID { get; set; } + public string? ErrorFileId { get; set; } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/BatchOptions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/BatchOptions.cs new file mode 100644 index 000000000..40a580370 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/BatchOptions.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel; +using System.Collections.Generic; +using System.IO; +using Azure.AI.OpenAI.Tests.Utils; +using OpenAI.TestFramework.Utils; + +namespace Azure.AI.OpenAI.Tests.Models; + +public class BatchOptions +{ + public string? InputFileId { get; set; } + public string? Endpoint { get; set; } + public string CompletionWindow { get; set; } = "24h"; + public IDictionary Metadata { get; } = new Dictionary(); + + public BinaryContent ToBinaryContent() + { + using MemoryStream stream = new MemoryStream(); + JsonHelpers.Serialize(stream, this, JsonOptions.OpenAIJsonOptions); + + stream.Seek(0, SeekOrigin.Begin); + var data = BinaryData.FromStream(stream); + return BinaryContent.Create(data); + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/BatchResult.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/BatchResult.cs new file mode 100644 index 000000000..9d36d8d18 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/BatchResult.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Text.Json; +using Azure.AI.OpenAI.Tests.Utils; + +namespace Azure.AI.OpenAI.Tests.Models; + +public class BatchResult +{ + public string? ID { get; init; } + public string? CustomId { get; init; } + public T? Response { get; init; } + public JsonElement? Error { get; init; } + + public static IReadOnlyList> From(BinaryData data) + { + List> list = new(); + using var reader = new StreamReader(data.ToStream(), Encoding.UTF8, false); + string? line; + while ((line = reader.ReadLine()) != null) + { + if (string.IsNullOrWhiteSpace(line)) + { + break; + } + + var entry = JsonSerializer.Deserialize>(line, JsonOptions.OpenAIJsonOptions); + if (entry != null) + { + list.Add(entry); + } + } + + return list; + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningCheckpoint.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningCheckpoint.cs new file mode 100644 index 000000000..8ab2bb46a --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningCheckpoint.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; + +namespace Azure.AI.OpenAI.Tests.Models +{ + public class FineTuningCheckpoint : FineTuningModelBase + { + public DateTimeOffset CreatedAt { get; init; } + public string? FineTunedModelCheckpoint { get; init; } + public string? FineTuningJobID { get; init; } + public int StepNumber { get; init; } + public MetricsInfo Metrics { get; init; } = new MetricsInfo(); + + public class MetricsInfo + { + public int Step { get; init; } + public float TrainLoss { get; init; } + public float TrainMeanTokenAccuracy { get; init; } + public float ValidLoss { get; init; } + public float ValidMeanTokenAccuracy { get; init; } + public float FullValidLoss { get; init; } + public float FullValidMeanTokenAccuracy { get; init; } + } + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningHyperparameters.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningHyperparameters.cs new file mode 100644 index 000000000..9057245c8 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningHyperparameters.cs @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; + +namespace Azure.AI.OpenAI.Tests.Models; + +public class FineTuningHyperparameters : IJsonModel +{ + private Dictionary _values = new(); + + public AutoOrLongValue? BatchSize + { + get => Get("batch_size"); + set => Set("batch_size", value); + } + + public AutoOrLongValue? LearningRateMultiplier + { + get => Get("learning_rate_multiplier"); + set => Set("learning_rate_multiplier", value); + } + + public AutoOrLongValue? NumEpochs + { + get => Get("n_epochs"); + set => Set("n_epochs", value); + } + + private AutoOrLongValue? Get(string key) + { + if (_values.TryGetValue(key, out JsonElement element)) + { + return AutoOrLongValue.FromJsonElement(element); + } + + return null; + } + + private void Set(string key, AutoOrLongValue? value) + { + JsonElement? element = value?.ToJsonElement(); + if (element == null) + { + _values.Remove(key); + } + else + { + _values[key] = element.Value; + } + } + + FineTuningHyperparameters IJsonModel.Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + { + var dict = JsonSerializer.Deserialize>(ref reader); + FineTuningHyperparameters instance = new(); + instance._values = dict ?? new Dictionary(); + return instance; + } + + FineTuningHyperparameters IPersistableModel.Create(BinaryData data, ModelReaderWriterOptions options) + { + ReadOnlyMemory rawData = data.ToMemory(); + var reader = new Utf8JsonReader(rawData.Span); + return ((IJsonModel)this).Create(ref reader, options); + } + + string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) + => ModelReaderWriterOptions.Json.Format; + + void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + writer.WriteStartObject(); + foreach (var kvp in _values) + { + writer.WritePropertyName(kvp.Key); + kvp.Value.WriteTo(writer); + } + writer.WriteEndObject(); + } + + BinaryData IPersistableModel.Write(ModelReaderWriterOptions options) + => ModelReaderWriter.Write(this, options); +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningJob.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningJob.cs new file mode 100644 index 000000000..e2834f3fe --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningJob.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; + +namespace Azure.AI.OpenAI.Tests.Models; + +public class FineTuningJob : FineTuningModelBase +{ + public DateTimeOffset CreatedAt { get; init; } + public IReadOnlyDictionary? Error { get; set; } + public string? FineTunedModel { get; init; } + public string Model { get; init; } = string.Empty; + public string? OrganizationID { get; init; } + public string Status { get; set; } = string.Empty; + public IReadOnlyList? ResultFiles { get; init; } + public int? TrainedTokens { get; init; } + public DateTimeOffset EstimatedFinish { get; init; } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningJobEvent.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningJobEvent.cs new file mode 100644 index 000000000..4e444b0b7 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningJobEvent.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; + +namespace Azure.AI.OpenAI.Tests.Models; + +public class FineTuningJobEvent : FineTuningModelBase +{ + public DateTimeOffset CreatedAt { get; init; } + public string? Level { get; init; } + public string? Message { get; init; } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningModelBase.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningModelBase.cs new file mode 100644 index 000000000..3e75eee9e --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningModelBase.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +namespace Azure.AI.OpenAI.Tests.Models; + +public abstract class FineTuningModelBase +{ + required public string ID { get; init; } + required public string Object { get; init; } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningOptions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningOptions.cs new file mode 100644 index 000000000..4c44995f8 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/FineTuningOptions.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.IO; +using Azure.AI.OpenAI.Tests.Utils; +using OpenAI.TestFramework.Utils; + +namespace Azure.AI.OpenAI.Tests.Models; + +public class FineTuningOptions +{ + required public string TrainingFile { get; init; } + required public string Model { get; init; } + public int? Seed { get; set; } + public string? Suffix { get; set; } + public FineTuningHyperparameters? Hyperparameters { get; init; } + + public BinaryContent ToBinaryContent() + { + MemoryStream stream = new(); + JsonHelpers.Serialize(stream, this, JsonOptions.OpenAIJsonOptions); + stream.Seek(0, SeekOrigin.Begin); + return BinaryContent.Create(stream); + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/ListResponse.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/ListResponse.cs new file mode 100644 index 000000000..c5207f796 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Models/ListResponse.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System.Collections.Generic; + +namespace Azure.AI.OpenAI.Tests.Models; + +public class ListResponse +{ + public bool HasMore { get; init; } + public IReadOnlyList? Data { get; init; } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Properties/AssemblyInfo.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Properties/AssemblyInfo.cs new file mode 100644 index 000000000..278184836 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Properties/AssemblyInfo.cs @@ -0,0 +1,4 @@ +using System.Runtime.CompilerServices; +using Castle.Core.Internal; + +[assembly: InternalsVisibleTo(InternalsVisible.ToDynamicProxyGenAssembly2)] diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Samples/00_ClientConfiguration.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Samples/00_ClientConfiguration.cs new file mode 100644 index 000000000..a4ebb8fc1 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Samples/00_ClientConfiguration.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable disable + +using System; +using Azure.Identity; +using OpenAI.Chat; + +namespace Azure.AI.OpenAI.Samples; + +public partial class AzureOpenAISamples +{ + public void CreateAnAzureOpenAIClient() + { + #region Snippet:ConfigureClient:WithAOAITopLevelClient + string keyFromEnvironment = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY"); + + AzureOpenAIClient azureClient = new( + new Uri("https://your-azure-openai-resource.com"), + new AzureKeyCredential(keyFromEnvironment)); + ChatClient chatClient = azureClient.GetChatClient("my-gpt-35-turbo-deployment"); + #endregion + } + + public void CreateAnAzureOpenAIClientWithEntra() + { + #region Snippet:ConfigureClient:WithEntra + AzureOpenAIClient azureClient = new( + new Uri("https://your-azure-openai-resource.com"), + new DefaultAzureCredential()); + ChatClient chatClient = azureClient.GetChatClient("my-gpt-4o-mini-deployment"); + #endregion + } + + public void UseAzureGovernment() + { + #region Snippet:ConfigureClient:GovernmentAudience + AzureOpenAIClientOptions options = new() + { + Audience = AzureOpenAIAudience.AzureGovernment, + }; + AzureOpenAIClient azureClient = new( + new Uri("https://your-azure-openai-resource.com"), + new DefaultAzureCredential()); + ChatClient chatClient = azureClient.GetChatClient("my-gpt-4o-mini-deployment"); + #endregion + } + + public void UseCustomAuthorizationScope() + { + #region Snippet:ConfigureClient:CustomAudience + AzureOpenAIClientOptions optionsWithCustomAudience = new() + { + Audience = "https://cognitiveservices.azure.com/.default", + }; + #endregion + + AzureOpenAIClient azureClient = new( + new Uri("https://your-azure-openai-resource.com"), + new DefaultAzureCredential()); + ChatClient chatClient = azureClient.GetChatClient("my-gpt-4o-mini-deployment"); + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Samples/01_Chat.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Samples/01_Chat.cs new file mode 100644 index 000000000..3ad7584e5 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Samples/01_Chat.cs @@ -0,0 +1,292 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable disable + +using System; +using System.ClientModel; +using System.Collections.Generic; +using System.Text; +using System.Text.Json; +using Azure.Identity; +using OpenAI.Chat; + +namespace Azure.AI.OpenAI.Samples; + +public partial class AzureOpenAISamples +{ + public void BasicChat() + { + #region Snippet:SimpleChatResponse + AzureOpenAIClient azureClient = new( + new Uri("https://your-azure-openai-resource.com"), + new DefaultAzureCredential()); + ChatClient chatClient = azureClient.GetChatClient("my-gpt-35-turbo-deployment"); + + ChatCompletion completion = chatClient.CompleteChat( + [ + // System messages represent instructions or other guidance about how the assistant should behave + new SystemChatMessage("You are a helpful assistant that talks like a pirate."), + // User messages represent user input, whether historical or the most recen tinput + new UserChatMessage("Hi, can you help me?"), + // Assistant messages in a request represent conversation history for responses + new AssistantChatMessage("Arrr! Of course, me hearty! What can I do for ye?"), + new UserChatMessage("What's the best way to train a parrot?"), + ]); + + Console.WriteLine($"{completion.Role}: {completion.Content[0].Text}"); + #endregion + } + + public void StreamingChat() + { + #region Snippet:StreamChatMessages + AzureOpenAIClient azureClient = new( + new Uri("https://your-azure-openai-resource.com"), + new DefaultAzureCredential()); + ChatClient chatClient = azureClient.GetChatClient("my-gpt-35-turbo-deployment"); + + CollectionResult completionUpdates = chatClient.CompleteChatStreaming( + [ + new SystemChatMessage("You are a helpful assistant that talks like a pirate."), + new UserChatMessage("Hi, can you help me?"), + new AssistantChatMessage("Arrr! Of course, me hearty! What can I do for ye?"), + new UserChatMessage("What's the best way to train a parrot?"), + ]); + + foreach (StreamingChatCompletionUpdate completionUpdate in completionUpdates) + { + foreach (ChatMessageContentPart contentPart in completionUpdate.ContentUpdate) + { + Console.Write(contentPart.Text); + } + } + #endregion + } + + public void ChatWithTools() + { + #region Snippet:ChatTools:DefineTool + static string GetCurrentLocation() + { + // Call the location API here. + return "San Francisco"; + } + + static string GetCurrentWeather(string location, string unit = "celsius") + { + // Call the weather API here. + return $"31 {unit}"; + } + + ChatTool getCurrentLocationTool = ChatTool.CreateFunctionTool( + functionName: nameof(GetCurrentLocation), + functionDescription: "Get the user's current location" + ); + + ChatTool getCurrentWeatherTool = ChatTool.CreateFunctionTool( + functionName: nameof(GetCurrentWeather), + functionDescription: "Get the current weather in a given location", + functionParameters: BinaryData.FromString(""" + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. Boston, MA" + }, + "unit": { + "type": "string", + "enum": [ "celsius", "fahrenheit" ], + "description": "The temperature unit to use. Infer this from the specified location." + } + }, + "required": [ "location" ] + } + """) + ); + #endregion + + AzureOpenAIClient azureClient = new( + new Uri("https://your-azure-openai-resource.com"), + new DefaultAzureCredential()); + ChatClient chatClient = azureClient.GetChatClient("my-gpt-35-turbo-deployment"); + + #region Snippet:ChatTools:RequestWithFunctions + ChatCompletionOptions options = new() + { + Tools = { getCurrentLocationTool, getCurrentWeatherTool }, + }; + + List conversationMessages = + [ + new UserChatMessage("What's the weather like in Boston?"), + ]; + ChatCompletion completion = chatClient.CompleteChat(conversationMessages); + #endregion + + #region Snippet:ChatTools:HandleToolCalls + // Purely for convenience and clarity, this standalone local method handles tool call responses. + string GetToolCallContent(ChatToolCall toolCall) + { + if (toolCall.FunctionName == getCurrentWeatherTool.FunctionName) + { + // Validate arguments before using them; it's not always guaranteed to be valid JSON! + try + { + using JsonDocument argumentsDocument = JsonDocument.Parse(toolCall.FunctionArguments); + if (!argumentsDocument.RootElement.TryGetProperty("location", out JsonElement locationElement)) + { + // Handle missing required "location" argument + } + else + { + string location = locationElement.GetString(); + if (argumentsDocument.RootElement.TryGetProperty("unit", out JsonElement unitElement)) + { + return GetCurrentWeather(location, unitElement.GetString()); + } + else + { + return GetCurrentWeather(location); + } + } + } + catch (JsonException) + { + // Handle the JsonException (bad arguments) here + } + } + // Handle unexpected tool calls + throw new NotImplementedException(); + } + + if (completion.FinishReason == ChatFinishReason.ToolCalls) + { + // Add a new assistant message to the conversation history that includes the tool calls + conversationMessages.Add(new AssistantChatMessage(completion)); + + foreach (ChatToolCall toolCall in completion.ToolCalls) + { + conversationMessages.Add(new ToolChatMessage(toolCall.Id, GetToolCallContent(toolCall))); + } + + // Now make a new request with all the messages thus far, including the original + } + #endregion + } + + public void StreamingChatToolCalls() + { + static string GetCurrentLocation() + { + // Call the location API here. + return "San Francisco"; + } + + static string GetCurrentWeather(string location, string unit = "celsius") + { + // Call the weather API here. + return $"31 {unit}"; + } + + ChatTool getCurrentLocationTool = ChatTool.CreateFunctionTool( + functionName: nameof(GetCurrentLocation), + functionDescription: "Get the user's current location" + ); + + ChatTool getCurrentWeatherTool = ChatTool.CreateFunctionTool( + functionName: nameof(GetCurrentWeather), + functionDescription: "Get the current weather in a given location", + functionParameters: BinaryData.FromString(""" + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. Boston, MA" + }, + "unit": { + "type": "string", + "enum": [ "celsius", "fahrenheit" ], + "description": "The temperature unit to use. Infer this from the specified location." + } + }, + "required": [ "location" ] + } + """) + ); + + AzureOpenAIClient azureClient = new( + new Uri("https://your-azure-openai-resource.com"), + new DefaultAzureCredential()); + ChatClient chatClient = azureClient.GetChatClient("my-gpt-35-turbo-deployment"); + + ChatCompletionOptions options = new() + { + Tools = { getCurrentLocationTool, getCurrentWeatherTool }, + }; + + List conversationMessages = + [ + new UserChatMessage("What's the weather like in Boston?"), + ]; + + #region Snippet:ChatTools:StreamingChatTools + Dictionary toolCallIdsByIndex = []; + Dictionary functionNamesByIndex = []; + Dictionary functionArgumentBuildersByIndex = []; + StringBuilder contentBuilder = new(); + + foreach (StreamingChatCompletionUpdate streamingChatUpdate + in chatClient.CompleteChatStreaming(conversationMessages, options)) + { + foreach (ChatMessageContentPart contentPart in streamingChatUpdate.ContentUpdate) + { + contentBuilder.Append(contentPart.Text); + } + foreach (StreamingChatToolCallUpdate toolCallUpdate in streamingChatUpdate.ToolCallUpdates) + { + if (!string.IsNullOrEmpty(toolCallUpdate.Id)) + { + toolCallIdsByIndex[toolCallUpdate.Index] = toolCallUpdate.Id; + } + if (!string.IsNullOrEmpty(toolCallUpdate.FunctionName)) + { + functionNamesByIndex[toolCallUpdate.Index] = toolCallUpdate.FunctionName; + } + if (!string.IsNullOrEmpty(toolCallUpdate.FunctionArgumentsUpdate)) + { + StringBuilder argumentsBuilder + = functionArgumentBuildersByIndex.TryGetValue(toolCallUpdate.Index, out StringBuilder existingBuilder) + ? existingBuilder + : new(); + argumentsBuilder.Append(toolCallUpdate.FunctionArgumentsUpdate); + functionArgumentBuildersByIndex[toolCallUpdate.Index] = argumentsBuilder; + } + } + } + + List toolCalls = []; + foreach (KeyValuePair indexToIdPair in toolCallIdsByIndex) + { + toolCalls.Add(ChatToolCall.CreateFunctionToolCall( + indexToIdPair.Value, + functionNamesByIndex[indexToIdPair.Key], + functionArgumentBuildersByIndex[indexToIdPair.Key].ToString())); + } + + conversationMessages.Add(new AssistantChatMessage(toolCalls, contentBuilder.ToString())); + + // Placeholder: each tool call must be resolved, like in the non-streaming case + string GetToolCallOutput(ChatToolCall toolCall) => null; + + foreach (ChatToolCall toolCall in toolCalls) + { + conversationMessages.Add(new ToolChatMessage(toolCall.Id, GetToolCallOutput(toolCall))); + } + + // Repeat with the history and all tool call resolution messages added + #endregion + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Samples/02_Oyd.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Samples/02_Oyd.cs new file mode 100644 index 000000000..1f1a85d69 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Samples/02_Oyd.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable disable + +using System; +using Azure.AI.OpenAI.Chat; +using Azure.Identity; +using OpenAI.Chat; + +namespace Azure.AI.OpenAI.Samples; + +public partial class AzureOpenAISamples +{ + public void OnYourDataSearch() + { + AzureOpenAIClient azureClient = new( + new Uri("https://your-azure-openai-resource.com"), + new DefaultAzureCredential()); + ChatClient chatClient = azureClient.GetChatClient("my-gpt-35-turbo-deployment"); + + #region Snippet:ChatUsingYourOwnData + // Extension methods to use data sources with options are subject to SDK surface changes. Suppress the + // warning to acknowledge and this and use the subject-to-change AddDataSource method. + #pragma warning disable AOAI001 + + ChatCompletionOptions options = new(); + options.AddDataSource(new AzureSearchChatDataSource() + { + Endpoint = new Uri("https://your-search-resource.search.windows.net"), + IndexName = "contoso-products-index", + Authentication = DataSourceAuthentication.FromApiKey( + Environment.GetEnvironmentVariable("OYD_SEARCH_KEY")), + }); + + ChatCompletion completion = chatClient.CompleteChat( + [ + new UserChatMessage("What are the best-selling Contoso products this month?"), + ], + options); + + AzureChatMessageContext onYourDataContext = completion.GetAzureMessageContext(); + + if (onYourDataContext?.Intent is not null) + { + Console.WriteLine($"Intent: {onYourDataContext.Intent}"); + } + foreach (AzureChatCitation citation in onYourDataContext?.Citations ?? []) + { + Console.WriteLine($"Citation: {citation.Content}"); + } + #endregion + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Samples/03_Assistants.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Samples/03_Assistants.cs new file mode 100644 index 000000000..aaf5b15ff --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Samples/03_Assistants.cs @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable disable + +using System; +using System.Threading.Tasks; +using Azure.Identity; +using OpenAI.Assistants; + +namespace Azure.AI.OpenAI.Samples; + +public partial class AzureOpenAISamples +{ + public async Task StreamingAssistantRunAsync() + { + #region Snippet:Assistants:CreateClient + AzureOpenAIClient azureClient = new( + new Uri("https://your-azure-openai-resource.com"), + new DefaultAzureCredential()); + + // The Assistants feature area is in beta, with API specifics subject to change. + // Suppress the [Experimental] warning via .csproj or, as here, in the code to acknowledge. + #pragma warning disable OPENAI001 + AssistantClient assistantClient = azureClient.GetAssistantClient(); + #endregion + + #region Snippet:Assistants:PrepareToRun + Assistant assistant = await assistantClient.CreateAssistantAsync( + model: "my-gpt-4o-deployment", + new AssistantCreationOptions() + { + Name = "My Friendly Test Assistant", + Instructions = "You politely help with math questions. Use the code interpreter tool when asked to " + + "visualize numbers.", + Tools = { ToolDefinition.CreateCodeInterpreter() }, + }); + ThreadInitializationMessage initialMessage = new( + MessageRole.User, + [ + "Hi, Assistant! Draw a graph for a line with a slope of 4 and y-intercept of 9." + ]); + AssistantThread thread = await assistantClient.CreateThreadAsync(new ThreadCreationOptions() + { + InitialMessages = { initialMessage }, + }); + #endregion + + #region Snippet:Assistants:StreamRun + RunCreationOptions runOptions = new() + { + AdditionalInstructions = "When possible, talk like a pirate." + }; + await foreach (StreamingUpdate streamingUpdate + in assistantClient.CreateRunStreamingAsync(thread, assistant, runOptions)) + { + if (streamingUpdate.UpdateKind == StreamingUpdateReason.RunCreated) + { + Console.WriteLine($"--- Run started! ---"); + } + else if (streamingUpdate is MessageContentUpdate contentUpdate) + { + Console.Write(contentUpdate.Text); + if (contentUpdate.ImageFileId is not null) + { + Console.WriteLine($"[Image content file ID: {contentUpdate.ImageFileId}"); + } + } + } + #endregion + + #region Snippet:Assistants:Cleanup + // Optionally, delete persistent resources that are no longer needed. + _ = await assistantClient.DeleteAssistantAsync(assistant); + _ = await assistantClient.DeleteThreadAsync(thread); + #endregion + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/AoaiTestBase.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/AoaiTestBase.cs new file mode 100644 index 000000000..70c178dd2 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/AoaiTestBase.cs @@ -0,0 +1,742 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Azure.AI.OpenAI.Tests.Models; +using Azure.AI.OpenAI.Tests.Utils; +using Azure.AI.OpenAI.Tests.Utils.Config; +using NUnit.Framework.Interfaces; +using OpenAI.Assistants; +using OpenAI.Audio; +using OpenAI.Batch; +using OpenAI.Chat; +using OpenAI.Embeddings; +using OpenAI.Files; +using OpenAI.FineTuning; +using OpenAI.Images; +using OpenAI.TestFramework; +using OpenAI.TestFramework.Recording.Proxy; +using OpenAI.TestFramework.Recording.Proxy.Service; +using OpenAI.TestFramework.Recording.RecordingProxy; +using OpenAI.TestFramework.Recording.Sanitizers; +using OpenAI.TestFramework.Utils; +using OpenAI.VectorStores; +using TokenCredential = Azure.Core.TokenCredential; + +namespace Azure.AI.OpenAI.Tests; + +public class AoaiTestBase : RecordedClientTestBase where TClient : class +{ + private const string AZURE_URI_SANITIZER_PATTERN = @"(?<=/(subscriptions|resourceGroups|accounts)/)([^/]+?)(?=(/|$))"; + private const string SMALL_1x1_PNG = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAAJcEhZcwAAFiQAABYkAZsVxhQAAAAMSURBVBhXY2BgYAAAAAQAAVzN/2kAAAAASUVORK5CYII="; + + public static readonly DateTimeOffset START_2024 = new DateTimeOffset(2024, 01, 01, 00, 00, 00, TimeSpan.Zero); + public static readonly DateTimeOffset UNIX_EPOCH = +#if NETFRAMEWORK + DateTimeOffset.Parse("1970-01-01T00:00:00.0000000+00:00"); +#else + DateTimeOffset.UnixEpoch; +#endif + + internal TestConfig TestConfig { get; } + + internal Assets Assets { get; } + + public AzureTestEnvironment TestEnvironment { get; } + + protected AoaiTestBase(bool isAsync) : this(isAsync, null) + { } + + protected AoaiTestBase(bool isAsync, RecordedTestMode? mode = null) + : base(isAsync, mode) + { + TestConfig = new TestConfig(Mode); + Assets = new Assets(); + TestEnvironment = new AzureTestEnvironment(Mode); + + // Remove some of the default sanitizers to customize their behaviour + RecordingOptions.SanitizersToRemove.AddRange( + [ + "AZSDK2003", // Location header (we use a less restrictive sanitizer) + "AZSDK4001", // Replaces entire host name in URL. We want to mask only subdomain part to make it easier to distinguish requests + "AZSDK3430", // OpenAI liberally uses "id" in its JSON responses, and we want to keep them in the recordings + "AZSDK3493", // $..name in JSON. OpenAI uses this for things that don't need to be sanitized + ]); + + // Prevent resource names from leaking into recordings + RecordingOptions.Sanitizers.AddRange( + [ + new UriRegexSanitizer(SanitizedJsonConfig.HOST_SUBDOMAIN_PATTERN) + { + Value = SanitizedJsonConfig.MASK_STRING + }, + new UriRegexSanitizer(AZURE_URI_SANITIZER_PATTERN) + { + Value = SanitizedJsonConfig.MASK_STRING + }, + new HeaderRegexSanitizer("Location") + { + Regex = AZURE_URI_SANITIZER_PATTERN, + Value = SanitizedJsonConfig.MASK_STRING + }, + new HeaderRegexSanitizer("Azure-AsyncOperation") + { + Regex = AZURE_URI_SANITIZER_PATTERN, + Value = SanitizedJsonConfig.MASK_STRING + }, + new BodyKeySanitizer("$..endpoint") + { + Regex = SanitizedJsonConfig.HOST_SUBDOMAIN_PATTERN, + Value = SanitizedJsonConfig.MASK_STRING + } + ]); + + // Prevent keys from leaking into our recordings + RecordingOptions.SanitizeJsonBody("*..key", "*..api_key"); + + // Because the current implementation of multi-part form content data in OpenAI and Azure OpenAI uses random + // to generate boundaries, this causes problems during playback as the boundary will be different each time. + // Longer term, we should find a way to pass the TestRecording.Random to the multi-part form generator in the + // code. The simplest solution for now is to disable recording the body for these mime types + RecordingOptions.RequestOverride = request => + { + if (request?.Headers.GetFirstOrDefault("Content-Type")?.StartsWith("multipart/form-data") == true) + { + return RequestRecordMode.RecordWithoutRequestBody; + } + + return RequestRecordMode.Record; + }; + RecordingOptions.Sanitizers.Add(new HeaderRegexSanitizer("Content-Type") + { + Regex = @"multipart/form-data; boundary=[^\s]+", + Value = "multipart/form-data; boundary=***" + }); + + // Data URIs trimmed to prevent the recording from being too large + RecordingOptions.Sanitizers.Add(new BodyKeySanitizer("$..url") + { + Regex = @"(?<=data:image/png;base64,)(.+)", + Value = SMALL_1x1_PNG + }); + + // Base64 encoded images in the response are replaced with a 1x1 black pixel PNG image to ensure valid data + RecordingOptions.Sanitizers.Add(new BodyKeySanitizer($"..b64_json") + { + Value = SMALL_1x1_PNG + }); + } + + /// + /// Gets the top level test client to use for testing. + /// + /// The test configuration to use + /// (Optional) The client options to use. + /// (Optional) The token credential to use. If this is null, an API key will be read from the + /// test configuration. + /// (Optional) The key credential to use instead of the one from the configuration. + public virtual AzureOpenAIClient GetTestTopLevelClient( + IConfiguration? config, + TestClientOptions? options = null, + TokenCredential? tokenCredential = null, + ApiKeyCredential? keyCredential = null) + { + // First validate that the config has the parameters we need + if (config == null) + { + throw CreateKeyNotFoundEx("any configuration"); + } + else if (config.Endpoint is null) + { + throw CreateKeyNotFoundEx("endpoint"); + } + else if (tokenCredential == null && keyCredential == null && string.IsNullOrEmpty(config.Key)) + { + throw CreateKeyNotFoundEx("API key"); + } + + // Configure the test options as needed + options ??= new(); + Action? requestAction = options.ShouldOutputRequests ? DumpRequest : null; + Action? responseAction = options.ShouldOutputResponses ? DumpResponse : null; + options.AddPolicy(new TestPipelinePolicy(requestAction, responseAction), PipelinePosition.PerCall); + + options = ConfigureClientOptions(options); + + AzureOpenAIClient topLevelClient; + if (tokenCredential != null) + { + topLevelClient = new AzureOpenAIClient(config.Endpoint, tokenCredential, options); + } + else + { + topLevelClient = new AzureOpenAIClient(config.Endpoint, keyCredential ?? new ApiKeyCredential(config.Key!), options); + } + + return topLevelClient; + } + + /// + /// Gets the properly instrumented client to use for testing. This have proper support for automatic sync/async method testing, + /// as well as recording, and playback support. + /// + /// (Optional) The client options to use. + /// (Optional) The token credential to use. If this is null, an API key will be read from the + /// test configuration. + /// (Optional) The key credential to use instead of the one from the configuration. + /// The test client instance. + public virtual TClient GetTestClient(TestClientOptions? options = null, TokenCredential? tokenCredential = null, ApiKeyCredential? keyCredential = null) + => GetTestClient(TestConfig.GetConfig(), options, tokenCredential, keyCredential); + + /// + /// Gets the properly instrumented client to use for testing. This have proper support for automatic sync/async method testing, + /// as well as recording, and playback support. + /// + /// + /// (Optional) The client options to use. + /// (Optional) The token credential to use. If this is null, an API key will be read from the + /// test configuration. + /// (Optional) The key credential to use instead of the one from the configuration. + /// The test client instance. + public virtual TClient GetTestClient(string configName, TestClientOptions? options = null, TokenCredential? tokenCredential = null, ApiKeyCredential? keyCredential = null) + => GetTestClient(TestConfig.GetConfig(configName), options, tokenCredential, keyCredential); + + /// + /// Gets a different type of client using the same configuration as the specified client. + /// + /// The type of other client to create. + /// The client instance whose configuration we want to use. + /// (Optional) The specific deployment to use instead of the one from the config. + /// + /// The client instance passed was not instrumented + public virtual TExplicitClient GetTestClientFrom(TClient client, string? deploymentName = null) + { + var instrumented = (TopLevelInfo?)GetClientContext(client); + if (instrumented?.TopLevelClient != null + && instrumented?.Config != null) + { + return GetTestClient(instrumented.TopLevelClient, instrumented.Config, deploymentName); + } + + throw new NotSupportedException("The client provided was not properly instrumented. Please make sure to get your test client " + + "instances using the GetTestClient() methods"); + } + + #region overrides + + /// + protected override RecordedTestMode GetDefaultRecordedTestMode() + => AzureTestEnvironment.DefaultRecordMode; + + /// + protected override ProxyServiceOptions CreateProxyServiceOptions() + => new() + { + DotnetExecutable = TestEnvironment.DotNetExe.FullName, + TestProxyDll = TestEnvironment.TestProxyDll.FullName, + DevCertFile = TestEnvironment.TestProxyHttpsCert.FullName, + DevCertPassword = TestEnvironment.TestProxyHttpsCertPassword, + StorageLocationDir = TestEnvironment.RepoRoot.FullName, + }; + + /// + protected override RecordingStartInformation CreateRecordingSessionStartInfo() + { + // This uses the same directory structure as the previous Azure.Core.TestFramework used for an easy drop in replacement. + // For example, suppose your test class is (and your class name matches the file name): + // c:\src\azure-sdk-for-net\sdk\openai\Azure.AI.OpenAI\tests\ChatTests.cs + // Then this would return something like: + // sdk\openai\Azure.AI.OpenAI\tests\SessionRecords\ChatTests\TestName.json + DirectoryInfo? sourceDir = GetType().Assembly.GetAssemblySourceDir(); + string relativeDir = PathHelpers.GetRelativePath( + TestEnvironment.RepoRoot.FullName, + sourceDir?.FullName ?? TestEnvironment.RepoRoot.FullName); + + string recordingFile = Path.Combine( + relativeDir, + "SessionRecords", + GetType().Name, + GetRecordedTestFileName()); + + // Start at the source directory for the current test project, and then walk up the directory structure searching for + // an "assets.json" file. + string? assetsFile = null; + for ( + DirectoryInfo? current = sourceDir; + current != null && current.FullName != TestEnvironment.RepoRoot.FullName; + current = current?.Parent) + { + string file = Path.Combine(current!.FullName, "assets.json"); + if (File.Exists(file)) + { + assetsFile = file; + break; + } + } + + return new() + { + RecordingFile = recordingFile, + AssetsFile = assetsFile + }; + } + + #endregion + + /// + /// Polls until a condition has been met with a maximum wait time. The function will always return the last value even + /// if the condition was not met. + /// + /// The value in the . + /// The initial value. + /// The asynchronous function to get the latest state of the value. + /// When we should stop waiting. + /// (Optional) The amount of time to wait between retries. This will be ignored in playback + /// mode. Default is 2 seconds. + /// (Optional) The maximum amount of time to wait until the condition becomes true. This will be ignored in + /// playback mode. The default is 2 minutes. + /// The final state. This will return when the conditions have been met or we timed out. + protected virtual Task WaitUntilReturnLast(T initialValue, Func>> getAsync, Predicate stopCondition, TimeSpan? waitTimeBetweenRequests = null, TimeSpan? maxWait = null) + => WaitUntilReturnLast(initialValue, new Func>(async () => await getAsync().ConfigureAwait(false)), stopCondition, waitTimeBetweenRequests, maxWait); + + /// + /// Polls until a condition has been met with a maximum wait time. The function will always return the last value even + /// if the condition was not met. + /// + /// The return value. + /// The initial value. + /// The asynchronous function to get the latest state of the value. + /// When we should stop waiting. + /// (Optional) The amount of time to wait between retries. This will be ignored in playback + /// mode. Default is 2 seconds. + /// (Optional) The maximum amount of time to wait until the condition becomes true. This will be ignored in + /// playback mode. The default is 2 minutes. + /// The final state. This will return when the conditions have been met or we timed out. + protected virtual async Task WaitUntilReturnLast(T initialValue, Func> getAsync, Predicate stopCondition, TimeSpan? waitTimeBetweenRequests = null, TimeSpan? maxWait = null) + { + TimeSpan delay, max; + if (Mode == RecordedTestMode.Playback) + { + delay = TimeSpan.FromMilliseconds(10); + max = TimeSpan.FromSeconds(30); + } + else + { + delay = waitTimeBetweenRequests ?? TimeSpan.FromSeconds(2); + max = maxWait ?? TimeSpan.FromMinutes(2); + } + + DateTimeOffset stopTime = DateTimeOffset.Now + max; + T result = initialValue; + + while (!stopCondition(result) && DateTimeOffset.Now < stopTime) + { + await Task.Delay(delay).ConfigureAwait(false); + result = await getAsync().ConfigureAwait(false); + } + + return result; + } + + /// + /// Gets the properly instrumented client to use for testing. This have proper support for automatic sync/async method testing, + /// as well as recording, and playback support. + /// + /// The test configuration to use + /// (Optional) The client options to use. + /// (Optional) The token credential to use. If this is null, an API key will be read from the + /// test configuration. + /// (Optional) The key credential to use instead of the one from the configuration. + /// The test client instance. + protected virtual TClient GetTestClient(IConfiguration? config, TestClientOptions? options = null, TokenCredential? tokenCredential = null, ApiKeyCredential? keyCredential = null) + { + AzureOpenAIClient topLevelClient = GetTestTopLevelClient(config, options, tokenCredential, keyCredential); + return GetTestClient(topLevelClient, config!); + } + + /// + /// Gets the properly instrumented client to use for testing. This have proper support for automatic sync/async method testing, + /// as well as recording, and playback support. + /// + /// The type of test client to get. + /// The top level client to use. + /// The configuration to use to get the deployment information (if needed). + /// The instrumented client instance to use. + /// Support for the type of client being requested has not been implemented yet. + protected virtual TExplicitClient GetTestClient(AzureOpenAIClient topLevelClient, IConfiguration config, string? deploymentName = null) + { + Func getDeployment = () => deploymentName ?? config?.Deployment ?? throw CreateKeyNotFoundEx("deployment"); + object clientObject; + + switch (typeof(TExplicitClient).Name) + { + case nameof(AssistantClient): + clientObject = topLevelClient.GetAssistantClient(); + break; + case nameof(AudioClient): + clientObject = topLevelClient.GetAudioClient(getDeployment()); + break; + case nameof(BatchClient): + clientObject = topLevelClient.GetBatchClient(getDeployment()); + break; + case nameof(ChatClient): + clientObject = topLevelClient.GetChatClient(getDeployment()); + break; + case nameof(EmbeddingClient): + clientObject = topLevelClient.GetEmbeddingClient(getDeployment()); + break; + case nameof(FileClient): + clientObject = topLevelClient.GetFileClient(); + break; + case nameof(FineTuningClient): + clientObject = topLevelClient.GetFineTuningClient(); + break; + case nameof(ImageClient): + clientObject = topLevelClient.GetImageClient(getDeployment()); + break; + case nameof(VectorStoreClient): + clientObject = topLevelClient.GetVectorStoreClient(); + break; + case nameof(AzureDeploymentClient): + var accessor = NonPublic.FromField("_transport"); + clientObject = new AzureDeploymentClient( + config, + TestEnvironment.Credential, + transport: accessor.Get(topLevelClient.Pipeline)); + break; + default: + throw new NotImplementedException($"Test client helpers not yet implemented for {typeof(TExplicitClient)}"); + }; + + object instrumented = WrapClient( + typeof(TExplicitClient), + clientObject, + new TopLevelInfo + { + TopLevelClient = topLevelClient, + Config = config, + }, + null); + + return (TExplicitClient)instrumented; + } + + private Exception CreateKeyNotFoundEx(string whatIsMissing) + { + return new KeyNotFoundException($"Could not find any {whatIsMissing} to use. Please make sure you have the necessary" + + $" {TestConfig.AssetsJson} config file, or have the needed environment variables set"); + } + + private static void DumpRequest(PipelineRequest request) + { + Console.WriteLine($"--- New request ---"); + Console.WriteLine($"{request.Method} {request?.Uri}"); + string headers = string.Join("\n ", + request!.Headers + .Select(kvp => $"{kvp.Key}: {(kvp.Key.ToLowerInvariant().Contains("auth") ? "***" : kvp.Value)}")); + Console.Write(" "); + Console.WriteLine(headers); + + if (request?.Content is not null) + { + using MemoryStream stream = new(); + request.Content.WriteTo(stream, default); + stream.Position = 0; + + string? contentType = request.Headers.GetFirstOrDefault("Content-Type"); + if (IsProbableTextContent(contentType)) + { + DumpText(contentType, stream); + } + else + { + DumpHex(stream); + } + } + } + + private static void DumpResponse(PipelineResponse response) + { + Console.WriteLine($"--- Response ---"); + Console.WriteLine($"{response.Status} - {response.ReasonPhrase}"); + string headers = string.Join( + "\n ", + response.Headers + .Where(kvp => !kvp.Key.ToLowerInvariant().Contains("client-")) + .Select(kvp => $"{kvp.Key}: {kvp.Value}")); + Console.Write(" "); + Console.WriteLine(headers); + + response.BufferContent(); + + if (response!.Content is not null) + { + using Stream stream = response.Content.ToStream(); + string? contentType = response.Headers.GetFirstOrDefault("Content-Type"); + if (IsProbableTextContent(contentType)) + { + DumpText(contentType, stream); + } + else + { + DumpHex(stream); + } + } + + Console.WriteLine(); + } + + private static bool IsProbableTextContent(string? contentType) + { + contentType = contentType?.ToLowerInvariant() ?? string.Empty; + return contentType.StartsWith("application/json") + || contentType.StartsWith("text/"); + } + + private static void DumpText(string? contentType, Stream stream) + { + if (contentType?.ToLowerInvariant().StartsWith("application/json") == true) + { + var json = JsonDocument.Parse(stream); + + stream = new MemoryStream(); + using (Utf8JsonWriter writer = new(stream, new() { Indented = true })) + { + json.WriteTo(writer); + } + + stream.Seek(0, SeekOrigin.Begin); + } + + using StreamReader reader = new(stream); + Console.WriteLine(reader.ReadToEnd()); + } + + private static void DumpHex(Stream stream, int maxLines = 256) + { + byte[] buffer = new byte[32]; + StringBuilder hex = new(3 * buffer.Length); + StringBuilder chars = new(buffer.Length); + + int read = 0; + for (int lines = 0; (read = stream.FillBuffer(buffer)) > 0 && lines < maxLines; lines++) + { + for (int i = 0; i < read; i++) + { + hex.AppendFormat("{0:X2} ", buffer[i]); + + char c = Convert.ToChar(buffer[i]); + chars.Append(char.IsControl(c) ? ' ' : c); + } + + Console.Write(hex.PadRight(buffer.Length * 3)); + Console.Write("| "); + Console.WriteLine(chars); + + hex.Clear(); + chars.Clear(); + } + + if (read != 0) + { + Console.WriteLine(" ... truncated"); + } + } + + protected void ValidateById(string id) + { + Assert.That(id, Is.Not.Null.Or.Empty); + switch (typeof(T).Name) + { + case nameof(Assistant): + _assistantIdsToDelete.Add(id); + break; + case nameof(AssistantThread): + _threadIdsToDelete.Add(id); + break; + case nameof(OpenAIFileInfo): + _fileIdsToDelete.Add(id); + break; + case nameof(ThreadRun): + break; + case nameof(VectorStore): + _vectorStoreIdsToDelete.Add(id); + break; + default: + throw new NotImplementedException(); + } + } + + protected void ValidateById(string id, string parentId) + { + Assert.That(id, Is.Not.Null.Or.Empty); + Assert.That(parentId, Is.Not.Null.Or.Empty); + switch (typeof(T).Name) + { + case nameof(ThreadMessage): + _threadIdsWithMessageIdsToDelete.Add((parentId, id)); + break; + case nameof(VectorStoreFileAssociation): + _vectorStoreFileAssociationsToRemove.Add((parentId, id)); + break; + default: + throw new NotImplementedException(); + } + } + + /// + /// Performs basic, invariant validation of a target that was just instantiated from its corresponding origination + /// mechanism. If applicable, the instance is recorded into the test run for cleanup of persistent resources. + /// + /// Instance type being validated. + /// The instance to validate. + /// The provided instance type isn't supported. + protected void Validate(T target) + { + if (target is ThreadMessage message) + { + ValidateById(message.Id, message.ThreadId); + } + else if (target is VectorStoreFileAssociation fileAssociation) + { + ValidateById(fileAssociation.VectorStoreId, fileAssociation.FileId); + } + else + { + ValidateById(target switch + { + Assistant assistant => assistant.Id, + AssistantThread thread => thread.Id, + OpenAIFileInfo file => file.Id, + ThreadRun run => run.Id, + VectorStore store => store.Id, + _ => throw new NotImplementedException(), + }); + } + } + + [TearDown] + protected void Cleanup() + { + AzureOpenAIClient topLevelCleanupClient = GetTestTopLevelClient(TestConfig.GetConfig(), new() + { + ShouldOutputRequests = false, + ShouldOutputResponses = false, + }); + AssistantClient client = topLevelCleanupClient.GetAssistantClient(); + VectorStoreClient vectorStoreClient = topLevelCleanupClient.GetVectorStoreClient(); + FileClient fileClient = topLevelCleanupClient.GetFileClient(); + RequestOptions requestOptions = new() { ErrorOptions = ClientErrorBehaviors.NoThrow, }; + foreach ((string threadId, string messageId) in _threadIdsWithMessageIdsToDelete) + { + Console.WriteLine($"Cleanup: {messageId} -> {client.DeleteMessage(threadId, messageId, requestOptions)?.GetRawResponse().Status}"); + } + foreach (string assistantId in _assistantIdsToDelete) + { + Console.WriteLine($"Cleanup: {assistantId} -> {client.DeleteAssistant(assistantId, requestOptions)?.GetRawResponse().Status}"); + } + foreach (string threadId in _threadIdsToDelete) + { + Console.WriteLine($"Cleanup: {threadId} -> {client.DeleteThread(threadId, requestOptions)?.GetRawResponse().Status}"); + } + foreach ((string vectorStoreId, string fileId) in _vectorStoreFileAssociationsToRemove) + { + Console.WriteLine($"Cleanup: {vectorStoreId}<->{fileId} => {vectorStoreClient.RemoveFileFromStore(vectorStoreId, fileId, requestOptions)?.GetRawResponse().Status}"); + } + foreach (string vectorStoreId in _vectorStoreIdsToDelete) + { + Console.WriteLine($"Cleanup: {vectorStoreId} => {vectorStoreClient.DeleteVectorStore(vectorStoreId, requestOptions)?.GetRawResponse().Status}"); + } + foreach (string fileId in _fileIdsToDelete) + { + Console.WriteLine($"Cleanup: {fileId} -> {fileClient.DeleteFile(fileId, requestOptions)?.GetRawResponse().Status}"); + } + _threadIdsWithMessageIdsToDelete.Clear(); + _assistantIdsToDelete.Clear(); + _threadIdsToDelete.Clear(); + _vectorStoreFileAssociationsToRemove.Clear(); + _vectorStoreIdsToDelete.Clear(); + _fileIdsToDelete.Clear(); + + // If we are in recording mode, update the recorded playback configuration as well + if (Mode == RecordedTestMode.Record + && TestContext.CurrentContext.Result.Outcome == ResultState.Success) + { + TestConfig.SavePlaybackConfig(); + } + } + + protected static void ValidateClientResult(ClientResult result) + { + Assert.That(result, Is.Not.Null); + Assert.That(result.GetRawResponse(), Is.Not.Null); + } + + protected static PipelineResponse ValidateClientResultResponse(ClientResult result) + { + ValidateClientResult(result); + + PipelineResponse response = result.GetRawResponse(); + Assert.That(response.Status, Is.GreaterThanOrEqualTo(200).And.LessThan(300)); + Assert.That(response.Headers, Is.Not.Null); + Assert.That(response.Headers.GetFirstOrDefault("Content-Type"), Does.StartWith("application/json")); + Assert.That(response.Content, Is.Not.Null); + + return response; + } + + protected virtual TModel ValidateAndParse(ClientResult result) where TModel : IJsonModel + { + var response = ValidateClientResultResponse(result); + + TModel? model = ModelReaderWriter.Read(response.Content, ModelReaderWriterOptions.Json); + Assert.That(model, Is.Not.Null); + return model!; + } + + protected virtual TModel ValidateAndParse(ClientResult result, JsonSerializerOptions? options = null) + { + var response = ValidateClientResultResponse(result); + + using Stream stream = response.Content.ToStream(); + Assert.That(stream, Is.Not.Null); + + TModel? model = JsonHelpers.Deserialize(stream, options ?? JsonOptions.OpenAIJsonOptions); + Assert.That(model, Is.Not.Null); + return model!; + } + + internal class TopLevelInfo + { + //required public object Client { get; init; } + required public AzureOpenAIClient TopLevelClient { get; init; } + required public IConfiguration Config { get; init; } + } + + private readonly List _assistantIdsToDelete = []; + private readonly List _threadIdsToDelete = []; + private readonly List<(string, string)> _threadIdsWithMessageIdsToDelete = []; + private readonly List _fileIdsToDelete = []; + private readonly List<(string, string)> _vectorStoreFileAssociationsToRemove = []; + private readonly List _vectorStoreIdsToDelete = []; +} + +public class TestClientOptions : AzureOpenAIClientOptions +{ + public TestClientOptions() : base() + { } + + public TestClientOptions(ServiceVersion version) : base(version) + { } + + public bool ShouldOutputRequests { get; set; } = true; + public bool ShouldOutputResponses { get; set; } = true; +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Assets.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Assets.cs new file mode 100644 index 000000000..15590bc9c --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Assets.cs @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.IO; + +namespace Azure.AI.OpenAI.Tests +{ + public class Assets + { + public Assets() + { + HelloWorld = new() + { + Type = AssetType.Audio, + Language = "en", + Description = "Hello world", + Name = "hello_world.m4a", + RelativePath = GetPath("hello_world.m4a"), + MimeType = "audio/m4a" + }; + WhisperFrenchDescription = new() + { + Type = AssetType.Audio, + Language = "fr", + Description = "Whisper description in French", + Name = "french.wav", + RelativePath = GetPath("french.wav"), + MimeType = "audio/wave" + }; + DogAndCat = new() + { + Type = AssetType.Image, + Language = null, + Description = "A picture of a cat next to a dog", + Name = "variation_sample_image.jpg", + RelativePath = GetPath("variation_sample_image.png"), + MimeType = "image/png", + Url = new Uri("https://cdn.openai.com/API/images/guides/image_variation_original.webp") + }; + FineTuning = new() + { + Type = AssetType.Text, + Language = "en", + Description = "Fine tuning data for Open AI to generate a JSON object based on sports headlines", + Name = "fine_tuning.jsonl", + RelativePath = GetPath("fine_tuning.jsonl"), + MimeType = "text/plain" + }; + } + + public virtual AssetInfo HelloWorld { get; } + public virtual AssetInfo WhisperFrenchDescription { get; } + public virtual AssetInfo DogAndCat { get; } + public virtual AssetInfo FineTuning { get; } + + protected virtual string GetPath(string assetName) + { + return Path.Combine("Assets", assetName); + } + } + + public enum AssetType + { + Text, + Audio, + Image, + Raw + } + + public class AssetInfo + { + required public AssetType Type { get; init; } + required public string Name { get; init; } + required public string RelativePath { get; init; } + required public string MimeType { get; init; } + public string? Language { get; init; } + public string? Description { get; init; } + public Uri? Url { get; init; } + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/AzureTestEnvironment.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/AzureTestEnvironment.cs new file mode 100644 index 000000000..be87fb307 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/AzureTestEnvironment.cs @@ -0,0 +1,236 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using Azure.Core; +using Azure.Identity; +using OpenAI.TestFramework; +using OpenAI.TestFramework.Mocks; +using OpenAI.TestFramework.Recording; +using OpenAI.TestFramework.Utils; + +namespace Azure.AI.OpenAI.Tests.Utils; + +/// +/// Represents an Azure test environment. +/// +public class AzureTestEnvironment +{ + private readonly RecordedTestMode _mode; + private readonly string _optionPrefix; + private TokenCredential? _credential; + + /// + /// Initializes a new instance. + /// + /// The recorded test mode to use. + public AzureTestEnvironment(RecordedTestMode mode) + { + _mode = mode; + + /** + * We want to be able to to find "root" folders: + * - The root of the Git repo on disk + * - The root folder of the source code (eng/sdk) + * These two are usually the same. In external repos, they may however be a little different. + * + * To search for these folders, we use a simple method where we search up from these starting folders: + * - Check the "SourcePath" assembly metadata attribute value. All projects in the Azure C# repo automatically have this attribute + * added as part of the build "magic" (see {repo_root}\Directory.Build.Targets) + * - Where the executing assembly is running from + * Until we find a parent folder that contains a specific subfolder(s). + */ + DirectoryInfo?[] startingPoints = + [ + AssemblyHelper.GetAssemblySourceDir(), + new FileInfo(Assembly.GetExecutingAssembly().Location).Directory, + ]; + + RepoRoot = FindFirstParentWithSubfolders(startingPoints, ".git") + ?? throw new InvalidOperationException("Could not determine the GIT root folder for this repository"); + + string sourceRoot = (FindFirstParentWithSubfolders(startingPoints, "eng", "sdk") ?? RepoRoot) + .FullName; + + DotNetExe = AssemblyHelper.GetDotnetExecutable() + ?? throw new InvalidOperationException( + "Could not determine the dotnet executable to use. Do you have .Net installed or have your paths correctly configured?"); + + TestProxyDll = new FileInfo( + AssemblyHelper.GetAssemblyMetadata("TestProxyPath") + ?? throw new InvalidOperationException("Could not determine the path to the recording test proxy DLL")); + + TestProxyHttpsCert = new FileInfo(Path.Combine( + sourceRoot, + "eng", + "common", + "testproxy", + "dotnet-devcert.pfx")); + if (!TestProxyHttpsCert.Exists) + { + throw new InvalidOperationException("Could not find test proxy HTTPS root certificate to use."); + } + + TestProxyHttpsCertPassword = "password"; + + string? serviceName = null; + DirectoryInfo? sourceDir = GetType().Assembly.GetAssemblySourceDir(); + if (sourceDir != null) + { + string relativePath = PathHelpers.GetRelativePath( + Path.Combine(sourceRoot, "sdk"), + sourceDir.FullName); + serviceName = relativePath + .Split(new char[] { Path.DirectorySeparatorChar }, StringSplitOptions.RemoveEmptyEntries) + .FirstOrDefault()!; + } + + _optionPrefix = serviceName?.ToUpperInvariant() + "_"; + } + + /// + /// Gets the root Git folder. + /// + public DirectoryInfo RepoRoot { get; } + + /// + /// Gets the path to the dotnet executable. This will be used in combination with to start the + /// recording test proxy service. + /// + public FileInfo DotNetExe { get; } + + /// + /// The path to test proxy DLL that will be used when starting the recording test proxy service. + /// + public FileInfo TestProxyDll { get; } + + /// + /// Gets the HTTPS certificate file to use as the signing certificate for HTTPS connections to the test proxy. + /// + public FileInfo TestProxyHttpsCert { get; } + + /// + /// Gets the password for . + /// + public string TestProxyHttpsCertPassword { get; } + + /// + /// Gets the token credential to use during testing. This will change depending on the record mode. + /// + public TokenCredential Credential => _credential ??= GetCredential(); + + /// + /// Gets the default record mode to use for the test. This will attempt to read from the test context, or environment variables. + /// + public static RecordedTestMode DefaultRecordMode + { + get + { + string? modeString = TestContext.Parameters["TestMode"] + ?? Environment.GetEnvironmentVariable("AZURE_TEST_MODE"); + + if (Enum.TryParse(modeString, true, out RecordedTestMode mode)) + { + return mode; + } + + return RecordedTestMode.Playback; + } + } + + /// + /// Gets an optional value from environment variables. + /// + /// The name of the value to retrieve. + /// The value, or null if it did not exist. + public string? GetOptionalVariable(string name) + { + return new[] + { + _optionPrefix + name, + name, + "AZURE_" + name + } + .Select(Environment.GetEnvironmentVariable) + .FirstOrDefault(value => !string.IsNullOrWhiteSpace(value)); + } + + /// + /// Gets a value from environment variables, or throws an exception if it does not exist. + /// + /// The name of the value to retrieve. + /// The value. + /// If the value did not exist. + public string GetVariable(string name) + { + string? optionalVariable = GetOptionalVariable(name); + return optionalVariable + ?? throw new InvalidOperationException($"Could not find required environment variable '{_optionPrefix + name }' or '{name}'."); + } + + private static DirectoryInfo? FindFirstParentWithSubfolders(IEnumerable startingDirs, params string[] subFolders) + => startingDirs + .Select(d => FindParentWithSubfolders(d, subFolders)) + .FirstOrDefault(d => d != null); + + private static DirectoryInfo? FindParentWithSubfolders(DirectoryInfo? start, params string[] subFolders) + { + if (subFolders == null || subFolders.Length == 0) + { + return null; + } + + for (DirectoryInfo? current = start; current != null; current = current.Parent) + { + if (!current.Exists) + { + return null; + } + else if (subFolders.All(sub => current.EnumerateDirectories(sub).Any())) + { + return current; + } + } + + return null; + } + + private TokenCredential GetCredential() + { + if (_mode == RecordedTestMode.Playback) + { + return new MockTokenCredential(); + } + + // I'm not sure exactly what the possible combinations to use here are, so I've essentially copied the logic + // TestEnvironment.cs in Azure.Core.TestFramework (though it is a little simplified here) + string? clientSecret = GetOptionalVariable("CLIENT_SECRET"); + string? systemAccessToken = GetOptionalVariable("SYSTEM_ACCESSTOKEN"); + + if (!string.IsNullOrWhiteSpace(clientSecret)) + { + return new ClientSecretCredential( + GetVariable("TENANT_ID"), + GetVariable("CLIENT_ID"), + clientSecret); + } + else if (!string.IsNullOrWhiteSpace(systemAccessToken)) + { + return new AzurePipelinesCredential( + GetVariable("AZURESUBSCRIPTION_TENANT_ID"), + GetVariable("AZURESUBSCRIPTION_CLIENT_ID"), + GetVariable("AZURESUBSCRIPTION_SERVICE_CONNECTION_ID"), + systemAccessToken, + new AzurePipelinesCredentialOptions { AuthorityHost = new Uri(GetVariable("AZURE_AUTHORITY_HOST")) }); + } + else + { + return new DefaultAzureCredential( + new DefaultAzureCredentialOptions() { ExcludeManagedIdentityCredential = true }); + } + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/BasicConfig.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/BasicConfig.cs new file mode 100644 index 000000000..d7f7e596d --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/BasicConfig.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; + +namespace Azure.AI.OpenAI.Tests.Utils.Config +{ + /// + /// A basic configuration that allows you to directly set values. + /// + public class BasicConfig : IConfiguration + { + private Dictionary _values = new Dictionary(StringComparer.OrdinalIgnoreCase); + + /// + public Uri? Endpoint { get; set; } + /// + public string? Key { get; set; } + /// + public string? Deployment { get; set; } + + /// + /// Adds an additional value to the configuration. + /// + /// The type of the value to add. + /// The key. + /// The value to add. + /// The instance for chaining. + public BasicConfig AddValue(string key, TVal? value) + { + if (value != null) + { + _values[key] = value; + } + else + { + _values.Remove(key); + } + + return this; + } + + /// + public TVal? GetValue(string key) + { + if (_values.TryGetValue(key, out object? val) + && val is TVal cast) + { + return cast; + } + + return default; + } + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/EnvironmentValuesConfig.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/EnvironmentValuesConfig.cs new file mode 100644 index 000000000..3cd72d98c --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/EnvironmentValuesConfig.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.ComponentModel; + +namespace Azure.AI.OpenAI.Tests.Utils.Config +{ + /// + /// Configuration that reads from environment variables. + /// + public class EnvironmentValuesConfig : INamedConfiguration + { + private const char ENV_KEY_SEPARATOR = '_'; + private const string SUFFIX_AOAI_API_KEY = "API_KEY"; + private const string SUFFIX_AOAI_ENDPOINT = "ENDPOINT"; + private const string SUFFIX_AOAI_DEPLOYMENT = "DEPLOYMENT"; + + private readonly string _prefix; + + /// + /// Creates a new instance. + /// + /// The environment value prefix to use. For example AZURE_OPENAI. + /// The prefix specified was null. + public EnvironmentValuesConfig(string prefix) + { + _prefix = prefix + ?.TrimEnd(ENV_KEY_SEPARATOR) + .ToUpperInvariant() + ?? throw new ArgumentNullException(nameof(prefix)); + + Endpoint = GetValue(SUFFIX_AOAI_ENDPOINT); + Key = GetValue(SUFFIX_AOAI_API_KEY); + Deployment = GetValue(SUFFIX_AOAI_DEPLOYMENT); + } + + /// + /// Creates a new instance. + /// + /// The environment value prefix to use. For example AZURE_OPENAI. + /// The specific type of client we want to get environment variable for + /// The prefix specified was null. + public EnvironmentValuesConfig(string prefix, string clientName) + : this($"{prefix}{ENV_KEY_SEPARATOR}{clientName}") + { + Name = clientName; + } + + /// + public string? Name { get; } + + /// + public Uri? Endpoint { get; } + + /// + public string? Key { get; } + + /// + public string? Deployment { get; } + + /// + public TVal? GetValue(string key) + { + string envKey = $"{_prefix}{ENV_KEY_SEPARATOR}{key.ToUpperInvariant()}"; + + string? value = Environment.GetEnvironmentVariable(envKey); + if (value == null) + { + return default; + } + else if (value is TVal val) + { + return val; + } + else + { + var defaultConverter = TypeDescriptor.GetConverter(typeof(TVal)); + return (TVal?)defaultConverter.ConvertFromInvariantString(value); + } + } + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/FlattenedConfig.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/FlattenedConfig.cs new file mode 100644 index 000000000..205a5d072 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/FlattenedConfig.cs @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using OpenAI.TestFramework.Utils; + +namespace Azure.AI.OpenAI.Tests.Utils.Config; + +/// +/// Represents a flattened configuration that reads from one or more configurations in order. It will also +/// record the values read from each configuration. +/// +public class FlattenedConfig : IConfiguration +{ + private IReadOnlyList _configs; + private IDictionary? _recordedConfig; + + /// + /// Creates a new instance. + /// + /// The configurations to read from in order. + /// Where to store the recorded configuration. + /// The configs passed was null. + public FlattenedConfig(INamedConfiguration?[] configs, IDictionary recordedConfig) + { + _configs = configs ?? throw new ArgumentNullException(nameof(configs)); + _recordedConfig = recordedConfig ?? throw new ArgumentNullException(nameof(recordedConfig)); + + Endpoint = GetAndRecordProperty(c => c.Endpoint, (c, v) => c.Endpoint = v); + Key = GetAndRecordProperty(c => c.Key, (c, v) => c.Key = v); + Deployment = GetAndRecordProperty(c => c.Deployment, (c, v) => c.Deployment = v); + } + + /// + public Uri? Endpoint { get; } + /// + public string? Key { get; } + /// + public string? Deployment { get; } + + /// + public TVal? GetValue(string key) + { + TVal? value = default; + INamedConfiguration? selected = _configs + .Where(config => config != null) + .FirstOrDefault(config => (value = config!.GetValue(key)) != null); + + if (_recordedConfig != null && selected != null && value != null) + { + string configName = selected.Name ?? JsonConfig.DEFAULT_CONFIG_NAME; + SanitizedJsonConfig recorded = _recordedConfig.GetOrAdd(configName, _ => new SanitizedJsonConfig()); + recorded.SetValue(key, value); + } + + return value; + } + + private TVal? GetAndRecordProperty(Func getter, Action setter) + { + TVal? value = default; + INamedConfiguration? selected = _configs + .Where(config => config != null) + .FirstOrDefault(config => (value = getter(config!)) != null); + + if (_recordedConfig != null && selected != null && value != null) + { + string configName = selected.Name ?? JsonConfig.DEFAULT_CONFIG_NAME; + SanitizedJsonConfig recorded = _recordedConfig.GetOrAdd(configName, _ => new SanitizedJsonConfig()); + setter(recorded, value); + } + + return value; + } + + +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/IConfiguration.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/IConfiguration.cs new file mode 100644 index 000000000..97511d548 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/IConfiguration.cs @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using OpenAI.TestFramework.AutoSyncAsync; + +namespace Azure.AI.OpenAI.Tests.Utils.Config; + +/// +/// A test configuration for an Azure resource. +/// +public interface IConfiguration +{ + /// + /// The endpoint to use for sending requests to the Azure resource. + /// + Uri? Endpoint { get; } + + /// + /// The API key to use for authenticating requests to the Azure resource. + /// + string? Key { get; } + + /// + /// The deployment to use for this Azure resource. + /// + string? Deployment { get; } + + /// + /// Gets additional values from the test configuration for the Azure resource. + /// + /// The type of the value. + /// The name of the value (usually snake cased). For example: fine_tuned_model. + /// The parsed value for that key, or null of the key was not found, or failed to be parsed. + TVal? GetValue(string key); +} + +/// +/// A named test configuration for an Azure resource. +/// +public interface INamedConfiguration : IConfiguration +{ + /// + /// The name of the configuration. + /// + string? Name { get; } +} + +/// +/// Extensions methods for . +/// +public static class ConfigurationExtensions +{ + /// + /// Gets additional values from the test configuration for the Azure resource, but throws exceptions if the key is not found. + /// + /// The type of the value. + /// The configuration to get a value from. + /// The name of the value (usually snake cased). For example: fine_tuned_model. + /// The successfully parsed value for that key. + /// If the configuration passed was null + /// If the key could not be found + public static TVal GetValueOrThrow(this IConfiguration? config, string key) + { + if (config == null) + { + throw new ArgumentNullException(nameof(config)); + } + + return config.GetValue(key) + ?? throw new KeyNotFoundException($"Could not find a value for '{key}' in the test configuration"); + } + + /// + /// Gets the configuration that was used when creating the client instance. + /// + /// The type of the client. + /// The client instance. + /// The configuration. + /// The client did not have a config associated with it. + public static IConfiguration GetConfigOrThrow(this TExplicitClient client) where TExplicitClient : class + { + var instrumented = GetTopLevelClientInfo(client); + return instrumented.Config ?? throw new ArgumentException("The client was instrumented with a null configuration"); + } + + /// + /// Gets the deployment to use from the configuration, or throws if none was found. + /// + /// The config. + /// (Optional) The client name to include in th exception message. + /// The deployment. + /// The deployment was not set or found. + public static string DeploymentOrThrow(this IConfiguration? config, string? clientName = null) + { + string str = clientName == null ? string.Empty : clientName + " "; + return config?.Deployment + ?? throw new KeyNotFoundException($"Could not find a {str}deployment in the test configuration"); + } + + /// + /// Gets the deployment from the specified client. + /// + /// The type of the client. + /// The client instance. + /// The deployment name used for that client instance. + /// The client either was not properly instrumented. + /// The client did not have a deployment configured. + public static string DeploymentOrThrow(this TExplicitClient client) where TExplicitClient : class + { + var instrumented = GetTopLevelClientInfo(client); + return instrumented.Config.DeploymentOrThrow(client!.GetType().Name); + } + + private static AoaiTestBase.TopLevelInfo GetTopLevelClientInfo(TExplicitClient? client) + where TExplicitClient : class + { + if (client == null) + { + throw new ArgumentNullException(nameof(client)); + } + + return ((AoaiTestBase.TopLevelInfo?)(client as IAutoSyncAsync)?.Context) + ?? throw new ArgumentException( + $"The client was not properly wrapped for automatic sync/async ({client.GetType().Name})", + nameof(client)); + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/JsonConfig.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/JsonConfig.cs new file mode 100644 index 000000000..e7334b11f --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/JsonConfig.cs @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Text.Json.Serialization; +using OpenAI.TestFramework.Utils; + +namespace Azure.AI.OpenAI.Tests.Utils.Config; + +/// +/// A configuration that is deserialized from JSON. +/// +public class JsonConfig : IConfiguration +{ + /// + /// The default configuration key to use. + /// + public const string DEFAULT_CONFIG_NAME = "default"; + + /// + /// The JSON configuration to use when serializing and deserializing. + /// + public static readonly JsonSerializerOptions JSON_OPTIONS = new() + { + PropertyNameCaseInsensitive = true, + PropertyNamingPolicy = JsonOptions.SnakeCaseLower, + DictionaryKeyPolicy = JsonOptions.SnakeCaseLower, + WriteIndented = true, + AllowTrailingCommas = true, +#if NETFRAMEWORK + IgnoreNullValues = true, +#else + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, +#endif + }; + + /// + public Uri? Endpoint { get; init; } + /// + public string? Key { get; init; } + /// + public string? Deployment { get; init; } + + /// + /// Json values that are not part of the class go here. + /// + [JsonExtensionData] + public Dictionary? ExtensionData { get; set; } + + /// + public TVal? GetValue(string key) + { + if (ExtensionData?.TryGetValue(key, out JsonElement value) == true) + { + return value.Deserialize(JSON_OPTIONS); + } + + return default; + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/NamedConfig.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/NamedConfig.cs new file mode 100644 index 000000000..5690f48a9 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/NamedConfig.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Text.Json; + +namespace Azure.AI.OpenAI.Tests.Utils.Config; + +/// +/// A wrapper around a test configuration to associate an optional name. +/// +public class NamedConfig : INamedConfiguration +{ + private readonly IConfiguration? _config; + + /// + /// Creates a new instance. + /// + /// The configuration instance. + /// The name of the config. + public NamedConfig(IConfiguration? config, string? name) + { + _config = config; + Name = name; + } + + /// + public string? Name { get; } + /// + public Uri? Endpoint => _config?.Endpoint; + /// + public string? Key => _config?.Key; + /// + public string? Deployment => _config?.Deployment; + /// + public TVal? GetValue(string key) => _config == null ? default : _config.GetValue(key); +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/SanitizedJsonConfig.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/SanitizedJsonConfig.cs new file mode 100644 index 000000000..54e111aa7 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Config/SanitizedJsonConfig.cs @@ -0,0 +1,182 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.RegularExpressions; +using OpenAI.TestFramework.Utils; + +namespace Azure.AI.OpenAI.Tests.Utils.Config +{ + /// + /// A sanitized JSON configuration. This will automatically sanitize the Endpoint, Key, subscription ID and resource group in the configuration + /// file. Please make sure to add any additional sanitization rules to the dictionary. + /// + public class SanitizedJsonConfig : IConfiguration + { + /// + /// The string to use when masking sensitive data. + /// + public const string MASK_STRING = "Sanitized"; + + /// + /// The pattern to match the subdomain of a URL. + /// + public const string HOST_SUBDOMAIN_PATTERN = @"(?<=.+://)([^\.]+)(?=[\./])"; + + private static readonly Regex HOST_SUBDOMAIN_MATCHER = new Regex(HOST_SUBDOMAIN_PATTERN, RegexOptions.Compiled); + private static readonly IReadOnlyDictionary> SANITIZERS = new Dictionary> + { + ["subscription_id"] = v => MASK_STRING, + ["resource_group"] = v => MASK_STRING, + ["endpoint"] = v => v is not null && (v is string || v is Uri) + ? MaskUriSubdomain(v.ToString())! + : MASK_STRING, + ["key"] = v => MASK_STRING, + ["api_key"] = v => MASK_STRING, + }; + + private Uri? _endpoint; + private string? _key; + private string? _deployment; + + /// + /// Creates a new instance. + /// + public SanitizedJsonConfig() + { + ExtensionData = new SortedDictionary(); + } + + /// + /// Creates a new instance from another . + /// + /// The configuration to create from. + /// If the configuration was null. + public SanitizedJsonConfig(JsonConfig config) : this() + { + if (config == null) + { + throw new ArgumentNullException(nameof(config)); + } + + Endpoint = config.Endpoint; + Key = config.Key; + Deployment = config.Deployment; + + if (config?.ExtensionData != null) + { + foreach (var kvp in config.ExtensionData) + { + switch (kvp.Value.ValueKind) + { + case JsonValueKind.Undefined: + case JsonValueKind.Null: + break; + case JsonValueKind.String: + SetValue(kvp.Key, kvp.Value.GetString()); + break; + default: + ExtensionData[kvp.Key] = kvp.Value.Clone(); + break; + } + } + } + } + + /// + public Uri? Endpoint + { + get => _endpoint; + set => _endpoint = MaskProperty(value); + } + + /// + public string? Key + { + get => _key; + set => _key = MaskProperty(value); + } + + /// + public string? Deployment + { + get => _deployment; + set => _deployment = MaskProperty(value); + } + + /// + /// Json values that are not part of the class go here. + /// + [JsonExtensionData] + public IDictionary ExtensionData { get; } + + /// + public virtual TVal? GetValue(string key) + { + if (ExtensionData?.TryGetValue(key, out JsonElement value) == true) + { + return value.Deserialize(JsonConfig.JSON_OPTIONS); + } + + return default; + } + + /// + /// Sets an additional value in the configuration. If the value is null it will be removed. + /// + /// Type of the value to set. + /// The name of the value (usually snake cased). For example: fine_tuned_model. + /// The value to set. + public virtual void SetValue(string key, TVal? value) + { + if (value == null) + { + if (ExtensionData != null) + { + ExtensionData.Remove(key); + } + } + else + { + value = MaskData(key, value); + JsonElement json = JsonHelpers.SerializeToElement(value, JsonConfig.JSON_OPTIONS); + ExtensionData[key] = json; + } + } + + private static TVal? MaskProperty(TVal? value, [CallerMemberName] string? key = null) + { + string convertedKey = JsonConfig.JSON_OPTIONS.PropertyNamingPolicy?.ConvertName(key ?? string.Empty) ?? string.Empty; + return MaskData(convertedKey, value); + } + + private static TVal? MaskData(string key, TVal? value) + { + if (value == null) + { + return default; + } + else if (SANITIZERS.TryGetValue(key ?? string.Empty, out var sanitizer)) + { + return (TVal?)sanitizer(value); + } + + return value; + } + + private static Uri? MaskUriSubdomain(string? uri) + { + if (uri == null) + { + return null; + } + + string maskedUrl = HOST_SUBDOMAIN_MATCHER.Replace(uri.ToString(), MASK_STRING); + return new Uri(maskedUrl); + } + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Extensions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Extensions.cs new file mode 100644 index 000000000..d25ef06b3 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/Extensions.cs @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Text; + +namespace Azure.AI.OpenAI.Tests.Utils; + +/// +/// Helper extension methods. +/// +public static class Extensions +{ + /// + /// Attempts to fill the buffer as much as possible from a stream. This will try to keep reading + /// until the buffer is filled, or the stream ends. + /// + /// The stream to read from. + /// The buffer to try to fill. + /// The number of bytes read. + public static int FillBuffer(this Stream stream, byte[] buffer) + { + if (stream == null) + throw new ArgumentNullException(nameof(stream)); + else if (buffer == null) + throw new ArgumentNullException(nameof(buffer)); + + int totalRead = 0; + while (totalRead < buffer.Length) + { + int read = stream.Read(buffer, totalRead, buffer.Length - totalRead); + if (read == 0) + { + return totalRead; + } + + totalRead += read; + } + + return totalRead; + } + + /// + /// Pads the current instance with the specified character on the left. + /// + /// The string builder instance + /// The total width we want the string builder to be + /// The padding characters + /// The same builder for chaining, with any needed padding. + public static StringBuilder PadRight(this StringBuilder builder, int totalWidth, char paddingChar = ' ') + { + if (builder == null) + throw new ArgumentNullException(nameof(builder)); + else if (totalWidth < 0) + throw new ArgumentOutOfRangeException(nameof(totalWidth), "Total width must be greater than or equal to 0."); + else if (totalWidth == 0) + return builder; + + int padding = totalWidth - builder.Length; + if (padding > 0) + { + builder.Append(paddingChar, padding); + } + + return builder; + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/JsonOptions.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/JsonOptions.cs new file mode 100644 index 000000000..4489b0169 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/JsonOptions.cs @@ -0,0 +1,168 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.Globalization; +using System.Runtime.CompilerServices; +using System.Text.Json; + +#nullable enable + +namespace Azure.AI.OpenAI.Tests.Utils; + +/// +/// A helper class to make working with older versions of System.Text.Json simpler +/// +public static class JsonOptions +{ + // TODO FIXME once we update to newer versions of System.Text.JSon we should switch to using + // JsonNamingPolicy.SnakeCaseLower + public static JsonNamingPolicy SnakeCaseLower { get; } = + new SnakeCaseNamingPolicy(); + + public static JsonSerializerOptions OpenAIJsonOptions { get; } = new() + { + PropertyNameCaseInsensitive = true, + PropertyNamingPolicy = SnakeCaseLower, +#if NETFRAMEWORK + IgnoreNullValues = true, +#else + DefaultIgnoreCondition = System.Text.Json.Serialization.JsonIgnoreCondition.WhenWritingNull, +#endif + Converters = + { + new ModelReaderWriterConverter(), + new UnixDateTimeConverter() + } + }; + + public static JsonSerializerOptions AzureJsonOptions { get; } = new() + { + PropertyNameCaseInsensitive = true, + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, +#if NETFRAMEWORK + IgnoreNullValues = true, +#else + DefaultIgnoreCondition = System.Text.Json.Serialization.JsonIgnoreCondition.WhenWritingNull, +#endif + }; + + // Ported over from the source code for newer versions of System.Text.Json + private class SnakeCaseNamingPolicy : JsonNamingPolicy + { + private enum SeparatorState + { + NotStarted, + UppercaseLetter, + LowercaseLetterOrDigit, + SpaceSeparator + } + + public override string ConvertName(string name) + { + if (string.IsNullOrEmpty(name)) + { + return string.Empty; + } + + return ConvertName('_', name.AsSpan()); + } + + internal static string ConvertName(char separator, ReadOnlySpan chars) + { + char[]? rentedBuffer = null; + + int num = (int)(1.2 * chars.Length); + Span output = num > 128 + ? (rentedBuffer = ArrayPool.Shared.Rent(num))! + : stackalloc char[128]; + + SeparatorState separatorState = SeparatorState.NotStarted; + int charsWritten = 0; + + for (int i = 0; i < chars.Length; i++) + { + char c = chars[i]; + UnicodeCategory unicodeCategory = char.GetUnicodeCategory(c); + switch (unicodeCategory) + { + case UnicodeCategory.UppercaseLetter: + switch (separatorState) + { + case SeparatorState.LowercaseLetterOrDigit: + case SeparatorState.SpaceSeparator: + WriteChar(separator, ref output); + break; + case SeparatorState.UppercaseLetter: + if (i + 1 < chars.Length && char.IsLower(chars[i + 1])) + { + WriteChar(separator, ref output); + } + break; + } + + c = char.ToLowerInvariant(c); + WriteChar(c, ref output); + separatorState = SeparatorState.UppercaseLetter; + break; + + case UnicodeCategory.LowercaseLetter: + case UnicodeCategory.DecimalDigitNumber: + if (separatorState == SeparatorState.SpaceSeparator) + { + WriteChar(separator, ref output); + } + + WriteChar(c, ref output); + separatorState = SeparatorState.LowercaseLetterOrDigit; + break; + + case UnicodeCategory.SpaceSeparator: + if (separatorState != 0) + { + separatorState = SeparatorState.SpaceSeparator; + } + break; + + default: + WriteChar(c, ref output); + separatorState = SeparatorState.NotStarted; + break; + } + } + + string result = output.Slice(0, charsWritten).ToString(); + if (rentedBuffer != null) + { + output.Slice(0, charsWritten).Clear(); + ArrayPool.Shared.Return(rentedBuffer); + } + return result; + + void ExpandBuffer(ref Span destination) + { + int minimumLength = checked(destination.Length * 2); + char[] array = ArrayPool.Shared.Rent(minimumLength); + destination.CopyTo(array); + if (rentedBuffer != null) + { + destination.Slice(0, charsWritten).Clear(); + ArrayPool.Shared.Return(rentedBuffer); + } + rentedBuffer = array; + destination = rentedBuffer; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + void WriteChar(char value, ref Span destination) + { + if (charsWritten == destination.Length) + { + ExpandBuffer(ref destination); + } + destination[charsWritten++] = value; + } + } + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/MockTokenCredential.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/MockTokenCredential.cs new file mode 100644 index 000000000..8615c0f07 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/MockTokenCredential.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; + +namespace OpenAI.TestFramework.Mocks; + +/// +/// A mock token credential to be used for testing. +/// +public class MockTokenCredential : TokenCredential +{ + /// + /// Event raised when a token is requested. + /// + public event EventHandler? TokenRequested; + + /// + public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) + { + TokenRequested?.Invoke(this, requestContext); + return new AccessToken("TEST TOKEN " + string.Join(",", requestContext.Scopes), DateTimeOffset.MaxValue); + } + + /// + public override ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) + { + return new(GetToken(requestContext, cancellationToken)); + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/ModelReaderWriterConverter.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/ModelReaderWriterConverter.cs new file mode 100644 index 000000000..255628633 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/ModelReaderWriterConverter.cs @@ -0,0 +1,59 @@ +#nullable enable + +using System; +using System.ClientModel.Primitives; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Azure.AI.OpenAI.Tests.Utils +{ + /// + /// Adapter to allow mixing reflection based JSON serialization and deserialization with the ModelReaderWriter based ones + /// + public class ModelReaderWriterConverter : JsonConverterFactory + { + /// + public override bool CanConvert(Type typeToConvert) + { + bool implementsInterface = typeof(IJsonModel).IsAssignableFrom(typeToConvert); + bool hasParameterlessConstructor = typeToConvert.GetConstructors(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public) + .Any(ci => ci.GetParameters()?.Count() == 0); + return implementsInterface && hasParameterlessConstructor; + } + + /// + public override JsonConverter CreateConverter(Type typeToConvert, JsonSerializerOptions options) + { + return (JsonConverter)Activator.CreateInstance(typeof(InnerModelReaderWriterConverter<>).MakeGenericType([typeToConvert]))!; + } + + private class InnerModelReaderWriterConverter : JsonConverter where T : IJsonModel + { + private IJsonModel _converter; + + /// + /// Creates a new instance + /// + /// The type does not have any paramterless constructor + public InnerModelReaderWriterConverter() + { + _converter = (IJsonModel)(Activator.CreateInstance(typeof(T), true) + ?? throw new ArgumentNullException()); + } + + /// + public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + return _converter.Create(ref reader, ModelReaderWriterOptions.Json); + } + + /// + public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options) + { + _converter.Write(writer, ModelReaderWriterOptions.Json); + } + } + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/NonPublic.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/NonPublic.cs new file mode 100644 index 000000000..f890e4e75 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/NonPublic.cs @@ -0,0 +1,123 @@ +#nullable enable + +using System; +using System.Reflection; + +namespace Azure.AI.OpenAI.Tests.Utils; + +/// +/// Helpers to make accessing the many internal or private members of the Azure test framework more streamlined +/// +public static class NonPublic +{ + /// + /// Creates an accessor for an internal, protected, or private property. + /// + /// The type of the class that defines this property. + /// The type of the property. + /// The name of the property. + /// The property accessor. + /// If a property with that name and type could not be found. + public static Accessor FromProperty(string propertyName) where TObj : class + { + PropertyInfo? prop = typeof(TObj).GetProperty( + propertyName, BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance); + + if (prop == null) + { + throw new ArgumentException($"'{propertyName}' property could not be found in '{typeof(TObj).FullName}'"); + } + else if (prop.PropertyType != typeof(TProp)) + { + throw new ArgumentException($"'{propertyName}' property is not of type '{typeof(TProp).FullName}'"); + } + + Func? getter = null; + Action? setter = null; + + MethodInfo? method = prop.GetGetMethod(true); + if (method != null) + { + getter = (Func)method.CreateDelegate(typeof(Func)); + } + + method = prop.GetSetMethod(true); + if (method != null) + { + setter = (Action)method.CreateDelegate(typeof(Action)); + } + + return new Accessor(getter, setter); + } + + /// + /// Creates an accessory for an internal, protected, or private field. + /// + /// The type of the class that defines this field. + /// The type of the field. + /// The name of the field. + /// The filed accessor. + /// If a field with that name and type could not be found. + public static Accessor FromField(string fieldName) where TObj : class + { + FieldInfo? field = typeof(TObj).GetField( + fieldName, BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance); + + if (field == null) + { + throw new ArgumentException($"'{fieldName}' field could not be found in '{typeof(TObj).FullName}'"); + } + else if (field.FieldType != typeof(TField)) + { + throw new ArgumentException($"'{fieldName}' field is not of type '{typeof(TField).FullName}'"); + } + + Func getter = (instance) => (TField)field.GetValue(instance)!; + Action? setter = (instance, val) => field.SetValue(instance, val); + + return new Accessor(getter, setter); + } + + /// + /// The accessor struct that makes accessing internal, protected, or private properties/fields easier. + /// + /// The type of the class that defines this field. + /// Tye type of the property/field. + public readonly struct Accessor where TObj : class + { + private readonly Func _getter; + private readonly Action _setter; + + public Accessor(Func? getter, Action? setter) + { + HasGet = getter != null; + _getter = getter ?? (_ => throw new InvalidOperationException("Get is not supported")); + HasSet = setter != null; + _setter = setter ?? ((_, __) => throw new InvalidOperationException("Get is not supported")); + } + + /// + /// True if we can read the value of the property/field. + /// + public bool HasGet { get; } + + /// + /// True if we can set the value of the property/field. + /// + public bool HasSet { get; } + + /// + /// Gets the value of the property/field. + /// + /// The instance to get the value from. Can be null for static properties/fields. + /// The value of the property/field. + public TValue Get(TObj? instance) => _getter(instance); + + /// + /// Sets the value of the property/field. + /// + /// The instance to set the value on. Can be null for static properties/fields. + /// The value to set. + public void Set(TObj? instance, TValue value) => _setter(instance, value); + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/RunOnScopeExit.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/RunOnScopeExit.cs new file mode 100644 index 000000000..5ec65283e --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/RunOnScopeExit.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Threading.Tasks; + +namespace Azure.AI.OpenAI.Tests.Utils +{ + public class RunOnScopeExit : IAsyncDisposable + { + private Func _asyncFunc; + + public RunOnScopeExit(Func asyncFunc) + { + _asyncFunc = asyncFunc ?? throw new ArgumentNullException(nameof(asyncFunc)); + } + + public async ValueTask DisposeAsync() + { + await _asyncFunc().ConfigureAwait(false); + } + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/TestConfig.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/TestConfig.cs new file mode 100644 index 000000000..5b999e118 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/TestConfig.cs @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Text.Json; +using Azure.AI.OpenAI.Tests.Utils.Config; +using OpenAI.TestFramework; +using OpenAI.TestFramework.Utils; + +namespace Azure.AI.OpenAI.Tests; + +internal class TestConfig +{ + private const string AZURE_OPENAI_ENV_KEY_PREFIX = "AZURE_OPENAI"; + + private readonly bool _isPlayback; + private readonly IReadOnlyDictionary _jsonConfig; + private SortedDictionary _recordedConfig; + + public virtual string AssetsSubFolder => "Assets"; + public virtual string AssetsJson => "test_config.json"; + public virtual string PlaybackAssetsJson => $"playback_{AssetsJson}"; + + public TestConfig(RecordedTestMode? mode) + { + _isPlayback = mode == RecordedTestMode.Playback; + _recordedConfig = new(new DefaultFirstStringComparer()); + + // Load the previous playback configuration and use that to initialize the recorded config + string playbackConfigJson = Path.Combine(AssetsSubFolder, PlaybackAssetsJson); + var playbackConfig = ReadJsonConfig(playbackConfigJson); + if (playbackConfig != null) + { + foreach (var kvp in playbackConfig) + { + _recordedConfig.Add(kvp.Key, new SanitizedJsonConfig(kvp.Value)); + } + } + + // When in playback mode, we always use the playback configuration. This ensures that we run in the same way in CI/CD + // as we do locally. + if (_isPlayback) + { + _jsonConfig = playbackConfig + ?? throw new InvalidOperationException($"The playback config file was not found: {playbackConfigJson}"); + } + else + { + _jsonConfig = new[] + { + AssetsJson, + Path.Combine(AssetsSubFolder, AssetsJson), + Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.UserProfile), ".azure", AssetsSubFolder, AssetsJson), + Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), ".azure", AssetsSubFolder, AssetsJson), + } + .Select(f => ReadJsonConfig(f)) + .FirstOrDefault(c => c != null) + ?? new Dictionary(); + } + } + + public virtual IConfiguration? GetConfig() + => GetConfig(ToKey()); + + public virtual IConfiguration? GetConfig(string name) + { + // In order to populate each property of the Config object, the search order is as follows: + // 1. Getting the specific config for the name in the JSON config file + // 2. Getting the value from the default config + // 3. (Not in playback) Getting the value from the AZURE_OPENAI__ environment variable + // 4. (Not in playback) Getting the value from the AZURE_OPENAI_ environment variable + // It will fall through each one if the value is null + + return new FlattenedConfig( + [ + new NamedConfig(_jsonConfig.GetValueOrDefault(name), name), + new NamedConfig(_jsonConfig.GetValueOrDefault(JsonConfig.DEFAULT_CONFIG_NAME), null), + _isPlayback ? null : new EnvironmentValuesConfig(AZURE_OPENAI_ENV_KEY_PREFIX, name), + _isPlayback ? null : new EnvironmentValuesConfig(AZURE_OPENAI_ENV_KEY_PREFIX) + ], _recordedConfig); + } + + public virtual void SavePlaybackConfig() + { + try + { + string? sourceDirectoryPath = typeof(TestConfig).Assembly + .GetCustomAttributes() + .FirstOrDefault(attrib => attrib.Key == "TestProjectSourceBasePath") + ?.Value; + + if (sourceDirectoryPath != null) + { + string playbackConfigJson = Path.Combine(sourceDirectoryPath, AssetsSubFolder, PlaybackAssetsJson); + + string oldJson = string.Empty; + if (File.Exists(playbackConfigJson)) + { + oldJson = File.ReadAllText(playbackConfigJson); + } + + string newJson = JsonSerializer.Serialize(_recordedConfig, JsonConfig.JSON_OPTIONS); + + // Visual Studio's hot reload feature can get upset if you are debugging the code and the playback config + // file changes, so we only save it if it is different + if (oldJson != newJson) + { + File.WriteAllText(playbackConfigJson, newJson, Encoding.UTF8); + } + } + } + catch (Exception ex) + { + Console.Error.WriteLine("Failed to save the playback configuration file. Details: " + ex); + } + } + + protected static string ToKey() + { + string fullName = typeof(TClient).Name; + int stopAt = fullName.LastIndexOf("Client"); + stopAt = stopAt == -1 ? fullName.Length : stopAt; + + StringBuilder builder = new(fullName.Length); + bool prevWasUpper = true; + + for (int i = 0; i < stopAt; i++) + { + char c = fullName[i]; + if (char.IsUpper(c)) + { + if (prevWasUpper) + { + builder.Append(char.ToLowerInvariant(c)); + } + else + { + builder.Append('_'); + builder.Append(char.ToLowerInvariant(c)); + } + + prevWasUpper = true; + } + else + { + builder.Append(c); + prevWasUpper = false; + } + } + + return builder.ToString(); + } + + protected static IReadOnlyDictionary? ReadJsonConfig(string fullPath) + { + try + { + if (File.Exists(fullPath)) + { + string json = File.ReadAllText(fullPath); + return JsonSerializer.Deserialize>(json, JsonConfig.JSON_OPTIONS); + } + } + catch (Exception) + { + } + + return null; + } + + private class DefaultFirstStringComparer : IComparer + { + public int Compare(string? x, string? y) + { + if (ReferenceEquals(x, y)) + { + return 0; + } + else if (x == null) + { + return -1; + } + else if (y == null) + { + return 1; + } + else if (x == JsonConfig.DEFAULT_CONFIG_NAME && y != JsonConfig.DEFAULT_CONFIG_NAME) + { + return -1; + } + else if (x != JsonConfig.DEFAULT_CONFIG_NAME && y == JsonConfig.DEFAULT_CONFIG_NAME) + { + return 1; + } + + return string.Compare(x, y, StringComparison.Ordinal); + } + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/UnixDateTimeConverter.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/UnixDateTimeConverter.cs new file mode 100644 index 000000000..a4c9b1856 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/Utils/UnixDateTimeConverter.cs @@ -0,0 +1,109 @@ +#nullable enable + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Azure.AI.OpenAI.Tests.Utils +{ + public class UnixDateTimeConverter : JsonConverterFactory + { + private static Lazy _dateTimeOffset = new(() => new DateTimeOffsetConverter(), false); + private static Lazy _nullableDateTimeOffset = new(() => new NullableDateTimeOffsetConverter(), false); + private static Lazy _dateTime = new(() => new DateTimeConverter(), false); + private static Lazy _nullableDateTime = new(() => new NullableDateTimeConverter(), false); + + public override bool CanConvert(Type typeToConvert) + => typeToConvert == typeof(DateTime) + || typeToConvert == typeof(DateTime?) + || typeToConvert == typeof(DateTimeOffset) + || typeToConvert == typeof(DateTimeOffset?); + + public override JsonConverter CreateConverter(Type typeToConvert, JsonSerializerOptions options) + { + switch (typeToConvert) + { + case Type t when t == typeof(DateTime): + return _dateTime.Value; + case Type t when t == typeof(DateTime?): + return _nullableDateTime.Value; + case Type t when t == typeof(DateTimeOffset): + return _dateTimeOffset.Value; + case Type t when t == typeof(DateTimeOffset?): + return _nullableDateTimeOffset.Value; + default: + throw new NotSupportedException(); + } + } + + private static DateTimeOffset? Read(ref Utf8JsonReader reader) + { + if (reader.TokenType == JsonTokenType.Null) + { + return default; + } + else if (reader.TokenType == JsonTokenType.Number) + { + long unixTimeInSeconds = reader.GetInt64(); + return DateTimeOffset.FromUnixTimeSeconds(unixTimeInSeconds).ToLocalTime(); + } + else if (reader.TokenType == JsonTokenType.String + && long.TryParse(reader.GetString(), out long unixTime)) + { + return DateTimeOffset.FromUnixTimeSeconds(unixTime).ToLocalTime(); + } + else + { + throw new JsonException("Expected a number token type but got " + reader.TokenType); + } + } + + private static void Write(Utf8JsonWriter writer, DateTimeOffset? value) + { + if (value == null) + { + writer.WriteNullValue(); + } + else + { + writer.WriteNumberValue(value.Value.ToUnixTimeSeconds()); + } + } + + private class DateTimeOffsetConverter : JsonConverter + { + public override DateTimeOffset Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + => UnixDateTimeConverter.Read(ref reader) ?? default; + + public override void Write(Utf8JsonWriter writer, DateTimeOffset value, JsonSerializerOptions options) + => UnixDateTimeConverter.Write(writer, value); + } + + private class NullableDateTimeOffsetConverter : JsonConverter + { + public override DateTimeOffset? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + => UnixDateTimeConverter.Read(ref reader); + + public override void Write(Utf8JsonWriter writer, DateTimeOffset? value, JsonSerializerOptions options) + => UnixDateTimeConverter.Write(writer, value); + } + + private class DateTimeConverter : JsonConverter + { + public override DateTime Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + => UnixDateTimeConverter.Read(ref reader)?.LocalDateTime ?? default; + + public override void Write(Utf8JsonWriter writer, DateTime value, JsonSerializerOptions options) + => UnixDateTimeConverter.Write(writer, value); + } + + private class NullableDateTimeConverter : JsonConverter + { + public override DateTime? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + => UnixDateTimeConverter.Read(ref reader)?.LocalDateTime ?? default; + + public override void Write(Utf8JsonWriter writer, DateTime? value, JsonSerializerOptions options) + => UnixDateTimeConverter.Write(writer, value); + } + } +} diff --git a/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/VectorStoreTests.cs b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/VectorStoreTests.cs new file mode 100644 index 000000000..cb6dc8e17 --- /dev/null +++ b/.dotnet.azure/sdk/openai/Azure.AI.OpenAI/tests/VectorStoreTests.cs @@ -0,0 +1,242 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable disable + +using System; +using System.ClientModel; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Azure.AI.OpenAI.Tests.Utils.Config; +using NUnit.Framework; +using OpenAI; +using OpenAI.Files; +using OpenAI.TestFramework; +using OpenAI.VectorStores; + +namespace Azure.AI.OpenAI.Tests; + +public class VectorStoreTests : AoaiTestBase +{ + public VectorStoreTests(bool isAsync) : base(isAsync) + { } + + [Test] + [Category("Smoke")] + public void CanCreateClient() + { + VectorStoreClient client = GetTestClient(); + Assert.That(client, Is.Not.Null); + } + + [RecordedTest] + public async Task CanCreateGetAndDeleteVectorStores() + { + VectorStoreClient client = GetTestClient(); + + VectorStore vectorStore = await client.CreateVectorStoreAsync(); + Validate(vectorStore); + bool deleted = await client.DeleteVectorStoreAsync(vectorStore); + Assert.That(deleted, Is.True); + + IReadOnlyList testFiles = await GetNewTestFilesAsync(client.GetConfigOrThrow(), 5); + + vectorStore = await client.CreateVectorStoreAsync(new VectorStoreCreationOptions() + { + FileIds = { testFiles[0].Id }, + Name = "test vector store", + ExpirationPolicy = new VectorStoreExpirationPolicy() + { + Anchor = VectorStoreExpirationAnchor.LastActiveAt, + Days = 3, + }, + Metadata = + { + ["test-key"] = "test-value", + }, + }); + Validate(vectorStore); + Assert.Multiple(() => + { + Assert.That(vectorStore.Name, Is.EqualTo("test vector store")); + Assert.That(vectorStore.ExpirationPolicy?.Anchor, Is.EqualTo(VectorStoreExpirationAnchor.LastActiveAt)); + Assert.That(vectorStore.ExpirationPolicy?.Days, Is.EqualTo(3)); + Assert.That(vectorStore.FileCounts.Total, Is.EqualTo(1)); + Assert.That(vectorStore.CreatedAt, Is.GreaterThan(s_2024)); + Assert.That(vectorStore.ExpiresAt, Is.GreaterThan(s_2024)); + Assert.That(vectorStore.Status, Is.EqualTo(VectorStoreStatus.InProgress)); + Assert.That(vectorStore.Metadata?.TryGetValue("test-key", out string metadataValue) == true && metadataValue == "test-value"); + }); + vectorStore = await client.GetVectorStoreAsync(vectorStore); + Assert.Multiple(() => + { + Assert.That(vectorStore.Name, Is.EqualTo("test vector store")); + Assert.That(vectorStore.ExpirationPolicy?.Anchor, Is.EqualTo(VectorStoreExpirationAnchor.LastActiveAt)); + Assert.That(vectorStore.ExpirationPolicy?.Days, Is.EqualTo(3)); + Assert.That(vectorStore.FileCounts.Total, Is.EqualTo(1)); + Assert.That(vectorStore.CreatedAt, Is.GreaterThan(s_2024)); + Assert.That(vectorStore.ExpiresAt, Is.GreaterThan(s_2024)); + Assert.That(vectorStore.Metadata?.TryGetValue("test-key", out string metadataValue) == true && metadataValue == "test-value"); + }); + + deleted = await client.DeleteVectorStoreAsync(vectorStore.Id); + Assert.That(deleted, Is.True); + + vectorStore = await client.CreateVectorStoreAsync(new VectorStoreCreationOptions() + { + FileIds = testFiles.Select(file => file.Id).ToList() + }); + Validate(vectorStore); + Assert.Multiple(() => + { + Assert.That(vectorStore.Name, Is.Null.Or.Empty); + Assert.That(vectorStore.FileCounts.Total, Is.EqualTo(5)); + }); + } + + [RecordedTest] + public async Task CanEnumerateVectorStores() + { + VectorStoreClient client = GetTestClient(); + for (int i = 0; i < 10; i++) + { + VectorStore vectorStore = await client.CreateVectorStoreAsync(new VectorStoreCreationOptions() + { + Name = $"Test Vector Store {i}", + }); + Validate(vectorStore); + Assert.That(vectorStore.Name, Is.EqualTo($"Test Vector Store {i}")); + } + + AsyncPageCollection response = client.GetVectorStoresAsync(new VectorStoreCollectionOptions() { Order = ListOrder.NewestFirst }); + Assert.That(response, Is.Not.Null); + + int lastIdSeen = int.MaxValue; + int count = 0; + await foreach (VectorStore vectorStore in response.GetAllValuesAsync()) + { + Assert.That(vectorStore.Id, Is.Not.Null); + if (vectorStore.Name?.StartsWith("Test Vector Store ") == true) + { + string idString = vectorStore.Name.Substring("Test Vector Store ".Length); + + Assert.That(int.TryParse(idString, out int seenId), Is.True); + Assert.That(seenId, Is.LessThan(lastIdSeen)); + lastIdSeen = seenId; + } + if (lastIdSeen == 0 || ++count >= 100) + { + break; + } + } + + Assert.That(lastIdSeen, Is.EqualTo(0)); + } + + [RecordedTest] + public async Task CanAssociateFiles() + { + VectorStoreClient client = GetTestClient(); + VectorStore vectorStore = await client.CreateVectorStoreAsync(); + Validate(vectorStore); + + IReadOnlyList files = await GetNewTestFilesAsync(client.GetConfigOrThrow(), 3); + + foreach (OpenAIFileInfo file in files) + { + VectorStoreFileAssociation association = await client.AddFileToVectorStoreAsync(vectorStore, file); + Validate(association); + Assert.Multiple(() => + { + Assert.That(association.FileId, Is.EqualTo(file.Id)); + Assert.That(association.VectorStoreId, Is.EqualTo(vectorStore.Id)); + Assert.That(association.LastError, Is.Null); + Assert.That(association.CreatedAt, Is.GreaterThan(s_2024)); + Assert.That(association.Status, Is.AnyOf(VectorStoreFileAssociationStatus.InProgress, VectorStoreFileAssociationStatus.Completed)); + }); + } + + bool removed = await client.RemoveFileFromStoreAsync(vectorStore, files[0]); + Assert.True(removed); + + // Errata: removals aren't immediately reflected when requesting the list + Thread.Sleep(1000); + + int count = 0; + AsyncPageCollection response = client.GetFileAssociationsAsync(vectorStore); + await foreach (VectorStoreFileAssociation association in response.GetAllValuesAsync()) + { + count++; + Assert.That(association.FileId, Is.Not.EqualTo(files[0].Id)); + Assert.That(association.VectorStoreId, Is.EqualTo(vectorStore.Id)); + } + + Assert.That(count, Is.EqualTo(2)); + } + + [RecordedTest] + public async Task CanUseBatchIngestion() + { + VectorStoreClient client = GetTestClient(); + VectorStore vectorStore = await client.CreateVectorStoreAsync(); + Validate(vectorStore); + + IReadOnlyList testFiles = await GetNewTestFilesAsync(client.GetConfigOrThrow(), 3); + + VectorStoreBatchFileJob batchJob = await client.CreateBatchFileJobAsync(vectorStore, testFiles); + Assert.Multiple(() => + { + Assert.That(batchJob.BatchId, Is.Not.Null); + Assert.That(batchJob.VectorStoreId, Is.EqualTo(vectorStore.Id)); + Assert.That(batchJob.Status, Is.EqualTo(VectorStoreBatchFileJobStatus.InProgress)); + }); + + batchJob = await WaitUntilReturnLast( + batchJob, + () => client.GetBatchFileJobAsync(batchJob), + b => b.Status != VectorStoreBatchFileJobStatus.InProgress); + Assert.That(batchJob.Status, Is.EqualTo(VectorStoreBatchFileJobStatus.Completed)); + + AsyncPageCollection response = client.GetFileAssociationsAsync(batchJob); + await foreach (VectorStoreFileAssociation association in response.GetAllValuesAsync()) + { + Assert.Multiple(() => + { + Assert.That(association.FileId, Is.Not.Null); + Assert.That(association.VectorStoreId, Is.EqualTo(vectorStore.Id)); + Assert.That(association.Status, Is.EqualTo(VectorStoreFileAssociationStatus.Completed)); + // Assert.That(association.Size, Is.GreaterThan(0)); + Assert.That(association.CreatedAt, Is.GreaterThan(s_2024)); + Assert.That(association.LastError, Is.Null); + }); + } + } + + private async Task> GetNewTestFilesAsync(IConfiguration config, int count) + { + AzureOpenAIClient azureClient = GetTestTopLevelClient(config, new() + { + ShouldOutputRequests = false, + ShouldOutputResponses = false, + }); + FileClient client = GetTestClient(azureClient, config); + + List files = []; + for (int i = 0; i < count; i++) + { + OpenAIFileInfo file = await client.UploadFileAsync( + BinaryData.FromString("This is a test file").ToStream(), + $"test_file_{i.ToString().PadLeft(3, '0')}.txt", + FileUploadPurpose.Assistants) + .ConfigureAwait(false); + Validate(file); + files.Add(file); + } + + return files; + } + + private static readonly DateTimeOffset s_2024 = new(2024, 1, 1, 0, 0, 0, TimeSpan.Zero); +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/Directory.Build.props b/.dotnet.azure/sdk/openai/tools/TestFramework/Directory.Build.props new file mode 100644 index 000000000..f85173f26 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/Directory.Build.props @@ -0,0 +1,18 @@ + + + + false + true + false + false + false + false + true + + + + + diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/TestFramework.sln b/.dotnet.azure/sdk/openai/tools/TestFramework/TestFramework.sln new file mode 100644 index 000000000..a88dc3caf --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/TestFramework.sln @@ -0,0 +1,31 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.10.35013.160 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenAI.TestFramework.Tests", "tests\OpenAI.TestFramework.Tests.csproj", "{61E849EB-F8BC-47C7-B730-874DD678BEA7}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenAI.TestFramework", "src\OpenAI.TestFramework.csproj", "{BE2FF759-255B-44A8-BAE7-73E287AEEB97}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {61E849EB-F8BC-47C7-B730-874DD678BEA7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {61E849EB-F8BC-47C7-B730-874DD678BEA7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {61E849EB-F8BC-47C7-B730-874DD678BEA7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {61E849EB-F8BC-47C7-B730-874DD678BEA7}.Release|Any CPU.Build.0 = Release|Any CPU + {BE2FF759-255B-44A8-BAE7-73E287AEEB97}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {BE2FF759-255B-44A8-BAE7-73E287AEEB97}.Debug|Any CPU.Build.0 = Debug|Any CPU + {BE2FF759-255B-44A8-BAE7-73E287AEEB97}.Release|Any CPU.ActiveCfg = Release|Any CPU + {BE2FF759-255B-44A8-BAE7-73E287AEEB97}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {F145C399-D9D8-45F9-87DC-4BFFF983FA91} + EndGlobalSection +EndGlobal diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/assets.json b/.dotnet.azure/sdk/openai/tools/TestFramework/assets.json new file mode 100644 index 000000000..d33e24017 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/assets.json @@ -0,0 +1,6 @@ +{ + "AssetsRepo": "Azure/azure-sdk-assets", + "AssetsRepoPrefixPath": "net", + "TagPrefix": "net/openai/OpenAI.TestFramework", + "Tag": "net/openai/OpenAI.TestFramework_f41330e3ac" +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Adapters/SyncToAsyncCollectionResult.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Adapters/SyncToAsyncCollectionResult.cs new file mode 100644 index 000000000..64096eb4a --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Adapters/SyncToAsyncCollectionResult.cs @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; + +namespace OpenAI.TestFramework.Adapters; + +/// +/// An adapter to make a look and work like a . This +/// simplifies writing test cases +/// +/// The type of the items the enumerator returns +public class SyncToAsyncCollectionResult : AsyncCollectionResult +{ + private bool _responseSet; + private CollectionResult? _syncCollection; + private Exception? _ex; + + /// + /// Creates a new instance + /// + /// The synchronous collection to wrap + /// If the collection was null + public SyncToAsyncCollectionResult(CollectionResult syncCollection) + { + _syncCollection = syncCollection ?? throw new ArgumentNullException(nameof(syncCollection)); + TrySetRawResponse(); + } + + /// + /// Creates a new instance. + /// + /// The exception to throw. + /// If the exception was null. + public SyncToAsyncCollectionResult(Exception ex) + { + _ex = ex ?? throw new ArgumentNullException(nameof(ex)); + _syncCollection = null; + } + + /// + public override IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return InnerEnumerable(cancellationToken).GetAsyncEnumerator(); + } + + private async IAsyncEnumerable InnerEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (_ex != null) + { + ExceptionDispatchInfo.Capture(_ex).Throw(); + } + + var asyncWrapper = new SyncToAsyncEnumerator(_syncCollection?.GetEnumerator()!, cancellationToken); + while (await asyncWrapper.MoveNextAsync().ConfigureAwait(false)) + { + TrySetRawResponse(); + yield return asyncWrapper.Current; + } + } + + private void TrySetRawResponse() + { + if (_responseSet) + { + return; + } + + // Client result doesn't provide virtual methods so we have to manually set it ourselves here + try + { + var raw = _syncCollection?.GetRawResponse(); + if (raw != null) + { + SetRawResponse(raw); + _responseSet = true; + } + } + catch (Exception) { /* dont' care */ } + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Adapters/SyncToAsyncEnumerable.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Adapters/SyncToAsyncEnumerable.cs new file mode 100644 index 000000000..c71c1a0e1 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Adapters/SyncToAsyncEnumerable.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Adapters; + +/// +/// Wraps an as an +/// +/// The type of items being enumerated. +public class SyncToAsyncEnumerable : IAsyncEnumerable +{ + private IEnumerable _enumerable; + Exception? _ex; + + /// + /// Creates a new instance. + /// + /// The synchronous enumerable to wrap. + public SyncToAsyncEnumerable(IEnumerable enumerable) + { + _enumerable = enumerable; + } + + /// + /// Creates a new instance. + /// + /// The synchronous enumerable to wrap. + public SyncToAsyncEnumerable(Exception ex) + { + _ex = ex; + _enumerable = Array.Empty(); + } + + /// + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + if (_ex != null) + { + return new SyncToAsyncEnumerator(_ex); + } + else + { + return new SyncToAsyncEnumerator(_enumerable.GetEnumerator(), cancellationToken); + } + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Adapters/SyncToAsyncEnumerator.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Adapters/SyncToAsyncEnumerator.cs new file mode 100644 index 000000000..fa0ce81b0 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Adapters/SyncToAsyncEnumerator.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Runtime.ExceptionServices; + +namespace OpenAI.TestFramework.Adapters; + +/// +/// Wraps an as an +/// +/// The type of items being enumerated. +public class SyncToAsyncEnumerator : IAsyncEnumerator +{ + private IEnumerator _sync; + private CancellationToken _token; + private Exception? _ex; + + /// + /// Creates a new instance. + /// + /// The synchronous enumerator to wrap. + /// (Optional) The cancellation token to use. + /// If the enumerator was null. + public SyncToAsyncEnumerator(IEnumerator sync, CancellationToken token = default) + { + _sync = sync ?? throw new ArgumentNullException(nameof(sync)); + _token = token; + } + + /// + /// Creates a new instance. + /// + /// The exception to throw. + /// If the exception was null. + public SyncToAsyncEnumerator(Exception ex) + { + _sync = Enumerable.Empty().GetEnumerator(); + _token = default; + _ex = ex ?? throw new ArgumentNullException(nameof(ex)); + } + + /// + public T Current => _sync.Current; + + /// + public ValueTask DisposeAsync() + { + _sync.Dispose(); + return default; + } + + /// + public ValueTask MoveNextAsync() + { + if (_ex != null) + { + ExceptionDispatchInfo.Capture(_ex).Throw(); + } + + _token.ThrowIfCancellationRequested(); + bool ret = _sync.MoveNext(); + return new ValueTask(ret); + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Adapters/SyncToAsyncPageCollection.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Adapters/SyncToAsyncPageCollection.cs new file mode 100644 index 000000000..89b963137 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Adapters/SyncToAsyncPageCollection.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.Runtime.ExceptionServices; + +namespace OpenAI.TestFramework.Adapters; + +/// +/// An adapter to make a look and work like a . This +/// simplifies writing test cases. +/// +/// The type of the items the enumerator returns. +public class SyncToAsyncPageCollection : AsyncPageCollection +{ + private PageCollection? _syncCollection; + private Exception? _ex; + + /// + /// Creates a new instance. + /// + /// The synchronous collection to wrap. + /// If the collection was null. + public SyncToAsyncPageCollection(PageCollection syncCollection) + { + _syncCollection = syncCollection ?? throw new ArgumentNullException(nameof(syncCollection)); + } + + /// + /// Creates a new instance. + /// + /// The exception to throw. + /// If the exception was null. + public SyncToAsyncPageCollection(Exception ex) + { + _ex = ex ?? throw new ArgumentNullException(nameof(ex)); + _syncCollection = null; + } + + /// + protected override Task> GetCurrentPageAsyncCore() + { + if (_ex != null) + { + return Task.FromException>(_ex); + } + else + { + return Task.FromResult(_syncCollection!.GetCurrentPage()); + } + } + + /// + protected override async IAsyncEnumerator> GetAsyncEnumeratorCore(CancellationToken cancellationToken = default) + { + if (_ex != null) + { + ExceptionDispatchInfo.Capture(_ex).Throw(); + } + + foreach (PageResult page in _syncCollection!) + { + await Task.Delay(0).ConfigureAwait(false); + yield return page; + } + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/AsyncOnlyAttribute.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/AsyncOnlyAttribute.cs new file mode 100644 index 000000000..3fbfc191f --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/AsyncOnlyAttribute.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using NUnit.Framework; + +namespace OpenAI.TestFramework; + +/// +/// Attribute that can be applied to a test to indicate it only runs in asynchronous mode. +/// +[AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)] +public class AsyncOnlyAttribute() : NUnitAttribute +{ +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsync/AsyncToSyncInterceptor.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsync/AsyncToSyncInterceptor.cs new file mode 100644 index 000000000..08fe268ec --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsync/AsyncToSyncInterceptor.cs @@ -0,0 +1,429 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Reflection; +using Castle.DynamicProxy; +using OpenAI.TestFramework.Adapters; +using Ext = OpenAI.TestFramework.Utils.TypeExtensions; + +namespace OpenAI.TestFramework.AutoSyncAsync; + +/// +/// An interceptor for Castle dynamic proxies that allows you to call the synchronous version of a method when the asynchronous one +/// is called on the proxy. This is useful for testing where you can write the async version of a test, and then automatically test +/// both async and sync methods with the same test code. +/// +[DebuggerStepThrough] +public class AsyncToSyncInterceptor : IInterceptor +{ + private const string AsyncSuffix = "Async"; + + private static readonly TypeArrayEquality s_typeArrayEquality = new(); + private static readonly ConcurrentDictionary> s_syncAsyncPairs = new(); + private static readonly MethodInfo s_taskFromResult = typeof(Task).GetMethod(nameof(Task.FromResult), BindingFlags.Public | BindingFlags.Static)!; + private static readonly MethodInfo s_taskFromException = typeof(Task) + .GetMethods(BindingFlags.Static | BindingFlags.Public) + .Where(m => m.Name == nameof(Task.FromException) && m.IsGenericMethodDefinition) + .First(); + + private readonly BindingFlags _flags; + + /// + /// Creates a new instance. + /// + /// True if you want to use async methods, false otherwise. + /// The binding flags to use when searching for methods. Default is public instance methods. + public AsyncToSyncInterceptor(bool useAsync, BindingFlags flags = BindingFlags.Public | BindingFlags.Instance) + { + UseAsync = useAsync; + _flags = flags; + } + + /// + /// Gets the shared use sync methods instance. + /// + public static AsyncToSyncInterceptor UseSyncMethods { get; } = new(false); + + /// + /// Gets the shared use async methods instance. + /// + public static AsyncToSyncInterceptor UseAsyncMethods { get; } = new(true); + + /// + [DebuggerStepThrough] + public virtual void Intercept(IInvocation invocation) + { + // 1. Should we even intercept this? + if (ShouldSkipIntercepting(invocation.Method)) + { + invocation.Proceed(); + return; + } + + // 2. Check if this method is one of a pair of Operation and OperationAsync methods. + bool isSyncAsyncPair = IsMethodSyncAsyncPair(invocation.Method); + if (!isSyncAsyncPair) + { + throw CreateEx("Method does not have a synchronous and asynchronous pair", invocation.Method); + } + + // 3. If it is, check if the method is the synchronous version. We only allow async versions in the test code + bool isAsyncMethod = invocation.Method.Name.EndsWith(AsyncSuffix); + if (!isAsyncMethod) + { + throw CreateEx("You must use the asynchronous versions of the methods when writing your tests", invocation.Method); + } + + Type asyncReturnType = invocation.Method.ReturnType; + + // 4. Call the correct synchronous or asynchronous method and warp the returned result or exception + if (UseAsync) + { + // Async method running in async mode, no need to do anything, special, continue normally + invocation.Proceed(); + } + else + { + // Call the equivalent sync method + string methodName = RemoveAsyncSuffix(invocation.Method.Name); + Type expectedReturnType = ToSyncRetType(asyncReturnType); + Type[] expectedArgs = invocation.Method.GetParameters().Select(p => p.ParameterType).ToArray(); + + MethodInfo syncMethod = invocation.TargetType.GetMethod( + methodName, _flags, binder: null, expectedArgs, modifiers: null)!; + + // this should never happen since we've already checked for the existence of the expected method + Debug.Assert(syncMethod != null); + if (syncMethod == null) + { + throw CreateEx("Could not find the synchronous version of the method", invocation.Method); + } + + if (syncMethod.ContainsGenericParameters) + { + syncMethod = syncMethod.MakeGenericMethod(invocation.Method.GetGenericArguments()); + } + + // Call the synchronous method + try + { + object? result = syncMethod.Invoke(invocation.InvocationTarget, invocation.Arguments); + if (result != null && !expectedReturnType.IsAssignableFrom(result.GetType())) + { + throw CreateEx("The synchronous method returned an unexpected type", invocation.Method); + } + + invocation.ReturnValue = ToAsyncResult(asyncReturnType, result); + } + catch (TargetInvocationException ex) + { + invocation.ReturnValue = ToAsyncException(asyncReturnType, ex.InnerException ?? ex); + } + } + } + + /// + /// Whether or not we are using async methods. + /// + public bool UseAsync { get; } + + /// Determines whether or not we should skip intercepting this method or not. + /// + /// The method we are inspecting. + /// True to skip intercepting this method, false otherwise. + protected virtual bool ShouldSkipIntercepting(MethodInfo? method) + { + return method == null + // Skip for special names (i.e. getters and setters) + || method.IsSpecialName + // Also for dispose methods + || method.Name == nameof(IDisposable.Dispose) + || method.Name == nameof(IAsyncDisposable.DisposeAsync); + } + + /// + /// Determines whether or not the specified method is part of a pair of synchronous and asynchronous methods. This will + /// check based on 3 factors: + /// + /// If there is a "???" and "???Async" pair of named methods + /// If the arguments are exactly the same for both methods + /// If we know how to determine the expected return type for the synchronous method, from the asynchronous one + /// + /// + /// The method to check. + /// True if it is, false otherwise. + protected virtual bool IsMethodSyncAsyncPair(MethodInfo? method) + { + if (method == null || method.DeclaringType == null) + { + return false; + } + + ISet validPrefixes = s_syncAsyncPairs.GetOrAdd(method.DeclaringType, t => DetermineValidSyncAsyncPairs(t, _flags)); + return validPrefixes.Contains(RemoveAsyncSuffix(method.Name)); + } + + /// + /// Determines what the corresponding synchronous return type would be for the specified asynchronous return type. + /// + /// The asynchronous return type. + /// The corresponding synchronous return type. + /// If we don't know what the equivalent would be. + protected virtual Type ToSyncRetType(Type asyncReturnType) + { + if (typeof(Task) == asyncReturnType || typeof(ValueTask) == asyncReturnType) + { + return typeof(void); + } + else if (Ext.IsClosedGenericOf(asyncReturnType, typeof(Task<>), out Type[] genericTypes)) + { + return genericTypes[0]; + } + else if (Ext.IsClosedGenericOf(asyncReturnType, typeof(ValueTask<>), out genericTypes)) + { + return genericTypes[0]; + } + else if (Ext.IsClosedGenericOf(asyncReturnType, typeof(AsyncPageCollection<>), out genericTypes)) + { + return typeof(PageCollection<>).MakeGenericType(genericTypes); + } + else if (Ext.IsClosedGenericOf(asyncReturnType, typeof(AsyncCollectionResult<>), out genericTypes)) + { + return typeof(CollectionResult<>).MakeGenericType(genericTypes); + } + else if (Ext.IsClosedGenericOf(asyncReturnType, typeof(IAsyncEnumerable<>), out genericTypes)) + { + return typeof(IEnumerable<>).MakeGenericType(genericTypes); + } + else + { + throw new NotSupportedException("Don't know how to create the sync to async wrapper for " + asyncReturnType.FullName); + } + } + + /// + /// Wraps the result from a synchronous method into the equivalent asynchronous return type. + /// + /// The asynchronous return type. + /// The result to wrap. + /// The wrapped result. + /// If we don't support the conversion. + protected virtual object? ToAsyncResult(Type asyncReturnType, object? result) + { + if (typeof(Task) == asyncReturnType) + { + return Task.CompletedTask; + } + else if (Ext.IsClosedGenericOf(asyncReturnType, typeof(Task<>), out Type[] genericTypes)) + { + return s_taskFromResult + .MakeGenericMethod(genericTypes) + .Invoke(null, [result]); + } + else if (typeof(ValueTask) == asyncReturnType) + { + return new ValueTask(); + } + else if (Ext.IsClosedGenericOf(asyncReturnType, typeof(ValueTask<>), out genericTypes)) + { + return Activator.CreateInstance( + typeof(ValueTask<>).MakeGenericType(genericTypes), + result); + } + else if (Ext.IsClosedGenericOf(asyncReturnType, typeof(AsyncPageCollection<>), out genericTypes)) + { + return Activator.CreateInstance( + typeof(SyncToAsyncPageCollection<>).MakeGenericType(genericTypes), + result); + } + else if (Ext.IsClosedGenericOf(asyncReturnType, typeof(AsyncCollectionResult<>), out genericTypes)) + { + return Activator.CreateInstance( + typeof(SyncToAsyncCollectionResult<>).MakeGenericType(genericTypes), + result); + } + else if (Ext.IsClosedGenericOf(asyncReturnType, typeof(IAsyncEnumerable<>), out genericTypes)) + { + return Activator.CreateInstance( + typeof(SyncToAsyncEnumerable<>).MakeGenericType(genericTypes), + result); + } + else + { + throw new NotSupportedException("Don't know how to wrap the exception for " + asyncReturnType.FullName); + } + } + + /// + /// Wraps the exception from a synchronous method into the equivalent asynchronous return type. + /// + /// The asynchronous return type. + /// The exception to wrap. + /// The wrapped exception. + /// If we don't support the conversion. + protected virtual object? ToAsyncException(Type asyncReturnType, Exception ex) + { + if (typeof(Task) == asyncReturnType) + { + return Task.FromException(ex); + } + else if (Ext.IsClosedGenericOf(asyncReturnType, typeof(Task<>), out Type[] genericTypes)) + { + return s_taskFromException + .MakeGenericMethod(genericTypes) + .Invoke(null, [ex]); + } + else if (typeof(ValueTask) == asyncReturnType) + { + return new ValueTask(Task.FromException(ex)); + } + else if (Ext.IsClosedGenericOf(asyncReturnType, typeof(ValueTask<>), out genericTypes)) + { + var failedTask = s_taskFromException + .MakeGenericMethod(genericTypes) + .Invoke(null, [ex]); + return Activator.CreateInstance( + typeof(ValueTask<>).MakeGenericType(genericTypes), + failedTask); + } + else if (Ext.IsClosedGenericOf(asyncReturnType, typeof(AsyncPageCollection<>), out genericTypes)) + { + return Activator.CreateInstance( + typeof(SyncToAsyncPageCollection<>).MakeGenericType(genericTypes), + ex); + } + else if (Ext.IsClosedGenericOf(asyncReturnType, typeof(AsyncCollectionResult<>), out genericTypes)) + { + return Activator.CreateInstance( + typeof(SyncToAsyncCollectionResult<>).MakeGenericType(genericTypes), + ex); + } + else if (Ext.IsClosedGenericOf(asyncReturnType, typeof(IAsyncEnumerable<>), out genericTypes)) + { + return Activator.CreateInstance( + typeof(SyncToAsyncEnumerable<>).MakeGenericType(genericTypes), + ex); + } + else + { + throw new NotSupportedException("Don't know how to determine the synchronous equivalent return type of " + asyncReturnType.FullName); + } + } + + private static InvalidOperationException CreateEx(string description, MethodInfo method) + { + return new InvalidOperationException($"{description}. '{method.DeclaringType?.Name} -> {method.Name}'"); + } + + private static string RemoveAsyncSuffix(string? name) + { + if (name == null) + return string.Empty; + + int index = name.LastIndexOf(AsyncSuffix); + return index >= 0 + ? name.Substring(0, index) + : name; + } + + [DebuggerStepperBoundary] + private ISet DetermineValidSyncAsyncPairs(Type declaringType, BindingFlags flags) + { + // Group potential pairs based only on the method name removing the "Async" postfix + var potentialPairs = declaringType.GetMethods(flags) + .Where(m => !m.IsSpecialName) + .GroupBy(m => RemoveAsyncSuffix(m.Name)) + .OrderBy(g => g.Key) + .Select(g => new + { + g.Key, + Potentials = g.Select(m => new + { + m.Name, + Args = m.GetParameters().Select(p => p.ParameterType).ToArray(), + Return = m.ReturnType, + }) + // Order by name to ensure OperationName comes before OperationNameAsync + .OrderBy(p => p.Name) + // Match on method arguments + .GroupBy(g => g.Args, s_typeArrayEquality) + .Select(g => g.ToArray()) + }); + + // Now evaluate potential pairs to ensure that for each argument list for that method, there exists both a synchronous + // and asynchronous version with equivalent return types + HashSet validPairPrefixes = new(); + + foreach (var entry in potentialPairs) + { + bool allValid = entry.Potentials.All(matchedPair => + { + // because of the way we sorted above, we should have exactly 2 entries here, the first is the synchronous method + // the second the corresponding asynchronous method + return matchedPair.Length == 2 + && matchedPair[0].Name + AsyncSuffix == matchedPair[1].Name + && matchedPair[0].Return == ToSyncRetType(matchedPair[1].Return); + }); + + if (allValid) + { + validPairPrefixes.Add(entry.Key); + } + } + + return validPairPrefixes; + } + + /// + /// Helper comparer that compares all of the Types in an array for equality. + /// + private class TypeArrayEquality : IEqualityComparer + { + /// + public bool Equals(Type[]? x, Type[]? y) + { + if (ReferenceEquals(x, y)) + { + return true; + } + else if (x == null || y == null) + { + return false; + } + else if (x.LongLength != y.LongLength) + { + return false; + } + + for (long i = 0; i < x.LongLength; i++) + { + if (x[i] != y[i]) + { + return false; + } + } + + return true; + } + + /// + public int GetHashCode(Type[] obj) + { + if (obj == null) + { + return 0; + } + + int rollingHash = 1; // to distinguish empty case from null case + for (long i = 0; i < obj.LongLength; i++) + { + rollingHash = (rollingHash, obj[i].GetHashCode()).GetHashCode(); + } + + return rollingHash; + } + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsync/AutoSyncAsyncMixIn.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsync/AutoSyncAsyncMixIn.cs new file mode 100644 index 000000000..538f8fb75 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsync/AutoSyncAsyncMixIn.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.AutoSyncAsync; + +/// +/// An implementation of that allows you to get the original back, as well as a place +/// to store an additional context. +/// +public class AutoSyncAsyncMixIn : IAutoSyncAsync +{ + /// + /// Creates a new instance. + /// + /// The original instance. + public AutoSyncAsyncMixIn(object original, object? context = null) + { + Original = original; + Context = context; + } + + /// + public object Original { get; } + + /// + public object? Context { get; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsync/IAutoSyncAsync.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsync/IAutoSyncAsync.cs new file mode 100644 index 000000000..551df75c5 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsync/IAutoSyncAsync.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.AutoSyncAsync; + +/// +/// An interface that serves as a way to identify a dynamically proxied class that supports automatic sync and async testing. This +/// also provides a way to get the the original un-proxied instance. +/// instance. +/// +public interface IAutoSyncAsync +{ + /// + /// Gets the original un-proxied instance back. + /// + public object Original { get; } + + /// + /// Any additional context associated with the instrumented object (e.g. options used to create it). + /// + public object? Context { get; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsync/TestProxyGenerationHook.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsync/TestProxyGenerationHook.cs new file mode 100644 index 000000000..d11d5d86f --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsync/TestProxyGenerationHook.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Reflection; +using Castle.DynamicProxy; +using NUnit.Framework.Interfaces; +using NUnit.Framework.Internal; + +namespace OpenAI.TestFramework.AutoSyncAsync +{ + /// + /// Controls which methods are skipped during dynamic proxy generation. + /// + public class TestProxyGenerationHook : IProxyGenerationHook + { + /// + public void MethodsInspected() + { } + + /// + public void NonProxyableMemberNotification(Type type, MemberInfo memberInfo) + { } + + /// + public bool ShouldInterceptMethod(Type type, MethodInfo methodInfo) + { + IMethodInfo? testMethod = TestExecutionContext.CurrentContext.CurrentTest.Method; + + if (methodInfo == null + // Skip for special names (i.e. getters and setters) + || methodInfo.IsSpecialName + // Also for dispose methods + || methodInfo.Name == nameof(IDisposable.Dispose) + || methodInfo.Name == nameof(IAsyncDisposable.DisposeAsync) + // If we are running a sync only or async only, skip intercepting altogether + || testMethod?.IsDefined(false) == true + || testMethod?.IsDefined(false) == true) + { + return false; + } + + return true; + } + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsync/ThisLeakInterceptor.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsync/ThisLeakInterceptor.cs new file mode 100644 index 000000000..c7d6c1e70 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsync/ThisLeakInterceptor.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Diagnostics; +using Castle.DynamicProxy; + +namespace OpenAI.TestFramework.AutoSyncAsync; + +/// +/// A basic interceptor that prevents the leaking of the original un-proxied this instance as a return value. +/// +public class ThisLeakInterceptor : IInterceptor +{ + /// + [DebuggerStepThrough] + public void Intercept(IInvocation invocation) + { + invocation.Proceed(); + + if (invocation.ReturnValue == invocation.InvocationTarget) + { + invocation.ReturnValue = invocation.Proxy; + } + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsyncTestFixtureAttribute.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsyncTestFixtureAttribute.cs new file mode 100644 index 000000000..29983a94d --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/AutoSyncAsyncTestFixtureAttribute.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using NUnit.Framework; +using NUnit.Framework.Interfaces; +using NUnit.Framework.Internal; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework; + +/// +/// Attribute used to indicate that a test fixture should automatically be be run in both synchronous and asynchronous mode. +/// +[AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = true)] +public class AutoSyncAsyncTestFixtureAttribute : NUnitAttribute, IFixtureBuilder2 +{ + /// + public IEnumerable BuildFrom(ITypeInfo typeInfo) + => BuildFrom(typeInfo, null!); + + /// + public IEnumerable BuildFrom(ITypeInfo typeInfo, IPreFilter filter) + { + List suites = + [ + .. new TestFixtureAttribute([false]).BuildFrom(typeInfo, new AndPreFilter(filter, new SyncAsyncPreFilter(false))), + .. new TestFixtureAttribute([true]).BuildFrom(typeInfo, new AndPreFilter(filter, new SyncAsyncPreFilter(true))), + ]; + + return suites; + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/ClientTestBase.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/ClientTestBase.cs new file mode 100644 index 000000000..7406e738e --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/ClientTestBase.cs @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Diagnostics; +using Castle.DynamicProxy; +using NUnit.Framework; +using NUnit.Framework.Internal; +using OpenAI.TestFramework.AutoSyncAsync; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework; + +/// +/// Base class for client test cases. This provides support for writing only a test that uses the Async version of +/// methods, and automatically creating a test that uses the equivalent Sync version of a method. Please note that +/// this will only work for public virtual methods. In order for this to work, you should write a test that uses the +/// async version of a method. +/// +[AutoSyncAsyncTestFixture] +public abstract class ClientTestBase +{ + private static ProxyGenerator? s_proxyGenerator = null; + private static ThisLeakInterceptor? s_thisLeakInterceptor = null; + private static AsyncToSyncInterceptor? s_asyncInterceptor = null; + private static AsyncToSyncInterceptor? s_syncInterceptor = null; + + private CancellationTokenSource? _cts = null; + + /// + /// Creates a new instance. + /// + /// True to run the async version of a test, false to run the sync version of a test. + public ClientTestBase(bool isAsync) + { + IsAsync = isAsync; + } + + /// + /// Gets whether or not we are running async tests. + /// + public virtual bool IsAsync { get; } + + /// + /// Gets the start time of the test. + /// + public virtual DateTimeOffset TestStartTime => TestExecutionContext.CurrentContext.StartTime.ToUniversalTime(); + + /// + /// Gets the test timeout. + /// + public virtual TimeSpan TestTimeout => Debugger.IsAttached + ? Default.DebuggerAttachedTestTimeout + : Default.TestTimeout; + + /// + /// Gets the cancellation token to use + /// + public virtual CancellationToken Token => _cts?.Token ?? default; + + [SetUp] + public void TestSetup() + { + _cts?.Dispose(); + _cts = new CancellationTokenSource(TestTimeout); + } + + [TearDown] + public void TestCleanup() + { + _cts?.Dispose(); + _cts = null; + } + + /// + /// Gets the instance to use to create proxies of classes + /// that allow you inject additional functionality in for testing. + /// + protected static ProxyGenerator ProxyGenerator => s_proxyGenerator ??= new ProxyGenerator(); + + /// + /// An interceptor that prevents leaking a reference to the original instance as a return value from methods. + /// + protected static ThisLeakInterceptor ThisLeakInterceptor => s_thisLeakInterceptor ??= new ThisLeakInterceptor(); + + /// + /// An interceptor to force the use of async version of a method. + /// + protected static AsyncToSyncInterceptor UseSyncMethodInterceptor => s_syncInterceptor ??= new AsyncToSyncInterceptor(false); + + /// + /// An interceptor to force the use of sync version of a method. + /// + protected static AsyncToSyncInterceptor UseAsyncMethodInterceptor => s_asyncInterceptor ??= new AsyncToSyncInterceptor(true); + + /// + /// Wraps a client for automatic sync/async testing. This will return a proxied version of the client that will allow you to + /// automatically use the sync versions of a method. + /// + /// The type of the client instance. + /// The client instance to instrument for testing. + /// (Optional) Any additional context to associate with the wrapped client. + /// (Optional) Any additional interceptors to use. + /// The proxied version of the client. + public T WrapClient(T client, object? context = null, params IInterceptor[] interceptors) where T : class + => (T)WrapClient(typeof(T), client, context, interceptors); + + /// + /// Gets the original client from a wrapped client. + /// + /// The type of the client. + /// The wrapped client instance. + /// The original client instance. + /// The the client passed was not wrapped. + public virtual T UnWrap(T wrapped) where T : class + { + if (wrapped is IAutoSyncAsync instrumented) + { + return (T)instrumented.Original; + } + + throw new NotSupportedException($"That instance was not wrapped using {nameof(WrapClient)}"); + } + + /// + /// Gets the context associated with the wrapped instance. + /// + /// The type of the client. + /// The wrapped client. + /// The associated context for the wrapped instance. Will be null if none was set. + /// The the instance passed was not wrapped. + public virtual object? GetClientContext(T client) where T : class + { + if (client is IAutoSyncAsync instrumented) + { + return instrumented.Context; + } + + throw new NotSupportedException($"That instance was not wrapped using {nameof(WrapClient)}"); + } + + /// + /// Wraps a client with sync/async equivalent methods for testing. This enables the automatic testing of the sync version + /// of methods if you write an async test case. + /// + /// The type of the client. + /// The client instance to wrap. + /// (Optional) Any additional context to associate with the wrapped client. + /// (Optional) Any additional interceptors to include. + /// The wrapped version of the client. + protected internal virtual object WrapClient(Type instanceType, object client, object? context, IEnumerable? interceptors) + { + List allInterceptors = new(); + + if (interceptors != null) + { + allInterceptors.AddRange(interceptors); + } + + allInterceptors.Add(ThisLeakInterceptor); + allInterceptors.Add(IsAsync ? UseAsyncMethodInterceptor : UseSyncMethodInterceptor); + + ProxyGenerationOptions options = new(new TestProxyGenerationHook()); + options.AddMixinInstance(new AutoSyncAsyncMixIn(client, context)); + + object proxy = ProxyGenerator.CreateClassProxyWithTarget( + instanceType, + [], + client, + options, + allInterceptors.ToArray()); + + return proxy; + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/CapturedMessage.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/CapturedMessage.cs new file mode 100644 index 000000000..854b8cd7d --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/CapturedMessage.cs @@ -0,0 +1,175 @@ +// Copyright(c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework.Mocks; + +/// +/// A captured message. This is used as part of the . +/// +public abstract class CapturedMessage +{ + private static BinaryData? s_emptyData = null; + private static IReadOnlyDictionary>? s_emptyHeaders = null; + + /// + /// An empty header dictionary. + /// + public static IReadOnlyDictionary> EMPTY_HEADERS + => s_emptyHeaders ??= new Dictionary>(); + + /// + /// Empty binary data. + /// + public static BinaryData EMPTY_DATA => s_emptyData ??= new BinaryData(Array.Empty()); + + /// + /// Gets or sets the headers of the captured message. + /// + public IReadOnlyDictionary> Headers { get; init; } = EMPTY_HEADERS; + + /// + /// Gets or sets the content of the captured message. + /// + public BinaryData Content { get; init; } = EMPTY_DATA; + + /// + /// Copies the content from the provided to a new instance. + /// + /// The to copy the content from. + /// A new instance containing the copied content. + public static BinaryData CopyContent(HttpContent? content) + { + if (content == null) + { + return EMPTY_DATA; + } + + using Stream stream = content.ReadAsStreamAsync().Result; + return BinaryData.FromStream(stream); + } + + /// + /// Copies the headers from the provided and to a new dictionary. + /// + /// The to copy headers from. + /// The to copy headers from. + /// A new dictionary containing the copied headers. + public static IReadOnlyDictionary> CopyHeaders(HttpHeaders header, HttpContentHeaders? contentHeaders) + { + Dictionary> dict = new(StringComparer.OrdinalIgnoreCase); + foreach (var kvp in header) + { + dict[kvp.Key] = new List(kvp.Value); + } + + if (contentHeaders != null) + { + foreach (var kvp in contentHeaders) + { + var list = (List?)dict.GetValueOrDefault(kvp.Key); + if (list == null) + { + list = new List(); + dict[kvp.Key] = list; + } + + list.AddRange(kvp.Value); + } + } + + return dict; + } +} + +/// +/// A captured request. +/// +public class CapturedRequest : CapturedMessage +{ + /// + /// Creates a new instance. + /// + public CapturedRequest() + { } + + /// + /// Creates a new instance of using the provided . + /// + /// The to create the captured request from. + public CapturedRequest(HttpRequestMessage request) + { + if (request == null) + { + throw new ArgumentNullException(nameof(request)); + } + + Method = request.Method; + Uri = request.RequestUri; + Headers = CopyHeaders(request.Headers, request.Content?.Headers); + Content = CopyContent(request.Content); + } + + /// + /// Gets or sets the HTTP method of the captured request. + /// + public HttpMethod Method { get; init; } = HttpMethod.Get; + + /// + /// Gets or sets the URI of the captured request. + /// + public Uri? Uri { get; init; } +} + +/// +/// A captured response. +/// +public class CapturedResponse : CapturedMessage +{ + /// + /// Gets or sets the status code of the captured response. + /// + public HttpStatusCode Status { get; init; } = HttpStatusCode.OK; + + /// + /// Gets or sets the reason phrase of the captured response. + /// + public string? ReasonPhrase { get; init; } = "OK"; + + /// + /// Converts the captured response to an . + /// + /// The . + public HttpResponseMessage ToResponse() + { + const string contentPrefix = "Content-"; + + HttpResponseMessage response = new() + { + StatusCode = Status, + ReasonPhrase = ReasonPhrase + }; + + foreach (var kvp in Headers.Where(h => h.Key?.StartsWith(contentPrefix) == false)) + { + response.Headers.TryAddWithoutValidation(kvp.Key, kvp.Value); + } + + if (Content != null && Content.ToMemory().Length > 0) + { + response.Content = new StreamContent(Content.ToStream()); + foreach (var kvp in Headers.Where(h => h.Key?.StartsWith(contentPrefix) == true)) + { + response.Content.Headers.TryAddWithoutValidation(kvp.Key, kvp.Value); + } + } + + return response; + } +} + + diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockAsyncCollectionResult.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockAsyncCollectionResult.cs new file mode 100644 index 000000000..86e871aa6 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockAsyncCollectionResult.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; + +namespace OpenAI.TestFramework.Mocks; + +/// +/// Represents a mock implementation of the class. +/// +/// The type of the values in the collection. +public class MockAsyncCollectionResult : AsyncCollectionResult +{ + private readonly Func> _enumerateAsyncFunc; + + /// + /// Initializes a new instance of the class + /// with the specified asynchronous enumeration function and optional pipeline response. + /// + /// The function that asynchronously enumerates the values in the collection. + /// The optional pipeline response. + public MockAsyncCollectionResult(Func> enumerateAsyncFunc, PipelineResponse? response = null) : + base(response ?? new MockPipelineResponse()) + { + _enumerateAsyncFunc = enumerateAsyncFunc ?? throw new ArgumentNullException(nameof(enumerateAsyncFunc)); + } + + /// + public override IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + => _enumerateAsyncFunc().GetAsyncEnumerator(cancellationToken); +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockAsyncPageCollection.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockAsyncPageCollection.cs new file mode 100644 index 000000000..424681252 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockAsyncPageCollection.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; + +namespace OpenAI.TestFramework.Mocks; + +/// +/// Represents a mock implementation of the class. +/// +/// The type of the values in the collection. +public class MockAsyncPageCollection : AsyncPageCollection +{ + private readonly Func> _enumerateAsyncFunc; + private readonly PipelineResponse _response; + private readonly int _itemsPerPage; + private PageResult? _currentPage; + + /// + /// Initializes a new instance. + /// + /// The function that enumerates the collection asynchronously. + /// The pipeline response. + public MockAsyncPageCollection(Func> enumerateAsyncFunc, PipelineResponse response, int itemsPerPage = 5) + { + if (itemsPerPage <= 0) + { + throw new ArgumentOutOfRangeException(nameof(itemsPerPage)); + } + + _enumerateAsyncFunc = enumerateAsyncFunc ?? throw new ArgumentNullException(nameof(enumerateAsyncFunc)); + _response = response; + _itemsPerPage = itemsPerPage; + } + + /// + protected override Task> GetCurrentPageAsyncCore() + => Task.FromResult(_currentPage ?? throw new InvalidOperationException("Please call MoveNextAsync first.")); + + /// + protected override async IAsyncEnumerator> GetAsyncEnumeratorCore(CancellationToken cancellationToken = default) + { + List items = new(_itemsPerPage); + int pageStart = 0; + int rolling = 0; + + await foreach (TValue value in _enumerateAsyncFunc()) + { + items.Add(value); + rolling++; + if (items.Count == _itemsPerPage) + { + _currentPage = PageResult.Create(items, ToContinuation(pageStart), ToContinuation(rolling), _response); + yield return _currentPage; + items.Clear(); + pageStart = rolling; + } + } + + if (items.Count > 0) + { + _currentPage = PageResult.Create(items, ToContinuation(pageStart), ToContinuation(rolling), _response); + yield return _currentPage; + } + } + + private static ContinuationToken ToContinuation(int offset) + => ContinuationToken.FromBytes(BinaryData.FromBytes(BitConverter.GetBytes(offset))); +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockCollectionResult.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockCollectionResult.cs new file mode 100644 index 000000000..e12e34b34 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockCollectionResult.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; + +namespace OpenAI.TestFramework.Mocks +{ + /// + /// Represents a mock implementation of the class. + /// + /// The type of the values in the collection. + public class MockCollectionResult : CollectionResult + { + private readonly Func> _enumerateFunc; + + /// + /// Initializes a new instance of the class with the specified enumeration + /// function and optional pipeline response. + /// + /// The function used to enumerate the collection. + /// The pipeline response associated with the collection. + public MockCollectionResult(Func> enumerateFunc, PipelineResponse? response = null) : + base(response ?? new MockPipelineResponse()) + { + _enumerateFunc = enumerateFunc ?? throw new ArgumentNullException(nameof(enumerateFunc)); + } + + /// + /// Returns an enumerator that iterates through the collection. + /// + /// An enumerator that can be used to iterate through the collection. + public override IEnumerator GetEnumerator() + => _enumerateFunc().GetEnumerator(); + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockHeaders.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockHeaders.cs new file mode 100644 index 000000000..7cac49376 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockHeaders.cs @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Mocks; + +/// +/// Basic implementation of headers. +/// +public class MockHeaders +{ + private IDictionary> _headers = + new Dictionary>(StringComparer.OrdinalIgnoreCase); + + /// + /// Adds a header value. + /// + /// The name of the header. + /// The value to add. + public virtual void Add(string name, string value) + { + IList? existing; + if (!_headers.TryGetValue(name, out existing)) + { + existing = new List(); + _headers[name] = existing; + } + + existing.Add(value); + } + + /// + /// Removes all values of a header. + /// + /// The name of the header to remove. + /// True if we removed a value, false otherwise. + public virtual bool Remove(string name) => _headers.Remove(name); + + /// + /// Sets the value for a header. This will override all existing values. + /// + /// The name of the header. + /// The value to set. + public virtual void Set(string name, string value) => _headers[name] = new List() { value }; + + /// + /// Gets an enumerator for the header values. In the case of a header with more than one value, they will be joined into + /// a single comma separated string. + /// + /// The enumerator. + public virtual IEnumerator> GetEnumerator() + => _headers + .Select(kvp => new KeyValuePair(kvp.Key, string.Join(",", kvp.Value))) + .GetEnumerator(); + + /// + /// Gets the value for a header. In the case of a header with more than one value, they will be joined into a single comma + /// separated string. + /// + /// The name of the header. + /// The value of the headers + /// True if the header was found, false otherwise. + public virtual bool TryGetValue(string name, out string? value) + { + if (_headers.TryGetValue(name, out IList? existing)) + { + value = string.Join(",", existing); + return true; + } + + value = null; + return false; + } + + /// + /// Gets the values for a header. + /// + /// The name of the header. + /// All of the values for the header. + /// True if the header was found, false otherwise. + public virtual bool TryGetValues(string name, out IEnumerable? values) + { + if (_headers.TryGetValue(name, out IList? existing)) + { + values = existing; + return true; + } + + values = null; + return false; + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockHttpMessageHandler.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockHttpMessageHandler.cs new file mode 100644 index 000000000..74f4b9980 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockHttpMessageHandler.cs @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Net; +using System.Net.Http; + +namespace OpenAI.TestFramework.Mocks; + +/// +/// A mock message handler that doesn't use the network. This captures all received requests, and allows you to specify a handler +/// to hand craft response messages. This can be useful for unit testing. +/// +public class MockHttpMessageHandler : HttpMessageHandler, IDisposable +{ + /// + /// Handles a captured request. + /// + /// The captured request. + /// The corresponding response. + public delegate CapturedResponse RequestHandlerDelegate(CapturedRequest request); + + private RequestHandlerDelegate _handler; + private List _requests; + private List _responses; + private PipelineTransport? _transport; + + /// + /// Creates a new instance. + /// + /// (Optional) The handler to use to generate responses. Default returns an empty + /// response body with HTTP 204 + public MockHttpMessageHandler(RequestHandlerDelegate? requestHandler = null) + { + _handler = requestHandler ?? ReturnEmpty; + _requests = new List(); + _responses = new List(); + } + + /// + /// Event raised when a request is received. + /// + public event EventHandler? OnRequest; + + /// + /// Event raised when a response is generated. + /// + public event EventHandler? OnResponse; + + /// + /// Gets the transport to pass to your System.ClientModel based clients. + /// + public PipelineTransport Transport => _transport ??= new HttpClientPipelineTransport(new HttpClient(this)); + + /// + /// All received requests. + /// + public IReadOnlyList Requests => _requests; + + /// + /// All generated responses. + /// + public IReadOnlyList Responses => _responses; + + /// + /// Default handler that always returns an empty JSON payload as the response with the correct headers set + /// + /// The request + /// An empty successful JSON response + public static CapturedResponse ReturnEmptyJson(CapturedRequest request) + => new() + { + Status = HttpStatusCode.OK, + ReasonPhrase = "OK", + Content = BinaryData.FromString("{}"), + Headers = new Dictionary>() + { + ["Content-Type"] = ["application/json"], + ["Content-Length"] = ["2"] + } + }; + + /// + /// Default handler that returns an empty HTTP 204 payload + /// + /// The request + /// An HTTP 204 empty response + public static CapturedResponse ReturnEmpty(CapturedRequest request) + => new() { Status = HttpStatusCode.NoContent }; + + private HttpResponseMessage HandleRequest(HttpRequestMessage request, CancellationToken token) + { + try + { + CapturedRequest capturedRequest = new(request); + OnRequest?.Invoke(this, capturedRequest); + _requests.Add(capturedRequest); + + CapturedResponse capturedResponse = _handler(capturedRequest); + OnResponse?.Invoke(this, capturedResponse); + _responses.Add(capturedResponse); + + return capturedResponse.ToResponse(); + } + catch (Exception ex) + { + throw new ClientResultException("Failed to process request", null, ex); + } + } + + #region HttpMessagHandler implementation + +#if NET + override +#endif + protected HttpResponseMessage Send(HttpRequestMessage request, CancellationToken cancellationToken) + => HandleRequest(request, cancellationToken); + + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + => Task.FromResult(HandleRequest(request, cancellationToken)); + + #endregion +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockPageCollection.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockPageCollection.cs new file mode 100644 index 000000000..1f08987ae --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockPageCollection.cs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; + +namespace OpenAI.TestFramework.Mocks; + +/// +/// Represents a mock implementation of the class. +/// +/// The type of the values in the collection. +public class MockPageCollection : PageCollection +{ + private readonly Func> _enumerateFunc; + private readonly PipelineResponse _response; + private readonly int _itemsPerPage; + private PageResult? _currentPage; + + /// + /// Creates a new instance. + /// + /// The function used to enumerate the collection. + /// The pipeline response. + /// (Optional) The number of items per page. + public MockPageCollection(Func> enumerateFunc, PipelineResponse response, int itemsPerPage = 5) + { + if (itemsPerPage <= 0) + { + throw new ArgumentOutOfRangeException(nameof(itemsPerPage)); + } + + _enumerateFunc = enumerateFunc ?? throw new ArgumentNullException(nameof(enumerateFunc)); + _response = response; + _itemsPerPage = itemsPerPage; + } + + /// + protected override PageResult GetCurrentPageCore() + => _currentPage ?? throw new InvalidOperationException("Please call MoveNextAsync first."); + + /// + protected override IEnumerator> GetEnumeratorCore() + { + List items = new(_itemsPerPage); + int pageStart = 0; + int rolling = 0; + + foreach (TValue item in _enumerateFunc()) + { + items.Add(item); + rolling++; + if (items.Count == _itemsPerPage) + { + _currentPage = PageResult.Create(items, ToContinuation(pageStart), ToContinuation(rolling), _response); + yield return _currentPage; + items.Clear(); + pageStart = rolling; + } + } + + if (items.Count > 0) + { + _currentPage = PageResult.Create(items, ToContinuation(pageStart), null, _response); + yield return _currentPage; + } + } + + private static ContinuationToken ToContinuation(int offset) + => ContinuationToken.FromBytes(BinaryData.FromBytes(BitConverter.GetBytes(offset))); +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockPipelineResponse.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockPipelineResponse.cs new file mode 100644 index 000000000..1ade396df --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockPipelineResponse.cs @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; + +namespace OpenAI.TestFramework.Mocks; + +/// +/// A mock implementation of a pipeline response +/// +public class MockPipelineResponse : PipelineResponse +{ + private Stream? _contentStream; + private BinaryData? _buffered; + + /// + /// Creates a new instance. + /// + /// (Optional) The HTTP status. + /// (Optional) The HTTP reason phrase. + /// (Optional) The HTTP response body content. + public MockPipelineResponse( + int? status = null, + string? reasonPhrase = null, + BinaryData? content = null) + { + Status = status ?? 200; + ReasonPhrase = reasonPhrase ?? "OK"; + _buffered = content; + ContentStream = content?.ToStream(); + HeadersCore = new MockResponseHeaders(); + } + + /// + public override int Status { get; } + + /// + public override string ReasonPhrase { get; } + + /// + public override Stream? ContentStream + { + get => _contentStream; + set + { + _contentStream = value; + _buffered = null; + } + } + + /// + public override BinaryData Content => _buffered ?? throw new InvalidOperationException("Response content is not yet buffered"); + + /// + protected override PipelineResponseHeaders HeadersCore { get; } + + /// + public override BinaryData BufferContent(CancellationToken cancellationToken = default) + => BufferContentSyncAsync(false, cancellationToken).GetAwaiter().GetResult(); + + /// + public override ValueTask BufferContentAsync(CancellationToken cancellationToken = default) + => BufferContentSyncAsync(true, cancellationToken); + + /// + public override void Dispose() + { + ContentStream?.Dispose(); + } + + private async ValueTask BufferContentSyncAsync(bool isAsync, CancellationToken token) + { + if (_buffered != null) + { + return _buffered; + } + + _buffered = ContentStream == null + ? BinaryData.FromBytes(Array.Empty()) + : isAsync + ? await BinaryData.FromStreamAsync(ContentStream, token).ConfigureAwait(false) + : BinaryData.FromStream(ContentStream); + + ContentStream?.Dispose(); + ContentStream = _buffered.ToStream(); + return _buffered; + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockRequestHeaders.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockRequestHeaders.cs new file mode 100644 index 000000000..e03c4cd18 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockRequestHeaders.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; + +namespace OpenAI.TestFramework.Mocks; + +/// +/// Mock implementation of request headers. +/// +public class MockRequestHeaders : PipelineRequestHeaders +{ + private MockHeaders _headers = new(); + + /// + public override void Add(string name, string value) + => _headers.Add(name, value); + + /// + public override bool Remove(string name) + => _headers.Remove(name); + + /// + public override void Set(string name, string value) + => _headers.Set(name, value); + + /// + public override IEnumerator> GetEnumerator() + => _headers.GetEnumerator(); + + /// + public override bool TryGetValue(string name, out string? value) + => _headers.TryGetValue(name, out value); + + /// + public override bool TryGetValues(string name, out IEnumerable? values) + => _headers.TryGetValues(name, out values); +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockResponseHeaders.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockResponseHeaders.cs new file mode 100644 index 000000000..aead0b4b8 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockResponseHeaders.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; + +namespace OpenAI.TestFramework.Mocks; + +/// +/// Mock implementation of response headers. +/// +public class MockResponseHeaders : PipelineResponseHeaders +{ + private MockHeaders _headers = new(); + + /// + public override IEnumerator> GetEnumerator() + => _headers.GetEnumerator(); + + /// + public override bool TryGetValue(string name, out string? value) + => _headers.TryGetValue(name, out value); + + /// + public override bool TryGetValues(string name, out IEnumerable? values) + => _headers.TryGetValues(name, out values); +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockRestService.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockRestService.cs new file mode 100644 index 000000000..58420f679 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockRestService.cs @@ -0,0 +1,413 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Concurrent; +using System.Net; +using System.Net.Sockets; +using System.Text.Json; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework.Mocks; + +/// +/// Represents a mock REST service for testing purposes. +/// +/// The type of data stored in the service. +public class MockRestService : IDisposable +{ + /// + /// Represents an entry in the mock REST service. + /// + /// The ID of the entry. + /// The data associated with the entry. + public record Entry(string id, TData data) + { +#if NETFRAMEWORK + public Entry() : this(string.Empty, default!) + { + // .Net framework System.Text.Json cannot deserialize records without a parameterless constructor + } +#endif + }; + + /// + /// Represents an error in the mock REST service. + /// + /// The error code. + /// The error message. + /// The stack trace of the error. + public record Error(int error, string message, string? stack = null); + + private static readonly JsonSerializerOptions s_options = new() + { + WriteIndented = true, +#pragma warning disable SYSLIB0020 + IgnoreNullValues = true +#pragma warning restore SYSLIB0020 + }; + + private ConcurrentDictionary _data; + private HttpListener _listener; + private CancellationTokenSource _cts; + private Task _workerTask; + + /// + /// Initializes a new instance of the class. + /// + /// (Optional) The base path of the service. + /// (Optional) The port number to listen on. If set to 0, a port will be automatically selected. + public MockRestService(string? basePath = null, ushort port = 0) + { + _data = new(); + basePath = basePath?.EnsureEndsWith("/"); + + int maxAttempts = port == 0 ? 15 : 1; + Exception? ex = null; + for (int i = 0; _listener == null && i < maxAttempts; i++) + { + _listener = TryStartListener(basePath ?? string.Empty, port, out ex)!; + } + + if (_listener == null || ex != null) + { + throw new ApplicationException("Failed to start the mock rest service", ex); + } + + HttpEndpoint = TerminatePathWithSlash(new Uri(_listener.Prefixes.First())); + _cts = new(); + _workerTask = Task.Run(() => WorkerAsync(_cts.Token), _cts.Token); + } + + /// + /// Gets the HTTP endpoint of the mock REST service. + /// + public Uri HttpEndpoint { get; } + + /// + /// Gets all entries in the mock REST service. + /// + /// An enumerable collection of entries. + public virtual IEnumerable GetAll() + => _data.Select(kvp => new Entry(kvp.Key, kvp.Value)); + + /// + /// Tries to get an entry from the mock REST service. + /// + /// The ID of the entry to get. + /// When this method returns, contains the entry associated with the specified ID, if found; otherwise, null. + /// true if the entry was found; otherwise, false. + public virtual bool TryGet(string id, out Entry? entry) + { + if (_data.TryGetValue(id, out TData? value)) + { + entry = new(id, value); + return true; + } + + entry = null; + return false; + } + + /// + /// Tries to add an entry to the mock REST service. + /// + /// The ID of the entry to add. + /// The data associated with the entry. + /// When this method returns, contains the added entry, if successful; otherwise, null. + /// true if the entry was added successfully; otherwise, false. + public virtual bool TryAdd(string id, TData data, out Entry? entry) + { + entry = null; + + if (_data.TryAdd(id, data)) + { + entry = new(id, data); + return true; + } + + return false; + } + + /// + /// Tries to delete an entry from the mock REST service. + /// + /// The ID of the entry to delete. + /// true if the entry was deleted successfully; otherwise, false. + public virtual bool TryDelete(string id) + => _data.TryRemove(id, out _); + + /// + /// Tries to update an entry in the mock REST service. + /// + /// The ID of the entry to update. + /// The updated data for the entry. + /// When this method returns, contains the updated entry, if successful; otherwise, null. + /// true if the entry was updated successfully; otherwise, false. + public virtual bool TryUpdate(string id, TData data, out Entry? entry) + { + _data[id] = data; + entry = new(id, data); + return true; + } + + /// + /// Resets the mock REST service removing all entries. + /// + public virtual void Reset() + => _data.Clear(); + + /// + /// Disposes of the resources used by the mock REST service. + /// + public void Dispose() + { + _cts.Cancel(); + _listener.Stop(); + try { _workerTask.Wait(500); } catch { } + _listener.Close(); + _cts.Dispose(); + } + + /// + /// Worker method that handles incoming HTTP requests. + /// + /// The cancellation token. + protected virtual async Task WorkerAsync(CancellationToken token) + { + while (!token.IsCancellationRequested) + { + HttpListenerContext context = await _listener.GetContextAsync().ConfigureAwait(false); + HttpListenerRequest request = context.Request; + HttpListenerResponse response = context.Response; + + if (request == null || request.Url == null) + { + context.Response?.Abort(); + continue; + } + + try + { + response.ContentLength64 = 0; + + string? id = GetId(HttpEndpoint, request.Url); + switch (request.HttpMethod.ToUpperInvariant()) + { + case "GET": + if (id == null) + { + // Send down all data + IEnumerable allData = GetAll(); + WriteJsonResponse(response, 200, allData); + } + else if (TryGet(id, out Entry? entry) && entry != null) + { + WriteJsonResponse(response, 200, entry); + } + else + { + response.StatusCode = (int)HttpStatusCode.NotFound; + } + break; + + case "POST": + if (id == null) + { + response.StatusCode = (int)HttpStatusCode.BadRequest; + } + else + { + TData? data = ReadBody(request); + if (data == null) + { + response.StatusCode = (int)HttpStatusCode.GatewayTimeout; + } + else if (TryAdd(id, data, out Entry? entry)) + { + if (entry == null) + { + response.StatusCode = (int)HttpStatusCode.NoContent; + } + else + { + WriteJsonResponse(response, 200, entry); + } + } + else + { + response.StatusCode = (int)HttpStatusCode.Conflict; + } + } + break; + + case "PUT": + if (id == null) + { + response.StatusCode = (int)HttpStatusCode.BadRequest; + } + else + { + TData? data = ReadBody(request); + if (data == null) + { + response.StatusCode = (int)HttpStatusCode.GatewayTimeout; + } + else if (TryUpdate(id, data, out Entry? entry)) + { + if (entry == null) + { + response.StatusCode = (int)HttpStatusCode.NoContent; + } + else + { + WriteJsonResponse(response, 200, entry); + } + } + else + { + response.StatusCode = (int)HttpStatusCode.NotFound; + response.ContentLength64 = 0; + } + } + break; + + case "DELETE": + response.ContentLength64 = 0; + if (id == null) + { + response.StatusCode = (int)HttpStatusCode.BadRequest; + } + else if (TryDelete(id)) + { + response.StatusCode = (int)HttpStatusCode.NoContent; + } + else + { + response.StatusCode = (int)HttpStatusCode.NotFound; + } + break; + + default: + response.StatusCode = (int)HttpStatusCode.MethodNotAllowed; + break; + } + + response.Close(); + } + catch (Exception ex) + { + response.StatusCode = (int)HttpStatusCode.InternalServerError; + try + { + if (response.OutputStream.Length > 0 || response.OutputStream.CanSeek) + { + response.OutputStream.SetLength(0); + } + + if (response.OutputStream.Length == 0) + { + WriteJsonResponse( + response, + (int)HttpStatusCode.InternalServerError, + new Error( + 500, + ex.Message +#if DEBUG + , ex.StackTrace +#endif + )); + } + } + catch { /* we tried */ } + } + } + } + + private static ushort GetFreePort() + { + TcpListener? listener = null; + try + { + listener = new TcpListener(IPAddress.Loopback, 0); + listener.Start(); + return (ushort)((IPEndPoint)listener.LocalEndpoint).Port; + } + finally + { + listener?.Stop(); + } + } + + private static HttpListener? TryStartListener(string basePath, ushort port, out Exception? ex) + { + if (port == 0) + { + port = GetFreePort(); + } + + HttpListener? listener = null; + try + { + listener = new(); + listener.Prefixes.Add($"http://localhost:{port}/{basePath}"); + listener.Start(); + ex = null; + return listener; + } + catch (Exception e) + { + listener?.Close(); + ex = e; + return null; + } + } + + private static Uri TerminatePathWithSlash(Uri uri) + { + if (uri.IsAbsoluteUri) + { + if (!uri.AbsolutePath.EndsWith("/")) + { + UriBuilder builder = new(uri); + builder.Path += '/'; + return builder.Uri; + } + } + else if (!uri.OriginalString.EndsWith("/")) + { + return new Uri(uri.OriginalString + '/', UriKind.RelativeOrAbsolute); + } + + return uri; + } + + private static string? GetId(Uri baseUri, Uri requestUri) + { + Uri normalizedRequestUri = TerminatePathWithSlash(requestUri); + Uri relative = baseUri.MakeRelativeUri(normalizedRequestUri); + return relative.OriginalString.Split(["/"], StringSplitOptions.RemoveEmptyEntries).FirstOrDefault(); + } + + private static TData? ReadBody(HttpListenerRequest request) + { + if (request.ContentLength64 == 0) + { + return default; + } + + return JsonHelpers.Deserialize(request.InputStream, s_options); + } + + private static void WriteJsonResponse(HttpListenerResponse response, int status, T data) + { + response.StatusCode = status; + + using MemoryStream buffer = new(); + JsonHelpers.Serialize(buffer, data, s_options); + buffer.Seek(0, SeekOrigin.Begin); + + response.ContentType = "application/json"; + response.ContentLength64 = buffer.Length; + buffer.CopyTo(response.OutputStream); + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockRestServiceClient.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockRestServiceClient.cs new file mode 100644 index 000000000..a3fb851d3 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Mocks/MockRestServiceClient.cs @@ -0,0 +1,274 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Globalization; +using System.Net.Http; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework.Mocks; + +/// +/// A client for . +/// +/// The type of data used by the client. +public class MockRestServiceClient : IDisposable +{ + private ClientPipeline _pipeline; + private Uri _baseUri; + + /// + /// Only used to generate a dynamic proxy for testing. Do not use this yourself. + /// + internal MockRestServiceClient() + { + _pipeline = null!; + _baseUri = null!; + } + + /// + /// Initializes a new instance of the class with the specified service URI and options. + /// + /// The service URI. + /// The client pipeline options. + public MockRestServiceClient(Uri serviceUri, ClientPipelineOptions? options = null) + { + _pipeline = ClientPipeline.Create(options); + _baseUri = serviceUri ?? throw new ArgumentNullException(nameof(serviceUri)); + } + + /// + /// Adds data asynchronously to the service with the specified ID. + /// + /// The ID of the data. + /// The data to add. + /// The cancellation token. + /// A task representing the asynchronous operation. + public virtual Task AddAsync(string id, TData data, CancellationToken token = default) + { + if (string.IsNullOrWhiteSpace(id)) + throw new ArgumentException("Value cannot be null or whitespace.", nameof(id)); + + ValidateData(data); + return SendSyncOrAsync(true, HttpMethod.Post, id, data, token).AsTask(); + } + + /// + /// Adds data synchronously to the service with the specified ID. + /// + /// The ID of the data. + /// The data to add. + /// The cancellation token. + /// The result of the operation. + public virtual ClientResult Add(string id, TData data, CancellationToken token = default) + { + if (string.IsNullOrWhiteSpace(id)) + throw new ArgumentException("Value cannot be null or whitespace.", nameof(id)); + + ValidateData(data); + return SendSyncOrAsync(false, HttpMethod.Post, id, data, token).GetAwaiter().GetResult(); + } + + /// + /// Gets data asynchronously from the service with the specified ID. Will return null if the data does not exist. + /// + /// The ID of the data. + /// The cancellation token. + /// A task representing the asynchronous operation. + public virtual async Task> GetAsync(string id, CancellationToken token = default) + { + if (string.IsNullOrWhiteSpace(id)) + throw new ArgumentException("Value cannot be null or whitespace.", nameof(id)); + + try + { + ClientResult result = await SendSyncOrAsync(true, HttpMethod.Get, id, default, token) + .ConfigureAwait(false); + + var response = result.GetRawResponse(); + return ClientResult.FromOptionalValue( + response.Content.ToObjectFromJson.Entry>().data, + response); + } + catch (ClientResultException ex) + { + if (ex.GetRawResponse()?.Status == 404) + { + return ClientResult.FromOptionalValue(default, ex.GetRawResponse()!); + } + + throw; + } + } + + /// + /// Gets data synchronously from the service with the specified ID. Will return null if the data does not exist. + /// + /// The ID of the data. + /// The cancellation token. + /// The result of the operation. + public virtual ClientResult Get(string id, CancellationToken token = default) + { + if (string.IsNullOrWhiteSpace(id)) + throw new ArgumentException("Value cannot be null or whitespace.", nameof(id)); + + try + { + ClientResult result = SendSyncOrAsync(false, HttpMethod.Get, id, default, token).GetAwaiter().GetResult(); + var response = result.GetRawResponse(); + return ClientResult.FromOptionalValue( + response.Content.ToObjectFromJson.Entry>().data, + response); + } + catch (ClientResultException ex) + { + if (ex.GetRawResponse()?.Status == 404) + { + return ClientResult.FromOptionalValue(default, ex.GetRawResponse()!); + } + + throw; + } + } + + /// + /// Removes data asynchronously from the service with the specified ID. + /// + /// The ID of the data. + /// The cancellation token. + /// A task representing the asynchronous operation. + public virtual async Task> RemoveAsync(string id, CancellationToken token = default) + { + if (string.IsNullOrWhiteSpace(id)) + throw new ArgumentException("Value cannot be null or whitespace.", nameof(id)); + + try + { + ClientResult result = await SendSyncOrAsync(true, HttpMethod.Delete, id, default, token); + return ClientResult.FromValue(true, result.GetRawResponse()); + } + catch (ClientResultException ex) + { + if (ex.GetRawResponse()?.Status == 404) + { + return ClientResult.FromValue(false, ex.GetRawResponse()!); + } + + throw; + } + } + + /// + /// Removes data synchronously from the service with the specified ID. + /// + /// The ID of the data. + /// The cancellation token. + /// The result of the operation. + public virtual ClientResult Remove(string id, CancellationToken token = default) + { + if (string.IsNullOrWhiteSpace(id)) + throw new ArgumentException("Value cannot be null or whitespace.", nameof(id)); + + try + { + ClientResult result = SendSyncOrAsync(false, HttpMethod.Delete, id, default, token).GetAwaiter().GetResult(); + return ClientResult.FromValue(true, result.GetRawResponse()); + } + catch (ClientResultException ex) + { + if (ex.GetRawResponse()?.Status == 404) + { + return ClientResult.FromValue(false, ex.GetRawResponse()!); + } + + throw; + } + } + + /// + /// Disposes of the resources used by the client. + /// + public virtual void Dispose() + { + // no obvious way to dispose of the pipeline, nor the inner transport + } + + /// + /// Validates the data before sending it to the service. + /// + /// The data to validate. + protected virtual void ValidateData(TData? data) + { + if (data == null) + { + throw new ArgumentNullException(nameof(data)); + } + } + + /// + /// Sends the request to the service synchronously or asynchronously. This will serialize the passed in data to JSON using the default + /// serializer. + /// + /// Indicates whether the request should be sent asynchronously. + /// The HTTP method. + /// The ID of the data. + /// The data to send. + /// The cancellation token. + /// The result of the operation. + protected async ValueTask SendSyncOrAsync(bool isAsync, HttpMethod method, string? id, TData? data, CancellationToken token) + { + UriBuilder builder = new(_baseUri); + if (id != null) + { + builder.Path += id; + } + + PipelineMessage message = _pipeline.CreateMessage(); + message.Request.Method = method.Method; + message.Request.Uri = builder.Uri; + message.Apply(new RequestOptions() + { + CancellationToken = token, + BufferResponse = true + }); + + if (data == null) + { + message.Request.Headers.Set("Content-Length", "0"); + } + else + { + using MemoryStream stream = new(); + JsonHelpers.Serialize(stream, data); + var binaryData = BinaryData.FromBytes(new ReadOnlyMemory(stream.GetBuffer(), 0, (int)stream.Length)); + + message.Request.Headers.Set("Content-Length", stream.Length.ToString(CultureInfo.InvariantCulture)); + message.Request.Headers.Set("Content-Type", "application/json"); + message.Request.Content = BinaryContent.Create(binaryData); + } + + if (isAsync) + { + await _pipeline.SendAsync(message).ConfigureAwait(false); + } + else + { + _pipeline.Send(message); + } + + if (message.Response?.IsError == true) + { + if (message.Response.Content?.ToMemory().Length > 0) + { + var error = message.Response.Content.ToObjectFromJson.Error>(); + throw new ClientResultException($"Error {error.error}: {error.message}", message.Response); + } + + throw new ClientResultException(message.Response); + } + + return ClientResult.FromResponse(message.Response!); + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/OpenAI.TestFramework.csproj b/.dotnet.azure/sdk/openai/tools/TestFramework/src/OpenAI.TestFramework.csproj new file mode 100644 index 000000000..dc92fd798 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/OpenAI.TestFramework.csproj @@ -0,0 +1,39 @@ + + + + $(RequiredTargetFrameworks) + enable + enable + latest + + + + + Utils\Polyfill\%(RecursiveDir)\%(Filename).cs + + + + + + + + + + + + + + 0024000004800000940000000602000000240000525341310004000001000100c547cac37abd99c8db225ef2f6c8a3602f3b3606cc9891605d02baa56104f4cfc0734aa39b93bf7852f7d9266654753cc297e7d2edfe0bac1cdcf9f717241550e0a7b191195b7667bb4f64bcb8e2121380fd1d9d46ad2d92d2d15605093924cceaf74c4861eff62abf69b9291ed0a340e113be11e6a7d3113e92484cf7045cc7 + + + + + + + <_Parameter1>TestProxyPath + <_Parameter2>$(NuGetPackageRoot)\azure.sdk.tools.testproxy\$(TestProxyVersion)\tools\net6.0\any\Azure.Sdk.Tools.TestProxy.dll + + + + diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/RecordedClientTestBase.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/RecordedClientTestBase.cs new file mode 100644 index 000000000..27d6ca2e7 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/RecordedClientTestBase.cs @@ -0,0 +1,423 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Diagnostics; +using System.Net; +using System.Text; +using NUnit.Framework; +using NUnit.Framework.Internal; +using OpenAI.TestFramework.Recording; +using OpenAI.TestFramework.Recording.Proxy; +using OpenAI.TestFramework.Recording.Proxy.Service; +using OpenAI.TestFramework.Recording.RecordingProxy; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework; + +/// +/// Base class for client test cases that supports recording and playback of HTTP/HTTPS REST requests. This recording +/// support is provided by use of the Test Proxy . +/// This provides the basic framework to start the Test Proxy, create a recording for a test or playback a recording +/// for a test. It also provides support for automatic testing of async and sync versions of methods (see +/// for more details). +/// +[NonParallelizable] +public abstract class RecordedClientTestBase : ClientTestBase +{ + /// + /// Invalid characters that will be removed from test names when creating recordings. + /// + /// + /// Using Windows version as it is the most restrictive of all platforms: + /// + /// + protected static readonly ISet s_invalidChars = new HashSet() + { + '\"', '<', '>', '|', '\0', + (char)1, (char)2, (char)3, (char)4, (char)5, (char)6, (char)7, (char)8, (char)9, (char)10, + (char)11, (char)12, (char)13, (char)14, (char)15, (char)16, (char)17, (char)18, (char)19, (char)20, + (char)21, (char)22, (char)23, (char)24, (char)25, (char)26, (char)27, (char)28, (char)29, (char)30, + (char)31, ':', '*', '?', '\\', '/' + }; + + private DateTimeOffset _testStartTime; + private TestRecordingOptions _options; + + /// + /// Creates a new instance. + /// + /// True to run the async version of a test, false to run the sync version of a test. + public RecordedClientTestBase(bool isAsync) : this(isAsync, null) + { } + + /// + /// Creates a new instance. + /// + /// True to run the async version of a test, false to run the sync version of a test. + /// (Optional) The recorded test mode to use. If unset, the default recorded test mode will be used. + public RecordedClientTestBase(bool isAsync, RecordedTestMode? mode = null) : base(isAsync) + { + _options = new TestRecordingOptions(); + Mode = mode ?? GetDefaultRecordedTestMode(); + } + + /// + public override DateTimeOffset TestStartTime => _testStartTime; + + /// + /// Gets the test proxy instance to use for the current test case. + /// + public ProxyService? Proxy { get; protected internal set; } + + /// + /// Gets or sets the current recording mode for the test. + /// + public RecordedTestMode Mode { get; set; } + + /// + /// Gets or sets the recording options to use for the current test. This will be pre-populated with a sensible configuration. + /// + public TestRecordingOptions RecordingOptions + { + get => _options; + set => _options = value ?? throw new ArgumentNullException(nameof(value)); + } + + /// + /// Gets the recording for the current test. + /// + public TestRecording? Recording { get; protected internal set; } + + /// + /// Gets the maximum amount of time to wait for starting/tearing down the test proxy, as well as the maximum amount of time + /// to wait for configuring a recording session, and then saving it or closing it. + /// + public virtual TimeSpan TestProxyWaitTime => Debugger.IsAttached + ? Default.DebuggerAttachedTestTimeout + : Default.TestProxyWaitTime; + + /// + /// Gets the test timeout. + /// + public override TimeSpan TestTimeout + { + get + { + if (Debugger.IsAttached) + { + return Default.DebuggerAttachedTestTimeout; + } + + switch (Mode) + { + default: + case RecordedTestMode.Record: + case RecordedTestMode.Live: + return TimeSpan.FromSeconds(60); + + case RecordedTestMode.Playback: + return Default.TestTimeout; + } + } + } + + /// + /// Determines whether or not to use Fiddler. If this is true, then the recording transport will be updated to use Fiddler + /// as the intermediary when talking to the test proxy, as well as accept the Fiddler root certificate. + /// + public virtual bool UseFiddler + { + get + { + // Check to see if Fiddler is already running and capturing traffic by checking to see if a proxy is configured for + // 127.0.0.1:8888 with no credentials + try + { + Uri dummyUri = new("https://not.a.real.uri.com"); + + IWebProxy webProxy = WebRequest.GetSystemWebProxy(); + Uri? proxyUri = webProxy?.GetProxy(dummyUri); + if (proxyUri == null || proxyUri == dummyUri) + { + return false; + } + + // assume default of 127.0.0.1:8888 with no credentials + var cred = webProxy?.Credentials?.GetCredential(dummyUri, string.Empty); + return proxyUri.Host == "127.0.0.1" + && proxyUri.Port == 8888 + && string.IsNullOrWhiteSpace(cred?.UserName) + && string.IsNullOrWhiteSpace(cred?.Password); + } + catch + { + return false; + } + } + } + + /// + /// Checks if the recording has a recorded value for . If there is none, the + /// will be added and return. Otherwise the existing value will be returned. + /// + /// The name of the value. + /// The value to add. + /// The existing value, or the newly added value. + /// If you called this function outside of a test run. + public string? GetOrAddRecordedValue(string name, string valueToAdd) + => GetOrAddRecordedValue(name, () => valueToAdd); + + /// + /// Checks if the recording has a recorded value for . If there is none, a value will be created, added + /// and returned. Otherwise the existing value will be returned. + /// + /// The name of the value. + /// The factory used to create the value. + /// The existing value, or the newly added value. + /// If you called this function outside of a test run. + public virtual string GetOrAddRecordedValue(string name, Func valueFactory) + { + if (Recording == null) + { + throw new InvalidOperationException("Recorded value should not be retrieved outside the test method invocation"); + } + + return Recording.GetOrAddVariable(name, valueFactory); + } + + /// + /// Starts the test proxy for the current test. This will be called once at the start of the test fixture. + /// + /// Asynchronous task. + [OneTimeSetUp] + public virtual async Task StartTestProxyAsync() + { + using CancellationTokenSource cts = new(TestProxyWaitTime); + + ProxyServiceOptions options = CreateProxyServiceOptions(); + Proxy = await ProxyService.CreateNewAsync(options, cts.Token).ConfigureAwait(false); + } + + [OneTimeTearDown] + public virtual Task StopTestProxyAsync() + { + Proxy?.Dispose(); + Proxy = null; + + //TODO FIXME: Do we need to do any cleanup here? + return Task.CompletedTask; + } + + /// + /// Starts the test proxy (if it has not already been started), and then configures the recording session for the current + /// test. This should also set the property to the new recording session. + /// + /// Asynchronous task. + [SetUp] + public virtual async Task StartTestRecordingAsync() + { + // Check if the current NUnit test method has a specific attribute applied to it + if (!IsCurrentTestRecorded()) + { + return; + } + + if (Proxy == null) + { + throw new InvalidOperationException("The proxy service was not set and/or started"); + } + + _testStartTime = DateTimeOffset.UtcNow; + + // TODO FIXME: Add logic to ignore certain tests here by throwing IgnoreException()? + + using CancellationTokenSource cts = new(TestProxyWaitTime); + Recording = await StartAndConfigureRecordingSessionAsync(Proxy, cts.Token).ConfigureAwait(false); + + // don't include test proxy overhead as part of the test time + _testStartTime = DateTimeOffset.UtcNow; + } + + /// + /// Stops a recording session for the current test. If the test passed and we are in recording mode, the recording will be saved, + /// otherwise it will be discarded. + /// + /// Asynchronous task. + [TearDown] + public virtual async Task StopTestRecordingAsync() + { + if (!IsCurrentTestRecorded()) + { + return; + } + + bool testsPassed = TestContext.CurrentContext.Result.Outcome.Status == NUnit.Framework.Interfaces.TestStatus.Passed; + using CancellationTokenSource cts = new(TestProxyWaitTime); + + if (Recording != null) + { + await Recording.FinishAsync(testsPassed, cts.Token).ConfigureAwait(false); + } + } + + /// + /// Configures the client options for a System.ClientModel based service client. This will be used to configure the transport + /// such that all requests are routed to the test proxy during recording (for capture), and playback (for replaying captured + /// requests). + /// + /// The type of the client options. + /// The options to configure. + /// The configured client options. + /// The current recording mode is not supported. + /// There was no test recording configured for this test. + public virtual TClientOptions ConfigureClientOptions(TClientOptions options) + where TClientOptions : ClientPipelineOptions + { + if (!IsCurrentTestRecorded()) + { + return options; + } + + // If we are in playback, or record mode we should set the transport to the test proxy transport, except + // in the case where we've explicitly specified the transport ourselves in case we are doing some custom + // work. + if (options.Transport != null) + { + return options; + } + + switch (Mode) + { + case RecordedTestMode.Live: + // no need to to anything special + return options; + + case RecordedTestMode.Record: + // continue + break; + + case RecordedTestMode.Playback: + // force the use of a fixed retry with a short timeout + options.RetryPolicy = new TestClientRetryPolicy(delay: TimeSpan.FromMilliseconds(100)); + break; + + default: + throw new NotSupportedException("The following mode is not supported: " + Mode); + } + + if (Recording == null) + { + throw new InvalidOperationException("Please call this from within a test method invocation"); + } + + ProxyTransportOptions transportOptions = Recording.GetProxyTransportOptions(); + transportOptions.UseFiddler = UseFiddler; + if (_options.RequestOverride != null) + { + transportOptions.ShouldRecordRequest = _options.RequestOverride; + } + + options.Transport = new ProxyTransport(transportOptions); + return options; + } + + /// + /// Gets the default recorded test mode to use. + /// + /// The test mode to use. + protected virtual RecordedTestMode GetDefaultRecordedTestMode() => RecordedTestMode.Playback; + + /// + /// Gets the name of recording JSON file that contains the recording. This will be based on a sanitized version + /// of test name, and "Async" will be automatically appended when running the asynchronous versions of tests. + /// + /// The name of the test to use. + protected virtual string GetRecordedTestFileName() + { + const string c_asyncSuffix = "Async"; + TestContext.TestAdapter testAdapter = TestContext.CurrentContext.Test; + + StringBuilder builder = new(testAdapter.Name.Length + c_asyncSuffix.Length); + foreach (char c in testAdapter.Name) + { + builder.Append(s_invalidChars.Contains(c) ? '%' : c); + } + + if (IsAsync) + { + builder.Append(c_asyncSuffix); + } + + builder.Append(".json"); + + return builder.ToString(); + } + + /// + /// Configures a recording/playback session for the current test on the test proxy. This is called at the start of every test. + /// It is responsible for configuring all the necessary sanitizers, matchers, and transforms for the test proxy. + /// + /// The test proxy service to configure the recording session for. + /// The cancellation token to use. + /// The configured test recording session. + /// The test proxy service instance did not have a valid client configured. + /// The recording mode is not supported. + protected virtual async Task StartAndConfigureRecordingSessionAsync(ProxyService proxy, CancellationToken token) + { + var client = proxy.Client ?? throw new ArgumentNullException("Test proxy client was null"); + IDictionary? variables = null; + + ProxyClientResult result; + switch (Mode) + { + case RecordedTestMode.Live: + // nothing to see here + return new TestRecording(string.Empty, RecordedTestMode.Live, proxy); + + case RecordedTestMode.Playback: + var playbackResult = await client.StartPlaybackAsync(CreateRecordingSessionStartInfo(), token).ConfigureAwait(false); + variables = playbackResult.Value; + result = playbackResult; + break; + + case RecordedTestMode.Record: + result = await client.StartRecordingAsync(CreateRecordingSessionStartInfo(), token).ConfigureAwait(false); + break; + + default: + throw new NotSupportedException("Don't know how to handle recording mode: " + Mode); + } + + string? recordingId = result.RecordingId; + if (string.IsNullOrWhiteSpace(recordingId)) + { + throw new InvalidOperationException("Recording test proxy did not return a recording ID"); + } + + TestRecording recording = new TestRecording(recordingId!, Mode, proxy, variables); + await recording.ApplyOptions(_options, token).ConfigureAwait(false); + return recording; + } + + /// + /// Determines whether or not the current test should be recorded (or played back from a file). + /// + /// True to enable the use of the recording test proxy, false otherwise. + protected virtual bool IsCurrentTestRecorded() + { + return TestExecutionContext.CurrentContext.CurrentTest.GetCustomAttributes(true).Any(); + } + + /// + /// Creates the options used when starting a new instance of the test proxy service. + /// + /// The options to use. + protected abstract ProxyServiceOptions CreateProxyServiceOptions(); + + /// + /// Creates the information used to configured a recording/playback session for the current test on the test proxy. + /// + /// The information to use. + protected abstract RecordingStartInformation CreateRecordingSessionStartInfo(); +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/RecordedTestAttribute.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/RecordedTestAttribute.cs new file mode 100644 index 000000000..81d94268c --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/RecordedTestAttribute.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using NUnit.Framework; + +namespace OpenAI.TestFramework; + +/// +/// An attribute used to indicate that a test should be recorded (or played back from a file). When you inherit from +/// in your test class, and add this attribute to your test function, and then +/// make sure to call +/// on the client options you use to configure a client, this should automatically enable the recording/playback +/// functionality. +/// +[AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)] +public class RecordedTestAttribute : TestAttribute +{ +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/RecordedTestMode.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/RecordedTestMode.cs new file mode 100644 index 000000000..bc0371ccf --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/RecordedTestMode.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework; + +/// +/// The recording mode. +/// +public enum RecordedTestMode +{ + /// + /// Talk to live services. No recording or playback is used. + /// + Live, + + /// + /// Record the test and overwrite any existing recordings. + /// + Record, + + /// + /// Playback the test from a recording. + /// + Playback, +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Condition.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Condition.cs new file mode 100644 index 000000000..6b0bdad00 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Condition.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Recording; + +/// +/// A condition used to evaluate whether or not a sanitizer should be applied. +/// +public class Condition +{ + /// Gets or sets the uri regex. + public string? UriRegex { get; set; } + + /// Header condition to apply. + public HeaderCondition? ResponseHeader { get; set; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/HeaderCondition.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/HeaderCondition.cs new file mode 100644 index 000000000..12d2ba55c --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/HeaderCondition.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Recording; + +/// +/// Header condition to apply. +/// +public class HeaderCondition +{ + /// Gets or sets the key. + public string? Key { get; set; } + /// Gets or sets the value regex. + public string? ValueRegex { get; set; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Matchers/BaseMatcher.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Matchers/BaseMatcher.cs new file mode 100644 index 000000000..c4578d763 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Matchers/BaseMatcher.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Text.Json; +using System.Text.Json.Serialization; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework.Recording.Matchers; + +/// +/// The base class for matchers that are applied during a playback session to match an incoming request +/// to a recorded one. +/// +public abstract class BaseMatcher : IUtf8JsonSerializable +{ + /// + /// Creates a new instance. + /// + /// The type of this sanitizer (e.g. GeneralRegexSanitizer). + /// If the type was null. + protected BaseMatcher(string type) + { + Type = type ?? throw new ArgumentNullException(nameof(type)); + } + + /// + /// Gets the type of the matcher (e.g. BodilessMatcher). + /// + [JsonIgnore] + public string Type { get; } + + /// + public virtual void Write(Utf8JsonWriter writer, JsonSerializerOptions? options = null) + { + // By default use reflection based serialization + JsonSerializer.Serialize(writer, this, GetType(), Default.InnerRecordingJsonOptions); + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Matchers/CustomMatcher.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Matchers/CustomMatcher.cs new file mode 100644 index 000000000..80e0f1b6d --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Matchers/CustomMatcher.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Recording.Matchers; + +/// +/// This matcher exposes the default matcher in a customizable way. Currently this merely includes enabling/disabling body match and +/// adding additional excluded headers. All optional settings are safely defaulted. This means that providing zero additional +/// configuration will produce a sanitizer that is functionally identical to the default. +/// +public class CustomMatcher() : BaseMatcher("CustomDefaultMatcher") +{ + /// + /// A comma separated list of additional headers that should be excluded during matching. "Excluded" headers are entirely ignored. + /// Unlike "ignored" headers, the presence (or lack of presence) of a header will not cause mismatch. + /// + public string? ExcludedHeaders { get; set; } + + /// + /// Should the body value be compared during lookup operations? + /// + public bool? CompareBodies { get; set; } + + /// + /// A comma separated list of additional headers that should be ignored during matching. Any headers that are "ignored" will not + /// do value comparison when matching. This means that if the recording has a header that isn't in the request, a test mismatch + /// exception will be thrown noting the lack of header in the request. This also applies if the header is present in the request + /// but not recording. + /// + public string? IgnoredHeaders { get; set; } + + /// + /// A comma separated list of query parameters that should be ignored during matching. + /// + public string? IgnoredQueryParameters { get; set; } + + /// + /// By default, the test-proxy does not sort query params before matching. Setting true will sort query params alphabetically + /// before comparing URI. + /// + public bool? IgnoreQueryOrdering { get; set; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Matchers/ExistingMatcher.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Matchers/ExistingMatcher.cs new file mode 100644 index 000000000..8d2cb4eab --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Matchers/ExistingMatcher.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Text.Json; + +namespace OpenAI.TestFramework.Recording.Matchers; + +/// +/// Used for specifying the use of pre-existing matchers defined in the test proxy. +/// +/// The name of the existing matcher. +public class ExistingMatcher(string existingMatcherName) : BaseMatcher(existingMatcherName) +{ + private static ExistingMatcher? _bodiless = null; + private static ExistingMatcher? _headerless = null; + + /// + /// This matcher adjusts the "match" operation to EXCLUDE the body when matching a request to a recording's entries. + /// + public static ExistingMatcher Bodiless => _bodiless ??= new ExistingMatcher("BodilessMatcher"); + + /// + /// NOT RECOMMENDED. This matcher adjusts the "match" operation to ignore header differences when matching a request. + /// Be aware that wholly ignoring headers during matching might incur unexpected issues down the line. + /// + public static ExistingMatcher Headerless => _headerless ??= new ExistingMatcher("HeaderlessMatcher"); + + /// + public override void Write(Utf8JsonWriter writer, JsonSerializerOptions? options = null) + { + // Pre-existing matchers use an empty JSON object. + writer.WriteStartObject(); + writer.WriteEndObject(); + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyClient.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyClient.cs new file mode 100644 index 000000000..55660d990 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyClient.cs @@ -0,0 +1,679 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Net.Http; +using System.Text.Json; +using OpenAI.TestFramework.Recording.Matchers; +using OpenAI.TestFramework.Recording.Proxy; +using OpenAI.TestFramework.Recording.Proxy.Service; +using OpenAI.TestFramework.Recording.Sanitizers; +using OpenAI.TestFramework.Recording.Transforms; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework.Recording.RecordingProxy; + +/// +/// A client for configuring the recording text proxy. Please see here for more information: +/// https://github.com/Azure/azure-sdk-tools/blob/main/tools/test-proxy/Azure.Sdk.Tools.TestProxy/README.md +/// +public class ProxyClient +{ + protected internal const string X_RECORDING_ID_HEADER = "x-recording-id"; + + private ProxyClientOptions _options; + private ClientPipeline _pipeline; + + /// + /// For testing only. + /// + internal ProxyClient() + { + _options = new(new Uri("http://localhost:0")); + _pipeline = ClientPipeline.Create(); + } + + /// + /// Creates a new instance. + /// + /// The options to use. + public ProxyClient(ProxyClientOptions options) + { + _options = options ?? throw new ArgumentNullException(nameof(options)); + _pipeline = ClientPipeline.Create(options); + } + + /// + /// Starts playback session of recordings. + /// + /// The configuration to use for starting playback. + /// The cancellation token to use. + /// The result that includes any recorded variables. + public virtual ProxyClientResult> StartPlayback(RecordingStartInformation startInfo, CancellationToken token = default) + { + if (startInfo == null) + { + throw new ArgumentNullException(nameof(startInfo)); + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "playback/start", startInfo, token); + return SendSyncOrAsync>(false, message, token).GetAwaiter().GetResult(); + } + + /// + /// Starts playback session of recordings asynchronously. + /// + /// The configuration to use for starting playback. + /// The cancellation token to use. + /// The result that includes any recorded variables. + public virtual async Task>> StartPlaybackAsync(RecordingStartInformation startInfo, CancellationToken token = default) + { + if (startInfo == null) + { + throw new ArgumentNullException(nameof(startInfo)); + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "playback/start", startInfo, token); + return await SendSyncOrAsync>(true, message, token).ConfigureAwait(false); + } + + /// + /// Stops a playback session. + /// + /// The ID for the playback session to stop. + /// The cancellation token to use. + /// The client result. + public virtual ProxyClientResult StopPlayback(string recordingId, CancellationToken token = default) + { + if (string.IsNullOrWhiteSpace(recordingId)) + { + throw new ArgumentException("Recording ID cannot be null, empty, or white space only"); + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "playback/stop", null, token, new() + { + [X_RECORDING_ID_HEADER] = recordingId, + }); + return SendSyncOrAsync(false, message, token).GetAwaiter().GetResult(); + } + + /// + /// Stops a playback session asynchronously. + /// + /// The ID for the playback session to stop. + /// The cancellation token to use. + /// The client result. + public virtual async Task StopPlaybackAsync(string recordingId, CancellationToken token = default) + { + if (string.IsNullOrWhiteSpace(recordingId)) + { + throw new ArgumentException("Recording ID cannot be null, empty, or white space only"); + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "playback/stop", null, token, new() + { + [X_RECORDING_ID_HEADER] = recordingId, + }); + return await SendSyncOrAsync(true, message, token).ConfigureAwait(false); + } + + /// + /// Starts a recording session. + /// + /// The configuration to use for the recording session. + /// The cancellation token to use. + /// The client result. + public virtual ProxyClientResult StartRecording(RecordingStartInformation startInfo, CancellationToken token = default) + { + if (startInfo == null) + { + throw new ArgumentNullException(nameof(startInfo)); + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "record/start", startInfo, token); + return SendSyncOrAsync(false, message, token).GetAwaiter().GetResult(); + } + + /// + /// Starts a recording session asynchronously. + /// + /// The configuration to use for the recording session. + /// The cancellation token to use. + /// The client result. + public virtual async Task StartRecordingAsync(RecordingStartInformation startInfo, CancellationToken token = default) + { + if (startInfo == null) + { + throw new ArgumentNullException(nameof(startInfo)); + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "record/start", startInfo, token); + return await SendSyncOrAsync(true, message, token).ConfigureAwait(false); + } + + /// + /// Stops a recording session. + /// + /// The identifier for the recording session. + /// (Optional) Any additional variables to include with the recording. + /// (Optional) Set this to true to turn off recording. + /// The cancellation token to use. + /// The client result. + public virtual ProxyClientResult StopRecording(string recordingId, IDictionary? variables = null, bool skipRecording = false, CancellationToken token = default) + { + if (string.IsNullOrWhiteSpace(recordingId)) + { + throw new ArgumentException("Recording ID cannot be null, empty, or white space only"); + } + + Dictionary additionalHeaders = new() + { + [X_RECORDING_ID_HEADER] = recordingId + }; + + if (skipRecording) + { + additionalHeaders["x-recording-skip"] = "request-response"; + } + + variables ??= new Dictionary(); + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "record/stop", variables, token, additionalHeaders); + return SendSyncOrAsync(false, message, token).GetAwaiter().GetResult(); + } + + /// + /// Stops a recording session asynchronously. + /// + /// The ID for the recording session to stop. + /// (Optional) Any additional variables to include with the recording. + /// (Optional) Set this to true to turn off recording. + /// The cancellation token to use. + /// The client result. + public virtual async Task StopRecordingAsync(string recordingId, IDictionary? variables = null, bool skipRecording = false, CancellationToken token = default) + { + if (string.IsNullOrWhiteSpace(recordingId)) + { + throw new ArgumentException("Recording ID cannot be null, empty, or white space only"); + } + + Dictionary additionalHeaders = new() + { + [X_RECORDING_ID_HEADER] = recordingId + }; + + if (skipRecording) + { + additionalHeaders["x-recording-skip"] = "request-response"; + } + + variables ??= new Dictionary(); + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "record/stop", variables, token, additionalHeaders); + return await SendSyncOrAsync(true, message, token).ConfigureAwait(false); + } + + /// + /// Sets options for the proxy. + /// + /// The identifier for the playback/recording session. + /// The options to set. + /// The cancellation token to use. + /// The client result. + public virtual ProxyClientResult SetRecordingTransportOptions(string recordingId, ProxyServiceOptions options, CancellationToken token = default) + { + if (string.IsNullOrWhiteSpace(recordingId)) + { + throw new ArgumentException("Recording ID cannot be null, empty, or white space only"); + } + else if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "admin/setrecordingoptions", options, token, new() + { + [X_RECORDING_ID_HEADER] = recordingId, + }); + return SendSyncOrAsync(false, message, token).GetAwaiter().GetResult(); + } + + /// + /// Sets options for the proxy asynchronously. + /// + /// The identifier for the playback/recording session. + /// The options to set. + /// The cancellation token to use. + /// The client result. + public virtual async Task SetRecordingTransportOptionsAsync(string recordingId, ProxyServiceOptions options, CancellationToken token = default) + { + if (string.IsNullOrWhiteSpace(recordingId)) + { + throw new ArgumentException("Recording ID cannot be null, empty, or white space only"); + } + else if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "admin/setrecordingoptions", options, token, new() + { + [X_RECORDING_ID_HEADER] = recordingId, + }); + return await SendSyncOrAsync(true, message, token).ConfigureAwait(false); + } + + /// + /// Removes some pre-defined sanitizers to be used during recording/playback by specifying their IDs. + /// + /// The set of sanitizer IDs to remove. + /// (Optional) If specified, the sanitizers will be removed for a particular session only. + /// If null, the sanitizers will be removed globally on the test proxy. + /// The cancellation token to use. + /// The client result. + public virtual ProxyClientResult RemoveSanitizers(ISet sanitizerIds, string? recordingId = null, CancellationToken token = default) + { + if (sanitizerIds == null) + { + throw new ArgumentNullException(nameof(sanitizerIds)); + } + + Dictionary headers = new(); + if (recordingId != null) + { + headers[X_RECORDING_ID_HEADER] = recordingId; + } + + PipelineMessage message = CreateJsonRequest( + HttpMethod.Post, + "admin/removesanitizers", + new SanitizerIdList() { Sanitizers = sanitizerIds.ToArray() }, + token, + headers); + return SendSyncOrAsync(false, message, token).GetAwaiter().GetResult(); + } + + /// + /// Removes some pre-defined sanitizers to be used during recording/playback by specifying their IDs. + /// + /// The set of sanitizer IDs to remove. + /// (Optional) If specified, the sanitizers will be removed for a particular session only. + /// If null, the sanitizers will be removed globally on the test proxy. + /// The cancellation token to use. + /// The client result. + public virtual async Task RemoveSanitizersAsync(ISet sanitizerIds, string? recordingId = null, CancellationToken token = default) + { + if (sanitizerIds == null) + { + throw new ArgumentNullException(nameof(sanitizerIds)); + } + + Dictionary headers = new(); + if (recordingId != null) + { + headers[X_RECORDING_ID_HEADER] = recordingId; + } + + PipelineMessage message = CreateJsonRequest( + HttpMethod.Post, + "admin/removesanitizers", + new SanitizerIdList() { Sanitizers = sanitizerIds.ToArray() }, + token, + headers); + return await SendSyncOrAsync(true, message, token).ConfigureAwait(false); + } + + /// + /// Adds sanitizers for the recording test proxy. + /// + /// The sanitizers to add. + /// (Optional) If specified, the sanitizers will added for a particular session only. + /// If null, the sanitizers will be added globally on the test proxy. + /// The cancellation token to use. + /// The client result with the set of sanitizer IDs added. + public virtual ProxyClientResult> AddSanitizers(IEnumerable sanitizers, string? recordingId = null, CancellationToken token = default) + { + if (sanitizers == null) + { + throw new ArgumentNullException(nameof(sanitizers)); + } + + Dictionary headers = new(); + if (recordingId != null) + { + headers[X_RECORDING_ID_HEADER] = recordingId; + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "Admin/AddSanitizers", sanitizers, token, headers); + ProxyClientResult result = SendSyncOrAsync(false, message, token).GetAwaiter().GetResult(); + return new ProxyClientResult>( + result.Value.Sanitizers ?? Array.Empty(), + result.GetRawResponse()); + } + + /// + /// Adds sanitizers for the recording test proxy asynchronously. + /// + /// The sanitizers to add. + /// (Optional) If specified, the sanitizers will added for a particular session only. + /// If null, the sanitizers will be added globally on the test proxy. + /// The cancellation token to use. + /// The client result with the set of sanitizer IDs added. + public virtual async Task>> AddSanitizersAsync(IEnumerable sanitizers, string? recordingId = null, CancellationToken token = default) + { + if (sanitizers == null) + { + throw new ArgumentNullException(nameof(sanitizers)); + } + + Dictionary headers = new(); + if (recordingId != null) + { + headers[X_RECORDING_ID_HEADER] = recordingId; + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "Admin/AddSanitizers", sanitizers, token, headers); + ProxyClientResult result = await SendSyncOrAsync(true, message, token).ConfigureAwait(false); + return new ProxyClientResult>( + result.Value.Sanitizers ?? Array.Empty(), + result.GetRawResponse()); + } + + /// + /// Sets the matcher to use. + /// + /// The matcher to use. + /// (Optional) If specified, the matcher will be set for a particular session only. + /// If null, the matcher will be set globally on the test proxy. + /// The cancellation token to use. + /// The client result. + public virtual ProxyClientResult SetMatcher(BaseMatcher matcher, string? recordingId = null, CancellationToken token = default) + { + if (matcher == null) + { + throw new ArgumentNullException(nameof(matcher)); + } + + Dictionary headers = new() + { + ["x-abstraction-identifier"] = matcher.Type + }; + + if (recordingId != null) + { + headers[X_RECORDING_ID_HEADER] = recordingId; + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "admin/setmatcher", matcher, token, headers); + return SendSyncOrAsync(false, message, token).GetAwaiter().GetResult(); + } + + /// + /// Sets the matcher to use asynchronously. + /// + /// The matcher to use. + /// (Optional) If specified, the matcher will be set for a particular session only. + /// If null, the matcher will be set globally on the test proxy. + /// The cancellation token to use. + /// The client result. + public virtual async Task SetMatcherAsync(BaseMatcher matcher, string? recordingId = null, CancellationToken token = default) + { + if (matcher == null) + { + throw new ArgumentNullException(nameof(matcher)); + } + + Dictionary headers = new() + { + ["x-abstraction-identifier"] = matcher.Type + }; + + if (recordingId != null) + { + headers[X_RECORDING_ID_HEADER] = recordingId; + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "admin/setmatcher", matcher, token, headers); + return await SendSyncOrAsync(true, message, token).ConfigureAwait(false); + } + + /// + /// Adds a transform. + /// + /// The transform to add. + /// (Optional) If specified, the transform will be added for a particular session only. + /// If null, the transform will be added globally on the test proxy. + /// The cancellation token to use. + /// The client result. + public virtual ProxyClientResult AddTransform(BaseTransform transform, string? recordingId = null, CancellationToken token = default) + { + if (transform == null) + { + throw new ArgumentNullException(nameof(transform)); + } + + Dictionary headers = new() + { + ["x-abstraction-identifier"] = transform.Type + }; + + if (recordingId != null) + { + headers[X_RECORDING_ID_HEADER] = recordingId; + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "admin/addtransform", transform, token, headers); + return SendSyncOrAsync(false, message, token).GetAwaiter().GetResult(); + } + + /// + /// Adds a transform asynchronously. + /// + /// The transform to add. + /// (Optional) If specified, the transform will be added for a particular session only. + /// If null, the transform will be added globally on the test proxy. + /// The cancellation token to use. + /// The client result. + public virtual async Task AddTransformAsync(BaseTransform transform, string? recordingId = null, CancellationToken token = default) + { + if (transform == null) + { + throw new ArgumentNullException(nameof(transform)); + } + + Dictionary headers = new() + { + ["x-abstraction-identifier"] = transform.Type + }; + + if (recordingId != null) + { + headers[X_RECORDING_ID_HEADER] = recordingId; + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "admin/addtransform", transform, token, headers); + return await SendSyncOrAsync(true, message, token).ConfigureAwait(false); + } + + /// + /// Resets the sanitizers, matcher, and transforms to the default. + /// + /// (Optional) If specified, only the particular session will be reset. + /// If null, the reset will apply globally. + /// The cancellation token to use. + /// The client result. + public virtual ProxyClientResult Reset(string? recordingId = null, CancellationToken token = default) + { + Dictionary headers = new(); + if (recordingId != null) + { + headers[X_RECORDING_ID_HEADER] = recordingId; + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "Admin/Reset", null, token, headers); + return SendSyncOrAsync(false, message, token).GetAwaiter().GetResult(); + } + + /// + /// Resets the sanitizers, matcher, and transforms to the default asynchronously. + /// + /// (Optional) If specified, only the particular session will be reset. + /// If null, the reset will apply globally. + /// The cancellation token to use. + /// The client result. + public virtual async Task ResetAsync(string? recordingId = null, CancellationToken token = default) + { + Dictionary headers = new(); + if (recordingId != null) + { + headers[X_RECORDING_ID_HEADER] = recordingId; + } + + PipelineMessage message = CreateJsonRequest(HttpMethod.Post, "Admin/Reset", null, token, headers); + return await SendSyncOrAsync(true, message, token).ConfigureAwait(false); + } + + /// + /// Lists the available sanitizers, matchers, and transforms. + /// + /// The cancellation token. + /// The client result with the HTML returned from the service. + public virtual ProxyClientResult ListAvailable(CancellationToken token = default) + { + PipelineMessage message = CreateJsonRequest(HttpMethod.Get, "Info/Available", null, token); + ProxyClientResult result = SendSyncOrAsync(false, message, token).GetAwaiter().GetResult(); + return new ProxyClientResult(result.GetRawResponse().Content.ToString(), result.GetRawResponse()); + } + + /// + /// Lists the available sanitizers, matchers, and transforms asynchronously. + /// + /// The cancellation token. + /// The client result with the HTML returned from the service. + public virtual async Task> ListAvailableAsync(CancellationToken token = default) + { + PipelineMessage message = CreateJsonRequest(HttpMethod.Get, "Info/Available", null, token); + ProxyClientResult result = await SendSyncOrAsync(true, message, token).ConfigureAwait(false); + return new ProxyClientResult(result.GetRawResponse().Content.ToString(), result.GetRawResponse()); + } + + protected virtual PipelineMessage CreateJsonRequest(HttpMethod method, string path, TBody? body, CancellationToken token, Dictionary? headers = null) + { + PipelineMessage message = _pipeline.CreateMessage(); + message.Apply(new RequestOptions + { + CancellationToken = token, + BufferResponse = true + }); + + PipelineRequest request = message.Request; + request.Method = method.Method; + request.Uri = new Uri(_options.HttpEndpoint, path); + request.Headers.Add("Accept", "application/json"); + + if (headers != null) + { + foreach (var kvp in headers) + { + request.Headers.Add(kvp.Key, kvp.Value); + } + } + + if (body != null) + { + MemoryStream stream = new(); + using Utf8JsonWriter writer = new(stream); + JsonSerializer.Serialize(writer, body, Default.RecordingJsonOptions); + BinaryData jsonBody = BinaryData.FromBytes(new ReadOnlyMemory(stream.GetBuffer(), 0, (int)stream.Length)); + + request.Headers.Add("Content-Type", "application/json"); + request.Content = BinaryContent.Create(jsonBody); + } + + return message; + } + + protected virtual async ValueTask SendSyncOrAsync(bool isAsync, PipelineMessage message, CancellationToken token) + { + if (isAsync) + { + await _pipeline.SendAsync(message).ConfigureAwait(false); + } + else + { + _pipeline.Send(message); + } + + PipelineResponse response = message.Response ?? throw new ClientResultException("Response was null", message.Response); + if (response.IsError) + { + if (response.Content.ToMemory().Length > 0) + { + string contentType = response.Headers.GetFirstOrDefault("Content-Type") ?? string.Empty; + + if (contentType.StartsWith("text/", StringComparison.OrdinalIgnoreCase)) + { + string error = response.Content.ToString(); + throw new ClientResultException(error, response); + } + else if (contentType.StartsWith("application/json", StringComparison.OrdinalIgnoreCase)) + { + string error; + try + { + var parsed = response.Content.ToObjectFromJson(new() + { + PropertyNameCaseInsensitive = true + }); + + error = $"{parsed.Status}: {parsed.Message}"; + } + catch + { + error = response.Content.ToString(); + } + + throw new ClientResultException(error, response); + } + } + + throw new ClientResultException(response); + } + + return new ProxyClientResult(response); + } + + protected virtual async ValueTask> SendSyncOrAsync(bool isAsync, PipelineMessage message, CancellationToken token) + { + if (isAsync) + { + await SendSyncOrAsync(isAsync, message, token).ConfigureAwait(false); + } + else + { + SendSyncOrAsync(isAsync, message, token).GetAwaiter().GetResult(); + } + + PipelineResponse response = message.Response!; // we've already validated this is not null in the previous call + + try + { + TResponse? parsed = JsonSerializer.Deserialize(response.Content.ToMemory().Span, Default.TestProxyJsonOptions); + if (parsed == null) + { + throw new InvalidDataException("Response parsed to null"); + } + + return new ProxyClientResult(parsed, response); + } + catch (Exception ex) + { + throw new ClientResultException("Failed to deserialize response", message.Response, ex); + } + } + + private struct ErrorResponse + { + public string? Message { get; set; } + public string? Status { get; set; } + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyClientOptions.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyClientOptions.cs new file mode 100644 index 000000000..bb35323c9 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyClientOptions.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; + +namespace OpenAI.TestFramework.Recording.RecordingProxy; + +/// +/// Options for the test proxy client. +/// +public class ProxyClientOptions : ClientPipelineOptions +{ + /// + /// Creates a new instance. + /// + /// The HTTP endpoint. + /// The endpoint was null. + /// The endpoint was not absolute. + public ProxyClientOptions(Uri http) + { + if (http == null) throw new ArgumentNullException(nameof(http)); + else if (!http.IsAbsoluteUri) throw new ArgumentException("URI must be absolute", nameof(http)); + + HttpEndpoint = http; + } + + /// + /// The HTTP endpoint to use + /// + public Uri HttpEndpoint { get; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyClientResult.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyClientResult.cs new file mode 100644 index 000000000..a1e16d300 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyClientResult.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework.Recording.RecordingProxy +{ + /// + /// Represents the result of a proxy client operation. + /// + public class ProxyClientResult : ClientResult + { + /// + /// Initializes a new instance of the class. + /// + /// (Optional) The pipeline response. + public ProxyClientResult(PipelineResponse? response = null) + { + if (response != null) + { + SetRawResponse(response); + } + } + + /// + /// Gets the recording ID from the response headers. + /// + public string? RecordingId => GetRawResponse().Headers.GetFirstOrDefault(ProxyClient.X_RECORDING_ID_HEADER); + } + + /// + /// Represents the result of a proxy client operation. + /// + /// The type of the result value. + public class ProxyClientResult : ProxyClientResult + { + /// + /// Initializes a new instance of the class. + /// + /// The result value. + /// (Optional) The pipeline response. + public ProxyClientResult(TResult value, PipelineResponse? response = null) + { + Value = value; + if (response != null) + { + SetRawResponse(response); + } + } + + /// + /// Gets the result value. + /// + public virtual TResult Value { get; } + + /// + /// Implicitly converts the to the result value. + /// + /// The instance. + /// The result value. + public static implicit operator TResult(ProxyClientResult result) => result.Value; + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyService.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyService.cs new file mode 100644 index 000000000..8d6460afd --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyService.cs @@ -0,0 +1,256 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Text; +using NUnit.Framework; +using OpenAI.TestFramework.Recording.RecordingProxy; +using OpenAI.TestFramework.Utils.Processes; + +namespace OpenAI.TestFramework.Recording.Proxy; + +/// +/// Represents the test proxy. See here for more information: +/// https://github.com/Azure/azure-sdk-tools/blob/main/tools/test-proxy/Azure.Sdk.Tools.TestProxy/README.md +/// +public class ProxyService : IDisposable +{ + private const int c_maxLines = 50; + + private Process _testProxyProcess; + private Uri? _http; + private Uri? _https; + private TaskCompletionSource<(int, int)> _portsAvailableTcs; + private StringBuilder _errorOutput; + private int _lines; + private ProxyClient? _client; + private WindowsJob? _windowsJob; + + /// + /// Creates a new instance. + /// + /// The options to use. + /// was null. + private ProxyService(ProxyServiceOptions options) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + options.Validate(); + + ProcessStartInfo startInfo = new() + { + FileName = options.DotnetExecutable, + Arguments = $@"""{options.TestProxyDll}"" start -u --storage-location=""{options.StorageLocationDir}""", + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false, + EnvironmentVariables = + { + ["ASPNETCORE_URLS"] = $"http://127.0.0.1:{options.HttpPort};https://127.0.0.1:{options.HttpsPort}", + ["Logging__LogLevel__Azure.Sdk.Tools.TestProxy"] = "Error", + ["Logging__LogLevel__Default"] = "Error", + ["Logging__LogLevel__Microsoft.AspNetCore"] = "Error", + ["Logging__LogLevel__Microsoft.Hosting.Lifetime"] = "Information", + } + }; + + if (options.DevCertFile != null) + { + startInfo.EnvironmentVariables["ASPNETCORE_Kestrel__Certificates__Default__Path"] = options.DevCertFile; + if (options.DevCertPassword != null) + { + startInfo.EnvironmentVariables["ASPNETCORE_Kestrel__Certificates__Default__Password"] = options.DevCertPassword; + } + } + + _errorOutput = new(); + _portsAvailableTcs = new(); + _testProxyProcess = new Process() + { + EnableRaisingEvents = true, + StartInfo = startInfo + }; + + _testProxyProcess.Exited += (_, _) => + { + _portsAvailableTcs.TrySetException(new InvalidOperationException("Test proxy process exited unexpectedly")); + }; + _testProxyProcess.ErrorDataReceived += HandleStdErr; + _testProxyProcess.OutputDataReceived += HandleStdOut; + + _windowsJob = null; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + // If running on Windows, use a Job to instruct the OS to kill the test proxy service process + // should this current process die for any reason. + _windowsJob = new($"TestProxy_{Process.GetCurrentProcess().Id}"); + } + } + + /// + /// Gets the client to use to communicate with this recording test proxy. + /// + public ProxyClient Client => _client + ?? throw new InvalidOperationException("Please wait for the proxy to finish starting first"); + + /// + /// Gets the HTTP endpoint the test recording proxy is listening on. + /// + public Uri HttpEndpoint => _http + ?? throw new InvalidOperationException("Please wait for the proxy to finish starting first"); + + /// + /// Gets the HTTPS endpoint the test recording proxy is listening on. + /// + public Uri HttpsEndpoint => _https + ?? throw new InvalidOperationException("Please wait for the proxy to finish starting first"); + + /// + /// Creates a new instance of the recording test proxy. + /// + /// The options to use for the proxy. + /// The cancellation token to use. + /// The initialized recording test proxy instance. + public static async Task CreateNewAsync(ProxyServiceOptions options, CancellationToken token = default) + { + token.ThrowIfCancellationRequested(); + + ProxyService proxy = new ProxyService(options); + + // Try to make sure the test proxy process is terminated when we exit + AppDomain.CurrentDomain.DomainUnload += (_, _) => proxy.Dispose(); + // TODO FIXME: On Windows, use a job to ensure the OS will properly kill the process + + await proxy.StartAsync(token).ConfigureAwait(false); + return proxy; + } + + /// + /// Tears down the recording test proxy instance. + /// + public void Dispose() + { + _portsAvailableTcs.TrySetException(new ObjectDisposedException(nameof(ProxyService))); + try + { + _testProxyProcess.Kill(); + if (_windowsJob != null) + { + // do NOT call Dispose here. This will terminate this process too. + } + } catch { /* we tried */ } + } + + /// + /// Checks to see if any errors were encountered in the test proxy, and if so throws an exception. + /// + /// If there were any errors encountered. + public void ThrowOnErrors() + { + lock (_errorOutput) + { + if (_errorOutput.Length > 0) + { + string error = _errorOutput.ToString(); + _errorOutput.Clear(); + throw new InvalidOperationException($"An error occurred in the test proxy:\n{error}"); + } + } + } + + /// + /// For testing purposes only + /// + /// The client to set. + internal void SetClient(ProxyClient client) + { + _client = client; + } + + /// + /// Starts the recording test proxy instance, and waits until we can read the ports it is listening on for + /// HTTP and HTTPS. + /// + /// The cancellation token to use. + /// Asynchronous tas + /// The test proxy failed to start, or we encountered some other error. + protected async Task StartAsync(CancellationToken token = default) + { + token.Register(_portsAvailableTcs.SetCanceled); + + bool success = _testProxyProcess.Start(); + if (!success) + { + throw new InvalidOperationException("The test proxy process failed to start"); + } + + _windowsJob?.Add(_testProxyProcess); + + _testProxyProcess.BeginOutputReadLine(); + _testProxyProcess.BeginErrorReadLine(); + + await _portsAvailableTcs.Task.ConfigureAwait(false); + } + + private static Uri? ParseListeningOnUri(string line) + { + const string nowListeningOn = "Now listening on: "; + int index = line.IndexOf(nowListeningOn, StringComparison.OrdinalIgnoreCase); + if (index < 0) + { + return null; + } + + Uri.TryCreate(line.AsSpan().Slice(index + nowListeningOn.Length).Trim().ToString(), UriKind.Absolute, out Uri? uri); + return uri; + } + + private void HandleStdErr(object sender, DataReceivedEventArgs args) + { + if (args?.Data != null) + { + lock (_errorOutput) + { + _errorOutput.Append(args.Data); + } + + TestContext.Progress.WriteLine(args.Data); + } + } + + private void HandleStdOut(object sender, DataReceivedEventArgs args) + { + if (_lines++ >= c_maxLines) + { + _portsAvailableTcs.TrySetException(new InvalidOperationException( + $"Failed to start the test proxy. One or both the ports was not populated. http: {_http}, https: {_https}")); + _testProxyProcess.OutputDataReceived -= HandleStdOut; + return; + } + else if (args?.Data == null) + { + return; + } + + Uri? uri = ParseListeningOnUri(args.Data); + if (_http == null && uri?.Scheme == "http") + { + _http = uri; + _client = new ProxyClient(new ProxyClientOptions(_http!)); + } + else if (_https == null && uri?.Scheme == "https") + { + _https = uri; + } + + if (_http != null && _https != null) + { + _testProxyProcess.OutputDataReceived -= HandleStdOut; + _portsAvailableTcs.TrySetResult((_http.Port, _https.Port)); + } + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyServiceOptions.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyServiceOptions.cs new file mode 100644 index 000000000..2f3e3d27f --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyServiceOptions.cs @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Recording.Proxy; + +/// +/// Options for starting the recording test proxy. +/// +public class ProxyServiceOptions +{ + /// + /// Gets the full path to the dotnet executable. + /// + required public string DotnetExecutable { get; set; } + + /// + /// Gets the full path to the test proxy DLL. + /// + required public string TestProxyDll { get; set; } + + /// + /// The path to the directory to store or read recordings from. + /// + required public string StorageLocationDir { get; set; } + + /// + /// (Optional) The file to use for the HTTPS endpoint certificate. + /// + public string? DevCertFile { get; set; } + + /// + /// (Optional) The password to use for opening the for the HTTPS endpoint. + /// + public string? DevCertPassword { get; set; } + + /// + /// (Optional) The HTTP port the test proxy should listen on. Set this to 0 to have the next available port be automatically selected. + /// + public ushort HttpPort { get; set; } + + /// + /// (Optional) The HTTPS port the test proxy should listen on. Set this to 0 to have the next available port be automatically selected. + /// + public ushort HttpsPort { get; set; } + + /// + /// Validates the configuration. + /// + /// The storage location directory was could not be found. + /// The HTTPS certificate file could not be found. + /// No password was specified for the developer certificate file. + internal protected virtual void Validate() + { + List exceptions = new(); + + if (!File.Exists(DotnetExecutable)) + { + exceptions.Add(new FileNotFoundException("Could not find (or read from) the dotnet executable: " + DotnetExecutable)); + } + else if (!File.Exists(TestProxyDll)) + { + exceptions.Add(new FileNotFoundException("Could not find (or read from) the test proxy DLL: " + TestProxyDll)); + } + else if (!Directory.Exists(StorageLocationDir)) + { + exceptions.Add(new DirectoryNotFoundException("Could not find (or read from) the following directory: " + StorageLocationDir)); + } + else if (DevCertFile != null && !File.Exists(DevCertFile)) + { + exceptions.Add(new FileNotFoundException("Could not find (or read from) the HTTPS certificate file: " + DevCertFile)); + } + else if (DevCertFile != null && DevCertPassword == null) + { + exceptions.Add(new InvalidOperationException($"You must set the {nameof(DevCertPassword)} property if you specify the {nameof(DevCertFile)}")); + } + + if (exceptions.Any()) + { + throw new AggregateException("The test proxy service configuration is invalid", exceptions); ; + } + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyTransport.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyTransport.cs new file mode 100644 index 000000000..69cb9a8fc --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyTransport.cs @@ -0,0 +1,216 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Net.Http; +using System.Text.Json; + +namespace OpenAI.TestFramework.Recording.RecordingProxy; + +/// +/// Implements a that will redirect all HTTP/HTTPS requests to the test proxy for recording or playback. +/// Depending on the mode, the test proxy will then either forward the request to the upstream service and record the request and response, +/// or playback the response from a previous recording. +/// +public class ProxyTransport : PipelineTransport +{ + private const string DevCertIssuer = "CN=localhost"; + private const string FiddlerCertIssuer = "CN=DO_NOT_TRUST_FiddlerRoot, O=DO_NOT_TRUST, OU=Created by http://www.fiddler2.com"; + private const string FiddlerHost = "ipv4.fiddler"; + + private readonly ProxyTransportOptions _options; + + /// + /// Initializes a new instance of the class. + /// + /// The options for the proxy transport. + public ProxyTransport(ProxyTransportOptions options) + { + _options = options ?? throw new ArgumentNullException(nameof(options)); + + string certIssuer; + if (_options.UseFiddler) + { + certIssuer = FiddlerCertIssuer; + } + else + { + certIssuer = DevCertIssuer; + } + + HttpClientHandler handler = new() + { + ServerCertificateCustomValidationCallback = (_, certificate, _, _) => certificate?.Issuer == certIssuer, + UseCookies = _options.AllowCookies, + AllowAutoRedirect = _options.AllowAutoRedirect + }; + + InnerTransport = new HttpClientPipelineTransport(new HttpClient(handler)); + } + + /// + /// The actual transport to use for sending requests, and receiving responses. + /// + protected PipelineTransport InnerTransport { get; } + + /// + protected override PipelineMessage CreateMessageCore() + { + Exception? ex = _options.MismatchException?.GetValue(); + if (ex != null) + { + throw ex; + } + + PipelineMessage message = InnerTransport.CreateMessage(); + PipelineRequest request = message.Request; + + // PipelineRequest no longer has a ClientRequestId property, so we need to set it on the headers directly + request.Headers.Add("x-ms-client-request-id", _options.RequestId); + + return message; + } + + /// + protected override void ProcessCore(PipelineMessage message) + => ProcessCoreSyncOrAsync(message, async: false).GetAwaiter().GetResult(); + + /// + protected override ValueTask ProcessCoreAsync(PipelineMessage message) + => ProcessCoreSyncOrAsync(message, async: true); + + /// + /// Processes the pipeline message synchronously or asynchronously. + /// + /// The pipeline message to process. + /// A flag indicating whether to process asynchronously. + /// A representing the asynchronous operation. + protected virtual async ValueTask ProcessCoreSyncOrAsync(PipelineMessage message, bool async) + { + try + { + RedirectToTestProxy(message); + if (async) + { + await InnerTransport.ProcessAsync(message).ConfigureAwait(false); + } + else + { + InnerTransport.Process(message); + } + + await ProcessResponseSyncAsync(message, async).ConfigureAwait(false); + } + finally + { + // revert the original URI - this is important for tests that rely on aspects of the URI in the pipeline + // e.g. KeyVault caches tokens based on URI + message.Request.Headers.TryGetValue("x-recording-upstream-base-uri", out string? original); + if (message.Request.Uri is null) + { + throw new InvalidOperationException("The request cannot have a null URI"); + } + if (original == null) + { + throw new InvalidOperationException("The TestProxy response did not contain the expected \"x-recording-upstream-base-uri\" header"); + } + + var originalBaseUri = new Uri(original); + var builder = new UriBuilder(message.Request.Uri); + builder.Scheme = originalBaseUri.Scheme; + builder.Host = originalBaseUri.Host; + builder.Port = originalBaseUri.Port; + + message.Request.Uri = builder.Uri; + } + } + + /// + /// Processes the response synchronously or asynchronously. + /// + /// The pipeline message containing the response. + /// A flag indicating whether to process asynchronously. + /// A representing the asynchronous operation. + protected virtual async ValueTask ProcessResponseSyncAsync(PipelineMessage message, bool async) + { + if (message.Response?.Headers.TryGetValues("x-request-mismatch", out _) == true) + { + if (message.Response.ContentStream == null) + { + throw new TestRecordingMismatchException("Detected a mismatch but the response had no body"); + } + + using var doc = async + ? await JsonDocument.ParseAsync(message.Response.ContentStream).ConfigureAwait(false) + : JsonDocument.Parse(message.Response.ContentStream); + throw new TestRecordingMismatchException(doc.RootElement.GetProperty("Message").GetString(), null); + } + } + + // copied from https://github.com/Azure/azure-sdk-for-net/blob/main/common/Perf/Azure.Test.Perf/TestProxyPolicy.cs + /// + /// Redirects the pipeline message to the test proxy based on the recording mode. + /// + /// The pipeline message to redirect. + protected virtual void RedirectToTestProxy(PipelineMessage message) + { + if (_options.Mode == RecordedTestMode.Record) + { + switch (_options.ShouldRecordRequest(message.Request)) + { + case RequestRecordMode.Record: + break; + case RequestRecordMode.RecordWithoutRequestBody: + message.Request.Headers.Set("x-recording-skip", "request-body"); + break; + case RequestRecordMode.DoNotRecord: + message.Request.Headers.Set("x-recording-skip", "request-response"); + break; + } + } + else if (_options.Mode == RecordedTestMode.Playback) + { + switch (_options.ShouldRecordRequest(message.Request)) + { + case RequestRecordMode.Record: + break; + case RequestRecordMode.RecordWithoutRequestBody: + // CAUTION: setting the request content to null has the unfortunate side effect of causing any HttpClient backed + // implementation of networking to not send up any Content-??? headers as well which can cause test + // mismatches. Let's work around this by setting some empty content. + message.Request.Content = BinaryContent.Create(BinaryData.FromBytes(Array.Empty())); + break; + case RequestRecordMode.DoNotRecord: + throw new InvalidOperationException( + "Cannot playback when recording has been disabled. Please make sure to skip the test or request."); + } + } + + var request = message.Request; + request.Headers.Set("x-recording-id", _options.RecordingId); + request.Headers.Set("x-recording-mode", _options.Mode.ToString().ToLowerInvariant()); + + if (request.Uri is null) + { + throw new InvalidOperationException("Request URI cannot be null"); + } + + // Intentionally reset the upstream URI in case the request URI changes between retries - e.g. when using GeoRedundant secondary Storage + var builder = new UriBuilder() + { + Scheme = request.Uri.Scheme, + Host = request.Uri.Host, + Port = request.Uri.Port, + }; + request.Headers.Set("x-recording-upstream-base-uri", builder.ToString()); + + Uri baseUri = request.Uri.Scheme == "https" ? _options.HttpsEndpoint : _options.HttpEndpoint; + + builder = new(request.Uri); + builder.Host = _options.UseFiddler ? FiddlerHost : baseUri.Host; + builder.Port = baseUri.Port; + + request.Uri = builder.Uri; + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyTransportOptions.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyTransportOptions.cs new file mode 100644 index 000000000..a0087e850 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/ProxyTransportOptions.cs @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework.Recording.RecordingProxy; + +/// +/// The options for the recording test proxy transport. +/// +public class ProxyTransportOptions +{ + private Func? _shouldRecordRequest; + + /// + /// Gets or sets the test proxy HTTP endpoint. + /// + required public Uri HttpEndpoint { get; set; } + + /// + /// Gets or sets the test proxy HTTPS endpoint. + /// + required public Uri HttpsEndpoint { get; set; } + + /// + /// Gets or sets the current test recording mode. + /// + required public RecordedTestMode Mode { get; set; } + + /// + /// Gets or sets the identifier for the recording. + /// + required public string RecordingId { get; set; } + + /// + /// The ID for the request. Please make sure that a consistent ID is used during recording and playback to avoid + /// mismatches. + /// + required public string RequestId { get; set; } + + /// + /// Gets or sets the delegate used to get/set the test recording mismatch exception. + /// + public PropertyDelegate? MismatchException { get; set; } + + /// + /// Gets or sets a value indicating whether to use Fiddler. If this is true, the transport will be updated to accept + /// the Fiddler root certificate. + /// + public bool UseFiddler { get; set; } + + /// + /// Gets or sets the predicate used to determine whether or not a particular request should not be recorded. + /// Default behaviour is to defer to what the matchers/sanitizers do. + /// + public Func ShouldRecordRequest + { + get => _shouldRecordRequest ?? (_ => RequestRecordMode.Record); + set => _shouldRecordRequest = value; + } + + /// + /// Gets or sets a value indicating whether to allow cookies while sending and receiving requests. + /// + public bool AllowCookies { get; set; } + + /// + /// Gets or sets a value indicating whether to allow auto redirect when processing server responses. + /// + public bool AllowAutoRedirect { get; set; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/RequestRecordMode.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/RequestRecordMode.cs new file mode 100644 index 000000000..d8a782327 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/RequestRecordMode.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Recording.RecordingProxy; + +/// +/// Enumeration of possible values of how to record a request. This acts as an override. +/// +public enum RequestRecordMode +{ + /// + /// Records the request. + /// + Record, + /// + /// Records the request headers but skips the request body. + /// + RecordWithoutRequestBody, + /// + /// Does not record the request (nor the response). + /// + DoNotRecord, +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/Service/PemPair.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/Service/PemPair.cs new file mode 100644 index 000000000..15c72d94e --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/Service/PemPair.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Recording.Proxy.Service; + +/// +/// Information about certificates for the test proxy service. +/// +public class PemPair +{ + /// Gets or sets the pem value. + public string? PemValue { get; set; } + /// Gets or sets the pem key. + public string? PemKey { get; set; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/Service/ProxyServiceRecordingOptions.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/Service/ProxyServiceRecordingOptions.cs new file mode 100644 index 000000000..449e90926 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/Service/ProxyServiceRecordingOptions.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Recording.Proxy.Service; + +/// +/// Options for the test proxy. +/// +public class ProxyServiceRecordingOptions +{ + /// + /// Whether or not to follow redirects + /// + public bool? HandleRedirects { get; set; } + + /// + /// If set, this will change the "root" path the test proxy uses when loading a recording. + /// + public string? ContextDirectory { get; set; } + + /// + /// Options for the transport. + /// + public ProxyServiceTransportCustomizations? Transport { get; set; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/Service/ProxyServiceTransportCustomizations.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/Service/ProxyServiceTransportCustomizations.cs new file mode 100644 index 000000000..6be1ba257 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/Service/ProxyServiceTransportCustomizations.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework.Recording.Proxy.Service; + +/// +/// Transport customizations for the test proxy service. +/// +public class ProxyServiceTransportCustomizations() +{ + /// Gets or sets the allow auto redirect. + public bool? AllowAutoRedirect { get; set; } + + /// + /// If specified, the public key contained here will be used during validation of the SSL connection by + /// comparing thumbprints. + /// + public string? TLSValidationCert { get; set; } + + /// + /// If specified, the will only be applied to the specified host. + /// + public string? TSLValidationCertHost { get; set; } + + /// + /// Each certificate pair contained within this list should be added to the clientHandler for the server + /// or an individual recording. + /// + public IList? Certificates { get; set; } + + /// + /// During playback, a response is normally returned all at once. By offering this response time, we can + /// "stretch" the writing of the response bytes over a time range of milliseconds. + /// + [JsonConverter(typeof(TimespanToMillisecondConverter))] + public TimeSpan? PlaybackResponseTime { get; set; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/Service/RecordingStartInformation.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/Service/RecordingStartInformation.cs new file mode 100644 index 000000000..e3526592a --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/Service/RecordingStartInformation.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace OpenAI.TestFramework.Recording.Proxy.Service; + +/// +/// Information for starting a recording or playback session with the recording test proxy. +/// +public class RecordingStartInformation +{ + /// + /// Gets or sets the file to save recordings to, or to play back requests from. + /// + [JsonPropertyName("x-recording-file")] + required public string RecordingFile { get; set; } + + /// + /// Gets or sets the path to the "assets.json" file to use for integration with external Git + /// repositories. This enables the proxy to work against repositories that do not emplace their + /// test recordings directly alongside their test implementations. + /// + /// + /// Please refer to the documentation for more information: + /// https://github.com/Azure/azure-sdk-tools/blob/main/tools/test-proxy/documentation/asset-sync/README.md + /// + [JsonPropertyName("x-recording-assets-file")] + public string? AssetsFile { get; set; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/Service/SanitizerIdList.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/Service/SanitizerIdList.cs new file mode 100644 index 000000000..f0982542b --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Proxy/Service/SanitizerIdList.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Recording.Proxy.Service; + +/// +/// Request to remove sanitizers for the test proxy. +/// +public struct SanitizerIdList +{ + /// + /// The IDs of the sanitizers to remove. + /// + public string[]? Sanitizers { get; set; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/BaseRegexSanitizer.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/BaseRegexSanitizer.cs new file mode 100644 index 000000000..3c52f35b8 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/BaseRegexSanitizer.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Recording.Sanitizers; + +/// +/// The case class for regex based sanitizers +/// +public abstract class BaseRegexSanitizer(string type) : BaseSanitizer(type) +{ + /// + /// Gets the regular expression to match what to replace. + /// + public string? Regex { get; set; } + + /// + /// Gets or sets the value to replace the match with. + /// + public string? Value { get; set; } + + /// + /// Gets or sets the group in the regex match to replace. + /// + public string? GroupForReplace { get; set; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/BaseSanitizer.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/BaseSanitizer.cs new file mode 100644 index 000000000..011145bf5 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/BaseSanitizer.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Text.Json; +using System.Text.Json.Serialization; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework.Recording.Sanitizers; + +/// +/// The base class for all test proxy recording sanitizers +/// +public abstract class BaseSanitizer : IUtf8JsonSerializable +{ + /// + /// Creates a new instance. + /// + /// The type of this sanitizer (e.g. GeneralRegexSanitizer). + /// If the type was null. + protected BaseSanitizer(string type) + { + Type = type ?? throw new ArgumentNullException(nameof(Type)); + } + + /// + /// Gets the type of the sanitizer (e.g. HeaderRegexSanitizer). + /// + [JsonIgnore] + public string Type { get; } + + /// + public void Write(Utf8JsonWriter writer, JsonSerializerOptions? options = null) + { + writer.WriteStartObject(); + { + writer.WriteString("Name"u8, Type); + writer.WritePropertyName("Body"u8); + + SerializeInner(writer, options); + } + writer.WriteEndObject(); + } + + /// + /// Serializes the child types. By default this will use reflection based serialization. + /// + /// The writer to write to. + protected virtual void SerializeInner(Utf8JsonWriter writer, JsonSerializerOptions? options = null) + { + // By default use reflection based serialization + JsonSerializer.Serialize(writer, this, GetType(), Default.InnerRecordingJsonOptions); + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/BodyKeySanitizer.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/BodyKeySanitizer.cs new file mode 100644 index 000000000..87b3f67e1 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/BodyKeySanitizer.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Recording.Sanitizers; + +/// +/// Sanitizer for a request body that matches a particular value in JSON using a JPath expression. +/// +public class BodyKeySanitizer : BaseRegexSanitizer +{ + /// + /// Creates a new instance. + /// + /// The JSON path to match. + /// If the JSON path is null. + public BodyKeySanitizer(string jsonPath) : base("BodyKeySanitizer") + { + JsonPath = jsonPath ?? throw new ArgumentNullException(nameof(jsonPath)); + } + + /// + /// The JPath expression to match a particular value to sanitize. + /// + public string JsonPath { get; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/BodyRegexSanitizer.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/BodyRegexSanitizer.cs new file mode 100644 index 000000000..e49b6f625 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/BodyRegexSanitizer.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Recording.Sanitizers; + +/// +/// Sanitizer for the body of a request or response. +/// +public class BodyRegexSanitizer : BaseRegexSanitizer +{ + /// + /// Creates a new instance. + /// + /// Gets the regular expression to match what to replace. + /// If was null. + public BodyRegexSanitizer(string regex) : base("BodyRegexSanitizer") + { + Regex = regex ?? throw new ArgumentNullException(nameof(regex)); + } + + /// + /// Condition to apply for the sanitization or transform. If the condition is not met, sanitization is not performed. + /// + public Condition? Condition { get; set; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/HeaderRegexSanitizer.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/HeaderRegexSanitizer.cs new file mode 100644 index 000000000..d1a76fc04 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/HeaderRegexSanitizer.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Recording.Sanitizers; + +/// +/// Sanitizer for a request header. +/// +public class HeaderRegexSanitizer : BaseRegexSanitizer +{ + /// + /// Creates a new instance. + /// + /// The header to sanitize. + /// If the is null. + public HeaderRegexSanitizer(string key) : base("HeaderRegexSanitizer") + { + Key = key ?? throw new ArgumentNullException(nameof(key)); + } + + /// + /// The name of the header to sanitize. + /// + public string Key { get; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/UriRegexSanitizer.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/UriRegexSanitizer.cs new file mode 100644 index 000000000..3c5bad68e --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Sanitizers/UriRegexSanitizer.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Recording.Sanitizers; + +/// +/// Sanitizer for a request URI. +/// +public class UriRegexSanitizer : BaseRegexSanitizer +{ + /// + /// Creates a new instance. + /// + /// The regular expression to match in the request URI. + /// If the regular expression is null. + public UriRegexSanitizer(string regex) : base("UriRegexSanitizer") + { + Regex = regex ?? throw new ArgumentNullException(nameof(regex)); + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/TestRandom.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/TestRandom.cs new file mode 100644 index 000000000..d3b6bccb6 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/TestRandom.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Recording; + +/// +/// Represents an implementation of the class used for test recordings. +/// +public class TestRandom : Random +{ + private RecordedTestMode _mode; + + /// + /// Initializes a new instance of the class. + /// + /// The recorded test mode. + /// The seed value. + public TestRandom(RecordedTestMode mode, int seed) : base(seed) + { + _mode = mode; + } + + /// + /// Generates a new based on the recorded test mode. + /// + /// A new . + public Guid NewGuid() + { + if (_mode == RecordedTestMode.Live) + { + return Guid.NewGuid(); + } + + var bytes = new byte[16]; + NextBytes(bytes); + return new Guid(bytes); + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/TestRecording.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/TestRecording.cs new file mode 100644 index 000000000..d58573f9e --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/TestRecording.cs @@ -0,0 +1,250 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Globalization; +using System.Security.Cryptography; +using OpenAI.TestFramework.Recording.Matchers; +using OpenAI.TestFramework.Recording.Proxy; +using OpenAI.TestFramework.Recording.RecordingProxy; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework.Recording; + +/// +/// Represents a test recording session. This is used to record or playback requests and responses. It also provides +/// a random generator that is consistent between recording and playback sessions. +/// +public class TestRecording : IAsyncDisposable +{ + /// + /// The key to use to store the random seed in the recording. + /// + public const string RandomSeedVariableKey = "RandomSeed"; + + private SortedDictionary _variables; + + /// + /// Creates a new instance. + /// + /// The unique identifier for the recording. + /// The current recording mode. + /// The test proxy service instance to use for the recording. + /// (Optional) Any variables populate this recording this. This is normally used in + /// playback mode to pass in any variables saved as part of the recording. + /// Any of the required parameters are null. + /// Some expected values were missing or null. + /// The current recording mode is not supported. + public TestRecording(string id, RecordedTestMode mode, ProxyService proxy, IDictionary? variables = null) + { + ID = id ?? throw new ArgumentNullException(nameof(id)); + Mode = mode; + Proxy = proxy ?? throw new ArgumentNullException(nameof(proxy)); + _variables = variables == null + ? new() + : new(variables); + + if (Proxy.Client == null) + { + throw new InvalidOperationException("Recording test proxy did not have a client defined"); + } + + int seed; + switch (Mode) + { + case RecordedTestMode.Live: + Random = new TestRandom(Mode, GetRandomSeed()); + break; + + case RecordedTestMode.Record: + seed = GetRandomSeed(); + _variables[RandomSeedVariableKey] = seed.ToString(CultureInfo.InvariantCulture); + Random = new TestRandom(Mode, seed); + break; + + case RecordedTestMode.Playback: + if (Variables.TryGetValue(RandomSeedVariableKey, out string? seedString) + && int.TryParse(seedString, NumberStyles.Integer, CultureInfo.InvariantCulture, out seed)) + { + Random = new TestRandom(Mode, seed); + } + else + { + // To maximise backwards compatibility with the recordings from the previous test framework, we'll just use a random + // seed if one wasn't set instead of failing here. Worst case, we'll get recording mismatches if this is not configured + // correctly. + Random = new TestRandom(Mode, GetRandomSeed()); + } + break; + + default: + throw new NotSupportedException("Unsupported recording mode: " + Mode); + } + } + + /// + /// Gets the unique identifier for this recording. + /// + public string ID { get; } + + /// + /// Gets the current recording mode. + /// + public RecordedTestMode Mode { get; } + + /// + /// Gets the random generator to use for this recording. Using this ensures consistent random values generated during + /// recording, as well as during playback. + /// + public TestRandom Random { get; } + + /// + /// Gets the proxy service associated with the recording. + /// + protected internal ProxyService Proxy { get; } + + /// + /// Gets any variables associated with the recording. + /// + protected IReadOnlyDictionary Variables => _variables; + + /// + /// Disposes of the recording session. If you were recording, this will try to save your captured requests and + /// responses. If you were playing back, this will stop the playback session. + /// + /// Asynchronous task + public virtual ValueTask DisposeAsync() => FinishAsync(true); + + /// + /// Finishes the recording session. This will stop recording or playback. If you were recording, you can use + /// to determine whether or not captured requests and responses will be saved. + /// + /// True to save any captured requests and responses to the file specified in your + /// . False to not save. This is only used if + /// you were recording. + /// The cancellation token to use. + /// Asynchronous task + /// If the recording mode is not supported. + public async virtual ValueTask FinishAsync(bool save, CancellationToken token = default) + { + switch (Mode) + { + case RecordedTestMode.Live: + // nothing to see here, move along + break; + case RecordedTestMode.Playback: + await Proxy.Client.StopPlaybackAsync(ID, token).ConfigureAwait(false); + break; + case RecordedTestMode.Record: + await Proxy.Client.StopRecordingAsync(ID, _variables, !save, token).ConfigureAwait(false); + break; + default: + throw new NotSupportedException("The following mode is not supported: " + Mode); + } + + Proxy.ThrowOnErrors(); + } + + /// + /// Gets a recorded variable. + /// + /// The name of the variable. + /// The variable value, or null if the variable was not set. + public virtual string? GetVariable(string name) + { + return _variables.GetValueOrDefault(name); + } + + /// + /// Sets a recorded variable to a value. + /// + /// The name of the variable. + /// The value to set. + public virtual void SetVariable(string name, string value) + { + _variables[name] = value; + } + + /// + /// Gets a recorded variable, or if it was not set, creates and adds a new variable. + /// + /// The name of the variable. + /// The factory used to create a value if none was previously set. + /// The already existing value, or the newly added value. + public virtual string GetOrAddVariable(string name, Func valueFactory) + { + string? value; + if (!_variables.TryGetValue(name, out value) || value == null) + { + value = valueFactory(); + SetVariable(name, value); + } + + return value; + } + + /// + /// Gets the options to use as the options for creating transport to pass to clients. This will allow the clients to + /// forward requests to the test proxy. + /// + /// The options to use. + public virtual ProxyTransportOptions GetProxyTransportOptions() + { + return new() + { + HttpEndpoint = Proxy.HttpEndpoint, + HttpsEndpoint = Proxy.HttpsEndpoint, + Mode = Mode, + RecordingId = ID, + RequestId = Random.NewGuid().ToString() + }; + } + + /// + /// Applies recording options to the current recording. + /// + /// The recording options to apply for this recording/playback session. + /// The cancellation token to use. + /// Asynchronous task + public virtual async Task ApplyOptions(TestRecordingOptions options, CancellationToken token) + { + if (options.Sanitizers.Any()) + { + await Proxy.Client.AddSanitizersAsync(options.Sanitizers, ID, token).ConfigureAwait(false); + } + + if (options.SanitizersToRemove.Any()) + { + await Proxy.Client.RemoveSanitizersAsync(options.SanitizersToRemove, ID, token).ConfigureAwait(false); + } + + if (Mode == RecordedTestMode.Playback) + { + BaseMatcher matcher = options.Matcher ?? new CustomMatcher() + { + CompareBodies = options.CompareBodies, + ExcludedHeaders = options.ExcludedHeaders.JoinOrNull(","), + IgnoredHeaders = options.IgnoredHeaders.JoinOrNull(","), + IgnoredQueryParameters = options.IgnoredQueryParameters.JoinOrNull(","), + }; + + await Proxy.Client.SetMatcherAsync(matcher, ID, token).ConfigureAwait(false); + + foreach (var transform in options.Transforms) + { + await Proxy.Client.AddTransformAsync(transform, ID, token).ConfigureAwait(false); + } + } + } + + private static int GetRandomSeed() + { +#if NET6_0_OR_GREATER + return RandomNumberGenerator.GetInt32(int.MaxValue); +#else + byte[] bytes = new byte[4]; + using var rng = RandomNumberGenerator.Create(); + rng.GetBytes(bytes); + return BitConverter.ToInt32(bytes, 0); +#endif + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/TestRecordingMismatchException.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/TestRecordingMismatchException.cs new file mode 100644 index 000000000..3f6af0242 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/TestRecordingMismatchException.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Runtime.Serialization; + +namespace OpenAI.TestFramework.Recording; + +/// +/// Exception thrown when the test recording does not match during playback. +/// +[Serializable] +public class TestRecordingMismatchException : Exception +{ + /// + /// Creates a new instance + /// + public TestRecordingMismatchException() + { + } + + /// + /// Creates a new instance. + /// + /// The exception message. + public TestRecordingMismatchException(string message) : base(message) + { + } + + /// + /// Creates a new instance. + /// + /// The exception message. + /// The inner exception. + public TestRecordingMismatchException(string? message, Exception? innerException = null) : base(message, innerException) + { + } + +#if !NET8_0_OR_GREATER + /// + protected TestRecordingMismatchException(SerializationInfo info, StreamingContext context) : base(info, context) + { + } +#endif +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/TestRecordingOptions.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/TestRecordingOptions.cs new file mode 100644 index 000000000..de97b0d89 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/TestRecordingOptions.cs @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using OpenAI.TestFramework.Recording.Matchers; +using OpenAI.TestFramework.Recording.RecordingProxy; +using OpenAI.TestFramework.Recording.Sanitizers; +using OpenAI.TestFramework.Recording.Transforms; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework.Recording; + +/// +/// Options to configure a test recording. This can be used to set sanitizers to apply to the URI, headers, and/or body of a request +/// before matching, and before saving the recording. This can also be used to specify which matcher will be used to match a request +/// to a recorded one during playback. Finally this can be used to set the transforms applied to responses from the test proxy. +/// +public class TestRecordingOptions +{ + /// + /// Creates a new instance + /// + public TestRecordingOptions() + { } + + /// + /// The list of sanitizers to apply to request before matching, and before saving a recording. + /// + public IList Sanitizers { get; } = new List(); + + /// + /// Gets or sets the matcher to use. If this is unset, a custom matcher will be created based on the options specified in this class. + /// + public BaseMatcher? Matcher { get; set; } + + /// + /// The list of transforms to apply when returning a response during playback. + /// + public IList Transforms { get; } = new List(); + + /// + /// The sanitizers to remove from the list of default sanitizers. More details about default sanitizers can be found here: + /// https://github.com/Azure/azure-sdk-tools/blob/main/tools/test-proxy/Azure.Sdk.Tools.TestProxy/README.md#removing-a-sanitizer. + /// + /// You can find the list of sanitizer IDs to remove in two ways: + /// + /// Sending a GET request to http://{proxy_endpoint}/Info/Active + /// Looking at the source code for the test proxy here: + /// https://github.com/Azure/azure-sdk-tools/blob/main/tools/test-proxy/Azure.Sdk.Tools.TestProxy/Common/SanitizerDictionary.cs + /// + /// + public ISet SanitizersToRemove { get; } = new HashSet() + { + // For now, we should leave the default sanitizers in place since it is better to err on the side of caution + }; + + /// + /// Query parameters that we are only interested in checking if a value is set, but don't care about the actual value set. + /// + public ISet IgnoredQueryParameters { get; } = new HashSet(); + + /// + /// Headers that we are only interested in checking if a value is set, but don't care about the actual value set. + /// + public ISet IgnoredHeaders { get; } = new HashSet() + { + "Date", + "x-ms-date", + "User-Agent", + }; + + /// + /// Headers to completely disregard when recording and matching. In other words it is as if these headers were never set. + /// + public ISet ExcludedHeaders { get; } = new HashSet() + { +#if NETFRAMEWORK + // .Net framework will add some headers not found in newer .Net versions so let's completely ignore them here. It is also + // different in how it handles setting the Content-Length header when there is no body as compared to .Net + "Connection", + "Content-Length", +#endif + }; + + /// + /// Whether or not we want to compare bodies from the request and the recorded request during playback. Default + /// is true. + /// + public bool CompareBodies { get; set; } = true; + + /// + /// A function used to override if recording is enabled for a particular request. This will override other settings present + /// here. + /// + public Func? RequestOverride { get; set; } + + /// + /// Helper method to simplify sanitizing specific headers values. This will add a entry + /// to . The default replacement value will be set to . + /// + /// The keys to sanitize. + public void SanitizeHeaders(params string[] keys) + => SanitizeHeaders(Default.SanitizedValue, keys); + + /// + /// Helper method to simplify sanitizing specific headers values. This will add a entry + /// to . + /// + /// The value to replace matches with. + /// The keys to sanitize. + public virtual void SanitizeHeaders(string sanitizedValue, IEnumerable keys) + { + if (keys == null) + { + return; + } + + foreach (var key in keys) + { + Sanitizers.Add(new HeaderRegexSanitizer(key) { Value = sanitizedValue }); + } + } + + /// + /// Helper method to sanitize specific parts of a JSON request body. This will add a entry + /// to for each JSON path provided in . The default replacement value + /// will be set to . + /// + /// The JSON paths to sanitize. + public void SanitizeJsonBody(params string[] jsonPaths) + => SanitizeJsonBody(Default.SanitizedValue, jsonPaths); + + /// + /// Helper method to sanitize specific parts of a JSON request body. This will add a entry + /// to for each JSON path provided in . + /// + /// The value to replace matches with. + /// The JSON paths to sanitize. + public virtual void SanitizeJsonBody(string sanitizedValue, IEnumerable jsonPaths) + { + if (jsonPaths == null) + { + return; + } + + foreach (var key in jsonPaths) + { + Sanitizers.Add(new BodyKeySanitizer(key) { Value = sanitizedValue }); + } + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Transforms/BaseTransform.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Transforms/BaseTransform.cs new file mode 100644 index 000000000..11be5e1be --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Transforms/BaseTransform.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Text.Json; +using System.Text.Json.Serialization; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework.Recording.Transforms; + +/// +/// Base class for test recording proxy transforms. Transforms are applied when returning a request during playback. +/// +public abstract class BaseTransform : IUtf8JsonSerializable +{ + /// + /// Creates a new instance. + /// + /// The type of this sanitizer (e.g. GeneralRegexSanitizer). + /// If the type was null. + protected BaseTransform(string type) + { + Type = type ?? throw new ArgumentNullException(nameof(Type)); + } + + /// + /// Gets the type of the sanitizer (e.g. HeaderRegexSanitizer). + /// + [JsonIgnore] + public string Type { get; } + + /// + public virtual void Write(Utf8JsonWriter writer, JsonSerializerOptions? options = null) + { + // By default use reflection based serialization + JsonSerializer.Serialize(writer, this, GetType(), Default.InnerRecordingJsonOptions); + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Transforms/HeaderTransform.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Transforms/HeaderTransform.cs new file mode 100644 index 000000000..4817f84c5 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Recording/Transforms/HeaderTransform.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Recording.Transforms; + +/// +/// Transform applied to headers before the response is generated during recording playback. +/// +public class HeaderTransform : BaseTransform +{ + /// + /// Creates a new instance. + /// + /// The response header to set. + /// If the is null. + public HeaderTransform(string key) : base("HeaderTransform") + { + Key = key ?? throw new ArgumentNullException(nameof(key)); + } + + /// + /// Gets the header to transform. + /// + public string Key { get; } + + /// + /// Gets or sets the value to set. + /// + public string? Value { get; set; } + + /// + /// The condition to apply for this transform. If the condition is not met, no transform is performed. + /// + public Condition? Condition { get; set; } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/SyncOnlyAttribute.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/SyncOnlyAttribute.cs new file mode 100644 index 000000000..2d00681f9 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/SyncOnlyAttribute.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using NUnit.Framework; + +namespace OpenAI.TestFramework; + +/// +/// Attribute that can be applied to a test to indicate it only runs in synchronous mode. +/// +[AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)] +public class SyncOnlyAttribute() : NUnitAttribute +{ +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/AndPreFilters.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/AndPreFilters.cs new file mode 100644 index 000000000..714bb78e1 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/AndPreFilters.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Reflection; +using NUnit.Framework.Interfaces; + +namespace OpenAI.TestFramework.Utils; + +/// +/// Represents a pre-filter that combines multiple pre-filters using a logical AND operation. +/// +public class AndPreFilter : IPreFilter +{ + private IEnumerable _filters; + + /// + /// Initializes a new instance. + /// + /// The pre-filters to combine. + public AndPreFilter(params IPreFilter[] filters) : this((IEnumerable)filters) + { } + + /// + /// Initializes a new instance. + /// + /// The pre-filters to combine. + public AndPreFilter(IEnumerable filters) + { + _filters = filters?.Where(p => p != null) ?? Array.Empty(); + } + + /// + public bool IsMatch(Type type) => _filters.All(p => p.IsMatch(type)); + + /// + public bool IsMatch(Type type, MethodInfo method) => _filters.All(p => p.IsMatch(type, method)); +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/AssemblyHelper.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/AssemblyHelper.cs new file mode 100644 index 000000000..ae11a0eed --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/AssemblyHelper.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Reflection; +using System.Runtime.InteropServices; + +namespace OpenAI.TestFramework.Utils +{ + /// + /// Assembly related helper methods + /// + public static class AssemblyHelper + { + /// + /// Gets the value of the named assembly metadata attribute for the assembly where the is defined. + /// + /// The type whose assembly we want to read from. + /// The name of the metadata assembly attribute to read. + /// The value of the metadata attribute, or null if none was specified or could be found. + public static string? GetAssemblyMetadata(string name) + => GetAssemblyMetadata(typeof(T).Assembly, name); + + /// + /// Gets the value of the named assembly metadata attribute from assembly. + /// + /// The assembly to read the metadata attribute from + /// The name of the metadata assembly attribute to read. + /// The value of the metadata attribute, or null if none was specified or could be found. + public static string? GetAssemblyMetadata(this Assembly assembly, string name) + { + return assembly + ?.GetCustomAttributes() + .FirstOrDefault(a => a.Key == name && !string.IsNullOrWhiteSpace(a.Value)) + ?.Value; + } + + /// + /// Gets the root source directory for the assembly that defines the type . + /// + /// The type whose assembly source path we want to read. + /// The directory containing the original source path, or null if it was not set or did not exist. + public static DirectoryInfo? GetAssemblySourceDir() + => GetAssemblySourceDir(typeof(T).Assembly); + + /// + /// Gets the source path for the assembly. In order for this to work, you will need to set the assembly metadata attribute + /// your project file as follows: + /// + /// <ItemGroup> + /// <AssemblyAttribute Include="System.Reflection.AssemblyMetadataAttribute"> + /// <_Parameter1>SourcePath</_Parameter1> + /// <_Parameter2>$(MSBuildProjectDirectory)</_Parameter2> + /// </AssemblyAttribute> + /// </ItemGroup> + /// + /// + /// The assembly whose source path we want to find. + /// The directory containing the original source path, or null if it was not set or did not exist. + public static DirectoryInfo? GetAssemblySourceDir(this Assembly assembly) + { + string? sourcePath = assembly.GetAssemblyMetadata("SourcePath"); + if (sourcePath == null) + { + return null; + } + + DirectoryInfo dir = new(sourcePath); + return dir.Exists + ? dir + : null; + } + + /// + /// Finds the dotnet executable path for the current system. It does this by reading the DOTNET_INSTALL_DIR environment variable + /// first, and then inspecting all folders in the current PATH environment variable. + /// + /// The path to the found dotnet executable, or null if none could be found. + public static FileInfo? GetDotnetExecutable() + { + string dotnetExeName = "dotnet"; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + dotnetExeName += ".exe"; + } + + List searchDirs = + [ + Environment.GetEnvironmentVariable("DOTNET_INSTALL_DIR"), + ..Environment.GetEnvironmentVariable("PATH") + ?.Split(Path.PathSeparator) + ?? Array.Empty() + ]; + + return searchDirs + .Where(dir => !string.IsNullOrWhiteSpace(dir)) + .Select(dir => new FileInfo(Path.Combine(dir!, dotnetExeName))) + .FirstOrDefault(file => file.Exists); + } + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/Default.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/Default.cs new file mode 100644 index 000000000..5774a4774 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/Default.cs @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace OpenAI.TestFramework.Utils; + +/// +/// Options used for various recordings. +/// +public static class Default +{ + private static JsonSerializerOptions? _recordingJsonOptions; + private static JsonSerializerOptions? _innerRecordingJsonOptions; + private static JsonSerializerOptions? _testProxyJsonOptions; + private static TimeSpan? _testProxyWaitTime; + private static TimeSpan? _requestRetryDelay; + private static TimeSpan? _debuggerTestTimeout; + private static TimeSpan? _defaultTestTimeout; + + /// + /// Gets the default value to replace matches with while sanitizing. + /// + public const string SanitizedValue = "Sanitized"; + + /// + /// Gets the JSON serialization options to use for recording sanitizers, matchers, and transforms child instances. + /// + public static JsonSerializerOptions InnerRecordingJsonOptions => _innerRecordingJsonOptions ??= new() + { + PropertyNameCaseInsensitive = true, + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + WriteIndented = true, +#if NET + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, +#else + IgnoreNullValues = true, +#endif + }; + + /// + /// Gets the JSON serialization options to use for recording sanitizers, matchers, and transforms. + /// + public static JsonSerializerOptions RecordingJsonOptions + { + get + { + if (_recordingJsonOptions == null) + { + _recordingJsonOptions = InnerRecordingJsonOptions.Clone(); + _recordingJsonOptions.Converters.Add( + +#if NET6_0 + // .Net 6.0 seems to have a weird bug here. This is not needed for .Net framework, nor .Net 7+ + new Utf8JsonSerializableConverterFactory() +#else + new Utf8JsonSerializableConverter() +#endif + ); + } + + return _recordingJsonOptions; + } + } + + + /// + /// Gets the JSON serialization options to use for the test proxy + /// + public static JsonSerializerOptions TestProxyJsonOptions => _testProxyJsonOptions ??= new() + { + PropertyNameCaseInsensitive = true, + WriteIndented = true, +#if NET + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, +#else + IgnoreNullValues = true, +#endif + }; + + /// + /// The default maximum amount of time to wait to for the test proxy operations to finish (e.g. start up + /// and configuration, or saving a recording and teardown). + /// + public static TimeSpan TestProxyWaitTime => _testProxyWaitTime ??= TimeSpan.FromMinutes(2); + + /// + /// Gets the maximum number of times to retry requests + /// + public const int MaxRequestRetries = 3; + + /// + /// The amount of time to wait between requests. + /// + public static TimeSpan RequestRetryDelay => _requestRetryDelay ??= TimeSpan.FromSeconds(0.8); + + /// + /// The amount of time to wait when the debugger is attached. This is much higher than normal to allow for more time while debugging. + /// + public static TimeSpan DebuggerAttachedTestTimeout => _debuggerTestTimeout ??= TimeSpan.FromMinutes(15); + + /// + /// The default test timeout. + /// + public static TimeSpan TestTimeout => _defaultTestTimeout ??= TimeSpan.FromSeconds(15); +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/Extensions.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/Extensions.cs new file mode 100644 index 000000000..437e8dc05 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/Extensions.cs @@ -0,0 +1,414 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Diagnostics; + +namespace OpenAI.TestFramework.Utils; + +/// +/// String related extension methods. +/// +public static class StringExtensions +{ + /// + /// Ensures that a string ends with a specified suffix. + /// + /// The string value. + /// The suffix to check for. + /// The string comparison type. Default is . + /// The original string if it ended in the suffix, or a new string value with the suffix appended. + public static string EnsureEndsWith(this string value, string suffix, StringComparison comparison = StringComparison.Ordinal) + { + if (value == null) + { + return null!; + } + + if (value.EndsWith(suffix, comparison)) + { + return value; + } + + return value + suffix; + } + + /// + /// Ensures that a string ends with a specified suffix. + /// + /// The string value. + /// The suffix to check for. + /// The string comparison type. Default is . + /// The original string if it ended in the suffix, or a new string value with the suffix appended. + public static string EnsureEndsWith(this string value, char suffix, StringComparison comparison = StringComparison.Ordinal) + => EnsureEndsWith(value, suffix.ToString(), comparison); +} + +/// +/// Extension methods for System.ClientModel types. +/// +public static class ScmExtensions +{ + /// + /// Gets the first value associated with the specified header name from the pipeline request headers. + /// + /// The pipeline request headers. + /// The name of the header. + /// The first non-empty value associated with the specified header name, or null if the header is not found or has no non-empty values. + public static string? GetFirstOrDefault(this PipelineRequestHeaders headers, string name) + { + if (headers?.TryGetValues(name, out IEnumerable? values) == true) + { + return values?.FirstOrDefault(v => !string.IsNullOrWhiteSpace(v)); + } + + return null; + } + + /// + /// Gets the first value associated with the specified header name from the pipeline response headers. + /// + /// The pipeline response headers. + /// The name of the header. + /// The first non-empty value associated with the specified header name, or null if the header is not found or has no non-empty values. + public static string? GetFirstOrDefault(this PipelineResponseHeaders headers, string name) + { + if (headers?.TryGetValues(name, out IEnumerable? values) == true) + { + return values?.FirstOrDefault(v => !string.IsNullOrWhiteSpace(v)); + } + + return null; + } +} + +/// +/// Extensions for collections +/// +public static class CollectionExtensions +{ + /// + /// Adds the elements to a collection. + /// + /// The type of the elements in the collection. + /// The collection to add elements to. + /// The items to add. + public static void AddRange(this ICollection collection, IEnumerable itemsToAdd) + { + foreach (T item in itemsToAdd) + { + collection.Add(item); + } + } + + /// + /// Joins the elements of a collection into a single string using the specified separator. + /// Returns null if the collection is null or empty. + /// + /// The collection of strings to join. + /// The separator string. + /// A string that consists of the elements of the collection joined by the separator, or null if the collection is null or empty. + public static string? JoinOrNull(this IEnumerable values, string separator) + { + if (values == null || !values.Any()) + { + return null; + } + + return string.Join(separator, values); + } + +#if NETFRAMEWORK + /// + /// Gets the value associated with the specified key from the dictionary, or returns the default value if the key is not found. + /// + /// The type of the keys in the dictionary. + /// The type of the values in the dictionary. + /// The dictionary. + /// The key to locate. + /// The value associated with the specified key, or the default value if the key is not found. + public static TVal? GetValueOrDefault(this IReadOnlyDictionary dict, TKey key) + => GetValueOrDefault(dict, key, default!); + + /// + /// Gets the value associated with the specified key from the dictionary, or returns the specified default value if the key is not found. + /// + /// The type of the keys in the dictionary. + /// The type of the values in the dictionary. + /// The dictionary. + /// The key to locate. + /// The default value to return if the key is not found. + /// The value associated with the specified key, or the specified default value if the key is not found. + public static TVal GetValueOrDefault(this IReadOnlyDictionary dict, TKey key, TVal defaultValue) + { + if (dict?.TryGetValue(key, out TVal? value) == true) + { + return value; + } + + return defaultValue; + } +#endif + + /// + /// Gets the value associated with the specified key from the dictionary, or returns the default value if the key is not found. + /// + /// The type of the keys in the dictionary. + /// The type of the values in the dictionary. + /// The dictionary. + /// The key to locate. + /// The value associated with the specified key, or the default value if the key is not found. + public static TVal? GetValueOrDefault(this Dictionary dict, TKey key) where TKey : notnull + => GetValueOrDefault((IDictionary)dict, key, default!); + + /// + /// Gets the value associated with the specified key from the dictionary, or returns the specified default value if the key is not found. + /// + /// The type of the keys in the dictionary. + /// The type of the values in the dictionary. + /// The dictionary. + /// The key to locate. + /// The default value to return if the key is not found. + /// The value associated with the specified key, or the specified default value if the key is not found. + public static TVal GetValueOrDefault(this Dictionary dict, TKey key, TVal defaultValue) where TKey : notnull + => GetValueOrDefault((IDictionary)dict, key, defaultValue); + + /// + /// Gets the value associated with the specified key from the sorted dictionary, or returns the default value if the key is not found. + /// + /// The type of the keys in the sorted dictionary. + /// The type of the values in the sorted dictionary. + /// The sorted dictionary. + /// The key to locate. + /// The value associated with the specified key, or the default value if the key is not found. + public static TVal? GetValueOrDefault(this SortedDictionary dict, TKey key) where TKey : notnull + => GetValueOrDefault((IDictionary)dict, key, default!); + + /// + /// Gets the value associated with the specified key from the sorted dictionary, or returns the specified default value if the key is not found. + /// + /// The type of the keys in the sorted dictionary. + /// The type of the values in the sorted dictionary. + /// The sorted dictionary. + /// The key to locate. + /// The default value to return if the key is not found. + /// The value associated with the specified key, or the specified default value if the key is not found. + public static TVal GetValueOrDefault(this SortedDictionary dict, TKey key, TVal defaultValue) where TKey : notnull + => GetValueOrDefault((IDictionary)dict, key, defaultValue); + + /// + /// Gets the value associated with the specified key from the dictionary, or returns the default value if the key is not found. + /// + /// The type of the keys in the dictionary. + /// The type of the values in the dictionary. + /// The dictionary. + /// The key to locate. + /// The value associated with the specified key, or the default value if the key is not found. + public static TVal? GetValueOrDefault(this IDictionary dict, TKey key) + => GetValueOrDefault(dict, key, default!); + + /// + /// Gets the value associated with the specified key from the dictionary, or returns the specified default value if the key is not found. + /// + /// The type of the keys in the dictionary. + /// The type of the values in the dictionary. + /// The dictionary. + /// The key to locate. + /// The default value to return if the key is not found. + /// The value associated with the specified key, or the specified default value if the key is not found. + public static TVal GetValueOrDefault(this IDictionary dict, TKey key, TVal defaultValue) + { + if (dict?.TryGetValue(key, out TVal? value) == true) + { + return value; + } + + return defaultValue; + } + + /// + /// Gets the value associated with the specified key from the dictionary, or creates and adds a new value if the key did not exist. + /// + /// The type of the keys in the dictionary. + /// The type of the values in the dictionary. + /// The dictionary. + /// The key to locate. + /// The function used to create a value for the key if it is not found in the dictionary. + /// The value associated with the specified key, or the value created by the if the key is not found. + public static TValue GetOrAdd(this IDictionary dictionary, TKey key, Func valueFactory) + { + if (dictionary == null) + { + throw new ArgumentNullException(nameof(dictionary)); + } + + if (!dictionary.TryGetValue(key, out TValue? value)) + { + value = valueFactory(key); + dictionary[key] = value; + } + + return value!; + } + + /// + /// Asynchronously returns the first element of a sequence. + /// is found. + /// + /// The type of the elements in the sequence. + /// The sequence to search. + /// A cancellation token to cancel the operation. + /// Asynchronous task. + public static ValueTask FirstOrDefaultAsync(this IAsyncEnumerable enumerable, CancellationToken token = default) + => FirstOrDefaultAsync(enumerable, _ => true); + + /// + /// Asynchronously returns the first element of a sequence that satisfies a specified condition or a default value if no such element + /// is found. + /// + /// The type of the elements in the sequence. + /// The sequence to search. + /// A function to test each element for a condition. + /// A cancellation token to cancel the operation. + /// Asynchronous task. + public static async ValueTask FirstOrDefaultAsync(this IAsyncEnumerable enumerable, Predicate predicate, CancellationToken token = default) + { + await foreach (T item in enumerable.WithCancellation(token)) + { + if (predicate(item)) + { + return item; + } + } + + return default!; + } + + /// + /// Converts an to a asynchronously. + /// + /// The type of the elements in the enumerable. + /// The to convert. + /// The cancellation token. + /// Asynchronous task to do the conversion. + public static async Task> ToListAsync(this IAsyncEnumerable asyncEnumerable, CancellationToken token = default) + { + List list = new List(); + await foreach (T item in asyncEnumerable.WithCancellation(token)) + { + list.Add(item); + } + return list; + } + + /// + /// Converts an async enumerable of pages to a asynchronously. + /// + /// The type of the elements in the enumerable. + /// The to convert. + /// The cancellation token. + /// Asynchronous task to do the conversion. + public static async Task> ToListAsync(this IAsyncEnumerable> pageAsyncEnumerable, CancellationToken token = default) + { + List list = new List(); + await foreach(PageResult page in pageAsyncEnumerable.WithCancellation(token)) + { + list.AddRange(page.Values); + } + return list; + } +} + +/// +/// Helpers for working with paths. +/// +public static class PathHelpers +{ + /// + /// Create a relative path from one path to another. Paths will be resolved before calculating the difference. + /// + /// The source path the output should be relative to. This path is always considered to be a directory. + /// The destination path. + /// The relative path or if the paths don't share the same root. + public static string GetRelativePath(string relativeTo, string path) + { + +#if NET + return Path.GetRelativePath(relativeTo, path); +#else + relativeTo = Path.GetFullPath(relativeTo).EnsureEndsWith(Path.DirectorySeparatorChar); + path = Path.GetFullPath(path).EnsureEndsWith(Path.DirectorySeparatorChar); + + Uri relativeToUri = new Uri(relativeTo); + Uri pathUri = new Uri(path); + + if (relativeToUri.Scheme != pathUri.Scheme) + { + return path; + } + + Uri relative = relativeToUri.MakeRelativeUri(pathUri); + return Uri.UnescapeDataString(relative.ToString()) + .Replace('/', '\\'); +#endif + } +} + + +/// +/// Extensions for types. +/// +public static class TypeExtensions +{ + /// + /// Determines whether the specified type either implements the open generic type specified, + /// or inherits from the open generic type specified. + /// + /// The type to inspect. + /// The open generic type. + /// The arguments of the closed generic type. + /// True if the type implements, or inherits, or is a closed version of the open type. + [DebuggerStepThrough] + public static bool IsClosedGenericOf(this Type type, Type openGeneric, out Type[] closedTypeArguments) + { + Type? closedType = null; + + if (openGeneric.IsInterface) + { + closedType = type.GetInterfaces() + .FirstOrDefault(iType => IsAssignableToOpen(iType, openGeneric)); + } + + if (closedType == null) + { + for (Type? current = type; current != null && closedType == null; current = current.BaseType) + { + if (IsAssignableToOpen(current, openGeneric)) + { + closedType = current; + } + } + } + + closedTypeArguments = closedType?.GetGenericArguments() ?? Array.Empty(); + return closedType != null; + } + + /// + /// Determines if the type is or inherits from the open generic type. + /// + /// The type. + /// The open generic type. + /// True if the open generic type could be assigned from the type. + [DebuggerStepThrough] + public static bool IsAssignableToOpen(this Type type, Type openGeneric) + { + if (!type.IsGenericType || !type.IsConstructedGenericType) + { + return false; + } + + return openGeneric.IsAssignableFrom(type.GetGenericTypeDefinition()); + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/IUtf8JsonSerializable.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/IUtf8JsonSerializable.cs new file mode 100644 index 000000000..7b81c3532 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/IUtf8JsonSerializable.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Text.Json; + +namespace OpenAI.TestFramework.Utils; + +/// +/// Interface applied to types that can be serialized to JSON. +/// +public interface IUtf8JsonSerializable +{ + /// + /// Writes this instance as JSON to the writer. + /// + /// The writer to write to. + /// The options to use when writing. + void Write(Utf8JsonWriter writer, JsonSerializerOptions? options = null); +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/JsonHelpers.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/JsonHelpers.cs new file mode 100644 index 000000000..ed6752ca4 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/JsonHelpers.cs @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace OpenAI.TestFramework.Utils; + +public static class JsonHelpers +{ + /// + /// Serializes the specified data to a stream using as a UTF-8 encoded JSON text. + /// + /// The type of the data to serialize. + /// The stream to write the serialized data to. + /// The data to serialize. + /// (Optional) Options to use when serializing. + public static void Serialize(Stream stream, T data, JsonSerializerOptions? options = null) + { +#if NETFRAMEWORK + using (Utf8JsonWriter writer = new(stream)) + { + JsonSerializer.Serialize(writer, data, options); + writer.Flush(); + } +#else + JsonSerializer.Serialize(stream, data, options); +#endif + } + + /// + /// Deserializes UTF-8 encoded JSON text from a stream. + /// + /// The type of the data to deserialize. + /// The stream to read the serialized data from. + /// (Optional) Options to use when deserializing. + /// The deserialized data. + public static T? Deserialize(Stream stream, JsonSerializerOptions? options = null) + { +#if NETFRAMEWORK + // For now let's keep it simple and load entire JSON bytes into memory + using MemoryStream buffer = new(); + stream.CopyTo(buffer); + + ReadOnlySpan jsonBytes = buffer.GetBuffer().AsSpan(0, (int)buffer.Length); + return JsonSerializer.Deserialize(jsonBytes, options); +#else + return JsonSerializer.Deserialize(stream, options); +#endif + } + +#if NET6_0_OR_GREATER + // .Net 6 and newer already have the extension method we need defined in JsonSerializer +#else + // TODO FIXME once we move to newer versions of System.Text.Json we can directly use the + // JsonSerializer extension method for elements + public static T? Deserialize(this JsonElement element, JsonSerializerOptions? options = null) + { + using MemoryStream stream = new(); + using Utf8JsonWriter writer = new(stream, new() + { + Encoder = System.Text.Encodings.Web.JavaScriptEncoder.UnsafeRelaxedJsonEscaping, + Indented = false, + SkipValidation = true + }); + element.WriteTo(writer); + writer.Flush(); + + stream.Seek(0, SeekOrigin.Begin); + if (((ulong)stream.Length & 0xffffffff00000000) != 0ul) + { + throw new ArgumentOutOfRangeException("JsonElement is too large"); + } + + ReadOnlySpan span = new(stream.GetBuffer(), 0, (int)stream.Length); + return JsonSerializer.Deserialize(span, options); + } +#endif + + /// + /// Serializes a value to a JsonElement. + /// + /// Type of the data to serialize. + /// The value to serialize. + /// (Optional) Options to use when serializing. + /// The serialized value as a JsonElement. + public static JsonElement SerializeToElement(T value, JsonSerializerOptions? options = null) + { +#if NET6_0_OR_GREATER + return JsonSerializer.SerializeToElement(value, options); +#else + using MemoryStream stream = new(); + Serialize(stream, value, options); + stream.Seek(0, SeekOrigin.Begin); + return JsonDocument.Parse(stream).RootElement; +#endif + } + + /// + /// Creates a clone of the specified JSON serializer options. + /// + /// The JSON serializer options to clone. + /// (Optional) Filter to apply for selecting specific converters to include in the cloned options. + /// A clone of the JSON serializer options. + public static JsonSerializerOptions Clone(this JsonSerializerOptions options, Predicate? converterFilter = null) + { +#if NET + JsonSerializerOptions cloned = new JsonSerializerOptions(options); + if (converterFilter != null) + { + cloned.Converters.Clear(); + foreach (var converter in options.Converters.Where(c => converterFilter(c))) + { + cloned.Converters.Add(converter); + } + } + + return cloned; +#else + JsonSerializerOptions clone = new() + { + AllowTrailingCommas = options.AllowTrailingCommas, + DefaultBufferSize = options.DefaultBufferSize, + DictionaryKeyPolicy = options.DictionaryKeyPolicy, + Encoder = options.Encoder, + IgnoreNullValues = options.IgnoreNullValues, + IgnoreReadOnlyProperties = options.IgnoreReadOnlyProperties, + MaxDepth = options.MaxDepth, + PropertyNameCaseInsensitive = options.PropertyNameCaseInsensitive, + PropertyNamingPolicy = options.PropertyNamingPolicy, + ReadCommentHandling = options.ReadCommentHandling, + WriteIndented = options.WriteIndented, + }; + + foreach (var converter in options.Converters.Where(c => converterFilter?.Invoke(c) ?? true)) + { + clone.Converters.Add(converter); + } + + return clone; +#endif + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/PropertyDelegate.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/PropertyDelegate.cs new file mode 100644 index 000000000..3c2f467cf --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/PropertyDelegate.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Utils; + +/// +/// Represents a delegate for getting and setting property values. +/// +/// The type of the property value. +public struct PropertyDelegate +{ + private Func? _getter; + private Action? _setter; + + /// + /// Initializes a new instance of the struct. + /// + /// The delegate used to get the property value. + /// The delegate used to set the property value. + public PropertyDelegate(Func getter, Action setter) + { + _getter = getter ?? throw new ArgumentNullException(nameof(getter)); + _setter = setter ?? throw new ArgumentNullException(nameof(setter)); + } + + /// + /// Gets the value of the property. + /// + /// The value of the property. + public TVal GetValue() + { + if (_getter != null) + return _getter(); + else + throw new InvalidOperationException("No getter was set"); + } + + /// + /// Sets the value of the property. + /// + /// The value to set. + public void SetValue(TVal val) + { + if (_setter != null) + _setter(val); + else + throw new InvalidOperationException("No setter was set"); + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/SyncAsyncPreFilter.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/SyncAsyncPreFilter.cs new file mode 100644 index 000000000..a2c32fc1a --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/SyncAsyncPreFilter.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Reflection; +using NUnit.Framework.Interfaces; + +namespace OpenAI.TestFramework.Utils +{ + /// + /// Filter to exclude sync only or async only tests in the appropriate test run. + /// + public class SyncAsyncPreFilter : IPreFilter + { + private bool _isAsync; + + /// + /// Creates a new instance. + /// + /// True to filter for an async test run, false to filter for sync test run. + public SyncAsyncPreFilter(bool isAsync) + { + _isAsync = isAsync; + } + + /// + public bool IsMatch(Type type) + => type.GetCustomAttribute() != null; + + /// + public bool IsMatch(Type type, MethodInfo method) + { + if (!IsMatch(type)) + { + return false; + } + + return _isAsync && method.GetCustomAttribute() == null + || !_isAsync && method.GetCustomAttribute() == null; + } + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/TestClientRetryPolicy.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/TestClientRetryPolicy.cs new file mode 100644 index 000000000..517ff5576 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/TestClientRetryPolicy.cs @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Diagnostics; +using System.Reflection; + +namespace OpenAI.TestFramework.Utils; + +/// +/// Represents a retry policy to be used when testing clients. +/// +public class TestClientRetryPolicy : ClientRetryPolicy +{ + private Func _getRetries; + + /// + /// Initializes a new instance of the class. + /// + /// The maximum number of retries. + /// The delay between retries. + /// Indicates whether the delay should be exponential. + public TestClientRetryPolicy(int maxRetries = Utils.Default.MaxRequestRetries, TimeSpan? delay = null, bool exponentialDelay = false) + : base(maxRetries) + { + MaxRetries = MaxRetries; + Delay = delay ?? Utils.Default.RequestRetryDelay; + IsExponentialDelay = exponentialDelay; + + // Of course, even reading the number of retries property on the PipelineMessage is internal only. + // So reflection it is + _getRetries = (Func) + (typeof(PipelineMessage).GetProperty("RetryCount", BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance) + ?.GetGetMethod(true) + ?.CreateDelegate(typeof(Func)) + ?? throw new InvalidOperationException("Failed to get RetryCount property")); + } + + /// + /// Gets the maximum number of retries. + /// + public int MaxRetries { get; } + + /// + /// Gets the delay between retries. + /// + public TimeSpan Delay { get; } + + /// + /// Gets a value indicating whether the delay should be exponential. + /// + public bool IsExponentialDelay { get; } + + /// + protected override TimeSpan GetNextDelay(PipelineMessage message, int tryCount) + { + TimeSpan delay = IsExponentialDelay + ? TimeSpan.FromMilliseconds((1 << tryCount - 1) * Delay.TotalMilliseconds) + : Delay; + + return delay; + } + + /// + protected override bool ShouldRetry(PipelineMessage message, Exception? exception) + { + if (_getRetries(message) >= MaxRetries) + { + return false; + } + + if (!message.ResponseClassifier.TryClassify(message, exception, out bool isRetriable) + && !PipelineMessageClassifier.Default.TryClassify(message, exception, out isRetriable)) + { + Debug.Assert(false, "Failed to classify message"); + } + + return isRetriable; + } + + /// + protected override ValueTask ShouldRetryAsync(PipelineMessage message, Exception? exception) + => new ValueTask(ShouldRetry(message, exception)); +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/TestPipelinePolicy.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/TestPipelinePolicy.cs new file mode 100644 index 000000000..bc004f1bc --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/TestPipelinePolicy.cs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; + +namespace OpenAI.TestFramework.Utils; + +/// +/// A pipeline policy that raises events before a request sent, and after response has been received. +/// +public class TestPipelinePolicy() : PipelinePolicy() +{ + /// + /// Creates a new instance. This will instantiate the and + /// events based on and respectively. + /// + /// (Optional) Action to perform before sending a request. + /// (Optional) Action to perform after a response is received. + public TestPipelinePolicy(Action? requestAction, Action? responseAction) : this() + { + if (requestAction != null) BeforeRequest += (s, e) => requestAction(e); + + if (responseAction != null) AfterResponse += (s, e) => responseAction(e); + } + + /// + /// Event raised before a request is sent. + /// + public event EventHandler? BeforeRequest; + + /// + /// Event raised after a response has been received. + /// + public event EventHandler? AfterResponse; + + /// + public override void Process(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) + { + BeforeRequest?.Invoke(this, message.Request); + ProcessNext(message, pipeline, currentIndex); + if (message.Response != null) + { + AfterResponse?.Invoke(this, message.Response); + } + } + + /// + public override async ValueTask ProcessAsync(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) + { + BeforeRequest?.Invoke(this, message.Request); + await ProcessNextAsync(message, pipeline, currentIndex).ConfigureAwait(false); + if (message.Response != null) + { + AfterResponse?.Invoke(this, message.Response); + } + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/TimespanToMillisecondConverter.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/TimespanToMillisecondConverter.cs new file mode 100644 index 000000000..7045f956b --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/TimespanToMillisecondConverter.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Globalization; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace OpenAI.TestFramework.Utils; + +/// +/// Converter for TimeSpans to/from integer millisecond values in JSON. +/// +public class TimespanToMillisecondConverter : JsonConverter +{ + /// + /// Reads a value from JSON. + /// + /// The to read from. + /// The type of the object to convert. + /// The serializer options. + /// The deserialized value. + public override TimeSpan? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + switch (reader.TokenType) + { + case JsonTokenType.Null: + return null; + + case JsonTokenType.Number: + return TimeSpan.FromMilliseconds(reader.GetInt32()); + + case JsonTokenType.String: + string? strValue = reader.GetString(); + if (int.TryParse(strValue, NumberStyles.Integer, CultureInfo.InvariantCulture, out int milliseconds)) + { + return TimeSpan.FromMilliseconds(milliseconds); + } + else + { + throw new JsonException("Invalid millisecond value: " + strValue); + } + + default: + throw new JsonException($"Don't know how to parse '{reader.TokenType}' as a millisecond value"); + } + } + + /// + /// Writes a value to JSON. + /// + /// The to write to. + /// The value to write. + /// The serializer options. + public override void Write(Utf8JsonWriter writer, TimeSpan? value, JsonSerializerOptions options) + { + if (value == null) + { + writer.WriteNullValue(); + } + else + { + writer.WriteNumberValue((int)Math.Ceiling(value.Value.TotalMilliseconds)); + } + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/Utf8JsonSerializableConverter.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/Utf8JsonSerializableConverter.cs new file mode 100644 index 000000000..3613f66c0 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/Utf8JsonSerializableConverter.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace OpenAI.TestFramework.Utils; + +/// +/// Converter for types that implement . +/// +public class Utf8JsonSerializableConverter : JsonConverter +{ + private static Utf8JsonSerializableConverter? s_instance; + + /// + /// Gets the shared instance of the converter. + /// + public static Utf8JsonSerializableConverter Instance => s_instance ??= new(); + + /// + public override bool CanConvert(Type typeToConvert) + => typeof(IUtf8JsonSerializable).IsAssignableFrom(typeToConvert); + + /// + public override IUtf8JsonSerializable Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + => throw new NotSupportedException("Only writing JSON is supported"); + + /// + public override void Write(Utf8JsonWriter writer, IUtf8JsonSerializable value, JsonSerializerOptions options) + => value.Write(writer); +} + +#if NET6_0 +/// +/// .Net 6.0 has some odd quirks and is particularly pedantic with converters so directly using Utf8JsonSerializableConverter would +/// result in an InvalidCastException. The work around is to use a converter factory. Thankfully, neither .Net Framework, nor .Net 7+ +/// exhibit this behavior. +/// +public class Utf8JsonSerializableConverterFactory : JsonConverterFactory +{ + public override bool CanConvert(Type typeToConvert) => typeof(IUtf8JsonSerializable).IsAssignableFrom(typeToConvert); + public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) + => (JsonConverter?)Activator.CreateInstance(typeof(InnerConverter<>).MakeGenericType(typeToConvert)); + + private class InnerConverter : JsonConverter where T : IUtf8JsonSerializable + { + public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + => (T)Utf8JsonSerializableConverter.Instance.Read(ref reader, typeToConvert, options); + + public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options) + => Utf8JsonSerializableConverter.Instance.Write(writer, value, options); + } +} +#endif diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/WindowsJob.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/WindowsJob.cs new file mode 100644 index 000000000..a5eb2570b --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/src/Utils/WindowsJob.cs @@ -0,0 +1,208 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Security; + +namespace OpenAI.TestFramework.Utils.Processes; + +/// +/// A job provides a way to link several processes together on Windows. In this way, they can all be +/// terminated by calling the method. The OS will also automatically terminate +/// the linked processes if the owner process terminates. +/// +public class WindowsJob : IDisposable +{ + private IntPtr _jobHandle; + private int _disposed; + + /// + /// Creates a new job + /// + /// (Optional) The name to associate + public WindowsJob(string? name = null) + { + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + throw new NotSupportedException("This is only supported on Windows platforms"); + } + + var securityAttributes = new SECURITY_ATTRIBUTES() + { + nLength = (uint)Marshal.SizeOf(typeof(SECURITY_ATTRIBUTES)), + lpSecurityDescriptor = IntPtr.Zero, + bInheritHandle = false + }; + + // Create the job handle + _jobHandle = CreateJobObject(ref securityAttributes, name); + if (_jobHandle == IntPtr.Zero) + { + throw new COMException("Failed to create job", Marshal.GetLastWin32Error()); + } + + // Set the job state so that all associated handles are closed + var extendedInfo = new JOBOBJECT_EXTENDED_LIMIT_INFORMATION() + { + BasicLimitInformation = new JOBOBJECT_BASIC_LIMIT_INFORMATION() + { + LimitFlags = JobObjectLimits.LIMIT_KILL_ON_JOB_CLOSE + } + }; + + int length = Marshal.SizeOf(typeof(JOBOBJECT_EXTENDED_LIMIT_INFORMATION)); + IntPtr ptr = IntPtr.Zero; + try + { + ptr = Marshal.AllocHGlobal(length); + Marshal.StructureToPtr(extendedInfo, ptr, false); + + bool success = SetInformationJobObject( + _jobHandle, + JOBOBJECTINFOCLASS.JobObjectExtendedLimitInformation, + ptr, + (uint)length); + + if (!success) + { + throw new COMException("Failed to set the job extended information", Marshal.GetLastWin32Error()); + } + } + finally + { + Marshal.FreeHGlobal(ptr); + } + } + + /// + /// Adds a process to the job + /// + /// The process to add + public void Add(Process process) + { + if (process == null) + { + throw new ArgumentNullException(nameof(process)); + } + else if (process.Handle == IntPtr.Zero) + { + throw new ArgumentException("The specified process has a NULL handle"); + } + + bool success = AssignProcessToJobObject(_jobHandle, process.Handle); + if (!success) + { + throw new COMException("Failed to add the process to the job", Marshal.GetLastWin32Error()); + } + } + + /// + /// Closes the job. This will close all linked processes + /// + public void Close() + { + CloseHandle(_jobHandle); + _jobHandle = IntPtr.Zero; + } + + /// + /// Disposes of the job. This will also close all linked process. + /// + public void Dispose() + { + if (Interlocked.Exchange(ref _disposed, 1) == 0) + { + Close(); + } + } + + #region native methods + + [DllImport("kernel32.dll", CharSet = CharSet.Auto, SetLastError = true)] + internal static extern IntPtr CreateJobObject([In] ref SECURITY_ATTRIBUTES lpJobAttributes, string? lpName); + + [DllImport("kernel32.dll", CharSet = CharSet.Auto, SetLastError = true)] + internal static extern IntPtr OpenJobObject(uint dwDesiredAccess, bool bInheritHandles, string lpName); + + [DllImport("kernel32.dll", SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + internal static extern bool AssignProcessToJobObject(IntPtr hJob, IntPtr hProcess); + + [DllImport("kernel32.dll", SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + internal static extern bool SetInformationJobObject( + [In] IntPtr hJob, + JOBOBJECTINFOCLASS JobObjectInfoClass, + [In] IntPtr lpJobObjectInfo, + uint cbJobObjectInfoLength); + + [DllImport("kernel32.dll", SetLastError = true)] +#if NETFRAMEWORK + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] +#endif + [SuppressUnmanagedCodeSecurity] + [return: MarshalAs(UnmanagedType.Bool)] + internal static extern bool CloseHandle(IntPtr hObject); + +#endregion + + #region native types + + [StructLayout(LayoutKind.Sequential)] + internal struct SECURITY_ATTRIBUTES + { + public uint nLength; + public IntPtr lpSecurityDescriptor; + public bool bInheritHandle; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct JOBOBJECT_BASIC_LIMIT_INFORMATION + { + public Int64 PerProcessUserTimeLimit; + public Int64 PerJobUserTimeLimit; + public JobObjectLimits LimitFlags; + public UIntPtr MinimumWorkingSetSize; + public UIntPtr MaximumWorkingSetSize; + public UInt32 ActiveProcessLimit; + public UIntPtr Affinity; + public UInt32 PriorityClass; + public UInt32 SchedulingClass; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct JOBOBJECT_EXTENDED_LIMIT_INFORMATION + { + public JOBOBJECT_BASIC_LIMIT_INFORMATION BasicLimitInformation; + public IO_COUNTERS IoInfo; + public UIntPtr ProcessMemoryLimit; + public UIntPtr JobMemoryLimit; + public UIntPtr PeakProcessMemoryUsed; + public UIntPtr PeakJobMemoryUsed; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct IO_COUNTERS + { + public UInt64 ReadOperationCount; + public UInt64 WriteOperationCount; + public UInt64 OtherOperationCount; + public UInt64 ReadTransferCount; + public UInt64 WriteTransferCount; + public UInt64 OtherTransferCount; + } + + internal enum JOBOBJECTINFOCLASS + { + JobObjectExtendedLimitInformation = 9, + } + + internal enum JobObjectLimits : UInt32 + { + LIMIT_KILL_ON_JOB_CLOSE = 0x00002000, + } +} + +#endregion diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/tests/AdaptersTests.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/tests/AdaptersTests.cs new file mode 100644 index 000000000..266d388a3 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/tests/AdaptersTests.cs @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.Diagnostics; +using NUnit.Framework; +using OpenAI.TestFramework.Adapters; +using OpenAI.TestFramework.Mocks; + +namespace OpenAI.TestFramework.Tests; + +[TestFixture] +public class AdaptersTests +{ + public CancellationToken Token => + new CancellationTokenSource(Debugger.IsAttached + ? TimeSpan.FromMinutes(15) + : TimeSpan.FromSeconds(5)) + .Token; + + [Test] + public async Task TestSyncToAsyncEnumerator() + { + const int start = 0; + const int num = 100; + + IEnumerator sync = Enumerable.Range(start, num).GetEnumerator(); + await using SyncToAsyncEnumerator async = new(sync, Token); + + for (int i = start; i < num; i++) + { + bool success = await async.MoveNextAsync(); + Assert.That(success, Is.True); + Assert.That(async.Current, Is.EqualTo(i)); + } + } + + [Test] + public async Task TestSyncToAsyncResultCollection() + { + const int start = 0; + const int num = 100; + + MockCollectionResult sync = new(() => Enumerable.Range(start, num)); + SyncToAsyncCollectionResult asyncAdapter = new(sync); + + await using var asyncEnumerator = asyncAdapter.GetAsyncEnumerator(Token); + + for (int i = start; i < num; i++) + { + bool success = await asyncEnumerator.MoveNextAsync(); + Assert.That(success, Is.True); + Assert.That(asyncEnumerator.Current, Is.EqualTo(i)); + } + } + + [Test] + public async Task TestFailedSyncToAsyncResultCollection() + { + MockCollectionResult sync = new(Fail); + SyncToAsyncCollectionResult asyncAdapter = new(sync); + + await using var asyncEnumerator = asyncAdapter.GetAsyncEnumerator(Token); + Assert.ThrowsAsync(() => asyncEnumerator.MoveNextAsync().AsTask()); + } + + [Test] + public async Task TestSyncToAsyncPageableCollection() + { + const int start = 0; + const int num = 100; + const int itemsPerPage = 10; + int expectedPages = (int)Math.Ceiling((double)num / itemsPerPage); + + MockPageCollection sync = new(() => Enumerable.Range(start, num), new MockPipelineResponse(), itemsPerPage); + SyncToAsyncPageCollection asyncAdapter = new(sync); + + int numPages = 0; + int expected = 0; + await foreach (var page in asyncAdapter) + { + numPages++; + foreach (int actual in page.Values) + { + Assert.That(actual, Is.EqualTo(expected)); + expected++; + } + } + + Assert.That(numPages, Is.EqualTo(expectedPages)); + } + + [Test] + public async Task TestFailedSyncToAsyncPageableCollection() + { + MockPageCollection sync = new(Fail, new MockPipelineResponse()); + SyncToAsyncPageCollection asyncAdapter = new(sync); + + await using var asyncEnumerator = ((IAsyncEnumerable>)asyncAdapter).GetAsyncEnumerator(Token); + Assert.ThrowsAsync(() => asyncEnumerator.MoveNextAsync().AsTask()); + } + + private static IEnumerable Fail() + { + throw new ApplicationException("This should fail"); + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/tests/AutoSyncAsyncTests.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/tests/AutoSyncAsyncTests.cs new file mode 100644 index 000000000..44f13bee3 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/tests/AutoSyncAsyncTests.cs @@ -0,0 +1,201 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using NUnit.Framework; +using OpenAI.TestFramework.Tests.Helpers; + +namespace OpenAI.TestFramework.Tests; + +public class AutoSyncAsyncTests(bool useAsync) : ClientTestBase(useAsync) +{ + private static readonly string EX_MSG = Guid.NewGuid().ToString(); + + [Test] + [SyncOnly] + public void OnlyInSyncMode() + { + Assert.That(IsAsync, Is.False); + } + + [Test] + [AsyncOnly] + public void OnlyInAsyncMode() + { + Assert.That(IsAsync, Is.True); + } + + [Test] + public void CanGetOriginal() + { + MockClient original = new MockClient(); + + MockClient instrumented = WrapClient(original); + Assert.That(instrumented, Is.Not.Null); + Assert.That(ReferenceEquals(original, instrumented), Is.False); + Assert.That(typeof(MockClient).IsAssignableFrom(instrumented.GetType()), Is.True); + + MockClient recovered = UnWrap(instrumented); + Assert.That(recovered, Is.Not.Null); + Assert.That(ReferenceEquals(original, recovered), Is.True); + } + + [Test] + public void CanGetContext() + { + var context = new MockClientContext(); + + MockClient client = WrapClient(new MockClient(), context); + Assert.That(client, Is.Not.Null); + + var recoveredContext = GetClientContext(client) as MockClientContext; + Assert.That(recoveredContext, Is.Not.Null); + Assert.That(recoveredContext!.Id, Is.EqualTo(context.Id)); + Assert.That(ReferenceEquals(recoveredContext, context), Is.True); + } + + [Test] + public async Task TaskWorks() + { + MockClient client = WrapClient(new MockClient()); + await client.DoAsync(); + AssertCorrectFunctionCalled(client); + } + + [Test] + public void FailedTaskWorks() + { + MockClient client = WrapClient(new MockClient()); + ArgumentException? ex = Assert.ThrowsAsync(() => client.FailAsync(EX_MSG)); + Assert.That(ex, Is.Not.Null); + Assert.That(ex!.Message, Is.EqualTo(EX_MSG)); + AssertCorrectFunctionCalled(client); + } + + [Test] + public async Task TaskWithResultWorks() + { + MockClient client = WrapClient(new MockClient()); + int count = await client.CountAsync(); + Assert.That(count, Is.EqualTo(IsAsync ? 12 : 5)); + AssertCorrectFunctionCalled(client); + } + + [Test] + public void FailedTaskWithResultWorks() + { + MockClient client = WrapClient(new MockClient()); + ArgumentException? ex = Assert.ThrowsAsync(() => client.FailWithResultAsync(EX_MSG)); + Assert.That(ex, Is.Not.Null); + Assert.That(ex!.Message, Is.EqualTo(EX_MSG)); + AssertCorrectFunctionCalled(client); + } + + [Test] + public async Task ResultCollectionWorks() + { + const int num = 3; + const int increment = 2; + + MockClient client = WrapClient(new MockClient()); + AsyncCollectionResult coll = client.ResultCollectionAsync(num, increment); + + Assert.IsNotNull(coll); + Assert.That(coll.GetRawResponse(), Is.Not.Null); + Assert.That(coll.GetRawResponse().Status, Is.EqualTo(200)); + Assert.That(coll.GetRawResponse().ReasonPhrase, Is.EqualTo("OK")); + + int numResults = 0; + await foreach (int i in coll) + { + Assert.That(i, Is.EqualTo(numResults * increment)); + numResults++; + } + + Assert.That(numResults, Is.EqualTo(num)); + AssertCorrectFunctionCalled(client); + } + + [Test] + public void FailedResultCollection() + { + MockClient client = WrapClient(new MockClient()); + + // For now we mimic how the OpenAI and Azure OpenAI libraries work in that no service requests are sent + // until we try to enumerate the async collections. So exceptions aren't expected initially + AsyncCollectionResult coll = client.FailResultCollectionAsync(EX_MSG); + Assert.That(coll, Is.Not.Null); + + IAsyncEnumerator enumerator = coll.GetAsyncEnumerator(); + Assert.That(enumerator, Is.Not.Null); + ArgumentException? ex = Assert.ThrowsAsync(() => enumerator.MoveNextAsync().AsTask()); + Assert.That(ex, Is.Not.Null); + Assert.That(ex!.Message, Is.EqualTo(EX_MSG)); + AssertCorrectFunctionCalled(client); + } + + [Test] + public async Task PageableCollectionWorks() + { + const int num = 50; + const int increment = 1; + const int itemsPerPage = 20; + int expectedPages = (int)Math.Ceiling((double)num / itemsPerPage); + + MockClient client = WrapClient(new MockClient()); + AsyncPageCollection coll = client.PageableCollectionAsync(num, increment, itemsPerPage); + Assert.IsNotNull(coll); + + int numPages = 0; + int numResults = 0; + await foreach(PageResult page in coll) + { + Assert.That(page.GetRawResponse(), Is.Not.Null); + Assert.That(page.GetRawResponse().Status, Is.EqualTo(200)); + Assert.That(page.GetRawResponse().ReasonPhrase, Is.EqualTo("OK")); + + numPages++; + foreach (int actual in page.Values) + { + Assert.That(actual, Is.EqualTo(numResults * increment)); + numResults++; + } + } + + Assert.That(numResults, Is.EqualTo(num)); + Assert.That(numPages, Is.EqualTo(expectedPages)); + AssertCorrectFunctionCalled(client); + } + + [Test] + public void FailedPageableCollection() + { + MockClient client = WrapClient(new MockClient()); + + // For now we mimic how the OpenAI and Azure OpenAI libraries work in that no service requests are sent + // until we try to enumerate the async collections. So exceptions aren't expected initially + AsyncPageCollection coll = client.FailPageableCollectionAsync(EX_MSG); + Assert.That(coll, Is.Not.Null); + + IAsyncEnumerator> enumerator = ((IAsyncEnumerable>)coll).GetAsyncEnumerator(); + Assert.That(enumerator, Is.Not.Null); + ArgumentException? ex = Assert.ThrowsAsync(() => enumerator.MoveNextAsync().AsTask()); + Assert.That(ex, Is.Not.Null); + Assert.That(ex!.Message, Is.EqualTo(EX_MSG)); + AssertCorrectFunctionCalled(client); + } + + private void AssertCorrectFunctionCalled(MockClient client, int expectedCalls = 1) + { + if (IsAsync) + { + Assert.That(client.AsyncHit, Is.EqualTo(expectedCalls)); + Assert.That(client.SyncHit, Is.EqualTo(0)); + } + else + { + Assert.That(client.AsyncHit, Is.EqualTo(0)); + Assert.That(client.SyncHit, Is.EqualTo(expectedCalls)); + } + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/tests/Helpers/MockClient.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/tests/Helpers/MockClient.cs new file mode 100644 index 000000000..9eac6054d --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/tests/Helpers/MockClient.cs @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.Runtime.CompilerServices; +using OpenAI.TestFramework.Mocks; + +namespace OpenAI.TestFramework.Tests.Helpers; + +public class MockClient +{ + private int _asyncHit; + private int _syncHit; + + public virtual int AsyncHit => _asyncHit; + public virtual int SyncHit => _syncHit; + + public virtual Task DoAsync() + { + Interlocked.Increment(ref _asyncHit); + return Task.Delay(200); + } + + public virtual void Do() + { + Interlocked.Increment(ref _syncHit); + } + + public virtual Task FailAsync(string message) + { + Interlocked.Increment(ref _asyncHit); + return Task.FromException(new ArgumentException(message)); + } + + public virtual void Fail(string message) + { + Interlocked.Increment(ref _syncHit); + throw new ArgumentException(message); + } + + public virtual async Task CountAsync() + { + Interlocked.Increment(ref _asyncHit); + await Task.Delay(100).ConfigureAwait(false); + return 12; + } + + public virtual int Count() + { + Interlocked.Increment(ref _syncHit); + return 5; + } + + public virtual Task FailWithResultAsync(string message) + { + Interlocked.Increment(ref _asyncHit); + return Task.FromException(new ArgumentException(message)); + } + + public virtual int FailWithResult(string message) + { + Interlocked.Increment(ref _syncHit); + throw new ArgumentException(message); + } + + public virtual AsyncCollectionResult ResultCollectionAsync(int num, int increment = 5) + { + Interlocked.Increment(ref _asyncHit); + return new MockAsyncCollectionResult(() => EnumerateAsync(num, increment)); + } + + public virtual CollectionResult ResultCollection(int num, int increment = 5) + { + Interlocked.Increment(ref _syncHit); + return new MockCollectionResult(() => Enumerate(num, increment)); + } + + public virtual AsyncCollectionResult FailResultCollectionAsync(string message) + { + Interlocked.Increment(ref _asyncHit); + return new MockAsyncCollectionResult(() => FailEnumerateAsync(message)); + } + + public virtual CollectionResult FailResultCollection(string message) + { + Interlocked.Increment(ref _syncHit); + return new MockCollectionResult(() => FailEnumerate(message)); + } + + public virtual AsyncPageCollection PageableCollectionAsync(int num, int increment, int itemsPerPage) + { + Interlocked.Increment(ref _asyncHit); + return new MockAsyncPageCollection(() => EnumerateAsync(num, increment), new MockPipelineResponse(), itemsPerPage); + } + + public virtual PageCollection PageableCollection(int num, int increment, int itemsPerPage) + { + Interlocked.Increment(ref _syncHit); + return new MockPageCollection(() => Enumerate(num, increment), new MockPipelineResponse(), itemsPerPage); + } + + public virtual AsyncPageCollection FailPageableCollectionAsync(string message) + { + Interlocked.Increment(ref _asyncHit); + return new MockAsyncPageCollection(() => FailEnumerateAsync(message), new MockPipelineResponse()); + } + + public virtual PageCollection FailPageableCollection(string message) + { + Interlocked.Increment(ref _syncHit); + return new MockPageCollection(() => FailEnumerate(message), new MockPipelineResponse()); + } + + private async IAsyncEnumerable EnumerateAsync(int num, int increment, [EnumeratorCancellation] CancellationToken token = default) + { + int running = 0; + for (int i = 0; i < num; i++, running += increment) + { + await Task.Delay(100); + yield return running; + } + } + + private IEnumerable Enumerate(int num, int increment) + { + int running = 0; + for (int i = 0; i < num; i++, running += increment) + { + yield return running; + } + } + + private async IAsyncEnumerable FailEnumerateAsync(string message, [EnumeratorCancellation] CancellationToken token = default) + { + bool c = true; + await Task.Delay(100).ConfigureAwait(false); + if (c) + { + throw new ArgumentException(message); + } + + yield break; + } + + private IEnumerable FailEnumerate(string message) + { + throw new ArgumentException(message); + } +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/tests/Helpers/MockClientContext.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/tests/Helpers/MockClientContext.cs new file mode 100644 index 000000000..e36ed1e1f --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/tests/Helpers/MockClientContext.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OpenAI.TestFramework.Tests.Helpers; + +public class MockClientContext +{ + public string Id { get; } = Guid.NewGuid().ToString(); +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/tests/MockStringServiceTests.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/tests/MockStringServiceTests.cs new file mode 100644 index 000000000..d5d3edff4 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/tests/MockStringServiceTests.cs @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Reflection; +using NUnit.Framework; +using OpenAI.TestFramework.Mocks; +using OpenAI.TestFramework.Recording.Proxy; +using OpenAI.TestFramework.Recording.Proxy.Service; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework.Tests; + +public class MockStringServiceTests : RecordedClientTestBase +{ + private const string c_basePath = "data"; + + public MockStringServiceTests(bool isAsync) + : base(isAsync, null) + { + RecordingOptions.SanitizersToRemove.Add("AZSDK3430"); // $..id + } + + public DirectoryInfo RepositoryRoot { get; } = FindRepoRoot(); + + [Test] + public async Task AddAndGet() + { + const string id = "first.one"; + const string expected = "The first value goes here"; + + using MockRestService service = new(c_basePath); + var options = ConfigureClientOptions(new ClientPipelineOptions()); + using var client = WrapClient(new MockRestServiceClient(service.HttpEndpoint, options)); + + ClientResult add = await client.AddAsync(id, expected, Token); + Assert.That(add, Is.Not.Null); + Assert.That(add.GetRawResponse().Status, Is.EqualTo(200)); + + string? retrieved = await client.GetAsync("first.one", Token); + Assert.That(retrieved, Is.EqualTo(expected)); + } + + [Test] + public async Task AddAndDelete() + { + const string id = "first.one"; + const string expected = "The first value goes here"; + + using MockRestService service = new(c_basePath); + var options = ConfigureClientOptions(new ClientPipelineOptions()); + using var client = WrapClient(new MockRestServiceClient(service.HttpEndpoint, options)); + + ClientResult add = await client.AddAsync(id, expected, Token); + Assert.That(add, Is.Not.Null); + Assert.That(add.GetRawResponse().Status, Is.EqualTo(200)); + + bool deleted = await client.RemoveAsync(id, Token); + Assert.That(deleted, Is.True); + + string? retrieved = await client.GetAsync("first.one", Token); + Assert.That(retrieved, Is.Null); + } + + #region overrides + + protected override ProxyServiceOptions CreateProxyServiceOptions() + => new() + { + DotnetExecutable = AssemblyHelper.GetDotnetExecutable()?.FullName!, + TestProxyDll = AssemblyHelper.GetAssemblyMetadata("TestProxyPath")!, + DevCertFile = Path.Combine( + RepositoryRoot.FullName, + "eng", + "common", + "testproxy", + "dotnet-devcert.pfx"), + DevCertPassword = "password", + StorageLocationDir = RepositoryRoot.FullName, + }; + + protected override RecordingStartInformation CreateRecordingSessionStartInfo() + => new() + { + RecordingFile = GetRecordingFile(), + AssetsFile = GetAssetsFile() + }; + + #endregion + + #region helper methods + + private static DirectoryInfo FindRepoRoot() + { + /** + * This code assumes that we are running in the standard Azure .Net SDK repository layout. With this in mind, + * we generally assume that we are running our test code from + * /artifacts/bin/// + * So to find the root we keep navigating up until we find a folder with a .git subfolder + * + * Another alternative would be to call: git rev-parse --show-toplevel + */ + + DirectoryInfo? current = new FileInfo(Assembly.GetExecutingAssembly().Location).Directory; + while (current != null && !current.EnumerateDirectories(".git").Any()) + { + current = current.Parent; + } + + return current + ?? throw new InvalidOperationException("Could not determine the root folder for this repository"); + } + + private string GetRecordingFile() + { + DirectoryInfo sourceDir = AssemblyHelper.GetAssemblySourceDir() + ?? throw new InvalidOperationException("Could not determine the source path for this assembly"); + string relativeDir = PathHelpers.GetRelativePath(RepositoryRoot.FullName, sourceDir.FullName); + return Path.Combine( + relativeDir, + "SessionRecords", + GetType().Name, + GetRecordedTestFileName()); + } + + private string? GetAssetsFile() + { + DirectoryInfo? sourceDir = AssemblyHelper.GetAssemblySourceDir() + ?? throw new InvalidOperationException("Could not determine the source path for this assembly"); + + // walk up the tree until we hit either the repository root, or found a folder with an "assets.json" file + for (; sourceDir != null && sourceDir?.FullName != RepositoryRoot.FullName; sourceDir = sourceDir.Parent) + { + string assetsFile = Path.Combine(sourceDir!.FullName, "assets.json"); + if (File.Exists(assetsFile)) + { + return assetsFile; + } + } + + return null; + } + + #endregion +} diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/tests/OpenAI.TestFramework.Tests.csproj b/.dotnet.azure/sdk/openai/tools/TestFramework/tests/OpenAI.TestFramework.Tests.csproj new file mode 100644 index 000000000..e6934e292 --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/tests/OpenAI.TestFramework.Tests.csproj @@ -0,0 +1,21 @@ + + + + $(RequiredTargetFrameworks);net8.0 + enable + enable + latest + + + + + + + + + + + + + + diff --git a/.dotnet.azure/sdk/openai/tools/TestFramework/tests/ProxyServiceTests.cs b/.dotnet.azure/sdk/openai/tools/TestFramework/tests/ProxyServiceTests.cs new file mode 100644 index 000000000..496afbc6d --- /dev/null +++ b/.dotnet.azure/sdk/openai/tools/TestFramework/tests/ProxyServiceTests.cs @@ -0,0 +1,351 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel; +using System.ClientModel.Primitives; +using NUnit.Framework; +using OpenAI.TestFramework.Mocks; +using OpenAI.TestFramework.Recording; +using OpenAI.TestFramework.Recording.Matchers; +using OpenAI.TestFramework.Recording.Proxy; +using OpenAI.TestFramework.Recording.Proxy.Service; +using OpenAI.TestFramework.Recording.RecordingProxy; +using OpenAI.TestFramework.Recording.Sanitizers; +using OpenAI.TestFramework.Recording.Transforms; +using OpenAI.TestFramework.Utils; + +namespace OpenAI.TestFramework.Tests +{ + [NonParallelizable] + public class ProxyServiceTests(bool isAsync) : ClientTestBase(isAsync) + { + #region Properties and setup/teardown methods + + public DirectoryInfo? RecordingDir { get; private set; } + + public FileInfo? RecordingFile { get; private set; } + + [SetUp] + public void CreateRecordingFile() + { + RecordingDir = new DirectoryInfo(Path.Combine(Path.GetTempPath(), "RecordingTests", Guid.NewGuid().ToString())); + if (!RecordingDir.Exists) + { + RecordingDir.Create(); + } + + RecordingFile = new FileInfo(Path.Combine(RecordingDir.FullName, Path.GetRandomFileName() + ".json")); + } + + [TearDown] + public void DeleteRecordingFile() + { + if (RecordingFile != null) + { + RecordingFile.Delete(); + } + + if (RecordingDir != null) + { + RecordingDir.Delete(true); + } + } + + #endregion + + [Test] + public async Task StartProxy() + { + using ProxyService proxy = await CreateProxyServiceAsync(); + + Assert.That(proxy.HttpEndpoint, Is.Not.Null); + Assert.That(proxy.HttpEndpoint.Port, Is.GreaterThan(0).And.LessThanOrEqualTo(ushort.MaxValue)); + Assert.That(proxy.HttpsEndpoint, Is.Not.Null); + Assert.That(proxy.HttpsEndpoint.Port, Is.GreaterThan(0).And.LessThanOrEqualTo(ushort.MaxValue)); + + ProxyClientResult available = await proxy.Client.ListAvailableAsync(Token); + Assert.That(available, Is.Not.Null); + Assert.That(available.GetRawResponse(), Is.Not.Null); + Assert.That(available.GetRawResponse().Status, Is.EqualTo(200)); + Assert.That(available.Value, Is.Not.Null); + Assert.That(available.Value, Does.Contain("BodilessMatcher")); + } + + [Test] + public async Task AddSanitizers() + { + using ProxyService proxy = await CreateProxyServiceAsync(); + + List sanitizers = + [ + new BodyKeySanitizer("body.key"), + new BodyRegexSanitizer("(.*)") + { + GroupForReplace = "1", + Condition = new Recording.Condition() + { + ResponseHeader = new() + { + Key = "Content-Type", + ValueRegex = "json$" + }, + UriRegex = "https://[^/]+/sub" + } + }, + new HeaderRegexSanitizer("Authentication") + { + Value = "replacement", + GroupForReplace = "1", + Regex = "^Bearer " + }, + new UriRegexSanitizer("https://[^/]+/sub") + { + GroupForReplace = "1", + Value = "replacement" + } + ]; + + ProxyClientResult> result = await proxy.Client.AddSanitizersAsync(sanitizers, token: Token); + Assert.That(result, Is.Not.Null); + Assert.That(result.GetRawResponse(), Is.Not.Null); + Assert.That(result.GetRawResponse().Status, Is.EqualTo(200)); + Assert.That(result.Value, Is.Not.Null); + Assert.That(result.Value, Has.Count.EqualTo(sanitizers.Count)); + } + + [Test] + public async Task SetMatcher() + { + using ProxyService proxy = await CreateProxyServiceAsync(); + + BaseMatcher[] matchers = + [ + ExistingMatcher.Headerless, + ExistingMatcher.Bodiless, + new CustomMatcher() + { + CompareBodies = false, + ExcludedHeaders = "Authorization", + IgnoredHeaders = "Content-Length,Content-Type", + IgnoredQueryParameters = "page,version", + IgnoreQueryOrdering = true, + } + ]; + + foreach (var matcher in matchers) + { + ProxyClientResult result = await proxy.Client.SetMatcherAsync(matcher, token: Token); + Assert.That(result, Is.Not.Null); + Assert.That(result.GetRawResponse(), Is.Not.Null); + Assert.That(result.GetRawResponse().Status, Is.EqualTo(200)); + } + } + + [Test] + public async Task SetTransform() + { + using ProxyService proxy = await CreateProxyServiceAsync(); + + HeaderTransform transform = new("X-Client-RequestId") + { + Value = "replacement", + Condition = new() + { + UriRegex = "http.*://[^/]+/(.*)" + } + }; + + ProxyClientResult result = await proxy.Client.AddTransformAsync(transform, token: Token); + Assert.That(result, Is.Not.Null); + Assert.That(result.GetRawResponse(), Is.Not.Null); + Assert.That(result.GetRawResponse().Status, Is.EqualTo(200)); + } + + [Test] + public async Task StartStopRecording() + { + const string key1 = "key1"; + string value1 = Guid.NewGuid().ToString(); + const string key2 = "the.others"; + string value2 = "value"; + + using ProxyService proxy = await CreateProxyServiceAsync(); + + RecordingStartInformation startInfo = new() + { + RecordingFile = RecordingFile!.FullName, + }; + + ProxyClientResult result = await proxy.Client.StartRecordingAsync(startInfo, token: Token); + Assert.That(result, Is.Not.Null); + Assert.That(result.GetRawResponse(), Is.Not.Null); + Assert.That(result.GetRawResponse().Status, Is.EqualTo(200)); + + string recordingId = result.RecordingId!; + Assert.That(recordingId, Is.Not.Null); + + Dictionary additional = new() + { + [key1] = value1, + [key2] = value2, + }; + + result = await proxy.Client.StopRecordingAsync(recordingId, additional, false, Token); + + // At this point we should have a recording file + string recordedJson = File.ReadAllText(RecordingFile.FullName); + Assert.That(recordedJson, Does.Contain(key1) + .And.Contain(value1) + .And.Contain(key2) + .And.Contain(value2)); + } + + [Test] + public async Task RecordAndPlayback() + { + using ProxyService recordingProxyService = await CreateProxyServiceAsync(); + RecordingStartInformation startInfo = new() { RecordingFile = RecordingFile!.FullName }; + + using MockRestService mockRestService = new(); + TestRecordingOptions recordingOptions = new() + { + SanitizersToRemove = + { + "AZSDK3430", // $..id + } + }; + + string id1; + string id2; + + // Start recording, and capture some requests + { + ProxyClientResult result = await recordingProxyService.Client.StartRecordingAsync(startInfo, Token); + Assert.That(result, Is.Not.Null); + Assert.That(result.RecordingId, !Is.Null.Or.Empty); + string recordingId = result.RecordingId!; + + await using TestRecording recording = new(recordingId, RecordedTestMode.Record, recordingProxyService); + await recording.ApplyOptions(recordingOptions, Token); + + id1 = recording.Random.NewGuid().ToString(); + id2 = recording.Random.NewGuid().ToString(); + + await SendRequestsAsync(recording, mockRestService.HttpEndpoint, id1, id2, Token); + } + + // validate the service has what we expect + var serviceIds = mockRestService.GetAll() + .Select(e => e.id) + .ToArray(); + Assert.That(serviceIds, Is.EquivalentTo(new[] { id1, id2 })); + + mockRestService.Reset(); + + // Playback the recording + { + ProxyClientResult> result = await recordingProxyService.Client.StartPlaybackAsync(startInfo, Token); + Assert.That(result, Is.Not.Null); + Assert.That(result.RecordingId, !Is.Null.Or.Empty); + string recordingId = result.RecordingId!; + + await using TestRecording playback = new(recordingId, RecordedTestMode.Playback, recordingProxyService, result.Value); + await playback.ApplyOptions(recordingOptions, Token); + + string id = playback.Random.NewGuid().ToString(); + Assert.That(id, Is.EqualTo(id1)); + id = playback.Random.NewGuid().ToString(); + Assert.That(id, Is.EqualTo(id2)); + + await SendRequestsAsync(playback, mockRestService.HttpEndpoint, id1, id2, Token); + } + + // since we are playing back, the service should not have been called + Assert.That(mockRestService.GetAll().Count(), Is.EqualTo(0)); + + static async Task SendRequestsAsync(TestRecording recording, Uri restEndpoint, string id1, string id2, CancellationToken token) + { + const string value1 = "The value for the first item"; + const string value2 = "The secondary value goes here"; + const string id3 = "random"; + const string value3 = "Sure why not"; + + ClientPipelineOptions options = new(); + options.RetryPolicy = new TestClientRetryPolicy(0, TimeSpan.FromMilliseconds(100)); + options.Transport = new ProxyTransport(recording.GetProxyTransportOptions()); + + using MockRestServiceClient client = new(restEndpoint, options); + + ClientResult add = await client.AddAsync(id1, value1, token); + Assert.That(add, Is.Not.Null); + Assert.That(add.GetRawResponse().Status, Is.EqualTo(200)); + + add = await client.AddAsync(id2, value2, token); + Assert.That(add, Is.Not.Null); + Assert.That(add.GetRawResponse().Status, Is.EqualTo(200)); + + add = await client.AddAsync(id3, value3, token); + Assert.That(add, Is.Not.Null); + Assert.That(add.GetRawResponse().Status, Is.EqualTo(200)); + + ClientResult get = await client.GetAsync(id2, token); + Assert.That(add, Is.Not.Null); + Assert.That(add.GetRawResponse().Status, Is.EqualTo(200)); + Assert.That(get.Value, Is.EqualTo(value2)); + + get = await client.GetAsync(id3, token); + Assert.That(add, Is.Not.Null); + Assert.That(add.GetRawResponse().Status, Is.EqualTo(200)); + Assert.That(get.Value, Is.EqualTo(value3)); + + ClientResult remove = await client.RemoveAsync(id3, token); + Assert.That(remove.Value, Is.True); + + remove = await client.RemoveAsync("does.not.exist", token); + Assert.That(remove.Value, Is.False); + + get = await client.GetAsync(id3, token); + Assert.That(get, Is.Not.Null); + Assert.That(get.GetRawResponse().Status, Is.EqualTo(404)); + Assert.That(get.Value, Is.Null); + } + } + + #region helper methods + + private async Task CreateProxyServiceAsync() + { + ProxyService? proxy = null; + try + { + proxy = await ProxyService.CreateNewAsync( + new ProxyServiceOptions() + { + DotnetExecutable = AssemblyHelper.GetDotnetExecutable()?.FullName!, + TestProxyDll = AssemblyHelper.GetAssemblyMetadata("TestProxyPath")!, + StorageLocationDir = RecordingDir!.FullName + }, + Token); + + Assert.That(proxy, Is.Not.Null); + Assert.DoesNotThrow(proxy.ThrowOnErrors); + Assert.That(proxy.Client, Is.Not.Null); + + var wrappedClient = WrapClient(proxy.Client); + var setter = typeof(ProxyService).GetMethod("SetClient", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Instance) + ?? throw new InvalidOperationException("Could not find the ProxyService.SetClient method"); + setter.Invoke(proxy, [wrappedClient]); + + var ret = proxy; + proxy = null; + return ret; + } + finally + { + proxy?.Dispose(); + } + } + + #endregion + } +} diff --git a/.dotnet.azure/src/Azure.AI.OpenAI.csproj b/.dotnet.azure/src/Azure.AI.OpenAI.csproj deleted file mode 100644 index a0036c103..000000000 --- a/.dotnet.azure/src/Azure.AI.OpenAI.csproj +++ /dev/null @@ -1,41 +0,0 @@ - - - - - - - Azure OpenAI's official extension package for using OpenAI's .NET library with the Azure OpenAI Service. - - Azure.AI.OpenAI Client Library - 2.0.0 - beta.3 - Microsoft Azure OpenAI - true - - netstandard2.0 - true - $(NoWarn);CS1591;AZC0012;AZC0102;CS8002;CS0436;AZC0112;OPENAI001; - true - enable - preview - disable - Unsigned - - - - - - - - - - - - - - diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index ba2c7d825..af321d330 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -15,12 +15,12 @@ jobs: version_suffix_args: ${{ format('--version-suffix="alpha.{0}"', github.run_number) }} steps: - name: Setup .NET - uses: actions/setup-dotnet@v1 + uses: actions/setup-dotnet@v4 with: dotnet-version: 8.x - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Build run: dotnet build @@ -52,7 +52,7 @@ jobs: working-directory: .dotnet - name: Upload artifact - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: build-artifacts path: ${{github.workspace}}/artifacts @@ -77,16 +77,52 @@ jobs: working-directory: .dotnet azure_build: # Development mirror only; validate AOAI compilation - runs-on: ubuntu-latest + strategy: + matrix: + # Builds on all supported .Net versions on all supported OS platforms. It also + # distributes the tests across these OSes. + os_net: + - runs-on: ubuntu-latest + test_framework: net8.0 + - runs-on: macos-latest + test_framework: net6.0 + - runs-on: windows-latest + test_framework: net462 + runs-on: ${{ matrix.os_net.runs-on }} steps: - - name: Setup .NET - uses: actions/setup-dotnet@v1 + - name: Setup .NET 6 and .Net 8 + uses: actions/setup-dotnet@v4 with: - dotnet-version: 8.x + # .Net Framework 4.6.2 is pre-installed on Windows 10 versions 1607 and newer + dotnet-version: | + 6.x + 8.x - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Build - run: dotnet build - working-directory: .dotnet.azure \ No newline at end of file + run: dotnet build -p:Configuration=unsigned + working-directory: .dotnet.azure/sdk/openai + + - name: Test (${{ matrix.os_net.test_framework }}) + run: >- + dotnet test + --framework ${{ matrix.os_net.test_framework }} + --no-build + --filter "(TestCategory!=Live)" + --logger "trx;LogFileName=Azure.AI.OpenAI.Tests.trx" + --logger:"console;verbosity=quiet" + --blame-crash-dump-type full + --blame-hang-dump-type full + --blame-hang-timeout 15minutes + --results-directory "${{github.workspace}}/TestResults" + Azure.AI.OpenAI/tests/Azure.AI.OpenAI.Tests.csproj + working-directory: .dotnet.azure/sdk/openai + + - name: Publish test results + uses: actions/upload-artifact@v4 + with: + name: TestResults-${{ matrix.os_net.runs-on }}-${{ matrix.os_net.test_framework }} + path: TestResults + if: ${{ always() }}