Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/cognitiveservices/azopenai/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Release History

## 0.1.0 (2023-07-19)
## 0.1.0 (2023-07-20)

* Initial release of the `azopenai` library
2 changes: 1 addition & 1 deletion sdk/cognitiveservices/azopenai/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "go",
"TagPrefix": "go/cognitiveservices/azopenai",
"Tag": "go/cognitiveservices/azopenai_e8362ae205"
"Tag": "go/cognitiveservices/azopenai_8fdad86997"
}
20 changes: 19 additions & 1 deletion sdk/cognitiveservices/azopenai/autorest.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ directive:
# allow interception of formatting the URL path
- from: client.go
where: $
transform: return $.replace(/runtime\.JoinPaths\(client.endpoint, urlPath\)/g, "client.formatURL(urlPath)");
transform: |
return $
.replace(/runtime\.JoinPaths\(client.endpoint, urlPath\)/g, "client.formatURL(urlPath, getDeploymentID(body))");

# Some ImageGenerations hackery to represent the ImageLocation/ImagePayload polymorphism.
# - Remove the auto-generated ImageGenerationsDataItem.
Expand Down Expand Up @@ -276,4 +278,20 @@ directive:
- from: client.go
where: $
transform: return $.replace(/runtime\.NewResponseError/sg, "client.newError");

#
# rename `Model` to `DeploymentID`
#
- from: models.go
where: $
transform: |
return $
.replace(/\/\/ The model name.*?Model \*string/sg, "// DeploymentID specifies the name of the deployment (for Azure OpenAI) or model (for OpenAI) to use for this request.\nDeploymentID string");

- from: models_serde.go
where: $
transform: |
return $
.replace(/populate\(objectMap, "model", (c|e).Model\)/g, 'populate(objectMap, "model", &$1.DeploymentID)')
.replace(/err = unpopulate\(val, "Model", &(c|e).Model\)/g, 'err = unpopulate(val, "Model", &$1.DeploymentID)');
```
8 changes: 4 additions & 4 deletions sdk/cognitiveservices/azopenai/client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

72 changes: 36 additions & 36 deletions sdk/cognitiveservices/azopenai/client_chat_completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,36 @@ import (
"github.com/stretchr/testify/require"
)

var chatCompletionsRequest = azopenai.ChatCompletionsOptions{
Messages: []azopenai.ChatMessage{
{
Role: to.Ptr(azopenai.ChatRole("user")),
Content: to.Ptr("Count to 10, with a comma between each number, no newlines and a period at the end. E.g., 1, 2, 3, ..."),
func newTestChatCompletionOptions(tv testVars) azopenai.ChatCompletionsOptions {
return azopenai.ChatCompletionsOptions{
Messages: []azopenai.ChatMessage{
{
Role: to.Ptr(azopenai.ChatRole("user")),
Content: to.Ptr("Count to 10, with a comma between each number, no newlines and a period at the end. E.g., 1, 2, 3, ..."),
},
},
},
MaxTokens: to.Ptr(int32(1024)),
Temperature: to.Ptr(float32(0.0)),
Model: &openAIChatCompletionsModel,
MaxTokens: to.Ptr(int32(1024)),
Temperature: to.Ptr(float32(0.0)),
DeploymentID: tv.ChatCompletions,
}
}

var expectedContent = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10."
var expectedRole = azopenai.ChatRoleAssistant

func TestClient_GetChatCompletions(t *testing.T) {
cred, err := azopenai.NewKeyCredential(apiKey)
cred, err := azopenai.NewKeyCredential(azureOpenAI.APIKey)
require.NoError(t, err)

chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, chatCompletionsModelDeployment, newClientOptionsForTest(t))
chatClient, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Endpoint, cred, newClientOptionsForTest(t))
require.NoError(t, err)

testGetChatCompletions(t, chatClient, true)
testGetChatCompletions(t, chatClient, azureOpenAI)
}

func TestClient_GetChatCompletionsStream(t *testing.T) {
chatClient := newAzureOpenAIClientForTest(t, canaryChatCompletionsModelDeployment, true)
testGetChatCompletionsStream(t, chatClient)
chatClient := newAzureOpenAIClientForTest(t, azureOpenAICanary)
testGetChatCompletionsStream(t, chatClient, azureOpenAICanary)
}

func TestClient_OpenAI_GetChatCompletions(t *testing.T) {
Expand All @@ -58,7 +60,7 @@ func TestClient_OpenAI_GetChatCompletions(t *testing.T) {
}

chatClient := newOpenAIClientForTest(t)
testGetChatCompletions(t, chatClient, false)
testGetChatCompletions(t, chatClient, openAI)
}

func TestClient_OpenAI_GetChatCompletionsStream(t *testing.T) {
Expand All @@ -67,10 +69,10 @@ func TestClient_OpenAI_GetChatCompletionsStream(t *testing.T) {
}

chatClient := newOpenAIClientForTest(t)
testGetChatCompletionsStream(t, chatClient)
testGetChatCompletionsStream(t, chatClient, openAI)
}

func testGetChatCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
func testGetChatCompletions(t *testing.T, client *azopenai.Client, tv testVars) {
expected := azopenai.ChatCompletions{
Choices: []azopenai.ChatChoice{
{
Expand All @@ -91,10 +93,10 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client, isAzure bool)
},
}

resp, err := client.GetChatCompletions(context.Background(), chatCompletionsRequest, nil)
resp, err := client.GetChatCompletions(context.Background(), newTestChatCompletionOptions(tv), nil)
require.NoError(t, err)

if isAzure {
if tv.Azure {
// Azure also provides content-filtering. This particular prompt and responses
// will be considered safe.
expected.PromptAnnotations = []azopenai.PromptFilterResult{
Expand All @@ -112,8 +114,8 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client, isAzure bool)
require.Equal(t, expected, resp.ChatCompletions)
}

func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client) {
streamResp, err := client.GetChatCompletionsStream(context.Background(), chatCompletionsRequest, nil)
func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, tv testVars) {
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(tv), nil)
require.NoError(t, err)

// the data comes back differently for streaming
Expand Down Expand Up @@ -178,19 +180,19 @@ func TestClient_GetChatCompletions_DefaultAzureCredential(t *testing.T) {
})
require.NoError(t, err)

chatClient, err := azopenai.NewClient(endpoint, dac, chatCompletionsModelDeployment, &azopenai.ClientOptions{
chatClient, err := azopenai.NewClient(azureOpenAI.Endpoint, dac, &azopenai.ClientOptions{
ClientOptions: policy.ClientOptions{Transport: recordingTransporter},
})
require.NoError(t, err)

testGetChatCompletions(t, chatClient, true)
testGetChatCompletions(t, chatClient, azureOpenAI)
}

func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
cred, err := azopenai.NewKeyCredential(apiKey)
cred, err := azopenai.NewKeyCredential(azureOpenAI.APIKey)
require.NoError(t, err)

chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, "thisdoesntexist", newClientOptionsForTest(t))
chatClient, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Endpoint, cred, newClientOptionsForTest(t))
require.NoError(t, err)

_, err = chatClient.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
Expand All @@ -200,8 +202,9 @@ func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
Content: to.Ptr("Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."),
},
},
MaxTokens: to.Ptr(int32(1024)),
Temperature: to.Ptr(float32(0.0)),
MaxTokens: to.Ptr(int32(1024)),
Temperature: to.Ptr(float32(0.0)),
DeploymentID: "invalid model name",
}, nil)

var respErr *azcore.ResponseError
Expand All @@ -214,20 +217,17 @@ func TestClient_GetChatCompletionsStream_Error(t *testing.T) {
t.Skip()
}

doTest := func(t *testing.T, client *azopenai.Client) {
t.Helper()
streamResp, err := client.GetChatCompletionsStream(context.Background(), chatCompletionsRequest, nil)
t.Run("AzureOpenAI", func(t *testing.T) {
client := newBogusAzureOpenAIClient(t)
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(azureOpenAI), nil)
require.Empty(t, streamResp)
assertResponseIsError(t, err)
}

t.Run("AzureOpenAI", func(t *testing.T) {
client := newBogusAzureOpenAIClient(t, chatCompletionsModelDeployment)
doTest(t, client)
})

t.Run("OpenAI", func(t *testing.T) {
client := newBogusOpenAIClient(t)
doTest(t, client)
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(openAI), nil)
require.Empty(t, streamResp)
assertResponseIsError(t, err)
})
}
24 changes: 15 additions & 9 deletions sdk/cognitiveservices/azopenai/client_completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ import (
)

func TestClient_GetCompletions_AzureOpenAI(t *testing.T) {
cred, err := azopenai.NewKeyCredential(apiKey)
cred, err := azopenai.NewKeyCredential(azureOpenAI.APIKey)
require.NoError(t, err)

client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, completionsModelDeployment, newClientOptionsForTest(t))
client, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Endpoint, cred, newClientOptionsForTest(t))
require.NoError(t, err)

testGetCompletions(t, client)
testGetCompletions(t, client, true)
}

func TestClient_GetCompletions_OpenAI(t *testing.T) {
Expand All @@ -31,15 +31,21 @@ func TestClient_GetCompletions_OpenAI(t *testing.T) {
}

client := newOpenAIClientForTest(t)
testGetCompletions(t, client)
testGetCompletions(t, client, false)
}

func testGetCompletions(t *testing.T, client *azopenai.Client) {
func testGetCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
deploymentID := openAI.Completions

if isAzure {
deploymentID = azureOpenAI.Completions
}

resp, err := client.GetCompletions(context.Background(), azopenai.CompletionsOptions{
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048 - 127)),
Temperature: to.Ptr(float32(0.0)),
Model: &openAICompletionsModel,
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048 - 127)),
Temperature: to.Ptr(float32(0.0)),
DeploymentID: deploymentID,
}, nil)
require.NoError(t, err)

Expand Down
24 changes: 11 additions & 13 deletions sdk/cognitiveservices/azopenai/client_embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ import (
)

func TestClient_GetEmbeddings_InvalidModel(t *testing.T) {
cred, err := azopenai.NewKeyCredential(apiKey)
cred, err := azopenai.NewKeyCredential(azureOpenAI.APIKey)
require.NoError(t, err)

chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, "thisdoesntexist", newClientOptionsForTest(t))
chatClient, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Endpoint, cred, newClientOptionsForTest(t))
require.NoError(t, err)

_, err = chatClient.GetEmbeddings(context.Background(), azopenai.EmbeddingsOptions{}, nil)
_, err = chatClient.GetEmbeddings(context.Background(), azopenai.EmbeddingsOptions{
DeploymentID: "thisdoesntexist",
}, nil)

var respErr *azcore.ResponseError
require.ErrorAs(t, err, &respErr)
Expand All @@ -32,21 +34,17 @@ func TestClient_OpenAI_GetEmbeddings(t *testing.T) {
}

client := newOpenAIClientForTest(t)
modelID := "text-similarity-curie-001"
testGetEmbeddings(t, client, modelID)
testGetEmbeddings(t, client, openAI.Embeddings)
}

func TestClient_GetEmbeddings(t *testing.T) {
// model deployment points to `text-similarity-curie-001`
deploymentID := "embedding"

cred, err := azopenai.NewKeyCredential(apiKey)
cred, err := azopenai.NewKeyCredential(azureOpenAI.APIKey)
require.NoError(t, err)

client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, deploymentID, newClientOptionsForTest(t))
client, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Endpoint, cred, newClientOptionsForTest(t))
require.NoError(t, err)

testGetEmbeddings(t, client, deploymentID)
testGetEmbeddings(t, client, azureOpenAI.Embeddings)
}

func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentID string) {
Expand All @@ -71,8 +69,8 @@ func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentI
ctx: context.TODO(),
deploymentID: modelOrDeploymentID,
body: azopenai.EmbeddingsOptions{
Input: []string{"\"Your text string goes here\""},
Model: &modelOrDeploymentID,
Input: []string{"\"Your text string goes here\""},
DeploymentID: modelOrDeploymentID,
},
options: nil,
},
Expand Down
16 changes: 9 additions & 7 deletions sdk/cognitiveservices/azopenai/client_functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@ func TestGetChatCompletions_usingFunctions(t *testing.T) {

t.Run("OpenAI", func(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
testChatCompletionsFunctions(t, chatClient)
testChatCompletionsFunctions(t, chatClient, openAI)
})

t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newAzureOpenAIClientForTest(t, chatCompletionsModelDeployment, false)
testChatCompletionsFunctions(t, chatClient)
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
testChatCompletionsFunctions(t, chatClient, azureOpenAI)
})
}

func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client) {
resp, err := chatClient.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
Model: to.Ptr("gpt-4-0613"),
func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, tv testVars) {
body := azopenai.ChatCompletionsOptions{
DeploymentID: tv.ChatCompletions,
Messages: []azopenai.ChatMessage{
{
Role: to.Ptr(azopenai.ChatRoleUser),
Expand Down Expand Up @@ -72,7 +72,9 @@ func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client) {
},
},
Temperature: to.Ptr[float32](0.0),
}, nil)
}

resp, err := chatClient.GetChatCompletions(context.Background(), body, nil)
require.NoError(t, err)

funcCall := resp.ChatCompletions.Choices[0].Message.FunctionCall
Expand Down
Loading