Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions core/providers/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,12 @@ type AnthropicImageContent struct {

// AnthropicProvider implements the Provider interface for Anthropic's Claude API.
type AnthropicProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for API requests
streamClient *http.Client // HTTP client for streaming requests
apiVersion string // API version for the provider
networkConfig schemas.NetworkConfig // Network configuration including extra headers
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for API requests
streamClient *http.Client // HTTP client for streaming requests
apiVersion string // API version for the provider
networkConfig schemas.NetworkConfig // Network configuration including extra headers
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
}

// anthropicChatResponsePool provides a pool for Anthropic chat response objects.
Expand Down Expand Up @@ -228,11 +229,12 @@ func NewAnthropicProvider(config *schemas.ProviderConfig, logger schemas.Logger)
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")

return &AnthropicProvider{
logger: logger,
client: client,
streamClient: streamClient,
apiVersion: "2023-06-01",
networkConfig: config.NetworkConfig,
logger: logger,
client: client,
streamClient: streamClient,
apiVersion: "2023-06-01",
networkConfig: config.NetworkConfig,
sendBackRawResponse: config.SendBackRawResponse,
}
}

Expand Down Expand Up @@ -330,7 +332,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model str
response := acquireAnthropicTextResponse()
defer releaseAnthropicTextResponse(response)

rawResponse, bifrostErr := handleProviderResponse(responseBody, response)
rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse)
if bifrostErr != nil {
return nil, bifrostErr
}
Expand Down Expand Up @@ -358,11 +360,15 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model str
},
Model: response.Model,
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Anthropic,
RawResponse: rawResponse,
Provider: schemas.Anthropic,
},
}

// Set raw response if enabled
if provider.sendBackRawResponse {
bifrostResponse.ExtraFields.RawResponse = rawResponse
}

if params != nil {
bifrostResponse.ExtraFields.Params = *params
}
Expand Down Expand Up @@ -391,7 +397,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model str
response := acquireAnthropicChatResponse()
defer releaseAnthropicChatResponse(response)

rawResponse, bifrostErr := handleProviderResponse(responseBody, response)
rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse)
if bifrostErr != nil {
return nil, bifrostErr
}
Expand All @@ -404,8 +410,12 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model str
}

bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{
Provider: schemas.Anthropic,
RawResponse: rawResponse,
Provider: schemas.Anthropic,
}

// Set raw response if enabled
if provider.sendBackRawResponse {
bifrostResponse.ExtraFields.RawResponse = rawResponse
}

if params != nil {
Expand Down
38 changes: 24 additions & 14 deletions core/providers/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@ func releaseAzureTextResponse(resp *AzureTextResponse) {

// AzureProvider implements the Provider interface for Azure's OpenAI API.
type AzureProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for API requests
streamClient *http.Client // HTTP client for streaming requests
networkConfig schemas.NetworkConfig // Network configuration including extra headers
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for API requests
streamClient *http.Client // HTTP client for streaming requests
networkConfig schemas.NetworkConfig // Network configuration including extra headers
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
}

// NewAzureProvider creates a new Azure provider instance.
Expand Down Expand Up @@ -147,10 +148,11 @@ func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*A
client = configureProxy(client, config.ProxyConfig, logger)

return &AzureProvider{
logger: logger,
client: client,
streamClient: streamClient,
networkConfig: config.NetworkConfig,
logger: logger,
client: client,
streamClient: streamClient,
networkConfig: config.NetworkConfig,
sendBackRawResponse: config.SendBackRawResponse,
}, nil
}

Expand Down Expand Up @@ -263,7 +265,7 @@ func (provider *AzureProvider) TextCompletion(ctx context.Context, model string,
response := acquireAzureTextResponse()
defer releaseAzureTextResponse(response)

rawResponse, bifrostErr := handleProviderResponse(responseBody, response)
rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse)
if bifrostErr != nil {
return nil, bifrostErr
}
Expand Down Expand Up @@ -298,11 +300,15 @@ func (provider *AzureProvider) TextCompletion(ctx context.Context, model string,
SystemFingerprint: response.SystemFingerprint,
Usage: &response.Usage,
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Azure,
RawResponse: rawResponse,
Provider: schemas.Azure,
},
}

// Set raw response if enabled
if provider.sendBackRawResponse {
bifrostResponse.ExtraFields.RawResponse = rawResponse
}

if params != nil {
bifrostResponse.ExtraFields.Params = *params
}
Expand Down Expand Up @@ -331,7 +337,7 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, model string,
response := acquireAzureChatResponse()
defer releaseAzureChatResponse(response)

rawResponse, bifrostErr := handleProviderResponse(responseBody, response)
rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse)
if bifrostErr != nil {
return nil, bifrostErr
}
Expand All @@ -345,11 +351,15 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, model string,
SystemFingerprint: response.SystemFingerprint,
Usage: &response.Usage,
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Azure,
RawResponse: rawResponse,
Provider: schemas.Azure,
},
}

// Set raw response if enabled
if provider.sendBackRawResponse {
bifrostResponse.ExtraFields.RawResponse = rawResponse
}

if params != nil {
bifrostResponse.ExtraFields.Params = *params
}
Expand Down
43 changes: 25 additions & 18 deletions core/providers/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,11 @@ type BedrockStreamMetadataEvent struct {

// BedrockProvider implements the Provider interface for AWS Bedrock.
type BedrockProvider struct {
logger schemas.Logger // Logger for provider operations
client *http.Client // HTTP client for API requests
meta schemas.MetaConfig // Bedrock-specific configuration
networkConfig schemas.NetworkConfig // Network configuration including extra headers
logger schemas.Logger // Logger for provider operations
client *http.Client // HTTP client for API requests
meta schemas.MetaConfig // Bedrock-specific configuration
networkConfig schemas.NetworkConfig // Network configuration including extra headers
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
}

// bedrockChatResponsePool provides a pool for Bedrock response objects.
Expand Down Expand Up @@ -239,10 +240,11 @@ func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) (
}

return &BedrockProvider{
logger: logger,
client: client,
meta: config.MetaConfig,
networkConfig: config.NetworkConfig,
logger: logger,
client: client,
meta: config.MetaConfig,
networkConfig: config.NetworkConfig,
sendBackRawResponse: config.SendBackRawResponse,
}, nil
}

Expand Down Expand Up @@ -836,14 +838,15 @@ func (provider *BedrockProvider) TextCompletion(ctx context.Context, model strin
return nil, err
}

// Parse raw response
var rawResponse interface{}
if err := sonic.Unmarshal(body, &rawResponse); err != nil {
return nil, newBifrostOperationError("error parsing raw response", err, schemas.Bedrock)
// Parse raw response if enabled
if provider.sendBackRawResponse {
var rawResponse interface{}
if err := sonic.Unmarshal(body, &rawResponse); err != nil {
return nil, newBifrostOperationError("error parsing raw response", err, schemas.Bedrock)
}
bifrostResponse.ExtraFields.RawResponse = rawResponse
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.

bifrostResponse.ExtraFields.RawResponse = rawResponse

if params != nil {
bifrostResponse.ExtraFields.Params = *params
}
Expand Down Expand Up @@ -955,7 +958,7 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model strin
response := acquireBedrockChatResponse()
defer releaseBedrockChatResponse(response)

rawResponse, bifrostErr := handleProviderResponse(responseBody, response)
rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse)
if bifrostErr != nil {
return nil, bifrostErr
}
Expand Down Expand Up @@ -1033,12 +1036,16 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model strin
},
Model: model,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: &latency,
Provider: schemas.Bedrock,
RawResponse: rawResponse,
Latency: &latency,
Provider: schemas.Bedrock,
},
}

// Set raw response if enabled
if provider.sendBackRawResponse {
bifrostResponse.ExtraFields.RawResponse = rawResponse
}

if params != nil {
bifrostResponse.ExtraFields.Params = *params
}
Expand Down
25 changes: 15 additions & 10 deletions core/providers/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,11 @@ type CohereEmbeddingResponse struct {

// CohereProvider implements the Provider interface for Cohere.
type CohereProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for API requests
streamClient *http.Client // HTTP client for streaming requests
networkConfig schemas.NetworkConfig // Network configuration including extra headers
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for API requests
streamClient *http.Client // HTTP client for streaming requests
networkConfig schemas.NetworkConfig // Network configuration including extra headers
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
}

// CohereStreamStartEvent represents the start of a stream event.
Expand Down Expand Up @@ -169,10 +170,11 @@ func NewCohereProvider(config *schemas.ProviderConfig, logger schemas.Logger) *C
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")

return &CohereProvider{
logger: logger,
client: client,
streamClient: streamClient,
networkConfig: config.NetworkConfig,
logger: logger,
client: client,
streamClient: streamClient,
networkConfig: config.NetworkConfig,
sendBackRawResponse: config.SendBackRawResponse,
}
}

Expand Down Expand Up @@ -256,7 +258,7 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, model string
response := acquireCohereResponse()
defer releaseCohereResponse(response)

rawResponse, bifrostErr := handleProviderResponse(responseBody, response)
rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse)
if bifrostErr != nil {
return nil, bifrostErr
}
Expand Down Expand Up @@ -327,10 +329,13 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, model string
CompletionTokens: float64Ptr(response.Meta.BilledUnits.OutputTokens),
},
ChatHistory: convertChatHistory(response.ChatHistory),
RawResponse: rawResponse,
},
}

if provider.sendBackRawResponse {
bifrostResponse.ExtraFields.RawResponse = rawResponse
}

if params != nil {
bifrostResponse.ExtraFields.Params = *params
}
Expand Down
Loading