Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions api_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ func TestOpenAIFullURL(t *testing.T) {

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
az := DefaultConfig("dummy")
cli := NewClientWithConfig(az)
cli := NewClient("dummy")
actual := cli.fullURL(c.Suffix)
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
Expand Down Expand Up @@ -89,11 +88,7 @@ func TestRequestAuthHeader(t *testing.T) {

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
az := DefaultConfig(c.Token)
az.APIType = c.APIType
az.OrgID = c.OrgID

cli := NewClientWithConfig(az)
cli := NewAzureClient(c.Token, "", "", WithOrganizationID(c.OrgID), WithSpecificAPIType(c.APIType))
req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil)
if err != nil {
t.Errorf("Failed to create request: %v", err)
Expand Down Expand Up @@ -134,8 +129,7 @@ func TestAzureFullURL(t *testing.T) {

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
az := DefaultAzureConfig("dummy", c.BaseURL, c.Engine)
cli := NewClientWithConfig(az)
cli := NewAzureClient("dummy", c.BaseURL, c.Engine)
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
actual := cli.fullURL("/chat/completions")
if actual != c.Expect {
Expand Down
4 changes: 1 addition & 3 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) {
func TestRequestError(t *testing.T) {
var err error

config := DefaultConfig("dummy")
config.BaseURL = "https://httpbin.org/status/418?"
c := NewClientWithConfig(config)
c := NewClient("dummy", WithCustomBaseURL("https://httpbin.org/status/418?"))
ctx := context.Background()
_, err = c.ListEngines(ctx)
checks.HasError(t, err, "ListEngines did not fail")
Expand Down
8 changes: 2 additions & 6 deletions audio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ func TestAudio(t *testing.T) {
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1"))

testcases := []struct {
name string
Expand Down Expand Up @@ -78,9 +76,7 @@ func TestAudioWithOptionalArgs(t *testing.T) {
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1"))

testcases := []struct {
name string
Expand Down
40 changes: 19 additions & 21 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ import (
)

func TestChatCompletionsStreamWrongModel(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config)
client := NewClient("whatever", WithCustomBaseURL("http://localhost/v1"))
ctx := context.Background()

req := ChatCompletionRequest{
Expand Down Expand Up @@ -61,14 +59,14 @@ func TestCreateChatCompletionStream(t *testing.T) {
defer server.Close()

// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
httpClient := &http.Client{
Transport: &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
},
}

client := NewClientWithConfig(config)
client := NewClient(test.GetTestToken(), WithCustomBaseURL(server.URL+"/v1"), WithCustomClient(httpClient))
ctx := context.Background()

request := ChatCompletionRequest{
Expand Down Expand Up @@ -168,14 +166,14 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
defer server.Close()

// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
httpClient := &http.Client{
Transport: &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
},
}

client := NewClientWithConfig(config)
client := NewClient(test.GetTestToken(), WithCustomBaseURL(server.URL+"/v1"), WithCustomClient(httpClient))
ctx := context.Background()

request := ChatCompletionRequest{
Expand Down Expand Up @@ -225,14 +223,14 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
defer ts.Close()

// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
httpClient := &http.Client{
Transport: &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
},
}

client := NewClientWithConfig(config)
client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1"), WithCustomClient(httpClient))
ctx := context.Background()

request := ChatCompletionRequest{
Expand Down
12 changes: 3 additions & 9 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ import (
)

func TestChatCompletionsWrongModel(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config)
client := NewClient("whatever", WithCustomBaseURL("http://localhost/v1"))
ctx := context.Background()

req := ChatCompletionRequest{
Expand All @@ -38,9 +36,7 @@ func TestChatCompletionsWrongModel(t *testing.T) {
}

func TestChatCompletionsWithStream(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config)
client := NewClient("whatever", WithCustomBaseURL("http://localhost/v1"))
ctx := context.Background()

req := ChatCompletionRequest{
Expand All @@ -60,9 +56,7 @@ func TestChatCompletions(t *testing.T) {
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1"))
ctx := context.Background()

req := ChatCompletionRequest{
Expand Down
32 changes: 19 additions & 13 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,28 @@ type Client struct {
}

// NewClient creates new OpenAI API client.
func NewClient(authToken string) *Client {
func NewClient(authToken string, options ...Option) *Client {
config := DefaultConfig(authToken)
return NewClientWithConfig(config)

for _, opt := range options {
opt(&config)
}

return newClient(config)
}

// NewClientWithConfig creates new OpenAI API client for specified config.
func NewClientWithConfig(config ClientConfig) *Client {
// NewAzureClient create new openAI API from Azure client
func NewAzureClient(authToken string, baseUrl string, engine string, options ...Option) *Client {
config := DefaultAzureConfig(authToken, baseUrl, engine)

for _, opt := range options {
opt(&config)
}

return newClient(config)
}

func newClient(config ClientConfig) *Client {
return &Client{
config: config,
requestBuilder: newRequestBuilder(),
Expand All @@ -34,15 +49,6 @@ func NewClientWithConfig(config ClientConfig) *Client {
}
}

// NewOrgClient creates new OpenAI API client for specified Organization ID.
//
// Deprecated: Please use NewClientWithConfig.
func NewOrgClient(authToken, org string) *Client {
config := DefaultConfig(authToken)
config.OrgID = org
return NewClientWithConfig(config)
}

func (c *Client) sendRequest(req *http.Request, v any) error {
req.Header.Set("Accept", "application/json; charset=utf-8")
// Azure API Key authentication
Expand Down
6 changes: 3 additions & 3 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ import (

func TestClient(t *testing.T) {
const mockToken = "mock token"
client := NewClient(mockToken)
const mockOrg = "mock org"

client := NewClient(mockToken, WithOrganizationID(mockOrg))
if client.config.authToken != mockToken {
t.Errorf("Client does not contain proper token")
}

const mockOrg = "mock org"
client = NewOrgClient(mockToken, mockOrg)
if client.config.authToken != mockToken {
t.Errorf("Client does not contain proper token")
}
Expand Down
11 changes: 3 additions & 8 deletions completion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ import (
)

func TestCompletionsWrongModel(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config)
client := NewClient("whatever", WithCustomBaseURL("http://localhost/v1"))

_, err := client.CreateCompletion(
context.Background(),
Expand All @@ -35,8 +33,7 @@ func TestCompletionsWrongModel(t *testing.T) {
}

func TestCompletionWithStream(t *testing.T) {
config := DefaultConfig("whatever")
client := NewClientWithConfig(config)
client := NewClient("whatever")

ctx := context.Background()
req := CompletionRequest{Stream: true}
Expand All @@ -56,9 +53,7 @@ func TestCompletions(t *testing.T) {
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1"))
ctx := context.Background()

req := CompletionRequest{
Expand Down
4 changes: 1 addition & 3 deletions edits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ func TestEdits(t *testing.T) {
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1"))
ctx := context.Background()

// create an edit request
Expand Down
4 changes: 1 addition & 3 deletions embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ func TestEmbeddingEndpoint(t *testing.T) {
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1"))
ctx := context.Background()

_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
Expand Down
4 changes: 1 addition & 3 deletions error_accumulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) {
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1"))

ctx := context.Background()

Expand Down
50 changes: 50 additions & 0 deletions example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package openai

import (
"context"
"fmt"
)

func ExampleNewClient() {
cli := NewClient("your-api-key")

resp, err := cli.CreateChatCompletion(context.Background(), ChatCompletionRequest{
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
})

if err != nil {
fmt.Println(err)
}

fmt.Println(resp)
}

func ExampleNewAzureClient() {
cli := NewAzureClient("your-api-key", "https://your Azure OpenAI Endpoint ", "your Model deployment name")

resp, err := cli.CreateChatCompletion(
context.Background(),
ChatCompletionRequest{
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello Azure OpenAI!",
},
},
},
)

if err != nil {
fmt.Printf("ChatCompletion error: %v\n", err)
return
}

fmt.Println(resp.Choices[0].Message.Content)
}
12 changes: 3 additions & 9 deletions files_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ func TestFileUpload(t *testing.T) {
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1"))
ctx := context.Background()

req := FileRequest{
Expand Down Expand Up @@ -81,9 +79,7 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) {
}

func TestFileUploadWithFailingFormBuilder(t *testing.T) {
config := DefaultConfig("")
config.BaseURL = ""
client := NewClientWithConfig(config)
client := NewClient("", WithCustomBaseURL(""))
mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) formBuilder {
return mockBuilder
Expand Down Expand Up @@ -128,9 +124,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) {
}

func TestFileUploadWithNonExistentPath(t *testing.T) {
config := DefaultConfig("")
config.BaseURL = ""
client := NewClientWithConfig(config)
client := NewClient("", WithCustomBaseURL(""))

ctx := context.Background()
req := FileRequest{
Expand Down
4 changes: 1 addition & 3 deletions fine_tunes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ func TestFineTunes(t *testing.T) {
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client := NewClient(test.GetTestToken(), WithCustomBaseURL(ts.URL+"/v1"))
ctx := context.Background()

_, err = client.ListFineTunes(ctx)
Expand Down
Loading