From 35ca7458585a04c6ce0f5d7f61ac8aef7a12c12a Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Fri, 31 Mar 2023 17:17:07 +0800 Subject: [PATCH 01/26] feat: add azure openai support --- api.go | 16 +++++++++++++++- config.go | 34 +++++++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/api.go b/api.go index 00d6d3514..b8d8c207e 100644 --- a/api.go +++ b/api.go @@ -83,6 +83,13 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { } func (c *Client) fullURL(suffix string) string { + // /openai/deployments/{engine}/chat/completions?api-version={api_version} + if c.config.ApiType == ApiTypeAzure || c.config.ApiType == ApiTypeAzureAD { + return fmt.Sprintf("%s%s/%s/%s%s?api-version=%s", + c.config.BaseURL, azureApiPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.ApiVersion) + } + + // c.config.ApiType == ApiTypeOpenAI || c.config.ApiType == "" return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } @@ -100,7 +107,14 @@ func (c *Client) newStreamRequest( req.Header.Set("Accept", "text/event-stream") req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Connection", "keep-alive") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication + // Azure API Key authentication + if c.config.ApiType == ApiTypeAzure { + req.Header.Set("api-key", c.config.authToken) + } else { + // OpenAI or Azure AD authentication + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + } return req, nil } diff --git a/config.go b/config.go index e09c256f2..021aff221 100644 --- a/config.go +++ b/config.go @@ -7,16 +7,30 @@ import ( const ( apiURLv1 = "https://api.openai.com/v1" defaultEmptyMessagesLimit uint = 300 + + azureApiPrefix = "openai" + azureDeploymentsPrefix = "deployments" +) + +type ApiType string + +const ( + ApiTypeOpenAI ApiType = "OPEN_AI" + ApiTypeAzure ApiType = "AZURE" + ApiTypeAzureAD ApiType = "AZURE_AD" ) // ClientConfig is a configuration of a client. type ClientConfig struct { + ApiType ApiType + Engine string + ApiVersion string + authToken string HTTPClient *http.Client - - BaseURL string - OrgID string + BaseURL string + OrgID string EmptyMessagesLimit uint } @@ -31,3 +45,17 @@ func DefaultConfig(authToken string) ClientConfig { EmptyMessagesLimit: defaultEmptyMessagesLimit, } } + +func DefaultAzureConfig(apiBase, engine, apiKey, apiVersion string) ClientConfig { + return ClientConfig{ + ApiType: ApiTypeAzure, + Engine: engine, + ApiVersion: apiVersion, + HTTPClient: &http.Client{}, + BaseURL: apiBase, + OrgID: "", + authToken: apiKey, + + EmptyMessagesLimit: defaultEmptyMessagesLimit, + } +} From 4ef2708fd9d4991ed3f4f67617a42480d65615d4 Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Fri, 31 Mar 2023 17:46:48 +0800 Subject: [PATCH 02/26] chore: refine config --- config.go | 86 +++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 77 insertions(+), 9 deletions(-) diff --git a/config.go b/config.go index 021aff221..0b6fd728f 100644 --- a/config.go +++ b/config.go @@ -1,11 +1,12 @@ package openai import ( + "fmt" "net/http" ) const ( - apiURLv1 = "https://api.openai.com/v1" + openaiApiURLv1 = "https://api.openai.com/v1" defaultEmptyMessagesLimit uint = 300 azureApiPrefix = "openai" @@ -20,6 +21,12 @@ const ( ApiTypeAzureAD ApiType = "AZURE_AD" ) +var supportedApiType = map[ApiType]struct{}{ + ApiTypeOpenAI: {}, + ApiTypeAzure: {}, + ApiTypeAzureAD: {}, +} + // ClientConfig is a configuration of a client. type ClientConfig struct { ApiType ApiType @@ -38,7 +45,7 @@ type ClientConfig struct { func DefaultConfig(authToken string) ClientConfig { return ClientConfig{ HTTPClient: &http.Client{}, - BaseURL: apiURLv1, + BaseURL: openaiApiURLv1, OrgID: "", authToken: authToken, @@ -46,16 +53,77 @@ func DefaultConfig(authToken string) ClientConfig { } } -func DefaultAzureConfig(apiBase, engine, apiKey, apiVersion string) ClientConfig { - return ClientConfig{ - ApiType: ApiTypeAzure, - Engine: engine, - ApiVersion: apiVersion, +func NewConfig(authTokenOrKey string, opts ...Option) (ClientConfig, error) { + cfg := ClientConfig{ + ApiType: ApiTypeOpenAI, + Engine: "", + ApiVersion: "", HTTPClient: &http.Client{}, - BaseURL: apiBase, + BaseURL: openaiApiURLv1, OrgID: "", - authToken: apiKey, + authToken: authTokenOrKey, EmptyMessagesLimit: defaultEmptyMessagesLimit, } + for _, o := range opts { + o(&cfg) + } + if authTokenOrKey == "" { + return ClientConfig{}, fmt.Errorf("auth token or key is required") + } + + if _, ok := supportedApiType[cfg.ApiType]; !ok { + return ClientConfig{}, fmt.Errorf("unsupported API type %s", cfg.ApiType) + } + + if cfg.ApiType == ApiTypeAzure || cfg.ApiType == ApiTypeAzureAD { + if cfg.ApiVersion == "" { + return ClientConfig{}, fmt.Errorf("an API version is required for the Azure API type") + } + } + + return cfg, nil +} + +type Option func(*ClientConfig) + +// WithApiType sets the API type to use. +func WithApiType(apiType ApiType) Option { + return func(o *ClientConfig) { + o.ApiType = apiType + } +} + +// WithEngine sets the engine to use. +func WithEngine(engine string) Option { + return func(o *ClientConfig) { + o.Engine = engine + } +} + +// WithApiVersion sets the API version to use. +func WithApiVersion(apiVersion string) Option { + return func(o *ClientConfig) { + o.ApiVersion = apiVersion + } +} + +// WithHTTPClient sets the HTTP client to use. +func WithHTTPClient(client *http.Client) Option { + return func(o *ClientConfig) { + o.HTTPClient = client + } +} + +func WithBaseURL(apiBase string) Option { + return func(o *ClientConfig) { + o.BaseURL = apiBase + } +} + +// WithOrgID sets the organization ID to use. +func WithOrgID(orgID string) Option { + return func(o *ClientConfig) { + o.OrgID = orgID + } } From 350e961b2f27b547b5006cb9ca9d4ef2dc1be902 Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Fri, 31 Mar 2023 17:52:09 +0800 Subject: [PATCH 03/26] chore: make config options like the python one --- api.go | 10 +++++----- api_test.go | 2 +- audio_test.go | 4 ++-- chat_stream_test.go | 6 +++--- chat_test.go | 6 +++--- completion_test.go | 4 ++-- config.go | 36 ++++++++++++++++++------------------ edits_test.go | 2 +- error_accumulator_test.go | 2 +- files_test.go | 2 +- fine_tunes_test.go | 2 +- image_test.go | 8 ++++---- models_test.go | 2 +- moderation_test.go | 2 +- request_builder_test.go | 2 +- stream_test.go | 6 +++--- 16 files changed, 48 insertions(+), 48 deletions(-) diff --git a/api.go b/api.go index b8d8c207e..3c889fe82 100644 --- a/api.go +++ b/api.go @@ -39,7 +39,7 @@ func NewOrgClient(authToken, org string) *Client { func (c *Client) sendRequest(req *http.Request, v interface{}) error { req.Header.Set("Accept", "application/json; charset=utf-8") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.ApiKey)) // Check whether Content-Type is already set, Upload Files API requires // Content-Type == multipart/form-data @@ -86,11 +86,11 @@ func (c *Client) fullURL(suffix string) string { // /openai/deployments/{engine}/chat/completions?api-version={api_version} if c.config.ApiType == ApiTypeAzure || c.config.ApiType == ApiTypeAzureAD { return fmt.Sprintf("%s%s/%s/%s%s?api-version=%s", - c.config.BaseURL, azureApiPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.ApiVersion) + c.config.ApiBase, azureApiPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.ApiVersion) } // c.config.ApiType == ApiTypeOpenAI || c.config.ApiType == "" - return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) + return fmt.Sprintf("%s%s", c.config.ApiBase, suffix) } func (c *Client) newStreamRequest( @@ -111,10 +111,10 @@ func (c *Client) newStreamRequest( // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // Azure API Key authentication if c.config.ApiType == ApiTypeAzure { - req.Header.Set("api-key", c.config.authToken) + req.Header.Set("api-key", c.config.ApiKey) } else { // OpenAI or Azure AD authentication - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.ApiKey)) } return req, nil } diff --git a/api_test.go b/api_test.go index 478a274d4..91b939405 100644 --- a/api_test.go +++ b/api_test.go @@ -132,7 +132,7 @@ func TestRequestError(t *testing.T) { var err error config := DefaultConfig("dummy") - config.BaseURL = "https://httpbin.org/status/418?" + config.ApiBase = "https://httpbin.org/status/418?" c := NewClientWithConfig(config) ctx := context.Background() _, err = c.ListEngines(ctx) diff --git a/audio_test.go b/audio_test.go index 087084805..7527c960c 100644 --- a/audio_test.go +++ b/audio_test.go @@ -31,7 +31,7 @@ func TestAudio(t *testing.T) { defer ts.Close() config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" + config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) testcases := []struct { @@ -79,7 +79,7 @@ func TestAudioWithOptionalArgs(t *testing.T) { defer ts.Close() config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" + config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) testcases := []struct { diff --git a/chat_stream_test.go b/chat_stream_test.go index 24046db6c..aa98d3cb8 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -16,7 +16,7 @@ import ( func TestChatCompletionsStreamWrongModel(t *testing.T) { config := DefaultConfig("whatever") - config.BaseURL = "http://localhost/v1" + config.ApiBase = "http://localhost/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -62,7 +62,7 @@ func TestCreateChatCompletionStream(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" + config.ApiBase = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), http.DefaultTransport, @@ -169,7 +169,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" + config.ApiBase = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), http.DefaultTransport, diff --git a/chat_test.go b/chat_test.go index ce302a69f..30fd791f0 100644 --- a/chat_test.go +++ b/chat_test.go @@ -18,7 +18,7 @@ import ( func TestChatCompletionsWrongModel(t *testing.T) { config := DefaultConfig("whatever") - config.BaseURL = "http://localhost/v1" + config.ApiBase = "http://localhost/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -39,7 +39,7 @@ func TestChatCompletionsWrongModel(t *testing.T) { func TestChatCompletionsWithStream(t *testing.T) { config := DefaultConfig("whatever") - config.BaseURL = "http://localhost/v1" + config.ApiBase = "http://localhost/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -61,7 +61,7 @@ func TestChatCompletions(t *testing.T) { defer ts.Close() config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" + config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/completion_test.go b/completion_test.go index 2e302591a..2b3d47d9f 100644 --- a/completion_test.go +++ b/completion_test.go @@ -19,7 +19,7 @@ import ( func TestCompletionsWrongModel(t *testing.T) { config := DefaultConfig("whatever") - config.BaseURL = "http://localhost/v1" + config.ApiBase = "http://localhost/v1" client := NewClientWithConfig(config) _, err := client.CreateCompletion( @@ -57,7 +57,7 @@ func TestCompletions(t *testing.T) { defer ts.Close() config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" + config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/config.go b/config.go index 0b6fd728f..85e82fe9c 100644 --- a/config.go +++ b/config.go @@ -33,43 +33,37 @@ type ClientConfig struct { Engine string ApiVersion string - authToken string + ApiKey string HTTPClient *http.Client - BaseURL string + ApiBase string OrgID string EmptyMessagesLimit uint } -func DefaultConfig(authToken string) ClientConfig { - return ClientConfig{ - HTTPClient: &http.Client{}, - BaseURL: openaiApiURLv1, - OrgID: "", - authToken: authToken, - - EmptyMessagesLimit: defaultEmptyMessagesLimit, - } +func DefaultConfig(apiKey string) (ClientConfig, error) { + return NewConfig(WithApiKey(apiKey)) } -func NewConfig(authTokenOrKey string, opts ...Option) (ClientConfig, error) { +func NewConfig(opts ...Option) (ClientConfig, error) { cfg := ClientConfig{ ApiType: ApiTypeOpenAI, Engine: "", ApiVersion: "", HTTPClient: &http.Client{}, - BaseURL: openaiApiURLv1, + ApiBase: openaiApiURLv1, OrgID: "", - authToken: authTokenOrKey, + ApiKey: "", EmptyMessagesLimit: defaultEmptyMessagesLimit, } for _, o := range opts { o(&cfg) } - if authTokenOrKey == "" { - return ClientConfig{}, fmt.Errorf("auth token or key is required") + + if cfg.ApiKey == "" { + return ClientConfig{}, fmt.Errorf("api key is required") } if _, ok := supportedApiType[cfg.ApiType]; !ok { @@ -115,9 +109,15 @@ func WithHTTPClient(client *http.Client) Option { } } -func WithBaseURL(apiBase string) Option { +func WithApiBase(apiBase string) Option { + return func(o *ClientConfig) { + o.ApiBase = apiBase + } +} + +func WithApiKey(apiKey string) Option { return func(o *ClientConfig) { - o.BaseURL = apiBase + o.ApiKey = apiKey } } diff --git a/edits_test.go b/edits_test.go index fa6c12825..88c059142 100644 --- a/edits_test.go +++ b/edits_test.go @@ -25,7 +25,7 @@ func TestEdits(t *testing.T) { defer ts.Close() config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" + config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/error_accumulator_test.go b/error_accumulator_test.go index 637bf3678..deb142456 100644 --- a/error_accumulator_test.go +++ b/error_accumulator_test.go @@ -76,7 +76,7 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) { defer ts.Close() config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" + config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/files_test.go b/files_test.go index 3e8dfc442..f63d18b13 100644 --- a/files_test.go +++ b/files_test.go @@ -24,7 +24,7 @@ func TestFileUpload(t *testing.T) { defer ts.Close() config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" + config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/fine_tunes_test.go b/fine_tunes_test.go index c60254993..0c65c3e24 100644 --- a/fine_tunes_test.go +++ b/fine_tunes_test.go @@ -66,7 +66,7 @@ func TestFineTunes(t *testing.T) { defer ts.Close() config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" + config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/image_test.go b/image_test.go index 9917b7881..180f66a51 100644 --- a/image_test.go +++ b/image_test.go @@ -25,7 +25,7 @@ func TestImages(t *testing.T) { defer ts.Close() config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" + config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -95,7 +95,7 @@ func TestImageEdit(t *testing.T) { defer ts.Close() config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" + config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -139,7 +139,7 @@ func TestImageEditWithoutMask(t *testing.T) { defer ts.Close() config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" + config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -205,7 +205,7 @@ func TestImageVariation(t *testing.T) { defer ts.Close() config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" + config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/models_test.go b/models_test.go index dad59be79..98f35a5be 100644 --- a/models_test.go +++ b/models_test.go @@ -23,7 +23,7 @@ func TestListModels(t *testing.T) { defer ts.Close() config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" + config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/moderation_test.go b/moderation_test.go index 3535bc807..483771ef4 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -27,7 +27,7 @@ func TestModerations(t *testing.T) { defer ts.Close() config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" + config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/request_builder_test.go b/request_builder_test.go index e5b65df0c..998157385 100644 --- a/request_builder_test.go +++ b/request_builder_test.go @@ -45,7 +45,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { defer ts.Close() config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" + config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) client.requestBuilder = &failingRequestBuilder{} diff --git a/stream_test.go b/stream_test.go index a80504d24..8ee869e74 100644 --- a/stream_test.go +++ b/stream_test.go @@ -15,7 +15,7 @@ import ( func TestCompletionsStreamWrongModel(t *testing.T) { config := DefaultConfig("whatever") - config.BaseURL = "http://localhost/v1" + config.ApiBase = "http://localhost/v1" client := NewClientWithConfig(config) _, err := client.CreateCompletionStream( @@ -56,7 +56,7 @@ func TestCreateCompletionStream(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" + config.ApiBase = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), http.DefaultTransport, @@ -141,7 +141,7 @@ func TestCreateCompletionStreamError(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" + config.ApiBase = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), http.DefaultTransport, From 61287b2d831da0bb4d4c27636dab7629601b3718 Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Fri, 31 Mar 2023 17:53:33 +0800 Subject: [PATCH 04/26] chore: adjust config struct field order --- api.go | 18 ++++++++++++------ config.go | 23 +++++++++++------------ 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/api.go b/api.go index 3c889fe82..f468465cd 100644 --- a/api.go +++ b/api.go @@ -15,9 +15,12 @@ type Client struct { } // NewClient creates new OpenAI API client. -func NewClient(authToken string) *Client { - config := DefaultConfig(authToken) - return NewClientWithConfig(config) +func NewClient(authToken string) (*Client, error) { + config, err := DefaultConfig(authToken) + if err != nil { + return nil, err + } + return NewClientWithConfig(config), nil } // NewClientWithConfig creates new OpenAI API client for specified config. @@ -31,10 +34,13 @@ func NewClientWithConfig(config ClientConfig) *Client { // NewOrgClient creates new OpenAI API client for specified Organization ID. // // Deprecated: Please use NewClientWithConfig. -func NewOrgClient(authToken, org string) *Client { - config := DefaultConfig(authToken) +func NewOrgClient(authToken, org string) (*Client, error) { + config, err := DefaultConfig(authToken) + if err != nil { + return nil, err + } config.OrgID = org - return NewClientWithConfig(config) + return NewClientWithConfig(config), nil } func (c *Client) sendRequest(req *http.Request, v interface{}) error { diff --git a/config.go b/config.go index 85e82fe9c..d8115f450 100644 --- a/config.go +++ b/config.go @@ -30,14 +30,14 @@ var supportedApiType = map[ApiType]struct{}{ // ClientConfig is a configuration of a client. type ClientConfig struct { ApiType ApiType - Engine string + ApiKey string + ApiBase string ApiVersion string - ApiKey string + Engine string + OrgID string HTTPClient *http.Client - ApiBase string - OrgID string EmptyMessagesLimit uint } @@ -48,14 +48,13 @@ func DefaultConfig(apiKey string) (ClientConfig, error) { func NewConfig(opts ...Option) (ClientConfig, error) { cfg := ClientConfig{ - ApiType: ApiTypeOpenAI, - Engine: "", - ApiVersion: "", - HTTPClient: &http.Client{}, - ApiBase: openaiApiURLv1, - OrgID: "", - ApiKey: "", - + ApiType: ApiTypeOpenAI, + ApiKey: "", + ApiBase: openaiApiURLv1, + ApiVersion: "", + Engine: "", + OrgID: "", + HTTPClient: &http.Client{}, EmptyMessagesLimit: defaultEmptyMessagesLimit, } for _, o := range opts { From 9f06a7b3c71828bce57c87e8e1ed9800042a5a0d Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Fri, 31 Mar 2023 18:32:25 +0800 Subject: [PATCH 05/26] test: fix tests --- api_test.go | 6 +++--- audio_test.go | 4 ++-- chat_stream_test.go | 6 +++--- chat_test.go | 6 +++--- completion_test.go | 6 +++--- edits_test.go | 2 +- error_accumulator_test.go | 2 +- files_test.go | 2 +- fine_tunes_test.go | 2 +- image_test.go | 8 ++++---- models_test.go | 2 +- moderation_test.go | 2 +- request_builder_test.go | 2 +- stream_test.go | 6 +++--- 14 files changed, 28 insertions(+), 28 deletions(-) diff --git a/api_test.go b/api_test.go index 91b939405..f1bb949a3 100644 --- a/api_test.go +++ b/api_test.go @@ -18,7 +18,7 @@ func TestAPI(t *testing.T) { } var err error - c := NewClient(apiToken) + c, _ := NewClient(apiToken) ctx := context.Background() _, err = c.ListEngines(ctx) checks.NoError(t, err, "ListEngines error") @@ -107,7 +107,7 @@ func TestAPIError(t *testing.T) { } var err error - c := NewClient(apiToken + "_invalid") + c, _ := NewClient(apiToken + "_invalid") ctx := context.Background() _, err = c.ListEngines(ctx) checks.NoError(t, err, "ListEngines did not fail") @@ -131,7 +131,7 @@ func TestAPIError(t *testing.T) { func TestRequestError(t *testing.T) { var err error - config := DefaultConfig("dummy") + config, _ := DefaultConfig("dummy") config.ApiBase = "https://httpbin.org/status/418?" c := NewClientWithConfig(config) ctx := context.Background() diff --git a/audio_test.go b/audio_test.go index 7527c960c..da233c192 100644 --- a/audio_test.go +++ b/audio_test.go @@ -30,7 +30,7 @@ func TestAudio(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) @@ -78,7 +78,7 @@ func TestAudioWithOptionalArgs(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) diff --git a/chat_stream_test.go b/chat_stream_test.go index aa98d3cb8..2fb6c6d77 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -15,7 +15,7 @@ import ( ) func TestChatCompletionsStreamWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config, _ := DefaultConfig("whatever") config.ApiBase = "http://localhost/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -61,7 +61,7 @@ func TestCreateChatCompletionStream(t *testing.T) { defer server.Close() // Client portion of the test - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), @@ -168,7 +168,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) { defer server.Close() // Client portion of the test - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), diff --git a/chat_test.go b/chat_test.go index 30fd791f0..5f937bb9f 100644 --- a/chat_test.go +++ b/chat_test.go @@ -17,7 +17,7 @@ import ( ) func TestChatCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config, _ := DefaultConfig("whatever") config.ApiBase = "http://localhost/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -38,7 +38,7 @@ func TestChatCompletionsWrongModel(t *testing.T) { } func TestChatCompletionsWithStream(t *testing.T) { - config := DefaultConfig("whatever") + config, _ := DefaultConfig("whatever") config.ApiBase = "http://localhost/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -60,7 +60,7 @@ func TestChatCompletions(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/completion_test.go b/completion_test.go index 2b3d47d9f..17a40f2d4 100644 --- a/completion_test.go +++ b/completion_test.go @@ -18,7 +18,7 @@ import ( ) func TestCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config, _ := DefaultConfig("whatever") config.ApiBase = "http://localhost/v1" client := NewClientWithConfig(config) @@ -35,7 +35,7 @@ func TestCompletionsWrongModel(t *testing.T) { } func TestCompletionWithStream(t *testing.T) { - config := DefaultConfig("whatever") + config, _ := DefaultConfig("whatever") client := NewClientWithConfig(config) ctx := context.Background() @@ -56,7 +56,7 @@ func TestCompletions(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/edits_test.go b/edits_test.go index 88c059142..8d9387e52 100644 --- a/edits_test.go +++ b/edits_test.go @@ -24,7 +24,7 @@ func TestEdits(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/error_accumulator_test.go b/error_accumulator_test.go index deb142456..02725e784 100644 --- a/error_accumulator_test.go +++ b/error_accumulator_test.go @@ -75,7 +75,7 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) diff --git a/files_test.go b/files_test.go index f63d18b13..9d23b7dbe 100644 --- a/files_test.go +++ b/files_test.go @@ -23,7 +23,7 @@ func TestFileUpload(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/fine_tunes_test.go b/fine_tunes_test.go index 0c65c3e24..fb9c99456 100644 --- a/fine_tunes_test.go +++ b/fine_tunes_test.go @@ -65,7 +65,7 @@ func TestFineTunes(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/image_test.go b/image_test.go index 180f66a51..66f712591 100644 --- a/image_test.go +++ b/image_test.go @@ -24,7 +24,7 @@ func TestImages(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -94,7 +94,7 @@ func TestImageEdit(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -138,7 +138,7 @@ func TestImageEditWithoutMask(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -204,7 +204,7 @@ func TestImageVariation(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/models_test.go b/models_test.go index 98f35a5be..c270ec499 100644 --- a/models_test.go +++ b/models_test.go @@ -22,7 +22,7 @@ func TestListModels(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/moderation_test.go b/moderation_test.go index 483771ef4..75fd4d576 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -26,7 +26,7 @@ func TestModerations(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/request_builder_test.go b/request_builder_test.go index 998157385..9c1c6bfb9 100644 --- a/request_builder_test.go +++ b/request_builder_test.go @@ -44,7 +44,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = ts.URL + "/v1" client := NewClientWithConfig(config) client.requestBuilder = &failingRequestBuilder{} diff --git a/stream_test.go b/stream_test.go index 8ee869e74..0881b2572 100644 --- a/stream_test.go +++ b/stream_test.go @@ -14,7 +14,7 @@ import ( ) func TestCompletionsStreamWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config, _ := DefaultConfig("whatever") config.ApiBase = "http://localhost/v1" client := NewClientWithConfig(config) @@ -55,7 +55,7 @@ func TestCreateCompletionStream(t *testing.T) { defer server.Close() // Client portion of the test - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), @@ -140,7 +140,7 @@ func TestCreateCompletionStreamError(t *testing.T) { defer server.Close() // Client portion of the test - config := DefaultConfig(test.GetTestToken()) + config, _ := DefaultConfig(test.GetTestToken()) config.ApiBase = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), From 41a5ae53aaa03199630dc0cde97d3ed00b2be77f Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Fri, 31 Mar 2023 18:37:04 +0800 Subject: [PATCH 06/26] style: make the linter happy --- api.go | 16 ++++----- api_test.go | 2 +- audio_test.go | 4 +-- chat_stream_test.go | 6 ++-- chat_test.go | 6 ++-- completion_test.go | 4 +-- config.go | 68 +++++++++++++++++++-------------------- edits_test.go | 2 +- error_accumulator_test.go | 2 +- files_test.go | 2 +- fine_tunes_test.go | 2 +- image_test.go | 8 ++--- models_test.go | 2 +- moderation_test.go | 2 +- request_builder_test.go | 2 +- stream_test.go | 6 ++-- 16 files changed, 67 insertions(+), 67 deletions(-) diff --git a/api.go b/api.go index f468465cd..371f3c579 100644 --- a/api.go +++ b/api.go @@ -45,7 +45,7 @@ func NewOrgClient(authToken, org string) (*Client, error) { func (c *Client) sendRequest(req *http.Request, v interface{}) error { req.Header.Set("Accept", "application/json; charset=utf-8") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.ApiKey)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.APIKey)) // Check whether Content-Type is already set, Upload Files API requires // Content-Type == multipart/form-data @@ -90,13 +90,13 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { func (c *Client) fullURL(suffix string) string { // /openai/deployments/{engine}/chat/completions?api-version={api_version} - if c.config.ApiType == ApiTypeAzure || c.config.ApiType == ApiTypeAzureAD { + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { return fmt.Sprintf("%s%s/%s/%s%s?api-version=%s", - c.config.ApiBase, azureApiPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.ApiVersion) + c.config.APIBase, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion) } - // c.config.ApiType == ApiTypeOpenAI || c.config.ApiType == "" - return fmt.Sprintf("%s%s", c.config.ApiBase, suffix) + // c.config.APIType == APITypeOpenAI || c.config.APIType == "" + return fmt.Sprintf("%s%s", c.config.APIBase, suffix) } func (c *Client) newStreamRequest( @@ -116,11 +116,11 @@ func (c *Client) newStreamRequest( // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // Azure API Key authentication - if c.config.ApiType == ApiTypeAzure { - req.Header.Set("api-key", c.config.ApiKey) + if c.config.APIType == APITypeAzure { + req.Header.Set("api-key", c.config.APIKey) } else { // OpenAI or Azure AD authentication - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.ApiKey)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.APIKey)) } return req, nil } diff --git a/api_test.go b/api_test.go index f1bb949a3..eecb99880 100644 --- a/api_test.go +++ b/api_test.go @@ -132,7 +132,7 @@ func TestRequestError(t *testing.T) { var err error config, _ := DefaultConfig("dummy") - config.ApiBase = "https://httpbin.org/status/418?" + config.APIBase = "https://httpbin.org/status/418?" c := NewClientWithConfig(config) ctx := context.Background() _, err = c.ListEngines(ctx) diff --git a/audio_test.go b/audio_test.go index da233c192..e1c186910 100644 --- a/audio_test.go +++ b/audio_test.go @@ -31,7 +31,7 @@ func TestAudio(t *testing.T) { defer ts.Close() config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = ts.URL + "/v1" + config.APIBase = ts.URL + "/v1" client := NewClientWithConfig(config) testcases := []struct { @@ -79,7 +79,7 @@ func TestAudioWithOptionalArgs(t *testing.T) { defer ts.Close() config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = ts.URL + "/v1" + config.APIBase = ts.URL + "/v1" client := NewClientWithConfig(config) testcases := []struct { diff --git a/chat_stream_test.go b/chat_stream_test.go index 2fb6c6d77..514d1293b 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -16,7 +16,7 @@ import ( func TestChatCompletionsStreamWrongModel(t *testing.T) { config, _ := DefaultConfig("whatever") - config.ApiBase = "http://localhost/v1" + config.APIBase = "http://localhost/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -62,7 +62,7 @@ func TestCreateChatCompletionStream(t *testing.T) { // Client portion of the test config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = server.URL + "/v1" + config.APIBase = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), http.DefaultTransport, @@ -169,7 +169,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) { // Client portion of the test config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = server.URL + "/v1" + config.APIBase = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), http.DefaultTransport, diff --git a/chat_test.go b/chat_test.go index 5f937bb9f..2b2638c1c 100644 --- a/chat_test.go +++ b/chat_test.go @@ -18,7 +18,7 @@ import ( func TestChatCompletionsWrongModel(t *testing.T) { config, _ := DefaultConfig("whatever") - config.ApiBase = "http://localhost/v1" + config.APIBase = "http://localhost/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -39,7 +39,7 @@ func TestChatCompletionsWrongModel(t *testing.T) { func TestChatCompletionsWithStream(t *testing.T) { config, _ := DefaultConfig("whatever") - config.ApiBase = "http://localhost/v1" + config.APIBase = "http://localhost/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -61,7 +61,7 @@ func TestChatCompletions(t *testing.T) { defer ts.Close() config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = ts.URL + "/v1" + config.APIBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/completion_test.go b/completion_test.go index 17a40f2d4..4cf0a6130 100644 --- a/completion_test.go +++ b/completion_test.go @@ -19,7 +19,7 @@ import ( func TestCompletionsWrongModel(t *testing.T) { config, _ := DefaultConfig("whatever") - config.ApiBase = "http://localhost/v1" + config.APIBase = "http://localhost/v1" client := NewClientWithConfig(config) _, err := client.CreateCompletion( @@ -57,7 +57,7 @@ func TestCompletions(t *testing.T) { defer ts.Close() config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = ts.URL + "/v1" + config.APIBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/config.go b/config.go index d8115f450..7791f8785 100644 --- a/config.go +++ b/config.go @@ -6,33 +6,33 @@ import ( ) const ( - openaiApiURLv1 = "https://api.openai.com/v1" + openaiAPIURLv1 = "https://api.openai.com/v1" defaultEmptyMessagesLimit uint = 300 - azureApiPrefix = "openai" + azureAPIPrefix = "openai" azureDeploymentsPrefix = "deployments" ) -type ApiType string +type APIType string const ( - ApiTypeOpenAI ApiType = "OPEN_AI" - ApiTypeAzure ApiType = "AZURE" - ApiTypeAzureAD ApiType = "AZURE_AD" + APITypeOpenAI APIType = "OPEN_AI" + APITypeAzure APIType = "AZURE" + APITypeAzureAD APIType = "AZURE_AD" ) -var supportedApiType = map[ApiType]struct{}{ - ApiTypeOpenAI: {}, - ApiTypeAzure: {}, - ApiTypeAzureAD: {}, +var supportedAPIType = map[APIType]struct{}{ + APITypeOpenAI: {}, + APITypeAzure: {}, + APITypeAzureAD: {}, } // ClientConfig is a configuration of a client. type ClientConfig struct { - ApiType ApiType - ApiKey string - ApiBase string - ApiVersion string + APIType APIType + APIKey string + APIBase string + APIVersion string Engine string OrgID string @@ -43,15 +43,15 @@ type ClientConfig struct { } func DefaultConfig(apiKey string) (ClientConfig, error) { - return NewConfig(WithApiKey(apiKey)) + return NewConfig(WithAPIKey(apiKey)) } func NewConfig(opts ...Option) (ClientConfig, error) { cfg := ClientConfig{ - ApiType: ApiTypeOpenAI, - ApiKey: "", - ApiBase: openaiApiURLv1, - ApiVersion: "", + APIType: APITypeOpenAI, + APIKey: "", + APIBase: openaiAPIURLv1, + APIVersion: "", Engine: "", OrgID: "", HTTPClient: &http.Client{}, @@ -61,16 +61,16 @@ func NewConfig(opts ...Option) (ClientConfig, error) { o(&cfg) } - if cfg.ApiKey == "" { + if cfg.APIKey == "" { return ClientConfig{}, fmt.Errorf("api key is required") } - if _, ok := supportedApiType[cfg.ApiType]; !ok { - return ClientConfig{}, fmt.Errorf("unsupported API type %s", cfg.ApiType) + if _, ok := supportedAPIType[cfg.APIType]; !ok { + return ClientConfig{}, fmt.Errorf("unsupported API type %s", cfg.APIType) } - if cfg.ApiType == ApiTypeAzure || cfg.ApiType == ApiTypeAzureAD { - if cfg.ApiVersion == "" { + if cfg.APIType == APITypeAzure || cfg.APIType == APITypeAzureAD { + if cfg.APIVersion == "" { return ClientConfig{}, fmt.Errorf("an API version is required for the Azure API type") } } @@ -80,10 +80,10 @@ func NewConfig(opts ...Option) (ClientConfig, error) { type Option func(*ClientConfig) -// WithApiType sets the API type to use. -func WithApiType(apiType ApiType) Option { +// WithAPIType sets the API type to use. +func WithAPIType(apiType APIType) Option { return func(o *ClientConfig) { - o.ApiType = apiType + o.APIType = apiType } } @@ -94,10 +94,10 @@ func WithEngine(engine string) Option { } } -// WithApiVersion sets the API version to use. -func WithApiVersion(apiVersion string) Option { +// WithAPIVersion sets the API version to use. +func WithAPIVersion(apiVersion string) Option { return func(o *ClientConfig) { - o.ApiVersion = apiVersion + o.APIVersion = apiVersion } } @@ -108,15 +108,15 @@ func WithHTTPClient(client *http.Client) Option { } } -func WithApiBase(apiBase string) Option { +func WithAPIBase(apiBase string) Option { return func(o *ClientConfig) { - o.ApiBase = apiBase + o.APIBase = apiBase } } -func WithApiKey(apiKey string) Option { +func WithAPIKey(apiKey string) Option { return func(o *ClientConfig) { - o.ApiKey = apiKey + o.APIKey = apiKey } } diff --git a/edits_test.go b/edits_test.go index 8d9387e52..03b90d1e9 100644 --- a/edits_test.go +++ b/edits_test.go @@ -25,7 +25,7 @@ func TestEdits(t *testing.T) { defer ts.Close() config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = ts.URL + "/v1" + config.APIBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/error_accumulator_test.go b/error_accumulator_test.go index 02725e784..96258377a 100644 --- a/error_accumulator_test.go +++ b/error_accumulator_test.go @@ -76,7 +76,7 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) { defer ts.Close() config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = ts.URL + "/v1" + config.APIBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/files_test.go b/files_test.go index 9d23b7dbe..f6896297f 100644 --- a/files_test.go +++ b/files_test.go @@ -24,7 +24,7 @@ func TestFileUpload(t *testing.T) { defer ts.Close() config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = ts.URL + "/v1" + config.APIBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/fine_tunes_test.go b/fine_tunes_test.go index fb9c99456..dc61c418e 100644 --- a/fine_tunes_test.go +++ b/fine_tunes_test.go @@ -66,7 +66,7 @@ func TestFineTunes(t *testing.T) { defer ts.Close() config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = ts.URL + "/v1" + config.APIBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/image_test.go b/image_test.go index 66f712591..71674dfc1 100644 --- a/image_test.go +++ b/image_test.go @@ -25,7 +25,7 @@ func TestImages(t *testing.T) { defer ts.Close() config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = ts.URL + "/v1" + config.APIBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -95,7 +95,7 @@ func TestImageEdit(t *testing.T) { defer ts.Close() config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = ts.URL + "/v1" + config.APIBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -139,7 +139,7 @@ func TestImageEditWithoutMask(t *testing.T) { defer ts.Close() config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = ts.URL + "/v1" + config.APIBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -205,7 +205,7 @@ func TestImageVariation(t *testing.T) { defer ts.Close() config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = ts.URL + "/v1" + config.APIBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/models_test.go b/models_test.go index c270ec499..ce8135fd9 100644 --- a/models_test.go +++ b/models_test.go @@ -23,7 +23,7 @@ func TestListModels(t *testing.T) { defer ts.Close() config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = ts.URL + "/v1" + config.APIBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/moderation_test.go b/moderation_test.go index 75fd4d576..1339e1baa 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -27,7 +27,7 @@ func TestModerations(t *testing.T) { defer ts.Close() config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = ts.URL + "/v1" + config.APIBase = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/request_builder_test.go b/request_builder_test.go index 9c1c6bfb9..8b03c90eb 100644 --- a/request_builder_test.go +++ b/request_builder_test.go @@ -45,7 +45,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { defer ts.Close() config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = ts.URL + "/v1" + config.APIBase = ts.URL + "/v1" client := NewClientWithConfig(config) client.requestBuilder = &failingRequestBuilder{} diff --git a/stream_test.go b/stream_test.go index 0881b2572..a40d10398 100644 --- a/stream_test.go +++ b/stream_test.go @@ -15,7 +15,7 @@ import ( func TestCompletionsStreamWrongModel(t *testing.T) { config, _ := DefaultConfig("whatever") - config.ApiBase = "http://localhost/v1" + config.APIBase = "http://localhost/v1" client := NewClientWithConfig(config) _, err := client.CreateCompletionStream( @@ -56,7 +56,7 @@ func TestCreateCompletionStream(t *testing.T) { // Client portion of the test config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = server.URL + "/v1" + config.APIBase = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), http.DefaultTransport, @@ -141,7 +141,7 @@ func TestCreateCompletionStreamError(t *testing.T) { // Client portion of the test config, _ := DefaultConfig(test.GetTestToken()) - config.ApiBase = server.URL + "/v1" + config.APIBase = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), http.DefaultTransport, From 8cdf9a87db859d48ca34f21a0450450f2e96bfed Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Fri, 31 Mar 2023 19:44:16 +0800 Subject: [PATCH 07/26] fix: support Azure API Key authentication in sendRequest --- api.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/api.go b/api.go index 371f3c579..3e4ca10aa 100644 --- a/api.go +++ b/api.go @@ -45,7 +45,13 @@ func NewOrgClient(authToken, org string) (*Client, error) { func (c *Client) sendRequest(req *http.Request, v interface{}) error { req.Header.Set("Accept", "application/json; charset=utf-8") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.APIKey)) + // Azure API Key authentication + if c.config.APIType == APITypeAzure { + req.Header.Set("api-key", c.config.APIKey) + } else { + // OpenAI or Azure AD authentication + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.APIKey)) + } // Check whether Content-Type is already set, Upload Files API requires // Content-Type == multipart/form-data From f5157ea63187916a104969f14ea6fb2b46a45f39 Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Fri, 31 Mar 2023 21:19:24 +0800 Subject: [PATCH 08/26] chore: check error in CreateChatCompletionStream --- chat_stream.go | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index 009e1b135..126dbbd47 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -3,6 +3,9 @@ package openai import ( "bufio" "context" + "encoding/json" + "fmt" + "net/http" ) type ChatCompletionStreamChoiceDelta struct { @@ -50,16 +53,32 @@ func (c *Client) CreateChatCompletionStream( return } - resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + res, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() if err != nil { return } + if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { + var errRes ErrorResponse + err = json.NewDecoder(res.Body).Decode(&errRes) + if err != nil || errRes.Error == nil { + reqErr := RequestError{ + StatusCode: res.StatusCode, + Err: err, + } + err = fmt.Errorf("error, %w", &reqErr) + return + } + errRes.Error.StatusCode = res.StatusCode + err = fmt.Errorf("error, status code: %d, message: %w", res.StatusCode, errRes.Error) + return + } + stream = &ChatCompletionStream{ streamReader: &streamReader[ChatCompletionStreamResponse]{ emptyMessagesLimit: c.config.EmptyMessagesLimit, - reader: bufio.NewReader(resp.Body), - response: resp, + reader: bufio.NewReader(res.Body), + response: res, errAccumulator: newErrorAccumulator(), unmarshaler: &jsonUnmarshaler{}, }, From a8a6b7e0b3e390a2ef2735d7ee2ef04ff9648e76 Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Fri, 31 Mar 2023 21:22:56 +0800 Subject: [PATCH 09/26] chore: pass tests --- chat_stream.go | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index 126dbbd47..4a5d714f6 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -3,7 +3,6 @@ package openai import ( "bufio" "context" - "encoding/json" "fmt" "net/http" ) @@ -59,18 +58,7 @@ func (c *Client) CreateChatCompletionStream( } if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { - var errRes ErrorResponse - err = json.NewDecoder(res.Body).Decode(&errRes) - if err != nil || errRes.Error == nil { - reqErr := RequestError{ - StatusCode: res.StatusCode, - Err: err, - } - err = fmt.Errorf("error, %w", &reqErr) - return - } - errRes.Error.StatusCode = res.StatusCode - err = fmt.Errorf("error, status code: %d, message: %w", res.StatusCode, errRes.Error) + err = fmt.Errorf("error, status code: %d", res.StatusCode) return } From ca1315e724dd9fef878f7fc473de5ef8859e489e Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Fri, 31 Mar 2023 21:24:23 +0800 Subject: [PATCH 10/26] chore: try pass tests again --- chat_stream.go | 1 - 1 file changed, 1 deletion(-) diff --git a/chat_stream.go b/chat_stream.go index 4a5d714f6..5488e080a 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -59,7 +59,6 @@ func (c *Client) CreateChatCompletionStream( if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { err = fmt.Errorf("error, status code: %d", res.StatusCode) - return } stream = &ChatCompletionStream{ From 5da05cf5098cc058a487f908d02aaf7a7fa39d9a Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Mon, 3 Apr 2023 11:10:42 +0800 Subject: [PATCH 11/26] chore: change ClientConfig back due to this lib does not like WithXxx config style --- api.go | 30 +++++------ api_test.go | 8 +-- audio_test.go | 8 +-- chat_stream_test.go | 12 ++--- chat_test.go | 12 ++--- completion_test.go | 10 ++-- config.go | 105 +++++--------------------------------- edits_test.go | 4 +- error_accumulator_test.go | 4 +- files_test.go | 4 +- fine_tunes_test.go | 4 +- image_test.go | 16 +++--- models_test.go | 4 +- moderation_test.go | 4 +- request_builder_test.go | 4 +- stream_test.go | 12 ++--- 16 files changed, 78 insertions(+), 163 deletions(-) diff --git a/api.go b/api.go index 3e4ca10aa..ebc669309 100644 --- a/api.go +++ b/api.go @@ -15,12 +15,9 @@ type Client struct { } // NewClient creates new OpenAI API client. -func NewClient(authToken string) (*Client, error) { - config, err := DefaultConfig(authToken) - if err != nil { - return nil, err - } - return NewClientWithConfig(config), nil +func NewClient(authToken string) *Client { + config := DefaultConfig(authToken) + return NewClientWithConfig(config) } // NewClientWithConfig creates new OpenAI API client for specified config. @@ -34,23 +31,20 @@ func NewClientWithConfig(config ClientConfig) *Client { // NewOrgClient creates new OpenAI API client for specified Organization ID. // // Deprecated: Please use NewClientWithConfig. -func NewOrgClient(authToken, org string) (*Client, error) { - config, err := DefaultConfig(authToken) - if err != nil { - return nil, err - } +func NewOrgClient(authToken, org string) *Client { + config := DefaultConfig(authToken) config.OrgID = org - return NewClientWithConfig(config), nil + return NewClientWithConfig(config) } func (c *Client) sendRequest(req *http.Request, v interface{}) error { req.Header.Set("Accept", "application/json; charset=utf-8") // Azure API Key authentication if c.config.APIType == APITypeAzure { - req.Header.Set("api-key", c.config.APIKey) + req.Header.Set("api-key", c.config.authToken) } else { // OpenAI or Azure AD authentication - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.APIKey)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) } // Check whether Content-Type is already set, Upload Files API requires @@ -98,11 +92,11 @@ func (c *Client) fullURL(suffix string) string { // /openai/deployments/{engine}/chat/completions?api-version={api_version} if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { return fmt.Sprintf("%s%s/%s/%s%s?api-version=%s", - c.config.APIBase, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion) + c.config.BaseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion) } // c.config.APIType == APITypeOpenAI || c.config.APIType == "" - return fmt.Sprintf("%s%s", c.config.APIBase, suffix) + return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } func (c *Client) newStreamRequest( @@ -123,10 +117,10 @@ func (c *Client) newStreamRequest( // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // Azure API Key authentication if c.config.APIType == APITypeAzure { - req.Header.Set("api-key", c.config.APIKey) + req.Header.Set("api-key", c.config.authToken) } else { // OpenAI or Azure AD authentication - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.APIKey)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) } return req, nil } diff --git a/api_test.go b/api_test.go index eecb99880..478a274d4 100644 --- a/api_test.go +++ b/api_test.go @@ -18,7 +18,7 @@ func TestAPI(t *testing.T) { } var err error - c, _ := NewClient(apiToken) + c := NewClient(apiToken) ctx := context.Background() _, err = c.ListEngines(ctx) checks.NoError(t, err, "ListEngines error") @@ -107,7 +107,7 @@ func TestAPIError(t *testing.T) { } var err error - c, _ := NewClient(apiToken + "_invalid") + c := NewClient(apiToken + "_invalid") ctx := context.Background() _, err = c.ListEngines(ctx) checks.NoError(t, err, "ListEngines did not fail") @@ -131,8 +131,8 @@ func TestAPIError(t *testing.T) { func TestRequestError(t *testing.T) { var err error - config, _ := DefaultConfig("dummy") - config.APIBase = "https://httpbin.org/status/418?" + config := DefaultConfig("dummy") + config.BaseURL = "https://httpbin.org/status/418?" c := NewClientWithConfig(config) ctx := context.Background() _, err = c.ListEngines(ctx) diff --git a/audio_test.go b/audio_test.go index e1c186910..087084805 100644 --- a/audio_test.go +++ b/audio_test.go @@ -30,8 +30,8 @@ func TestAudio(t *testing.T) { ts.Start() defer ts.Close() - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = ts.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) testcases := []struct { @@ -78,8 +78,8 @@ func TestAudioWithOptionalArgs(t *testing.T) { ts.Start() defer ts.Close() - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = ts.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) testcases := []struct { diff --git a/chat_stream_test.go b/chat_stream_test.go index 514d1293b..24046db6c 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -15,8 +15,8 @@ import ( ) func TestChatCompletionsStreamWrongModel(t *testing.T) { - config, _ := DefaultConfig("whatever") - config.APIBase = "http://localhost/v1" + config := DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -61,8 +61,8 @@ func TestCreateChatCompletionStream(t *testing.T) { defer server.Close() // Client portion of the test - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = server.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), http.DefaultTransport, @@ -168,8 +168,8 @@ func TestCreateChatCompletionStreamError(t *testing.T) { defer server.Close() // Client portion of the test - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = server.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), http.DefaultTransport, diff --git a/chat_test.go b/chat_test.go index 2b2638c1c..ce302a69f 100644 --- a/chat_test.go +++ b/chat_test.go @@ -17,8 +17,8 @@ import ( ) func TestChatCompletionsWrongModel(t *testing.T) { - config, _ := DefaultConfig("whatever") - config.APIBase = "http://localhost/v1" + config := DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -38,8 +38,8 @@ func TestChatCompletionsWrongModel(t *testing.T) { } func TestChatCompletionsWithStream(t *testing.T) { - config, _ := DefaultConfig("whatever") - config.APIBase = "http://localhost/v1" + config := DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -60,8 +60,8 @@ func TestChatCompletions(t *testing.T) { ts.Start() defer ts.Close() - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = ts.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/completion_test.go b/completion_test.go index 4cf0a6130..2e302591a 100644 --- a/completion_test.go +++ b/completion_test.go @@ -18,8 +18,8 @@ import ( ) func TestCompletionsWrongModel(t *testing.T) { - config, _ := DefaultConfig("whatever") - config.APIBase = "http://localhost/v1" + config := DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" client := NewClientWithConfig(config) _, err := client.CreateCompletion( @@ -35,7 +35,7 @@ func TestCompletionsWrongModel(t *testing.T) { } func TestCompletionWithStream(t *testing.T) { - config, _ := DefaultConfig("whatever") + config := DefaultConfig("whatever") client := NewClientWithConfig(config) ctx := context.Background() @@ -56,8 +56,8 @@ func TestCompletions(t *testing.T) { ts.Start() defer ts.Close() - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = ts.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/config.go b/config.go index 7791f8785..07606bc15 100644 --- a/config.go +++ b/config.go @@ -1,7 +1,6 @@ package openai import ( - "fmt" "net/http" ) @@ -21,108 +20,30 @@ const ( APITypeAzureAD APIType = "AZURE_AD" ) -var supportedAPIType = map[APIType]struct{}{ - APITypeOpenAI: {}, - APITypeAzure: {}, - APITypeAzureAD: {}, -} - // ClientConfig is a configuration of a client. type ClientConfig struct { + authToken string + BaseURL string + OrgID string + APIType APIType - APIKey string - APIBase string APIVersion string - - Engine string - OrgID string + Engine string HTTPClient *http.Client EmptyMessagesLimit uint } -func DefaultConfig(apiKey string) (ClientConfig, error) { - return NewConfig(WithAPIKey(apiKey)) -} - -func NewConfig(opts ...Option) (ClientConfig, error) { - cfg := ClientConfig{ - APIType: APITypeOpenAI, - APIKey: "", - APIBase: openaiAPIURLv1, - APIVersion: "", - Engine: "", - OrgID: "", - HTTPClient: &http.Client{}, - EmptyMessagesLimit: defaultEmptyMessagesLimit, - } - for _, o := range opts { - o(&cfg) - } - - if cfg.APIKey == "" { - return ClientConfig{}, fmt.Errorf("api key is required") - } - - if _, ok := supportedAPIType[cfg.APIType]; !ok { - return ClientConfig{}, fmt.Errorf("unsupported API type %s", cfg.APIType) - } - - if cfg.APIType == APITypeAzure || cfg.APIType == APITypeAzureAD { - if cfg.APIVersion == "" { - return ClientConfig{}, fmt.Errorf("an API version is required for the Azure API type") - } - } - - return cfg, nil -} - -type Option func(*ClientConfig) +func DefaultConfig(authToken string) ClientConfig { + return ClientConfig{ + authToken: authToken, + BaseURL: openaiAPIURLv1, + APIType: APITypeOpenAI, + OrgID: "", -// WithAPIType sets the API type to use. -func WithAPIType(apiType APIType) Option { - return func(o *ClientConfig) { - o.APIType = apiType - } -} - -// WithEngine sets the engine to use. -func WithEngine(engine string) Option { - return func(o *ClientConfig) { - o.Engine = engine - } -} - -// WithAPIVersion sets the API version to use. -func WithAPIVersion(apiVersion string) Option { - return func(o *ClientConfig) { - o.APIVersion = apiVersion - } -} - -// WithHTTPClient sets the HTTP client to use. -func WithHTTPClient(client *http.Client) Option { - return func(o *ClientConfig) { - o.HTTPClient = client - } -} - -func WithAPIBase(apiBase string) Option { - return func(o *ClientConfig) { - o.APIBase = apiBase - } -} + HTTPClient: &http.Client{}, -func WithAPIKey(apiKey string) Option { - return func(o *ClientConfig) { - o.APIKey = apiKey - } -} - -// WithOrgID sets the organization ID to use. -func WithOrgID(orgID string) Option { - return func(o *ClientConfig) { - o.OrgID = orgID + EmptyMessagesLimit: defaultEmptyMessagesLimit, } } diff --git a/edits_test.go b/edits_test.go index 03b90d1e9..fa6c12825 100644 --- a/edits_test.go +++ b/edits_test.go @@ -24,8 +24,8 @@ func TestEdits(t *testing.T) { ts.Start() defer ts.Close() - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = ts.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/error_accumulator_test.go b/error_accumulator_test.go index 96258377a..637bf3678 100644 --- a/error_accumulator_test.go +++ b/error_accumulator_test.go @@ -75,8 +75,8 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) { ts.Start() defer ts.Close() - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = ts.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/files_test.go b/files_test.go index f6896297f..3e8dfc442 100644 --- a/files_test.go +++ b/files_test.go @@ -23,8 +23,8 @@ func TestFileUpload(t *testing.T) { ts.Start() defer ts.Close() - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = ts.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/fine_tunes_test.go b/fine_tunes_test.go index dc61c418e..c60254993 100644 --- a/fine_tunes_test.go +++ b/fine_tunes_test.go @@ -65,8 +65,8 @@ func TestFineTunes(t *testing.T) { ts.Start() defer ts.Close() - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = ts.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/image_test.go b/image_test.go index 71674dfc1..9917b7881 100644 --- a/image_test.go +++ b/image_test.go @@ -24,8 +24,8 @@ func TestImages(t *testing.T) { ts.Start() defer ts.Close() - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = ts.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -94,8 +94,8 @@ func TestImageEdit(t *testing.T) { ts.Start() defer ts.Close() - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = ts.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -138,8 +138,8 @@ func TestImageEditWithoutMask(t *testing.T) { ts.Start() defer ts.Close() - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = ts.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() @@ -204,8 +204,8 @@ func TestImageVariation(t *testing.T) { ts.Start() defer ts.Close() - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = ts.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/models_test.go b/models_test.go index ce8135fd9..dad59be79 100644 --- a/models_test.go +++ b/models_test.go @@ -22,8 +22,8 @@ func TestListModels(t *testing.T) { ts.Start() defer ts.Close() - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = ts.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/moderation_test.go b/moderation_test.go index 1339e1baa..3535bc807 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -26,8 +26,8 @@ func TestModerations(t *testing.T) { ts.Start() defer ts.Close() - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = ts.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() diff --git a/request_builder_test.go b/request_builder_test.go index 8b03c90eb..e5b65df0c 100644 --- a/request_builder_test.go +++ b/request_builder_test.go @@ -44,8 +44,8 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { ts.Start() defer ts.Close() - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = ts.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) client.requestBuilder = &failingRequestBuilder{} diff --git a/stream_test.go b/stream_test.go index a40d10398..a80504d24 100644 --- a/stream_test.go +++ b/stream_test.go @@ -14,8 +14,8 @@ import ( ) func TestCompletionsStreamWrongModel(t *testing.T) { - config, _ := DefaultConfig("whatever") - config.APIBase = "http://localhost/v1" + config := DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" client := NewClientWithConfig(config) _, err := client.CreateCompletionStream( @@ -55,8 +55,8 @@ func TestCreateCompletionStream(t *testing.T) { defer server.Close() // Client portion of the test - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = server.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), http.DefaultTransport, @@ -140,8 +140,8 @@ func TestCreateCompletionStreamError(t *testing.T) { defer server.Close() // Client portion of the test - config, _ := DefaultConfig(test.GetTestToken()) - config.APIBase = server.URL + "/v1" + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" config.HTTPClient.Transport = &tokenRoundTripper{ test.GetTestToken(), http.DefaultTransport, From 94438d4e70dd376fd2c152599a969377019a5e3a Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Mon, 3 Apr 2023 11:20:53 +0800 Subject: [PATCH 12/26] chore: revert fix to CreateChatCompletionStream() due to cause tests not pass --- chat_stream.go | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index 5488e080a..009e1b135 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -3,8 +3,6 @@ package openai import ( "bufio" "context" - "fmt" - "net/http" ) type ChatCompletionStreamChoiceDelta struct { @@ -52,20 +50,16 @@ func (c *Client) CreateChatCompletionStream( return } - res, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() if err != nil { return } - if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { - err = fmt.Errorf("error, status code: %d", res.StatusCode) - } - stream = &ChatCompletionStream{ streamReader: &streamReader[ChatCompletionStreamResponse]{ emptyMessagesLimit: c.config.EmptyMessagesLimit, - reader: bufio.NewReader(res.Body), - response: res, + reader: bufio.NewReader(resp.Body), + response: resp, errAccumulator: newErrorAccumulator(), unmarshaler: &jsonUnmarshaler{}, }, From 37eecb89c3a54ded1e38244d0efd88cbc82c529a Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Mon, 3 Apr 2023 14:29:42 +0800 Subject: [PATCH 13/26] chore: at least add some comment about the required fields --- config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 07606bc15..d308e56f4 100644 --- a/config.go +++ b/config.go @@ -27,8 +27,8 @@ type ClientConfig struct { OrgID string APIType APIType - APIVersion string - Engine string + APIVersion string // required for APITypeAzure or APITypeAzureAD + Engine string // required for APITypeAzure or APITypeAzureAD HTTPClient *http.Client From 60cafde1810f7fea708be7a2391e2af356db49cc Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Mon, 3 Apr 2023 20:05:37 +0800 Subject: [PATCH 14/26] chore: re order ClientConfig fields --- config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index d308e56f4..ae9c1544e 100644 --- a/config.go +++ b/config.go @@ -23,9 +23,9 @@ const ( // ClientConfig is a configuration of a client. type ClientConfig struct { authToken string - BaseURL string - OrgID string + BaseURL string + OrgID string APIType APIType APIVersion string // required for APITypeAzure or APITypeAzureAD Engine string // required for APITypeAzure or APITypeAzureAD From 006ac907e282fdc778793e61025326ae932f3085 Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Mon, 3 Apr 2023 20:08:02 +0800 Subject: [PATCH 15/26] chore: add DefaultAzure() --- config.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/config.go b/config.go index ae9c1544e..5057b85f6 100644 --- a/config.go +++ b/config.go @@ -47,3 +47,18 @@ func DefaultConfig(authToken string) ClientConfig { EmptyMessagesLimit: defaultEmptyMessagesLimit, } } + +func DefaultAzure(apiKey, baseURl, engine, apiVersion string) ClientConfig { + return ClientConfig{ + authToken: apiKey, + BaseURL: baseURl, + OrgID: "", + APIType: APITypeAzure, + APIVersion: apiVersion, + Engine: engine, + + HTTPClient: &http.Client{}, + + EmptyMessagesLimit: defaultEmptyMessagesLimit, + } +} From 84ea88179c9d9be8a817141bbcde1a108ce07608 Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Mon, 3 Apr 2023 20:14:24 +0800 Subject: [PATCH 16/26] chore: set default api_version the same as py one "2023-03-15-preview" --- config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 5057b85f6..4213c0ace 100644 --- a/config.go +++ b/config.go @@ -48,13 +48,13 @@ func DefaultConfig(authToken string) ClientConfig { } } -func DefaultAzure(apiKey, baseURl, engine, apiVersion string) ClientConfig { +func DefaultAzure(apiKey, baseURl, engine string) ClientConfig { return ClientConfig{ authToken: apiKey, BaseURL: baseURl, OrgID: "", APIType: APITypeAzure, - APIVersion: apiVersion, + APIVersion: "2023-03-15-preview", Engine: engine, HTTPClient: &http.Client{}, From becc1bb4221d6edd15ae4ca0b13c9ed985b66123 Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Mon, 3 Apr 2023 20:18:18 +0800 Subject: [PATCH 17/26] style: fixup typo --- config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 4213c0ace..ab64e8a4d 100644 --- a/config.go +++ b/config.go @@ -48,10 +48,10 @@ func DefaultConfig(authToken string) ClientConfig { } } -func DefaultAzure(apiKey, baseURl, engine string) ClientConfig { +func DefaultAzure(apiKey, baseURL, engine string) ClientConfig { return ClientConfig{ authToken: apiKey, - BaseURL: baseURl, + BaseURL: baseURL, OrgID: "", APIType: APITypeAzure, APIVersion: "2023-03-15-preview", From f5c822da563f00ffec40d9d80f210efc5d5ef859 Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Tue, 4 Apr 2023 01:30:13 +0800 Subject: [PATCH 18/26] test: add api_internal_test.go --- api_internal_test.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 api_internal_test.go diff --git a/api_internal_test.go b/api_internal_test.go new file mode 100644 index 000000000..b070e6db5 --- /dev/null +++ b/api_internal_test.go @@ -0,0 +1,18 @@ +package openai + +import ( + "fmt" + "testing" +) + +func TestAzureFullURL(t *testing.T) { + az := DefaultAzure("dummy", "https://httpbin.org/", "chatgpt-demo") + cli := NewClientWithConfig(az) + // /openai/deployments/{engine}/chat/completions?api-version={api_version} + expect := fmt.Sprintf("https://httpbin.org/openai/deployments/chatgpt-demo/chat/completions?api-version=2023-03-15-preview") + actual := cli.fullURL("/chat/completions") + if actual != expect { + t.Errorf("Expected %s, got %s", expect, actual) + } + t.Logf("Full URL: %s", actual) +} From ee4bd5dbc608ef7514ef9742122647a958557b95 Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Tue, 4 Apr 2023 01:31:27 +0800 Subject: [PATCH 19/26] style: make lint happy --- api_internal_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index b070e6db5..38f3c675a 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -1,7 +1,6 @@ package openai import ( - "fmt" "testing" ) @@ -9,7 +8,8 @@ func TestAzureFullURL(t *testing.T) { az := DefaultAzure("dummy", "https://httpbin.org/", "chatgpt-demo") cli := NewClientWithConfig(az) // /openai/deployments/{engine}/chat/completions?api-version={api_version} - expect := fmt.Sprintf("https://httpbin.org/openai/deployments/chatgpt-demo/chat/completions?api-version=2023-03-15-preview") + expect := "https://httpbin.org/" + + "openai/deployments/chatgpt-demo/chat/completions?api-version=2023-03-15-preview" actual := cli.fullURL("/chat/completions") if actual != expect { t.Errorf("Expected %s, got %s", expect, actual) From 416dc264b2e6ce6a17126dc1e6806bc881a6cb5a Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Tue, 4 Apr 2023 01:36:06 +0800 Subject: [PATCH 20/26] chore: add constant AzureAPIKeyHeader --- api.go | 2 +- config.go | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/api.go b/api.go index ebc669309..0091b70f5 100644 --- a/api.go +++ b/api.go @@ -117,7 +117,7 @@ func (c *Client) newStreamRequest( // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // Azure API Key authentication if c.config.APIType == APITypeAzure { - req.Header.Set("api-key", c.config.authToken) + req.Header.Set(AzureAPIKeyHeader, c.config.authToken) } else { // OpenAI or Azure AD authentication req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) diff --git a/config.go b/config.go index ab64e8a4d..9033204a5 100644 --- a/config.go +++ b/config.go @@ -20,6 +20,8 @@ const ( APITypeAzureAD APIType = "AZURE_AD" ) +const AzureAPIKeyHeader = "api-key" + // ClientConfig is a configuration of a client. type ClientConfig struct { authToken string From 27be57276eb1c80e20e3e8b79f5ce191fb6b5845 Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Tue, 4 Apr 2023 02:04:22 +0800 Subject: [PATCH 21/26] chore: use AzureAPIKeyHeader for api-key header, fix azure base url auto trim suffix / --- api.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/api.go b/api.go index 0091b70f5..2c978bc25 100644 --- a/api.go +++ b/api.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" ) // Client is OpenAI GPT-3 API client. @@ -41,7 +42,7 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { req.Header.Set("Accept", "application/json; charset=utf-8") // Azure API Key authentication if c.config.APIType == APITypeAzure { - req.Header.Set("api-key", c.config.authToken) + req.Header.Set(AzureAPIKeyHeader, c.config.authToken) } else { // OpenAI or Azure AD authentication req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) @@ -91,8 +92,10 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { func (c *Client) fullURL(suffix string) string { // /openai/deployments/{engine}/chat/completions?api-version={api_version} if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { - return fmt.Sprintf("%s%s/%s/%s%s?api-version=%s", - c.config.BaseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion) + baseURL := c.config.BaseURL + baseURL = strings.TrimRight(baseURL, "/") + return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s", + baseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion) } // c.config.APIType == APITypeOpenAI || c.config.APIType == "" From ebef9a9219e75bb4a29ad8c8433e55083c9d248d Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Tue, 4 Apr 2023 02:04:52 +0800 Subject: [PATCH 22/26] test: add TestAzureFullURL, TestRequestAuthHeader and TestOpenAIFullURL --- api_internal_test.go | 128 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 119 insertions(+), 9 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index 38f3c675a..ea007ffdd 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -1,18 +1,128 @@ package openai import ( + "context" "testing" ) +func TestOpenAIFullURL(t *testing.T) { + cases := []struct { + Name string + BaseURL string + Engine string + Expect string + }{ + { + "DefaultConfig", + "", + "", + "https://api.openai.com/v1/chat/completions", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultConfig("dummy") + cli := NewClientWithConfig(az) + // /openai/deployments/{engine}/chat/completions?api-version={api_version} + actual := cli.fullURL("/chat/completions") + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("Full URL: %s", actual) + }) + } +} + +func TestRequestAuthHeader(t *testing.T) { + cases := []struct { + Name string + APIType APIType + HeaderKey string + Token string + Expect string + }{ + { + "OpenAI", + APITypeOpenAI, + "Authorization", + "dummy-token-openai", + "Bearer dummy-token-openai", + }, + { + "AzureAD", + APITypeAzureAD, + "Authorization", + "dummy-token-azure", + "Bearer dummy-token-azure", + }, + { + "Azure", + APITypeAzure, + AzureAPIKeyHeader, + "dummy-api-key-here", + "dummy-api-key-here", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultConfig(c.Token) + if c.APIType == APITypeAzureAD { + az.APIType = APITypeAzureAD + } else if c.APIType == APITypeAzure { + az.APIType = APITypeAzure + } + + cli := NewClientWithConfig(az) + req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil) + if err != nil { + t.Errorf("Failed to create request: %v", err) + } + actual := req.Header.Get(c.HeaderKey) + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("%s: %s", c.HeaderKey, actual) + }) + } +} + func TestAzureFullURL(t *testing.T) { - az := DefaultAzure("dummy", "https://httpbin.org/", "chatgpt-demo") - cli := NewClientWithConfig(az) - // /openai/deployments/{engine}/chat/completions?api-version={api_version} - expect := "https://httpbin.org/" + - "openai/deployments/chatgpt-demo/chat/completions?api-version=2023-03-15-preview" - actual := cli.fullURL("/chat/completions") - if actual != expect { - t.Errorf("Expected %s, got %s", expect, actual) + cases := []struct { + Name string + BaseURL string + Engine string + Expect string + }{ + { + "AzureBaseURLWithSlashAutoStrip", + "https://httpbin.org/", + "chatgpt-demo", + "https://httpbin.org/" + + "openai/deployments/chatgpt-demo" + + "/chat/completions?api-version=2023-03-15-preview", + }, + { + "AzureBaseURLWithoutSlashOK", + "https://httpbin.org", + "chatgpt-demo", + "https://httpbin.org/" + + "openai/deployments/chatgpt-demo" + + "/chat/completions?api-version=2023-03-15-preview", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultAzure("dummy", c.BaseURL, c.Engine) + cli := NewClientWithConfig(az) + // /openai/deployments/{engine}/chat/completions?api-version={api_version} + actual := cli.fullURL("/chat/completions") + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("Full URL: %s", actual) + }) } - t.Logf("Full URL: %s", actual) } From 862ce64b8c91137728701f688e740ce0ac78863c Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Tue, 4 Apr 2023 02:08:47 +0800 Subject: [PATCH 23/26] test: simplify TestRequestAuthHeader --- api_internal_test.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index ea007ffdd..888a26617 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -42,6 +42,13 @@ func TestRequestAuthHeader(t *testing.T) { Token string Expect string }{ + { + "OpenAIDefault", + "", + "Authorization", + "dummy-token-openai", + "Bearer dummy-token-openai", + }, { "OpenAI", APITypeOpenAI, @@ -68,11 +75,7 @@ func TestRequestAuthHeader(t *testing.T) { for _, c := range cases { t.Run(c.Name, func(t *testing.T) { az := DefaultConfig(c.Token) - if c.APIType == APITypeAzureAD { - az.APIType = APITypeAzureAD - } else if c.APIType == APITypeAzure { - az.APIType = APITypeAzure - } + az.APIType = c.APIType cli := NewClientWithConfig(az) req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil) From 115595b50897b54a2c266528d8070cd1de80f1ef Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Tue, 4 Apr 2023 02:18:06 +0800 Subject: [PATCH 24/26] test: refine TestOpenAIFullURL --- api_internal_test.go | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index 888a26617..f9395f2f9 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -7,25 +7,27 @@ import ( func TestOpenAIFullURL(t *testing.T) { cases := []struct { - Name string - BaseURL string - Engine string - Expect string + Name string + Suffix string + Expect string }{ { - "DefaultConfig", - "", - "", + "ChatCompletionsURL", + "/chat/completions", "https://api.openai.com/v1/chat/completions", }, + { + "CompletionsURL", + "/completions", + "https://api.openai.com/v1/completions", + }, } for _, c := range cases { t.Run(c.Name, func(t *testing.T) { az := DefaultConfig("dummy") cli := NewClientWithConfig(az) - // /openai/deployments/{engine}/chat/completions?api-version={api_version} - actual := cli.fullURL("/chat/completions") + actual := cli.fullURL(c.Suffix) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } From 59461eca5799ae05f9fa764786479ba9461f6abd Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Tue, 4 Apr 2023 02:22:37 +0800 Subject: [PATCH 25/26] chore: refine comments --- config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 9033204a5..8836f1c5d 100644 --- a/config.go +++ b/config.go @@ -29,8 +29,8 @@ type ClientConfig struct { BaseURL string OrgID string APIType APIType - APIVersion string // required for APITypeAzure or APITypeAzureAD - Engine string // required for APITypeAzure or APITypeAzureAD + APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + Engine string // required when APIType is APITypeAzure or APITypeAzureAD HTTPClient *http.Client From 222aa40c4fbbae0e42763b2d001e416deb028145 Mon Sep 17 00:00:00 2001 From: ttyS3 Date: Tue, 4 Apr 2023 02:29:08 +0800 Subject: [PATCH 26/26] feat: DefaultAzureConfig --- api_internal_test.go | 2 +- config.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index f9395f2f9..83dcafcf2 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -120,7 +120,7 @@ func TestAzureFullURL(t *testing.T) { for _, c := range cases { t.Run(c.Name, func(t *testing.T) { - az := DefaultAzure("dummy", c.BaseURL, c.Engine) + az := DefaultAzureConfig("dummy", c.BaseURL, c.Engine) cli := NewClientWithConfig(az) // /openai/deployments/{engine}/chat/completions?api-version={api_version} actual := cli.fullURL("/chat/completions") diff --git a/config.go b/config.go index 8836f1c5d..52e1efc3f 100644 --- a/config.go +++ b/config.go @@ -50,7 +50,7 @@ func DefaultConfig(authToken string) ClientConfig { } } -func DefaultAzure(apiKey, baseURL, engine string) ClientConfig { +func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig { return ClientConfig{ authToken: apiKey, BaseURL: baseURL,