Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai)
[![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai)

This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support:
This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support:

* ChatGPT 4o, o1
* GPT-3, GPT-4
Expand Down Expand Up @@ -660,7 +660,7 @@ if errors.As(err, &e) {
case 401:
// invalid auth or key (do not retry)
case 429:
// rate limiting or engine overload (wait and retry)
// rate limiting or engine overload (wait and retry)
case 500:
// openai server error (retry)
default:
Expand Down Expand Up @@ -807,6 +807,58 @@ func main() {
}
```
</details>

<details>
<summary>Using ExtraFields</summary>

```go
package main

import (
"context"
"fmt"
openai "github.com/sashabaranov/go-openai"
)

func main() {
client := openai.NewClient("your token")
ctx := context.Background()

// Create chat request
req := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
}

// Add custom fields
extraFields := map[string]any{
"custom_field": "test_value",
"numeric_field": 42,
"bool_field": true,
}
req.SetExtraFields(extraFields)

// Get custom fields
gotFields := req.GetExtraFields()
fmt.Printf("Extra fields: %v\n", gotFields)

// Send request
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
fmt.Printf("ChatCompletion error: %v\n", err)
return
}

fmt.Println(resp.Choices[0].Message.Content)
}
```
</details>

See the `examples/` folder for more.

## Frequently Asked Questions
Expand All @@ -827,18 +879,18 @@ Due to the factors mentioned above, different answers may be returned even for t

By adopting these strategies, you can expect more consistent results.

**Related Issues:**
**Related Issues:**
[omitempty option of request struct will generate incorrect request when parameter is 0.](https://github.com/sashabaranov/go-openai/issues/9)

### Does Go OpenAI provide a method to count tokens?

No, Go OpenAI does not offer a feature to count tokens, and there are no plans to provide such a feature in the future. However, if there's a way to implement a token counting feature with zero dependencies, it might be possible to merge that feature into Go OpenAI. Otherwise, it would be more appropriate to implement it in a dedicated library or repository.

For counting tokens, you might find the following links helpful:
For counting tokens, you might find the following links helpful:
- [Counting Tokens For Chat API Calls](https://github.com/pkoukk/tiktoken-go#counting-tokens-for-chat-api-calls)
- [How to count tokens with tiktoken](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb)

**Related Issues:**
**Related Issues:**
[Is it possible to join the implementation of GPT3 Tokenizer](https://github.com/sashabaranov/go-openai/issues/62)

## Contributing
Expand Down
17 changes: 17 additions & 0 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,23 @@ type ChatCompletionRequest struct {
ReasoningEffort string `json:"reasoning_effort,omitempty"`
// Metadata to store with the completion.
Metadata map[string]string `json:"metadata,omitempty"`

// Extra fields to be sent in the request.
// Useful for experimental features not yet officially supported.
extraFields map[string]any
}

// SetExtraFields adds extra fields to the JSON object.
//
// SetExtraFields will override any existing fields with the same key.
// For security reasons, ensure this is only used with trusted input data.
func (r *ChatCompletionRequest) SetExtraFields(extraFields map[string]any) {
r.extraFields = extraFields
}

// GetExtraFields returns the extra fields set in the request.
func (r ChatCompletionRequest) GetExtraFields() map[string]any {
return r.extraFields
}

type StreamOptions struct {
Expand Down
49 changes: 49 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -871,3 +871,52 @@ func TestFinishReason(t *testing.T) {
}
}
}

func TestChatCompletionRequestExtraFields(t *testing.T) {
req := openai.ChatCompletionRequest{
Model: "gpt-4",
}

// 测试设置额外字段
extraFields := map[string]any{
"custom_field": "test_value",
"numeric_field": 42,
"bool_field": true,
}
req.SetExtraFields(extraFields)

// 测试获取额外字段
gotFields := req.GetExtraFields()

// 验证字段数量
if len(gotFields) != len(extraFields) {
t.Errorf("Expected %d extra fields, got %d", len(extraFields), len(gotFields))
}

// 验证字段值
for key, expectedValue := range extraFields {
gotValue, exists := gotFields[key]
if !exists {
t.Errorf("Expected field %s not found", key)
continue
}
if gotValue != expectedValue {
t.Errorf("Field %s: expected %v, got %v", key, expectedValue, gotValue)
}
}

// 测试覆盖已存在的字段
newFields := map[string]any{
"custom_field": "new_value",
}
req.SetExtraFields(newFields)
gotFields = req.GetExtraFields()

if len(gotFields) != len(newFields) {
t.Errorf("Expected %d extra fields after override, got %d", len(newFields), len(gotFields))
}

if gotFields["custom_field"] != "new_value" {
t.Errorf("Expected overridden value 'new_value', got %v", gotFields["custom_field"])
}
}
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
module github.com/meguminnnnnnnnn/go-openai

go 1.18

require github.com/evanphx/json-patch v0.5.2

require github.com/pkg/errors v0.9.1 // indirect
5 changes: 5 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8P3k=
github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
30 changes: 29 additions & 1 deletion internal/marshaller.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package openai

import (
"encoding/json"
"fmt"

jsonpatch "github.com/evanphx/json-patch"
)

type Marshaller interface {
Expand All @@ -11,5 +14,30 @@ type Marshaller interface {
type JSONMarshaller struct{}

func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) {
return json.Marshal(value)
originalBytes, err := json.Marshal(value)
if err != nil {
return nil, err
}
// Check if the value implements the GetExtraFields interface
getExtraFieldsBody, ok := value.(interface {
GetExtraFields() map[string]any
})
if !ok {
// If not, return the original bytes
return originalBytes, nil
}
extraFields := getExtraFieldsBody.GetExtraFields()
if len(extraFields) == 0 {
// If there are no extra fields, return the original bytes
return originalBytes, nil
}
patchBytes, err := json.Marshal(extraFields)
if err != nil {
return nil, fmt.Errorf("Marshal extraFields(%+v) err: %w", extraFields, err)
}
finalBytes, err := jsonpatch.MergePatch(originalBytes, patchBytes)
if err != nil {
return nil, fmt.Errorf("MergePatch originalBytes(%s) patchBytes(%s) err: %w", originalBytes, patchBytes, err)
}
return finalBytes, nil
}
1 change: 1 addition & 0 deletions internal/request_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func (b *HTTPRequestBuilder) Build(
if err != nil {
return
}

bodyReader = bytes.NewBuffer(reqBytes)
}
}
Expand Down
49 changes: 49 additions & 0 deletions internal/request_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai //nolint:testpackage // testing private field
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"reflect"
Expand Down Expand Up @@ -59,3 +60,51 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) {
t.Errorf("Build() got = %v, want %v", got, want)
}
}

type testExtraFieldsRequest struct {
Model string `json:"model"`
extraFields map[string]any
}

func (r *testExtraFieldsRequest) GetExtraFields() map[string]any {
return r.extraFields
}

func TestRequestBuilderReturnsRequestWhenRequestHasExtraFields(t *testing.T) {
b := NewRequestBuilder()
var (
ctx = context.Background()
method = http.MethodPost
url = "/foo"
request = &testExtraFieldsRequest{
Model: "test-model",
}
)
request.extraFields = map[string]any{"extra_field": "extra_value"}

reqBytes, err := b.marshaller.Marshal(request)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}

// 验证序列化结果包含原始字段和额外字段
var result map[string]interface{}
if err := json.Unmarshal(reqBytes, &result); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}

if result["model"] != "test-model" {
t.Errorf("Expected model to be 'test-model', got %v", result["model"])
}
if result["extra_field"] != "extra_value" {
t.Errorf("Expected extra_field to be 'extra_value', got %v", result["extra_field"])
}

want, _ := http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes))
got, _ := b.Build(ctx, method, url, request, nil)
if !reflect.DeepEqual(got.Body, want.Body) ||
!reflect.DeepEqual(got.URL, want.URL) ||
!reflect.DeepEqual(got.Method, want.Method) {
t.Errorf("Build() got = %v, want %v", got, want)
}
}