diff --git a/core/providers/anthropic.go b/core/providers/anthropic.go index 321d63b29e..b7c34a2b1e 100644 --- a/core/providers/anthropic.go +++ b/core/providers/anthropic.go @@ -265,13 +265,7 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB // Marshal the request body jsonData, err := json.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Anthropic) } // Create the request with the JSON body @@ -834,25 +828,13 @@ func handleAnthropicStreaming( jsonBody, err := json.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerType) } // Create HTTP request for streaming req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(jsonBody))) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to create HTTP request", - Error: err, - }, - } + return nil, newBifrostOperationError("failed to create HTTP request", err, providerType) } // Set headers @@ -866,27 +848,14 @@ func handleAnthropicStreaming( // Make the request resp, err := httpClient.Do(req) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerType) } // Check for HTTP errors if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) resp.Body.Close() - return nil, &schemas.BifrostError{ - IsBifrostError: false, - StatusCode: &resp.StatusCode, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("HTTP error from %s: %d", providerType, resp.StatusCode), - Error: fmt.Errorf("%s", string(body)), - }, - } + return nil, newProviderAPIError(fmt.Sprintf("HTTP error from %s: %d", providerType, resp.StatusCode), fmt.Errorf("%s", string(body)), resp.StatusCode, providerType, nil, nil) } // Create response channel @@ -989,7 +958,7 @@ func handleAnthropicStreaming( } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) } default: thought := "" @@ -1027,7 +996,7 @@ func handleAnthropicStreaming( } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) } } @@ -1068,7 +1037,7 @@ func handleAnthropicStreaming( } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) } case "input_json_delta": @@ -1106,7 +1075,7 @@ func handleAnthropicStreaming( } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) } case "thinking_delta": @@ -1137,7 +1106,7 @@ func handleAnthropicStreaming( } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) } case "signature_delta": @@ -1186,7 +1155,7 @@ func handleAnthropicStreaming( } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) } case "message_stop": @@ -1225,7 +1194,7 @@ func handleAnthropicStreaming( } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) return case "ping": @@ -1239,20 +1208,23 @@ func handleAnthropicStreaming( continue } if event.Error != nil { - // Send error through channel before closing - errorResponse := &schemas.BifrostStream{ - BifrostError: &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Type: &event.Error.Type, - Message: event.Error.Message, - }, + bifrostError := &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Type: &event.Error.Type, + Message: event.Error.Message, }, } + processedResponse, processedError := postHookRunner(&ctx, nil, bifrostError) + bifrostError = processedError + select { - case responseChan <- errorResponse: + case responseChan <- &schemas.BifrostStream{ + BifrostResponse: processedResponse, + BifrostError: bifrostError, + }: case <-ctx.Done(): } } @@ -1272,22 +1244,7 @@ func handleAnthropicStreaming( if err := scanner.Err(); err != nil { logger.Warn(fmt.Sprintf("Error reading %s stream: %v", providerType, err)) - - // Send scanner error through channel - errorResponse := &schemas.BifrostStream{ - BifrostError: &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "Error reading stream", - Error: err, - }, - }, - } - - select { - case responseChan <- errorResponse: - case <-ctx.Done(): - } + processAndSendError(ctx, postHookRunner, err, responseChan) } }() diff --git a/core/providers/azure.go b/core/providers/azure.go index ba11046add..2be8db4da3 100644 --- a/core/providers/azure.go +++ b/core/providers/azure.go @@ -165,33 +165,17 @@ func (provider *AzureProvider) GetProviderKey() schemas.ModelProvider { // Returns the response body or an error if the request fails. func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, key schemas.Key, model string) ([]byte, *schemas.BifrostError) { if key.AzureKeyConfig == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "azure key config not set", - }, - } + return nil, newConfigurationError("azure key config not set", schemas.Azure) } // Marshal the request body jsonData, err := json.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Azure) } if key.AzureKeyConfig.Endpoint == "" { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "endpoint not set", - }, - } + return nil, newConfigurationError("endpoint not set", schemas.Azure) } url := key.AzureKeyConfig.Endpoint @@ -199,12 +183,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody if key.AzureKeyConfig.Deployments != nil { deployment := key.AzureKeyConfig.Deployments[model] if deployment == "" { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("deployment if not found for model %s", model), - }, - } + return nil, newConfigurationError(fmt.Sprintf("deployment not found for model %s", model), schemas.Azure) } apiVersion := key.AzureKeyConfig.APIVersion @@ -214,12 +193,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody url = fmt.Sprintf("%s/openai/deployments/%s/%s?api-version=%s", url, deployment, path, *apiVersion) } else { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "deployments not set", - }, - } + return nil, newConfigurationError("deployments not set", schemas.Azure) } // Create the request with the JSON body @@ -389,10 +363,7 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, model string, // Returns a BifrostResponse containing the embedding(s) and any error that occurred. func (provider *AzureProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { if len(input.Texts) == 0 { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{Message: "no input text provided for embedding"}, - } + return nil, newBifrostOperationError("no input text provided for embedding", nil, schemas.Azure) } // Prepare request body - Azure uses deployment-scoped URLs, so model is not needed in body @@ -422,13 +393,7 @@ func (provider *AzureProvider) Embedding(ctx context.Context, model string, key // Parse response var response AzureEmbeddingResponse if err := json.Unmarshal(responseBody, &response); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderResponseUnmarshal, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Azure) } bifrostResponse := &schemas.BifrostResponse{ @@ -464,22 +429,12 @@ func (provider *AzureProvider) Embedding(ctx context.Context, model string, key if num, ok := v[j].(float64); ok { floatArray[j] = float32(num) } else { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("unsupported number type in embedding array: %T", v[j]), - }, - } + return nil, newBifrostOperationError(fmt.Sprintf("unsupported number type in embedding array: %T", v[j]), nil, schemas.Azure) } } embeddings[i] = floatArray default: - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("unsupported embedding type: %T", data.Embedding), - }, - } + return nil, newBifrostOperationError(fmt.Sprintf("unsupported embedding type: %T", data.Embedding), nil, schemas.Azure) } } bifrostResponse.Embedding = embeddings @@ -500,12 +455,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) if key.AzureKeyConfig == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "azure key config not set", - }, - } + return nil, newConfigurationError("azure key config not set", schemas.Azure) } // Merge additional parameters and set stream to true @@ -517,12 +467,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo // Construct Azure-specific URL with deployment if key.AzureKeyConfig.Endpoint == "" { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "endpoint not set", - }, - } + return nil, newConfigurationError("endpoint not set", schemas.Azure) } baseURL := key.AzureKeyConfig.Endpoint @@ -531,12 +476,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo if key.AzureKeyConfig.Deployments != nil { deployment := key.AzureKeyConfig.Deployments[model] if deployment == "" { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("deployment not found for model %s", model), - }, - } + return nil, newConfigurationError(fmt.Sprintf("deployment not found for model %s", model), schemas.Azure) } apiVersion := key.AzureKeyConfig.APIVersion @@ -546,12 +486,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo fullURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", baseURL, deployment, *apiVersion) } else { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "deployments not set", - }, - } + return nil, newConfigurationError("deployments not set", schemas.Azure) } // Prepare Azure-specific headers diff --git a/core/providers/bedrock.go b/core/providers/bedrock.go index 8ec4b94dad..ba6de3b957 100644 --- a/core/providers/bedrock.go +++ b/core/providers/bedrock.go @@ -456,12 +456,7 @@ func (provider *BedrockProvider) getTextCompletionResult(result []byte, model st }, nil } - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("invalid model choice: %s", model), - }, - } + return nil, newConfigurationError(fmt.Sprintf("invalid model choice: %s", model), schemas.Bedrock) } // parseBedrockAnthropicMessageToolCallContent parses the content of a tool call message. @@ -753,12 +748,7 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema return body, nil } - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("invalid model choice: %s", model), - }, - } + return nil, newConfigurationError(fmt.Sprintf("invalid model choice: %s", model), schemas.Bedrock) } // GetChatCompletionTools prepares tool specifications for Bedrock's API. @@ -849,13 +839,7 @@ func (provider *BedrockProvider) TextCompletion(ctx context.Context, model strin // Parse raw response var rawResponse interface{} if err := json.Unmarshal(body, &rawResponse); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "error parsing raw response", - Error: err, - }, - } + return nil, newBifrostOperationError("error parsing raw response", err, schemas.Bedrock) } bifrostResponse.ExtraFields.RawResponse = rawResponse @@ -1077,13 +1061,7 @@ func signAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken if req.Body != nil { bodyBytes, err := io.ReadAll(req.Body) if err != nil { - return &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "error reading request body", - Error: err, - }, - } + return newBifrostOperationError("error reading request body", err, schemas.Bedrock) } // Restore the body for subsequent reads req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) @@ -1110,13 +1088,7 @@ func signAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken })), ) if err != nil { - return &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to load aws config", - Error: err, - }, - } + return newBifrostOperationError("failed to load aws config", err, schemas.Bedrock) } // Create the AWS signer @@ -1125,24 +1097,12 @@ func signAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken // Get credentials creds, err := cfg.Credentials.Retrieve(context.TODO()) if err != nil { - return &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to retrieve aws credentials", - Error: err, - }, - } + return newBifrostOperationError("failed to retrieve aws credentials", err, schemas.Bedrock) } // Sign the request with AWS Signature V4 if err := signer.SignHTTP(context.TODO(), creds, req, bodyHash, service, region, time.Now()); err != nil { - return &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to sign request", - Error: err, - }, - } + return newBifrostOperationError("failed to sign request", err, schemas.Bedrock) } return nil @@ -1157,10 +1117,7 @@ func (provider *BedrockProvider) Embedding(ctx context.Context, model string, ke case strings.HasPrefix(model, "cohere.embed"): return provider.handleCohereEmbedding(ctx, model, key.Value, input, params) default: - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{Message: "embedding is not supported for this Bedrock model"}, - } + return nil, newConfigurationError("embedding is not supported for this Bedrock model", schemas.Bedrock) } } @@ -1168,16 +1125,10 @@ func (provider *BedrockProvider) Embedding(ctx context.Context, model string, ke func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { // Titan Text Embeddings V1/V2 - only supports single text input if len(input.Texts) == 0 { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{Message: "no input text provided for embedding"}, - } + return nil, newConfigurationError("no input text provided for embedding", schemas.Bedrock) } if len(input.Texts) > 1 { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{Message: "Amazon Titan embedding models support only single text input, received multiple texts"}, - } + return nil, newConfigurationError("Amazon Titan embedding models support only single text input, received multiple texts", schemas.Bedrock) } requestBody := map[string]interface{}{ @@ -1187,10 +1138,7 @@ func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model if params != nil { // Titan models do not support the dimensions parameter - they have fixed dimensions if params.Dimensions != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{Message: "Amazon Titan embedding models do not support custom dimensions parameter"}, - } + return nil, newConfigurationError("Amazon Titan embedding models do not support custom dimensions parameter", schemas.Bedrock) } if params.ExtraParams != nil { for k, v := range params.ExtraParams { @@ -1212,13 +1160,7 @@ func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model InputTextTokenCount int `json:"inputTextTokenCount"` } if err := json.Unmarshal(rawResponse, &titanResp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "error parsing Titan embedding response", - Error: err, - }, - } + return nil, newBifrostOperationError("error parsing Titan embedding response", err, schemas.Bedrock) } bifrostResponse := &schemas.BifrostResponse{ @@ -1244,10 +1186,7 @@ func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model // handleCohereEmbedding handles embedding requests for Cohere models on Bedrock. func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { if len(input.Texts) == 0 { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{Message: "no input text provided for embedding"}, - } + return nil, newConfigurationError("no input text provided for embedding", schemas.Bedrock) } requestBody := map[string]interface{}{ @@ -1272,13 +1211,7 @@ func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, mode Texts []string `json:"texts"` } if err := json.Unmarshal(rawResponse, &cohereResp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "error parsing Cohere embedding response", - Error: err, - }, - } + return nil, newBifrostOperationError("error parsing Cohere embedding response", err, schemas.Bedrock) } // Calculate token usage based on input texts (approximation since Cohere doesn't provide this) @@ -1350,12 +1283,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH } if provider.meta == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "meta config for bedrock is not provided", - }, - } + return nil, newConfigurationError("meta config for bedrock is not provided", schemas.Bedrock) } region := "us-east-1" @@ -1366,25 +1294,13 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH // Create the streaming request jsonBody, jsonErr := json.Marshal(requestBody) if jsonErr != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: jsonErr, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, jsonErr, schemas.Bedrock) } // Create HTTP request for streaming req, reqErr := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), strings.NewReader(string(jsonBody))) if reqErr != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "error creating request", - Error: reqErr, - }, - } + return nil, newBifrostOperationError("error creating request", reqErr, schemas.Bedrock) } // Set any extra headers from network config @@ -1396,38 +1312,20 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH return nil, signErr } } else { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "secret access key not set", - }, - } + return nil, newConfigurationError("secret access key not set", schemas.Bedrock) } // Make the request resp, respErr := provider.client.Do(req) if respErr != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: respErr, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderRequest, respErr, schemas.Bedrock) } // Check for HTTP errors if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) resp.Body.Close() - return nil, &schemas.BifrostError{ - IsBifrostError: false, - StatusCode: &resp.StatusCode, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("HTTP error from Bedrock: %d", resp.StatusCode), - Error: fmt.Errorf("%s", string(body)), - }, - } + return nil, newProviderAPIError(fmt.Sprintf("HTTP error from Bedrock: %d", resp.StatusCode), fmt.Errorf("%s", string(body)), resp.StatusCode, schemas.Bedrock, nil, nil) } // Create response channel @@ -1536,7 +1434,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) } case delta["toolUse"] != nil: @@ -1593,7 +1491,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) } } @@ -1627,7 +1525,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) } case event["stopReason"] != nil: @@ -1657,7 +1555,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, finalResponse, responseChan) + processAndSendResponse(ctx, postHookRunner, finalResponse, responseChan) return } @@ -1706,13 +1604,14 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, usageResponse, responseChan) + processAndSendResponse(ctx, postHookRunner, usageResponse, responseChan) } } } if err := scanner.Err(); err != nil { provider.logger.Warn(fmt.Sprintf("Error reading Bedrock stream: %v", err)) + processAndSendError(ctx, postHookRunner, err, responseChan) } }() diff --git a/core/providers/cohere.go b/core/providers/cohere.go index eccf4dfc20..8710efd7e3 100644 --- a/core/providers/cohere.go +++ b/core/providers/cohere.go @@ -593,10 +593,7 @@ func convertChatHistory(history []struct { // Supports Cohere's embedding models and returns a BifrostResponse containing the embedding(s). func (provider *CohereProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { if len(input.Texts) == 0 { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{Message: "no input text provided for embedding"}, - } + return nil, newConfigurationError("no input text provided for embedding", schemas.Cohere) } // Prepare request body with default values @@ -612,12 +609,7 @@ func (provider *CohereProvider) Embedding(ctx context.Context, model string, key // Validate encoding format - Cohere API supports float, int8, uint8, binary, ubinary, but our provider only implements float if params.EncodingFormat != nil { if *params.EncodingFormat != "float" { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("Cohere provider currently only supports 'float' encoding format, received: %s", *params.EncodingFormat), - }, - } + return nil, newConfigurationError(fmt.Sprintf("Cohere provider currently only supports 'float' encoding format, received: %s", *params.EncodingFormat), schemas.Cohere) } // Override default with the specified format requestBody["embedding_types"] = []string{*params.EncodingFormat} @@ -634,13 +626,7 @@ func (provider *CohereProvider) Embedding(ctx context.Context, model string, key // Marshal request body jsonBody, err := json.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Cohere) } // Create request @@ -679,25 +665,13 @@ func (provider *CohereProvider) Embedding(ctx context.Context, model string, key // Parse response var cohereResp CohereEmbeddingResponse if err := json.Unmarshal(resp.Body(), &cohereResp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "error parsing Cohere embedding response", - Error: err, - }, - } + return nil, newBifrostOperationError("error parsing Cohere embedding response", err, schemas.Cohere) } // Parse raw response for consistent format var rawResponse interface{} if err := json.Unmarshal(resp.Body(), &rawResponse); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "error parsing raw response for Cohere embedding", - Error: err, - }, - } + return nil, newBifrostOperationError("error parsing raw response for Cohere embedding", err, schemas.Cohere) } // Calculate token usage approximation (since Cohere doesn't provide this for embeddings) @@ -732,36 +706,18 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo // Prepare request body using shared function requestBody, err := prepareCohereChatRequest(messages, params, model, true) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to prepare Cohere chat request", - Error: err, - }, - } + return nil, newBifrostOperationError("failed to prepare Cohere chat request", err, schemas.Cohere) } jsonBody, err := json.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Cohere) } // Create HTTP request for streaming req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/v1/chat", strings.NewReader(string(jsonBody))) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to create HTTP request", - Error: err, - }, - } + return nil, newBifrostOperationError("failed to create HTTP request", err, schemas.Cohere) } // Set headers @@ -789,14 +745,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) resp.Body.Close() - return nil, &schemas.BifrostError{ - IsBifrostError: false, - StatusCode: &resp.StatusCode, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("HTTP error from Cohere: %d", resp.StatusCode), - Error: fmt.Errorf("%s", string(body)), - }, - } + return nil, newProviderAPIError(fmt.Sprintf("HTTP error from Cohere: %d", resp.StatusCode), fmt.Errorf("%s", string(body)), resp.StatusCode, schemas.Cohere, nil, nil) } // Create response channel @@ -870,7 +819,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) + processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) case "text-generation": var textEvent CohereStreamTextEvent @@ -905,7 +854,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, response, responseChan) + processAndSendResponse(ctx, postHookRunner, response, responseChan) case "tool-calls-chunk": var toolEvent CohereStreamToolCallEvent @@ -949,7 +898,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, response, responseChan) + processAndSendResponse(ctx, postHookRunner, response, responseChan) case "stream-end": var stopEvent CohereStreamStopEvent @@ -1005,7 +954,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo } // Use utility function to process and send response - ProcessAndSendResponse(ctx, postHookRunner, response, responseChan) + processAndSendResponse(ctx, postHookRunner, response, responseChan) return // End of stream @@ -1018,6 +967,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo if err := scanner.Err(); err != nil { provider.logger.Warn(fmt.Sprintf("Error reading Cohere stream: %v", err)) + processAndSendError(ctx, postHookRunner, err, responseChan) } }() diff --git a/core/providers/groq.go b/core/providers/groq.go index 9bf66a72b0..82b0dc0316 100644 --- a/core/providers/groq.go +++ b/core/providers/groq.go @@ -115,13 +115,7 @@ func (provider *GroqProvider) ChatCompletion(ctx context.Context, model string, jsonBody, err := json.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Groq) } // Create request diff --git a/core/providers/mistral.go b/core/providers/mistral.go index 7c103ed799..59d89b69fe 100644 --- a/core/providers/mistral.go +++ b/core/providers/mistral.go @@ -129,13 +129,7 @@ func (provider *MistralProvider) ChatCompletion(ctx context.Context, model strin jsonBody, err := json.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Mistral) } // Create request @@ -207,10 +201,7 @@ func (provider *MistralProvider) ChatCompletion(ctx context.Context, model strin // Supports Mistral's embedding models and returns a BifrostResponse containing the embedding(s). func (provider *MistralProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { if len(input.Texts) == 0 { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{Message: "no input text provided for embedding"}, - } + return nil, newConfigurationError("no input text provided for embedding", schemas.Mistral) } // Prepare request body with base parameters @@ -224,12 +215,7 @@ func (provider *MistralProvider) Embedding(ctx context.Context, model string, ke // Validate encoding format - Mistral API supports multiple formats, but our provider only implements float if params.EncodingFormat != nil { if *params.EncodingFormat != "float" { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("Mistral provider currently only supports 'float' encoding format, received: %s", *params.EncodingFormat), - }, - } + return nil, newConfigurationError(fmt.Sprintf("Mistral provider currently only supports 'float' encoding format, received: %s", *params.EncodingFormat), schemas.Mistral) } // Map to Mistral's parameter name requestBody["output_dtype"] = *params.EncodingFormat @@ -250,13 +236,7 @@ func (provider *MistralProvider) Embedding(ctx context.Context, model string, ke jsonBody, err := json.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Mistral) } // Create request @@ -297,25 +277,13 @@ func (provider *MistralProvider) Embedding(ctx context.Context, model string, ke // Parse into structured response var mistralResp MistralEmbeddingResponse if err := json.Unmarshal(rawMessage, &mistralResp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "error parsing Mistral embedding response", - Error: err, - }, - } + return nil, newBifrostOperationError("error parsing Mistral embedding response", err, schemas.Mistral) } // Parse raw response for consistent format var rawResponse interface{} if err := json.Unmarshal(rawMessage, &rawResponse); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "error parsing raw response for Mistral embedding", - Error: err, - }, - } + return nil, newBifrostOperationError("error parsing raw response for Mistral embedding", err, schemas.Mistral) } // Convert data to embeddings array diff --git a/core/providers/ollama.go b/core/providers/ollama.go index 2e190933a3..1f52787af4 100644 --- a/core/providers/ollama.go +++ b/core/providers/ollama.go @@ -116,13 +116,7 @@ func (provider *OllamaProvider) ChatCompletion(ctx context.Context, model string jsonBody, err := json.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Ollama) } // Create request diff --git a/core/providers/openai.go b/core/providers/openai.go index 9064af6fd1..66cda30dea 100644 --- a/core/providers/openai.go +++ b/core/providers/openai.go @@ -147,13 +147,7 @@ func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, model string jsonBody, err := json.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.OpenAI) } // Create request @@ -273,12 +267,7 @@ func prepareOpenAIChatRequest(messages []schemas.BifrostMessage, params *schemas func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { // Validate input texts are not empty if len(input.Texts) == 0 { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "input texts cannot be empty", - }, - } + return nil, newBifrostOperationError("input texts cannot be empty", nil, schemas.OpenAI) } // Prepare request body with base parameters @@ -306,13 +295,7 @@ func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key jsonBody, err := json.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.OpenAI) } // Create request @@ -346,13 +329,7 @@ func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key // Parse response var response OpenAIResponse if err := json.Unmarshal(resp.Body(), &response); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderResponseUnmarshal, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.OpenAI) } // Create final response @@ -383,12 +360,7 @@ func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key if num, ok := v[j].(float64); ok { floatArray[j] = float32(num) } else { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("unsupported number type in embedding array: %T", v[j]), - }, - } + return nil, newBifrostOperationError(fmt.Sprintf("unsupported number type in embedding array: %T", v[j]), nil, schemas.OpenAI) } } embeddings[i] = floatArray @@ -396,24 +368,13 @@ func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key // Decode base64 string into float32 array decodedData, err := base64.StdEncoding.DecodeString(v) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to decode base64 embedding", - Error: err, - }, - } + return nil, newBifrostOperationError("failed to decode base64 embedding", err, schemas.OpenAI) } // Validate that decoded data length is divisible by 4 (size of float32) const sizeOfFloat32 = 4 if len(decodedData)%sizeOfFloat32 != 0 { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "malformed base64 embedding data: length not divisible by 4", - }, - } + return nil, newBifrostOperationError("malformed base64 embedding data: length not divisible by 4", nil, schemas.OpenAI) } floats := make([]float32, len(decodedData)/sizeOfFloat32) @@ -422,12 +383,7 @@ func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key } embeddings[i] = floats default: - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("unsupported embedding type: %T", data.Embedding), - }, - } + return nil, newBifrostOperationError(fmt.Sprintf("unsupported embedding type: %T", data.Embedding), nil, schemas.OpenAI) } } bifrostResponse.Embedding = embeddings @@ -492,25 +448,13 @@ func handleOpenAIStreaming( jsonBody, err := json.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.OpenAI) } // Create HTTP request for streaming req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(jsonBody))) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to create HTTP request", - Error: err, - }, - } + return nil, newBifrostOperationError("failed to create HTTP request", err, schemas.OpenAI) } // Set headers @@ -524,27 +468,12 @@ func handleOpenAIStreaming( // Make the request resp, err := httpClient.Do(req) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, schemas.OpenAI) } // Check for HTTP errors if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - resp.Body.Close() - return nil, &schemas.BifrostError{ - IsBifrostError: false, - StatusCode: &resp.StatusCode, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("HTTP error from %s: %d", providerType, resp.StatusCode), - Error: fmt.Errorf("%s", string(body)), - }, - } + return nil, parseStreamOpenAIError(resp) } // Create response channel @@ -642,7 +571,7 @@ func handleOpenAIStreaming( } response.ExtraFields.Provider = providerType - ProcessAndSendResponse(ctx, postHookRunner, &response, responseChan) + processAndSendResponse(ctx, postHookRunner, &response, responseChan) continue } @@ -660,7 +589,7 @@ func handleOpenAIStreaming( } response.ExtraFields.Provider = providerType - ProcessAndSendResponse(ctx, postHookRunner, &response, responseChan) + processAndSendResponse(ctx, postHookRunner, &response, responseChan) // End stream processing after finish reason break @@ -673,29 +602,14 @@ func handleOpenAIStreaming( } response.ExtraFields.Provider = providerType - ProcessAndSendResponse(ctx, postHookRunner, &response, responseChan) + processAndSendResponse(ctx, postHookRunner, &response, responseChan) } } // Handle scanner errors if err := scanner.Err(); err != nil { logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) - - // Send scanner error through channel - errorResponse := &schemas.BifrostStream{ - BifrostError: &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "Error reading stream", - Error: err, - }, - }, - } - - select { - case responseChan <- errorResponse: - case <-ctx.Done(): - } + processAndSendError(ctx, postHookRunner, err, responseChan) } }() @@ -725,13 +639,7 @@ func (provider *OpenAIProvider) Speech(ctx context.Context, model string, key sc jsonBody, err := json.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.OpenAI) } // Create request @@ -810,13 +718,7 @@ func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner jsonBody, err := json.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.OpenAI) } // Prepare OpenAI headers @@ -830,13 +732,7 @@ func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner // Create HTTP request for streaming req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/v1/audio/speech", strings.NewReader(string(jsonBody))) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to create HTTP request", - Error: err, - }, - } + return nil, newBifrostOperationError("failed to create HTTP request", err, schemas.OpenAI) } // Set headers @@ -850,27 +746,12 @@ func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner // Make the request resp, err := provider.streamClient.Do(req) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, schemas.OpenAI) } // Check for HTTP errors if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - resp.Body.Close() - return nil, &schemas.BifrostError{ - IsBifrostError: false, - StatusCode: &resp.StatusCode, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("HTTP error from %s: %d", schemas.OpenAI, resp.StatusCode), - Error: fmt.Errorf("%s", string(body)), - }, - } + return nil, parseStreamOpenAIError(resp) } // Create response channel @@ -973,28 +854,13 @@ func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner response.ExtraFields.Params = *params } - ProcessAndSendResponse(ctx, postHookRunner, &response, responseChan) + processAndSendResponse(ctx, postHookRunner, &response, responseChan) } // Handle scanner errors if err := scanner.Err(); err != nil { provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) - - // Send scanner error through channel - errorResponse := &schemas.BifrostStream{ - BifrostError: &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "Error reading stream", - Error: err, - }, - }, - } - - select { - case responseChan <- errorResponse: - case <-ctx.Done(): - } + processAndSendError(ctx, postHookRunner, err, responseChan) } }() @@ -1009,132 +875,8 @@ func (provider *OpenAIProvider) Transcription(ctx context.Context, model string, var body bytes.Buffer writer := multipart.NewWriter(&body) - // Add file field - fileWriter, err := writer.CreateFormFile("file", "audio.mp3") // OpenAI requires a filename - if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to create form file", - Error: err, - }, - } - } - if _, err := fileWriter.Write(input.File); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to write file data", - Error: err, - }, - } - } - - // Add model field - if err := writer.WriteField("model", model); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to write model field", - Error: err, - }, - } - } - - // Add optional fields - if input.Language != nil { - if err := writer.WriteField("language", *input.Language); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to write language field", - Error: err, - }, - } - } - } - - if input.Prompt != nil { - if err := writer.WriteField("prompt", *input.Prompt); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to write prompt field", - Error: err, - }, - } - } - } - - if input.ResponseFormat != nil { - if err := writer.WriteField("response_format", *input.ResponseFormat); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to write response_format field", - Error: err, - }, - } - } - } - - // Note: Temperature and TimestampGranularities can be added via params.ExtraParams if needed - - // Add extra params if provided - if params != nil && params.ExtraParams != nil { - for key, value := range params.ExtraParams { - // Handle array parameters specially for OpenAI's form data format - switch v := value.(type) { - case []string: - // For arrays like timestamp_granularities[] or include[] - for _, item := range v { - if err := writer.WriteField(key+"[]", item); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("failed to write array param %s", key), - Error: err, - }, - } - } - } - case []interface{}: - // Handle generic interface arrays - for _, item := range v { - if err := writer.WriteField(key+"[]", fmt.Sprintf("%v", item)); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("failed to write array param %s", key), - Error: err, - }, - } - } - } - default: - // Handle non-array parameters normally - if err := writer.WriteField(key, fmt.Sprintf("%v", value)); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("failed to write extra param %s", key), - Error: err, - }, - } - } - } - } - } - - // Close the multipart writer - if err := writer.Close(); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to close multipart writer", - Error: err, - }, - } + if bifrostErr := parseTranscriptionFormDataBody(writer, input, model, params); bifrostErr != nil { + return nil, bifrostErr } // Create request @@ -1173,25 +915,13 @@ func (provider *OpenAIProvider) Transcription(ctx context.Context, model string, } if err := json.Unmarshal(responseBody, transcribeResponse); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderResponseUnmarshal, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.OpenAI) } // Parse raw response for RawResponse field var rawResponse interface{} if err := json.Unmarshal(responseBody, &rawResponse); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderDecodeRaw, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderDecodeRaw, err, schemas.OpenAI) } // Create final response @@ -1213,150 +943,17 @@ func (provider *OpenAIProvider) Transcription(ctx context.Context, model string, } -// TranscriptionStream handles streaming for transcription. -// It creates a multipart form, adds fields, creates HTTP request, and uses shared streaming logic. -// Returns a channel for streaming responses and any error that occurred. func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Create multipart form var body bytes.Buffer writer := multipart.NewWriter(&body) - // Add file field - fileWriter, err := writer.CreateFormFile("file", "audio.mp3") // OpenAI requires a filename - if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to create form file", - Error: err, - }, - } - } - if _, err := fileWriter.Write(input.File); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to write file data", - Error: err, - }, - } - } - - // Add model field - if err := writer.WriteField("model", model); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to write model field", - Error: err, - }, - } - } - - // Add optional fields - if input.Language != nil { - if err := writer.WriteField("language", *input.Language); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to write language field", - Error: err, - }, - } - } - } - - if input.Prompt != nil { - if err := writer.WriteField("prompt", *input.Prompt); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to write prompt field", - Error: err, - }, - } - } - } - - if input.ResponseFormat != nil { - if err := writer.WriteField("response_format", *input.ResponseFormat); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to write response_format field", - Error: err, - }, - } - } - } - - // Note: Temperature and TimestampGranularities can be added via params.ExtraParams if needed - - // Add extra params if provided - if params != nil && params.ExtraParams != nil { - for key, value := range params.ExtraParams { - // Handle array parameters specially for OpenAI's form data format - switch v := value.(type) { - case []string: - // For arrays like timestamp_granularities[] or include[] - for _, item := range v { - if err := writer.WriteField(key+"[]", item); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("failed to write array param %s", key), - Error: err, - }, - } - } - } - case []interface{}: - // Handle generic interface arrays - for _, item := range v { - if err := writer.WriteField(key+"[]", fmt.Sprintf("%v", item)); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("failed to write array param %s", key), - Error: err, - }, - } - } - } - default: - // Handle non-array parameters normally - if err := writer.WriteField(key, fmt.Sprintf("%v", value)); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("failed to write extra param %s", key), - Error: err, - }, - } - } - } - } - } - if err := writer.WriteField("stream", "true"); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to write stream field", - Error: err, - }, - } + return nil, newBifrostOperationError("failed to write stream field", err, schemas.OpenAI) } - // Close the multipart writer - if err := writer.Close(); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to close multipart writer", - Error: err, - }, - } + if bifrostErr := parseTranscriptionFormDataBody(writer, input, model, params); bifrostErr != nil { + return nil, bifrostErr } // Prepare OpenAI headers @@ -1370,13 +967,7 @@ func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHoo // Create HTTP request for streaming req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/v1/audio/transcriptions", &body) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "failed to create HTTP request", - Error: err, - }, - } + return nil, newBifrostOperationError("failed to create HTTP request", err, schemas.OpenAI) } // Set headers @@ -1390,26 +981,12 @@ func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHoo // Make the request resp, err := provider.streamClient.Do(req) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, schemas.OpenAI) } // Check for HTTP errors if resp.StatusCode != http.StatusOK { - //TODO: proper openAI error handling - resp.Body.Close() - return nil, &schemas.BifrostError{ - IsBifrostError: false, - StatusCode: &resp.StatusCode, - Error: schemas.ErrorField{ - Message: fmt.Sprintf("HTTP error from %s: %d", schemas.OpenAI, resp.StatusCode), - }, - } + return nil, parseStreamOpenAIError(resp) } // Create response channel @@ -1436,7 +1013,6 @@ func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHoo } var jsonData string - // Parse SSE data if strings.HasPrefix(line, "data: ") { jsonData = strings.TrimPrefix(line, "data: ") @@ -1511,32 +1087,89 @@ func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHoo response.ExtraFields.Params = *params } - ProcessAndSendResponse(ctx, postHookRunner, &response, responseChan) + processAndSendResponse(ctx, postHookRunner, &response, responseChan) } // Handle scanner errors if err := scanner.Err(); err != nil { provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) + processAndSendError(ctx, postHookRunner, err, responseChan) + } + }() - // Send scanner error through channel - errorResponse := &schemas.BifrostStream{ - BifrostError: &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "Error reading stream", - Error: err, - }, - }, - } + return responseChan, nil +} + +func parseTranscriptionFormDataBody(writer *multipart.Writer, input *schemas.TranscriptionInput, model string, params *schemas.ModelParameters) *schemas.BifrostError { + // Add file field + fileWriter, err := writer.CreateFormFile("file", "audio.mp3") // OpenAI requires a filename + if err != nil { + return newBifrostOperationError("failed to create form file", err, schemas.OpenAI) + } + if _, err := fileWriter.Write(input.File); err != nil { + return newBifrostOperationError("failed to write file data", err, schemas.OpenAI) + } + + // Add model field + if err := writer.WriteField("model", model); err != nil { + return newBifrostOperationError("failed to write model field", err, schemas.OpenAI) + } + + // Add optional fields + if input.Language != nil { + if err := writer.WriteField("language", *input.Language); err != nil { + return newBifrostOperationError("failed to write language field", err, schemas.OpenAI) + } + } - select { - case responseChan <- errorResponse: - case <-ctx.Done(): + if input.Prompt != nil { + if err := writer.WriteField("prompt", *input.Prompt); err != nil { + return newBifrostOperationError("failed to write prompt field", err, schemas.OpenAI) + } + } + + if input.ResponseFormat != nil { + if err := writer.WriteField("response_format", *input.ResponseFormat); err != nil { + return newBifrostOperationError("failed to write response_format field", err, schemas.OpenAI) + } + } + + // Note: Temperature and TimestampGranularities can be added via params.ExtraParams if needed + + // Add extra params if provided + if params != nil && params.ExtraParams != nil { + for key, value := range params.ExtraParams { + // Handle array parameters specially for OpenAI's form data format + switch v := value.(type) { + case []string: + // For arrays like timestamp_granularities[] or include[] + for _, item := range v { + if err := writer.WriteField(key+"[]", item); err != nil { + return newBifrostOperationError(fmt.Sprintf("failed to write array param %s", key), err, schemas.OpenAI) + } + } + case []interface{}: + // Handle generic interface arrays + for _, item := range v { + if err := writer.WriteField(key+"[]", fmt.Sprintf("%v", item)); err != nil { + return newBifrostOperationError(fmt.Sprintf("failed to write array param %s", key), err, schemas.OpenAI) + } + } + default: + // Handle non-array parameters normally + if err := writer.WriteField(key, fmt.Sprintf("%v", value)); err != nil { + return newBifrostOperationError(fmt.Sprintf("failed to write extra param %s", key), err, schemas.OpenAI) + } } } - }() + } - return responseChan, nil + // Close the multipart writer + if err := writer.Close(); err != nil { + return newBifrostOperationError("failed to close multipart writer", err, schemas.OpenAI) + } + + return nil } func parseOpenAIError(resp *fasthttp.Response) *schemas.BifrostError { @@ -1557,3 +1190,41 @@ func parseOpenAIError(resp *fasthttp.Response) *schemas.BifrostError { return bifrostErr } + +func parseStreamOpenAIError(resp *http.Response) *schemas.BifrostError { + var errorResp OpenAIError + + statusCode := resp.StatusCode + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + + if err := json.Unmarshal(body, &errorResp); err != nil { + return &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: &statusCode, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderResponseUnmarshal, + Error: err, + }, + } + } + + bifrostErr := &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: &statusCode, + Error: schemas.ErrorField{}, + } + + if errorResp.EventID != "" { + bifrostErr.EventID = &errorResp.EventID + } + bifrostErr.Error.Type = &errorResp.Error.Type + bifrostErr.Error.Code = &errorResp.Error.Code + bifrostErr.Error.Message = errorResp.Error.Message + bifrostErr.Error.Param = errorResp.Error.Param + if errorResp.Error.EventID != "" { + bifrostErr.Error.EventID = &errorResp.Error.EventID + } + + return bifrostErr +} diff --git a/core/providers/utils.go b/core/providers/utils.go index 856233139b..8080195b40 100644 --- a/core/providers/utils.go +++ b/core/providers/utils.go @@ -553,6 +553,48 @@ func newUnsupportedOperationError(operation string, providerName string) *schema } } +// newConfigurationError creates a standardized error for configuration errors. +// This helper reduces code duplication across providers that have configuration errors. +func newConfigurationError(message string, providerType schemas.ModelProvider) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + Provider: providerType, + Error: schemas.ErrorField{ + Message: message, + }, + } +} + +// newBifrostOperationError creates a standardized error for bifrost operation errors. +// This helper reduces code duplication across providers that have bifrost operation errors. +func newBifrostOperationError(message string, err error, providerType schemas.ModelProvider) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: true, + Provider: providerType, + Error: schemas.ErrorField{ + Message: message, + Error: err, + }, + } +} + +// newProviderAPIError creates a standardized error for provider API errors. +// This helper reduces code duplication across providers that have provider API errors. +func newProviderAPIError(message string, err error, statusCode int, providerType schemas.ModelProvider, errorType *string, eventID *string) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + Provider: providerType, + StatusCode: &statusCode, + Type: errorType, + EventID: eventID, + Error: schemas.ErrorField{ + Message: message, + Error: err, + Type: errorType, + }, + } +} + // approximateTokenCount provides a rough approximation of token count for text. // WARNING: This is a best-effort approximation using 1 token per 4 characters. // This heuristic is particularly inaccurate for: @@ -580,7 +622,7 @@ func approximateTokenCount(texts []string) int { // This utility reduces code duplication across streaming implementations by encapsulating // the common pattern of running post hooks, handling errors, and sending responses with // proper context cancellation handling. -func ProcessAndSendResponse( +func processAndSendResponse( ctx context.Context, postHookRunner schemas.PostHookRunner, response *schemas.BifrostResponse, @@ -611,3 +653,33 @@ func ProcessAndSendResponse( return } } + +// processAndSendError handles post-hook processing and sends the error to the channel. +// This utility reduces code duplication across streaming implementations by encapsulating +// the common pattern of running post hooks, handling errors, and sending responses with +// proper context cancellation handling. +func processAndSendError( + ctx context.Context, + postHookRunner schemas.PostHookRunner, + err error, + responseChan chan *schemas.BifrostStream, +) { + // Send scanner error through channel + bifrostError := + &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: fmt.Sprintf("Error reading stream: %v", err), + Error: err, + }, + } + processedResponse, processedError := postHookRunner(&ctx, nil, bifrostError) + errorResponse := &schemas.BifrostStream{ + BifrostResponse: processedResponse, + BifrostError: processedError, + } + select { + case responseChan <- errorResponse: + case <-ctx.Done(): + } +} diff --git a/core/providers/vertex.go b/core/providers/vertex.go index 2388f218e7..b1c58170c3 100644 --- a/core/providers/vertex.go +++ b/core/providers/vertex.go @@ -131,12 +131,7 @@ func (provider *VertexProvider) TextCompletion(ctx context.Context, model string // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *VertexProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { if key.VertexKeyConfig == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "vertex key config is not set", - }, - } + return nil, newConfigurationError("vertex key config is not set", schemas.Vertex) } // Format messages for Vertex API @@ -166,33 +161,17 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model string jsonBody, err := json.Marshal(requestBody) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Vertex) } projectID := key.VertexKeyConfig.ProjectID if projectID == "" { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "project ID is not set", - }, - } + return nil, newConfigurationError("project ID is not set", schemas.Vertex) } region := key.VertexKeyConfig.Region if region == "" { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "region is not set in meta config", - }, - } + return nil, newConfigurationError("region is not set in meta config", schemas.Vertex) } url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region) @@ -222,13 +201,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model string if err != nil { // Remove client from pool if auth client creation fails removeVertexClient(key.VertexKeyConfig.AuthCredentials) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "error creating auth client", - Error: err, - }, - } + return nil, newBifrostOperationError("error creating auth client", err, schemas.Vertex) } // Make request @@ -246,13 +219,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model string } // Remove client from pool for non-context errors (could be auth/network issues) removeVertexClient(key.VertexKeyConfig.AuthCredentials) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "error creating request", - Error: err, - }, - } + return nil, newBifrostOperationError("error creating request", err, schemas.Vertex) } defer resp.Body.Close() @@ -260,13 +227,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model string // Read response body body, err := io.ReadAll(resp.Body) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: "error reading request", - Error: err, - }, - } + return nil, newBifrostOperationError("error reading response", err, schemas.Vertex) } if resp.StatusCode != http.StatusOK { @@ -283,33 +244,15 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model string if err := json.Unmarshal(body, &openAIErr); err != nil { // Try Vertex error format if OpenAI format fails if err := json.Unmarshal(body, &vertexErr); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - StatusCode: &resp.StatusCode, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderResponseUnmarshal, - Error: err, - }, - } + return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex) } if len(vertexErr) > 0 { - return nil, &schemas.BifrostError{ - StatusCode: &resp.StatusCode, - Type: &vertexErr[0].Error.Status, - Error: schemas.ErrorField{ - Message: vertexErr[0].Error.Message, - }, - } + return nil, newProviderAPIError(vertexErr[0].Error.Message, nil, resp.StatusCode, schemas.Vertex, nil, nil) } } - return nil, &schemas.BifrostError{ - StatusCode: &resp.StatusCode, - Error: schemas.ErrorField{ - Message: openAIErr.Error.Message, - }, - } + return nil, newProviderAPIError(openAIErr.Error.Message, nil, resp.StatusCode, schemas.Vertex, nil, nil) } if strings.Contains(model, "claude") { @@ -385,45 +328,24 @@ func (provider *VertexProvider) Embedding(ctx context.Context, model string, key // Returns a channel of BifrostResponse objects for streaming results or an error if the request fails. func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { if key.VertexKeyConfig == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "vertex key config is not set", - }, - } + return nil, newConfigurationError("vertex key config is not set", schemas.Vertex) } projectID := key.VertexKeyConfig.ProjectID if projectID == "" { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "project ID is not set", - }, - } + return nil, newConfigurationError("project ID is not set", schemas.Vertex) } region := key.VertexKeyConfig.Region if region == "" { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "region is not set in meta config", - }, - } + return nil, newConfigurationError("region is not set in meta config", schemas.Vertex) } client, err := getAuthClient(key) if err != nil { // Remove client from pool if auth client creation fails removeVertexClient(key.VertexKeyConfig.AuthCredentials) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "error creating auth client", - Error: err, - }, - } + return nil, newBifrostOperationError("error creating auth client", err, schemas.Vertex) } if strings.Contains(model, "claude") {