Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/changelog.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- fix: adds timeout and connection disconnect handling for streaming responses
117 changes: 72 additions & 45 deletions core/providers/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,20 @@ func HandleAnthropicChatCompletionStreaming(

// Start streaming in a goroutine
go func() {
defer close(responseChan)
defer func() {
model := "unknown"
if meta != nil {
model = meta.Model
}
if ctx.Err() == context.Canceled {
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger)
} else if ctx.Err() == context.DeadlineExceeded {
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger)
}
close(responseChan)
}()
defer providerUtils.ReleaseStreamingResponse(resp)

if resp.BodyStream() == nil {
bifrostErr := providerUtils.NewBifrostOperationError(
"Provider returned an empty response",
Expand All @@ -510,6 +521,10 @@ func HandleAnthropicChatCompletionStreaming(
return
}

// Setup cancellation handler to close body stream on ctx cancellation
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger)
defer stopCancellation()

scanner := bufio.NewScanner(resp.BodyStream())
buf := make([]byte, 0, 1024*1024)
scanner.Buffer(buf, 10*1024*1024)
Expand All @@ -531,13 +546,15 @@ func HandleAnthropicChatCompletionStreaming(
var eventData string

for scanner.Scan() {
// If context was cancelled/timed out, let defer handle it
if ctx.Err() != nil {
return
}
line := scanner.Text()

// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, ":") {
continue
}

// Parse SSE event - track event type and data separately
if after, ok := strings.CutPrefix(line, "event: "); ok {
eventType = after
Expand All @@ -547,22 +564,18 @@ func HandleAnthropicChatCompletionStreaming(
} else {
continue
}

// Skip if we don't have both event type and data
if eventType == "" || eventData == "" {
continue
}

var event AnthropicStreamEvent
if err := sonic.Unmarshal([]byte(eventData), &event); err != nil {
logger.Warn(fmt.Sprintf("Failed to parse message_start event: %v", err))
continue
}

if event.Type == AnthropicStreamEventTypeMessageStart && event.Message != nil && event.Message.ID != "" {
messageID = event.Message.ID
}

// Check for usage in both top-level event.Usage and nested event.Message.Usage
// message_start events have usage nested in message.usage, while message_delta has it at top level
var usageToProcess *AnthropicUsage
Expand All @@ -571,7 +584,6 @@ func HandleAnthropicChatCompletionStreaming(
} else if event.Message != nil && event.Message.Usage != nil {
usageToProcess = event.Message.Usage
}

if usageToProcess != nil {
// Collect usage information and send at the end of the stream
// Here in some cases usage comes before final message
Expand Down Expand Up @@ -606,7 +618,6 @@ func HandleAnthropicChatCompletionStreaming(
}
}
}

if event.Delta != nil && event.Delta.StopReason != nil {
mappedReason := ConvertAnthropicFinishReasonToBifrost(*event.Delta.StopReason)
finishReason = &mappedReason
Expand All @@ -615,7 +626,6 @@ func HandleAnthropicChatCompletionStreaming(
// Handle different event types
modelName = event.Message.Model
}

response, bifrostErr, isLastChunk := event.ToBifrostChatCompletionStream()
if bifrostErr != nil {
bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{
Expand Down Expand Up @@ -652,36 +662,40 @@ func HandleAnthropicChatCompletionStreaming(

providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan)
}

if isLastChunk {
break
}

// Reset for next event
eventType = ""
eventData = ""
}

if err := scanner.Err(); err != nil {
// If context was cancelled/timed out, let defer handle it
if ctx.Err() != nil {
return
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
logger.Warn(fmt.Sprintf("Error reading %s stream: %v", providerName, err))
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, modelName, logger)
} else {
response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, modelName)
if postResponseConverter != nil {
response = postResponseConverter(response)
if response == nil {
logger.Warn("postResponseConverter returned nil; skipping chunk")
return
}
}
// Set raw request if enabled
if sendBackRawRequest {
providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody)
return
}
response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, modelName)
if postResponseConverter != nil {
response = postResponseConverter(response)
if response == nil {
logger.Warn("postResponseConverter returned nil; skipping chunk")
// Setting error on the context to signal to the defer that we need to close the stream
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
return
Comment thread
akshaydeo marked this conversation as resolved.
}
response.ExtraFields.Latency = time.Since(startTime).Milliseconds()
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan)
}
// Set raw request if enabled
if sendBackRawRequest {
providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody)
}
response.ExtraFields.Latency = time.Since(startTime).Milliseconds()
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan)
Comment thread
akshaydeo marked this conversation as resolved.
}()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Comment thread
akshaydeo marked this conversation as resolved.

return responseChan, nil
Expand Down Expand Up @@ -850,9 +864,23 @@ func HandleAnthropicResponsesStream(

// Start streaming in a goroutine
go func() {
defer func() {
model := "<unknown>"
if meta != nil {
model = meta.Model
}
if ctx.Err() == context.Canceled {
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger)
} else if ctx.Err() == context.DeadlineExceeded {
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger)
}
close(responseChan)
}()
defer providerUtils.ReleaseStreamingResponse(resp)
defer close(responseChan)

// Setup cancellation handler to close body stream on ctx cancellation
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger)
defer stopCancellation()
// If body stream is nil, return an error
if resp.BodyStream() == nil {
bifrostErr := providerUtils.NewBifrostOperationError(
"Provider returned an empty response",
Expand Down Expand Up @@ -883,13 +911,15 @@ func HandleAnthropicResponsesStream(
var modelName string

for scanner.Scan() {
// If context was cancelled/timed out, let defer handle it
if ctx.Err() != nil {
return
}
line := scanner.Text()

// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, ":") {
continue
}

// Parse SSE event - track event type and data separately
if after, ok := strings.CutPrefix(line, "event: "); ok {
eventType = after
Expand All @@ -899,22 +929,18 @@ func HandleAnthropicResponsesStream(
} else {
continue
}

// Skip if we don't have both event type and data
if eventType == "" || eventData == "" {
continue
}

var event AnthropicStreamEvent
if err := sonic.Unmarshal([]byte(eventData), &event); err != nil {
logger.Warn(fmt.Sprintf("Failed to parse message_start event: %v", err))
continue
}

if event.Message != nil && modelName == "" {
modelName = event.Message.Model
}

// Note: response.created and response.in_progress are now emitted by ToBifrostResponsesStream
// from the message_start event, so we don't need to call them manually here

Expand Down Expand Up @@ -969,6 +995,10 @@ func HandleAnthropicResponsesStream(
Provider: providerName,
ModelRequested: modelName,
}
// If context was cancelled/timed out, let defer handle it
if ctx.Err() != nil {
return
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger)
break
Expand Down Expand Up @@ -1020,8 +1050,12 @@ func HandleAnthropicResponsesStream(
eventType = ""
eventData = ""
}

if err := scanner.Err(); err != nil {
// If context was cancelled/timed out, let defer handle it
if ctx.Err() != nil {
return
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
logger.Warn(fmt.Sprintf("Error reading %s stream: %v", providerName, err))
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, modelName, logger)
}
Expand Down Expand Up @@ -1568,13 +1602,6 @@ func (provider *AnthropicProvider) TranscriptionStream(ctx *schemas.BifrostConte
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
}

// parseStreamAnthropicError parses Anthropic streaming error responses.
func parseStreamAnthropicError(resp *fasthttp.Response, providerType schemas.ModelProvider) *schemas.BifrostError {
statusCode := resp.StatusCode()
body := resp.Body()
return providerUtils.NewProviderAPIError(string(body), nil, statusCode, providerType, nil, nil)
}

// FileUpload uploads a file to Anthropic's Files API.
func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.FileUploadRequest); err != nil {
Expand Down
26 changes: 19 additions & 7 deletions core/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,19 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo

// Start streaming in a goroutine
go func() {
defer close(responseChan)
defer func() {
if ctx.Err() == context.Canceled {
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.SpeechStreamRequest, provider.logger)
} else if ctx.Err() == context.DeadlineExceeded {
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.SpeechStreamRequest, provider.logger)
}
close(responseChan)
}()
Comment thread
akshaydeo marked this conversation as resolved.
// Always release response on exit; bodyStream close should prevent indefinite blocking.
defer providerUtils.ReleaseStreamingResponse(resp)
// Setup cancellation handler to close body stream on ctx cancellation
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger)
defer stopCancellation()
Comment thread
akshaydeo marked this conversation as resolved.

Comment thread
coderabbitai[bot] marked this conversation as resolved.
// Check if response is compressed
bodyStream := resp.BodyStream()
Expand All @@ -1021,13 +1033,10 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo
var accumulated []byte

for {
// Check if context is done
select {
case <-ctx.Done():
// If context was cancelled/timed out, let defer handle it
if ctx.Err() != nil {
return
default:
}

// Read from stream
n, readErr := bodyStream.Read(readBuffer)
if n > 0 {
Expand Down Expand Up @@ -1057,7 +1066,6 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo
// Check if this has "data: " prefix (standard SSE format)
if bytes.HasPrefix(event, []byte("data: ")) {
audioData = event[6:] // Skip "data: " prefix

// Check for [DONE] marker
if bytes.Equal(audioData, []byte("[DONE]")) {
return
Expand Down Expand Up @@ -1115,6 +1123,10 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo

// Handle read errors
if readErr != nil {
// If context was cancelled/timed out, let defer handle it
if ctx.Err() != nil {
return
}
if readErr != io.EOF {
provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", readErr))
}
Expand Down
Loading