Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 37 additions & 28 deletions stream_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,43 +30,52 @@ func (stream *streamReader[T]) Recv() (response T, err error) {
return
}

response, err = stream.processLines()
return
}

func (stream *streamReader[T]) processLines() (T, error) {
var emptyMessagesCount uint

waitForData:
line, err := stream.reader.ReadBytes('\n')
if err != nil {
respErr := stream.unmarshalError()
if respErr != nil {
err = fmt.Errorf("error, %w", respErr.Error)
for {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to added level of indentaion I wonder if we should factor out for body in the separate function (or maybe parts of the for body)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you. Extracting the big, deeply nested for-loop into a separate method would be a good idea.

I will do so👍

rawLine, readErr := stream.reader.ReadBytes('\n')
if readErr != nil {
respErr := stream.unmarshalError()
if respErr != nil {
return *new(T), fmt.Errorf("error, %w", respErr.Error)
}
return *new(T), readErr
}
return
}

var headerData = []byte("data: ")
line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, headerData) {
if writeErr := stream.errAccumulator.Write(line); writeErr != nil {
err = writeErr
return
var headerData = []byte("data: ")
noSpaceLine := bytes.TrimSpace(rawLine)
if !bytes.HasPrefix(noSpaceLine, headerData) {
writeErr := stream.errAccumulator.Write(noSpaceLine)
if writeErr != nil {
return *new(T), writeErr
}
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
return *new(T), ErrTooManyEmptyStreamMessages
}

continue
}
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
err = ErrTooManyEmptyStreamMessages
return

noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
if string(noPrefixLine) == "[DONE]" {
stream.isFinished = true
return *new(T), io.EOF
}

goto waitForData
}
var response T
unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response)
if unmarshalErr != nil {
return *new(T), unmarshalErr
}

line = bytes.TrimPrefix(line, headerData)
if string(line) == "[DONE]" {
stream.isFinished = true
err = io.EOF
return
return response, nil
}

err = stream.unmarshaler.Unmarshal(line, &response)
return
}

func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) {
Expand Down
160 changes: 160 additions & 0 deletions stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openai_test

import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
Expand Down Expand Up @@ -217,6 +218,165 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
t.Logf("%+v\n", apiErr)
}

func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

// Totally 301 empty messages (300 is the limit)
for i := 0; i < 299; i++ {
dataBytes = append(dataBytes, '\n')
}

dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

dataBytes = append(dataBytes, []byte("event: done\n")...)
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)

_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
}))
defer server.Close()

// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}

client := NewClientWithConfig(config)
ctx := context.Background()

request := CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
}

stream, err := client.CreateCompletionStream(ctx, request)
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

_, _ = stream.Recv()
_, streamErr := stream.Recv()
if !errors.Is(streamErr, ErrTooManyEmptyStreamMessages) {
t.Errorf("TestCreateCompletionStreamTooManyEmptyStreamMessagesError did not return ErrTooManyEmptyStreamMessages")
}
}

func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

// Stream is terminated without sending "done" message

_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
}))
defer server.Close()

// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}

client := NewClientWithConfig(config)
ctx := context.Background()

request := CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
}

stream, err := client.CreateCompletionStream(ctx, request)
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

_, _ = stream.Recv()
_, streamErr := stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("TestCreateCompletionStreamUnexpectedTerminatedError did not return io.EOF")
}
}

func TestCreateCompletionStreamBrokenJSONError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

// Send broken json
dataBytes = append(dataBytes, []byte("event: message\n")...)
data = `{"id":"2","object":"completion","created":1598069255,"model":`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

dataBytes = append(dataBytes, []byte("event: done\n")...)
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)

_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
}))
defer server.Close()

// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}

client := NewClientWithConfig(config)
ctx := context.Background()

request := CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
}

stream, err := client.CreateCompletionStream(ctx, request)
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

_, _ = stream.Recv()
_, streamErr := stream.Recv()
var syntaxError *json.SyntaxError
if !errors.As(streamErr, &syntaxError) {
t.Errorf("TestCreateCompletionStreamBrokenJSONError did not return json.SyntaxError")
}
}

// Helper funcs.
func compareResponses(r1, r2 CompletionResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
Expand Down