From 91ed8104f148188a8fd58e9fe15e7b824dcecc5d Mon Sep 17 00:00:00 2001 From: Pratham Mishra <99235987+Pratham-Mishra04@users.noreply.github.com> Date: Thu, 10 Jul 2025 23:29:53 +0530 Subject: [PATCH] feat: error handling added for transport integrations with their test cases --- core/bifrost.go | 26 ++ core/schemas/bifrost.go | 13 +- tests/transports-integrations/README.md | 7 +- .../tests/integrations/test_anthropic.py | 19 ++ .../tests/integrations/test_google.py | 17 ++ .../tests/integrations/test_litellm.py | 19 ++ .../tests/integrations/test_openai.py | 18 ++ .../tests/utils/common.py | 276 +++++++++++++++--- .../integrations/anthropic/router.go | 3 + .../integrations/anthropic/types.go | 40 +++ .../bifrost-http/integrations/genai/router.go | 3 + .../bifrost-http/integrations/genai/types.go | 38 +++ .../integrations/litellm/router.go | 14 + .../integrations/openai/router.go | 3 + .../bifrost-http/integrations/openai/types.go | 67 +++++ transports/bifrost-http/integrations/utils.go | 49 ++-- 16 files changed, 551 insertions(+), 61 deletions(-) diff --git a/core/bifrost.go b/core/bifrost.go index 096c8c94dc..2d0552b789 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -192,6 +192,7 @@ func Init(config schemas.BifrostConfig) (*Bifrost, error) { // If the primary provider fails, it will try each fallback provider in order until one succeeds. func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { if err := validateRequest(req); err != nil { + err.Provider = req.Provider return nil, err } @@ -202,12 +203,14 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas. } if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { + primaryErr.Provider = req.Provider return nil, primaryErr } // Check if this is a short-circuit error that doesn't allow fallbacks // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { + primaryErr.Provider = req.Provider return nil, primaryErr } @@ -234,6 +237,7 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas. return result, nil } if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { + fallbackErr.Provider = fallback.Provider return nil, fallbackErr } @@ -241,6 +245,8 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas. } } + primaryErr.Provider = req.Provider + // All providers failed, return the original error return nil, primaryErr } @@ -250,6 +256,7 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas. // If the primary provider fails, it will try each fallback provider in order until one succeeds. func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { if err := validateRequest(req); err != nil { + err.Provider = req.Provider return nil, err } @@ -259,9 +266,15 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. return primaryResult, nil } + if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { + primaryErr.Provider = req.Provider + return nil, primaryErr + } + // Check if this is a short-circuit error that doesn't allow fallbacks // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { + primaryErr.Provider = req.Provider return nil, primaryErr } @@ -288,6 +301,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. return result, nil } if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { + fallbackErr.Provider = fallback.Provider return nil, fallbackErr } @@ -295,6 +309,8 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. } } + primaryErr.Provider = req.Provider + // All providers failed, return the original error return nil, primaryErr } @@ -304,6 +320,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. // If the primary provider fails, it will try each fallback provider in order until one succeeds. func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { if err := validateRequest(req); err != nil { + err.Provider = req.Provider return nil, err } @@ -317,9 +334,15 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.Bifro return primaryResult, nil } + if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { + primaryErr.Provider = req.Provider + return nil, primaryErr + } + // Check if this is a short-circuit error that doesn't allow fallbacks // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { + primaryErr.Provider = req.Provider return nil, primaryErr } @@ -345,6 +368,7 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.Bifro return result, nil } if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { + fallbackErr.Provider = fallback.Provider return nil, fallbackErr } @@ -352,6 +376,8 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.Bifro } } + primaryErr.Provider = req.Provider + // All providers failed, return the original error return nil, primaryErr } diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index a6274deca6..33d12afde4 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -431,12 +431,13 @@ const ( // - AllowFallbacks = &false: Bifrost will return this error immediately, no fallbacks // - AllowFallbacks = nil: Treated as true by default (fallbacks allowed for resilience) type BifrostError struct { - EventID *string `json:"event_id,omitempty"` - Type *string `json:"type,omitempty"` - IsBifrostError bool `json:"is_bifrost_error"` - StatusCode *int `json:"status_code,omitempty"` - Error ErrorField `json:"error"` - AllowFallbacks *bool `json:"-"` // Optional: Controls fallback behavior (nil = true by default) + Provider ModelProvider `json:"-"` + EventID *string `json:"event_id,omitempty"` + Type *string `json:"type,omitempty"` + IsBifrostError bool `json:"is_bifrost_error"` + StatusCode *int `json:"status_code,omitempty"` + Error ErrorField `json:"error"` + AllowFallbacks *bool `json:"-"` // Optional: Controls fallback behavior (nil = true by default) } // ErrorField represents detailed error information. diff --git a/tests/transports-integrations/README.md b/tests/transports-integrations/README.md index ddb6760f8f..0d1be2511e 100644 --- a/tests/transports-integrations/README.md +++ b/tests/transports-integrations/README.md @@ -34,7 +34,7 @@ The Bifrost integration tests use a centralized configuration system that routes ## 📋 Test Categories -Our test suite covers 11 comprehensive scenarios for each integration: +Our test suite covers 12 comprehensive scenarios for each integration: 1. **Simple Chat** - Basic single-message conversations 2. **Multi-turn Conversation** - Conversation history and context retention @@ -47,6 +47,7 @@ Our test suite covers 11 comprehensive scenarios for each integration: 9. **Multiple Images** - Multi-image analysis and comparison 10. **Complex End-to-End** - Comprehensive multimodal workflows 11. **Integration-Specific Features** - Integration-unique capabilities +12. **Error Handling** - Invalid request error processing and propagation ## 📁 Directory Structure @@ -179,6 +180,10 @@ python run_integration_tests.py litellm # Option 3: Using pytest directly pytest tests/integrations/test_openai.py -v + +# Run specific test categories +pytest tests/integrations/ -k "error_handling" -v # Run only error handling tests +pytest tests/integrations/ -k "test_12" -v # Run all 12th test cases (error handling) ``` #### Makefile Commands diff --git a/tests/transports-integrations/tests/integrations/test_anthropic.py b/tests/transports-integrations/tests/integrations/test_anthropic.py index 0e9c8a9907..5e5f4fc66b 100644 --- a/tests/transports-integrations/tests/integrations/test_anthropic.py +++ b/tests/transports-integrations/tests/integrations/test_anthropic.py @@ -35,12 +35,16 @@ MULTIPLE_TOOL_CALL_MESSAGES, IMAGE_URL, BASE64_IMAGE, + INVALID_ROLE_MESSAGES, WEATHER_TOOL, CALCULATOR_TOOL, + ALL_TOOLS, mock_tool_response, assert_valid_chat_response, assert_has_tool_calls, assert_valid_image_response, + assert_valid_error_response, + assert_error_propagation, extract_tool_calls, get_api_key, skip_if_no_api_key, @@ -544,6 +548,21 @@ def test_11_integration_specific_features(self, anthropic_client, test_config): # Should prefer calculator for math question assert tool_calls[0]["name"] == "calculate" + @skip_if_no_api_key("anthropic") + def test_12_error_handling_invalid_roles(self, anthropic_client, test_config): + """Test Case 12: Error handling for invalid roles""" + with pytest.raises(Exception) as exc_info: + anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=INVALID_ROLE_MESSAGES, + max_tokens=100, + ) + + # Verify the error is properly caught and contains role-related information + error = exc_info.value + assert_valid_error_response(error, "tester") + assert_error_propagation(error, "anthropic") + # Additional helper functions specific to Anthropic def extract_anthropic_tool_calls(response: Any) -> List[Dict[str, Any]]: diff --git a/tests/transports-integrations/tests/integrations/test_google.py b/tests/transports-integrations/tests/integrations/test_google.py index 236a2fed19..02b96d2736 100644 --- a/tests/transports-integrations/tests/integrations/test_google.py +++ b/tests/transports-integrations/tests/integrations/test_google.py @@ -31,15 +31,19 @@ MULTIPLE_TOOL_CALL_MESSAGES, IMAGE_URL, BASE64_IMAGE, + INVALID_ROLE_MESSAGES, WEATHER_TOOL, CALCULATOR_TOOL, assert_valid_chat_response, assert_valid_image_response, + assert_valid_error_response, + assert_error_propagation, get_api_key, skip_if_no_api_key, COMPARISON_KEYWORDS, WEATHER_KEYWORDS, LOCATION_KEYWORDS, + GENAI_INVALID_ROLE_CONTENT, ) from ..utils.config_loader import get_model @@ -422,6 +426,19 @@ def test_11_integration_specific_features(self, google_client, test_config): assert_valid_chat_response(response3) + @skip_if_no_api_key("google") + def test_12_error_handling_invalid_roles(self, google_client, test_config): + """Test Case 12: Error handling for invalid roles""" + with pytest.raises(Exception) as exc_info: + google_client.models.generate_content( + model=get_model("google", "chat"), contents=GENAI_INVALID_ROLE_CONTENT + ) + + # Verify the error is properly caught and contains role-related information + error = exc_info.value + assert_valid_error_response(error, "tester") + assert_error_propagation(error, "google") + # Additional helper functions specific to Google GenAI def extract_google_function_calls(response: Any) -> List[Dict[str, Any]]: diff --git a/tests/transports-integrations/tests/integrations/test_litellm.py b/tests/transports-integrations/tests/integrations/test_litellm.py index 0c9ff41e6f..f06e048ffb 100644 --- a/tests/transports-integrations/tests/integrations/test_litellm.py +++ b/tests/transports-integrations/tests/integrations/test_litellm.py @@ -36,13 +36,18 @@ IMAGE_BASE64_MESSAGES, MULTIPLE_IMAGES_MESSAGES, COMPLEX_E2E_MESSAGES, + INVALID_ROLE_MESSAGES, WEATHER_TOOL, CALCULATOR_TOOL, mock_tool_response, assert_valid_chat_response, assert_has_tool_calls, assert_valid_image_response, + assert_valid_error_response, + assert_error_propagation, extract_tool_calls, + get_api_key, + skip_if_no_api_key, COMPARISON_KEYWORDS, WEATHER_KEYWORDS, LOCATION_KEYWORDS, @@ -339,6 +344,20 @@ def test_11_integration_specific_features(self, test_config): assert_valid_chat_response(response3) + def test_12_error_handling_invalid_roles(self, test_config): + """Test Case 12: Error handling for invalid roles""" + with pytest.raises(Exception) as exc_info: + litellm.completion( + model=get_model("litellm", "chat"), + messages=INVALID_ROLE_MESSAGES, + max_tokens=100, + ) + + # Verify the error is properly caught and contains role-related information + error = exc_info.value + assert_valid_error_response(error, "tester") + assert_error_propagation(error, "litellm") + # Additional helper functions specific to LiteLLM def extract_litellm_tool_calls(response: Any) -> List[Dict[str, Any]]: diff --git a/tests/transports-integrations/tests/integrations/test_openai.py b/tests/transports-integrations/tests/integrations/test_openai.py index 09aa81ee51..40ec386db9 100644 --- a/tests/transports-integrations/tests/integrations/test_openai.py +++ b/tests/transports-integrations/tests/integrations/test_openai.py @@ -36,12 +36,15 @@ IMAGE_BASE64_MESSAGES, MULTIPLE_IMAGES_MESSAGES, COMPLEX_E2E_MESSAGES, + INVALID_ROLE_MESSAGES, WEATHER_TOOL, CALCULATOR_TOOL, mock_tool_response, assert_valid_chat_response, assert_has_tool_calls, assert_valid_image_response, + assert_valid_error_response, + assert_error_propagation, extract_tool_calls, get_api_key, skip_if_no_api_key, @@ -389,3 +392,18 @@ def test_11_integration_specific_features(self, openai_client, test_config): ) assert_valid_chat_response(response3) + + @skip_if_no_api_key("openai") + def test_12_error_handling_invalid_roles(self, openai_client, test_config): + """Test Case 12: Error handling for invalid roles""" + with pytest.raises(Exception) as exc_info: + openai_client.chat.completions.create( + model=get_model("openai", "chat"), + messages=INVALID_ROLE_MESSAGES, + max_tokens=100, + ) + + # Verify the error is properly caught and contains role-related information + error = exc_info.value + assert_valid_error_response(error, "tester") + assert_error_propagation(error, "openai") diff --git a/tests/transports-integrations/tests/utils/common.py b/tests/transports-integrations/tests/utils/common.py index 64fea1cfbf..bd36e29095 100644 --- a/tests/transports-integrations/tests/utils/common.py +++ b/tests/transports-integrations/tests/utils/common.py @@ -151,6 +151,75 @@ class Config: }, ] +# Common keyword arrays for flexible assertions +COMPARISON_KEYWORDS = [ + "compare", + "comparison", + "different", + "difference", + "differences", + "both", + "two", + "first", + "second", + "images", + "image", + "versus", + "vs", + "contrast", + "unlike", + "while", + "whereas", +] + +WEATHER_KEYWORDS = [ + "weather", + "temperature", + "sunny", + "cloudy", + "rain", + "snow", + "celsius", + "fahrenheit", + "degrees", + "hot", + "cold", + "warm", + "cool", +] + +LOCATION_KEYWORDS = ["boston", "san francisco", "new york", "city", "location", "place"] + +# Error test data for invalid role testing +INVALID_ROLE_MESSAGES = [ + {"role": "tester", "content": "Hello! This should fail due to invalid role."} +] + +# GenAI-specific invalid role content that passes SDK validation but fails at Bifrost +GENAI_INVALID_ROLE_CONTENT = [ + { + "role": "tester", # Invalid role that should be caught by Bifrost + "parts": [ + {"text": "Hello! This should fail due to invalid role in GenAI format."} + ], + } +] + +# Error keywords for validating error messages +ERROR_KEYWORDS = [ + "invalid", + "error", + "role", + "tester", + "unsupported", + "unknown", + "bad", + "incorrect", + "not allowed", + "not supported", + "forbidden", +] + # Helper Functions def safe_eval_arithmetic(expression: str) -> float: @@ -392,44 +461,178 @@ def assert_valid_image_response(response: Any): ), f"Response should reference the image content. Got: {content}" -# Common keyword arrays for flexible assertions -COMPARISON_KEYWORDS = [ - "compare", - "comparison", - "different", - "difference", - "differences", - "both", - "two", - "first", - "second", - "images", - "image", - "versus", - "vs", - "contrast", - "unlike", - "while", - "whereas", -] +def assert_valid_error_response( + response_or_exception: Any, expected_invalid_role: str = "tester" +): + """ + Assert that an error response or exception properly indicates an invalid role error. -WEATHER_KEYWORDS = [ - "weather", - "temperature", - "sunny", - "cloudy", - "rain", - "snow", - "celsius", - "fahrenheit", - "degrees", - "hot", - "cold", - "warm", - "cool", -] + Args: + response_or_exception: Either an HTTP error response or a raised exception + expected_invalid_role: The invalid role that should be mentioned in the error + """ + error_message = "" + error_type = "" + status_code = None -LOCATION_KEYWORDS = ["boston", "san francisco", "new york", "city", "location", "place"] + # Handle different error response formats + if hasattr(response_or_exception, "response"): + # This is likely a requests.HTTPError or similar + try: + error_data = response_or_exception.response.json() + status_code = response_or_exception.response.status_code + + # Extract error message from various formats + if isinstance(error_data, dict): + if "error" in error_data: + if isinstance(error_data["error"], dict): + error_message = error_data["error"].get( + "message", str(error_data["error"]) + ) + error_type = error_data["error"].get("type", "") + else: + error_message = str(error_data["error"]) + else: + error_message = error_data.get("message", str(error_data)) + else: + error_message = str(error_data) + except: + error_message = str(response_or_exception) + + elif hasattr(response_or_exception, "message"): + # Direct error object + error_message = response_or_exception.message + + elif hasattr(response_or_exception, "args") and response_or_exception.args: + # Exception with args + error_message = str(response_or_exception.args[0]) + + else: + # Fallback to string representation + error_message = str(response_or_exception) + + # Convert to lowercase for case-insensitive matching + error_message_lower = error_message.lower() + error_type_lower = error_type.lower() + + # Validate that error message indicates role-related issue + role_error_indicators = [ + expected_invalid_role.lower(), + "role", + "invalid", + "unsupported", + "unknown", + "not allowed", + "not supported", + "bad request", + "invalid_request", + ] + + has_role_error = any( + indicator in error_message_lower or indicator in error_type_lower + for indicator in role_error_indicators + ) + + assert has_role_error, ( + f"Error message should indicate invalid role '{expected_invalid_role}'. " + f"Got error message: '{error_message}', error type: '{error_type}'" + ) + + # Validate status code if available (should be 4xx for client errors) + if status_code is not None: + assert ( + 400 <= status_code < 500 + ), f"Expected 4xx status code for invalid role error, got {status_code}" + + return True + + +def assert_error_propagation(error_response: Any, integration: str): + """ + Assert that error is properly propagated through Bifrost to the integration. + + Args: + error_response: The error response from the integration + integration: The integration name (openai, anthropic, etc.) + """ + # Check that we got an error response (not a success) + assert error_response is not None, "Should have received an error response" + + # Integration-specific error format validation + if integration.lower() == "openai": + # OpenAI format: should have top-level 'type', 'event_id' and 'error' field with nested structure + if hasattr(error_response, "response"): + error_data = error_response.response.json() + assert "error" in error_data, "OpenAI error should have 'error' field" + assert ( + "type" in error_data + ), "OpenAI error should have top-level 'type' field" + assert ( + "event_id" in error_data + ), "OpenAI error should have top-level 'event_id' field" + assert isinstance( + error_data["type"], str + ), "OpenAI error type should be a string" + assert isinstance( + error_data["event_id"], str + ), "OpenAI error event_id should be a string" + + # Check nested error structure + error_obj = error_data["error"] + assert ( + "message" in error_obj + ), "OpenAI error.error should have 'message' field" + assert "type" in error_obj, "OpenAI error.error should have 'type' field" + assert "code" in error_obj, "OpenAI error.error should have 'code' field" + assert ( + "event_id" in error_obj + ), "OpenAI error.error should have 'event_id' field" + + elif integration.lower() == "anthropic": + # Anthropic format: should have 'type' and 'error' with 'type' and 'message' + if hasattr(error_response, "response"): + error_data = error_response.response.json() + assert "type" in error_data, "Anthropic error should have 'type' field" + # Type field can be empty string if not set in original error + assert isinstance( + error_data["type"], str + ), "Anthropic error type should be a string" + assert "error" in error_data, "Anthropic error should have 'error' field" + assert ( + "type" in error_data["error"] + ), "Anthropic error.error should have 'type' field" + assert ( + "message" in error_data["error"] + ), "Anthropic error.error should have 'message' field" + + elif integration.lower() in ["google", "gemini", "genai"]: + # Gemini format: follows Google API design guidelines with error.code, error.message, error.status + if hasattr(error_response, "response"): + error_data = error_response.response.json() + assert "error" in error_data, "Gemini error should have 'error' field" + + # Check Google API standard error structure + error_obj = error_data["error"] + assert ( + "code" in error_obj + ), "Gemini error.error should have 'code' field (HTTP status code)" + assert isinstance( + error_obj["code"], int + ), "Gemini error.error.code should be an integer" + assert ( + "message" in error_obj + ), "Gemini error.error should have 'message' field" + assert isinstance( + error_obj["message"], str + ), "Gemini error.error.message should be a string" + assert ( + "status" in error_obj + ), "Gemini error.error should have 'status' field" + assert isinstance( + error_obj["status"], str + ), "Gemini error.error.status should be a string" + + return True # Test Categories @@ -447,6 +650,7 @@ class TestCategories: MULTIPLE_IMAGES = "multiple_images" COMPLEX_E2E = "complex_e2e" INTEGRATION_SPECIFIC = "integration_specific" + ERROR_HANDLING = "error_handling" # Environment helpers diff --git a/transports/bifrost-http/integrations/anthropic/router.go b/transports/bifrost-http/integrations/anthropic/router.go index 81d2275997..f205193282 100644 --- a/transports/bifrost-http/integrations/anthropic/router.go +++ b/transports/bifrost-http/integrations/anthropic/router.go @@ -32,6 +32,9 @@ func NewAnthropicRouter(client *bifrost.Bifrost) *AnthropicRouter { ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { return DeriveAnthropicFromBifrostResponse(resp), nil }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return DeriveAnthropicErrorFromBifrostError(err) + }, }, } diff --git a/transports/bifrost-http/integrations/anthropic/types.go b/transports/bifrost-http/integrations/anthropic/types.go index 8ebae7008b..9240896d90 100644 --- a/transports/bifrost-http/integrations/anthropic/types.go +++ b/transports/bifrost-http/integrations/anthropic/types.go @@ -93,6 +93,18 @@ type AnthropicUsage struct { OutputTokens int `json:"output_tokens"` } +// AnthropicMessageError represents an Anthropic messages API error response +type AnthropicMessageError struct { + Type string `json:"type"` // always "error" + Error AnthropicMessageErrorStruct `json:"error"` // Error details +} + +// AnthropicMessageErrorStruct represents the error structure of an Anthropic messages API error response +type AnthropicMessageErrorStruct struct { + Type string `json:"type"` // Error type + Message string `json:"message"` // Error message +} + // MarshalJSON implements custom JSON marshalling for MessageContent. // It marshals either ContentStr or ContentBlocks directly without wrapping. func (mc AnthropicContent) MarshalJSON() ([]byte, error) { @@ -464,3 +476,31 @@ func DeriveAnthropicFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *A anthropicResp.Content = content return anthropicResp } + +// DeriveAnthropicErrorFromBifrostError derives a AnthropicMessageError from a BifrostError +func DeriveAnthropicErrorFromBifrostError(bifrostErr *schemas.BifrostError) *AnthropicMessageError { + if bifrostErr == nil { + return nil + } + + // Provide blank strings for nil pointer fields + errorType := "" + if bifrostErr.Type != nil { + errorType = *bifrostErr.Type + } + + // Handle nested error fields with nil checks + errorStruct := AnthropicMessageErrorStruct{ + Type: "", + Message: bifrostErr.Error.Message, + } + + if bifrostErr.Error.Type != nil { + errorStruct.Type = *bifrostErr.Error.Type + } + + return &AnthropicMessageError{ + Type: errorType, + Error: errorStruct, + } +} diff --git a/transports/bifrost-http/integrations/genai/router.go b/transports/bifrost-http/integrations/genai/router.go index 8f0b470a9b..675d489029 100644 --- a/transports/bifrost-http/integrations/genai/router.go +++ b/transports/bifrost-http/integrations/genai/router.go @@ -34,6 +34,9 @@ func NewGenAIRouter(client *bifrost.Bifrost) *GenAIRouter { ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { return DeriveGenAIFromBifrostResponse(resp), nil }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return DeriveGeminiErrorFromBifrostError(err) + }, PreCallback: extractAndSetModelFromURL, }, } diff --git a/transports/bifrost-http/integrations/genai/types.go b/transports/bifrost-http/integrations/genai/types.go index 5a71e47bae..01dd15d06f 100644 --- a/transports/bifrost-http/integrations/genai/types.go +++ b/transports/bifrost-http/integrations/genai/types.go @@ -136,6 +136,18 @@ type GeminiChatRequest struct { ResponseModalities []string `json:"responseModalities,omitempty"` } +// GeminiChatRequestError represents a Gemini chat completion error response +type GeminiChatRequestError struct { + Error GeminiChatRequestErrorStruct `json:"error"` // Error details following Google API format +} + +// GeminiChatRequestErrorStruct represents the error structure of a Gemini chat completion error response +type GeminiChatRequestErrorStruct struct { + Code int `json:"code"` // HTTP status code + Message string `json:"message"` // Error message + Status string `json:"status"` // Error status string (e.g., "INVALID_REQUEST") +} + func (r *GeminiChatRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { provider, model := integrations.ParseModelString(r.Model, schemas.Vertex) @@ -567,6 +579,32 @@ func DeriveGenAIFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *genai return genaiResp } +// DeriveGeminiErrorFromBifrostError derives a GeminiChatRequestError from a BifrostError +func DeriveGeminiErrorFromBifrostError(bifrostErr *schemas.BifrostError) *GeminiChatRequestError { + if bifrostErr == nil { + return nil + } + + code := 500 + status := "" + + if bifrostErr.Error.Type != nil { + status = *bifrostErr.Error.Type + } + + if bifrostErr.StatusCode != nil { + code = *bifrostErr.StatusCode + } + + return &GeminiChatRequestError{ + Error: GeminiChatRequestErrorStruct{ + Code: code, + Message: bifrostErr.Error.Message, + Status: status, + }, + } +} + // isImageMimeType checks if a MIME type represents an image format func isImageMimeType(mimeType string) bool { if mimeType == "" { diff --git a/transports/bifrost-http/integrations/litellm/router.go b/transports/bifrost-http/integrations/litellm/router.go index f8d2c25464..e7403514d5 100644 --- a/transports/bifrost-http/integrations/litellm/router.go +++ b/transports/bifrost-http/integrations/litellm/router.go @@ -140,6 +140,19 @@ func NewLiteLLMRouter(client *bifrost.Bifrost) *LiteLLMRouter { } } + errorConverter := func(err *schemas.BifrostError) interface{} { + switch err.Provider { + case schemas.OpenAI, schemas.Azure: + return openai.DeriveOpenAIErrorFromBifrostError(err) + case schemas.Anthropic: + return anthropic.DeriveAnthropicErrorFromBifrostError(err) + case schemas.Vertex: + return genai.DeriveGeminiErrorFromBifrostError(err) + default: + return err + } + } + routes := []integrations.RouteConfig{} for _, path := range paths { routes = append(routes, integrations.RouteConfig{ @@ -148,6 +161,7 @@ func NewLiteLLMRouter(client *bifrost.Bifrost) *LiteLLMRouter { GetRequestTypeInstance: getRequestTypeInstance, RequestConverter: requestConverter, ResponseConverter: responseConverter, + ErrorConverter: errorConverter, PreCallback: preHook, }) } diff --git a/transports/bifrost-http/integrations/openai/router.go b/transports/bifrost-http/integrations/openai/router.go index 7371781f0b..d31f6b6599 100644 --- a/transports/bifrost-http/integrations/openai/router.go +++ b/transports/bifrost-http/integrations/openai/router.go @@ -32,6 +32,9 @@ func NewOpenAIRouter(client *bifrost.Bifrost) *OpenAIRouter { ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { return DeriveOpenAIFromBifrostResponse(resp), nil }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return DeriveOpenAIErrorFromBifrostError(err) + }, }, } diff --git a/transports/bifrost-http/integrations/openai/types.go b/transports/bifrost-http/integrations/openai/types.go index 970788a1be..1d5b8f7bc0 100644 --- a/transports/bifrost-http/integrations/openai/types.go +++ b/transports/bifrost-http/integrations/openai/types.go @@ -39,6 +39,28 @@ type OpenAIChatResponse struct { SystemFingerprint *string `json:"system_fingerprint,omitempty"` } +// OpenAIChatError represents an OpenAI chat completion error response +type OpenAIChatError struct { + EventID string `json:"event_id"` // Unique identifier for the error event + Type string `json:"type"` // Type of error + Error struct { + Type string `json:"type"` // Error type + Code string `json:"code"` // Error code + Message string `json:"message"` // Error message + Param interface{} `json:"param"` // Parameter that caused the error + EventID string `json:"event_id"` // Event ID for tracking + } `json:"error"` +} + +// OpenAIChatErrorStruct represents the error structure of an OpenAI chat completion error response +type OpenAIChatErrorStruct struct { + Type string `json:"type"` // Error type + Code string `json:"code"` // Error code + Message string `json:"message"` // Error message + Param interface{} `json:"param"` // Parameter that caused the error + EventID string `json:"event_id"` // Event ID for tracking +} + // ConvertToBifrostRequest converts an OpenAI chat request to Bifrost format func (r *OpenAIChatRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { provider, model := integrations.ParseModelString(r.Model, schemas.OpenAI) @@ -130,3 +152,48 @@ func DeriveOpenAIFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *Open return openaiResp } + +// DeriveOpenAIErrorFromBifrostError derives a OpenAIChatError from a BifrostError +func DeriveOpenAIErrorFromBifrostError(bifrostErr *schemas.BifrostError) *OpenAIChatError { + if bifrostErr == nil { + return nil + } + + // Provide blank strings for nil pointer fields + eventID := "" + if bifrostErr.EventID != nil { + eventID = *bifrostErr.EventID + } + + errorType := "" + if bifrostErr.Type != nil { + errorType = *bifrostErr.Type + } + + // Handle nested error fields with nil checks + errorStruct := OpenAIChatErrorStruct{ + Type: "", + Code: "", + Message: bifrostErr.Error.Message, + Param: bifrostErr.Error.Param, + EventID: eventID, + } + + if bifrostErr.Error.Type != nil { + errorStruct.Type = *bifrostErr.Error.Type + } + + if bifrostErr.Error.Code != nil { + errorStruct.Code = *bifrostErr.Error.Code + } + + if bifrostErr.Error.EventID != nil { + errorStruct.EventID = *bifrostErr.Error.EventID + } + + return &OpenAIChatError{ + EventID: eventID, + Type: errorType, + Error: errorStruct, + } +} diff --git a/transports/bifrost-http/integrations/utils.go b/transports/bifrost-http/integrations/utils.go index 1350ec2b31..84ba65e743 100644 --- a/transports/bifrost-http/integrations/utils.go +++ b/transports/bifrost-http/integrations/utils.go @@ -27,6 +27,10 @@ type RequestConverter func(req interface{}) (*schemas.BifrostRequest, error) // It takes a BifrostResponse and returns the format expected by the specific integration. type ResponseConverter func(*schemas.BifrostResponse) (interface{}, error) +// ErrorConverter is a function that converts BifrostError to integration-specific format. +// It takes a BifrostError and returns the format expected by the specific integration. +type ErrorConverter func(*schemas.BifrostError) interface{} + // PreRequestCallback is called before processing the request. // It can be used to modify the request object (e.g., extract model from URL parameters) // or perform validation. If it returns an error, the request processing stops. @@ -45,6 +49,7 @@ type RouteConfig struct { GetRequestTypeInstance func() interface{} // Factory function to create request instance (SHOULD NOT BE NIL) RequestConverter RequestConverter // Function to convert request to BifrostRequest (SHOULD NOT BE NIL) ResponseConverter ResponseConverter // Function to convert BifrostResponse to integration format (SHOULD NOT BE NIL) + ErrorConverter ErrorConverter // Function to convert BifrostError to integration format (SHOULD NOT BE NIL) PreCallback PreRequestCallback // Optional: called before request processing PostCallback PostRequestCallback // Optional: called after request processing } @@ -83,6 +88,10 @@ func (g *GenericRouter) RegisterRoutes(r *router.Router) { log.Println("[WARN] route configuration is invalid: ResponseConverter cannot be nil for route " + route.Path) continue } + if route.ErrorConverter == nil { + log.Println("[WARN] route configuration is invalid: ErrorConverter cannot be nil for route " + route.Path) + continue + } // Test that GetRequestTypeInstance returns a valid instance if testInstance := route.GetRequestTypeInstance(); testInstance == nil { @@ -127,7 +136,7 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle body := ctx.Request.Body() if len(body) > 0 { if err := json.Unmarshal(body, req); err != nil { - g.sendError(ctx, newBifrostError(err, "Invalid JSON")) + g.sendError(ctx, config.ErrorConverter, newBifrostError(err, "Invalid JSON")) return } } @@ -138,7 +147,7 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle // or performing request-specific validation if config.PreCallback != nil { if err := config.PreCallback(ctx, req); err != nil { - g.sendError(ctx, newBifrostError(err, "failed to execute pre-request callback")) + g.sendError(ctx, config.ErrorConverter, newBifrostError(err, "failed to execute pre-request callback")) return } } @@ -146,15 +155,15 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle // Convert the integration-specific request to Bifrost format bifrostReq, err := config.RequestConverter(req) if err != nil { - g.sendError(ctx, newBifrostError(err, "failed to convert request to Bifrost format")) + g.sendError(ctx, config.ErrorConverter, newBifrostError(err, "failed to convert request to Bifrost format")) return } if bifrostReq == nil { - g.sendError(ctx, newBifrostError(nil, "Invalid request")) + g.sendError(ctx, config.ErrorConverter, newBifrostError(nil, "Invalid request")) return } if bifrostReq.Model == "" { - g.sendError(ctx, newBifrostError(nil, "Model parameter is required")) + g.sendError(ctx, config.ErrorConverter, newBifrostError(nil, "Model parameter is required")) return } @@ -162,57 +171,61 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle bifrostCtx := lib.ConvertToBifrostContext(ctx) result, bifrostErr := g.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) if bifrostErr != nil { - g.sendError(ctx, bifrostErr) + g.sendError(ctx, config.ErrorConverter, bifrostErr) return } // Execute post-request callback if configured // This is typically used for response modification or additional processing if config.PostCallback != nil { if err := config.PostCallback(ctx, req, result); err != nil { - g.sendError(ctx, newBifrostError(err, "failed to execute post-request callback")) + g.sendError(ctx, config.ErrorConverter, newBifrostError(err, "failed to execute post-request callback")) return } } if result == nil { - g.sendError(ctx, newBifrostError(nil, "Bifrost response is nil after post-request callback")) + g.sendError(ctx, config.ErrorConverter, newBifrostError(nil, "Bifrost response is nil after post-request callback")) return } // Convert Bifrost response to integration-specific format and send response, err := config.ResponseConverter(result) if err != nil { - g.sendError(ctx, newBifrostError(err, "failed to encode response")) + g.sendError(ctx, config.ErrorConverter, newBifrostError(err, "failed to encode response")) return } - g.sendSuccess(ctx, response) + g.sendSuccess(ctx, config.ErrorConverter, response) } } // sendError sends an error response with the appropriate status code and JSON body. // It handles different error types (string, error interface, or arbitrary objects). -func (g *GenericRouter) sendError(ctx *fasthttp.RequestCtx, err *schemas.BifrostError) { - if err.StatusCode != nil { - ctx.SetStatusCode(*err.StatusCode) +func (g *GenericRouter) sendError(ctx *fasthttp.RequestCtx, errorConverter ErrorConverter, bifrostErr *schemas.BifrostError) { + if bifrostErr.StatusCode != nil { + ctx.SetStatusCode(*bifrostErr.StatusCode) } else { ctx.SetStatusCode(fasthttp.StatusInternalServerError) } - ctx.SetContentType("application/json") - if encodeErr := json.NewEncoder(ctx).Encode(err); encodeErr != nil { + + errorBody, err := json.Marshal(errorConverter(bifrostErr)) + if err != nil { ctx.SetStatusCode(fasthttp.StatusInternalServerError) - ctx.SetBodyString(fmt.Sprintf("failed to encode error response: %v", encodeErr)) + ctx.SetBodyString(fmt.Sprintf("failed to encode error response: %v", err)) + return } + + ctx.SetBody(errorBody) } // sendSuccess sends a successful response with HTTP 200 status and JSON body. -func (g *GenericRouter) sendSuccess(ctx *fasthttp.RequestCtx, response interface{}) { +func (g *GenericRouter) sendSuccess(ctx *fasthttp.RequestCtx, errorConverter ErrorConverter, response interface{}) { ctx.SetStatusCode(fasthttp.StatusOK) ctx.SetContentType("application/json") responseBody, err := json.Marshal(response) if err != nil { - g.sendError(ctx, newBifrostError(err, "failed to encode response")) + g.sendError(ctx, errorConverter, newBifrostError(err, "failed to encode response")) return }