Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion internal/translator/anthropic_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"bytes"
"cmp"
"encoding/base64"
"errors"
"fmt"
"io"
"strings"
Expand Down Expand Up @@ -767,6 +768,13 @@ func newAnthropicStreamParser(requestModel string) *anthropicStreamParser {
func (p *anthropicStreamParser) writeChunk(eventBlock []byte, buf *[]byte) error {
chunk, err := p.parseAndHandleEvent(eventBlock)
if err != nil {
var streamErr anthropicStreamErrorEvent
if errors.As(err, &streamErr) {
*buf = append(*buf, sseDataPrefix...)
*buf = append(*buf, streamErr.payload...)
*buf = append(*buf, '\n', '\n')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still append the DONE event after the error type event?

return nil
}
return err
}
if chunk != nil {
Expand Down Expand Up @@ -876,6 +884,14 @@ func (p *anthropicStreamParser) Process(body io.Reader, endOfStream bool, span t
return
}

type anthropicStreamErrorEvent struct {
payload []byte
}

func (e anthropicStreamErrorEvent) Error() string {
return "anthropic stream error event"
}

func (p *anthropicStreamParser) parseAndHandleEvent(eventBlock []byte) (*openai.ChatCompletionResponseChunk, error) {
var eventType []byte
var eventData []byte
Expand Down Expand Up @@ -1085,7 +1101,7 @@ func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, dat
if err := json.Unmarshal(data, &errEvent); err != nil {
return nil, fmt.Errorf("unparsable error event: %s", string(data))
}
return nil, fmt.Errorf("anthropic stream error: %s - %s", errEvent.Error.Type, errEvent.Error.Message)
return nil, anthropicStreamErrorEvent{payload: data}

case "ping":
// Per documentation, ping events can be ignored.
Expand Down
36 changes: 36 additions & 0 deletions internal/translator/openai_awsanthropic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,19 @@ func TestAWSAnthropicStreamParser_ErrorHandling(t *testing.T) {
return err
}

runStreamBodyTest := func(t *testing.T, sseStream string, endOfStream bool) ([]byte, error) {
eventStreamData, err := wrapAnthropicSSEInEventStream(sseStream)
require.NoError(t, err)

openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)}
translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "").(*openAIToAWSAnthropicTranslatorV1ChatCompletion)
_, _, err = translator.RequestBody(nil, openAIReq, false)
require.NoError(t, err)

_, body, _, _, err := translator.ResponseBody(map[string]string{}, bytes.NewReader(eventStreamData), endOfStream, nil)
return body, err
}

tests := []struct {
name string
sseStream string
Expand Down Expand Up @@ -706,6 +719,29 @@ func TestAWSAnthropicStreamParser_ErrorHandling(t *testing.T) {
})
}

t.Run("forwards anthropic error event and continues stream", func(t *testing.T) {
sseStream := `event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"before error"}}

event: error
data: {"type":"error","error":{"type":"overloaded_error","message":"Overloaded"}}

event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"after error"}}

event: message_stop
data: {"type":"message_stop"}
`
body, err := runStreamBodyTest(t, sseStream, true)
require.NoError(t, err)

bodyStr := string(body)
require.Contains(t, bodyStr, `"content":"before error"`)
require.Contains(t, bodyStr, `"error":{"type":"overloaded_error","message":"Overloaded"}`)
require.Contains(t, bodyStr, `"content":"after error"`)
require.Contains(t, bodyStr, string(sseDoneMessage))
})

t.Run("body read error", func(t *testing.T) {
parser := newAnthropicStreamParser("test-model")
_, _, _, _, err := parser.Process(&mockErrorReader{}, false, nil)
Expand Down