Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 102 additions & 7 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,108 @@ func TestAPIError(t *testing.T) {
}
}

func TestAPIErrorUnmarshalJSONMessageField(t *testing.T) {
type testCase struct {
name string
response string
hasError bool
checkFn func(t *testing.T, apiErr APIError)
}
testCases := []testCase{
{
name: "parse succeeds when the message is string",
response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`,
hasError: false,
checkFn: func(t *testing.T, apiErr APIError) {
expected := "foo"
if apiErr.Message != expected {
t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected)
}
},
},
{
name: "parse succeeds when the message is array with single item",
response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`,
hasError: false,
checkFn: func(t *testing.T, apiErr APIError) {
expected := "foo"
if apiErr.Message != expected {
t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected)
}
},
},
{
name: "parse succeeds when the message is array with multiple items",
response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`,
hasError: false,
checkFn: func(t *testing.T, apiErr APIError) {
expected := "foo, bar, baz"
if apiErr.Message != expected {
t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected)
}
},
},
{
name: "parse succeeds when the message is empty array",
response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`,
hasError: false,
checkFn: func(t *testing.T, apiErr APIError) {
if apiErr.Message != "" {
t.Fatalf("Unexpected API message: %v; expected: empty", apiErr)
}
},
},
{
name: "parse succeeds when the message is null",
response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`,
hasError: false,
checkFn: func(t *testing.T, apiErr APIError) {
if apiErr.Message != "" {
t.Fatalf("Unexpected API message: %v; expected: empty", apiErr)
}
},
},
{
name: "parse failed when the message is object",
response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`,
hasError: true,
},
{
name: "parse failed when the message is int",
response: `{"message":1,"type":"invalid_request_error","param":null,"code":null}`,
hasError: true,
},
{
name: "parse failed when the message is float",
response: `{"message":0.1,"type":"invalid_request_error","param":null,"code":null}`,
hasError: true,
},
{
name: "parse failed when the message is bool",
response: `{"message":true,"type":"invalid_request_error","param":null,"code":null}`,
hasError: true,
},
{
name: "parse failed when the message is not exists",
response: `{"type":"invalid_request_error","param":null,"code":null}`,
hasError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var apiErr APIError
err := json.Unmarshal([]byte(tc.response), &apiErr)
if (err != nil) != tc.hasError {
t.Errorf("Unexpected error: %v", err)
return
}
if tc.checkFn != nil {
tc.checkFn(t, apiErr)
}
})
}
}

func TestAPIErrorUnmarshalJSONInteger(t *testing.T) {
var apiErr APIError
response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`
Expand Down Expand Up @@ -217,13 +319,6 @@ func TestAPIErrorUnmarshalJSONInvalidType(t *testing.T) {
checks.HasError(t, err, "Type should be a string")
}

func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) {
var apiErr APIError
response := `{"code":418,"message":false,"param":"prompt","type":"teapot_error"}`
err := json.Unmarshal([]byte(response), &apiErr)
checks.HasError(t, err, "Message should be a string")
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test has been moved to the newly added table test.

func TestRequestError(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
Expand Down
10 changes: 9 additions & 1 deletion error.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai
import (
"encoding/json"
"fmt"
"strings"
)

// APIError provides error information returned by the OpenAI API.
Expand Down Expand Up @@ -41,7 +42,14 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {

err = json.Unmarshal(rawMap["message"], &e.Message)
if err != nil {
return
// If the parameter field of a function call is invalid as a JSON schema
// refs: https://github.com/sashabaranov/go-openai/issues/381
var messages []string
err = json.Unmarshal(rawMap["message"], &messages)
if err != nil {
return
}
e.Message = strings.Join(messages, ", ")
}

// optional fields for azure openai
Expand Down