diff --git a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs index b7f6431121f0..3e80ae5d0394 100644 --- a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs +++ b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs @@ -11,6 +11,9 @@ public static partial class AIOpenAIModelFactory } public static partial class AzureOpenAIModelFactory { + public static Azure.AI.OpenAI.ChatChoice ChatChoice(Azure.AI.OpenAI.ChatMessage message = null, int index = 0, Azure.AI.OpenAI.CompletionsFinishReason reason = default(Azure.AI.OpenAI.CompletionsFinishReason)) { throw null; } + public static Azure.AI.OpenAI.ChatChoice ChatChoice(int index = 0, Azure.AI.OpenAI.CompletionsFinishReason finishReason = default(Azure.AI.OpenAI.CompletionsFinishReason)) { throw null; } + public static Azure.AI.OpenAI.ChatCompletions ChatCompletions(string id = null, System.DateTimeOffset created = default(System.DateTimeOffset), System.Collections.Generic.IEnumerable choices = null, Azure.AI.OpenAI.CompletionsUsage usage = null) { throw null; } public static Azure.AI.OpenAI.Choice Choice(string text = null, int index = 0, Azure.AI.OpenAI.CompletionsLogProbabilityModel logProbabilityModel = null, Azure.AI.OpenAI.CompletionsFinishReason finishReason = default(Azure.AI.OpenAI.CompletionsFinishReason)) { throw null; } } public partial class ChatChoice diff --git a/sdk/openai/Azure.AI.OpenAI/src/Azure.AI.OpenAI.csproj b/sdk/openai/Azure.AI.OpenAI/src/Azure.AI.OpenAI.csproj index 374e26185599..9229976ca5a4 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Azure.AI.OpenAI.csproj +++ b/sdk/openai/Azure.AI.OpenAI/src/Azure.AI.OpenAI.csproj @@ -24,5 +24,8 @@ + + + diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs index 9914245030d6..4d6083f9f7f9 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs @@ -1,13 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// - #nullable disable using System; using System.Collections.Generic; -using System.Linq; namespace Azure.AI.OpenAI { @@ -31,5 +28,46 @@ public static Choice Choice(string text = null, int index = default, Completions return new Choice(text, index, logProbabilityModel, finishReason); } + + /// Initializes a new instance of ChatChoice. + /// The ordered index associated with this chat completions choice. + /// The reason that this chat completions choice completed its generated. + /// A new instance for mocking. + public static ChatChoice ChatChoice(int index = default, CompletionsFinishReason finishReason = default) + { + return new ChatChoice(index, finishReason); + } + + /// Initializes a new instance of ChatChoice. + /// The chat message for a given chat completions prompt. + /// The ordered index associated with this chat completions choice. + /// The reason that this chat completions choice completed its generated. + /// A new instance for mocking. + public static ChatChoice ChatChoice(ChatMessage message = null, int index = default, CompletionsFinishReason reason = default) + { + return new ChatChoice(message, index, reason, null); + } + + /// Initializes a new instance of ChatCompletions. + /// A unique identifier associated with this chat completions response. + /// The first timestamp associated with generation activity for this completions response. + /// The collection of completions choices associated with this completions response. + /// Usage information for tokens processed and generated as part of this completions operation. + /// A new instance for mocking. + public static ChatCompletions ChatCompletions(string id = null, DateTimeOffset created = default(DateTimeOffset), IEnumerable choices = null, CompletionsUsage usage = null) + { + choices ??= new List(); + usage ??= AIOpenAIModelFactory.CompletionsUsage(); + + long constrainedUnixTimeInSec = Math.Max( + Math.Min(created.ToUnixTimeSeconds(), int.MaxValue), + int.MinValue); + + return new ChatCompletions( + id: id ?? string.Empty, + internalCreatedSecondsAfterUnixEpoch: (int)constrainedUnixTimeInSec, + choices: choices, + usage: usage); + } } } diff --git a/sdk/openai/Azure.AI.OpenAI/tests/Azure.AI.OpenAI.Tests.csproj b/sdk/openai/Azure.AI.OpenAI/tests/Azure.AI.OpenAI.Tests.csproj index 1782df2211f8..c09b9f56b63f 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/Azure.AI.OpenAI.Tests.csproj +++ b/sdk/openai/Azure.AI.OpenAI/tests/Azure.AI.OpenAI.Tests.csproj @@ -3,7 +3,7 @@ $(RequiredTargetFrameworks) - $(NoWarn);CS1591 + $(NoWarn);CS1591;SA1508;SA1507;SA1505 diff --git a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceModelFactoryTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceModelFactoryTests.cs new file mode 100644 index 000000000000..b4dbf736dcec --- /dev/null +++ b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceModelFactoryTests.cs @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Linq; +using NUnit.Framework; + +namespace Azure.AI.OpenAI.Tests +{ + [TestFixture] + public class OpenAIInferenceModelFactoryTests + { + [Test] + public void TestCompletionsLogProbabilityModel() + { + var logProbabilityModel = AIOpenAIModelFactory.CompletionsLogProbabilityModel( + new[] { "one", "two" }, + new float?[] { 0.9f, 0.72f }); + Assert.That(logProbabilityModel, Is.Not.Null); + Assert.That(logProbabilityModel.Tokens.Count, Is.EqualTo(2)); + Assert.That(logProbabilityModel.Tokens[0], Is.EqualTo("one")); + Assert.That(logProbabilityModel.Tokens[1], Is.EqualTo("two")); + Assert.That(logProbabilityModel.TokenLogProbabilities.Count, Is.EqualTo(2)); + Assert.That(logProbabilityModel.TokenLogProbabilities[0], Is.EqualTo(0.9F).Within(2).Percent); + Assert.That(logProbabilityModel.TokenLogProbabilities[1], Is.EqualTo(0.72F).Within(2).Percent); + Assert.That(logProbabilityModel.TopLogProbabilities, Is.Empty); + Assert.That(logProbabilityModel.TextOffsets, Is.Empty); + } + + [Test] + public void TestChatChoices() + { + var expectedChoices = new[] + { + new { role = ChatRole.Assistant, text = "First one", index = 0, reason = CompletionsFinishReason.ContentFiltered }, + new { role = ChatRole.System, text = "Second one", index = -1, reason = CompletionsFinishReason.Stopped }, + new { role = ChatRole.User, text = "Final one", index = 3, reason = CompletionsFinishReason.TokenLimitReached }, + }; + + var chatChoices = expectedChoices + .Select(e => AzureOpenAIModelFactory.ChatChoice( + new ChatMessage(e.role, e.text), + e.index, + e.reason)) + .ToArray(); + Assert.That(chatChoices, Is.All.Not.Null); + + for (int i = 0; i < chatChoices.Length; i++) + { + var actual = chatChoices[i]; + var expected = expectedChoices[i]; + + Assert.That(actual.Message, Is.Not.Null); + Assert.That(actual.Message.Role, Is.EqualTo(expected.role)); + Assert.That(actual.Message.Content, Is.EqualTo(expected.text)); + Assert.That(actual.Index, Is.EqualTo(expected.index)); + Assert.That(actual.FinishReason, Is.EqualTo(expected.reason)); + } + } + + [Test] + public void TestChatCompletions() + { + string expectedId = Guid.NewGuid().ToString(); + DateTimeOffset expectedCreationTime = DateTimeOffset.Now; + + var expectedChoices = new[] + { + new { role = ChatRole.Assistant, text = "First one", index = 0, reason = CompletionsFinishReason.ContentFiltered }, + new { role = ChatRole.System, text = "Second one", index = -1, reason = CompletionsFinishReason.Stopped }, + new { role = ChatRole.User, text = "Final one", index = 3, reason = CompletionsFinishReason.TokenLimitReached }, + }; + + var chatChoices = expectedChoices + .Select(e => AzureOpenAIModelFactory.ChatChoice( + new ChatMessage(e.role, e.text), + e.index, + e.reason)) + .ToArray(); + + var chatCompletions = AzureOpenAIModelFactory.ChatCompletions( + expectedId, + expectedCreationTime, + chatChoices, + AIOpenAIModelFactory.CompletionsUsage(2, 5, 7)); + + Assert.That(chatCompletions, Is.Not.Null); + Assert.That(chatCompletions.Id, Is.EqualTo(expectedId)); + Assert.That(chatCompletions.Created, Is.EqualTo(expectedCreationTime).Within(TimeSpan.FromSeconds(1))); // Internally we use Unix time with second precision + Assert.That(chatCompletions.Choices, Is.EquivalentTo(chatChoices)); + Assert.That(chatCompletions.Usage, Is.Not.Null); + Assert.That(chatCompletions.Usage.CompletionTokens, Is.EqualTo(2)); + Assert.That(chatCompletions.Usage.PromptTokens, Is.EqualTo(5)); + Assert.That(chatCompletions.Usage.TotalTokens, Is.EqualTo(7)); + } + } +}