Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 25 additions & 68 deletions core/providers/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,7 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB
// Marshal the request body
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: schemas.ErrorField{
Message: schemas.ErrProviderJSONMarshaling,
Error: err,
},
}
return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Anthropic)
}

// Create the request with the JSON body
Expand Down Expand Up @@ -834,25 +828,13 @@ func handleAnthropicStreaming(

jsonBody, err := json.Marshal(requestBody)
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: schemas.ErrorField{
Message: schemas.ErrProviderJSONMarshaling,
Error: err,
},
}
return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerType)
}

// Create HTTP request for streaming
req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(jsonBody)))
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: schemas.ErrorField{
Message: "failed to create HTTP request",
Error: err,
},
}
return nil, newBifrostOperationError("failed to create HTTP request", err, providerType)
}

// Set headers
Expand All @@ -866,27 +848,14 @@ func handleAnthropicStreaming(
// Make the request
resp, err := httpClient.Do(req)
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: schemas.ErrProviderRequest,
Error: err,
},
}
return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerType)
}

// Check for HTTP errors
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, &schemas.BifrostError{
IsBifrostError: false,
StatusCode: &resp.StatusCode,
Error: schemas.ErrorField{
Message: fmt.Sprintf("HTTP error from %s: %d", providerType, resp.StatusCode),
Error: fmt.Errorf("%s", string(body)),
},
}
return nil, newProviderAPIError(fmt.Sprintf("HTTP error from %s: %d", providerType, resp.StatusCode), fmt.Errorf("%s", string(body)), resp.StatusCode, providerType, nil, nil)
}

// Create response channel
Expand Down Expand Up @@ -989,7 +958,7 @@ func handleAnthropicStreaming(
}

// Use utility function to process and send response
ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan)
processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan)
}
default:
thought := ""
Expand Down Expand Up @@ -1027,7 +996,7 @@ func handleAnthropicStreaming(
}

// Use utility function to process and send response
ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan)
processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan)
}
}

Expand Down Expand Up @@ -1068,7 +1037,7 @@ func handleAnthropicStreaming(
}

// Use utility function to process and send response
ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan)
processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan)
}

case "input_json_delta":
Expand Down Expand Up @@ -1106,7 +1075,7 @@ func handleAnthropicStreaming(
}

// Use utility function to process and send response
ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan)
processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan)
}

case "thinking_delta":
Expand Down Expand Up @@ -1137,7 +1106,7 @@ func handleAnthropicStreaming(
}

// Use utility function to process and send response
ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan)
processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan)
}

case "signature_delta":
Expand Down Expand Up @@ -1186,7 +1155,7 @@ func handleAnthropicStreaming(
}

// Use utility function to process and send response
ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan)
processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan)
}

case "message_stop":
Expand Down Expand Up @@ -1225,7 +1194,7 @@ func handleAnthropicStreaming(
}

// Use utility function to process and send response
ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan)
processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan)
return

case "ping":
Expand All @@ -1239,20 +1208,23 @@ func handleAnthropicStreaming(
continue
}
if event.Error != nil {

// Send error through channel before closing
errorResponse := &schemas.BifrostStream{
BifrostError: &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Type: &event.Error.Type,
Message: event.Error.Message,
},
bifrostError := &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Type: &event.Error.Type,
Message: event.Error.Message,
},
}

processedResponse, processedError := postHookRunner(&ctx, nil, bifrostError)
bifrostError = processedError

select {
case responseChan <- errorResponse:
case responseChan <- &schemas.BifrostStream{
BifrostResponse: processedResponse,
BifrostError: bifrostError,
}:
case <-ctx.Done():
}
}
Expand All @@ -1272,22 +1244,7 @@ func handleAnthropicStreaming(

if err := scanner.Err(); err != nil {
logger.Warn(fmt.Sprintf("Error reading %s stream: %v", providerType, err))

// Send scanner error through channel
errorResponse := &schemas.BifrostStream{
BifrostError: &schemas.BifrostError{
IsBifrostError: true,
Error: schemas.ErrorField{
Message: "Error reading stream",
Error: err,
},
},
}

select {
case responseChan <- errorResponse:
case <-ctx.Done():
}
processAndSendError(ctx, postHookRunner, err, responseChan)
}
}()

Expand Down
91 changes: 13 additions & 78 deletions core/providers/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,46 +165,25 @@ func (provider *AzureProvider) GetProviderKey() schemas.ModelProvider {
// Returns the response body or an error if the request fails.
func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, key schemas.Key, model string) ([]byte, *schemas.BifrostError) {
if key.AzureKeyConfig == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "azure key config not set",
},
}
return nil, newConfigurationError("azure key config not set", schemas.Azure)
}

// Marshal the request body
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: schemas.ErrorField{
Message: schemas.ErrProviderJSONMarshaling,
Error: err,
},
}
return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Azure)
}

if key.AzureKeyConfig.Endpoint == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "endpoint not set",
},
}
return nil, newConfigurationError("endpoint not set", schemas.Azure)
}

url := key.AzureKeyConfig.Endpoint

if key.AzureKeyConfig.Deployments != nil {
deployment := key.AzureKeyConfig.Deployments[model]
if deployment == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: fmt.Sprintf("deployment if not found for model %s", model),
},
}
return nil, newConfigurationError(fmt.Sprintf("deployment not found for model %s", model), schemas.Azure)
}

apiVersion := key.AzureKeyConfig.APIVersion
Expand All @@ -214,12 +193,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody

url = fmt.Sprintf("%s/openai/deployments/%s/%s?api-version=%s", url, deployment, path, *apiVersion)
} else {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "deployments not set",
},
}
return nil, newConfigurationError("deployments not set", schemas.Azure)
}

// Create the request with the JSON body
Expand Down Expand Up @@ -389,10 +363,7 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, model string,
// Returns a BifrostResponse containing the embedding(s) and any error that occurred.
func (provider *AzureProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
if len(input.Texts) == 0 {
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: schemas.ErrorField{Message: "no input text provided for embedding"},
}
return nil, newBifrostOperationError("no input text provided for embedding", nil, schemas.Azure)
}

// Prepare request body - Azure uses deployment-scoped URLs, so model is not needed in body
Expand Down Expand Up @@ -422,13 +393,7 @@ func (provider *AzureProvider) Embedding(ctx context.Context, model string, key
// Parse response
var response AzureEmbeddingResponse
if err := json.Unmarshal(responseBody, &response); err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: schemas.ErrorField{
Message: schemas.ErrProviderResponseUnmarshal,
Error: err,
},
}
return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Azure)
}

bifrostResponse := &schemas.BifrostResponse{
Expand Down Expand Up @@ -464,22 +429,12 @@ func (provider *AzureProvider) Embedding(ctx context.Context, model string, key
if num, ok := v[j].(float64); ok {
floatArray[j] = float32(num)
} else {
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: schemas.ErrorField{
Message: fmt.Sprintf("unsupported number type in embedding array: %T", v[j]),
},
}
return nil, newBifrostOperationError(fmt.Sprintf("unsupported number type in embedding array: %T", v[j]), nil, schemas.Azure)
}
}
embeddings[i] = floatArray
default:
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: schemas.ErrorField{
Message: fmt.Sprintf("unsupported embedding type: %T", data.Embedding),
},
}
return nil, newBifrostOperationError(fmt.Sprintf("unsupported embedding type: %T", data.Embedding), nil, schemas.Azure)
}
}
bifrostResponse.Embedding = embeddings
Expand All @@ -500,12 +455,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo
formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params)

if key.AzureKeyConfig == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "azure key config not set",
},
}
return nil, newConfigurationError("azure key config not set", schemas.Azure)
}

// Merge additional parameters and set stream to true
Expand All @@ -517,12 +467,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo

// Construct Azure-specific URL with deployment
if key.AzureKeyConfig.Endpoint == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "endpoint not set",
},
}
return nil, newConfigurationError("endpoint not set", schemas.Azure)
}

baseURL := key.AzureKeyConfig.Endpoint
Expand All @@ -531,12 +476,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo
if key.AzureKeyConfig.Deployments != nil {
deployment := key.AzureKeyConfig.Deployments[model]
if deployment == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: fmt.Sprintf("deployment not found for model %s", model),
},
}
return nil, newConfigurationError(fmt.Sprintf("deployment not found for model %s", model), schemas.Azure)
}

apiVersion := key.AzureKeyConfig.APIVersion
Expand All @@ -546,12 +486,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo

fullURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", baseURL, deployment, *apiVersion)
} else {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "deployments not set",
},
}
return nil, newConfigurationError("deployments not set", schemas.Azure)
}

// Prepare Azure-specific headers
Expand Down
Loading