Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 52 additions & 10 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ var retryableStatusCodes = map[int]bool{
429: true, // Too Many Requests
}

// BifrostContextKey is a type for context keys used in Bifrost.
type BifrostContextKey string

// BifrostContextKeyRequestType is a context key for the request type.
const BifrostContextKeyRequestType BifrostContextKey = "bifrost-request-type"

// INITIALIZATION

// Init initializes a new Bifrost instance with the given configuration.
Expand Down Expand Up @@ -836,6 +842,14 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR
return nil, err
}

// Handle nil context early to prevent blocking
if ctx == nil {
ctx = bifrost.backgroundCtx
}

// Add request type to context
ctx = context.WithValue(ctx, BifrostContextKeyRequestType, requestType)

// Try the primary provider first
primaryResult, primaryErr := bifrost.tryRequest(req, ctx, requestType)

Expand Down Expand Up @@ -880,6 +894,14 @@ func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.Bi
return nil, err
}

// Handle nil context early to prevent blocking
if ctx == nil {
ctx = bifrost.backgroundCtx
}

// Add request type to context
ctx = context.WithValue(ctx, BifrostContextKeyRequestType, requestType)

// Try the primary provider first
primaryResult, primaryErr := bifrost.tryStreamRequest(req, ctx, requestType)

Expand Down Expand Up @@ -922,11 +944,6 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
return nil, newBifrostError(err)
}

// Handle nil context early to prevent blocking
if ctx == nil {
ctx = bifrost.backgroundCtx
}

// Add MCP tools to request if MCP is configured and requested
if requestType != EmbeddingRequest && requestType != SpeechRequest && bifrost.mcpManager != nil {
req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req)
Expand Down Expand Up @@ -1012,11 +1029,6 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
return nil, newBifrostError(err)
}

// Handle nil context early to prevent blocking
if ctx == nil {
ctx = bifrost.backgroundCtx
}

// Add MCP tools to request if MCP is configured and requested
if requestType != SpeechStreamRequest && requestType != TranscriptionStreamRequest && bifrost.mcpManager != nil {
req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req)
Expand All @@ -1035,6 +1047,36 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
}
return newBifrostMessageChan(resp), nil
}
// Handle short-circuit with stream
if shortCircuit.Stream != nil {
outputStream := make(chan *schemas.BifrostStream)

// Create a post hook runner cause pipeline object is put back in the pool on defer
pipelinePostHookRunner := func(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
return pipeline.RunPostHooks(ctx, result, err, preCount)
}

go func() {
defer close(outputStream)

for streamMsg := range shortCircuit.Stream {
if streamMsg == nil {
continue
}

// Run post hooks on the stream message
processedResp, processedErr := pipelinePostHookRunner(&ctx, streamMsg.BifrostResponse, streamMsg.BifrostError)

// Send the processed message to the output stream
outputStream <- &schemas.BifrostStream{
BifrostResponse: processedResp,
BifrostError: processedErr,
}
}
}()

return outputStream, nil
}
// Handle short-circuit with error
if shortCircuit.Error != nil {
resp, bifrostErr := pipeline.RunPostHooks(&ctx, nil, shortCircuit.Error, preCount)
Expand Down
24 changes: 17 additions & 7 deletions core/providers/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,7 @@ func handleAnthropicStreaming(
defer resp.Body.Close()

scanner := bufio.NewScanner(resp.Body)
chunkIndex := -1

// Track minimal state needed for response format
var messageID string
Expand Down Expand Up @@ -908,6 +909,8 @@ func handleAnthropicStreaming(
continue
}

chunkIndex++

// Handle different event types
switch eventType {
case "message_start":
Expand Down Expand Up @@ -958,7 +961,8 @@ func handleAnthropicStreaming(
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: providerType,
Provider: providerType,
ChunkIndex: chunkIndex,
},
}

Expand Down Expand Up @@ -996,7 +1000,8 @@ func handleAnthropicStreaming(
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: providerType,
Provider: providerType,
ChunkIndex: chunkIndex,
},
}

Expand Down Expand Up @@ -1037,7 +1042,8 @@ func handleAnthropicStreaming(
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: providerType,
Provider: providerType,
ChunkIndex: chunkIndex,
},
}

Expand Down Expand Up @@ -1075,7 +1081,8 @@ func handleAnthropicStreaming(
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: providerType,
Provider: providerType,
ChunkIndex: chunkIndex,
},
}

Expand Down Expand Up @@ -1106,7 +1113,8 @@ func handleAnthropicStreaming(
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: providerType,
Provider: providerType,
ChunkIndex: chunkIndex,
},
}

Expand Down Expand Up @@ -1155,7 +1163,8 @@ func handleAnthropicStreaming(
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: providerType,
Provider: providerType,
ChunkIndex: chunkIndex,
},
}

Expand Down Expand Up @@ -1194,7 +1203,8 @@ func handleAnthropicStreaming(
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: providerType,
Provider: providerType,
ChunkIndex: chunkIndex,
},
}

Expand Down
19 changes: 14 additions & 5 deletions core/providers/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
buf := make([]byte, 0, 64*1024)
scanner.Buffer(buf, 1024*1024)

chunkIndex := -1

for scanner.Scan() {
line := scanner.Text()

Expand Down Expand Up @@ -1403,6 +1405,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
continue
}

chunkIndex++

// Extract the complete JSON object
jsonStr := jsonData[:jsonEnd]

Expand Down Expand Up @@ -1447,7 +1451,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Bedrock,
Provider: schemas.Bedrock,
ChunkIndex: chunkIndex,
},
}

Expand Down Expand Up @@ -1504,7 +1509,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Bedrock,
Provider: schemas.Bedrock,
ChunkIndex: chunkIndex,
},
}

Expand Down Expand Up @@ -1538,7 +1544,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Bedrock,
Provider: schemas.Bedrock,
ChunkIndex: chunkIndex,
},
}

Expand Down Expand Up @@ -1568,7 +1575,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Bedrock,
Provider: schemas.Bedrock,
ChunkIndex: chunkIndex,
},
}

Expand Down Expand Up @@ -1617,7 +1625,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Bedrock,
Provider: schemas.Bedrock,
ChunkIndex: chunkIndex,
},
}

Expand Down
16 changes: 12 additions & 4 deletions core/providers/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,8 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo
// Create response channel
responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize)

chunkIndex := -1

// Start streaming in a goroutine
go func() {
defer close(responseChan)
Expand Down Expand Up @@ -792,6 +794,8 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo
continue
}

chunkIndex++

switch eventType {
case "stream-start":
var startEvent CohereStreamStartEvent
Expand Down Expand Up @@ -819,7 +823,8 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Cohere,
Provider: schemas.Cohere,
ChunkIndex: chunkIndex,
},
}

Expand Down Expand Up @@ -854,7 +859,8 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo
},
Model: model,
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Cohere,
Provider: schemas.Cohere,
ChunkIndex: chunkIndex,
},
}

Expand Down Expand Up @@ -898,7 +904,8 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo
},
Model: model,
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Cohere,
Provider: schemas.Cohere,
ChunkIndex: chunkIndex,
},
}

Expand Down Expand Up @@ -954,7 +961,8 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo
},
Model: model,
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Cohere,
Provider: schemas.Cohere,
ChunkIndex: chunkIndex,
},
}

Expand Down
Loading
Loading