Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,14 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
var errRes ErrorResponse
err = json.NewDecoder(res.Body).Decode(&errRes)
if err != nil || errRes.Error == nil {
return fmt.Errorf("error, status code: %d", res.StatusCode)
reqErr := RequestError{
StatusCode: res.StatusCode,
Err: err,
}
return fmt.Errorf("error, %w", &reqErr)
}
return fmt.Errorf("error, status code: %d, message: %s", res.StatusCode, errRes.Error.Message)
errRes.Error.StatusCode = res.StatusCode
return fmt.Errorf("error, status code: %d, message: %w", res.StatusCode, errRes.Error)
}

if v != nil {
Expand Down
47 changes: 47 additions & 0 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,53 @@ func TestAPI(t *testing.T) {
}
}

func TestAPIError(t *testing.T) {
apiToken := os.Getenv("OPENAI_TOKEN")
if apiToken == "" {
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
}

var err error
c := NewClient(apiToken + "_invalid")
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err == nil {
t.Fatal("ListEngines did not fail")
}

var apiErr *APIError
if !errors.As(err, &apiErr) {
t.Fatalf("Error is not an APIError: %+v", err)
}

if apiErr.StatusCode != 401 {
t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode)
}
if *apiErr.Code != "invalid_api_key" {
t.Fatalf("Unexpected API error code: %s", *apiErr.Code)
}
}

func TestRequestError(t *testing.T) {
var err error
c := NewClient("dummy")
c.BaseURL = "https://httpbin.org/status/418?"
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err == nil {
t.Fatal("ListEngines request did not fail")
}

var reqErr *RequestError
if !errors.As(err, &reqErr) {
t.Fatalf("Error is not a RequestError: %+v", err)
}

if reqErr.StatusCode != 418 {
t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode)
}
}

// numTokens Returns the number of GPT-3 encoded tokens in the given text.
// This function approximates based on the rule of thumb stated by OpenAI:
// https://beta.openai.com/tokenizer
Expand Down
39 changes: 33 additions & 6 deletions error.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,37 @@
package gogpt

import "fmt"

// APIError provides error information returned by the OpenAI API.
type APIError struct {
Code *string `json:"code,omitempty"`
Message string `json:"message"`
Param *string `json:"param,omitempty"`
Type string `json:"type"`
StatusCode int `json:"-"`
}

// RequestError provides informations about generic request errors.
type RequestError struct {
StatusCode int
Err error
}

type ErrorResponse struct {
Error *struct {
Code *int `json:"code,omitempty"`
Message string `json:"message"`
Param *string `json:"param,omitempty"`
Type string `json:"type"`
} `json:"error,omitempty"`
Error *APIError `json:"error,omitempty"`
}

func (e *APIError) Error() string {
return e.Message
}

func (e *RequestError) Error() string {
if e.Err != nil {
return e.Err.Error()
}
return fmt.Sprintf("status code %d", e.StatusCode)
}

func (e *RequestError) Unwrap() error {
return e.Err
}