From dabdbb0a4b9bc196a468001b6d6e7bb86b6b0e12 Mon Sep 17 00:00:00 2001 From: fregie Date: Tue, 27 Jun 2023 10:37:38 +0800 Subject: [PATCH] openai.DefaultAzureConfig support configure apiVersion --- README.md | 10 ++++++++++ config.go | 19 ++++++++++++++++--- config_test.go | 7 +++++++ 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index da1a2804d..4800cf2be 100644 --- a/README.md +++ b/README.md @@ -444,6 +444,11 @@ func main() { // return azureModelMapping[model] // } + // If you met the error "ChatCompletion error: error, status code: 404,message: The API deployment for this resource does not exist.If you created the deployment within the last 5 minutes, please wait a moment and try again." + // You can set the APIVersion to the correct version by: + // config := openai.DefaultAzureConfig("your Azure OpenAI Key", "https://your Azure OpenAI Endpoint", openai.WithAzureAPIVersion("2021-03-01-preview")) + // find the corrent version in Azure AI Studio -> Chat playgrount -> Chat session -> view Code + client := openai.NewClientWithConfig(config) resp, err := client.CreateChatCompletion( context.Background(), @@ -494,6 +499,11 @@ func main() { // return azureModelMapping[model] //} + // If you met the error "ChatCompletion error: error, status code: 404,message: The API deployment for this resource does not exist.If you created the deployment within the last 5 minutes, please wait a moment and try again." + // You can set the APIVersion to the correct version by: + // config := openai.DefaultAzureConfig("your Azure OpenAI Key", "https://your Azure OpenAI Endpoint", openai.WithAzureAPIVersion("2021-03-01-preview")) + // find the corrent version in Azure AI Studio -> Chat playgrount -> Chat session -> view Code + input := "Text to vectorize" client := openai.NewClientWithConfig(config) diff --git a/config.go b/config.go index c58b71ec6..fd49d3080 100644 --- a/config.go +++ b/config.go @@ -11,6 +11,7 @@ const ( azureAPIPrefix = "openai" azureDeploymentsPrefix = "deployments" + azureDefaultAPIVersion = "2023-05-15" ) type APIType string @@ -50,13 +51,21 @@ func DefaultConfig(authToken string) ClientConfig { } } -func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { - return ClientConfig{ +type AzureConfigOption func(*ClientConfig) + +func WithAzureAPIVersion(apiVersion string) AzureConfigOption { + return func(c *ClientConfig) { + c.APIVersion = apiVersion + } +} + +func DefaultAzureConfig(apiKey, baseURL string, opts ...AzureConfigOption) ClientConfig { + c := ClientConfig{ authToken: apiKey, BaseURL: baseURL, OrgID: "", APIType: APITypeAzure, - APIVersion: "2023-05-15", + APIVersion: azureDefaultAPIVersion, AzureModelMapperFunc: func(model string) string { return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") }, @@ -65,6 +74,10 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { EmptyMessagesLimit: defaultEmptyMessagesLimit, } + for _, opt := range opts { + opt(&c) + } + return c } func (ClientConfig) String() string { diff --git a/config_test.go b/config_test.go index 488511b11..df475be21 100644 --- a/config_test.go +++ b/config_test.go @@ -60,3 +60,10 @@ func TestGetAzureDeploymentByModel(t *testing.T) { }) } } + +func TestSetApiVersion(t *testing.T) { + conf := DefaultAzureConfig("", "https://test.openai.azure.com/", WithAzureAPIVersion("2021-03-01-preview")) + if conf.APIVersion != "2021-03-01-preview" { + t.Errorf("Expected %s, got %s", "2021-03-01-preview", conf.APIVersion) + } +}