diff --git a/core/changelog.md b/core/changelog.md index e69de29bb2..c0ebddec8b 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -0,0 +1 @@ +- fix: adds timeout and connection disconnect handling for streaming responses \ No newline at end of file diff --git a/core/providers/anthropic/anthropic.go b/core/providers/anthropic/anthropic.go index a8037c705a..0d5eced0e9 100644 --- a/core/providers/anthropic/anthropic.go +++ b/core/providers/anthropic/anthropic.go @@ -496,9 +496,20 @@ func HandleAnthropicChatCompletionStreaming( // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + model := "unknown" + if meta != nil { + model = meta.Model + } + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) - + if resp.BodyStream() == nil { bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", @@ -510,6 +521,10 @@ func HandleAnthropicChatCompletionStreaming( return } + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() + scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) scanner.Buffer(buf, 10*1024*1024) @@ -531,13 +546,15 @@ func HandleAnthropicChatCompletionStreaming( var eventData string for scanner.Scan() { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } line := scanner.Text() - // Skip empty lines and comments if line == "" || strings.HasPrefix(line, ":") { continue } - // Parse SSE event - track event type and data separately if after, ok := strings.CutPrefix(line, "event: "); ok { eventType = after @@ -547,22 +564,18 @@ func HandleAnthropicChatCompletionStreaming( } else { continue } - // Skip if we don't have both event type and data if eventType == "" || eventData == "" { continue } - var event AnthropicStreamEvent if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse message_start event: %v", err)) continue } - if event.Type == AnthropicStreamEventTypeMessageStart && event.Message != nil && event.Message.ID != "" { messageID = event.Message.ID } - // Check for usage in both top-level event.Usage and nested event.Message.Usage // message_start events have usage nested in message.usage, while message_delta has it at top level var usageToProcess *AnthropicUsage @@ -571,7 +584,6 @@ func HandleAnthropicChatCompletionStreaming( } else if event.Message != nil && event.Message.Usage != nil { usageToProcess = event.Message.Usage } - if usageToProcess != nil { // Collect usage information and send at the end of the stream // Here in some cases usage comes before final message @@ -606,7 +618,6 @@ func HandleAnthropicChatCompletionStreaming( } } } - if event.Delta != nil && event.Delta.StopReason != nil { mappedReason := ConvertAnthropicFinishReasonToBifrost(*event.Delta.StopReason) finishReason = &mappedReason @@ -615,7 +626,6 @@ func HandleAnthropicChatCompletionStreaming( // Handle different event types modelName = event.Message.Model } - response, bifrostErr, isLastChunk := event.ToBifrostChatCompletionStream() if bifrostErr != nil { bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ @@ -652,36 +662,40 @@ func HandleAnthropicChatCompletionStreaming( providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) } - if isLastChunk { break } - // Reset for next event eventType = "" eventData = "" } - if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading %s stream: %v", providerName, err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, modelName, logger) - } else { - response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, modelName) - if postResponseConverter != nil { - response = postResponseConverter(response) - if response == nil { - logger.Warn("postResponseConverter returned nil; skipping chunk") - return - } - } - // Set raw request if enabled - if sendBackRawRequest { - providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) + return + } + response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, modelName) + if postResponseConverter != nil { + response = postResponseConverter(response) + if response == nil { + logger.Warn("postResponseConverter returned nil; skipping chunk") + // Setting error on the context to signal to the defer that we need to close the stream + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + return } - response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) } + // Set raw request if enabled + if sendBackRawRequest { + providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) + } + response.ExtraFields.Latency = time.Since(startTime).Milliseconds() + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) }() return responseChan, nil @@ -850,9 +864,23 @@ func HandleAnthropicResponsesStream( // Start streaming in a goroutine go func() { + defer func() { + model := "" + if meta != nil { + model = meta.Model + } + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) - defer close(responseChan) - + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() + // If body stream is nil, return an error if resp.BodyStream() == nil { bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", @@ -883,13 +911,15 @@ func HandleAnthropicResponsesStream( var modelName string for scanner.Scan() { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } line := scanner.Text() - // Skip empty lines and comments if line == "" || strings.HasPrefix(line, ":") { continue } - // Parse SSE event - track event type and data separately if after, ok := strings.CutPrefix(line, "event: "); ok { eventType = after @@ -899,22 +929,18 @@ func HandleAnthropicResponsesStream( } else { continue } - // Skip if we don't have both event type and data if eventType == "" || eventData == "" { continue } - var event AnthropicStreamEvent if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse message_start event: %v", err)) continue } - if event.Message != nil && modelName == "" { modelName = event.Message.Model } - // Note: response.created and response.in_progress are now emitted by ToBifrostResponsesStream // from the message_start event, so we don't need to call them manually here @@ -969,6 +995,10 @@ func HandleAnthropicResponsesStream( Provider: providerName, ModelRequested: modelName, } + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) break @@ -1020,8 +1050,12 @@ func HandleAnthropicResponsesStream( eventType = "" eventData = "" } - if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading %s stream: %v", providerName, err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, modelName, logger) } @@ -1568,13 +1602,6 @@ func (provider *AnthropicProvider) TranscriptionStream(ctx *schemas.BifrostConte return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } -// parseStreamAnthropicError parses Anthropic streaming error responses. -func parseStreamAnthropicError(resp *fasthttp.Response, providerType schemas.ModelProvider) *schemas.BifrostError { - statusCode := resp.StatusCode() - body := resp.Body() - return providerUtils.NewProviderAPIError(string(body), nil, statusCode, providerType, nil, nil) -} - // FileUpload uploads a file to Anthropic's Files API. func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.FileUploadRequest); err != nil { diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index dd31038614..f9bc2f13ba 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -1006,7 +1006,19 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.SpeechStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.SpeechStreamRequest, provider.logger) + } + close(responseChan) + }() + // Always release response on exit; bodyStream close should prevent indefinite blocking. + defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) + defer stopCancellation() // Check if response is compressed bodyStream := resp.BodyStream() @@ -1021,13 +1033,10 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo var accumulated []byte for { - // Check if context is done - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } - // Read from stream n, readErr := bodyStream.Read(readBuffer) if n > 0 { @@ -1057,7 +1066,6 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Check if this has "data: " prefix (standard SSE format) if bytes.HasPrefix(event, []byte("data: ")) { audioData = event[6:] // Skip "data: " prefix - // Check for [DONE] marker if bytes.Equal(audioData, []byte("[DONE]")) { return @@ -1115,6 +1123,10 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Handle read errors if readErr != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } if readErr != io.EOF { provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", readErr)) } diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index 1cba86c322..373b90aea9 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -604,8 +604,18 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, provider.logger) + } + close(responseChan) + }() defer resp.Body.Close() + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.Body, provider.logger) + defer stopCancellation() // Process AWS Event Stream format startTime := time.Now() @@ -613,14 +623,22 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex payloadBuf := make([]byte, 0, 1024*1024) // 1MB payload buffer for { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } // Decode a single EventStream message message, err := decoder.Decode(resp.Body, payloadBuf) if err != nil { - ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } if err == io.EOF { // End of stream - this is normal break } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error decoding %s EventStream message: %v", providerName, err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) return @@ -778,8 +796,18 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + } + close(responseChan) + }() defer resp.Body.Close() + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.Body, provider.logger) + defer stopCancellation() // Process AWS Event Stream format usage := &schemas.BifrostLLMUsage{} @@ -796,13 +824,22 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex id := uuid.New().String() for { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } // Decode a single EventStream message message, err := decoder.Decode(resp.Body, payloadBuf) if err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + // End of stream - this is normal if err == io.EOF { - // End of stream - this is normal break } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error decoding %s EventStream message: %v", providerName, err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) return @@ -1019,8 +1056,19 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + } + close(responseChan) + }() + // Always release response on exit; bodyStream close should prevent indefinite blocking. defer resp.Body.Close() + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.Body, provider.logger) + defer stopCancellation() // Process AWS Event Stream format usage := &schemas.ResponsesResponseUsage{} @@ -1038,9 +1086,17 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po payloadBuf := make([]byte, 0, 1024*1024) // 1MB payload buffer for { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } // Decode a single EventStream message message, err := decoder.Decode(resp.Body, payloadBuf) if err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } if err == io.EOF { // End of stream - finalize any open items finalResponses := FinalizeBedrockStream(streamState, chunkIndex, usage) @@ -1072,7 +1128,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil), responseChan) } break - } + } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error decoding %s EventStream message: %v", providerName, err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) @@ -1134,7 +1190,6 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po } } } - responses, bifrostErr, _ := streamEvent.ToBifrostResponsesStream(chunkIndex, streamState) if bifrostErr != nil { bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ diff --git a/core/providers/cohere/cohere.go b/core/providers/cohere/cohere.go index 65628bd990..2adbfd3ac5 100644 --- a/core/providers/cohere/cohere.go +++ b/core/providers/cohere/cohere.go @@ -417,8 +417,18 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) @@ -430,6 +440,10 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext var responseID string for scanner.Scan() { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } line := scanner.Text() // Skip empty lines and comments @@ -503,6 +517,11 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext } if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) } @@ -650,8 +669,18 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) @@ -671,6 +700,10 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos var eventData string for scanner.Scan() { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } line := scanner.Text() // Skip empty lines and comments @@ -756,6 +789,11 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos } if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading %s stream: %v", providerName, err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) } diff --git a/core/providers/elevenlabs/elevenlabs.go b/core/providers/elevenlabs/elevenlabs.go index 4d7e3779a1..07d27b059d 100644 --- a/core/providers/elevenlabs/elevenlabs.go +++ b/core/providers/elevenlabs/elevenlabs.go @@ -376,8 +376,18 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) go func() { + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) - defer close(responseChan) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) + defer stopCancellation() // read binary audio chunks from the stream // 4KB buffer for reading chunks @@ -387,18 +397,20 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po lastChunkTime := time.Now() for { - // Check if context is done before processing - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } - n, err := bodyStream.Read(buffer) if err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } if err == io.EOF { break - } + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, provider.logger) return diff --git a/core/providers/gemini/gemini.go b/core/providers/gemini/gemini.go index 3de4259bc9..a325b59f73 100644 --- a/core/providers/gemini/gemini.go +++ b/core/providers/gemini/gemini.go @@ -380,9 +380,31 @@ func HandleGeminiChatCompletionStream( // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + if resp.BodyStream() == nil { + bifrostErr := providerUtils.NewBifrostOperationError( + "Provider returned an empty response", + fmt.Errorf("provider returned an empty response"), + providerName, + ) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) + return + } + + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() + scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) scanner.Buffer(buf, 10*1024*1024) @@ -395,10 +417,9 @@ func HandleGeminiChatCompletionStream( var modelName string for scanner.Scan() { - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } line := scanner.Text() @@ -406,19 +427,15 @@ func HandleGeminiChatCompletionStream( if line == "" || strings.HasPrefix(line, ":") { continue } - // Parse SSE data if !strings.HasPrefix(line, "data: ") { continue } - eventData := strings.TrimPrefix(line, "data: ") - // Skip empty data if strings.TrimSpace(eventData) == "" { continue } - // Process chunk using shared function geminiResponse, err := processGeminiStreamChunk(eventData) if err != nil { @@ -511,6 +528,11 @@ func HandleGeminiChatCompletionStream( // Handle scanner errors if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, model, logger) } @@ -687,9 +709,21 @@ func HandleGeminiResponsesStream( // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + } + close(responseChan) + }() + defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() + scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) scanner.Buffer(buf, 10*1024*1024) @@ -706,11 +740,11 @@ func HandleGeminiResponsesStream( var lastUsageMetadata *GenerateContentResponseUsageMetadata for scanner.Scan() { - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } + line := scanner.Text() // Skip empty lines and comments @@ -829,43 +863,50 @@ func HandleGeminiResponsesStream( // Handle scanner errors if err := scanner.Err(); err != nil { + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, model, logger) - } else { - // Finalize the stream by closing any open items - finalResponses := FinalizeGeminiResponsesStream(streamState, lastUsageMetadata, sequenceNumber) - for i, finalResponse := range finalResponses { - finalResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), - } - - if postResponseConverter != nil { - finalResponse = postResponseConverter(finalResponse) - if finalResponse == nil { - logger.Warn("postResponseConverter returned nil; skipping final response") - continue - } - } - - chunkIndex++ - sequenceNumber++ + return + } + // Finalize the stream by closing any open items + finalResponses := FinalizeGeminiResponsesStream(streamState, lastUsageMetadata, sequenceNumber) + for i, finalResponse := range finalResponses { + if finalResponse == nil { + logger.Warn("FinalizeGeminiResponsesStream returned nil; skipping final response") + continue + } + finalResponse.ExtraFields = schemas.BifrostResponseExtraFields{ + RequestType: schemas.ResponsesStreamRequest, + Provider: providerName, + ModelRequested: model, + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), + } - if sendBackRawResponse { - finalResponse.ExtraFields.RawResponse = "{}" // Final event has no payload + if postResponseConverter != nil { + finalResponse = postResponseConverter(finalResponse) + if finalResponse == nil { + logger.Warn("postResponseConverter returned nil; skipping final response") + continue } + } - // Set final latency on the last response (completed event) - if i == len(finalResponses)-1 { - finalResponse.ExtraFields.Latency = time.Since(startTime).Milliseconds() - } + chunkIndex++ + sequenceNumber++ - ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil), responseChan) + if sendBackRawResponse { + finalResponse.ExtraFields.RawResponse = "{}" // Final event has no payload } + isLast := i == len(finalResponses)-1 + // Set final latency on the last response (completed event) + if isLast { + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + finalResponse.ExtraFields.Latency = time.Since(startTime).Milliseconds() + } + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil), responseChan) } }() @@ -1091,8 +1132,20 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo // Start streaming in a goroutine go func() { + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + } + close(responseChan) + }() + defer providerUtils.ReleaseStreamingResponse(resp) - defer close(responseChan) + + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) // Increase buffer size to handle large chunks (especially for audio data) @@ -1104,11 +1157,11 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo lastChunkTime := startTime for scanner.Scan() { - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } + line := scanner.Text() // Skip empty lines @@ -1208,31 +1261,33 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil), responseChan) } } - // Handle scanner errors if err := scanner.Err(); err != nil { - provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, provider.logger) - } else { - response := &schemas.BifrostSpeechStreamResponse{ - Type: schemas.SpeechStreamResponseTypeDone, - Usage: usage, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), - }, - } - - // Set raw request if enabled - if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { - providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) + if ctx.Err() != nil { + return } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil), responseChan) + provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, provider.logger) + return + } + response := &schemas.BifrostSpeechStreamResponse{ + Type: schemas.SpeechStreamResponseTypeDone, + Usage: usage, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.SpeechStreamRequest, + Provider: providerName, + ModelRequested: request.Model, + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), + }, } + // Set raw request if enabled + if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { + providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil), responseChan) }() return responseChan, nil @@ -1360,8 +1415,18 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) // Increase buffer size to handle large chunks (especially for audio data) @@ -1375,11 +1440,11 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, var fullTranscriptionText string for scanner.Scan() { - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } + line := scanner.Text() // Skip empty lines @@ -1489,34 +1554,39 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, // Handle scanner errors if err := scanner.Err(); err != nil { + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, provider.logger) - } else { - response := &schemas.BifrostTranscriptionStreamResponse{ - Type: schemas.TranscriptionStreamResponseTypeDone, - Text: fullTranscriptionText, - Usage: &schemas.TranscriptionUsage{ - Type: "tokens", - InputTokens: usage.InputTokens, - OutputTokens: usage.OutputTokens, - TotalTokens: usage.TotalTokens, - }, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), - }, - } + return + } + response := &schemas.BifrostTranscriptionStreamResponse{ + Type: schemas.TranscriptionStreamResponseTypeDone, + Text: fullTranscriptionText, + Usage: &schemas.TranscriptionUsage{ + Type: "tokens", + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + TotalTokens: usage.TotalTokens, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.TranscriptionStreamRequest, + Provider: providerName, + ModelRequested: request.Model, + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), + }, + } - // Set raw request if enabled - if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { - providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) - } - ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response), responseChan) + // Set raw request if enabled + if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { + providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response), responseChan) + }() return responseChan, nil diff --git a/core/providers/mistral/mistral.go b/core/providers/mistral/mistral.go index 1faa6f04aa..34e6c32d13 100644 --- a/core/providers/mistral/mistral.go +++ b/core/providers/mistral/mistral.go @@ -438,8 +438,18 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) // Increase buffer size to handle large chunks @@ -454,13 +464,12 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext var currentData string for scanner.Scan() { - // Check if context is done before processing - select { - case <-ctx.Done(): + + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } - + line := scanner.Text() // Skip empty lines (event delimiter) @@ -498,6 +507,11 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext // Handle scanner errors if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, provider.logger) } diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index aeb116a243..f4b2e61777 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -446,8 +446,18 @@ func HandleOpenAITextCompletionStreaming( // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) @@ -462,13 +472,10 @@ func HandleOpenAITextCompletionStreaming( lastChunkTime := startTime for scanner.Scan() { - // Check if context is done before processing - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } - line := scanner.Text() // Skip empty lines and comments @@ -597,21 +604,31 @@ func HandleOpenAITextCompletionStreaming( // Handle scanner errors first if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, logger) - } else { - response := providerUtils.CreateBifrostTextCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.TextCompletionStreamRequest, providerName, request.Model) - if postResponseConverter != nil { - response = postResponseConverter(response) - } - // Set raw request if enabled - if sendBackRawRequest { - providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) + return + } + + response := providerUtils.CreateBifrostTextCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.TextCompletionStreamRequest, providerName, request.Model) + if postResponseConverter != nil { + response = postResponseConverter(response) + if response == nil { + logger.Warn("postResponseConverter returned nil; leaving chunk unmodified") + return } - response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(response, nil, nil, nil, nil), responseChan) } + // Set raw request if enabled + if sendBackRawRequest { + providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) + } + response.ExtraFields.Latency = time.Since(startTime).Milliseconds() + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(response, nil, nil, nil, nil), responseChan) }() return responseChan, nil @@ -880,10 +897,26 @@ func HandleOpenAIChatCompletionStreaming( // Create response channel responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) + // Determine request type for cleanup + streamRequestType := schemas.ChatCompletionStreamRequest + if isResponsesToChatCompletionsFallback { + streamRequestType = schemas.ResponsesStreamRequest + } + // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, streamRequestType, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, streamRequestType, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) @@ -899,13 +932,10 @@ func HandleOpenAIChatCompletionStreaming( var messageID string for scanner.Scan() { - // Check if context is done before processing - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } - line := scanner.Text() // Skip empty lines and comments @@ -940,7 +970,7 @@ func HandleOpenAIChatCompletionStreaming( bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ Provider: providerName, ModelRequested: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, + RequestType: streamRequestType, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, logger) @@ -964,7 +994,7 @@ func HandleOpenAIChatCompletionStreaming( IsBifrostError: false, Error: &schemas.ErrorField{}, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, + RequestType: streamRequestType, Provider: providerName, ModelRequested: request.Model, }, @@ -985,7 +1015,7 @@ func HandleOpenAIChatCompletionStreaming( return } - response.ExtraFields.RequestType = schemas.ResponsesStreamRequest + response.ExtraFields.RequestType = streamRequestType response.ExtraFields.Provider = providerName response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = response.SequenceNumber @@ -1097,10 +1127,18 @@ func HandleOpenAIChatCompletionStreaming( // Handle scanner errors first if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, logger) - } else if !isResponsesToChatCompletionsFallback { - response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, request.Model) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, streamRequestType, providerName, request.Model, logger) + return + } + + if !isResponsesToChatCompletionsFallback { + response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, streamRequestType, providerName, request.Model) if postResponseConverter != nil { response = postResponseConverter(response) } @@ -1359,8 +1397,18 @@ func HandleOpenAIResponsesStreaming( // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) @@ -1370,13 +1418,10 @@ func HandleOpenAIResponsesStreaming( lastChunkTime := startTime for scanner.Scan() { - // Check if context is done before processing - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } - line := scanner.Text() // Skip empty lines, comments, and event lines @@ -1476,6 +1521,11 @@ func HandleOpenAIResponsesStreaming( } // Handle scanner errors first if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, logger) } @@ -1827,8 +1877,18 @@ func HandleOpenAISpeechStreamRequest( // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) chunkIndex := -1 @@ -1837,11 +1897,9 @@ func HandleOpenAISpeechStreamRequest( lastChunkTime := startTime for scanner.Scan() { - // Check if context is done before processing - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } line := scanner.Text() @@ -1931,6 +1989,11 @@ func HandleOpenAISpeechStreamRequest( // Handle scanner errors if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, logger) } @@ -2196,8 +2259,18 @@ func HandleOpenAITranscriptionStreamRequest( // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) chunkIndex := -1 @@ -2206,13 +2279,11 @@ func HandleOpenAITranscriptionStreamRequest( lastChunkTime := startTime for scanner.Scan() { - // Check if context is done before processing - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } - + line := scanner.Text() // Skip empty lines and comments @@ -2295,6 +2366,11 @@ func HandleOpenAITranscriptionStreamRequest( // Handle scanner errors if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, logger) } diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index 15d77a57ca..cebcc7d056 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -980,6 +980,113 @@ func ProcessAndSendBifrostError( } } +// SetupStreamCancellation spawns a goroutine that closes the body stream when +// the context is cancelled or deadline exceeded, unblocking any blocked Read/Scan operations. +// Returns a cleanup function that MUST be called when streaming is done to +// prevent the goroutine from closing the stream during normal operation. +// Works with both fasthttp's BodyStream() (io.Reader) and net/http's resp.Body (io.ReadCloser). +func SetupStreamCancellation(ctx context.Context, bodyStream io.Reader, logger schemas.Logger) (cleanup func()) { + done := make(chan struct{}) + + go func() { + select { + case <-ctx.Done(): + // Context cancelled or deadline exceeded - close the body stream to unblock reads + if closer, ok := bodyStream.(io.Closer); ok { + if err := closer.Close(); err != nil && logger != nil { + logger.Debug(fmt.Sprintf("Error closing body stream on context done: %v", err)) + } + } + case <-done: + // Normal completion - do nothing + } + }() + + return func() { close(done) } +} + +// HandleStreamCancellation should be called when a streaming goroutine exits +// due to context cancellation. It ensures proper cleanup by: +// 1. Checking if StreamEndIndicator was already set (to avoid duplicate handling) +// 2. Setting StreamEndIndicator to true +// 3. Sending a cancellation error through PostHook chain +// +// This is critical for the logging plugin to update log status from "processing" to "error" +// when a client disconnects mid-stream. +func HandleStreamCancellation( + ctx *schemas.BifrostContext, + postHookRunner schemas.PostHookRunner, + responseChan chan *schemas.BifrostStream, + provider schemas.ModelProvider, + model string, + requestType schemas.RequestType, + logger schemas.Logger, +) { + // Check if already handled (StreamEndIndicator already set) + if indicator := ctx.GetAndSetValue(schemas.BifrostContextKeyStreamEndIndicator, true); indicator != nil { + if set, ok := indicator.(bool); ok && set { + return // Already handled + } + } + // Create cancellation error + cancelErr := &schemas.BifrostError{ + StatusCode: schemas.Ptr(499), // Client Closed Request + Error: &schemas.ErrorField{ + Message: "Request cancelled: client disconnected", + Type: schemas.Ptr(schemas.RequestCancelled), + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + Provider: provider, + ModelRequested: model, + RequestType: requestType, + }, + } + + // Send through PostHook chain - this updates the log to "error" status + ProcessAndSendBifrostError(ctx, postHookRunner, cancelErr, responseChan, logger) +} + +// HandleStreamTimeout should be called when a streaming goroutine exits +// due to context deadline exceeded. It ensures proper cleanup by: +// 1. Checking if StreamEndIndicator was already set (to avoid duplicate handling) +// 2. Setting StreamEndIndicator to true +// 3. Sending a timeout error through PostHook chain +// +// This is critical for the logging plugin to update log status from "processing" to "error" +// when a request times out mid-stream. +func HandleStreamTimeout( + ctx *schemas.BifrostContext, + postHookRunner schemas.PostHookRunner, + responseChan chan *schemas.BifrostStream, + provider schemas.ModelProvider, + model string, + requestType schemas.RequestType, + logger schemas.Logger, +) { + // Check if already handled (StreamEndIndicator already set) + if indicator := ctx.GetAndSetValue(schemas.BifrostContextKeyStreamEndIndicator, true); indicator != nil { + if set, ok := indicator.(bool); ok && set { + return // Already handled + } + } + // Create timeout error + timeoutErr := &schemas.BifrostError{ + StatusCode: schemas.Ptr(504), // Gateway Timeout + Error: &schemas.ErrorField{ + Message: "Request timed out: deadline exceeded", + Type: schemas.Ptr(schemas.RequestTimedOut), + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + Provider: provider, + ModelRequested: model, + RequestType: requestType, + }, + } + + // Send through PostHook chain - this updates the log to "error" status + ProcessAndSendBifrostError(ctx, postHookRunner, timeoutErr, responseChan, logger) +} + // ProcessAndSendError handles post-hook processing and sends the error to the channel. // This utility reduces code duplication across streaming implementations by encapsulating // the common pattern of running post hooks, handling errors, and sending responses with @@ -1148,6 +1255,7 @@ func ReleaseStreamingResponse(resp *fasthttp.Response) { // Drain any remaining data from the body stream before releasing // This prevents "whitespace in header" errors when the response is reused if resp.BodyStream() != nil { + // Drain the body stream io.Copy(io.Discard, resp.BodyStream()) } fasthttp.ReleaseResponse(resp) diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index eb60899476..dd1e669a9c 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -448,6 +448,7 @@ type BifrostCacheDebug struct { const ( RequestCancelled = "request_cancelled" + RequestTimedOut = "request_timed_out" ) // BifrostStream represents a stream of responses from the Bifrost system. diff --git a/core/schemas/context.go b/core/schemas/context.go index e69cb47d77..c39afd8713 100644 --- a/core/schemas/context.go +++ b/core/schemas/context.go @@ -231,6 +231,23 @@ func (bc *BifrostContext) SetValue(key, value any) { bc.userValues[key] = value } +// GetAndSetValue gets a value from the internal userValues map and sets it +func (bc *BifrostContext) GetAndSetValue(key any, value any) any { + bc.valuesMu.Lock() + defer bc.valuesMu.Unlock() + // Check if the key is a reserved key + if bc.blockRestrictedWrites.Load() && slices.Contains(reservedKeys, key) { + // we silently drop writes for these reserved keys + return bc.userValues[key] + } + if bc.userValues == nil { + bc.userValues = make(map[any]any) + } + oldValue := bc.userValues[key] + bc.userValues[key] = value + return oldValue +} + // GetUserValues returns a copy of all user-set values in this context. // If the parent is also a PluginContext, the values are merged with parent values // (this context's values take precedence over parent values). diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go index 08dfef5b4b..3c19880a6d 100644 --- a/core/schemas/mcp.go +++ b/core/schemas/mcp.go @@ -66,6 +66,7 @@ type MCPClientConfig struct { // - nil/omitted => treated as [] (no tools) // - ["tool1", "tool2"] => auto-execute only the specified tools // Note: If a tool is in ToolsToAutoExecute but not in ToolsToExecute, it will be skipped. + ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) } // MCPConnectionType defines the communication protocol for MCP connections diff --git a/core/version b/core/version index 3336003dcc..06c7347f09 100644 --- a/core/version +++ b/core/version @@ -1 +1 @@ -1.3.7 +1.3.8 \ No newline at end of file diff --git a/framework/changelog.md b/framework/changelog.md index e69de29bb2..c8195a0ce2 100644 --- a/framework/changelog.md +++ b/framework/changelog.md @@ -0,0 +1 @@ +- chore: updated core version to 1.3.8 \ No newline at end of file diff --git a/framework/version b/framework/version index c04c650a7a..5975b143a0 100644 --- a/framework/version +++ b/framework/version @@ -1 +1 @@ -1.2.7 +1.2.8 \ No newline at end of file diff --git a/plugins/governance/changelog.md b/plugins/governance/changelog.md index e69de29bb2..66f23e0307 100644 --- a/plugins/governance/changelog.md +++ b/plugins/governance/changelog.md @@ -0,0 +1 @@ +- chore: updated core version to 1.3.8 and framework version to 1.2.8 \ No newline at end of file diff --git a/plugins/governance/version b/plugins/governance/version index 721b9931f4..5596554988 100644 --- a/plugins/governance/version +++ b/plugins/governance/version @@ -1 +1 @@ -1.4.8 \ No newline at end of file +1.4.9 \ No newline at end of file diff --git a/plugins/jsonparser/changelog.md b/plugins/jsonparser/changelog.md index e69de29bb2..66f23e0307 100644 --- a/plugins/jsonparser/changelog.md +++ b/plugins/jsonparser/changelog.md @@ -0,0 +1 @@ +- chore: updated core version to 1.3.8 and framework version to 1.2.8 \ No newline at end of file diff --git a/plugins/jsonparser/version b/plugins/jsonparser/version index be05bba982..721b9931f4 100644 --- a/plugins/jsonparser/version +++ b/plugins/jsonparser/version @@ -1 +1 @@ -1.4.7 +1.4.8 \ No newline at end of file diff --git a/plugins/logging/changelog.md b/plugins/logging/changelog.md index e69de29bb2..66f23e0307 100644 --- a/plugins/logging/changelog.md +++ b/plugins/logging/changelog.md @@ -0,0 +1 @@ +- chore: updated core version to 1.3.8 and framework version to 1.2.8 \ No newline at end of file diff --git a/plugins/logging/version b/plugins/logging/version index be05bba982..721b9931f4 100644 --- a/plugins/logging/version +++ b/plugins/logging/version @@ -1 +1 @@ -1.4.7 +1.4.8 \ No newline at end of file diff --git a/plugins/maxim/changelog.md b/plugins/maxim/changelog.md index e69de29bb2..fee2201bf1 100644 --- a/plugins/maxim/changelog.md +++ b/plugins/maxim/changelog.md @@ -0,0 +1 @@ +- chore: updated core version to 1.3.8 and framework version to 1.2.8 diff --git a/plugins/maxim/version b/plugins/maxim/version index f01291b87f..fa5512aeca 100644 --- a/plugins/maxim/version +++ b/plugins/maxim/version @@ -1 +1 @@ -1.5.7 +1.5.8 \ No newline at end of file diff --git a/plugins/mocker/changelog.md b/plugins/mocker/changelog.md index e69de29bb2..66f23e0307 100644 --- a/plugins/mocker/changelog.md +++ b/plugins/mocker/changelog.md @@ -0,0 +1 @@ +- chore: updated core version to 1.3.8 and framework version to 1.2.8 \ No newline at end of file diff --git a/plugins/mocker/version b/plugins/mocker/version index be05bba982..721b9931f4 100644 --- a/plugins/mocker/version +++ b/plugins/mocker/version @@ -1 +1 @@ -1.4.7 +1.4.8 \ No newline at end of file diff --git a/plugins/otel/changelog.md b/plugins/otel/changelog.md index e69de29bb2..66f23e0307 100644 --- a/plugins/otel/changelog.md +++ b/plugins/otel/changelog.md @@ -0,0 +1 @@ +- chore: updated core version to 1.3.8 and framework version to 1.2.8 \ No newline at end of file diff --git a/plugins/otel/version b/plugins/otel/version index 2bf1ca5f54..db15278970 100644 --- a/plugins/otel/version +++ b/plugins/otel/version @@ -1 +1 @@ -1.1.7 +1.1.8 \ No newline at end of file diff --git a/plugins/semanticcache/changelog.md b/plugins/semanticcache/changelog.md index e69de29bb2..66f23e0307 100644 --- a/plugins/semanticcache/changelog.md +++ b/plugins/semanticcache/changelog.md @@ -0,0 +1 @@ +- chore: updated core version to 1.3.8 and framework version to 1.2.8 \ No newline at end of file diff --git a/plugins/semanticcache/version b/plugins/semanticcache/version index be05bba982..721b9931f4 100644 --- a/plugins/semanticcache/version +++ b/plugins/semanticcache/version @@ -1 +1 @@ -1.4.7 +1.4.8 \ No newline at end of file diff --git a/plugins/telemetry/changelog.md b/plugins/telemetry/changelog.md index e69de29bb2..66f23e0307 100644 --- a/plugins/telemetry/changelog.md +++ b/plugins/telemetry/changelog.md @@ -0,0 +1 @@ +- chore: updated core version to 1.3.8 and framework version to 1.2.8 \ No newline at end of file diff --git a/plugins/telemetry/main.go b/plugins/telemetry/main.go index d47dae53c9..b1b4898c5a 100644 --- a/plugins/telemetry/main.go +++ b/plugins/telemetry/main.go @@ -320,45 +320,46 @@ func (p *PrometheusPlugin) PostHook(ctx *schemas.BifrostContext, result *schemas customerID := getStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-customer-id")) customerName := getStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-customer-name")) - // Calculate cost and record metrics in a separate goroutine to avoid blocking the main thread - go func() { - labelValues := map[string]string{ - "provider": string(provider), - "model": model, - "method": string(requestType), - "virtual_key_id": virtualKeyID, - "virtual_key_name": virtualKeyName, - "selected_key_id": selectedKeyID, - "selected_key_name": selectedKeyName, - "number_of_retries": strconv.Itoa(numberOfRetries), - "fallback_index": strconv.Itoa(fallbackIndex), - "team_id": teamID, - "team_name": teamName, - "customer_id": customerID, - "customer_name": customerName, - } + // Extract ALL context values BEFORE spawning the goroutine. + labelValues := map[string]string{ + "provider": string(provider), + "model": model, + "method": string(requestType), + "virtual_key_id": virtualKeyID, + "virtual_key_name": virtualKeyName, + "selected_key_id": selectedKeyID, + "selected_key_name": selectedKeyName, + "number_of_retries": strconv.Itoa(numberOfRetries), + "fallback_index": strconv.Itoa(fallbackIndex), + "team_id": teamID, + "team_name": teamName, + "customer_id": customerID, + "customer_name": customerName, + } - // Get all prometheus labels from context - for _, key := range p.customLabels { - if value := ctx.Value(schemas.BifrostContextKey(key)); value != nil { - if strValue, ok := value.(string); ok { - labelValues[key] = strValue - } + // Get all custom prometheus labels from context BEFORE the goroutine + for _, key := range p.customLabels { + if value := ctx.Value(schemas.BifrostContextKey(key)); value != nil { + if strValue, ok := value.(string); ok { + labelValues[key] = strValue } } + } + + // Get label values in the correct order (cache_type will be handled separately for cache hits) + promLabelValues := getPrometheusLabelValues(append(p.defaultBifrostLabels, p.customLabels...), labelValues) - // Get label values in the correct order (cache_type will be handled separately for cache hits) - promLabelValues := getPrometheusLabelValues(append(p.defaultBifrostLabels, p.customLabels...), labelValues) + // Extract stream end indicator BEFORE the goroutine + streamEndIndicatorValue := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator) + isFinalChunk, hasFinalChunkIndicator := streamEndIndicatorValue.(bool) + // Calculate cost and record metrics in a separate goroutine to avoid blocking the main thread + go func() { // For streaming requests, handle per-token metrics for intermediate chunks if bifrost.IsStreamRequestType(requestType) { - // Determine if this is the final chunk - streamEndIndicatorValue := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator) - isFinalChunk, ok := streamEndIndicatorValue.(bool) - // For intermediate chunks, record per-token metrics and exit. // The final chunk will fall through to record full request metrics. - if !ok || !isFinalChunk { + if !hasFinalChunkIndicator || !isFinalChunk { // Record metrics for the first token if result != nil { extraFields := result.GetExtraFields() diff --git a/plugins/telemetry/version b/plugins/telemetry/version index 721b9931f4..5596554988 100644 --- a/plugins/telemetry/version +++ b/plugins/telemetry/version @@ -1 +1 @@ -1.4.8 \ No newline at end of file +1.4.9 \ No newline at end of file diff --git a/tests/integrations/typescript/config.json b/tests/integrations/typescript/config.json index a927698d52..cab65df51d 100644 --- a/tests/integrations/typescript/config.json +++ b/tests/integrations/typescript/config.json @@ -4,6 +4,7 @@ "openai": { "keys": [ { + "name": "OpenAI API Key", "value": "env.OPENAI_API_KEY", "weight": 1, "use_for_batch_api": true @@ -16,6 +17,7 @@ "anthropic": { "keys": [ { + "name": "Anthropic API Key", "value": "env.ANTHROPIC_API_KEY", "weight": 1, "use_for_batch_api": true @@ -28,6 +30,7 @@ "gemini": { "keys": [ { + "name": "Gemini API Key", "value": "env.GEMINI_API_KEY", "weight": 1, "use_for_batch_api": true @@ -40,7 +43,8 @@ "vertex": { "keys": [ { - "vertex_key_config": { + "name": "Vertex API Key", + "vertex_key_config": { "project_id": "env.GOOGLE_PROJECT_ID", "region": "env.GOOGLE_LOCATION" }, @@ -54,6 +58,7 @@ "mistral": { "keys": [ { + "name": "Mistral API Key", "value": "env.MISTRAL_API_KEY", "weight": 1 } @@ -65,6 +70,7 @@ "cohere": { "keys": [ { + "name": "Cohere API Key", "value": "env.COHERE_API_KEY", "weight": 1 } @@ -76,6 +82,7 @@ "groq": { "keys": [ { + "name": "Groq API Key", "value": "env.GROQ_API_KEY", "weight": 1 } @@ -87,6 +94,7 @@ "perplexity": { "keys": [ { + "name": "Perplexity API Key", "value": "env.PERPLEXITY_API_KEY", "weight": 1 } @@ -98,6 +106,7 @@ "cerebras": { "keys": [ { + "name": "Cerebras API Key", "value": "env.CEREBRAS_API_KEY", "weight": 1 } @@ -109,6 +118,7 @@ "openrouter": { "keys": [ { + "name": "OpenRouter API Key", "value": "env.OPENROUTER_API_KEY", "weight": 1 } @@ -120,6 +130,7 @@ "azure": { "keys": [ { + "name": "Azure OpenAI API Key", "value": "env.AZURE_OPENAI_API_KEY", "azure_key_config": { "endpoint": "env.AZURE_OPENAI_ENDPOINT", @@ -135,6 +146,7 @@ "bedrock": { "keys": [ { + "name": "Bedrock API Key", "bedrock_key_config": { "access_key": "env.AWS_ACCESS_KEY_ID", "secret_key": "env.AWS_SECRET_ACCESS_KEY", diff --git a/tests/integrations/typescript/tests/test-anthropic.test.ts b/tests/integrations/typescript/tests/test-anthropic.test.ts index 0db9dcb2a2..611b4e00cf 100644 --- a/tests/integrations/typescript/tests/test-anthropic.test.ts +++ b/tests/integrations/typescript/tests/test-anthropic.test.ts @@ -293,6 +293,68 @@ describe('Anthropic SDK Integration Tests', () => { }) }) + // ============================================================================ + // Streaming Client Disconnect Tests + // ============================================================================ + + describe('Streaming Chat - Client Disconnect', () => { + it('should handle client disconnect mid-stream', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + const abortController = new AbortController() + + // Request a longer response to ensure we have time to abort mid-stream + const stream = client.messages.stream({ + model, + max_tokens: 1000, + messages: [ + { + role: 'user', + content: 'Write a detailed essay about the history of computing, including at least 10 paragraphs.', + }, + ], + }, { + signal: abortController.signal, + }) + + let chunkCount = 0 + let content = '' + let wasAborted = false + + try { + for await (const event of stream) { + chunkCount++ + if (event.type === 'content_block_delta' && event.delta.type === 'text_delta') { + content += event.delta.text + } + + // Abort after receiving a few chunks + if (chunkCount >= 5) { + abortController.abort() + } + } + } catch (error) { + wasAborted = true + expect(error).toBeDefined() + // The error should be an AbortError or contain abort-related message + const errorMessage = error instanceof Error ? error.message.toLowerCase() : String(error).toLowerCase() + const isAbortError = errorMessage.includes('abort') || + errorMessage.includes('cancel') || + error instanceof DOMException || + (error as { name?: string })?.name === 'AbortError' + expect(isAbortError).toBe(true) + } + + // Verify we received some content before aborting + expect(chunkCount).toBeGreaterThanOrEqual(5) + expect(content.length).toBeGreaterThan(0) + expect(wasAborted).toBe(true) + console.log(`✅ Streaming client disconnect passed for anthropic/${model} (${chunkCount} chunks before abort)`) + }) + }) + // ============================================================================ // Tool Calling Tests // ============================================================================ @@ -891,6 +953,71 @@ describe('Anthropic SDK Integration Tests', () => { }) }) + describe('Extended Thinking Streaming - Client Disconnect', () => { + it('should handle client disconnect mid-stream during extended thinking', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'thinking') + + if (!model) { + console.log('⚠️ Skipping thinking streaming disconnect test: No thinking model configured') + return + } + + const abortController = new AbortController() + + try { + // Use type assertion for beta thinking feature + const stream = client.messages.stream({ + model, + max_tokens: 5000, + thinking: { + type: 'enabled', + budget_tokens: 3000, + }, + messages: [ + { + role: 'user', + content: 'Solve this complex problem step by step: A train leaves Station A at 8:00 AM traveling at 60 mph. Another train leaves Station B, 300 miles away, at 9:00 AM traveling toward Station A at 80 mph. At what time will they meet? Show all your detailed reasoning.', + }, + ], + } as never, { + signal: abortController.signal, + }) + + let chunkCount = 0 + let wasAborted = false + + try { + for await (const event of stream) { + chunkCount++ + + // Abort after receiving a few chunks + if (chunkCount >= 10) { + abortController.abort() + } + } + } catch (error) { + wasAborted = true + expect(error).toBeDefined() + const errorMessage = error instanceof Error ? error.message.toLowerCase() : String(error).toLowerCase() + const isAbortError = errorMessage.includes('abort') || + errorMessage.includes('cancel') || + error instanceof DOMException || + (error as { name?: string })?.name === 'AbortError' + expect(isAbortError).toBe(true) + } + + expect(chunkCount).toBeGreaterThanOrEqual(10) + expect(wasAborted).toBe(true) + console.log(`✅ Extended thinking streaming client disconnect passed for anthropic/${model} (${chunkCount} chunks before abort)`) + } catch (error) { + console.log(`⚠️ Extended thinking streaming disconnect test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + // ============================================================================ // Files API Tests // ============================================================================ diff --git a/tests/integrations/typescript/tests/test-bedrock.test.ts b/tests/integrations/typescript/tests/test-bedrock.test.ts index 04e353bead..b82d04c7b4 100644 --- a/tests/integrations/typescript/tests/test-bedrock.test.ts +++ b/tests/integrations/typescript/tests/test-bedrock.test.ts @@ -321,6 +321,84 @@ describe('Bedrock SDK Integration Tests', () => { ) }) + // ============================================================================ + // Streaming Client Disconnect Tests + // ============================================================================ + + describe('Streaming Chat - Client Disconnect', () => { + const testCases = getCrossProviderParamsWithVkForScenario('streaming', ['bedrock']) + + it.each(testCases)( + 'should handle client disconnect mid-stream - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for streaming') + return + } + + const client = getBedrockRuntimeClient() + const abortController = new AbortController() + + // Request a longer response to ensure we have time to abort mid-stream + const messages = convertToBedrockMessages([ + { role: 'user', content: 'Write a detailed essay about the history of computing, including at least 10 paragraphs.' }, + ]) + const modelId = formatProviderModel(provider, model) + + const command = new ConverseStreamCommand({ + modelId, + messages, + inferenceConfig: { maxTokens: 1000 }, + }) + + const response = await client.send(command, { + abortSignal: abortController.signal, + }) + + let chunkCount = 0 + let content = '' + let wasAborted = false + + try { + if (response.stream) { + for await (const event of response.stream) { + chunkCount++ + if (event.contentBlockDelta) { + const delta = event.contentBlockDelta.delta + if (delta && 'text' in delta && delta.text) { + content += delta.text + } + } + + // Abort after receiving a few chunks + if (chunkCount >= 5) { + abortController.abort() + } + } + } + } catch (error) { + wasAborted = true + expect(error).toBeDefined() + // The error should be an AbortError or contain abort-related message + const errorMessage = error instanceof Error ? error.message.toLowerCase() : String(error).toLowerCase() + const errorName = (error as { name?: string })?.name?.toLowerCase() || '' + const isAbortError = errorMessage.includes('abort') || + errorMessage.includes('cancel') || + errorName.includes('abort') || + error instanceof DOMException || + (error as { name?: string })?.name === 'AbortError' + expect(isAbortError).toBe(true) + } + + // Verify we received some content before aborting + expect(chunkCount).toBeGreaterThanOrEqual(5) + expect(content.length).toBeGreaterThan(0) + expect(wasAborted).toBe(true) + console.log(`✅ Streaming client disconnect passed for ${modelId} (${chunkCount} chunks before abort)`) + } + ) + }) + // ============================================================================ // Tool Calling Tests // ============================================================================ diff --git a/tests/integrations/typescript/tests/test-langchain.test.ts b/tests/integrations/typescript/tests/test-langchain.test.ts index ae92920351..3ff6fdd369 100644 --- a/tests/integrations/typescript/tests/test-langchain.test.ts +++ b/tests/integrations/typescript/tests/test-langchain.test.ts @@ -205,6 +205,68 @@ describe('LangChain.js Integration Tests', () => { }) }) + describe('Streaming Chat - Client Disconnect', () => { + it('should handle client disconnect mid-stream', async () => { + if (skipTests) return + + const baseUrl = getIntegrationUrl('openai') + const apiKey = hasApiKey('openai') ? getApiKey('openai') : 'dummy-key' + const modelName = getProviderModel('openai', 'chat') + + // Create model with longer max tokens for a longer response + const model = new ChatOpenAI({ + modelName, + openAIApiKey: apiKey, + configuration: { + baseURL: baseUrl, + }, + maxTokens: 1000, + timeout: 300000, + }) + + const abortController = new AbortController() + const messages = convertToLangChainMessages([ + { role: 'user', content: 'Write a detailed essay about the history of computing, including at least 10 paragraphs.' }, + ]) + + const stream = await model.stream(messages, { + signal: abortController.signal, + }) + + let chunkCount = 0 + let content = '' + let wasAborted = false + + try { + for await (const chunk of stream) { + chunkCount++ + if (chunk.content) { + content += typeof chunk.content === 'string' ? chunk.content : JSON.stringify(chunk.content) + } + + // Abort after receiving a few chunks + if (chunkCount >= 3) { + abortController.abort() + } + } + } catch (error) { + wasAborted = true + expect(error).toBeDefined() + const errorMessage = error instanceof Error ? error.message.toLowerCase() : String(error).toLowerCase() + const isAbortError = errorMessage.includes('abort') || + errorMessage.includes('cancel') || + error instanceof DOMException || + (error as { name?: string })?.name === 'AbortError' + expect(isAbortError).toBe(true) + } + + expect(chunkCount).toBeGreaterThanOrEqual(3) + expect(content.length).toBeGreaterThan(0) + expect(wasAborted).toBe(true) + console.log(`✅ LangChain OpenAI streaming client disconnect passed (${chunkCount} chunks before abort)`) + }) + }) + describe('Tool Calling', () => { it('should make tool calls', async () => { if (skipTests) return @@ -313,6 +375,66 @@ describe('LangChain.js Integration Tests', () => { }) }) + describe('Streaming Chat - Client Disconnect', () => { + it('should handle client disconnect mid-stream', async () => { + if (skipTests) return + + const baseUrl = getIntegrationUrl('anthropic') + const apiKey = hasApiKey('anthropic') ? getApiKey('anthropic') : 'dummy-key' + const modelName = getProviderModel('anthropic', 'chat') + + // Create model with longer max tokens for a longer response + const model = new ChatAnthropic({ + modelName, + anthropicApiKey: apiKey, + anthropicApiUrl: baseUrl, + maxTokens: 1000, + maxRetries: 3, + }) + + const abortController = new AbortController() + const messages = convertToLangChainMessages([ + { role: 'user', content: 'Write a detailed essay about the history of computing, including at least 10 paragraphs.' }, + ]) + + const stream = await model.stream(messages, { + signal: abortController.signal, + }) + + let chunkCount = 0 + let content = '' + let wasAborted = false + + try { + for await (const chunk of stream) { + chunkCount++ + if (chunk.content) { + content += typeof chunk.content === 'string' ? chunk.content : JSON.stringify(chunk.content) + } + + // Abort after receiving a few chunks + if (chunkCount >= 5) { + abortController.abort() + } + } + } catch (error) { + wasAborted = true + expect(error).toBeDefined() + const errorMessage = error instanceof Error ? error.message.toLowerCase() : String(error).toLowerCase() + const isAbortError = errorMessage.includes('abort') || + errorMessage.includes('cancel') || + error instanceof DOMException || + (error as { name?: string })?.name === 'AbortError' + expect(isAbortError).toBe(true) + } + + expect(chunkCount).toBeGreaterThanOrEqual(5) + expect(content.length).toBeGreaterThan(0) + expect(wasAborted).toBe(true) + console.log(`✅ LangChain Anthropic streaming client disconnect passed (${chunkCount} chunks before abort)`) + }) + }) + describe('Tool Calling', () => { it('should make tool calls', async () => { if (skipTests) return diff --git a/tests/integrations/typescript/tests/test-openai.test.ts b/tests/integrations/typescript/tests/test-openai.test.ts index 154f260fc6..72bd7cabaa 100644 --- a/tests/integrations/typescript/tests/test-openai.test.ts +++ b/tests/integrations/typescript/tests/test-openai.test.ts @@ -267,6 +267,71 @@ describe('OpenAI SDK Integration Tests', () => { ) }) + // ============================================================================ + // Streaming Client Disconnect Tests + // ============================================================================ + + describe('Streaming Chat - Client Disconnect', () => { + const testCases = getCrossProviderParamsWithVkForScenario('streaming') + + it.each(testCases)( + 'should handle client disconnect mid-stream - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for streaming') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const abortController = new AbortController() + + // Request a longer response to ensure we have time to abort mid-stream + const stream = await client.chat.completions.create({ + model: formatProviderModel(provider, model), + messages: [{ role: 'user', content: 'Write a detailed essay about the history of computing, including at least 10 paragraphs.' }], + max_tokens: 1000, + stream: true, + }, { + signal: abortController.signal, + }) + + let chunkCount = 0 + let content = '' + let wasAborted = false + + try { + for await (const chunk of stream) { + chunkCount++ + const delta = chunk.choices[0]?.delta?.content || '' + content += delta + + // Abort after receiving a few chunks + if (chunkCount >= 3) { + abortController.abort() + } + } + } catch (error) { + // Expect an abort error + wasAborted = true + expect(error).toBeDefined() + // The error should be an AbortError or contain abort-related message + const errorMessage = error instanceof Error ? error.message.toLowerCase() : String(error).toLowerCase() + const isAbortError = errorMessage.includes('abort') || + errorMessage.includes('cancel') || + error instanceof DOMException || + (error as { name?: string })?.name === 'AbortError' + expect(isAbortError).toBe(true) + } + + // Verify we received some content before aborting + expect(chunkCount).toBeGreaterThanOrEqual(3) + expect(content.length).toBeGreaterThan(0) + expect(wasAborted).toBe(true) + console.log(`✅ Streaming client disconnect passed for ${formatProviderModel(provider, model)} (${chunkCount} chunks before abort)`) + } + ) + }) + // ============================================================================ // Tool Calling Tests // ============================================================================ @@ -1434,6 +1499,74 @@ describe('OpenAI SDK Integration Tests', () => { ) }) + describe('Responses API - Streaming Client Disconnect', () => { + const testCases = getCrossProviderParamsWithVkForScenario('responses') + + it.each(testCases)( + 'should handle client disconnect mid-stream - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for responses') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const abortController = new AbortController() + const responses = (client as unknown as { + responses: { + create: (params: unknown, options?: { signal?: AbortSignal }) => Promise> + } + }).responses + + try { + const stream = await responses.create({ + model: formatProviderModel(provider, model), + input: 'Write a detailed essay about the history of artificial intelligence, including at least 10 paragraphs covering different eras and breakthroughs.', + max_output_tokens: 2000, + stream: true, + }, { + signal: abortController.signal, + }) + + let chunkCount = 0 + let content = '' + let wasAborted = false + + try { + for await (const event of stream as AsyncIterable<{ type?: string; delta?: { text?: string } }>) { + chunkCount++ + if (event.type === 'content_part.delta' || event.type === 'response.output_text.delta') { + if (event.delta?.text) { + content += event.delta.text + } + } + + // Abort after receiving a few chunks + if (chunkCount >= 3) { + abortController.abort() + } + } + } catch (error) { + wasAborted = true + expect(error).toBeDefined() + const errorMessage = error instanceof Error ? error.message.toLowerCase() : String(error).toLowerCase() + const isAbortError = errorMessage.includes('abort') || + errorMessage.includes('cancel') || + error instanceof DOMException || + (error as { name?: string })?.name === 'AbortError' + expect(isAbortError).toBe(true) + } + + expect(chunkCount).toBeGreaterThanOrEqual(3) + expect(wasAborted).toBe(true) + console.log(`✅ Responses API streaming client disconnect passed for ${formatProviderModel(provider, model)} (${chunkCount} chunks before abort)`) + } catch (error) { + console.log(`⚠️ Responses API streaming client disconnect test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) + describe('Responses API - Streaming With Tools', () => { const testCases = getCrossProviderParamsWithVkForScenario('responses') diff --git a/transports/bifrost-http/handlers/devpprof.go b/transports/bifrost-http/handlers/devpprof.go index 9f613d9602..fd5db1ed6c 100644 --- a/transports/bifrost-http/handlers/devpprof.go +++ b/transports/bifrost-http/handlers/devpprof.go @@ -328,6 +328,11 @@ func getTopAllocations() []AllocationInfo { continue } + // Skip allocations from the profiler itself + if isProfilerFunction(fn.Name, fn.Filename) { + continue + } + key := fn.Name if existing, ok := allocMap[key]; ok { existing.Bytes += sample.Value[allocSpaceIdx] @@ -415,24 +420,35 @@ func (h *DevPprofHandler) getGoroutines(ctx *fasthttp.RequestCtx) { } rawProfile := buf.String() - groups := parseGoroutineProfile(rawProfile) + allGroups := parseGoroutineProfile(rawProfile) - // Calculate summary + // Filter out profiler goroutines and calculate summary + groups := make([]GoroutineGroup, 0, len(allGroups)) summary := GoroutineSummary{} - for i := range groups { - categorizeGoroutine(&groups[i]) + profilerGoroutineCount := 0 - switch groups[i].Category { + for i := range allGroups { + categorizeGoroutine(&allGroups[i]) + + // Skip profiler's own goroutines + if isProfilerGoroutine(&allGroups[i]) { + profilerGoroutineCount += allGroups[i].Count + continue + } + + groups = append(groups, allGroups[i]) + + switch allGroups[i].Category { case "background": - summary.Background += groups[i].Count + summary.Background += allGroups[i].Count case "per-request": - summary.PerRequest += groups[i].Count + summary.PerRequest += allGroups[i].Count } - if groups[i].WaitMinutes >= 1 { - summary.LongWaiting += groups[i].Count - if groups[i].Category == "per-request" { - summary.PotentiallyStuck += groups[i].Count + if allGroups[i].WaitMinutes >= 1 { + summary.LongWaiting += allGroups[i].Count + if allGroups[i].Category == "per-request" { + summary.PotentiallyStuck += allGroups[i].Count } } } @@ -453,9 +469,16 @@ func (h *DevPprofHandler) getGoroutines(ctx *fasthttp.RequestCtx) { return groups[i].Count > groups[j].Count }) + // Calculate app goroutines (total minus profiler goroutines) + // Calculate total goroutines from profile snapshot + totalFromProfile := 0 + for _, g := range groups { + totalFromProfile += g.Count + } + response := GoroutineProfile{ Timestamp: time.Now().Format(time.RFC3339), - TotalGoroutines: runtime.NumGoroutine(), + TotalGoroutines: totalFromProfile, Groups: groups, Summary: summary, } @@ -476,18 +499,18 @@ func categorizeGoroutine(g *GoroutineGroup) { // Background goroutines - expected to run forever backgroundPatterns := []string{ - "requestWorker", // Provider queue workers - "collectLoop", // Metrics collector - "cleanupWorker", // Various cleanup workers - "startAccumulatorMapCleanup", // Stream accumulator cleanup - "cleanupOldTraces", // Trace store cleanup - "startCleanup", // Generic cleanup - "monitorLoop", // Health monitor - "StartHeartbeat", // WebSocket heartbeat - "time.Sleep", // Ticker-based workers - "runtime.gopark", // Runtime parking (often tickers) - "sync.(*Cond).Wait", // Condition variable waits - "net/http.(*persistConn)", // HTTP connection pool + "requestWorker", // Provider queue workers + "collectLoop", // Metrics collector + "cleanupWorker", // Various cleanup workers + "startAccumulatorMapCleanup", // Stream accumulator cleanup + "cleanupOldTraces", // Trace store cleanup + "startCleanup", // Generic cleanup + "monitorLoop", // Health monitor + "StartHeartbeat", // WebSocket heartbeat + "time.Sleep", // Ticker-based workers + "runtime.gopark", // Runtime parking (often tickers) + "sync.(*Cond).Wait", // Condition variable waits + "net/http.(*persistConn)", // HTTP connection pool "internal/poll.runtime_pollWait", // Network polling } @@ -651,6 +674,41 @@ func parseGoroutineProfile(profile string) []GoroutineGroup { return groups } +// profilerPatterns contains patterns to identify profiler-related code +var profilerPatterns = []string{ + "devpprof", + "pprof.WriteHeapProfile", + "pprof.Lookup", + "profile.Parse", + "MetricsCollector", + "collectLoop", + "getTopAllocations", + "parseGoroutineProfile", + "getGoroutines", + "getCPUSample", +} + +// isProfilerFunction checks if a function belongs to the profiler itself +func isProfilerFunction(funcName, fileName string) bool { + for _, pattern := range profilerPatterns { + if strings.Contains(funcName, pattern) || strings.Contains(fileName, pattern) { + return true + } + } + return false +} + +// isProfilerGoroutine checks if a goroutine belongs to the profiler +func isProfilerGoroutine(g *GoroutineGroup) bool { + stackStr := strings.Join(g.Stack, " ") + for _, pattern := range profilerPatterns { + if strings.Contains(stackStr, pattern) { + return true + } + } + return false +} + // Cleanup stops the metrics collector func (h *DevPprofHandler) Cleanup() { if h.collector != nil { diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index ae6c272135..924cb837ed 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -671,14 +671,12 @@ func (h *CompletionHandler) chatCompletion(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context") return } - if req.Stream != nil && *req.Stream { h.handleStreamingChatCompletion(ctx, bifrostChatReq, bifrostCtx, cancel) return } - defer cancel() // Ensure cleanup on function exit - + // Complete the request resp, bifrostErr := h.client.ChatCompletionRequest(bifrostCtx, bifrostChatReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) @@ -1295,6 +1293,7 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, ge } // Note: OpenAI responses API doesn't use [DONE] marker, it ends when the stream closes // Stream completed normally, Bifrost handles cleanup internally + cancel() }) } diff --git a/transports/changelog.md b/transports/changelog.md index e69de29bb2..699ab0a1d4 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -0,0 +1 @@ +- fix: handles client disconnects and server timeouts gracefully for streaming responses \ No newline at end of file diff --git a/transports/version b/transports/version index f9cfd27b77..dac8c45cbe 100644 --- a/transports/version +++ b/transports/version @@ -1 +1 @@ -1.4.0-prerelease8 +1.4.0-prerelease9 \ No newline at end of file diff --git a/ui/app/workspace/logs/page.tsx b/ui/app/workspace/logs/page.tsx index 8468fcf801..c538cb82c7 100644 --- a/ui/app/workspace/logs/page.tsx +++ b/ui/app/workspace/logs/page.tsx @@ -72,7 +72,7 @@ export default function LogsPage() { content_search: parseAsString.withDefault(""), start_time: parseAsInteger.withDefault(DEFAULT_START_TIME), end_time: parseAsInteger.withDefault(DEFAULT_END_TIME), - limit: parseAsInteger.withDefault(50), + limit: parseAsInteger.withDefault(25), // Default fallback, actual value calculated based on table height offset: parseAsInteger.withDefault(0), sort_by: parseAsString.withDefault("timestamp"), order: parseAsString.withDefault("desc"), diff --git a/ui/app/workspace/logs/views/logsTable.tsx b/ui/app/workspace/logs/views/logsTable.tsx index 85425ae3b9..64c0b8966f 100644 --- a/ui/app/workspace/logs/views/logsTable.tsx +++ b/ui/app/workspace/logs/views/logsTable.tsx @@ -2,10 +2,11 @@ import { Button } from "@/components/ui/button"; import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "@/components/ui/table"; +import { useTablePageSize } from "@/hooks/useTablePageSize"; import type { LogEntry, LogFilters, Pagination } from "@/lib/types/logs"; import { ColumnDef, flexRender, getCoreRowModel, SortingState, useReactTable } from "@tanstack/react-table"; import { ChevronLeft, ChevronRight, Pause, RefreshCw, X } from "lucide-react"; -import { useState } from "react"; +import { useEffect, useRef, useState } from "react"; import { LogFilters as LogFiltersComponent } from "./filters"; interface DataTableProps { @@ -42,6 +43,25 @@ export function LogsDataTable({ fetchStats, }: DataTableProps) { const [sorting, setSorting] = useState([{ id: pagination.sort_by, desc: pagination.order === "desc" }]); + const tableContainerRef = useRef(null); + const calculatedPageSize = useTablePageSize(tableContainerRef); + + // Refs to avoid stale closures in the page size effect + const paginationRef = useRef(pagination); + const onPaginationChangeRef = useRef(onPaginationChange); + paginationRef.current = pagination; + onPaginationChangeRef.current = onPaginationChange; + + // Update pagination limit when calculated page size increases (don't reduce on size reduction) + useEffect(() => { + if (calculatedPageSize && calculatedPageSize > paginationRef.current.limit) { + onPaginationChangeRef.current({ + ...paginationRef.current, + limit: calculatedPageSize, + offset: 0, // Reset to first page when page size changes + }); + } + }, [calculatedPageSize]); const handleSortingChange = (updaterOrValue: SortingState | ((old: SortingState) => SortingState)) => { const newSorting = typeof updaterOrValue === "function" ? updaterOrValue(sorting) : updaterOrValue; @@ -86,8 +106,8 @@ export function LogsDataTable({ return (
-
- +
+
{table.getHeaderGroups().map((headerGroup) => ( diff --git a/ui/components/devProfiler.tsx b/ui/components/devProfiler.tsx index e0c53fcabe..5d32ce1350 100644 --- a/ui/components/devProfiler.tsx +++ b/ui/components/devProfiler.tsx @@ -1,9 +1,10 @@ 'use client' -import { useGetDevPprofQuery } from '@/lib/store' +import { useGetDevGoroutinesQuery, useGetDevPprofQuery } from '@/lib/store' +import type { GoroutineGroup } from '@/lib/store/apis/devApi' import { isDevelopmentMode } from '@/lib/utils/port' -import { Activity, ChevronDown, ChevronUp, Cpu, HardDrive, X } from 'lucide-react' -import React, { useCallback, useMemo, useState } from 'react' +import { Activity, AlertTriangle, ChevronDown, ChevronRight, ChevronUp, Cpu, EyeOff, HardDrive, RotateCcw, TrendingUp, X } from 'lucide-react' +import React, { useCallback, useEffect, useMemo, useState } from 'react' import { Area, AreaChart, @@ -52,10 +53,282 @@ function truncateFunction (fn: string): string { return last } +// Get category badge color +function getCategoryColor (category: string): string { + switch (category) { + case 'per-request': + return 'text-amber-400 bg-amber-400/10' + case 'background': + return 'text-blue-400 bg-blue-400/10' + default: + return 'text-zinc-400 bg-zinc-400/10' + } +} + +// Extract file path from stack (first line containing .go:) +function getStackFilePath (stack: string[]): string { + for (const line of stack) { + // Match file path like "/path/to/file.go:123" and extract just the path + const match = line.match(/^\s*([^\s]+\.go):\d+/) + if (match) { + return match[1] + } + } + return '' +} + +// Generate a stable ID for a goroutine group +function getGoroutineId (g: GoroutineGroup): string { + return `${g.top_func}::${g.state}::${g.count}::${g.wait_minutes ?? 0}` +} + +// localStorage key for skipped goroutine file paths +const SKIPPED_GOROUTINE_FILES_KEY = 'devProfiler.skippedGoroutineFiles' + +// Load skipped goroutine file paths from localStorage +function loadSkippedGoroutineFiles (): Set { + if (typeof window === 'undefined') return new Set() + try { + const stored = localStorage.getItem(SKIPPED_GOROUTINE_FILES_KEY) + return stored ? new Set(JSON.parse(stored)) : new Set() + } catch { + return new Set() + } +} + +// Save skipped goroutine file paths to localStorage +function saveSkippedGoroutineFiles (skipped: Set): void { + if (typeof window === 'undefined') return + try { + localStorage.setItem(SKIPPED_GOROUTINE_FILES_KEY, JSON.stringify([...skipped])) + } catch { + // Ignore storage errors + } +} + +// Goroutine Health Section subcomponent +interface GoroutineHealthProps { + goroutineData: { + summary: { + background: number + per_request: number + long_waiting: number + potentially_stuck: number + } + total_goroutines: number + } | undefined + goroutineHealth: 'healthy' | 'warning' | 'critical' + goroutineTrend: { + isGrowing: boolean + growthPercent: number + avg: number + } | null + problemGoroutines: GoroutineGroup[] + expandedGoroutines: Set + toggleGoroutineExpand: (id: string) => void + skippedGoroutines: Set + onSkipGoroutine: (topFunc: string) => void + onClearSkipped: () => void +} + +function GoroutineHealthSection ({ + goroutineData, + goroutineHealth, + goroutineTrend, + problemGoroutines, + expandedGoroutines, + toggleGoroutineExpand, + skippedGoroutines, + onSkipGoroutine, + onClearSkipped, +}: GoroutineHealthProps): React.ReactNode { + if (!goroutineData) return null + + const { summary, total_goroutines } = goroutineData + + return ( +
+ {/* Header with health status */} +
+
+ + Goroutine Health +
+
+ {goroutineTrend?.isGrowing && ( + + + +{goroutineTrend.growthPercent.toFixed(0)}% + + )} + {goroutineHealth === 'critical' && ( + + + Stuck + + )} + {goroutineHealth === 'warning' && ( + + + Long Wait + + )} + {goroutineHealth === 'healthy' && ( + + Healthy + + )} +
+
+ + {/* Summary stats */} +
+
+ Total + {total_goroutines} +
+
+ Background + {summary.background} +
+
+ Per-Request + {summary.per_request} +
+
+ Stuck + 0 ? 'text-red-400' : 'text-zinc-500'}`}> + {summary.potentially_stuck} + +
+
+ + {/* Problem goroutines list */} + {(problemGoroutines.length > 0 || skippedGoroutines.size > 0) && ( +
+
+ Potential Leaks + {skippedGoroutines.size > 0 && ( + + )} +
+ {problemGoroutines.map((g) => { + const gid = getGoroutineId(g) + return ( +
+
toggleGoroutineExpand(gid)} + onKeyDown={(e) => { + if (e.key === 'Enter' || e.key === ' ') { + e.preventDefault() + toggleGoroutineExpand(gid) + } + }} + className="flex w-full cursor-pointer flex-col gap-1 px-2 py-1.5 pr-8 text-left hover:bg-zinc-700/50" + > +
+ {expandedGoroutines.has(gid) ? ( + + ) : ( + + )} + + {truncateFunction(g.top_func)} + +
+
+ + {g.category} + + {g.count}x + {g.wait_minutes != null && ( + {g.wait_minutes}m waiting + )} +
+
+ + {expandedGoroutines.has(gid) && ( +
+
+ State: {g.state} + {g.wait_reason && ( + Wait: {g.wait_reason} + )} +
+
+ {g.stack.slice(0, 10).map((line, j) => ( +
+ {line} +
+ ))} + {g.stack.length > 10 && ( +
+ ... {g.stack.length - 10} more frames +
+ )} +
+
+ )} +
+ )})} + + {problemGoroutines.length === 0 && skippedGoroutines.size > 0 && ( +
+ All potential leaks hidden +
+ )} + {problemGoroutines.length === 0 && skippedGoroutines.size === 0 && (summary.long_waiting > 0 || summary.potentially_stuck > 0) && ( +
+ {summary.long_waiting > 0 && summary.potentially_stuck > 0 + ? `${summary.long_waiting} long-waiting and ${summary.potentially_stuck} stuck goroutines (background workers filtered)` + : summary.long_waiting > 0 + ? `${summary.long_waiting} long-waiting goroutines (background workers filtered)` + : `${summary.potentially_stuck} stuck goroutines (background workers filtered)`} +
+ )} +
+ )} + + {/* No problems message */} + {problemGoroutines.length === 0 && summary.long_waiting === 0 && summary.potentially_stuck === 0 && ( +
+ No goroutine leaks detected +
+ )} +
+ ) +} + export function DevProfiler (): React.ReactNode { const [isVisible, setIsVisible] = useState(true) const [isExpanded, setIsExpanded] = useState(true) const [isDismissed, setIsDismissed] = useState(false) + const [expandedGoroutines, setExpandedGoroutines] = useState>(new Set()) + const [skippedGoroutines, setSkippedGoroutines] = useState>(() => loadSkippedGoroutineFiles()) + + // Sync skipped goroutines to localStorage + useEffect(() => { + saveSkippedGoroutineFiles(skippedGoroutines) + }, [skippedGoroutines]) // Only fetch in development mode and when not dismissed const shouldFetch = isDevelopmentMode() && !isDismissed @@ -65,6 +338,11 @@ export function DevProfiler (): React.ReactNode { skip: !shouldFetch, }) + const { data: goroutineData } = useGetDevGoroutinesQuery(undefined, { + pollingInterval: shouldFetch ? 10000 : 0, // Poll every 10 seconds + skip: !shouldFetch, + }) + // Memoize chart data transformation const memoryChartData = useMemo(() => { if (!data?.history) return [] @@ -84,10 +362,68 @@ export function DevProfiler (): React.ReactNode { })) }, [data?.history]) + // Detect goroutine count trend (growing = potential leak) + const goroutineTrend = useMemo(() => { + if (!data?.history || data.history.length < 5 || !data?.runtime) return null + const recent = data.history.slice(-5) + const avg = recent.reduce((sum, p) => sum + p.goroutines, 0) / recent.length + const current = data.runtime.num_goroutine + const isGrowing = current > avg * 1.1 // 10% above average + const growthPercent = avg > 0 ? ((current - avg) / avg) * 100 : 0 + return { isGrowing, growthPercent, avg } + }, [data?.history, data?.runtime?.num_goroutine]) + + // Filter problem goroutines (stuck or long-waiting, excluding expected background workers and skipped) + const problemGoroutines = useMemo(() => { + if (!goroutineData?.groups) return [] + return goroutineData.groups + .filter((g) => { + if (!g.wait_minutes || g.wait_minutes < 1) return false + if (g.category === 'background') return false + const filePath = getStackFilePath(g.stack) + if (filePath && skippedGoroutines.has(filePath)) return false + return true + }) + .slice(0, 5) + }, [goroutineData?.groups, skippedGoroutines]) + + // Get goroutine health status + const goroutineHealth = useMemo(() => { + if (!goroutineData?.summary) return 'healthy' + const { potentially_stuck, long_waiting } = goroutineData.summary + if (potentially_stuck > 0) return 'critical' + if (long_waiting > 0) return 'warning' + return 'healthy' + }, [goroutineData?.summary]) + const handleDismiss = useCallback(() => { setIsDismissed(true) }, []) + const toggleGoroutineExpand = useCallback((id: string) => { + setExpandedGoroutines((prev) => { + const next = new Set(prev) + if (next.has(id)) { + next.delete(id) + } else { + next.add(id) + } + return next + }) + }, []) + + const handleSkipGoroutine = useCallback((filePath: string) => { + setSkippedGoroutines((prev) => { + const next = new Set(prev) + next.add(filePath) + return next + }) + }, []) + + const handleClearSkipped = useCallback(() => { + setSkippedGoroutines(new Set()) + }, []) + const handleToggleExpand = useCallback(() => { setIsExpanded((prev) => !prev) }, []) @@ -160,7 +496,7 @@ export function DevProfiler (): React.ReactNode { )} {isExpanded && data && ( -
+
{/* Current Stats */}
@@ -360,7 +696,7 @@ export function DevProfiler (): React.ReactNode {
{/* Top Allocations */} -
+
Top Allocations @@ -395,6 +731,19 @@ export function DevProfiler (): React.ReactNode {
+ {/* Goroutine Health */} + + {/* Footer with info */}
CPUs: {data.runtime.num_cpu} | GOMAXPROCS: {data.runtime.gomaxprocs} | diff --git a/ui/hooks/useTablePageSize.ts b/ui/hooks/useTablePageSize.ts new file mode 100644 index 0000000000..6ea4e383e9 --- /dev/null +++ b/ui/hooks/useTablePageSize.ts @@ -0,0 +1,67 @@ +"use client" + +import { RefObject, useCallback, useEffect, useState } from "react" + +const ROW_HEIGHT = 48 // h-12 = 3rem = 48px +const HEADER_HEIGHT = 44 // approximate table header height +const STATUS_ROW_HEIGHT = 48 // the "Listening for logs..." row (h-12) +const MIN_PAGE_SIZE = 5 // minimum items per page + +interface UseTablePageSizeOptions { + debounceMs?: number +} + +export function useTablePageSize ( + containerRef: RefObject, + options: UseTablePageSizeOptions = {} +): number | null { + const { debounceMs = 150 } = options + const [pageSize, setPageSize] = useState(null) + + const calculatePageSize = useCallback((height: number): number => { + const availableHeight = height - HEADER_HEIGHT - STATUS_ROW_HEIGHT + const calculated = Math.floor(availableHeight / ROW_HEIGHT) + return Math.max(calculated, MIN_PAGE_SIZE) + }, []) + + useEffect(() => { + const element = containerRef.current + if (!element) return + + let timeoutId: ReturnType | null = null + + const handleResize = (entries: ResizeObserverEntry[]) => { + const entry = entries[0] + if (!entry) return + + const height = entry.contentRect.height + + if (timeoutId) { + clearTimeout(timeoutId) + } + + timeoutId = setTimeout(() => { + const newPageSize = calculatePageSize(height) + setPageSize(newPageSize) + }, debounceMs) + } + + const resizeObserver = new ResizeObserver(handleResize) + resizeObserver.observe(element) + + // Calculate initial size immediately + const initialHeight = element.getBoundingClientRect().height + if (initialHeight > 0) { + setPageSize(calculatePageSize(initialHeight)) + } + + return () => { + if (timeoutId) { + clearTimeout(timeoutId) + } + resizeObserver.disconnect() + } + }, [containerRef, calculatePageSize, debounceMs]) + + return pageSize +} diff --git a/ui/lib/store/apis/devApi.ts b/ui/lib/store/apis/devApi.ts index c3e857ec6b..e8cc281dd3 100644 --- a/ui/lib/store/apis/devApi.ts +++ b/ui/lib/store/apis/devApi.ts @@ -54,6 +54,33 @@ export interface PprofData { history: HistoryPoint[] } +// Goroutine group representing goroutines with same stack trace +export interface GoroutineGroup { + count: number + state: string + wait_reason?: string + wait_minutes?: number + top_func: string + stack: string[] + category: 'background' | 'per-request' | 'unknown' +} + +// Goroutine health summary +export interface GoroutineSummary { + background: number + per_request: number + long_waiting: number + potentially_stuck: number +} + +// Goroutine profile response +export interface GoroutineProfile { + timestamp: string + total_goroutines: number + groups: GoroutineGroup[] + summary: GoroutineSummary +} + export const devApi = baseApi.injectEndpoints({ endpoints: (builder) => ({ // Get dev pprof data - polls every 10 seconds @@ -62,11 +89,19 @@ export const devApi = baseApi.injectEndpoints({ url: '/dev/pprof', }), }), + // Get goroutine profile for leak detection + getDevGoroutines: builder.query({ + query: () => ({ + url: '/dev/pprof/goroutines', + }), + }), }), }) export const { useGetDevPprofQuery, useLazyGetDevPprofQuery, + useGetDevGoroutinesQuery, + useLazyGetDevGoroutinesQuery, } = devApi