From b681440cc391d1424d811facc1ca1dab7ab50e56 Mon Sep 17 00:00:00 2001 From: Richard Park Date: Mon, 10 Jul 2023 23:29:08 -0700 Subject: [PATCH 1/8] Updating to the 2023-07-01 API surface - Adding in functions support and example. - Added in accomodation for content filtering info. - Make it so we can use separate service instances for some tests so we can test against the latest upcoming fixes/changes. --- sdk/cognitiveservices/azopenai/autorest.md | 60 ++- sdk/cognitiveservices/azopenai/client.go | 26 +- .../azopenai/client_chat_completions_test.go | 58 +- .../azopenai/client_completions_test.go | 100 ++-- .../azopenai/client_embeddings_test.go | 4 + .../azopenai/client_functions_test.go | 101 ++++ .../azopenai/client_rai_test.go | 93 ++++ .../azopenai/client_shared_test.go | 90 +++- sdk/cognitiveservices/azopenai/client_test.go | 19 +- sdk/cognitiveservices/azopenai/constants.go | 41 ++ .../azopenai/custom_client.go | 4 + .../azopenai/custom_client_functions.go | 34 ++ .../azopenai/custom_client_image_test.go | 12 +- .../azopenai/custom_client_test.go | 34 +- .../azopenai/custom_models.go | 52 ++ .../azopenai/custom_models_test.go | 66 +++ .../azopenai/event_reader.go | 1 - .../azopenai/event_reader_test.go | 1 - sdk/cognitiveservices/azopenai/go.mod | 1 - sdk/cognitiveservices/azopenai/go.sum | 2 - sdk/cognitiveservices/azopenai/models.go | 242 ++++++++- .../azopenai/models_serde.go | 495 +++++++++++++++++- .../content_filter_response_error.json | 30 ++ .../azopenai/testdata/tsp-location.yaml | 2 +- 24 files changed, 1418 insertions(+), 150 deletions(-) create mode 100644 sdk/cognitiveservices/azopenai/client_functions_test.go create mode 100644 sdk/cognitiveservices/azopenai/client_rai_test.go create mode 100644 sdk/cognitiveservices/azopenai/custom_client_functions.go create mode 100644 sdk/cognitiveservices/azopenai/custom_models_test.go create mode 100644 sdk/cognitiveservices/azopenai/testdata/content_filter_response_error.json diff --git a/sdk/cognitiveservices/azopenai/autorest.md b/sdk/cognitiveservices/azopenai/autorest.md index f886d97fba3e..ad7af1f0a84e 100644 --- a/sdk/cognitiveservices/azopenai/autorest.md +++ b/sdk/cognitiveservices/azopenai/autorest.md @@ -16,7 +16,8 @@ go: true use: "@autorest/go@4.0.0-preview.52" title: "OpenAI" slice-elements-byval: true -remove-non-reference-schema: true +# can't use this since it removes an innererror type that we want () +# remove-non-reference-schema: true ``` ## Transformations @@ -81,7 +82,7 @@ directive: where: $.components.schemas["ImageOperation"].properties.status transform: $["$ref"] = $.anyOf[0]["$ref"];delete $.anyOf; - from: openapi-document - where: $.components.schemas["ImageGenerationOptions"].properties + where: $.components.schemas.ImageGenerationOptions.properties transform: | $.size["$ref"] = "#/components/schemas/ImageSize"; delete $.allOf; $.response_format["$ref"] = "#/components/schemas/ImageGenerationResponseFormat"; delete $.allOf; @@ -93,11 +94,12 @@ directive: - from: openapi-document where: $.components.schemas["ImageOperationStatus"].properties.status transform: $["$ref"] = "#/components/schemas/State"; delete $.allOf; + - from: openapi-document + where: $.components.schemas["ContentFilterResult"].properties.severity + transform: $["$ref"] = "#/components/schemas/ContentFilterSeverity"; delete $.allOf; - from: openapi-document where: $.components.schemas["ChatChoice"].properties.finish_reason - transform: > - delete $.oneOf; - $["$ref"] = "#/components/schemas/CompletionsFinishReason"; + transform: $["$ref"] = "#/components/schemas/CompletionsFinishReason"; delete $.oneOf; # Fix "AutoGenerated" models - from: openapi-document where: $.components.schemas["ChatCompletions"].properties.usage @@ -163,7 +165,7 @@ directive: - client.go - models.go - options.go - - response_types.go + - response_types.go where: $ transform: return $.replace(/Client(\w+)((?:Options|Response))/g, "$1$2"); @@ -172,10 +174,19 @@ directive: where: $ transform: return $.replace(/runtime\.JoinPaths\(client.endpoint, urlPath\)/g, "client.formatURL(urlPath)"); + # Some ImageGenerations hackery to represent the ImageLocation/ImagePayload polymorphism. + # - Remove the auto-generated ImageGenerationsDataItem. + # - Replace the ImageGenerations.Data type with []ImageGenerationDataItem + # - from: models.go + # where: $ + # transform: | + # return $.replace(/type ImageGenerationsDataItem struct {[^}]+}/, "// ImageGenerationsDataItem represents an image URL or payload\ntype ImageGenerationsDataItem struct{\nImageLocation\nImagePayload\n}") + # $.replace(/(type ImageGenerations struct.+?)Data any/g, "$1Data []ImageGenerationsDataItem") + - from: models.go where: $ transform: | - return $.replace(/type ImageGenerationsDataItem struct {[^}]+}/, "// ImageGenerationsDataItem represents an image URL or payload\ntype ImageGenerationsDataItem struct{\nImageLocation\nImagePayload\n}"); + return $.replace(/(type ImageGenerations struct.+?)Data any/sg, "$1Data []ImageGenerationsDataItem") # delete the auto-generated ImageGenerationsDataItem, we handle that custom - from: models.go @@ -218,6 +229,17 @@ directive: .replace(/BeginAzureBatchImageGenerationInternal/g, "beginAzureBatchImageGeneration") .replace(/BatchImageGenerationOperationResponse/g, "batchImageGenerationOperationResponse"); + # BUG: ChatCompletionsOptionsFunctionCall is another one of those "here's mutually exclusive values" options... + - from: + - models.go + - models_serde.go + where: $ + transform: | + return $ + .replace(/populateAny\(objectMap, "function_call", c.FunctionCall\)/, 'populate(objectMap, "function_call", c.FunctionCall)') + .replace(/\/\/ ChatCompletionsOptionsFunctionCall.+?\n}/, "") + .replace(/FunctionCall any/, "FunctionCall *ChatCompletionsOptionsFunctionCall"); + # fix some casing - from: - client.go @@ -228,8 +250,30 @@ directive: where: $ transform: return $.replace(/Logprobs/g, "LogProbs") - # remove PossibleazureOpenAIOperationStateValues, since we don't expose the poller + # delete ContentFilterResult in favor of our custom representation. + - from: + - models.go + - models_serde.go + where: $ + transform: | + return $.replace(/\/\/ ContentFilterResult.+?\n}/s, "") + .replace(/\/\/ MarshalJSON implements the json.Marshaller interface for type ContentFilterResult.+?\n}/s, "") + .replace(/\/\/ UnmarshalJSON implements the json.Unmarshaller interface for type ContentFilterResult.+?\n}/s, ""); + - from: constants.go where: $ transform: return $.replace(/\/\/ PossibleazureOpenAIOperationStateValues returns.+?\n}/s, ""); + + # fix incorrect property name for content filtering + # TODO: I imagine we should able to fix this in the tsp? + - from: models_serde.go + where: $ + transform: | + return $ + .replace(/ case "selfHarm":/g, ' case "self_harm":') + .replace(/populate\(objectMap, "selfHarm", c.SelfHarm\)/g, 'populate(objectMap, "self_harm", c.SelfHarm)'); + + - from: client.go + where: $ + transform: return $.replace(/runtime\.NewResponseError/sg, "client.newError"); ``` diff --git a/sdk/cognitiveservices/azopenai/client.go b/sdk/cognitiveservices/azopenai/client.go index d1cfafc12d56..3fccca28a9f6 100644 --- a/sdk/cognitiveservices/azopenai/client.go +++ b/sdk/cognitiveservices/azopenai/client.go @@ -27,7 +27,7 @@ type Client struct { // beginAzureBatchImageGeneration - Starts the generation of a batch of images from a text caption // If the operation fails it returns an *azcore.ResponseError type. // -// Generated from API version 2023-06-01-preview +// Generated from API version 2023-07-01-preview // - options - beginAzureBatchImageGenerationOptions contains the optional parameters for the Client.beginAzureBatchImageGeneration // method. func (client *Client) beginAzureBatchImageGeneration(ctx context.Context, body ImageGenerationOptions, options *beginAzureBatchImageGenerationOptions) (*runtime.Poller[azureBatchImageGenerationInternalResponse], error) { @@ -46,7 +46,7 @@ func (client *Client) beginAzureBatchImageGeneration(ctx context.Context, body I // AzureBatchImageGenerationInternal - Starts the generation of a batch of images from a text caption // If the operation fails it returns an *azcore.ResponseError type. // -// Generated from API version 2023-06-01-preview +// Generated from API version 2023-07-01-preview func (client *Client) azureBatchImageGenerationInternal(ctx context.Context, body ImageGenerationOptions, options *beginAzureBatchImageGenerationOptions) (*http.Response, error) { var err error req, err := client.azureBatchImageGenerationInternalCreateRequest(ctx, body, options) @@ -58,7 +58,7 @@ func (client *Client) azureBatchImageGenerationInternal(ctx context.Context, bod return nil, err } if !runtime.HasStatusCode(httpResp, http.StatusAccepted) { - err = runtime.NewResponseError(httpResp) + err = client.newError(httpResp) return nil, err } return httpResp, nil @@ -72,7 +72,7 @@ func (client *Client) azureBatchImageGenerationInternalCreateRequest(ctx context return nil, err } reqQP := req.Raw().URL.Query() - reqQP.Set("api-version", "2023-06-01-preview") + reqQP.Set("api-version", "2023-07-01-preview") req.Raw().URL.RawQuery = reqQP.Encode() req.Raw().Header["Accept"] = []string{"application/json"} if err := runtime.MarshalAsJSON(req, body); err != nil { @@ -85,7 +85,7 @@ func (client *Client) azureBatchImageGenerationInternalCreateRequest(ctx context // and generate text that continues from or "completes" provided prompt data. // If the operation fails it returns an *azcore.ResponseError type. // -// Generated from API version 2023-06-01-preview +// Generated from API version 2023-07-01-preview // - options - GetChatCompletionsOptions contains the optional parameters for the Client.GetChatCompletions method. func (client *Client) GetChatCompletions(ctx context.Context, body ChatCompletionsOptions, options *GetChatCompletionsOptions) (GetChatCompletionsResponse, error) { var err error @@ -98,7 +98,7 @@ func (client *Client) GetChatCompletions(ctx context.Context, body ChatCompletio return GetChatCompletionsResponse{}, err } if !runtime.HasStatusCode(httpResp, http.StatusOK) { - err = runtime.NewResponseError(httpResp) + err = client.newError(httpResp) return GetChatCompletionsResponse{}, err } resp, err := client.getChatCompletionsHandleResponse(httpResp) @@ -113,7 +113,7 @@ func (client *Client) getChatCompletionsCreateRequest(ctx context.Context, body return nil, err } reqQP := req.Raw().URL.Query() - reqQP.Set("api-version", "2023-06-01-preview") + reqQP.Set("api-version", "2023-07-01-preview") req.Raw().URL.RawQuery = reqQP.Encode() req.Raw().Header["Accept"] = []string{"application/json"} if err := runtime.MarshalAsJSON(req, body); err != nil { @@ -135,7 +135,7 @@ func (client *Client) getChatCompletionsHandleResponse(resp *http.Response) (Get // text that continues from or "completes" provided prompt data. // If the operation fails it returns an *azcore.ResponseError type. // -// Generated from API version 2023-06-01-preview +// Generated from API version 2023-07-01-preview // - options - GetCompletionsOptions contains the optional parameters for the Client.GetCompletions method. func (client *Client) GetCompletions(ctx context.Context, body CompletionsOptions, options *GetCompletionsOptions) (GetCompletionsResponse, error) { var err error @@ -148,7 +148,7 @@ func (client *Client) GetCompletions(ctx context.Context, body CompletionsOption return GetCompletionsResponse{}, err } if !runtime.HasStatusCode(httpResp, http.StatusOK) { - err = runtime.NewResponseError(httpResp) + err = client.newError(httpResp) return GetCompletionsResponse{}, err } resp, err := client.getCompletionsHandleResponse(httpResp) @@ -163,7 +163,7 @@ func (client *Client) getCompletionsCreateRequest(ctx context.Context, body Comp return nil, err } reqQP := req.Raw().URL.Query() - reqQP.Set("api-version", "2023-06-01-preview") + reqQP.Set("api-version", "2023-07-01-preview") req.Raw().URL.RawQuery = reqQP.Encode() req.Raw().Header["Accept"] = []string{"application/json"} if err := runtime.MarshalAsJSON(req, body); err != nil { @@ -184,7 +184,7 @@ func (client *Client) getCompletionsHandleResponse(resp *http.Response) (GetComp // GetEmbeddings - Return the embeddings for a given prompt. // If the operation fails it returns an *azcore.ResponseError type. // -// Generated from API version 2023-06-01-preview +// Generated from API version 2023-07-01-preview // - options - GetEmbeddingsOptions contains the optional parameters for the Client.GetEmbeddings method. func (client *Client) GetEmbeddings(ctx context.Context, body EmbeddingsOptions, options *GetEmbeddingsOptions) (GetEmbeddingsResponse, error) { var err error @@ -197,7 +197,7 @@ func (client *Client) GetEmbeddings(ctx context.Context, body EmbeddingsOptions, return GetEmbeddingsResponse{}, err } if !runtime.HasStatusCode(httpResp, http.StatusOK) { - err = runtime.NewResponseError(httpResp) + err = client.newError(httpResp) return GetEmbeddingsResponse{}, err } resp, err := client.getEmbeddingsHandleResponse(httpResp) @@ -212,7 +212,7 @@ func (client *Client) getEmbeddingsCreateRequest(ctx context.Context, body Embed return nil, err } reqQP := req.Raw().URL.Query() - reqQP.Set("api-version", "2023-06-01-preview") + reqQP.Set("api-version", "2023-07-01-preview") req.Raw().URL.RawQuery = reqQP.Encode() req.Raw().Header["Accept"] = []string{"application/json"} if err := runtime.MarshalAsJSON(req, body); err != nil { diff --git a/sdk/cognitiveservices/azopenai/client_chat_completions_test.go b/sdk/cognitiveservices/azopenai/client_chat_completions_test.go index a0d60c6fa557..98d6c8fd971a 100644 --- a/sdk/cognitiveservices/azopenai/client_chat_completions_test.go +++ b/sdk/cognitiveservices/azopenai/client_chat_completions_test.go @@ -44,30 +44,33 @@ func TestClient_GetChatCompletions(t *testing.T) { chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, chatCompletionsModelDeployment, newClientOptionsForTest(t)) require.NoError(t, err) - testGetChatCompletions(t, chatClient) + testGetChatCompletions(t, chatClient, true) } func TestClient_GetChatCompletionsStream(t *testing.T) { - cred, err := azopenai.NewKeyCredential(apiKey) - require.NoError(t, err) - - chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, chatCompletionsModelDeployment, newClientOptionsForTest(t)) - require.NoError(t, err) - - testGetChatCompletionsStream(t, chatClient, true) + chatClient := newAzureOpenAIClientForTest(t, canaryChatCompletionsModelDeployment, true) + testGetChatCompletionsStream(t, chatClient) } func TestClient_OpenAI_GetChatCompletions(t *testing.T) { + if testing.Short() { + t.Skip("Skipping OpenAI tests when attempting to do quick tests") + } + chatClient := newOpenAIClientForTest(t) - testGetChatCompletions(t, chatClient) + testGetChatCompletions(t, chatClient, false) } func TestClient_OpenAI_GetChatCompletionsStream(t *testing.T) { + if testing.Short() { + t.Skip("Skipping OpenAI tests when attempting to do quick tests") + } + chatClient := newOpenAIClientForTest(t) - testGetChatCompletionsStream(t, chatClient, false) + testGetChatCompletionsStream(t, chatClient) } -func testGetChatCompletions(t *testing.T, client *azopenai.Client) { +func testGetChatCompletions(t *testing.T, client *azopenai.Client, isAzure bool) { expected := azopenai.ChatCompletions{ Choices: []azopenai.ChatChoice{ { @@ -91,6 +94,15 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client) { resp, err := client.GetChatCompletions(context.Background(), chatCompletionsRequest, nil) require.NoError(t, err) + if isAzure { + // Azure also provides content-filtering. This particular prompt and responses + // will be considered safe. + expected.PromptAnnotations = []azopenai.PromptFilterResult{ + {PromptIndex: to.Ptr[int32](0), ContentFilterResults: (*azopenai.PromptFilterResultContentFilterResults)(safeContentFilter)}, + } + expected.Choices[0].ContentFilterResults = safeContentFilter + } + require.NotEmpty(t, resp.ID) require.NotEmpty(t, resp.Created) @@ -100,17 +112,10 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client) { require.Equal(t, expected, resp.ChatCompletions) } -func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, isAzure bool) { +func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client) { streamResp, err := client.GetChatCompletionsStream(context.Background(), chatCompletionsRequest, nil) require.NoError(t, err) - if isAzure { - // there's a bug right now where the first event comes back empty - // Issue: https://github.com/Azure/azure-sdk-for-go/issues/21086 - _, err := streamResp.ChatCompletionsStream.Read() - require.NoError(t, err) - } - // the data comes back differently for streaming // 1. the text comes back in the ChatCompletion.Delta field // 2. the role is only sent on the first streamed ChatCompletion @@ -125,6 +130,18 @@ func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, isAzure } require.NoError(t, err) + + if completion.PromptAnnotations != nil { + require.Equal(t, []azopenai.PromptFilterResult{ + {PromptIndex: to.Ptr[int32](0), ContentFilterResults: (*azopenai.PromptFilterResultContentFilterResults)(safeContentFilter)}, + }, completion.PromptAnnotations) + } + + if len(completion.Choices) == 0 { + // you can get empty entries that contain just metadata (ie, prompt annotations) + continue + } + require.Equal(t, 1, len(completion.Choices)) choices = append(choices, completion.Choices[0]) } @@ -140,7 +157,6 @@ func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, isAzure } require.Equal(t, expectedContent, message, "Ultimately, the same result as GetChatCompletions(), just sent across the .Delta field instead") - require.Equal(t, azopenai.ChatRoleAssistant, expectedRole) } @@ -167,7 +183,7 @@ func TestClient_GetChatCompletions_DefaultAzureCredential(t *testing.T) { }) require.NoError(t, err) - testGetChatCompletions(t, chatClient) + testGetChatCompletions(t, chatClient, true) } func TestClient_GetChatCompletions_InvalidModel(t *testing.T) { diff --git a/sdk/cognitiveservices/azopenai/client_completions_test.go b/sdk/cognitiveservices/azopenai/client_completions_test.go index 4a78d10e620a..766791501c4f 100644 --- a/sdk/cognitiveservices/azopenai/client_completions_test.go +++ b/sdk/cognitiveservices/azopenai/client_completions_test.go @@ -8,81 +8,61 @@ package azopenai_test import ( "context" - "log" "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/require" ) -func TestClient_GetCompletions(t *testing.T) { - type args struct { - ctx context.Context - deploymentID string - body azopenai.CompletionsOptions - options *azopenai.GetCompletionsOptions - } +func TestClient_GetCompletions_AzureOpenAI(t *testing.T) { cred, err := azopenai.NewKeyCredential(apiKey) require.NoError(t, err) client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, completionsModelDeployment, newClientOptionsForTest(t)) - if err != nil { - log.Fatalf("%v", err) + require.NoError(t, err) + + testGetCompletions(t, client) +} + +func TestClient_GetCompletions_OpenAI(t *testing.T) { + if testing.Short() { + t.Skip("Skipping OpenAI tests when attempting to do quick tests") } - tests := []struct { - name string - client *azopenai.Client - args args - want azopenai.GetCompletionsResponse - wantErr bool - }{ - { - name: "chatbot", - client: client, - args: args{ - ctx: context.TODO(), - deploymentID: completionsModelDeployment, - body: azopenai.CompletionsOptions{ - Prompt: []string{"What is Azure OpenAI?"}, - MaxTokens: to.Ptr(int32(2048 - 127)), - Temperature: to.Ptr(float32(0.0)), + + client := newOpenAIClientForTest(t) + testGetCompletions(t, client) +} + +func testGetCompletions(t *testing.T, client *azopenai.Client) { + resp, err := client.GetCompletions(context.Background(), azopenai.CompletionsOptions{ + Prompt: []string{"What is Azure OpenAI?"}, + MaxTokens: to.Ptr(int32(2048 - 127)), + Temperature: to.Ptr(float32(0.0)), + Model: &openAICompletionsModel, + }, nil) + require.NoError(t, err) + + want := azopenai.GetCompletionsResponse{ + Completions: azopenai.Completions{ + Choices: []azopenai.Choice{ + { + Text: to.Ptr("\n\nAzure OpenAI is a platform from Microsoft that provides access to OpenAI's artificial intelligence (AI) technologies. It enables developers to build, train, and deploy AI models in the cloud. Azure OpenAI provides access to OpenAI's powerful AI technologies, such as GPT-3, which can be used to create natural language processing (NLP) applications, computer vision models, and reinforcement learning models."), + Index: to.Ptr(int32(0)), + FinishReason: to.Ptr(azopenai.CompletionsFinishReason("stop")), + LogProbs: nil, }, - options: nil, }, - want: azopenai.GetCompletionsResponse{ - Completions: azopenai.Completions{ - Choices: []azopenai.Choice{ - { - Text: to.Ptr("\n\nAzure OpenAI is a platform from Microsoft that provides access to OpenAI's artificial intelligence (AI) technologies. It enables developers to build, train, and deploy AI models in the cloud. Azure OpenAI provides access to OpenAI's powerful AI technologies, such as GPT-3, which can be used to create natural language processing (NLP) applications, computer vision models, and reinforcement learning models."), - Index: to.Ptr(int32(0)), - FinishReason: to.Ptr(azopenai.CompletionsFinishReason("stop")), - LogProbs: nil, - }, - }, - Usage: &azopenai.CompletionsUsage{ - CompletionTokens: to.Ptr(int32(85)), - PromptTokens: to.Ptr(int32(6)), - TotalTokens: to.Ptr(int32(91)), - }, - }, + Usage: &azopenai.CompletionsUsage{ + CompletionTokens: to.Ptr(int32(85)), + PromptTokens: to.Ptr(int32(6)), + TotalTokens: to.Ptr(int32(91)), }, - wantErr: false, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.client.GetCompletions(tt.args.ctx, tt.args.body, tt.args.options) - if (err != nil) != tt.wantErr { - t.Errorf("Client.GetCompletions() error = %v, wantErr %v", err, tt.wantErr) - return - } - opts := cmpopts.IgnoreFields(azopenai.Completions{}, "Created", "ID") - if diff := cmp.Diff(tt.want.Completions, got.Completions, opts); diff != "" { - t.Errorf("Client.GetCompletions(): -want, +got:\n%s", diff) - } - }) - } + + want.ID = resp.Completions.ID + want.Created = resp.Completions.Created + + require.Equal(t, want, resp) } diff --git a/sdk/cognitiveservices/azopenai/client_embeddings_test.go b/sdk/cognitiveservices/azopenai/client_embeddings_test.go index 2fb598468d2a..8d61d0c0067d 100644 --- a/sdk/cognitiveservices/azopenai/client_embeddings_test.go +++ b/sdk/cognitiveservices/azopenai/client_embeddings_test.go @@ -27,6 +27,10 @@ func TestClient_GetEmbeddings_InvalidModel(t *testing.T) { } func TestClient_OpenAI_GetEmbeddings(t *testing.T) { + if testing.Short() { + t.Skip("Skipping OpenAI tests when attempting to do quick tests") + } + client := newOpenAIClientForTest(t) modelID := "text-similarity-curie-001" testGetEmbeddings(t, client, modelID) diff --git a/sdk/cognitiveservices/azopenai/client_functions_test.go b/sdk/cognitiveservices/azopenai/client_functions_test.go new file mode 100644 index 000000000000..56deb8ea1abe --- /dev/null +++ b/sdk/cognitiveservices/azopenai/client_functions_test.go @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai_test + +import ( + "context" + "encoding/json" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" + "github.com/stretchr/testify/require" +) + +type Params struct { + Type string `json:"type"` + Properties map[string]ParamProperty `json:"properties"` + Required []string `json:"required,omitempty"` +} + +type ParamProperty struct { + Type string `json:"type"` + Description string `json:"description,omitempty"` + Enum []string `json:"enum,omitempty"` +} + +func getClientForFunctionsTest(t *testing.T, azure bool) *azopenai.Client { + if azure { + cred, err := azopenai.NewKeyCredential(apiKey) + require.NoError(t, err) + + chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, chatCompletionsModelDeployment, newClientOptionsForTest(t)) + require.NoError(t, err) + + return chatClient + } else { + cred, err := azopenai.NewKeyCredential(openAIKey) + require.NoError(t, err) + + chatClient, err := azopenai.NewClientForOpenAI("https://api.openai.com/v1", cred, newClientOptionsForTest(t)) + require.NoError(t, err) + + return chatClient + } +} + +func TestFunctions(t *testing.T) { + // https://platform.openai.com/docs/guides/gpt/function-calling#:~:text=For%20example%2C%20you%20can%3A%201%20Create%20chatbots%20that,...%203%20Extract%20structured%20data%20from%20text%20 + chatClient := getClientForFunctionsTest(t, false) + + resp, err := chatClient.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{ + Model: to.Ptr("gpt-3.5-turbo-0613"), + Messages: []azopenai.ChatMessage{ + { + Role: to.Ptr(azopenai.ChatRoleUser), + Content: to.Ptr("What's the weather like in Boston, MA, in celsius?"), + }, + }, + FunctionCall: &azopenai.ChatCompletionsOptionsFunctionCall{ + Value: to.Ptr("auto"), + }, + Functions: []azopenai.FunctionDefinition{ + { + Name: to.Ptr("get_current_weather"), + Description: to.Ptr("Get the current weather in a given location"), + Parameters: Params{ + Required: []string{"location"}, + Type: "object", + Properties: map[string]ParamProperty{ + "location": { + Type: "string", + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: "string", + Enum: []string{"celsius", "fahrenheit"}, + }, + }, + }, + }, + }, + Temperature: to.Ptr[float32](0.0), + }, nil) + require.NoError(t, err) + + funcCall := resp.ChatCompletions.Choices[0].Message.FunctionCall + + require.Equal(t, "get_current_weather", *funcCall.Name) + + type location struct { + Location string `json:"location"` + Unit string `json:"unit"` + } + + var funcParams *location + err = json.Unmarshal([]byte(*funcCall.Arguments), &funcParams) + require.NoError(t, err) + + require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams) +} diff --git a/sdk/cognitiveservices/azopenai/client_rai_test.go b/sdk/cognitiveservices/azopenai/client_rai_test.go new file mode 100644 index 000000000000..69de5c96c4af --- /dev/null +++ b/sdk/cognitiveservices/azopenai/client_rai_test.go @@ -0,0 +1,93 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai_test + +import ( + "context" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" + "github.com/stretchr/testify/require" +) + +func TestClient_GetCompletions_AzureOpenAI_ContentFilter_Response(t *testing.T) { + // Scenario: Your API call asks for multiple responses (N>1) and at least 1 of the responses is filtered + // https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/cognitive-services/openai/concepts/content-filter.md#scenario-your-api-call-asks-for-multiple-responses-n1-and-at-least-1-of-the-responses-is-filtered + client := newAzureOpenAIClientForTest(t, completionsModelDeployment, false) + + resp, err := client.GetCompletions(context.Background(), azopenai.CompletionsOptions{ + Prompt: []string{"How do I rob a bank?"}, + MaxTokens: to.Ptr(int32(2048 - 127)), + Temperature: to.Ptr(float32(0.0)), + Model: &openAICompletionsModel, + }, nil) + + require.Empty(t, resp) + assertContentFilterError(t, err, false) +} + +func TestClient_GetChatCompletions_AzureOpenAI_ContentFilterWithError(t *testing.T) { + client := newAzureOpenAIClientForTest(t, canaryChatCompletionsModelDeployment, true) + + resp, err := client.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{ + Messages: []azopenai.ChatMessage{ + {Role: to.Ptr(azopenai.ChatRoleSystem), Content: to.Ptr("You are a helpful assistant.")}, + {Role: to.Ptr(azopenai.ChatRoleUser), Content: to.Ptr("How do I rob a bank?")}, + }, + MaxTokens: to.Ptr(int32(2048 - 127)), + Temperature: to.Ptr(float32(0.0)), + Model: &openAIChatCompletionsModel, + }, nil) + require.Empty(t, resp) + assertContentFilterError(t, err, true) +} + +func TestClient_GetChatCompletions_AzureOpenAI_ContentFilter_WithResponse(t *testing.T) { + client := newAzureOpenAIClientForTest(t, canaryChatCompletionsModelDeployment, true) + + resp, err := client.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{ + Messages: []azopenai.ChatMessage{ + {Role: to.Ptr(azopenai.ChatRoleUser), Content: to.Ptr("How do I cook a bell pepper?")}, + }, + MaxTokens: to.Ptr(int32(2048 - 127)), + Temperature: to.Ptr(float32(0.0)), + Model: &openAIChatCompletionsModel, + }, nil) + + require.NoError(t, err) + + require.Equal(t, safeContentFilter, resp.ChatCompletions.Choices[0].ContentFilterResults) +} + +// assertContentFilterError checks that the content filtering error came back from Azure OpenAI. +func assertContentFilterError(t *testing.T, err error, requireAnnotations bool) { + var respErr *azcore.ResponseError + require.ErrorAs(t, err, &respErr) + require.Equal(t, "content_filter", respErr.ErrorCode) + + require.Contains(t, respErr.Error(), "The response was filtered due to the prompt triggering") + + // Azure also returns error information when content filtering happens. + var contentFilterErr *azopenai.ContentFilterResponseError + require.ErrorAs(t, err, &contentFilterErr) + + if requireAnnotations { + require.Equal(t, &azopenai.ContentFilterResultsHate{Filtered: to.Ptr(false), Severity: to.Ptr(azopenai.ContentFilterSeveritySafe)}, contentFilterErr.ContentFilterResults.Hate) + require.Equal(t, &azopenai.ContentFilterResultsSelfHarm{Filtered: to.Ptr(false), Severity: to.Ptr(azopenai.ContentFilterSeveritySafe)}, contentFilterErr.ContentFilterResults.SelfHarm) + require.Equal(t, &azopenai.ContentFilterResultsSexual{Filtered: to.Ptr(false), Severity: to.Ptr(azopenai.ContentFilterSeveritySafe)}, contentFilterErr.ContentFilterResults.Sexual) + require.Equal(t, &azopenai.ContentFilterResultsViolence{Filtered: to.Ptr(true), Severity: to.Ptr(azopenai.ContentFilterSeverityMedium)}, contentFilterErr.ContentFilterResults.Violence) + } +} + +var safeContentFilter = &azopenai.ChatChoiceContentFilterResults{ + Hate: &azopenai.ContentFilterResultsHate{Filtered: to.Ptr(false), Severity: to.Ptr(azopenai.ContentFilterSeveritySafe)}, + SelfHarm: &azopenai.ContentFilterResultsSelfHarm{Filtered: to.Ptr(false), Severity: to.Ptr(azopenai.ContentFilterSeveritySafe)}, + Sexual: &azopenai.ContentFilterResultsSexual{Filtered: to.Ptr(false), Severity: to.Ptr(azopenai.ContentFilterSeveritySafe)}, + Violence: &azopenai.ContentFilterResultsViolence{Filtered: to.Ptr(false), Severity: to.Ptr(azopenai.ContentFilterSeveritySafe)}, +} diff --git a/sdk/cognitiveservices/azopenai/client_shared_test.go b/sdk/cognitiveservices/azopenai/client_shared_test.go index 543175a171a7..e2a51001e3e1 100644 --- a/sdk/cognitiveservices/azopenai/client_shared_test.go +++ b/sdk/cognitiveservices/azopenai/client_shared_test.go @@ -11,6 +11,7 @@ import ( "regexp" "strings" "testing" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" @@ -26,12 +27,32 @@ var ( completionsModelDeployment string // env: AOAI_COMPLETIONS_MODEL_DEPLOYMENT chatCompletionsModelDeployment string // env: AOAI_CHAT_COMPLETIONS_MODEL_DEPLOYMENT + canaryEndpoint string // env: AOAI_ENDPOINT_CANARY + canaryAPIKey string // env: AOAI_API_KEY_CANARY + canaryCompletionsModelDeployment string // env: AOAI_COMPLETIONS_MODEL_DEPLOYMENT_CANARY + canaryChatCompletionsModelDeployment string // env: AOAI_CHAT_COMPLETIONS_MODEL_DEPLOYMENT_CANARY + openAIKey string // env: OPENAI_API_KEY openAIEndpoint string // env: OPENAI_ENDPOINT openAICompletionsModel string // env: OPENAI_CHAT_COMPLETIONS_MODEL openAIChatCompletionsModel string // env: OPENAI_COMPLETIONS_MODEL ) +func getVars(suffix string) (endpoint, apiKey, completionsModelDeployment, chatCompletionsModelDeployment string) { + endpoint = os.Getenv("AOAI_ENDPOINT" + suffix) + + if endpoint != "" && !strings.HasSuffix(endpoint, "/") { + // (this just makes recording replacement easier) + endpoint += "/" + } + + apiKey = os.Getenv("AOAI_API_KEY" + suffix) + completionsModelDeployment = os.Getenv("AOAI_COMPLETIONS_MODEL_DEPLOYMENT" + suffix) + chatCompletionsModelDeployment = os.Getenv("AOAI_CHAT_COMPLETIONS_MODEL_DEPLOYMENT" + suffix) + + return +} + const fakeEndpoint = "https://recordedhost/" const fakeAPIKey = "redacted" @@ -42,6 +63,11 @@ func init() { openAIKey = fakeAPIKey openAIEndpoint = fakeEndpoint + canaryEndpoint = fakeEndpoint + canaryAPIKey = fakeAPIKey + canaryCompletionsModelDeployment = "" + canaryChatCompletionsModelDeployment = "gpt-4" + completionsModelDeployment = "text-davinci-003" openAICompletionsModel = "text-davinci-003" @@ -53,18 +79,8 @@ func init() { os.Exit(1) } - endpoint = os.Getenv("AOAI_ENDPOINT") - - if endpoint != "" && !strings.HasSuffix(endpoint, "/") { - // (this just makes recording replacement easier) - endpoint += "/" - } - - apiKey = os.Getenv("AOAI_API_KEY") - - // Ex: text-davinci-003 - completionsModelDeployment = os.Getenv("AOAI_COMPLETIONS_MODEL_DEPLOYMENT") - chatCompletionsModelDeployment = os.Getenv("AOAI_CHAT_COMPLETIONS_MODEL_DEPLOYMENT") + endpoint, apiKey, completionsModelDeployment, chatCompletionsModelDeployment = getVars("") + canaryEndpoint, canaryAPIKey, canaryCompletionsModelDeployment, canaryChatCompletionsModelDeployment = getVars("_CANARY") openAIKey = os.Getenv("OPENAI_API_KEY") openAIEndpoint = os.Getenv("OPENAI_ENDPOINT") @@ -96,6 +112,9 @@ func newRecordingTransporter(t *testing.T) policy.Transporter { err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(endpoint), nil) require.NoError(t, err) + err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(canaryEndpoint), nil) + require.NoError(t, err) + err = recording.AddURISanitizer("/openai/operations/images/00000000-AAAA-BBBB-CCCC-DDDDDDDDDDDD", "/openai/operations/images/[A-Za-z-0-9]+", nil) require.NoError(t, err) @@ -143,6 +162,53 @@ func newClientOptionsForTest(t *testing.T) *azopenai.ClientOptions { return co } +// newAzureOpenAIClientForTest can create a client pointing to the "canary" endpoint (basically - leading fixes or features) +// or the current deployed endpoint. +func newAzureOpenAIClientForTest(t *testing.T, modelDeploymentID string, useCanary bool) *azopenai.Client { + var apiKey = apiKey + var endpoint = endpoint + + if useCanary { + apiKey = canaryAPIKey + endpoint = canaryEndpoint + } + + cred, err := azopenai.NewKeyCredential(apiKey) + require.NoError(t, err) + + client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, modelDeploymentID, newClientOptionsForTest(t)) + require.NoError(t, err) + + return client +} + +func newOpenAIClientForTest(t *testing.T) *azopenai.Client { + if openAIKey == "" { + t.Skipf("OPENAI_API_KEY not defined, skipping OpenAI public endpoint test") + } + + cred, err := azopenai.NewKeyCredential(openAIKey) + require.NoError(t, err) + + // we get rate limited quite a bit. + options := newClientOptionsForTest(t) + + if options == nil { + options = &azopenai.ClientOptions{} + } + + options.Retry = policy.RetryOptions{ + MaxRetries: 60, + RetryDelay: time.Second, + MaxRetryDelay: time.Second, + } + + chatClient, err := azopenai.NewClientForOpenAI(openAIEndpoint, cred, newClientOptionsForTest(t)) + require.NoError(t, err) + + return chatClient +} + // newBogusAzureOpenAIClient creates a client that uses an invalid key, which will cause Azure OpenAI to return // a failure. func newBogusAzureOpenAIClient(t *testing.T, modelDeploymentID string) *azopenai.Client { diff --git a/sdk/cognitiveservices/azopenai/client_test.go b/sdk/cognitiveservices/azopenai/client_test.go index 2de81f43ede5..0186ec003003 100644 --- a/sdk/cognitiveservices/azopenai/client_test.go +++ b/sdk/cognitiveservices/azopenai/client_test.go @@ -14,10 +14,15 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" + "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" "github.com/stretchr/testify/require" ) func TestClient_OpenAI_InvalidModel(t *testing.T) { + if recording.GetRecordMode() == recording.PlaybackMode || testing.Short() { + t.Skip() + } + chatClient := newOpenAIClientForTest(t) _, err := chatClient.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{ @@ -35,17 +40,3 @@ func TestClient_OpenAI_InvalidModel(t *testing.T) { require.Equal(t, http.StatusNotFound, respErr.StatusCode) require.Contains(t, respErr.Error(), "The model `non-existent-model` does not exist") } - -func newOpenAIClientForTest(t *testing.T) *azopenai.Client { - if openAIKey == "" { - t.Skipf("OPENAI_API_KEY not defined, skipping OpenAI public endpoint test") - } - - cred, err := azopenai.NewKeyCredential(openAIKey) - require.NoError(t, err) - - chatClient, err := azopenai.NewClientForOpenAI(openAIEndpoint, cred, newClientOptionsForTest(t)) - require.NoError(t, err) - - return chatClient -} diff --git a/sdk/cognitiveservices/azopenai/constants.go b/sdk/cognitiveservices/azopenai/constants.go index a6fb5b35b0ec..1c3925d7c9ed 100644 --- a/sdk/cognitiveservices/azopenai/constants.go +++ b/sdk/cognitiveservices/azopenai/constants.go @@ -24,6 +24,7 @@ type ChatRole string const ( ChatRoleAssistant ChatRole = "assistant" + ChatRoleFunction ChatRole = "function" ChatRoleSystem ChatRole = "system" ChatRoleUser ChatRole = "user" ) @@ -32,6 +33,7 @@ const ( func PossibleChatRoleValues() []ChatRole { return []ChatRole{ ChatRoleAssistant, + ChatRoleFunction, ChatRoleSystem, ChatRoleUser, } @@ -42,6 +44,7 @@ type CompletionsFinishReason string const ( CompletionsFinishReasonContentFilter CompletionsFinishReason = "content_filter" + CompletionsFinishReasonFunctionCall CompletionsFinishReason = "function_call" CompletionsFinishReasonLength CompletionsFinishReason = "length" CompletionsFinishReasonStop CompletionsFinishReason = "stop" ) @@ -50,11 +53,49 @@ const ( func PossibleCompletionsFinishReasonValues() []CompletionsFinishReason { return []CompletionsFinishReason{ CompletionsFinishReasonContentFilter, + CompletionsFinishReasonFunctionCall, CompletionsFinishReasonLength, CompletionsFinishReasonStop, } } +// ContentFilterSeverity - Ratings for the intensity and risk level of harmful content. +type ContentFilterSeverity string + +const ( + ContentFilterSeverityHigh ContentFilterSeverity = "high" + ContentFilterSeverityLow ContentFilterSeverity = "low" + ContentFilterSeverityMedium ContentFilterSeverity = "medium" + ContentFilterSeveritySafe ContentFilterSeverity = "safe" +) + +// PossibleContentFilterSeverityValues returns the possible values for the ContentFilterSeverity const type. +func PossibleContentFilterSeverityValues() []ContentFilterSeverity { + return []ContentFilterSeverity{ + ContentFilterSeverityHigh, + ContentFilterSeverityLow, + ContentFilterSeverityMedium, + ContentFilterSeveritySafe, + } +} + +// FunctionCallPreset - The collection of predefined behaviors for handling request-provided function information in a chat +// completions operation. +type FunctionCallPreset string + +const ( + FunctionCallPresetAuto FunctionCallPreset = "auto" + FunctionCallPresetNone FunctionCallPreset = "none" +) + +// PossibleFunctionCallPresetValues returns the possible values for the FunctionCallPreset const type. +func PossibleFunctionCallPresetValues() []FunctionCallPreset { + return []FunctionCallPreset{ + FunctionCallPresetAuto, + FunctionCallPresetNone, + } +} + // ImageGenerationResponseFormat - The format in which the generated images are returned. type ImageGenerationResponseFormat string diff --git a/sdk/cognitiveservices/azopenai/custom_client.go b/sdk/cognitiveservices/azopenai/custom_client.go index 72591b215e14..6fba8cab1756 100644 --- a/sdk/cognitiveservices/azopenai/custom_client.go +++ b/sdk/cognitiveservices/azopenai/custom_client.go @@ -231,6 +231,10 @@ func (client *Client) formatURL(path string) string { } } +func (client *Client) newError(resp *http.Response) error { + return newContentFilterResponseError(resp) +} + type clientData struct { endpoint string baseEndpoint string diff --git a/sdk/cognitiveservices/azopenai/custom_client_functions.go b/sdk/cognitiveservices/azopenai/custom_client_functions.go new file mode 100644 index 000000000000..87b6b92ead9b --- /dev/null +++ b/sdk/cognitiveservices/azopenai/custom_client_functions.go @@ -0,0 +1,34 @@ +package azopenai + +import ( + "encoding/json" + "errors" +) + +// ChatCompletionsOptionsFunctionCall - Controls how the model responds to function calls. "none" means the model does not +// call a function, and responds to the end-user. "auto" means the model can pick between an end-user or calling a +// function. Specifying a particular function via {"name": "my_function"} forces the model to call that function. "none" is +// the default when no functions are present. "auto" is the default if functions +// are present. +type ChatCompletionsOptionsFunctionCall struct { + // IsFunction is true if Value refers to a function name. + IsFunction bool + + // Value is one of: + // - "auto", meaning the model can pick between an end-user or calling a function + // - "none", meaning the model does not call a function, + // - name of a function, in which case [IsFunction] should be set to true. + Value *string +} + +func (c ChatCompletionsOptionsFunctionCall) MarshalJSON() ([]byte, error) { + if c.IsFunction { + if c.Value == nil { + return nil, errors.New("the Value should be the function name to call, not nil") + } + + return json.Marshal(map[string]string{"name": *c.Value}) + } + + return json.Marshal(c.Value) +} diff --git a/sdk/cognitiveservices/azopenai/custom_client_image_test.go b/sdk/cognitiveservices/azopenai/custom_client_image_test.go index 661364cc8e28..b7bd5a39f8c8 100644 --- a/sdk/cognitiveservices/azopenai/custom_client_image_test.go +++ b/sdk/cognitiveservices/azopenai/custom_client_image_test.go @@ -36,6 +36,10 @@ func TestImageGeneration_AzureOpenAI(t *testing.T) { } func TestImageGeneration_OpenAI(t *testing.T) { + if testing.Short() { + t.Skip("Skipping OpenAI tests when attempting to do quick tests") + } + client := newOpenAIClientForTest(t) testImageGeneration(t, client, azopenai.ImageGenerationResponseFormatURL) } @@ -50,8 +54,8 @@ func TestImageGeneration_AzureOpenAI_WithError(t *testing.T) { } func TestImageGeneration_OpenAI_WithError(t *testing.T) { - if recording.GetRecordMode() == recording.PlaybackMode { - t.Skip() + if testing.Short() { + t.Skip("Skipping OpenAI tests when attempting to do quick tests") } client := newBogusOpenAIClient(t) @@ -59,6 +63,10 @@ func TestImageGeneration_OpenAI_WithError(t *testing.T) { } func TestImageGeneration_OpenAI_Base64(t *testing.T) { + if testing.Short() { + t.Skip("Skipping OpenAI tests when attempting to do quick tests") + } + client := newOpenAIClientForTest(t) testImageGeneration(t, client, azopenai.ImageGenerationResponseFormatB64JSON) } diff --git a/sdk/cognitiveservices/azopenai/custom_client_test.go b/sdk/cognitiveservices/azopenai/custom_client_test.go index 2319f93378a1..a45aeffd15af 100644 --- a/sdk/cognitiveservices/azopenai/custom_client_test.go +++ b/sdk/cognitiveservices/azopenai/custom_client_test.go @@ -85,15 +85,19 @@ func TestGetCompletionsStream_AzureOpenAI(t *testing.T) { client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, completionsModelDeployment, newClientOptionsForTest(t)) require.NoError(t, err) - testGetCompletionsStream(t, client, true) + testGetCompletionsStream(t, client) } func TestGetCompletionsStream_OpenAI(t *testing.T) { + if testing.Short() { + t.Skip("Skipping OpenAI tests when attempting to do quick tests") + } + client := newOpenAIClientForTest(t) - testGetCompletionsStream(t, client, false) + testGetCompletionsStream(t, client) } -func testGetCompletionsStream(t *testing.T, client *azopenai.Client, isAzure bool) { +func testGetCompletionsStream(t *testing.T, client *azopenai.Client) { body := azopenai.CompletionsOptions{ Prompt: []string{"What is Azure OpenAI?"}, MaxTokens: to.Ptr(int32(2048)), @@ -102,6 +106,8 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, isAzure boo } response, err := client.GetCompletionsStream(context.TODO(), body, nil) + require.NoError(t, err) + if err != nil { t.Errorf("Client.GetCompletionsStream() error = %v", err) return @@ -112,25 +118,29 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, isAzure boo var sb strings.Builder var eventCount int - if isAzure { - // there's a bug right now where the first event comes back empty - // Issue: https://github.com/Azure/azure-sdk-for-go/issues/21086 - _, err := reader.Read() - require.NoError(t, err) - } - for { - event, err := reader.Read() + completion, err := reader.Read() + if err == io.EOF { break } + + if completion.PromptAnnotations != nil { + require.Equal(t, []azopenai.PromptFilterResult{ + {PromptIndex: to.Ptr[int32](0), ContentFilterResults: (*azopenai.PromptFilterResultContentFilterResults)(safeContentFilter)}, + }, completion.PromptAnnotations) + } + eventCount++ + if err != nil { t.Errorf("reader.Read() error = %v", err) return } - sb.WriteString(*event.Choices[0].Text) + if len(completion.Choices) > 0 { + sb.WriteString(*completion.Choices[0].Text) + } } got := sb.String() const want = "\n\nAzure OpenAI is a platform from Microsoft that provides access to OpenAI's artificial intelligence (AI) technologies. It enables developers to build, train, and deploy AI models in the cloud. Azure OpenAI provides access to OpenAI's powerful AI technologies, such as GPT-3, which can be used to create natural language processing (NLP) applications, computer vision models, and reinforcement learning models." diff --git a/sdk/cognitiveservices/azopenai/custom_models.go b/sdk/cognitiveservices/azopenai/custom_models.go index c77252ebdf39..8dc45bf3c0e4 100644 --- a/sdk/cognitiveservices/azopenai/custom_models.go +++ b/sdk/cognitiveservices/azopenai/custom_models.go @@ -6,6 +6,14 @@ package azopenai +import ( + "encoding/json" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" +) + // Models for methods that return streaming response // GetCompletionsStreamOptions contains the optional parameters for the [Client.GetCompletionsStream] method. @@ -43,3 +51,47 @@ type ImageGenerationsDataItem struct { // to [ImageGenerationResponseFormatURL]. URL *string `json:"url"` } + +// ContentFilterResponseError is an error as a result of a request being filtered. +type ContentFilterResponseError struct { + azcore.ResponseError + + // ContentFilterResults contains Information about the content filtering category, if it has been detected. + ContentFilterResults *ContentFilterResults +} + +// Unwrap returns the inner error for this error. +func (e ContentFilterResponseError) Unwrap() error { + return &e.ResponseError +} + +func newContentFilterResponseError(resp *http.Response) error { + respErr := runtime.NewResponseError(resp).(*azcore.ResponseError) + + if respErr.ErrorCode != "content_filter" { + return respErr + } + + body, err := runtime.Payload(resp) + + if err != nil { + return err + } + + var envelope *struct { + Error struct { + InnerError struct { + FilterResult *ContentFilterResults `json:"content_filter_result"` + } `json:"innererror"` + } + } + + if err := json.Unmarshal(body, &envelope); err != nil { + return err + } + + return &ContentFilterResponseError{ + ResponseError: *respErr, + ContentFilterResults: envelope.Error.InnerError.FilterResult, + } +} diff --git a/sdk/cognitiveservices/azopenai/custom_models_test.go b/sdk/cognitiveservices/azopenai/custom_models_test.go new file mode 100644 index 000000000000..909a2c6dcd58 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/custom_models_test.go @@ -0,0 +1,66 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai + +import ( + "bytes" + "io" + "net/http" + "net/url" + "os" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/stretchr/testify/require" +) + +func TestParseResponseError(t *testing.T) { + bodyBytes, err := os.ReadFile("testdata/content_filter_response_error.json") + require.NoError(t, err) + + buff := bytes.NewBuffer(bodyBytes) + + fakeURL, err := url.Parse("https://openai-something.microsoft.com") + require.NoError(t, err) + + resp := &http.Response{ + StatusCode: http.StatusBadRequest, + Body: io.NopCloser(buff), + Request: &http.Request{ + Method: "POST", + URL: fakeURL, + }, + } + + err = newContentFilterResponseError(resp) + + // this is the outer error, which is the standard Azure response error. + var respErr *azcore.ResponseError + require.ErrorAs(t, err, &respErr) + require.Equal(t, http.StatusBadRequest, respErr.StatusCode) + require.Equal(t, "content_filter", respErr.ErrorCode) + + // Azure also returns error information when content filtering happens. + var contentFilterErr *ContentFilterResponseError + require.ErrorAs(t, err, &contentFilterErr) + + // we're still a response error + require.Equal(t, http.StatusBadRequest, respErr.StatusCode) + require.Equal(t, "content_filter", respErr.ErrorCode) + + contentFilterResults := contentFilterErr.ContentFilterResults + + // thsi comment was considered violent, so it was filtered. + require.Equal(t, &ContentFilterResultsViolence{ + Filtered: to.Ptr(true), + Severity: to.Ptr(ContentFilterSeverityMedium)}, contentFilterResults.Violence) + + require.Equal(t, &ContentFilterResultsHate{Filtered: to.Ptr(false), Severity: to.Ptr(ContentFilterSeveritySafe)}, contentFilterResults.Hate) + require.Equal(t, &ContentFilterResultsSelfHarm{Filtered: to.Ptr(false), Severity: to.Ptr(ContentFilterSeveritySafe)}, contentFilterResults.SelfHarm) + require.Equal(t, &ContentFilterResultsSexual{Filtered: to.Ptr(false), Severity: to.Ptr(ContentFilterSeveritySafe)}, contentFilterResults.Sexual) +} diff --git a/sdk/cognitiveservices/azopenai/event_reader.go b/sdk/cognitiveservices/azopenai/event_reader.go index c98b74ddd6af..28aea580e849 100644 --- a/sdk/cognitiveservices/azopenai/event_reader.go +++ b/sdk/cognitiveservices/azopenai/event_reader.go @@ -46,7 +46,6 @@ func (er *EventReader[T]) Read() (T, error) { } err := json.Unmarshal([]byte(tokens[1]), &data) return data, err - default: // Any other event type is an unexpected return data, errors.New("Unexpected event type: " + tokens[0]) } diff --git a/sdk/cognitiveservices/azopenai/event_reader_test.go b/sdk/cognitiveservices/azopenai/event_reader_test.go index cecca41d0828..0644c6b43b91 100644 --- a/sdk/cognitiveservices/azopenai/event_reader_test.go +++ b/sdk/cognitiveservices/azopenai/event_reader_test.go @@ -25,7 +25,6 @@ func TestEventReader_InvalidType(t *testing.T) { firstEvent, err := eventReader.Read() require.Empty(t, firstEvent) require.EqualError(t, err, "Unexpected event type: invaliddata") - } type badReader struct{} diff --git a/sdk/cognitiveservices/azopenai/go.mod b/sdk/cognitiveservices/azopenai/go.mod index 710379f09aa8..c78eaaad2109 100644 --- a/sdk/cognitiveservices/azopenai/go.mod +++ b/sdk/cognitiveservices/azopenai/go.mod @@ -6,7 +6,6 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 - github.com/google/go-cmp v0.5.9 github.com/joho/godotenv v1.3.0 github.com/stretchr/testify v1.7.0 ) diff --git a/sdk/cognitiveservices/azopenai/go.sum b/sdk/cognitiveservices/azopenai/go.sum index d3d94d1ae54a..10901f2ec9f9 100644 --- a/sdk/cognitiveservices/azopenai/go.sum +++ b/sdk/cognitiveservices/azopenai/go.sum @@ -13,8 +13,6 @@ github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= diff --git a/sdk/cognitiveservices/azopenai/models.go b/sdk/cognitiveservices/azopenai/models.go index 9a6ee087b5b0..35afc8aad73b 100644 --- a/sdk/cognitiveservices/azopenai/models.go +++ b/sdk/cognitiveservices/azopenai/models.go @@ -111,6 +111,11 @@ type ChatChoice struct { // REQUIRED; The ordered index associated with this chat completions choice. Index *int32 + // Information about the content filtering category (hate, sexual, violence, selfharm), if it has been detected, as well as + // the severity level (verylow, low, medium, high-scale that determines the + // intensity and risk level of harmful content) and if it has been filtered or not. + ContentFilterResults *ChatChoiceContentFilterResults + // The delta message content for a streaming response. Delta *ChatChoiceDelta @@ -118,6 +123,30 @@ type ChatChoice struct { Message *ChatChoiceMessage } +// ChatChoiceContentFilterResults - Information about the content filtering category (hate, sexual, violence, selfharm), if +// it has been detected, as well as the severity level (verylow, low, medium, high-scale that determines the +// intensity and risk level of harmful content) and if it has been filtered or not. +type ChatChoiceContentFilterResults struct { + // REQUIRED; Describes language attacks or uses that include pejorative or discriminatory language with reference to a person + // or identity group on the basis of certain differentiating attributes of these groups + // including but not limited to race, ethnicity, nationality, gender identity and expression, sexual orientation, religion, + // immigration status, ability status, personal appearance, and body size. + Hate *ContentFilterResultsHate + + // REQUIRED; Describes language related to physical actions intended to purposely hurt, injure, or damage one’s body, or kill + // oneself. + SelfHarm *ContentFilterResultsSelfHarm + + // REQUIRED; Describes language related to anatomical organs and genitals, romantic relationships, acts portrayed in erotic + // or affectionate terms, physical sexual acts, including those portrayed as an assault or a + // forced sexual violent act against one’s will, prostitution, pornography, and abuse. + Sexual *ContentFilterResultsSexual + + // REQUIRED; Describes language related to physical actions intended to hurt, injure, damage, or kill someone or something; + // describes weapons, etc. + Violence *ContentFilterResultsViolence +} + // ChatChoiceDelta - The delta message content for a streaming response. type ChatChoiceDelta struct { // REQUIRED; The role associated with this message payload. @@ -125,6 +154,14 @@ type ChatChoiceDelta struct { // The text associated with this message payload. Content *string + + // The name and arguments of a function that should be called, as generated by the model. + FunctionCall *ChatMessageFunctionCall + + // The name of the author of this message. name is required if role is function, and it should be the name of the function + // whose response is in the content. May contain a-z, A-Z, 0-9, and underscores, + // with a maximum length of 64 characters. + Name *string } // ChatChoiceMessage - The chat message for a given chat completions prompt. @@ -134,6 +171,14 @@ type ChatChoiceMessage struct { // The text associated with this message payload. Content *string + + // The name and arguments of a function that should be called, as generated by the model. + FunctionCall *ChatMessageFunctionCall + + // The name of the author of this message. name is required if role is function, and it should be the name of the function + // whose response is in the content. May contain a-z, A-Z, 0-9, and underscores, + // with a maximum length of 64 characters. + Name *string } // ChatCompletions - Representation of the response data from a chat completions request. Completions support a wide variety @@ -153,6 +198,10 @@ type ChatCompletions struct { // REQUIRED; Usage information for tokens processed and generated as part of this completions operation. Usage *CompletionsUsage + + // Content filtering results for zero or more prompts in the request. In a streaming request, results for different prompts + // may arrive at different times or in different orders. + PromptAnnotations []PromptFilterResult } // ChatCompletionsOptions - The configuration information for a chat completions request. Completions support a wide variety @@ -168,6 +217,16 @@ type ChatCompletionsOptions struct { // increases and decrease the likelihood of the model repeating the same statements verbatim. FrequencyPenalty *float32 + // Controls how the model responds to function calls. "none" means the model does not call a function, and responds to the + // end-user. "auto" means the model can pick between an end-user or calling a + // function. Specifying a particular function via {"name": "my_function"} forces the model to call that function. "none" is + // the default when no functions are present. "auto" is the default if functions + // are present. + FunctionCall *ChatCompletionsOptionsFunctionCall + + // A list of functions the model may generate JSON inputs for. + Functions []FunctionDefinition + // A map between GPT token IDs and bias scores that influences the probability of specific tokens appearing in a completions // response. Token IDs are computed via external tokenizer tools, while bias // scores reside in the range of -100 to 100 with minimum and maximum values corresponding to a full ban or exclusive selection @@ -219,6 +278,25 @@ type ChatMessage struct { // The text associated with this message payload. Content *string + + // The name and arguments of a function that should be called, as generated by the model. + FunctionCall *ChatMessageFunctionCall + + // The name of the author of this message. name is required if role is function, and it should be the name of the function + // whose response is in the content. May contain a-z, A-Z, 0-9, and underscores, + // with a maximum length of 64 characters. + Name *string +} + +// ChatMessageFunctionCall - The name and arguments of a function that should be called, as generated by the model. +type ChatMessageFunctionCall struct { + // REQUIRED; The arguments to call the function with, as generated by the model in JSON format. Note that the model does not + // always generate valid JSON, and may hallucinate parameters not defined by your function + // schema. Validate the arguments in your code before calling your function. + Arguments *string + + // REQUIRED; The name of the function to call. + Name *string } // Choice - The representation of a single prompt completion as part of an overall completions request. Generally, n choices @@ -236,6 +314,35 @@ type Choice struct { // REQUIRED; The generated text for a given completions prompt. Text *string + + // Information about the content filtering category (hate, sexual, violence, selfharm), if it has been detected, as well as + // the severity level (verylow, low, medium, high-scale that determines the + // intensity and risk level of harmful content) and if it has been filtered or not. + ContentFilterResults *ChoiceContentFilterResults +} + +// ChoiceContentFilterResults - Information about the content filtering category (hate, sexual, violence, selfharm), if it +// has been detected, as well as the severity level (verylow, low, medium, high-scale that determines the +// intensity and risk level of harmful content) and if it has been filtered or not. +type ChoiceContentFilterResults struct { + // REQUIRED; Describes language attacks or uses that include pejorative or discriminatory language with reference to a person + // or identity group on the basis of certain differentiating attributes of these groups + // including but not limited to race, ethnicity, nationality, gender identity and expression, sexual orientation, religion, + // immigration status, ability status, personal appearance, and body size. + Hate *ContentFilterResultsHate + + // REQUIRED; Describes language related to physical actions intended to purposely hurt, injure, or damage one’s body, or kill + // oneself. + SelfHarm *ContentFilterResultsSelfHarm + + // REQUIRED; Describes language related to anatomical organs and genitals, romantic relationships, acts portrayed in erotic + // or affectionate terms, physical sexual acts, including those portrayed as an assault or a + // forced sexual violent act against one’s will, prostitution, pornography, and abuse. + Sexual *ContentFilterResultsSexual + + // REQUIRED; Describes language related to physical actions intended to hurt, injure, damage, or kill someone or something; + // describes weapons, etc. + Violence *ContentFilterResultsViolence } // ChoiceLogProbs - The log probabilities model for tokens associated with this completions choice. @@ -270,6 +377,10 @@ type Completions struct { // REQUIRED; Usage information for tokens processed and generated as part of this completions operation. Usage *CompletionsUsage + + // Content filtering results for zero or more prompts in the request. In a streaming request, results for different prompts + // may arrive at different times or in different orders. + PromptAnnotations []PromptFilterResult } // CompletionsLogProbabilityModel - Representation of a log probabilities model for a completions generation. @@ -368,9 +479,75 @@ type CompletionsUsage struct { TotalTokens *int32 } +// ContentFilterResults - Information about the content filtering category, if it has been detected. +type ContentFilterResults struct { + // REQUIRED; Describes language attacks or uses that include pejorative or discriminatory language with reference to a person + // or identity group on the basis of certain differentiating attributes of these groups + // including but not limited to race, ethnicity, nationality, gender identity and expression, sexual orientation, religion, + // immigration status, ability status, personal appearance, and body size. + Hate *ContentFilterResultsHate + + // REQUIRED; Describes language related to physical actions intended to purposely hurt, injure, or damage one’s body, or kill + // oneself. + SelfHarm *ContentFilterResultsSelfHarm + + // REQUIRED; Describes language related to anatomical organs and genitals, romantic relationships, acts portrayed in erotic + // or affectionate terms, physical sexual acts, including those portrayed as an assault or a + // forced sexual violent act against one’s will, prostitution, pornography, and abuse. + Sexual *ContentFilterResultsSexual + + // REQUIRED; Describes language related to physical actions intended to hurt, injure, damage, or kill someone or something; + // describes weapons, etc. + Violence *ContentFilterResultsViolence +} + +// ContentFilterResultsHate - Describes language attacks or uses that include pejorative or discriminatory language with reference +// to a person or identity group on the basis of certain differentiating attributes of these groups +// including but not limited to race, ethnicity, nationality, gender identity and expression, sexual orientation, religion, +// immigration status, ability status, personal appearance, and body size. +type ContentFilterResultsHate struct { + // REQUIRED; A value indicating whether or not the content has been filtered. + Filtered *bool + + // REQUIRED; Ratings for the intensity and risk level of filtered content. + Severity *ContentFilterSeverity +} + +// ContentFilterResultsSelfHarm - Describes language related to physical actions intended to purposely hurt, injure, or damage +// one’s body, or kill oneself. +type ContentFilterResultsSelfHarm struct { + // REQUIRED; A value indicating whether or not the content has been filtered. + Filtered *bool + + // REQUIRED; Ratings for the intensity and risk level of filtered content. + Severity *ContentFilterSeverity +} + +// ContentFilterResultsSexual - Describes language related to anatomical organs and genitals, romantic relationships, acts +// portrayed in erotic or affectionate terms, physical sexual acts, including those portrayed as an assault or a +// forced sexual violent act against one’s will, prostitution, pornography, and abuse. +type ContentFilterResultsSexual struct { + // REQUIRED; A value indicating whether or not the content has been filtered. + Filtered *bool + + // REQUIRED; Ratings for the intensity and risk level of filtered content. + Severity *ContentFilterSeverity +} + +// ContentFilterResultsViolence - Describes language related to physical actions intended to hurt, injure, damage, or kill +// someone or something; describes weapons, etc. +type ContentFilterResultsViolence struct { + // REQUIRED; A value indicating whether or not the content has been filtered. + Filtered *bool + + // REQUIRED; Ratings for the intensity and risk level of filtered content. + Severity *ContentFilterSeverity +} + // Deployment - A specific deployment type Deployment struct { - // READ-ONLY; deployment id of the deployed model + // READ-ONLY; Specifies either the model deployment name (when using Azure OpenAI) or model name (when using non-Azure OpenAI) + // to use for this request. DeploymentID *string } @@ -430,6 +607,38 @@ type EmbeddingsUsageAutoGenerated struct { TotalTokens *int32 } +// FunctionCall - The name and arguments of a function that should be called, as generated by the model. +type FunctionCall struct { + // REQUIRED; The arguments to call the function with, as generated by the model in JSON format. Note that the model does not + // always generate valid JSON, and may hallucinate parameters not defined by your function + // schema. Validate the arguments in your code before calling your function. + Arguments *string + + // REQUIRED; The name of the function to call. + Name *string +} + +// FunctionDefinition - The definition of a caller-specified function that chat completions may invoke in response to matching +// user input. +type FunctionDefinition struct { + // REQUIRED; The name of the function to be called. + Name *string + + // A description of what the function does. The model will use this description when selecting the function and interpreting + // its parameters. + Description *string + + // The parameters the functions accepts, described as a JSON Schema object. + Parameters any +} + +// FunctionName - A structure that specifies the exact name of a specific, request-provided function to use when processing +// a chat completions operation. +type FunctionName struct { + // REQUIRED; The name of the function to call. + Name *string +} + // ImageGenerationOptions - Represents the request data used to generate images. type ImageGenerationOptions struct { // REQUIRED; A description of the desired images. @@ -468,3 +677,34 @@ type ImagePayload struct { // REQUIRED; The complete data for an image represented as a base64-encoded string. B64JSON *string } + +// PromptFilterResult - Content filtering results for a single prompt in the request. +type PromptFilterResult struct { + // REQUIRED; The index of this prompt in the set of prompt results + PromptIndex *int32 + + // Content filtering results for this prompt + ContentFilterResults *PromptFilterResultContentFilterResults +} + +// PromptFilterResultContentFilterResults - Content filtering results for this prompt +type PromptFilterResultContentFilterResults struct { + // REQUIRED; Describes language attacks or uses that include pejorative or discriminatory language with reference to a person + // or identity group on the basis of certain differentiating attributes of these groups + // including but not limited to race, ethnicity, nationality, gender identity and expression, sexual orientation, religion, + // immigration status, ability status, personal appearance, and body size. + Hate *ContentFilterResultsHate + + // REQUIRED; Describes language related to physical actions intended to purposely hurt, injure, or damage one’s body, or kill + // oneself. + SelfHarm *ContentFilterResultsSelfHarm + + // REQUIRED; Describes language related to anatomical organs and genitals, romantic relationships, acts portrayed in erotic + // or affectionate terms, physical sexual acts, including those portrayed as an assault or a + // forced sexual violent act against one’s will, prostitution, pornography, and abuse. + Sexual *ContentFilterResultsSexual + + // REQUIRED; Describes language related to physical actions intended to hurt, injure, damage, or kill someone or something; + // describes weapons, etc. + Violence *ContentFilterResultsViolence +} diff --git a/sdk/cognitiveservices/azopenai/models_serde.go b/sdk/cognitiveservices/azopenai/models_serde.go index 89b65bf201fd..4b591a60fca8 100644 --- a/sdk/cognitiveservices/azopenai/models_serde.go +++ b/sdk/cognitiveservices/azopenai/models_serde.go @@ -272,6 +272,7 @@ func (b *batchImageGenerationOperationResponse) UnmarshalJSON(data []byte) error // MarshalJSON implements the json.Marshaller interface for type ChatChoice. func (c ChatChoice) MarshalJSON() ([]byte, error) { objectMap := make(map[string]any) + populate(objectMap, "content_filter_results", c.ContentFilterResults) populate(objectMap, "delta", c.Delta) populate(objectMap, "finish_reason", c.FinishReason) populate(objectMap, "index", c.Index) @@ -288,6 +289,9 @@ func (c *ChatChoice) UnmarshalJSON(data []byte) error { for key, val := range rawMsg { var err error switch key { + case "content_filter_results": + err = unpopulate(val, "ContentFilterResults", &c.ContentFilterResults) + delete(rawMsg, key) case "delta": err = unpopulate(val, "Delta", &c.Delta) delete(rawMsg, key) @@ -308,10 +312,51 @@ func (c *ChatChoice) UnmarshalJSON(data []byte) error { return nil } +// MarshalJSON implements the json.Marshaller interface for type ChatChoiceContentFilterResults. +func (c ChatChoiceContentFilterResults) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "hate", c.Hate) + populate(objectMap, "self_harm", c.SelfHarm) + populate(objectMap, "sexual", c.Sexual) + populate(objectMap, "violence", c.Violence) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ChatChoiceContentFilterResults. +func (c *ChatChoiceContentFilterResults) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "hate": + err = unpopulate(val, "Hate", &c.Hate) + delete(rawMsg, key) + case "self_harm": + err = unpopulate(val, "SelfHarm", &c.SelfHarm) + delete(rawMsg, key) + case "sexual": + err = unpopulate(val, "Sexual", &c.Sexual) + delete(rawMsg, key) + case "violence": + err = unpopulate(val, "Violence", &c.Violence) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + // MarshalJSON implements the json.Marshaller interface for type ChatChoiceDelta. func (c ChatChoiceDelta) MarshalJSON() ([]byte, error) { objectMap := make(map[string]any) populate(objectMap, "content", c.Content) + populate(objectMap, "function_call", c.FunctionCall) + populate(objectMap, "name", c.Name) populate(objectMap, "role", c.Role) return json.Marshal(objectMap) } @@ -328,6 +373,12 @@ func (c *ChatChoiceDelta) UnmarshalJSON(data []byte) error { case "content": err = unpopulate(val, "Content", &c.Content) delete(rawMsg, key) + case "function_call": + err = unpopulate(val, "FunctionCall", &c.FunctionCall) + delete(rawMsg, key) + case "name": + err = unpopulate(val, "Name", &c.Name) + delete(rawMsg, key) case "role": err = unpopulate(val, "Role", &c.Role) delete(rawMsg, key) @@ -343,6 +394,8 @@ func (c *ChatChoiceDelta) UnmarshalJSON(data []byte) error { func (c ChatChoiceMessage) MarshalJSON() ([]byte, error) { objectMap := make(map[string]any) populate(objectMap, "content", c.Content) + populate(objectMap, "function_call", c.FunctionCall) + populate(objectMap, "name", c.Name) populate(objectMap, "role", c.Role) return json.Marshal(objectMap) } @@ -359,6 +412,12 @@ func (c *ChatChoiceMessage) UnmarshalJSON(data []byte) error { case "content": err = unpopulate(val, "Content", &c.Content) delete(rawMsg, key) + case "function_call": + err = unpopulate(val, "FunctionCall", &c.FunctionCall) + delete(rawMsg, key) + case "name": + err = unpopulate(val, "Name", &c.Name) + delete(rawMsg, key) case "role": err = unpopulate(val, "Role", &c.Role) delete(rawMsg, key) @@ -376,6 +435,7 @@ func (c ChatCompletions) MarshalJSON() ([]byte, error) { populate(objectMap, "choices", c.Choices) populate(objectMap, "created", c.Created) populate(objectMap, "id", c.ID) + populate(objectMap, "prompt_annotations", c.PromptAnnotations) populate(objectMap, "usage", c.Usage) return json.Marshal(objectMap) } @@ -398,6 +458,9 @@ func (c *ChatCompletions) UnmarshalJSON(data []byte) error { case "id": err = unpopulate(val, "ID", &c.ID) delete(rawMsg, key) + case "prompt_annotations": + err = unpopulate(val, "PromptAnnotations", &c.PromptAnnotations) + delete(rawMsg, key) case "usage": err = unpopulate(val, "Usage", &c.Usage) delete(rawMsg, key) @@ -413,6 +476,8 @@ func (c *ChatCompletions) UnmarshalJSON(data []byte) error { func (c ChatCompletionsOptions) MarshalJSON() ([]byte, error) { objectMap := make(map[string]any) populate(objectMap, "frequency_penalty", c.FrequencyPenalty) + populate(objectMap, "function_call", c.FunctionCall) + populate(objectMap, "functions", c.Functions) populate(objectMap, "logit_bias", c.LogitBias) populate(objectMap, "max_tokens", c.MaxTokens) populate(objectMap, "messages", c.Messages) @@ -438,6 +503,12 @@ func (c *ChatCompletionsOptions) UnmarshalJSON(data []byte) error { case "frequency_penalty": err = unpopulate(val, "FrequencyPenalty", &c.FrequencyPenalty) delete(rawMsg, key) + case "function_call": + err = unpopulate(val, "FunctionCall", &c.FunctionCall) + delete(rawMsg, key) + case "functions": + err = unpopulate(val, "Functions", &c.Functions) + delete(rawMsg, key) case "logit_bias": err = unpopulate(val, "LogitBias", &c.LogitBias) delete(rawMsg, key) @@ -480,6 +551,8 @@ func (c *ChatCompletionsOptions) UnmarshalJSON(data []byte) error { func (c ChatMessage) MarshalJSON() ([]byte, error) { objectMap := make(map[string]any) populate(objectMap, "content", c.Content) + populate(objectMap, "function_call", c.FunctionCall) + populate(objectMap, "name", c.Name) populate(objectMap, "role", c.Role) return json.Marshal(objectMap) } @@ -496,6 +569,12 @@ func (c *ChatMessage) UnmarshalJSON(data []byte) error { case "content": err = unpopulate(val, "Content", &c.Content) delete(rawMsg, key) + case "function_call": + err = unpopulate(val, "FunctionCall", &c.FunctionCall) + delete(rawMsg, key) + case "name": + err = unpopulate(val, "Name", &c.Name) + delete(rawMsg, key) case "role": err = unpopulate(val, "Role", &c.Role) delete(rawMsg, key) @@ -507,9 +586,41 @@ func (c *ChatMessage) UnmarshalJSON(data []byte) error { return nil } +// MarshalJSON implements the json.Marshaller interface for type ChatMessageFunctionCall. +func (c ChatMessageFunctionCall) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "arguments", c.Arguments) + populate(objectMap, "name", c.Name) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ChatMessageFunctionCall. +func (c *ChatMessageFunctionCall) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "arguments": + err = unpopulate(val, "Arguments", &c.Arguments) + delete(rawMsg, key) + case "name": + err = unpopulate(val, "Name", &c.Name) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + // MarshalJSON implements the json.Marshaller interface for type Choice. func (c Choice) MarshalJSON() ([]byte, error) { objectMap := make(map[string]any) + populate(objectMap, "content_filter_results", c.ContentFilterResults) populate(objectMap, "finish_reason", c.FinishReason) populate(objectMap, "index", c.Index) populate(objectMap, "logprobs", c.LogProbs) @@ -526,6 +637,9 @@ func (c *Choice) UnmarshalJSON(data []byte) error { for key, val := range rawMsg { var err error switch key { + case "content_filter_results": + err = unpopulate(val, "ContentFilterResults", &c.ContentFilterResults) + delete(rawMsg, key) case "finish_reason": err = unpopulate(val, "FinishReason", &c.FinishReason) delete(rawMsg, key) @@ -546,6 +660,45 @@ func (c *Choice) UnmarshalJSON(data []byte) error { return nil } +// MarshalJSON implements the json.Marshaller interface for type ChoiceContentFilterResults. +func (c ChoiceContentFilterResults) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "hate", c.Hate) + populate(objectMap, "self_harm", c.SelfHarm) + populate(objectMap, "sexual", c.Sexual) + populate(objectMap, "violence", c.Violence) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ChoiceContentFilterResults. +func (c *ChoiceContentFilterResults) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "hate": + err = unpopulate(val, "Hate", &c.Hate) + delete(rawMsg, key) + case "self_harm": + err = unpopulate(val, "SelfHarm", &c.SelfHarm) + delete(rawMsg, key) + case "sexual": + err = unpopulate(val, "Sexual", &c.Sexual) + delete(rawMsg, key) + case "violence": + err = unpopulate(val, "Violence", &c.Violence) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + // MarshalJSON implements the json.Marshaller interface for type ChoiceLogProbs. func (c ChoiceLogProbs) MarshalJSON() ([]byte, error) { objectMap := make(map[string]any) @@ -591,6 +744,7 @@ func (c Completions) MarshalJSON() ([]byte, error) { populate(objectMap, "choices", c.Choices) populate(objectMap, "created", c.Created) populate(objectMap, "id", c.ID) + populate(objectMap, "prompt_annotations", c.PromptAnnotations) populate(objectMap, "usage", c.Usage) return json.Marshal(objectMap) } @@ -613,6 +767,9 @@ func (c *Completions) UnmarshalJSON(data []byte) error { case "id": err = unpopulate(val, "ID", &c.ID) delete(rawMsg, key) + case "prompt_annotations": + err = unpopulate(val, "PromptAnnotations", &c.PromptAnnotations) + delete(rawMsg, key) case "usage": err = unpopulate(val, "Usage", &c.Usage) delete(rawMsg, key) @@ -777,6 +934,169 @@ func (c *CompletionsUsage) UnmarshalJSON(data []byte) error { return nil } +// MarshalJSON implements the json.Marshaller interface for type ContentFilterResults. +func (c ContentFilterResults) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "hate", c.Hate) + populate(objectMap, "self_harm", c.SelfHarm) + populate(objectMap, "sexual", c.Sexual) + populate(objectMap, "violence", c.Violence) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ContentFilterResults. +func (c *ContentFilterResults) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "hate": + err = unpopulate(val, "Hate", &c.Hate) + delete(rawMsg, key) + case "self_harm": + err = unpopulate(val, "SelfHarm", &c.SelfHarm) + delete(rawMsg, key) + case "sexual": + err = unpopulate(val, "Sexual", &c.Sexual) + delete(rawMsg, key) + case "violence": + err = unpopulate(val, "Violence", &c.Violence) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type ContentFilterResultsHate. +func (c ContentFilterResultsHate) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "filtered", c.Filtered) + populate(objectMap, "severity", c.Severity) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ContentFilterResultsHate. +func (c *ContentFilterResultsHate) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "filtered": + err = unpopulate(val, "Filtered", &c.Filtered) + delete(rawMsg, key) + case "severity": + err = unpopulate(val, "Severity", &c.Severity) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type ContentFilterResultsSelfHarm. +func (c ContentFilterResultsSelfHarm) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "filtered", c.Filtered) + populate(objectMap, "severity", c.Severity) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ContentFilterResultsSelfHarm. +func (c *ContentFilterResultsSelfHarm) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "filtered": + err = unpopulate(val, "Filtered", &c.Filtered) + delete(rawMsg, key) + case "severity": + err = unpopulate(val, "Severity", &c.Severity) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type ContentFilterResultsSexual. +func (c ContentFilterResultsSexual) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "filtered", c.Filtered) + populate(objectMap, "severity", c.Severity) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ContentFilterResultsSexual. +func (c *ContentFilterResultsSexual) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "filtered": + err = unpopulate(val, "Filtered", &c.Filtered) + delete(rawMsg, key) + case "severity": + err = unpopulate(val, "Severity", &c.Severity) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type ContentFilterResultsViolence. +func (c ContentFilterResultsViolence) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "filtered", c.Filtered) + populate(objectMap, "severity", c.Severity) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ContentFilterResultsViolence. +func (c *ContentFilterResultsViolence) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "filtered": + err = unpopulate(val, "Filtered", &c.Filtered) + delete(rawMsg, key) + case "severity": + err = unpopulate(val, "Severity", &c.Severity) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + // MarshalJSON implements the json.Marshaller interface for type Deployment. func (d Deployment) MarshalJSON() ([]byte, error) { objectMap := make(map[string]any) @@ -963,6 +1283,99 @@ func (e *EmbeddingsUsageAutoGenerated) UnmarshalJSON(data []byte) error { return nil } +// MarshalJSON implements the json.Marshaller interface for type FunctionCall. +func (f FunctionCall) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "arguments", f.Arguments) + populate(objectMap, "name", f.Name) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type FunctionCall. +func (f *FunctionCall) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", f, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "arguments": + err = unpopulate(val, "Arguments", &f.Arguments) + delete(rawMsg, key) + case "name": + err = unpopulate(val, "Name", &f.Name) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", f, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type FunctionDefinition. +func (f FunctionDefinition) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "description", f.Description) + populate(objectMap, "name", f.Name) + populateAny(objectMap, "parameters", f.Parameters) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type FunctionDefinition. +func (f *FunctionDefinition) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", f, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "description": + err = unpopulate(val, "Description", &f.Description) + delete(rawMsg, key) + case "name": + err = unpopulate(val, "Name", &f.Name) + delete(rawMsg, key) + case "parameters": + err = unpopulate(val, "Parameters", &f.Parameters) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", f, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type FunctionName. +func (f FunctionName) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "name", f.Name) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type FunctionName. +func (f *FunctionName) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", f, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "name": + err = unpopulate(val, "Name", &f.Name) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", f, err) + } + } + return nil +} + // MarshalJSON implements the json.Marshaller interface for type ImageGenerationOptions. func (i ImageGenerationOptions) MarshalJSON() ([]byte, error) { objectMap := make(map[string]any) @@ -1010,7 +1423,7 @@ func (i *ImageGenerationOptions) UnmarshalJSON(data []byte) error { func (i ImageGenerations) MarshalJSON() ([]byte, error) { objectMap := make(map[string]any) populate(objectMap, "created", i.Created) - populate(objectMap, "data", i.Data) + populateAny(objectMap, "data", i.Data) return json.Marshal(objectMap) } @@ -1037,6 +1450,76 @@ func (i *ImageGenerations) UnmarshalJSON(data []byte) error { return nil } +// MarshalJSON implements the json.Marshaller interface for type PromptFilterResult. +func (p PromptFilterResult) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "content_filter_results", p.ContentFilterResults) + populate(objectMap, "prompt_index", p.PromptIndex) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type PromptFilterResult. +func (p *PromptFilterResult) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", p, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "content_filter_results": + err = unpopulate(val, "ContentFilterResults", &p.ContentFilterResults) + delete(rawMsg, key) + case "prompt_index": + err = unpopulate(val, "PromptIndex", &p.PromptIndex) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", p, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type PromptFilterResultContentFilterResults. +func (p PromptFilterResultContentFilterResults) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "hate", p.Hate) + populate(objectMap, "self_harm", p.SelfHarm) + populate(objectMap, "sexual", p.Sexual) + populate(objectMap, "violence", p.Violence) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type PromptFilterResultContentFilterResults. +func (p *PromptFilterResultContentFilterResults) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", p, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "hate": + err = unpopulate(val, "Hate", &p.Hate) + delete(rawMsg, key) + case "self_harm": + err = unpopulate(val, "SelfHarm", &p.SelfHarm) + delete(rawMsg, key) + case "sexual": + err = unpopulate(val, "Sexual", &p.Sexual) + delete(rawMsg, key) + case "violence": + err = unpopulate(val, "Violence", &p.Violence) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", p, err) + } + } + return nil +} + func populate(m map[string]any, k string, v any) { if v == nil { return @@ -1047,6 +1530,16 @@ func populate(m map[string]any, k string, v any) { } } +func populateAny(m map[string]any, k string, v any) { + if v == nil { + return + } else if azcore.IsNullValue(v) { + m[k] = nil + } else { + m[k] = v + } +} + func unpopulate(data json.RawMessage, fn string, v any) error { if data == nil { return nil diff --git a/sdk/cognitiveservices/azopenai/testdata/content_filter_response_error.json b/sdk/cognitiveservices/azopenai/testdata/content_filter_response_error.json new file mode 100644 index 000000000000..709a49a6151c --- /dev/null +++ b/sdk/cognitiveservices/azopenai/testdata/content_filter_response_error.json @@ -0,0 +1,30 @@ +{ + "error": { + "message": "The response was filtered due to the prompt triggering Azure OpenAI’s content management policy. Please modify your prompt and retry. To learn more about our content filtering policies please read our documentation: https://go.microsoft.com/fwlink/?linkid=2198766", + "type": null, + "param": "prompt", + "code": "content_filter", + "status": 400, + "innererror": { + "code": "ResponsibleAIPolicyViolation", + "content_filter_result": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": false, + "severity": "safe" + }, + "violence": { + "filtered": true, + "severity": "medium" + } + } + } + } +} \ No newline at end of file diff --git a/sdk/cognitiveservices/azopenai/testdata/tsp-location.yaml b/sdk/cognitiveservices/azopenai/testdata/tsp-location.yaml index bae87096672f..f621a7595107 100644 --- a/sdk/cognitiveservices/azopenai/testdata/tsp-location.yaml +++ b/sdk/cognitiveservices/azopenai/testdata/tsp-location.yaml @@ -1,4 +1,4 @@ #location: https://github.com/Azure/azure-rest-api-specs/tree/1393b6e34d7370733e3e2236c4df686280a96f36/specification/cognitiveservices/OpenAI.Inference directory: specification/cognitiveservices/OpenAI.Inference -commit: 1393b6e34d7370733e3e2236c4df686280a96f36 +commit: 812c8a0322c016efec774d5682797d5a40336131 repo: Azure/azure-rest-api-specs \ No newline at end of file From e08d49df544957d0b7655d4e23c495ef6dbf3b6f Mon Sep 17 00:00:00 2001 From: Richard Park Date: Thu, 13 Jul 2023 20:04:52 -0700 Subject: [PATCH 2/8] rerecording --- sdk/cognitiveservices/azopenai/assets.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cognitiveservices/azopenai/assets.json b/sdk/cognitiveservices/azopenai/assets.json index d0419e0d7f21..1a5ae0fa7048 100644 --- a/sdk/cognitiveservices/azopenai/assets.json +++ b/sdk/cognitiveservices/azopenai/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "go", "TagPrefix": "go/cognitiveservices/azopenai", - "Tag": "go/cognitiveservices/azopenai_2b6f93a94d" + "Tag": "go/cognitiveservices/azopenai_b024443c88" } From df00fce6051e1609b0c7cf148a72f36cdfefe210 Mon Sep 17 00:00:00 2001 From: Richard Park Date: Thu, 13 Jul 2023 20:42:23 -0700 Subject: [PATCH 3/8] Adding missing doc comment --- sdk/cognitiveservices/azopenai/custom_client_functions.go | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/cognitiveservices/azopenai/custom_client_functions.go b/sdk/cognitiveservices/azopenai/custom_client_functions.go index 87b6b92ead9b..340f0d66541d 100644 --- a/sdk/cognitiveservices/azopenai/custom_client_functions.go +++ b/sdk/cognitiveservices/azopenai/custom_client_functions.go @@ -21,6 +21,7 @@ type ChatCompletionsOptionsFunctionCall struct { Value *string } +// MarshalJSON implements the json.Marshaller interface for type ChatCompletionsOptionsFunctionCall. func (c ChatCompletionsOptionsFunctionCall) MarshalJSON() ([]byte, error) { if c.IsFunction { if c.Value == nil { From 8f5248ad1d4ab156daf7f5a6ce7ea4a36400bf57 Mon Sep 17 00:00:00 2001 From: Richard Park Date: Thu, 13 Jul 2023 20:42:53 -0700 Subject: [PATCH 4/8] license header --- sdk/cognitiveservices/azopenai/custom_client_functions.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sdk/cognitiveservices/azopenai/custom_client_functions.go b/sdk/cognitiveservices/azopenai/custom_client_functions.go index 340f0d66541d..480f82a862fc 100644 --- a/sdk/cognitiveservices/azopenai/custom_client_functions.go +++ b/sdk/cognitiveservices/azopenai/custom_client_functions.go @@ -1,3 +1,9 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + package azopenai import ( From 372157349b5c534330c4522ae2ed9f21612d4c56 Mon Sep 17 00:00:00 2001 From: Richard Park Date: Thu, 13 Jul 2023 20:50:22 -0700 Subject: [PATCH 5/8] Accidentally double recording --- sdk/cognitiveservices/azopenai/client_shared_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cognitiveservices/azopenai/client_shared_test.go b/sdk/cognitiveservices/azopenai/client_shared_test.go index e2a51001e3e1..3e4242842bec 100644 --- a/sdk/cognitiveservices/azopenai/client_shared_test.go +++ b/sdk/cognitiveservices/azopenai/client_shared_test.go @@ -203,7 +203,7 @@ func newOpenAIClientForTest(t *testing.T) *azopenai.Client { MaxRetryDelay: time.Second, } - chatClient, err := azopenai.NewClientForOpenAI(openAIEndpoint, cred, newClientOptionsForTest(t)) + chatClient, err := azopenai.NewClientForOpenAI(openAIEndpoint, cred, options) require.NoError(t, err) return chatClient From 3c3c4ea7cdc318d41c6f3f2a9d61a75e1b2d5619 Mon Sep 17 00:00:00 2001 From: Richard Park Date: Thu, 13 Jul 2023 20:53:32 -0700 Subject: [PATCH 6/8] clean record --- sdk/cognitiveservices/azopenai/assets.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cognitiveservices/azopenai/assets.json b/sdk/cognitiveservices/azopenai/assets.json index 1a5ae0fa7048..b7b04a7814db 100644 --- a/sdk/cognitiveservices/azopenai/assets.json +++ b/sdk/cognitiveservices/azopenai/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "go", "TagPrefix": "go/cognitiveservices/azopenai", - "Tag": "go/cognitiveservices/azopenai_b024443c88" + "Tag": "go/cognitiveservices/azopenai_63852f374c" } From ad2db780a36d567949e4d66182ef2b075ac8639b Mon Sep 17 00:00:00 2001 From: Richard Park Date: Thu, 13 Jul 2023 21:02:14 -0700 Subject: [PATCH 7/8] Use my variable for the endpoint --- sdk/cognitiveservices/azopenai/client_functions_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cognitiveservices/azopenai/client_functions_test.go b/sdk/cognitiveservices/azopenai/client_functions_test.go index 56deb8ea1abe..352155c1eefb 100644 --- a/sdk/cognitiveservices/azopenai/client_functions_test.go +++ b/sdk/cognitiveservices/azopenai/client_functions_test.go @@ -38,7 +38,7 @@ func getClientForFunctionsTest(t *testing.T, azure bool) *azopenai.Client { cred, err := azopenai.NewKeyCredential(openAIKey) require.NoError(t, err) - chatClient, err := azopenai.NewClientForOpenAI("https://api.openai.com/v1", cred, newClientOptionsForTest(t)) + chatClient, err := azopenai.NewClientForOpenAI(openAIEndpoint, cred, newClientOptionsForTest(t)) require.NoError(t, err) return chatClient From 656bb82e0cb47e8995d71bf5f3b199a524f4a69c Mon Sep 17 00:00:00 2001 From: Richard Park Date: Fri, 14 Jul 2023 13:04:07 -0700 Subject: [PATCH 8/8] Pointer to error, not just error. --- sdk/cognitiveservices/azopenai/custom_models.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cognitiveservices/azopenai/custom_models.go b/sdk/cognitiveservices/azopenai/custom_models.go index 8dc45bf3c0e4..8a639b6ecc85 100644 --- a/sdk/cognitiveservices/azopenai/custom_models.go +++ b/sdk/cognitiveservices/azopenai/custom_models.go @@ -61,7 +61,7 @@ type ContentFilterResponseError struct { } // Unwrap returns the inner error for this error. -func (e ContentFilterResponseError) Unwrap() error { +func (e *ContentFilterResponseError) Unwrap() error { return &e.ResponseError }