diff --git a/api_internal_test.go b/api_internal_test.go index 9651ad402..bbf5f8b43 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -25,8 +25,7 @@ func TestOpenAIFullURL(t *testing.T) { for _, c := range cases { t.Run(c.Name, func(t *testing.T) { - az := DefaultConfig("dummy") - cli := NewClientWithConfig(az) + cli := NewClient("dummy") actual := cli.fullURL(c.Suffix) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) @@ -89,11 +88,7 @@ func TestRequestAuthHeader(t *testing.T) { for _, c := range cases { t.Run(c.Name, func(t *testing.T) { - az := DefaultConfig(c.Token) - az.APIType = c.APIType - az.OrgID = c.OrgID - - cli := NewClientWithConfig(az) + cli := NewAzureClient(c.Token, "", "", WithOrganizationID(c.OrgID), WithSpecificAPIType(c.APIType)) req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil) if err != nil { t.Errorf("Failed to create request: %v", err) @@ -134,8 +129,7 @@ func TestAzureFullURL(t *testing.T) { for _, c := range cases { t.Run(c.Name, func(t *testing.T) { - az := DefaultAzureConfig("dummy", c.BaseURL, c.Engine) - cli := NewClientWithConfig(az) + cli := NewAzureClient("dummy", c.BaseURL, c.Engine) // /openai/deployments/{engine}/chat/completions?api-version={api_version} actual := cli.fullURL("/chat/completions") if actual != c.Expect { diff --git a/api_test.go b/api_test.go index ecba25625..56f3d43e6 100644 --- a/api_test.go +++ b/api_test.go @@ -226,9 +226,7 @@ func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) { func TestRequestError(t *testing.T) { var err error - config := DefaultConfig("dummy") - config.BaseURL = "https://httpbin.org/status/418?" - c := NewClientWithConfig(config) + c := NewClient("dummy", WithCustomBaseURL("https://httpbin.org/status/418?")) ctx := context.Background() _, err = c.ListEngines(ctx) checks.HasError(t, err, "ListEngines did not fail") diff --git a/audio_test.go b/audio_test.go index daf51f28c..420f0b9f2 100644 --- a/audio_test.go +++ b/audio_test.go @@ -30,9 +30,7 @@ func TestAudio(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) testcases := []struct { name string @@ -78,9 +76,7 @@ func TestAudioWithOptionalArgs(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) testcases := []struct { name string diff --git a/chat_stream_test.go b/chat_stream_test.go index afcb86d5e..1a7e267fa 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -15,9 +15,7 @@ import ( ) func TestChatCompletionsStreamWrongModel(t *testing.T) { - config := DefaultConfig("whatever") - config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := NewClient("whatever", WithCustomBaseURL("http://localhost/v1")) ctx := context.Background() req := ChatCompletionRequest{ @@ -61,14 +59,14 @@ func TestCreateChatCompletionStream(t *testing.T) { defer server.Close() // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + httpClient := &http.Client{ + Transport: &tokenRoundTripper{ + test.GetTestToken(), + http.DefaultTransport, + }, } - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(server.URL+"/v1"), WithCustomClient(httpClient)) ctx := context.Background() request := ChatCompletionRequest{ @@ -168,14 +166,14 @@ func TestCreateChatCompletionStreamError(t *testing.T) { defer server.Close() // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + httpClient := &http.Client{ + Transport: &tokenRoundTripper{ + test.GetTestToken(), + http.DefaultTransport, + }, } - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(server.URL+"/v1"), WithCustomClient(httpClient)) ctx := context.Background() request := ChatCompletionRequest{ @@ -225,14 +223,14 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { defer ts.Close() // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + httpClient := &http.Client{ + Transport: &tokenRoundTripper{ + test.GetTestToken(), + http.DefaultTransport, + }, } - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1"), WithCustomClient(httpClient)) ctx := context.Background() request := ChatCompletionRequest{ diff --git a/chat_test.go b/chat_test.go index ce302a69f..639cfcb76 100644 --- a/chat_test.go +++ b/chat_test.go @@ -17,9 +17,7 @@ import ( ) func TestChatCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") - config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := NewClient("whatever", WithCustomBaseURL("http://localhost/v1")) ctx := context.Background() req := ChatCompletionRequest{ @@ -38,9 +36,7 @@ func TestChatCompletionsWrongModel(t *testing.T) { } func TestChatCompletionsWithStream(t *testing.T) { - config := DefaultConfig("whatever") - config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := NewClient("whatever", WithCustomBaseURL("http://localhost/v1")) ctx := context.Background() req := ChatCompletionRequest{ @@ -60,9 +56,7 @@ func TestChatCompletions(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) ctx := context.Background() req := ChatCompletionRequest{ diff --git a/client.go b/client.go index 368947b23..023b60443 100644 --- a/client.go +++ b/client.go @@ -18,13 +18,28 @@ type Client struct { } // NewClient creates new OpenAI API client. -func NewClient(authToken string) *Client { +func NewClient(authToken string, options ...Option) *Client { config := DefaultConfig(authToken) - return NewClientWithConfig(config) + + for _, opt := range options { + opt(&config) + } + + return newClient(config) } -// NewClientWithConfig creates new OpenAI API client for specified config. -func NewClientWithConfig(config ClientConfig) *Client { +// NewAzureClient create new openAI API from Azure client +func NewAzureClient(authToken string, baseUrl string, engine string, options ...Option) *Client { + config := DefaultAzureConfig(authToken, baseUrl, engine) + + for _, opt := range options { + opt(&config) + } + + return newClient(config) +} + +func newClient(config ClientConfig) *Client { return &Client{ config: config, requestBuilder: newRequestBuilder(), @@ -34,15 +49,6 @@ func NewClientWithConfig(config ClientConfig) *Client { } } -// NewOrgClient creates new OpenAI API client for specified Organization ID. -// -// Deprecated: Please use NewClientWithConfig. -func NewOrgClient(authToken, org string) *Client { - config := DefaultConfig(authToken) - config.OrgID = org - return NewClientWithConfig(config) -} - func (c *Client) sendRequest(req *http.Request, v any) error { req.Header.Set("Accept", "application/json; charset=utf-8") // Azure API Key authentication diff --git a/client_test.go b/client_test.go index 7bea6dd87..f56991390 100644 --- a/client_test.go +++ b/client_test.go @@ -8,13 +8,13 @@ import ( func TestClient(t *testing.T) { const mockToken = "mock token" - client := NewClient(mockToken) + const mockOrg = "mock org" + + client := NewClient(mockToken, WithOrganizationID(mockOrg)) if client.config.authToken != mockToken { t.Errorf("Client does not contain proper token") } - const mockOrg = "mock org" - client = NewOrgClient(mockToken, mockOrg) if client.config.authToken != mockToken { t.Errorf("Client does not contain proper token") } diff --git a/completion_test.go b/completion_test.go index 2e302591a..4c51e4a19 100644 --- a/completion_test.go +++ b/completion_test.go @@ -18,9 +18,7 @@ import ( ) func TestCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") - config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := NewClient("whatever", WithCustomBaseURL("http://localhost/v1")) _, err := client.CreateCompletion( context.Background(), @@ -35,8 +33,7 @@ func TestCompletionsWrongModel(t *testing.T) { } func TestCompletionWithStream(t *testing.T) { - config := DefaultConfig("whatever") - client := NewClientWithConfig(config) + client := NewClient("whatever") ctx := context.Background() req := CompletionRequest{Stream: true} @@ -56,9 +53,7 @@ func TestCompletions(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) ctx := context.Background() req := CompletionRequest{ diff --git a/edits_test.go b/edits_test.go index fa6c12825..174b2d150 100644 --- a/edits_test.go +++ b/edits_test.go @@ -24,9 +24,7 @@ func TestEdits(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) ctx := context.Background() // create an edit request diff --git a/embeddings_test.go b/embeddings_test.go index 252f7a5a0..d029edf30 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -81,9 +81,7 @@ func TestEmbeddingEndpoint(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) ctx := context.Background() _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) diff --git a/error_accumulator_test.go b/error_accumulator_test.go index ecf954d58..b8cb4db90 100644 --- a/error_accumulator_test.go +++ b/error_accumulator_test.go @@ -80,9 +80,7 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) ctx := context.Background() diff --git a/example_test.go b/example_test.go new file mode 100644 index 000000000..8ff23e1e8 --- /dev/null +++ b/example_test.go @@ -0,0 +1,50 @@ +package openai + +import ( + "context" + "fmt" +) + +func ExampleNewClient() { + cli := NewClient("your-api-key") + + resp, err := cli.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + + if err != nil { + fmt.Println(err) + } + + fmt.Println(resp) +} + +func ExampleNewAzureClient() { + cli := NewAzureClient("your-api-key", "https://your Azure OpenAI Endpoint ", "your Model deployment name") + + resp, err := cli.CreateChatCompletion( + context.Background(), + ChatCompletionRequest{ + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello Azure OpenAI!", + }, + }, + }, + ) + + if err != nil { + fmt.Printf("ChatCompletion error: %v\n", err) + return + } + + fmt.Println(resp.Choices[0].Message.Content) +} diff --git a/files_test.go b/files_test.go index bb06498c8..88fe921f0 100644 --- a/files_test.go +++ b/files_test.go @@ -24,9 +24,7 @@ func TestFileUpload(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) ctx := context.Background() req := FileRequest{ @@ -81,9 +79,7 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) { } func TestFileUploadWithFailingFormBuilder(t *testing.T) { - config := DefaultConfig("") - config.BaseURL = "" - client := NewClientWithConfig(config) + client := NewClient("", WithCustomBaseURL("")) mockBuilder := &mockFormBuilder{} client.createFormBuilder = func(io.Writer) formBuilder { return mockBuilder @@ -128,9 +124,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) { } func TestFileUploadWithNonExistentPath(t *testing.T) { - config := DefaultConfig("") - config.BaseURL = "" - client := NewClientWithConfig(config) + client := NewClient("", WithCustomBaseURL("")) ctx := context.Background() req := FileRequest{ diff --git a/fine_tunes_test.go b/fine_tunes_test.go index c60254993..2b37a73f9 100644 --- a/fine_tunes_test.go +++ b/fine_tunes_test.go @@ -65,9 +65,7 @@ func TestFineTunes(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) ctx := context.Background() _, err = client.ListFineTunes(ctx) diff --git a/image_test.go b/image_test.go index 4a7dad58f..e85dd452e 100644 --- a/image_test.go +++ b/image_test.go @@ -23,9 +23,7 @@ func TestImages(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) ctx := context.Background() req := ImageRequest{} @@ -93,9 +91,7 @@ func TestImageEdit(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) ctx := context.Background() origin, err := os.Create("image.png") @@ -138,9 +134,7 @@ func TestImageEditWithoutMask(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) ctx := context.Background() origin, err := os.Create("image.png") @@ -205,9 +199,7 @@ func TestImageVariation(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) ctx := context.Background() origin, err := os.Create("image.png") @@ -285,9 +277,7 @@ func (fb *mockFormBuilder) formDataContentType() string { } func TestImageFormBuilderFailures(t *testing.T) { - config := DefaultConfig("") - config.BaseURL = "" - client := NewClientWithConfig(config) + client := NewClient("", WithCustomBaseURL("")) mockBuilder := &mockFormBuilder{} client.createFormBuilder = func(io.Writer) formBuilder { @@ -352,9 +342,7 @@ func TestImageFormBuilderFailures(t *testing.T) { } func TestVariImageFormBuilderFailures(t *testing.T) { - config := DefaultConfig("") - config.BaseURL = "" - client := NewClientWithConfig(config) + client := NewClient("", WithCustomBaseURL("")) mockBuilder := &mockFormBuilder{} client.createFormBuilder = func(io.Writer) formBuilder { diff --git a/models_test.go b/models_test.go index dad59be79..951eef60e 100644 --- a/models_test.go +++ b/models_test.go @@ -22,9 +22,7 @@ func TestListModels(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) ctx := context.Background() _, err = client.ListModels(ctx) diff --git a/moderation_test.go b/moderation_test.go index 2c1145627..bd7656bb8 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -26,9 +26,7 @@ func TestModerations(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) ctx := context.Background() // create an edit request diff --git a/options.go b/options.go new file mode 100644 index 000000000..d74a3735b --- /dev/null +++ b/options.go @@ -0,0 +1,47 @@ +package openai + +import "net/http" + +type Option func(c *ClientConfig) + +func WithCustomBaseURL(url string) Option { + return func(c *ClientConfig) { + c.BaseURL = url + } +} + +func WithOrganizationID(orgID string) Option { + return func(c *ClientConfig) { + c.OrgID = orgID + } +} + +func WithSpecificAPIType(apiType APIType) Option { + return func(c *ClientConfig) { + c.APIType = apiType + } +} + +func WithCustomAPIVersion(version string) Option { + return func(c *ClientConfig) { + c.APIVersion = version + } +} + +func WithCustomEngine(engine string) Option { + return func(c *ClientConfig) { + c.Engine = engine + } +} + +func WithCustomClient(client *http.Client) Option { + return func(c *ClientConfig) { + c.HTTPClient = client + } +} + +func WithEmptyMessagesLimit(limit uint) Option { + return func(c *ClientConfig) { + c.EmptyMessagesLimit = limit + } +} diff --git a/request_builder_test.go b/request_builder_test.go index b1adbf1c6..b6ca6c727 100644 --- a/request_builder_test.go +++ b/request_builder_test.go @@ -44,9 +44,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) client.requestBuilder = &failingRequestBuilder{} ctx := context.Background() @@ -158,9 +156,7 @@ func TestReturnsRequestBuilderErrorsAddtion(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1")) client.requestBuilder = &failingRequestBuilder{} ctx := context.Background() diff --git a/stream_test.go b/stream_test.go index a5c591fde..167a0da33 100644 --- a/stream_test.go +++ b/stream_test.go @@ -14,9 +14,7 @@ import ( ) func TestCompletionsStreamWrongModel(t *testing.T) { - config := DefaultConfig("whatever") - config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := NewClient("whatever", WithCustomBaseURL("http://localhost/v1")) _, err := client.CreateCompletionStream( context.Background(), @@ -55,14 +53,14 @@ func TestCreateCompletionStream(t *testing.T) { defer server.Close() // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + httpClient := &http.Client{ + Transport: &tokenRoundTripper{ + test.GetTestToken(), + http.DefaultTransport, + }, } - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(server.URL+"/v1"), WithCustomClient(httpClient)) ctx := context.Background() request := CompletionRequest{ @@ -140,14 +138,14 @@ func TestCreateCompletionStreamError(t *testing.T) { defer server.Close() // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + httpClient := &http.Client{ + Transport: &tokenRoundTripper{ + test.GetTestToken(), + http.DefaultTransport, + }, } - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(server.URL+"/v1"), WithCustomClient(httpClient)) ctx := context.Background() request := CompletionRequest{ @@ -192,14 +190,14 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { defer ts.Close() // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + httpClient := &http.Client{ + Transport: &tokenRoundTripper{ + test.GetTestToken(), + http.DefaultTransport, + }, } - client := NewClientWithConfig(config) + client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1"), WithCustomClient(httpClient)) ctx := context.Background() request := CompletionRequest{