Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions api_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestRequestAuthHeader(t *testing.T) {
az.OrgID = c.OrgID

cli := NewClientWithConfig(az)
req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil)
req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil, "")
if err != nil {
t.Errorf("Failed to create request: %v", err)
}
Expand All @@ -109,14 +109,16 @@ func TestRequestAuthHeader(t *testing.T) {

func TestAzureFullURL(t *testing.T) {
cases := []struct {
Name string
BaseURL string
Engine string
Expect string
Name string
BaseURL string
AzureModelMapper map[string]string
Model string
Expect string
}{
{
"AzureBaseURLWithSlashAutoStrip",
"https://httpbin.org/",
nil,
"chatgpt-demo",
"https://httpbin.org/" +
"openai/deployments/chatgpt-demo" +
Expand All @@ -125,6 +127,7 @@ func TestAzureFullURL(t *testing.T) {
{
"AzureBaseURLWithoutSlashOK",
"https://httpbin.org",
nil,
"chatgpt-demo",
"https://httpbin.org/" +
"openai/deployments/chatgpt-demo" +
Expand All @@ -134,10 +137,10 @@ func TestAzureFullURL(t *testing.T) {

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
az := DefaultAzureConfig("dummy", c.BaseURL, c.Engine)
az := DefaultAzureConfig("dummy", c.BaseURL)
cli := NewClientWithConfig(az)
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
actual := cli.fullURL("/chat/completions")
actual := cli.fullURL("/chat/completions", c.Model)
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
Expand Down
2 changes: 1 addition & 1 deletion audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (c *Client) callAudioAPI(
}

urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), &formBody)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), &formBody)
if err != nil {
return AudioResponse{}, err
}
Expand Down
2 changes: 1 addition & 1 deletion chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (c *Client) CreateChatCompletion(
return
}

req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (c *Client) CreateChatCompletionStream(
}

request.Stream = true
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request)
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model)
if err != nil {
return
}
Expand Down
22 changes: 17 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ func decodeString(body io.Reader, output *string) error {
return nil
}

func (c *Client) fullURL(suffix string) string {
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
// fullURL returns full URL for request.
// args[0] is model name, if API type is Azure, model name is required to get deployment name.
func (c *Client) fullURL(suffix string, args ...any) string {
// /openai/deployments/{model}/chat/completions?api-version={api_version}
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
Expand All @@ -108,8 +110,17 @@ func (c *Client) fullURL(suffix string) string {
if strings.Contains(suffix, "/models") {
return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion)
}
azureDeploymentName := "UNKNOWN"
if len(args) > 0 {
model, ok := args[0].(string)
if ok {
azureDeploymentName = c.config.GetAzureDeploymentByModel(model)
}
}
return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s",
baseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion)
baseURL, azureAPIPrefix, azureDeploymentsPrefix,
azureDeploymentName, suffix, c.config.APIVersion,
)
}

// c.config.APIType == APITypeOpenAI || c.config.APIType == ""
Expand All @@ -120,8 +131,9 @@ func (c *Client) newStreamRequest(
ctx context.Context,
method string,
urlSuffix string,
body any) (*http.Request, error) {
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body)
body any,
model string) (*http.Request, error) {
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix, model), body)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func (c *Client) CreateCompletion(
return
}

req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
if err != nil {
return
}
Expand Down
28 changes: 19 additions & 9 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openai

import (
"net/http"
"regexp"
)

const (
Expand All @@ -26,13 +27,12 @@ const AzureAPIKeyHeader = "api-key"
type ClientConfig struct {
authToken string

BaseURL string
OrgID string
APIType APIType
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
Engine string // required when APIType is APITypeAzure or APITypeAzureAD

HTTPClient *http.Client
BaseURL string
OrgID string
APIType APIType
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
HTTPClient *http.Client

EmptyMessagesLimit uint
}
Expand All @@ -50,14 +50,16 @@ func DefaultConfig(authToken string) ClientConfig {
}
}

func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig {
func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
return ClientConfig{
authToken: apiKey,
BaseURL: baseURL,
OrgID: "",
APIType: APITypeAzure,
APIVersion: "2023-03-15-preview",
Engine: engine,
AzureModelMapperFunc: func(model string) string {
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
},

HTTPClient: &http.Client{},

Expand All @@ -68,3 +70,11 @@ func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig {
func (ClientConfig) String() string {
return "<OpenAI API ClientConfig>"
}

func (c ClientConfig) GetAzureDeploymentByModel(model string) string {
if c.AzureModelMapperFunc != nil {
return c.AzureModelMapperFunc(model)
}

return model
}
62 changes: 62 additions & 0 deletions config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package openai_test

import (
"testing"

. "github.com/sashabaranov/go-openai"
)

func TestGetAzureDeploymentByModel(t *testing.T) {
cases := []struct {
Model string
AzureModelMapperFunc func(model string) string
Expect string
}{
{
Model: "gpt-3.5-turbo",
Expect: "gpt-35-turbo",
},
{
Model: "gpt-3.5-turbo-0301",
Expect: "gpt-35-turbo-0301",
},
{
Model: "text-embedding-ada-002",
Expect: "text-embedding-ada-002",
},
{
Model: "",
Expect: "",
},
{
Model: "models",
Expect: "models",
},
{
Model: "gpt-3.5-turbo",
Expect: "my-gpt35",
AzureModelMapperFunc: func(model string) string {
modelmapper := map[string]string{
"gpt-3.5-turbo": "my-gpt35",
}
if val, ok := modelmapper[model]; ok {
return val
}
return model
},
},
}

for _, c := range cases {
t.Run(c.Model, func(t *testing.T) {
conf := DefaultAzureConfig("", "https://test.openai.azure.com/")
if c.AzureModelMapperFunc != nil {
conf.AzureModelMapperFunc = c.AzureModelMapperFunc
}
actual := conf.GetAzureDeploymentByModel(c.Model)
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
})
}
}
3 changes: 2 additions & 1 deletion edits.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openai

import (
"context"
"fmt"
"net/http"
)

Expand Down Expand Up @@ -31,7 +32,7 @@ type EditsResponse struct {

// Perform an API call to the Edits endpoint.
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits"), request)
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ type EmbeddingRequest struct {
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
// https://beta.openai.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request)
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request)
if err != nil {
return
}
Expand Down
3 changes: 1 addition & 2 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,7 @@ func Example_chatbot() {
func ExampleDefaultAzureConfig() {
azureKey := os.Getenv("AZURE_OPENAI_API_KEY") // Your azure API key
azureEndpoint := os.Getenv("AZURE_OPENAI_ENDPOINT") // Your azure OpenAI endpoint
azureModel := os.Getenv("AZURE_OPENAI_MODEL") // Your model deployment name
config := openai.DefaultAzureConfig(azureKey, azureEndpoint, azureModel)
config := openai.DefaultAzureConfig(azureKey, azureEndpoint)
client := openai.NewClientWithConfig(config)
resp, err := client.CreateChatCompletion(
context.Background(),
Expand Down
2 changes: 1 addition & 1 deletion models_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestAzureListModels(t *testing.T) {
ts.Start()
defer ts.Close()

config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/", "dummyengine")
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
config.BaseURL = ts.URL
client := NewClientWithConfig(config)
ctx := context.Background()
Expand Down
2 changes: 1 addition & 1 deletion moderation.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type ModerationResponse struct {
// Moderations — perform a moderation api call over a string.
// Input can be an array or slice but a string will reduce the complexity.
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations"), request)
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (c *Client) CreateCompletionStream(
}

request.Stream = true
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request)
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model)
if err != nil {
return
}
Expand Down