diff --git a/sdk/cognitiveservices/azopenai/assets.json b/sdk/cognitiveservices/azopenai/assets.json index 1d6ca0f9347d..e927cd33d454 100644 --- a/sdk/cognitiveservices/azopenai/assets.json +++ b/sdk/cognitiveservices/azopenai/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "go", "TagPrefix": "go/cognitiveservices/azopenai", - "Tag": "go/cognitiveservices/azopenai_49bcacb061" + "Tag": "go/cognitiveservices/azopenai_bf5b07347b" } diff --git a/sdk/cognitiveservices/azopenai/client_chat_completions_test.go b/sdk/cognitiveservices/azopenai/client_chat_completions_test.go new file mode 100644 index 000000000000..b15bdc3ec06b --- /dev/null +++ b/sdk/cognitiveservices/azopenai/client_chat_completions_test.go @@ -0,0 +1,187 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai_test + +import ( + "context" + "errors" + "io" + "os" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" + "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" + "github.com/stretchr/testify/require" +) + +var chatCompletionsRequest = azopenai.ChatCompletionsOptions{ + Messages: []*azopenai.ChatMessage{ + { + Role: to.Ptr(azopenai.ChatRole("user")), + Content: to.Ptr("Count to 10, with a comma between each number, no newlines and a period at the end. E.g., 1, 2, 3, ..."), + }, + }, + MaxTokens: to.Ptr(int32(1024)), + Temperature: to.Ptr(float32(0.0)), + Model: &openAIChatCompletionsModelDeployment, +} + +var expectedContent = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10." +var expectedRole = azopenai.ChatRoleAssistant + +func TestClient_GetChatCompletions(t *testing.T) { + cred, err := azopenai.NewKeyCredential(apiKey) + require.NoError(t, err) + + chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, chatCompletionsModelDeployment, newClientOptionsForTest(t)) + require.NoError(t, err) + + testGetChatCompletions(t, chatClient) +} + +func TestClient_GetChatCompletionsStream(t *testing.T) { + cred, err := azopenai.NewKeyCredential(apiKey) + require.NoError(t, err) + + chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, chatCompletionsModelDeployment, newClientOptionsForTest(t)) + require.NoError(t, err) + + testGetChatCompletionsStream(t, chatClient) +} + +func TestClient_OpenAI_GetChatCompletions(t *testing.T) { + chatClient := newOpenAIClientForTest(t) + testGetChatCompletions(t, chatClient) +} + +func TestClient_OpenAI_GetChatCompletionsStream(t *testing.T) { + chatClient := newOpenAIClientForTest(t) + testGetChatCompletionsStream(t, chatClient) +} + +func testGetChatCompletions(t *testing.T, client *azopenai.Client) { + expected := azopenai.ChatCompletions{ + Choices: []*azopenai.ChatChoice{ + { + Message: &azopenai.ChatChoiceMessage{ + Role: &expectedRole, + Content: &expectedContent, + }, + Index: to.Ptr(int32(0)), + FinishReason: to.Ptr(azopenai.CompletionsFinishReason("stop")), + }, + }, + Usage: &azopenai.CompletionsUsage{ + // these change depending on which model you use. These #'s work for gpt-4, which is + // what I'm using for these tests. + CompletionTokens: to.Ptr(int32(29)), + PromptTokens: to.Ptr(int32(42)), + TotalTokens: to.Ptr(int32(71)), + }, + } + + resp, err := client.GetChatCompletions(context.Background(), chatCompletionsRequest, nil) + require.NoError(t, err) + + require.NotEmpty(t, resp.ID) + require.NotEmpty(t, resp.Created) + + expected.ID = resp.ID + expected.Created = resp.Created + + require.Equal(t, expected, resp.ChatCompletions) +} + +func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client) { + streamResp, err := client.GetChatCompletionsStream(context.Background(), chatCompletionsRequest, nil) + require.NoError(t, err) + + // the data comes back differently for streaming + // 1. the text comes back in the ChatCompletion.Delta field + // 2. the role is only sent on the first streamed ChatCompletion + // check that the role came back as well. + var choices []*azopenai.ChatChoice + + for { + completion, err := streamResp.ChatCompletionsStream.Read() + + if errors.Is(err, io.EOF) { + break + } + + require.NoError(t, err) + require.Equal(t, 1, len(completion.Choices)) + choices = append(choices, completion.Choices[0]) + } + + var message string + + for _, choice := range choices { + if choice.Delta.Content == nil { + continue + } + + message += *choice.Delta.Content + } + + require.Equal(t, expectedContent, message, "Ultimately, the same result as GetChatCompletions(), just sent across the .Delta field instead") + + require.Equal(t, azopenai.ChatRoleAssistant, expectedRole) +} + +func TestClient_GetChatCompletions_DefaultAzureCredential(t *testing.T) { + if recording.GetRecordMode() == recording.PlaybackMode { + t.Skipf("Not running this test in playback (for now)") + } + + if os.Getenv("USE_TOKEN_CREDS") != "true" { + t.Skipf("USE_TOKEN_CREDS is not true, disabling token credential tests") + } + + recordingTransporter := newRecordingTransporter(t) + + dac, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{ + ClientOptions: policy.ClientOptions{ + Transport: recordingTransporter, + }, + }) + require.NoError(t, err) + + chatClient, err := azopenai.NewClient(endpoint, dac, chatCompletionsModelDeployment, &azopenai.ClientOptions{ + ClientOptions: policy.ClientOptions{Transport: recordingTransporter}, + }) + require.NoError(t, err) + + testGetChatCompletions(t, chatClient) +} + +func TestClient_GetChatCompletions_InvalidModel(t *testing.T) { + cred, err := azopenai.NewKeyCredential(apiKey) + require.NoError(t, err) + + chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, "thisdoesntexist", newClientOptionsForTest(t)) + require.NoError(t, err) + + _, err = chatClient.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{ + Messages: []*azopenai.ChatMessage{ + { + Role: to.Ptr(azopenai.ChatRole("user")), + Content: to.Ptr("Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."), + }, + }, + MaxTokens: to.Ptr(int32(1024)), + Temperature: to.Ptr(float32(0.0)), + }, nil) + + var respErr *azcore.ResponseError + require.ErrorAs(t, err, &respErr) + require.Equal(t, "DeploymentNotFound", respErr.ErrorCode) +} diff --git a/sdk/cognitiveservices/azopenai/client_completions_test.go b/sdk/cognitiveservices/azopenai/client_completions_test.go new file mode 100644 index 000000000000..dfd96747b4bb --- /dev/null +++ b/sdk/cognitiveservices/azopenai/client_completions_test.go @@ -0,0 +1,88 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai_test + +import ( + "context" + "log" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/require" +) + +func TestClient_GetCompletions(t *testing.T) { + type args struct { + ctx context.Context + deploymentID string + body azopenai.CompletionsOptions + options *azopenai.GetCompletionsOptions + } + cred, err := azopenai.NewKeyCredential(apiKey) + require.NoError(t, err) + + client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, completionsModelDeployment, newClientOptionsForTest(t)) + if err != nil { + log.Fatalf("%v", err) + } + tests := []struct { + name string + client *azopenai.Client + args args + want azopenai.GetCompletionsResponse + wantErr bool + }{ + { + name: "chatbot", + client: client, + args: args{ + ctx: context.TODO(), + deploymentID: completionsModelDeployment, + body: azopenai.CompletionsOptions{ + Prompt: []*string{to.Ptr("What is Azure OpenAI?")}, + MaxTokens: to.Ptr(int32(2048 - 127)), + Temperature: to.Ptr(float32(0.0)), + }, + options: nil, + }, + want: azopenai.GetCompletionsResponse{ + Completions: azopenai.Completions{ + Choices: []*azopenai.Choice{ + { + Text: to.Ptr("\n\nAzure OpenAI is a platform from Microsoft that provides access to OpenAI's artificial intelligence (AI) technologies. It enables developers to build, train, and deploy AI models in the cloud. Azure OpenAI provides access to OpenAI's powerful AI technologies, such as GPT-3, which can be used to create natural language processing (NLP) applications, computer vision models, and reinforcement learning models."), + Index: to.Ptr(int32(0)), + FinishReason: to.Ptr(azopenai.CompletionsFinishReason("stop")), + Logprobs: nil, + }, + }, + Usage: &azopenai.CompletionsUsage{ + CompletionTokens: to.Ptr(int32(85)), + PromptTokens: to.Ptr(int32(6)), + TotalTokens: to.Ptr(int32(91)), + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.client.GetCompletions(tt.args.ctx, tt.args.body, tt.args.options) + if (err != nil) != tt.wantErr { + t.Errorf("Client.GetCompletions() error = %v, wantErr %v", err, tt.wantErr) + return + } + opts := cmpopts.IgnoreFields(azopenai.Completions{}, "Created", "ID") + if diff := cmp.Diff(tt.want.Completions, got.Completions, opts); diff != "" { + t.Errorf("Client.GetCompletions(): -want, +got:\n%s", diff) + } + }) + } +} diff --git a/sdk/cognitiveservices/azopenai/client_embeddings_test.go b/sdk/cognitiveservices/azopenai/client_embeddings_test.go new file mode 100644 index 000000000000..747120751283 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/client_embeddings_test.go @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai_test + +import ( + "context" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" + "github.com/stretchr/testify/require" +) + +func TestClient_GetEmbeddings_InvalidModel(t *testing.T) { + cred, err := azopenai.NewKeyCredential(apiKey) + require.NoError(t, err) + + chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, "thisdoesntexist", newClientOptionsForTest(t)) + require.NoError(t, err) + + _, err = chatClient.GetEmbeddings(context.Background(), azopenai.EmbeddingsOptions{}, nil) + + var respErr *azcore.ResponseError + require.ErrorAs(t, err, &respErr) + require.Equal(t, "DeploymentNotFound", respErr.ErrorCode) +} + +func TestClient_OpenAI_GetEmbeddings(t *testing.T) { + client := newOpenAIClientForTest(t) + modelID := "text-similarity-curie-001" + testGetEmbeddings(t, client, modelID) +} + +func TestClient_GetEmbeddings(t *testing.T) { + // model deployment points to `text-similarity-curie-001` + deploymentID := "embedding" + + cred, err := azopenai.NewKeyCredential(apiKey) + require.NoError(t, err) + + client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, deploymentID, newClientOptionsForTest(t)) + require.NoError(t, err) + + testGetEmbeddings(t, client, deploymentID) +} + +func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentID string) { + type args struct { + ctx context.Context + deploymentID string + body azopenai.EmbeddingsOptions + options *azopenai.GetEmbeddingsOptions + } + + tests := []struct { + name string + client *azopenai.Client + args args + want azopenai.GetEmbeddingsResponse + wantErr bool + }{ + { + name: "Embeddings", + client: client, + args: args{ + ctx: context.TODO(), + deploymentID: modelOrDeploymentID, + body: azopenai.EmbeddingsOptions{ + Input: []byte("\"Your text string goes here\""), + Model: &modelOrDeploymentID, + }, + options: nil, + }, + want: azopenai.GetEmbeddingsResponse{ + azopenai.Embeddings{ + Data: []*azopenai.EmbeddingItem{}, + Usage: &azopenai.EmbeddingsUsage{}, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.client.GetEmbeddings(tt.args.ctx, tt.args.body, tt.args.options) + if (err != nil) != tt.wantErr { + t.Errorf("Client.GetEmbeddings() error = %v, wantErr %v", err, tt.wantErr) + return + } + if len(got.Embeddings.Data[0].Embedding) != 4096 { + t.Errorf("Client.GetEmbeddings() len(Data) want 4096, got %d", len(got.Embeddings.Data)) + return + } + }) + } +} diff --git a/sdk/cognitiveservices/azopenai/client_shared_test.go b/sdk/cognitiveservices/azopenai/client_shared_test.go index c3d643709801..e1a879fed305 100644 --- a/sdk/cognitiveservices/azopenai/client_shared_test.go +++ b/sdk/cognitiveservices/azopenai/client_shared_test.go @@ -1,28 +1,34 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See License.txt in the project root for license information. -package azopenai +package azopenai_test import ( + "crypto/tls" "fmt" + "net/http" "os" "regexp" "strings" "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" "github.com/joho/godotenv" "github.com/stretchr/testify/require" ) var ( - endpoint string // env: AOAI_ENDPOINT - apiKey string // env: AOAI_API_KEY - streamingModelDeployment string // env: AOAI_STREAMING_MODEL_DEPLOYMENT - - openAIKey string // env: OPENAI_API_KEY - openAIEndpoint string // env: OPENAI_ENDPOINT + endpoint string // env: AOAI_ENDPOINT + apiKey string // env: AOAI_API_KEY + completionsModelDeployment string // env: AOAI_COMPLETIONS_MODEL_DEPLOYMENT + chatCompletionsModelDeployment string // env: AOAI_CHAT_COMPLETIONS_MODEL_DEPLOYMENT + + openAIKey string // env: OPENAI_API_KEY + openAIEndpoint string // env: OPENAI_ENDPOINT + openAICompletionsModelDeployment string // env: OPENAI_CHAT_COMPLETIONS_MODEL + openAIChatCompletionsModelDeployment string // env: OPENAI_COMPLETIONS_MODEL ) const fakeEndpoint = "https://recordedhost/" @@ -34,7 +40,12 @@ func init() { apiKey = fakeAPIKey openAIKey = fakeAPIKey openAIEndpoint = fakeEndpoint - streamingModelDeployment = "text-davinci-003" + + completionsModelDeployment = "text-davinci-003" + openAICompletionsModelDeployment = "text-davinci-003" + + chatCompletionsModelDeployment = "gpt-4" + openAIChatCompletionsModelDeployment = "gpt-4" } else { if err := godotenv.Load(); err != nil { fmt.Printf("Failed to load .env file: %s\n", err) @@ -51,10 +62,13 @@ func init() { apiKey = os.Getenv("AOAI_API_KEY") // Ex: text-davinci-003 - streamingModelDeployment = os.Getenv("AOAI_STREAMING_MODEL_DEPLOYMENT") + completionsModelDeployment = os.Getenv("AOAI_COMPLETIONS_MODEL_DEPLOYMENT") + chatCompletionsModelDeployment = os.Getenv("AOAI_CHAT_COMPLETIONS_MODEL_DEPLOYMENT") openAIKey = os.Getenv("OPENAI_API_KEY") openAIEndpoint = os.Getenv("OPENAI_ENDPOINT") + openAICompletionsModelDeployment = os.Getenv("OPENAI_COMPLETIONS_MODEL") + openAIChatCompletionsModelDeployment = os.Getenv("OPENAI_CHAT_COMPLETIONS_MODEL") if openAIEndpoint != "" && !strings.HasSuffix(openAIEndpoint, "/") { // (this just makes recording replacement easier) @@ -92,8 +106,32 @@ func newRecordingTransporter(t *testing.T) policy.Transporter { return transport } -func newClientOptionsForTest(t *testing.T) *ClientOptions { - co := &ClientOptions{} - co.Transport = newRecordingTransporter(t) +func newClientOptionsForTest(t *testing.T) *azopenai.ClientOptions { + co := &azopenai.ClientOptions{} + + if recording.GetRecordMode() == recording.LiveMode { + keyLogPath := os.Getenv("SSLKEYLOGFILE") + + if keyLogPath == "" { + return nil + } + + keyLogWriter, err := os.OpenFile(keyLogPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0777) + require.NoError(t, err) + + t.Cleanup(func() { + _ = keyLogWriter.Close() + }) + + tp := http.DefaultTransport.(*http.Transport).Clone() + tp.TLSClientConfig = &tls.Config{ + KeyLogWriter: keyLogWriter, + } + + co.Transport = &http.Client{Transport: tp} + } else { + co.Transport = newRecordingTransporter(t) + } + return co } diff --git a/sdk/cognitiveservices/azopenai/client_test.go b/sdk/cognitiveservices/azopenai/client_test.go index ecfe03dc29ff..e4dc512061df 100644 --- a/sdk/cognitiveservices/azopenai/client_test.go +++ b/sdk/cognitiveservices/azopenai/client_test.go @@ -4,77 +4,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See License.txt in the project root for license information. -package azopenai +package azopenai_test import ( "context" - "log" "net/http" - "os" "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" + "github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" "github.com/stretchr/testify/require" ) -func TestClient_GetChatCompletions(t *testing.T) { - deploymentID := "gpt-35-turbo" - - cred, err := NewKeyCredential(apiKey) - require.NoError(t, err) - - chatClient, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, newClientOptionsForTest(t)) - require.NoError(t, err) - - testGetChatCompletions(t, chatClient, deploymentID) -} - -func TestClient_GetChatCompletions_DefaultAzureCredential(t *testing.T) { - if recording.GetRecordMode() == recording.PlaybackMode { - t.Skipf("Not running this test in playback (for now)") - } - - if os.Getenv("USE_TOKEN_CREDS") != "true" { - t.Skipf("USE_TOKEN_CREDS is not true, disabling token credential tests") - } - - deploymentID := "gpt-35-turbo" - - recordingTransporter := newRecordingTransporter(t) - - dac, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{ - ClientOptions: policy.ClientOptions{ - Transport: recordingTransporter, - }, - }) - require.NoError(t, err) - - chatClient, err := NewClient(endpoint, dac, deploymentID, &ClientOptions{ - ClientOptions: policy.ClientOptions{Transport: recordingTransporter}, - }) - require.NoError(t, err) - - testGetChatCompletions(t, chatClient, deploymentID) -} - -func TestClient_OpenAI_GetChatCompletions(t *testing.T) { - chatClient := newOpenAIClientForTest(t) - testGetChatCompletions(t, chatClient, "gpt-3.5-turbo") -} - func TestClient_OpenAI_InvalidModel(t *testing.T) { chatClient := newOpenAIClientForTest(t) - _, err := chatClient.GetChatCompletions(context.Background(), ChatCompletionsOptions{ - Messages: []*ChatMessage{ + _, err := chatClient.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{ + Messages: []*azopenai.ChatMessage{ { - Role: to.Ptr(ChatRoleSystem), + Role: to.Ptr(azopenai.ChatRoleSystem), Content: to.Ptr("hello"), }, }, @@ -87,263 +36,15 @@ func TestClient_OpenAI_InvalidModel(t *testing.T) { require.Contains(t, respErr.Error(), "The model `non-existent-model` does not exist") } -func testGetChatCompletions(t *testing.T, chatClient *Client, modelOrDeployment string) { - type args struct { - ctx context.Context - deploymentID string - body ChatCompletionsOptions - options *GetChatCompletionsOptions - } - - tests := []struct { - name string - client *Client - args args - want GetChatCompletionsResponse - - wantErr bool - }{ - { - name: "ChatCompletions", - client: chatClient, - args: args{ - ctx: context.TODO(), - deploymentID: modelOrDeployment, - body: ChatCompletionsOptions{ - Messages: []*ChatMessage{ - { - Role: to.Ptr(ChatRole("user")), - Content: to.Ptr("Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."), - }, - }, - MaxTokens: to.Ptr(int32(1024)), - Temperature: to.Ptr(float32(0.0)), - Model: &modelOrDeployment, - }, - options: nil, - }, - want: GetChatCompletionsResponse{ - ChatCompletions: ChatCompletions{ - Choices: []*ChatChoice{ - { - Message: &ChatChoiceMessage{ - Role: to.Ptr(ChatRole("assistant")), - Content: to.Ptr("1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100."), - }, - Index: to.Ptr(int32(0)), - FinishReason: to.Ptr(CompletionsFinishReason("stop")), - }, - }, - Usage: &CompletionsUsage{ - CompletionTokens: to.Ptr(int32(299)), - PromptTokens: to.Ptr(int32(37)), - TotalTokens: to.Ptr(int32(336)), - }, - }, - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.client.GetChatCompletions(tt.args.ctx, tt.args.body, tt.args.options) - if (err != nil) != tt.wantErr { - t.Errorf("Client.GetChatCompletions() error = %v, wantErr %v", err, tt.wantErr) - return - } - opts := cmpopts.IgnoreFields(ChatCompletions{}, "Created", "ID") - if diff := cmp.Diff(tt.want.ChatCompletions, got.ChatCompletions, opts); diff != "" { - t.Errorf("Client.GetCompletions(): -want, +got:\n%s", diff) - } - }) - } -} - -func TestClient_GetChatCompletions_InvalidModel(t *testing.T) { - cred, err := NewKeyCredential(apiKey) - require.NoError(t, err) - - chatClient, err := NewClientWithKeyCredential(endpoint, cred, "thisdoesntexist", newClientOptionsForTest(t)) - require.NoError(t, err) - - _, err = chatClient.GetChatCompletions(context.Background(), ChatCompletionsOptions{ - Messages: []*ChatMessage{ - { - Role: to.Ptr(ChatRole("user")), - Content: to.Ptr("Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."), - }, - }, - MaxTokens: to.Ptr(int32(1024)), - Temperature: to.Ptr(float32(0.0)), - }, nil) - - var respErr *azcore.ResponseError - require.ErrorAs(t, err, &respErr) - require.Equal(t, "DeploymentNotFound", respErr.ErrorCode) -} - -func TestClient_GetEmbeddings_InvalidModel(t *testing.T) { - cred, err := NewKeyCredential(apiKey) - require.NoError(t, err) - - chatClient, err := NewClientWithKeyCredential(endpoint, cred, "thisdoesntexist", newClientOptionsForTest(t)) - require.NoError(t, err) - - _, err = chatClient.GetEmbeddings(context.Background(), EmbeddingsOptions{}, nil) - - var respErr *azcore.ResponseError - require.ErrorAs(t, err, &respErr) - require.Equal(t, "DeploymentNotFound", respErr.ErrorCode) -} - -func TestClient_GetCompletions(t *testing.T) { - type args struct { - ctx context.Context - deploymentID string - body CompletionsOptions - options *GetCompletionsOptions - } - cred, err := NewKeyCredential(apiKey) - require.NoError(t, err) - - client, err := NewClientWithKeyCredential(endpoint, cred, streamingModelDeployment, newClientOptionsForTest(t)) - if err != nil { - log.Fatalf("%v", err) - } - tests := []struct { - name string - client *Client - args args - want GetCompletionsResponse - wantErr bool - }{ - { - name: "chatbot", - client: client, - args: args{ - ctx: context.TODO(), - deploymentID: streamingModelDeployment, - body: CompletionsOptions{ - Prompt: []*string{to.Ptr("What is Azure OpenAI?")}, - MaxTokens: to.Ptr(int32(2048 - 127)), - Temperature: to.Ptr(float32(0.0)), - }, - options: nil, - }, - want: GetCompletionsResponse{ - Completions: Completions{ - Choices: []*Choice{ - { - Text: to.Ptr("\n\nAzure OpenAI is a platform from Microsoft that provides access to OpenAI's artificial intelligence (AI) technologies. It enables developers to build, train, and deploy AI models in the cloud. Azure OpenAI provides access to OpenAI's powerful AI technologies, such as GPT-3, which can be used to create natural language processing (NLP) applications, computer vision models, and reinforcement learning models."), - Index: to.Ptr(int32(0)), - FinishReason: to.Ptr(CompletionsFinishReason("stop")), - Logprobs: nil, - }, - }, - Usage: &CompletionsUsage{ - CompletionTokens: to.Ptr(int32(85)), - PromptTokens: to.Ptr(int32(6)), - TotalTokens: to.Ptr(int32(91)), - }, - }, - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.client.GetCompletions(tt.args.ctx, tt.args.body, tt.args.options) - if (err != nil) != tt.wantErr { - t.Errorf("Client.GetCompletions() error = %v, wantErr %v", err, tt.wantErr) - return - } - opts := cmpopts.IgnoreFields(Completions{}, "Created", "ID") - if diff := cmp.Diff(tt.want.Completions, got.Completions, opts); diff != "" { - t.Errorf("Client.GetCompletions(): -want, +got:\n%s", diff) - } - }) - } -} - -func TestClient_OpenAI_GetEmbeddings(t *testing.T) { - client := newOpenAIClientForTest(t) - modelID := "text-similarity-curie-001" - testGetEmbeddings(t, client, modelID) -} - -func TestClient_GetEmbeddings(t *testing.T) { - // model deployment points to `text-similarity-curie-001` - deploymentID := "embedding" - - cred, err := NewKeyCredential(apiKey) - require.NoError(t, err) - - client, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, newClientOptionsForTest(t)) - require.NoError(t, err) - - testGetEmbeddings(t, client, deploymentID) -} - -func testGetEmbeddings(t *testing.T, client *Client, modelOrDeploymentID string) { - type args struct { - ctx context.Context - deploymentID string - body EmbeddingsOptions - options *GetEmbeddingsOptions - } - - tests := []struct { - name string - client *Client - args args - want GetEmbeddingsResponse - wantErr bool - }{ - { - name: "Embeddings", - client: client, - args: args{ - ctx: context.TODO(), - deploymentID: modelOrDeploymentID, - body: EmbeddingsOptions{ - Input: []byte("\"Your text string goes here\""), - Model: &modelOrDeploymentID, - }, - options: nil, - }, - want: GetEmbeddingsResponse{ - Embeddings{ - Data: []*EmbeddingItem{}, - Usage: &EmbeddingsUsage{}, - }, - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.client.GetEmbeddings(tt.args.ctx, tt.args.body, tt.args.options) - if (err != nil) != tt.wantErr { - t.Errorf("Client.GetEmbeddings() error = %v, wantErr %v", err, tt.wantErr) - return - } - if len(got.Embeddings.Data[0].Embedding) != 4096 { - t.Errorf("Client.GetEmbeddings() len(Data) want 4096, got %d", len(got.Embeddings.Data)) - return - } - }) - } -} - -func newOpenAIClientForTest(t *testing.T) *Client { +func newOpenAIClientForTest(t *testing.T) *azopenai.Client { if openAIKey == "" { t.Skipf("OPENAI_API_KEY not defined, skipping OpenAI public endpoint test") } - cred, err := NewKeyCredential(openAIKey) + cred, err := azopenai.NewKeyCredential(openAIKey) require.NoError(t, err) - chatClient, err := NewClientForOpenAI(openAIEndpoint, cred, newClientOptionsForTest(t)) + chatClient, err := azopenai.NewClientForOpenAI(openAIEndpoint, cred, newClientOptionsForTest(t)) require.NoError(t, err) return chatClient diff --git a/sdk/cognitiveservices/azopenai/custom_client.go b/sdk/cognitiveservices/azopenai/custom_client.go index cc4654dcf733..a2679b0ff883 100644 --- a/sdk/cognitiveservices/azopenai/custom_client.go +++ b/sdk/cognitiveservices/azopenai/custom_client.go @@ -109,17 +109,20 @@ func (b *openAIPolicy) Do(req *policy.Request) (*http.Response, error) { } // Methods that return streaming response - type streamCompletionsOptions struct { - CompletionsOptions + // we strip out the 'stream' field from the options exposed to the customer so + // now we need to add it back in. + any Stream bool `json:"stream"` } func (o streamCompletionsOptions) MarshalJSON() ([]byte, error) { - bytes, err := o.CompletionsOptions.MarshalJSON() + bytes, err := json.Marshal(o.any) + if err != nil { return nil, err } + objectMap := make(map[string]any) err = json.Unmarshal(bytes, &objectMap) if err != nil { @@ -164,3 +167,34 @@ func formatAzureOpenAIURL(endpoint, deploymentID string) string { escapedDeplID := url.PathEscape(deploymentID) return runtime.JoinPaths(endpoint, "openai", "deployments", escapedDeplID) } + +// GetChatCompletionsStream - Return the chat completions for a given prompt as a sequence of events. +// If the operation fails it returns an *azcore.ResponseError type. +// - options - GetCompletionsOptions contains the optional parameters for the Client.GetCompletions method. +func (client *Client) GetChatCompletionsStream(ctx context.Context, body ChatCompletionsOptions, options *GetChatCompletionsStreamOptions) (GetChatCompletionsStreamResponse, error) { + req, err := client.getChatCompletionsCreateRequest(ctx, ChatCompletionsOptions{}, &GetChatCompletionsOptions{}) + + if err != nil { + return GetChatCompletionsStreamResponse{}, err + } + + if err := runtime.MarshalAsJSON(req, streamCompletionsOptions{body, true}); err != nil { + return GetChatCompletionsStreamResponse{}, err + } + + runtime.SkipBodyDownload(req) + + resp, err := client.internal.Pipeline().Do(req) + + if err != nil { + return GetChatCompletionsStreamResponse{}, err + } + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return GetChatCompletionsStreamResponse{}, runtime.NewResponseError(resp) + } + + return GetChatCompletionsStreamResponse{ + ChatCompletionsStream: newEventReader[ChatCompletions](resp.Body), + }, nil +} diff --git a/sdk/cognitiveservices/azopenai/custom_client_test.go b/sdk/cognitiveservices/azopenai/custom_client_test.go index 99694ccb1315..9bea7687bd8b 100644 --- a/sdk/cognitiveservices/azopenai/custom_client_test.go +++ b/sdk/cognitiveservices/azopenai/custom_client_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. See License.txt in the project root for license information. -package azopenai +package azopenai_test import ( "context" @@ -15,6 +15,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" "github.com/stretchr/testify/require" ) @@ -23,19 +24,19 @@ func TestNewClient(t *testing.T) { endpoint string credential azcore.TokenCredential deploymentID string - options *ClientOptions + options *azopenai.ClientOptions } tests := []struct { name string args args - want *Client + want *azopenai.Client wantErr bool }{ // TODO: Add test cases. } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewClient(tt.args.endpoint, tt.args.credential, tt.args.deploymentID, tt.args.options) + got, err := azopenai.NewClient(tt.args.endpoint, tt.args.credential, tt.args.deploymentID, tt.args.options) if (err != nil) != tt.wantErr { t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr) return @@ -50,21 +51,21 @@ func TestNewClient(t *testing.T) { func TestNewClientWithKeyCredential(t *testing.T) { type args struct { endpoint string - credential KeyCredential + credential azopenai.KeyCredential deploymentID string - options *ClientOptions + options *azopenai.ClientOptions } tests := []struct { name string args args - want *Client + want *azopenai.Client wantErr bool }{ // TODO: Add test cases. } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewClientWithKeyCredential(tt.args.endpoint, tt.args.credential, tt.args.deploymentID, tt.args.options) + got, err := azopenai.NewClientWithKeyCredential(tt.args.endpoint, tt.args.credential, tt.args.deploymentID, tt.args.options) if (err != nil) != tt.wantErr { t.Errorf("NewClientWithKeyCredential() error = %v, wantErr %v", err, tt.wantErr) return @@ -77,16 +78,16 @@ func TestNewClientWithKeyCredential(t *testing.T) { } func TestClient_GetCompletionsStream(t *testing.T) { - body := CompletionsOptions{ + body := azopenai.CompletionsOptions{ Prompt: []*string{to.Ptr("What is Azure OpenAI?")}, MaxTokens: to.Ptr(int32(2048)), Temperature: to.Ptr(float32(0.0)), } - cred, err := NewKeyCredential(apiKey) + cred, err := azopenai.NewKeyCredential(apiKey) require.NoError(t, err) - client, err := NewClientWithKeyCredential(endpoint, cred, streamingModelDeployment, newClientOptionsForTest(t)) + client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, completionsModelDeployment, newClientOptionsForTest(t)) if err != nil { t.Errorf("NewClientWithKeyCredential() error = %v", err) return diff --git a/sdk/cognitiveservices/azopenai/custom_models.go b/sdk/cognitiveservices/azopenai/custom_models.go index fb320bebfb02..82822110346b 100644 --- a/sdk/cognitiveservices/azopenai/custom_models.go +++ b/sdk/cognitiveservices/azopenai/custom_models.go @@ -8,13 +8,24 @@ package azopenai // Models for methods that return streaming response -// GetCompletionsStreamOptions contains the optional parameters for the Client.GetCompletions method. +// GetCompletionsStreamOptions contains the optional parameters for the [Client.GetCompletionsStream] method. type GetCompletionsStreamOptions struct { // placeholder for future optional parameters } -// GetCompletionsStreamResponse is the response from [GetCompletionsStream]. +// GetCompletionsStreamResponse is the response from [Client.GetCompletionsStream]. type GetCompletionsStreamResponse struct { // CompletionsStream returns the stream of completions. Token limits and other settings may limit the number of completions returned by the service. CompletionsStream *EventReader[Completions] } + +// GetChatCompletionsStreamOptions contains the optional parameters for the [Client.GetChatCompletionsStream] method. +type GetChatCompletionsStreamOptions struct { + // placeholder for future optional parameters +} + +// GetChatCompletionsStreamResponse is the response from [Client.GetChatCompletionsStream]. +type GetChatCompletionsStreamResponse struct { + // ChatCompletionsStream returns the stream of completions. Token limits and other settings may limit the number of chat completions returned by the service. + ChatCompletionsStream *EventReader[ChatCompletions] +} diff --git a/sdk/cognitiveservices/azopenai/sample.env b/sdk/cognitiveservices/azopenai/sample.env new file mode 100644 index 000000000000..c4f5e9396a5d --- /dev/null +++ b/sdk/cognitiveservices/azopenai/sample.env @@ -0,0 +1,17 @@ +# Azure OpenAI +AOAI_ENDPOINT=https://.openai.azure.com/ +AOAI_API_KEY= + +# These names will come from your model deployments in your Azure OpenAI resource +# ex: text-davinci-003 +AOAI_COMPLETIONS_MODEL_DEPLOYMENT= +# ex: gpt-4 +AOAI_CHAT_COMPLETIONS_MODEL_DEPLOYMENT= + +# public OpenAI +OPENAI_ENDPOINT=https://api.openai.com/v1 +OPENAI_API_KEY= +# ex: text-davinci-003 +OPENAI_COMPLETIONS_MODEL= +# ex: gpt-4 +OPENAI_CHAT_COMPLETIONS_MODEL=