From 2e4d90c65747768cb4a029972f3aabd7ec04633c Mon Sep 17 00:00:00 2001 From: JoyShi Date: Mon, 15 May 2023 11:05:49 +0800 Subject: [PATCH 1/5] Move form_uilder into internal pkg. --- audio.go | 2 +- client.go | 8 +++++--- form_builder.go => internal/test/utils/form_builder.go | 4 ++-- .../test/utils/form_builder_test.go | 4 ++-- 4 files changed, 10 insertions(+), 8 deletions(-) rename form_builder.go => internal/test/utils/form_builder.go (91%) rename form_builder_test.go => internal/test/utils/form_builder_test.go (94%) diff --git a/audio.go b/audio.go index 12c6ccc22..51e1c6484 100644 --- a/audio.go +++ b/audio.go @@ -92,7 +92,7 @@ func (r AudioRequest) HasJSONResponse() bool { // audioMultipartForm creates a form with audio file contents and the name of the model to use for // audio processing. -func audioMultipartForm(request AudioRequest, b formBuilder) error { +func audioMultipartForm(request AudioRequest, b FormBuilder) error { f, err := os.Open(request.FilePath) if err != nil { return fmt.Errorf("opening audio file: %w", err) diff --git a/client.go b/client.go index 9579ba27b..9a6900a87 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,8 @@ import ( "io" "net/http" "strings" + + "github.com/sashabaranov/go-openai/internal/test/checks/utils" ) // Client is OpenAI GPT-3 API client. @@ -14,7 +16,7 @@ type Client struct { config ClientConfig requestBuilder requestBuilder - createFormBuilder func(io.Writer) formBuilder + createFormBuilder func(io.Writer) utils.FormBuilder } // NewClient creates new OpenAI API client. @@ -28,8 +30,8 @@ func NewClientWithConfig(config ClientConfig) *Client { return &Client{ config: config, requestBuilder: newRequestBuilder(), - createFormBuilder: func(body io.Writer) formBuilder { - return newFormBuilder(body) + createFormBuilder: func(body io.Writer) utils.FormBuilder { + return utils.NewFormBuilder(body) }, } } diff --git a/form_builder.go b/internal/test/utils/form_builder.go similarity index 91% rename from form_builder.go rename to internal/test/utils/form_builder.go index 7fbb1643a..13bebad5f 100644 --- a/form_builder.go +++ b/internal/test/utils/form_builder.go @@ -6,7 +6,7 @@ import ( "os" ) -type formBuilder interface { +type FormBuilder interface { createFormFile(fieldname string, file *os.File) error writeField(fieldname, value string) error close() error @@ -17,7 +17,7 @@ type defaultFormBuilder struct { writer *multipart.Writer } -func newFormBuilder(body io.Writer) *defaultFormBuilder { +func NewFormBuilder(body io.Writer) *defaultFormBuilder { return &defaultFormBuilder{ writer: multipart.NewWriter(body), } diff --git a/form_builder_test.go b/internal/test/utils/form_builder_test.go similarity index 94% rename from form_builder_test.go rename to internal/test/utils/form_builder_test.go index 78e2ec968..4432636cf 100644 --- a/form_builder_test.go +++ b/internal/test/utils/form_builder_test.go @@ -30,7 +30,7 @@ func TestFormBuilderWithFailingWriter(t *testing.T) { defer file.Close() defer os.Remove(file.Name()) - builder := newFormBuilder(&failingWriter{}) + builder := NewFormBuilder(&failingWriter{}) err = builder.createFormFile("file", file) checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails") } @@ -47,7 +47,7 @@ func TestFormBuilderWithClosedFile(t *testing.T) { defer os.Remove(file.Name()) body := &bytes.Buffer{} - builder := newFormBuilder(body) + builder := NewFormBuilder(body) err = builder.createFormFile("file", file) checks.HasError(t, err, "formbuilder should return error if file is closed") checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") From 97909c392ec5c705e4e19fa9e2fce5e4e9792792 Mon Sep 17 00:00:00 2001 From: JoyShi <286753440@qq.com> Date: Mon, 15 May 2023 11:10:02 +0800 Subject: [PATCH 2/5] Fix import of audio.go --- audio.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/audio.go b/audio.go index 51e1c6484..ef799a11b 100644 --- a/audio.go +++ b/audio.go @@ -6,6 +6,8 @@ import ( "fmt" "net/http" "os" + + "github.com/sashabaranov/go-openai/internal/test/checks/utils" ) // Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI. From f2edf62e75648c6bdb9671934f219d981a90058b Mon Sep 17 00:00:00 2001 From: JoyShi Date: Mon, 15 May 2023 11:32:44 +0800 Subject: [PATCH 3/5] Reorganize. --- audio.go | 20 ++++++------- client.go | 8 +++--- files.go | 8 +++--- files_test.go | 3 +- image.go | 28 +++++++++---------- image_test.go | 13 +++++---- internal/{test/utils => }/form_builder.go | 16 +++++------ .../{test/utils => }/form_builder_test.go | 4 +-- 8 files changed, 51 insertions(+), 49 deletions(-) rename internal/{test/utils => }/form_builder.go (62%) rename internal/{test/utils => }/form_builder_test.go (93%) diff --git a/audio.go b/audio.go index ef799a11b..47d0e9487 100644 --- a/audio.go +++ b/audio.go @@ -6,8 +6,8 @@ import ( "fmt" "net/http" "os" - - "github.com/sashabaranov/go-openai/internal/test/checks/utils" + + . "github.com/sashabaranov/go-openai/internal" ) // Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI. @@ -74,7 +74,7 @@ func (c *Client) callAudioAPI( if err != nil { return AudioResponse{}, err } - req.Header.Add("Content-Type", builder.formDataContentType()) + req.Header.Add("Content-Type", builder.FormDataContentType()) if request.HasJSONResponse() { err = c.sendRequest(req, &response) @@ -101,19 +101,19 @@ func audioMultipartForm(request AudioRequest, b FormBuilder) error { } defer f.Close() - err = b.createFormFile("file", f) + err = b.CreateFormFile("file", f) if err != nil { return fmt.Errorf("creating form file: %w", err) } - err = b.writeField("model", request.Model) + err = b.WriteField("model", request.Model) if err != nil { return fmt.Errorf("writing model name: %w", err) } // Create a form field for the prompt (if provided) if request.Prompt != "" { - err = b.writeField("prompt", request.Prompt) + err = b.WriteField("prompt", request.Prompt) if err != nil { return fmt.Errorf("writing prompt: %w", err) } @@ -121,7 +121,7 @@ func audioMultipartForm(request AudioRequest, b FormBuilder) error { // Create a form field for the format (if provided) if request.Format != "" { - err = b.writeField("response_format", string(request.Format)) + err = b.WriteField("response_format", string(request.Format)) if err != nil { return fmt.Errorf("writing format: %w", err) } @@ -129,7 +129,7 @@ func audioMultipartForm(request AudioRequest, b FormBuilder) error { // Create a form field for the temperature (if provided) if request.Temperature != 0 { - err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature)) + err = b.WriteField("temperature", fmt.Sprintf("%.2f", request.Temperature)) if err != nil { return fmt.Errorf("writing temperature: %w", err) } @@ -137,12 +137,12 @@ func audioMultipartForm(request AudioRequest, b FormBuilder) error { // Create a form field for the language (if provided) if request.Language != "" { - err = b.writeField("language", request.Language) + err = b.WriteField("language", request.Language) if err != nil { return fmt.Errorf("writing language: %w", err) } } // Close the multipart writer - return b.close() + return b.Close() } diff --git a/client.go b/client.go index 9a6900a87..f6a8b6837 100644 --- a/client.go +++ b/client.go @@ -8,7 +8,7 @@ import ( "net/http" "strings" - "github.com/sashabaranov/go-openai/internal/test/checks/utils" + . "github.com/sashabaranov/go-openai/internal" ) // Client is OpenAI GPT-3 API client. @@ -16,7 +16,7 @@ type Client struct { config ClientConfig requestBuilder requestBuilder - createFormBuilder func(io.Writer) utils.FormBuilder + createFormBuilder func(io.Writer) FormBuilder } // NewClient creates new OpenAI API client. @@ -30,8 +30,8 @@ func NewClientWithConfig(config ClientConfig) *Client { return &Client{ config: config, requestBuilder: newRequestBuilder(), - createFormBuilder: func(body io.Writer) utils.FormBuilder { - return utils.NewFormBuilder(body) + createFormBuilder: func(body io.Writer) FormBuilder { + return NewFormBuilder(body) }, } } diff --git a/files.go b/files.go index b701b9454..5667ec861 100644 --- a/files.go +++ b/files.go @@ -36,7 +36,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File var b bytes.Buffer builder := c.createFormBuilder(&b) - err = builder.writeField("purpose", request.Purpose) + err = builder.WriteField("purpose", request.Purpose) if err != nil { return } @@ -46,12 +46,12 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File return } - err = builder.createFormFile("file", fileData) + err = builder.CreateFormFile("file", fileData) if err != nil { return } - err = builder.close() + err = builder.Close() if err != nil { return } @@ -61,7 +61,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File return } - req.Header.Set("Content-Type", builder.formDataContentType()) + req.Header.Set("Content-Type", builder.FormDataContentType()) err = c.sendRequest(req, &file) diff --git a/files_test.go b/files_test.go index bb06498c8..56dbb414f 100644 --- a/files_test.go +++ b/files_test.go @@ -1,6 +1,7 @@ package openai //nolint:testpackage // testing private field import ( + . "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -85,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) { config.BaseURL = "" client := NewClientWithConfig(config) mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) formBuilder { + client.createFormBuilder = func(io.Writer) FormBuilder { return mockBuilder } diff --git a/image.go b/image.go index 21703bda7..87ffea25e 100644 --- a/image.go +++ b/image.go @@ -69,40 +69,40 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) builder := c.createFormBuilder(body) // image - err = builder.createFormFile("image", request.Image) + err = builder.CreateFormFile("image", request.Image) if err != nil { return } // mask, it is optional if request.Mask != nil { - err = builder.createFormFile("mask", request.Mask) + err = builder.CreateFormFile("mask", request.Mask) if err != nil { return } } - err = builder.writeField("prompt", request.Prompt) + err = builder.WriteField("prompt", request.Prompt) if err != nil { return } - err = builder.writeField("n", strconv.Itoa(request.N)) + err = builder.WriteField("n", strconv.Itoa(request.N)) if err != nil { return } - err = builder.writeField("size", request.Size) + err = builder.WriteField("size", request.Size) if err != nil { return } - err = builder.writeField("response_format", request.ResponseFormat) + err = builder.WriteField("response_format", request.ResponseFormat) if err != nil { return } - err = builder.close() + err = builder.Close() if err != nil { return } @@ -113,7 +113,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - req.Header.Set("Content-Type", builder.formDataContentType()) + req.Header.Set("Content-Type", builder.FormDataContentType()) err = c.sendRequest(req, &response) return } @@ -133,27 +133,27 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) builder := c.createFormBuilder(body) // image - err = builder.createFormFile("image", request.Image) + err = builder.CreateFormFile("image", request.Image) if err != nil { return } - err = builder.writeField("n", strconv.Itoa(request.N)) + err = builder.WriteField("n", strconv.Itoa(request.N)) if err != nil { return } - err = builder.writeField("size", request.Size) + err = builder.WriteField("size", request.Size) if err != nil { return } - err = builder.writeField("response_format", request.ResponseFormat) + err = builder.WriteField("response_format", request.ResponseFormat) if err != nil { return } - err = builder.close() + err = builder.Close() if err != nil { return } @@ -165,7 +165,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - req.Header.Set("Content-Type", builder.formDataContentType()) + req.Header.Set("Content-Type", builder.FormDataContentType()) err = c.sendRequest(req, &response) return } diff --git a/image_test.go b/image_test.go index 4a7dad58f..ed63061f5 100644 --- a/image_test.go +++ b/image_test.go @@ -1,6 +1,7 @@ package openai //nolint:testpackage // testing private field import ( + . "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -268,19 +269,19 @@ type mockFormBuilder struct { mockClose func() error } -func (fb *mockFormBuilder) createFormFile(fieldname string, file *os.File) error { +func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error { return fb.mockCreateFormFile(fieldname, file) } -func (fb *mockFormBuilder) writeField(fieldname, value string) error { +func (fb *mockFormBuilder) WriteField(fieldname, value string) error { return fb.mockWriteField(fieldname, value) } -func (fb *mockFormBuilder) close() error { +func (fb *mockFormBuilder) Close() error { return fb.mockClose() } -func (fb *mockFormBuilder) formDataContentType() string { +func (fb *mockFormBuilder) FormDataContentType() string { return "" } @@ -290,7 +291,7 @@ func TestImageFormBuilderFailures(t *testing.T) { client := NewClientWithConfig(config) mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) formBuilder { + client.createFormBuilder = func(io.Writer) FormBuilder { return mockBuilder } ctx := context.Background() @@ -357,7 +358,7 @@ func TestVariImageFormBuilderFailures(t *testing.T) { client := NewClientWithConfig(config) mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) formBuilder { + client.createFormBuilder = func(io.Writer) FormBuilder { return mockBuilder } ctx := context.Background() diff --git a/internal/test/utils/form_builder.go b/internal/form_builder.go similarity index 62% rename from internal/test/utils/form_builder.go rename to internal/form_builder.go index 13bebad5f..f4de790f9 100644 --- a/internal/test/utils/form_builder.go +++ b/internal/form_builder.go @@ -7,10 +7,10 @@ import ( ) type FormBuilder interface { - createFormFile(fieldname string, file *os.File) error - writeField(fieldname, value string) error - close() error - formDataContentType() string + CreateFormFile(fieldname string, file *os.File) error + WriteField(fieldname, value string) error + Close() error + FormDataContentType() string } type defaultFormBuilder struct { @@ -23,7 +23,7 @@ func NewFormBuilder(body io.Writer) *defaultFormBuilder { } } -func (fb *defaultFormBuilder) createFormFile(fieldname string, file *os.File) error { +func (fb *defaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error { fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name()) if err != nil { return err @@ -36,14 +36,14 @@ func (fb *defaultFormBuilder) createFormFile(fieldname string, file *os.File) er return nil } -func (fb *defaultFormBuilder) writeField(fieldname, value string) error { +func (fb *defaultFormBuilder) WriteField(fieldname, value string) error { return fb.writer.WriteField(fieldname, value) } -func (fb *defaultFormBuilder) close() error { +func (fb *defaultFormBuilder) Close() error { return fb.writer.Close() } -func (fb *defaultFormBuilder) formDataContentType() string { +func (fb *defaultFormBuilder) FormDataContentType() string { return fb.writer.FormDataContentType() } diff --git a/internal/test/utils/form_builder_test.go b/internal/form_builder_test.go similarity index 93% rename from internal/test/utils/form_builder_test.go rename to internal/form_builder_test.go index 4432636cf..d3faf9982 100644 --- a/internal/test/utils/form_builder_test.go +++ b/internal/form_builder_test.go @@ -31,7 +31,7 @@ func TestFormBuilderWithFailingWriter(t *testing.T) { defer os.Remove(file.Name()) builder := NewFormBuilder(&failingWriter{}) - err = builder.createFormFile("file", file) + err = builder.CreateFormFile("file", file) checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails") } @@ -48,7 +48,7 @@ func TestFormBuilderWithClosedFile(t *testing.T) { body := &bytes.Buffer{} builder := NewFormBuilder(body) - err = builder.createFormFile("file", file) + err = builder.CreateFormFile("file", file) checks.HasError(t, err, "formbuilder should return error if file is closed") checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") } From f49df15f774d717430b93ac643826c9a805d1e5e Mon Sep 17 00:00:00 2001 From: JoyShi Date: Mon, 15 May 2023 14:06:54 +0800 Subject: [PATCH 4/5] Fix import. --- audio.go | 4 ++-- client.go | 8 ++++---- image_test.go | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/audio.go b/audio.go index 47d0e9487..bf2365391 100644 --- a/audio.go +++ b/audio.go @@ -7,7 +7,7 @@ import ( "net/http" "os" - . "github.com/sashabaranov/go-openai/internal" + utils "github.com/sashabaranov/go-openai/internal" ) // Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI. @@ -94,7 +94,7 @@ func (r AudioRequest) HasJSONResponse() bool { // audioMultipartForm creates a form with audio file contents and the name of the model to use for // audio processing. -func audioMultipartForm(request AudioRequest, b FormBuilder) error { +func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error { f, err := os.Open(request.FilePath) if err != nil { return fmt.Errorf("opening audio file: %w", err) diff --git a/client.go b/client.go index f6a8b6837..c55166aa6 100644 --- a/client.go +++ b/client.go @@ -8,7 +8,7 @@ import ( "net/http" "strings" - . "github.com/sashabaranov/go-openai/internal" + utils "github.com/sashabaranov/go-openai/internal" ) // Client is OpenAI GPT-3 API client. @@ -16,7 +16,7 @@ type Client struct { config ClientConfig requestBuilder requestBuilder - createFormBuilder func(io.Writer) FormBuilder + createFormBuilder func(io.Writer) utils.FormBuilder } // NewClient creates new OpenAI API client. @@ -30,8 +30,8 @@ func NewClientWithConfig(config ClientConfig) *Client { return &Client{ config: config, requestBuilder: newRequestBuilder(), - createFormBuilder: func(body io.Writer) FormBuilder { - return NewFormBuilder(body) + createFormBuilder: func(body io.Writer) utils.FormBuilder { + return utils.NewFormBuilder(body) }, } } diff --git a/image_test.go b/image_test.go index ed63061f5..5cf6a268d 100644 --- a/image_test.go +++ b/image_test.go @@ -1,7 +1,7 @@ package openai //nolint:testpackage // testing private field import ( - . "github.com/sashabaranov/go-openai/internal" + utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -291,7 +291,7 @@ func TestImageFormBuilderFailures(t *testing.T) { client := NewClientWithConfig(config) mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) FormBuilder { + client.createFormBuilder = func(io.Writer) utils.FormBuilder { return mockBuilder } ctx := context.Background() @@ -358,7 +358,7 @@ func TestVariImageFormBuilderFailures(t *testing.T) { client := NewClientWithConfig(config) mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) FormBuilder { + client.createFormBuilder = func(io.Writer) utils.FormBuilder { return mockBuilder } ctx := context.Background() From 521b0083cdef87ca18aa12b0bad3ef7e9acd9bc4 Mon Sep 17 00:00:00 2001 From: JoyShi Date: Mon, 15 May 2023 14:50:56 +0800 Subject: [PATCH 5/5] Fix --- internal/form_builder.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/internal/form_builder.go b/internal/form_builder.go index f4de790f9..359dd7e2a 100644 --- a/internal/form_builder.go +++ b/internal/form_builder.go @@ -13,17 +13,17 @@ type FormBuilder interface { FormDataContentType() string } -type defaultFormBuilder struct { +type DefaultFormBuilder struct { writer *multipart.Writer } -func NewFormBuilder(body io.Writer) *defaultFormBuilder { - return &defaultFormBuilder{ +func NewFormBuilder(body io.Writer) *DefaultFormBuilder { + return &DefaultFormBuilder{ writer: multipart.NewWriter(body), } } -func (fb *defaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error { +func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error { fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name()) if err != nil { return err @@ -36,14 +36,14 @@ func (fb *defaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er return nil } -func (fb *defaultFormBuilder) WriteField(fieldname, value string) error { +func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error { return fb.writer.WriteField(fieldname, value) } -func (fb *defaultFormBuilder) Close() error { +func (fb *DefaultFormBuilder) Close() error { return fb.writer.Close() } -func (fb *defaultFormBuilder) FormDataContentType() string { +func (fb *DefaultFormBuilder) FormDataContentType() string { return fb.writer.FormDataContentType() }