Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"fmt"
"net/http"
"os"

utils "github.com/sashabaranov/go-openai/internal"
)

// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI.
Expand Down Expand Up @@ -72,7 +74,7 @@ func (c *Client) callAudioAPI(
if err != nil {
return AudioResponse{}, err
}
req.Header.Add("Content-Type", builder.formDataContentType())
req.Header.Add("Content-Type", builder.FormDataContentType())

if request.HasJSONResponse() {
err = c.sendRequest(req, &response)
Expand All @@ -92,55 +94,55 @@ func (r AudioRequest) HasJSONResponse() bool {

// audioMultipartForm creates a form with audio file contents and the name of the model to use for
// audio processing.
func audioMultipartForm(request AudioRequest, b formBuilder) error {
func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
f, err := os.Open(request.FilePath)
if err != nil {
return fmt.Errorf("opening audio file: %w", err)
}
defer f.Close()

err = b.createFormFile("file", f)
err = b.CreateFormFile("file", f)
if err != nil {
return fmt.Errorf("creating form file: %w", err)
}

err = b.writeField("model", request.Model)
err = b.WriteField("model", request.Model)
if err != nil {
return fmt.Errorf("writing model name: %w", err)
}

// Create a form field for the prompt (if provided)
if request.Prompt != "" {
err = b.writeField("prompt", request.Prompt)
err = b.WriteField("prompt", request.Prompt)
if err != nil {
return fmt.Errorf("writing prompt: %w", err)
}
}

// Create a form field for the format (if provided)
if request.Format != "" {
err = b.writeField("response_format", string(request.Format))
err = b.WriteField("response_format", string(request.Format))
if err != nil {
return fmt.Errorf("writing format: %w", err)
}
}

// Create a form field for the temperature (if provided)
if request.Temperature != 0 {
err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature))
err = b.WriteField("temperature", fmt.Sprintf("%.2f", request.Temperature))
if err != nil {
return fmt.Errorf("writing temperature: %w", err)
}
}

// Create a form field for the language (if provided)
if request.Language != "" {
err = b.writeField("language", request.Language)
err = b.WriteField("language", request.Language)
if err != nil {
return fmt.Errorf("writing language: %w", err)
}
}

// Close the multipart writer
return b.close()
return b.Close()
}
8 changes: 5 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ import (
"io"
"net/http"
"strings"

utils "github.com/sashabaranov/go-openai/internal"
)

// Client is OpenAI GPT-3 API client.
type Client struct {
config ClientConfig

requestBuilder requestBuilder
createFormBuilder func(io.Writer) formBuilder
createFormBuilder func(io.Writer) utils.FormBuilder
}

// NewClient creates new OpenAI API client.
Expand All @@ -28,8 +30,8 @@ func NewClientWithConfig(config ClientConfig) *Client {
return &Client{
config: config,
requestBuilder: newRequestBuilder(),
createFormBuilder: func(body io.Writer) formBuilder {
return newFormBuilder(body)
createFormBuilder: func(body io.Writer) utils.FormBuilder {
return utils.NewFormBuilder(body)
},
}
}
Expand Down
8 changes: 4 additions & 4 deletions files.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
var b bytes.Buffer
builder := c.createFormBuilder(&b)

err = builder.writeField("purpose", request.Purpose)
err = builder.WriteField("purpose", request.Purpose)
if err != nil {
return
}
Expand All @@ -46,12 +46,12 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
return
}

err = builder.createFormFile("file", fileData)
err = builder.CreateFormFile("file", fileData)
if err != nil {
return
}

err = builder.close()
err = builder.Close()
if err != nil {
return
}
Expand All @@ -61,7 +61,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
return
}

req.Header.Set("Content-Type", builder.formDataContentType())
req.Header.Set("Content-Type", builder.FormDataContentType())

err = c.sendRequest(req, &file)

Expand Down
3 changes: 2 additions & 1 deletion files_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openai //nolint:testpackage // testing private field

import (
. "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

Expand Down Expand Up @@ -85,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) {
config.BaseURL = ""
client := NewClientWithConfig(config)
mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) formBuilder {
client.createFormBuilder = func(io.Writer) FormBuilder {
return mockBuilder
}

Expand Down
49 changes: 0 additions & 49 deletions form_builder.go

This file was deleted.

28 changes: 14 additions & 14 deletions image.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,40 +69,40 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
builder := c.createFormBuilder(body)

// image
err = builder.createFormFile("image", request.Image)
err = builder.CreateFormFile("image", request.Image)
if err != nil {
return
}

// mask, it is optional
if request.Mask != nil {
err = builder.createFormFile("mask", request.Mask)
err = builder.CreateFormFile("mask", request.Mask)
if err != nil {
return
}
}

err = builder.writeField("prompt", request.Prompt)
err = builder.WriteField("prompt", request.Prompt)
if err != nil {
return
}

err = builder.writeField("n", strconv.Itoa(request.N))
err = builder.WriteField("n", strconv.Itoa(request.N))
if err != nil {
return
}

err = builder.writeField("size", request.Size)
err = builder.WriteField("size", request.Size)
if err != nil {
return
}

err = builder.writeField("response_format", request.ResponseFormat)
err = builder.WriteField("response_format", request.ResponseFormat)
if err != nil {
return
}

err = builder.close()
err = builder.Close()
if err != nil {
return
}
Expand All @@ -113,7 +113,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
return
}

req.Header.Set("Content-Type", builder.formDataContentType())
req.Header.Set("Content-Type", builder.FormDataContentType())
err = c.sendRequest(req, &response)
return
}
Expand All @@ -133,27 +133,27 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
builder := c.createFormBuilder(body)

// image
err = builder.createFormFile("image", request.Image)
err = builder.CreateFormFile("image", request.Image)
if err != nil {
return
}

err = builder.writeField("n", strconv.Itoa(request.N))
err = builder.WriteField("n", strconv.Itoa(request.N))
if err != nil {
return
}

err = builder.writeField("size", request.Size)
err = builder.WriteField("size", request.Size)
if err != nil {
return
}

err = builder.writeField("response_format", request.ResponseFormat)
err = builder.WriteField("response_format", request.ResponseFormat)
if err != nil {
return
}

err = builder.close()
err = builder.Close()
if err != nil {
return
}
Expand All @@ -165,7 +165,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
return
}

req.Header.Set("Content-Type", builder.formDataContentType())
req.Header.Set("Content-Type", builder.FormDataContentType())
err = c.sendRequest(req, &response)
return
}
13 changes: 7 additions & 6 deletions image_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openai //nolint:testpackage // testing private field

import (
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

Expand Down Expand Up @@ -268,19 +269,19 @@ type mockFormBuilder struct {
mockClose func() error
}

func (fb *mockFormBuilder) createFormFile(fieldname string, file *os.File) error {
func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
return fb.mockCreateFormFile(fieldname, file)
}

func (fb *mockFormBuilder) writeField(fieldname, value string) error {
func (fb *mockFormBuilder) WriteField(fieldname, value string) error {
return fb.mockWriteField(fieldname, value)
}

func (fb *mockFormBuilder) close() error {
func (fb *mockFormBuilder) Close() error {
return fb.mockClose()
}

func (fb *mockFormBuilder) formDataContentType() string {
func (fb *mockFormBuilder) FormDataContentType() string {
return ""
}

Expand All @@ -290,7 +291,7 @@ func TestImageFormBuilderFailures(t *testing.T) {
client := NewClientWithConfig(config)

mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) formBuilder {
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
return mockBuilder
}
ctx := context.Background()
Expand Down Expand Up @@ -357,7 +358,7 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
client := NewClientWithConfig(config)

mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) formBuilder {
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
return mockBuilder
}
ctx := context.Background()
Expand Down
Loading