diff --git a/nemoguardrails/actions/action_dispatcher.py b/nemoguardrails/actions/action_dispatcher.py index bd78a7248..a794b26a9 100644 --- a/nemoguardrails/actions/action_dispatcher.py +++ b/nemoguardrails/actions/action_dispatcher.py @@ -27,7 +27,7 @@ from langchain_core.runnables import Runnable from nemoguardrails import utils -from nemoguardrails.actions.llm.utils import LLMCallException +from nemoguardrails.exceptions import LLMCallException from nemoguardrails.logging.callbacks import logging_callbacks log = logging.getLogger(__name__) diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index c6f8439c5..0c81429ea 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -32,21 +32,23 @@ reasoning_trace_var, tool_calls_var, ) +from nemoguardrails.exceptions import LLMCallException from nemoguardrails.integrations.langchain.message_utils import dicts_to_messages from nemoguardrails.logging.callbacks import logging_callbacks from nemoguardrails.logging.explain import LLMCallInfo - -class LLMCallException(Exception): - """A wrapper around the LLM call invocation exception. - - This is used to propagate the exception out of the `generate_async` call (the default behavior is to - catch it and return an "Internal server error." message. - """ - - def __init__(self, inner_exception: Any): - super().__init__(f"LLM Call Exception: {str(inner_exception)}") - self.inner_exception = inner_exception +# Since different providers have different attributes for the base URL, we'll use this list +# to attempt to extract the base URL from a `BaseLanguageModel` instance. +BASE_URL_ATTRIBUTES = [ + "api_base", + "api_host", + "azure_endpoint", + "base_url", + "endpoint", + "endpoint_url", + "openai_api_base", + "server_url", +] def _infer_provider_from_module(llm: BaseLanguageModel) -> Optional[str]: @@ -209,6 +211,58 @@ def _prepare_callbacks( return logging_callbacks +def _raise_llm_call_exception( + exception: Exception, + llm: Union[BaseLanguageModel, Runnable], +) -> None: + """Raise an LLMCallException with enriched context about the failed invocation. + + Args: + exception: The original exception that occurred + llm: The LLM instance that was being invoked + + Raises: + LLMCallException with context message including model name and endpoint + """ + # Extract model name from context + llm_call_info = llm_call_info_var.get() + model_name = ( + llm_call_info.llm_model_name + if llm_call_info + else _infer_model_name(llm) + if isinstance(llm, BaseLanguageModel) + else "" + ) + + # Extract endpoint URL from the LLM instance + endpoint_url = None + for attr in BASE_URL_ATTRIBUTES: + if hasattr(llm, attr): + value = getattr(llm, attr, None) + if value: + endpoint_url = str(value) + break + + # If we didn't find endpoint URL, check the nested client object. + if not endpoint_url and hasattr(llm, "client"): + client = getattr(llm, "client", None) + if client and hasattr(client, "base_url"): + endpoint_url = str(client.base_url) + + # Build context message with model and endpoint info + context_parts = [] + if model_name: + context_parts.append(f"model={model_name}") + if endpoint_url: + context_parts.append(f"endpoint={endpoint_url}") + + if context_parts: + context_message = f"Error invoking LLM ({', '.join(context_parts)})" + raise LLMCallException(exception, context_message=context_message) + else: + raise LLMCallException(exception) + + async def _invoke_with_string_prompt( llm: Union[BaseLanguageModel, Runnable], prompt: str, @@ -218,7 +272,7 @@ async def _invoke_with_string_prompt( try: return await llm.ainvoke(prompt, config=RunnableConfig(callbacks=callbacks)) except Exception as e: - raise LLMCallException(e) + _raise_llm_call_exception(e, llm) async def _invoke_with_message_list( @@ -232,7 +286,7 @@ async def _invoke_with_message_list( try: return await llm.ainvoke(messages, config=RunnableConfig(callbacks=callbacks)) except Exception as e: - raise LLMCallException(e) + _raise_llm_call_exception(e, llm) def _convert_messages_to_langchain_format(prompt: List[dict]) -> List: diff --git a/nemoguardrails/exceptions.py b/nemoguardrails/exceptions.py new file mode 100644 index 000000000..abe5b143e --- /dev/null +++ b/nemoguardrails/exceptions.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +__all__ = [ + "ConfigurationError", + "InvalidModelConfigurationError", + "InvalidRailsConfigurationError", + "LLMCallException", +] + + +class ConfigurationError(ValueError): + """ + Base class for Guardrails Configuration validation errors. + """ + + pass + + +class InvalidModelConfigurationError(ConfigurationError): + """Raised when a guardrail configuration's model is invalid.""" + + pass + + +class InvalidRailsConfigurationError(ConfigurationError): + """Raised when rails configuration is invalid. + + Examples: + - Input/output rail references a model that doesn't exist in config + - Rail references a flow that doesn't exist + - Missing required prompt template + - Invalid rail parameters + """ + + pass + + +class LLMCallException(Exception): + """A wrapper around the LLM call invocation exception. + + This is used to propagate the exception out of the `generate_async` call. The default behavior is to + catch it and return an "Internal server error." message. + """ + + def __init__(self, inner_exception: Any, context_message: Optional[str] = None): + """Initialize LLMCallException. + + Args: + inner_exception: The original exception that occurred + context_message: Optional context to prepend (for example, the model name or endpoint) + """ + message = f"{context_message or 'LLM Call Exception'}: {str(inner_exception)}" + super().__init__(message) + + self.inner_exception = inner_exception + self.context_message = context_message diff --git a/nemoguardrails/llm/models/langchain_initializer.py b/nemoguardrails/llm/models/langchain_initializer.py index e789ba5c7..ab298e6cc 100644 --- a/nemoguardrails/llm/models/langchain_initializer.py +++ b/nemoguardrails/llm/models/langchain_initializer.py @@ -142,13 +142,13 @@ def init_langchain_model( initializers: list[ModelInitializer] = [ # Try special case handlers first (handles both chat and text) ModelInitializer(_handle_model_special_cases, ["chat", "text"]), + # FIXME: is text and chat a good idea? + # For text mode, use text completion, we are using both text and chat as the last resort + ModelInitializer(_init_text_completion_model, ["text", "chat"]), # For chat mode, first try the standard chat completion API ModelInitializer(_init_chat_completion_model, ["chat"]), # For chat mode, fall back to community chat models ModelInitializer(_init_community_chat_models, ["chat"]), - # FIXME: is text and chat a good idea? - # For text mode, use text completion, we are using both text and chat as the last resort - ModelInitializer(_init_text_completion_model, ["text", "chat"]), ] # Track the last exception for better error reporting diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 90d24bdc7..b7807432e 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -38,6 +38,10 @@ from nemoguardrails.colang.v1_0.runtime.flows import _normalize_flow_id from nemoguardrails.colang.v2_x.lang.utils import format_colang_parsing_error_message from nemoguardrails.colang.v2_x.runtime.errors import ColangParsingError +from nemoguardrails.exceptions import ( + InvalidModelConfigurationError, + InvalidRailsConfigurationError, +) from nemoguardrails.llm.types import Task log = logging.getLogger(__name__) @@ -144,8 +148,8 @@ def set_and_validate_model(cls, data: Any) -> Any: model_from_params = parameters.get("model_name") or parameters.get("model") if model_field and model_from_params: - raise ValueError( - "Model name must be specified in exactly one place: either in the 'model' field or in parameters, not both." + raise InvalidModelConfigurationError( + f"Model name must be specified in exactly one place: either the `model` field, or in `parameters` (`parameters.model` or `parameters.model_name`).", ) if not model_field and model_from_params: data["model"] = model_from_params @@ -162,8 +166,8 @@ def set_and_validate_model(cls, data: Any) -> Any: def model_must_be_none_empty(self) -> "Model": """Validate that a model name is present either directly or in parameters.""" if not self.model or not self.model.strip(): - raise ValueError( - "Model name must be specified either directly in the 'model' field or through 'model_name'/'model' in parameters" + raise InvalidModelConfigurationError( + f"Model name must be specified in exactly one place: either the `model` field, or in `parameters` (`parameters.model` or `parameters.model_name`)." ) return self @@ -349,10 +353,14 @@ class TaskPrompt(BaseModel): @root_validator(pre=True, allow_reuse=True) def check_fields(cls, values): if not values.get("content") and not values.get("messages"): - raise ValueError("One of `content` or `messages` must be provided.") + raise InvalidRailsConfigurationError( + "One of `content` or `messages` must be provided." + ) if values.get("content") and values.get("messages"): - raise ValueError("Only one of `content` or `messages` must be provided.") + raise InvalidRailsConfigurationError( + "Only one of `content` or `messages` must be provided." + ) return values @@ -1476,8 +1484,14 @@ def check_model_exists_for_input_rails(cls, values): if not flow_model: continue if flow_model not in model_types: - raise ValueError( - f"No `{flow_model}` model provided for input flow `{_normalize_flow_id(flow)}`" + flow_id = _normalize_flow_id(flow) + available_types = ( + ", ".join(f"'{str(t)}'" for t in sorted(model_types)) + if model_types + else "none" + ) + raise InvalidRailsConfigurationError( + f"Input flow '{flow_id}' references model type '{flow_model}' that is not defined in the configuration. Detected model types: {available_types}." ) return values @@ -1505,8 +1519,14 @@ def check_model_exists_for_output_rails(cls, values): if not flow_model: continue if flow_model not in model_types: - raise ValueError( - f"No `{flow_model}` model provided for output flow `{_normalize_flow_id(flow)}`" + flow_id = _normalize_flow_id(flow) + available_types = ( + ", ".join(f"'{str(t)}'" for t in sorted(model_types)) + if model_types + else "none" + ) + raise InvalidRailsConfigurationError( + f"Output flow '{flow_id}' references model type '{flow_model}' that is not defined in the configuration. Detected model types: {available_types}." ) return values @@ -1527,13 +1547,15 @@ def check_prompt_exist_for_self_check_rails(cls, values): "self check input" in enabled_input_rails and "self_check_input" not in provided_task_prompts ): - raise ValueError("You must provide a `self_check_input` prompt template.") + raise InvalidRailsConfigurationError( + f"Missing a `self_check_input` prompt template, which is required for the `self check input` rail." + ) if ( "llama guard check input" in enabled_input_rails and "llama_guard_check_input" not in provided_task_prompts ): - raise ValueError( - "You must provide a `llama_guard_check_input` prompt template." + raise InvalidRailsConfigurationError( + f"Missing a `llama_guard_check_input` prompt template, which is required for the `llama guard check input` rail." ) # Only content-safety and topic-safety include a $model reference in the rail flow text @@ -1551,27 +1573,31 @@ def check_prompt_exist_for_self_check_rails(cls, values): "self check output" in enabled_output_rails and "self_check_output" not in provided_task_prompts ): - raise ValueError("You must provide a `self_check_output` prompt template.") + raise InvalidRailsConfigurationError( + f"Missing a `self_check_output` prompt template, which is required for the `self check output` rail." + ) if ( "llama guard check output" in enabled_output_rails and "llama_guard_check_output" not in provided_task_prompts ): - raise ValueError( - "You must provide a `llama_guard_check_output` prompt template." + raise InvalidRailsConfigurationError( + f"Missing a `llama_guard_check_output` prompt template, which is required for the `llama guard check output` rail." ) if ( "patronus lynx check output hallucination" in enabled_output_rails and "patronus_lynx_check_output_hallucination" not in provided_task_prompts ): - raise ValueError( - "You must provide a `patronus_lynx_check_output_hallucination` prompt template." + raise InvalidRailsConfigurationError( + f"Missing a `patronus_lynx_check_output_hallucination` prompt template, which is required for the `patronus lynx check output hallucination` rail." ) if ( "self check facts" in enabled_output_rails and "self_check_facts" not in provided_task_prompts ): - raise ValueError("You must provide a `self_check_facts` prompt template.") + raise InvalidRailsConfigurationError( + f"Missing a `self_check_facts` prompt template, which is required for the `self check facts` rail." + ) # Only content-safety and topic-safety include a $model reference in the rail flow text # Need to match rails with flow_id (excluding $model reference) and match prompts @@ -1638,7 +1664,7 @@ def validate_models_api_key_env_var(cls, models): api_keys = [m.api_key_env_var for m in models] for api_key in api_keys: if api_key and not os.environ.get(api_key): - raise ValueError( + raise InvalidRailsConfigurationError( f"Model API Key environment variable '{api_key}' not set." ) return models @@ -1931,6 +1957,6 @@ def _validate_rail_prompts( prompt_flow_id = flow_id.replace(" ", "_") expected_prompt = f"{prompt_flow_id} $model={flow_model}" if expected_prompt not in prompts: - raise ValueError( - f"You must provide a `{expected_prompt}` prompt template." + raise InvalidRailsConfigurationError( + f"Missing a `{expected_prompt}` prompt template, which is required for the `{validation_rail}` rail." ) diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 187300aa2..e7a7c2f7c 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -69,6 +69,10 @@ from nemoguardrails.embeddings.index import EmbeddingsIndex from nemoguardrails.embeddings.providers import register_embedding_provider from nemoguardrails.embeddings.providers.base import EmbeddingModel +from nemoguardrails.exceptions import ( + InvalidModelConfigurationError, + InvalidRailsConfigurationError, +) from nemoguardrails.kb.kb import KnowledgeBase from nemoguardrails.llm.cache import CacheInterface, LFUCache from nemoguardrails.llm.models.initializer import ( @@ -234,13 +238,19 @@ def __init__( spec.loader.exec_module(config_module) config_modules.append(config_module) + colang_version_to_runtime: Dict[str, Type[Runtime]] = { + "1.0": RuntimeV1_0, + "2.x": RuntimeV2_x, + } + if config.colang_version not in colang_version_to_runtime: + raise InvalidRailsConfigurationError( + f"Unsupported colang version: {config.colang_version}. Supported versions: {list(colang_version_to_runtime.keys())}" + ) + # First, we initialize the runtime. - if config.colang_version == "1.0": - self.runtime = RuntimeV1_0(config=config, verbose=verbose) - elif config.colang_version == "2.x": - self.runtime = RuntimeV2_x(config=config, verbose=verbose) - else: - raise ValueError(f"Unsupported colang version: {config.colang_version}.") + self.runtime = colang_version_to_runtime[config.colang_version]( + config=config, verbose=verbose + ) # If we have a config_modules with an `init` function, we call it. # We need to call this here because the `init` might register additional @@ -328,26 +338,26 @@ def _validate_config(self): # content safety check input/output flows are special as they have parameters flow_name = _normalize_flow_id(flow_name) if flow_name not in existing_flows_names: - raise ValueError( + raise InvalidRailsConfigurationError( f"The provided input rail flow `{flow_name}` does not exist" ) for flow_name in self.config.rails.output.flows: flow_name = _normalize_flow_id(flow_name) if flow_name not in existing_flows_names: - raise ValueError( + raise InvalidRailsConfigurationError( f"The provided output rail flow `{flow_name}` does not exist" ) for flow_name in self.config.rails.retrieval.flows: if flow_name not in existing_flows_names: - raise ValueError( + raise InvalidRailsConfigurationError( f"The provided retrieval rail flow `{flow_name}` does not exist" ) # If both passthrough mode and single call mode are specified, we raise an exception. if self.config.passthrough and self.config.rails.dialog.single_call.enabled: - raise ValueError( + raise InvalidRailsConfigurationError( "The passthrough mode and the single call dialog rails mode can't be used at the same time. " "The single call mode needs to use an altered prompt when prompting the LLM. " ) @@ -491,7 +501,9 @@ def _init_llms(self): try: model_name = llm_config.model if not model_name: - raise ValueError("LLM Config model field not set") + raise InvalidModelConfigurationError( + f"`model` field must be set in model configuration: {llm_config.model_dump_json()}" + ) provider_name = llm_config.engine kwargs = self._prepare_model_kwargs(llm_config) @@ -1248,11 +1260,11 @@ def _validate_streaming_with_output_rails(self) -> None: not self.config.rails.output.streaming or not self.config.rails.output.streaming.enabled ): - raise ValueError( - "stream_async() cannot be used when output rails are configured but " - "rails.output.streaming.enabled is False. Either set " - "rails.output.streaming.enabled to True in your configuration, or use " - "generate_async() instead of stream_async()." + raise InvalidRailsConfigurationError( + f"stream_async() cannot be used when output rails are configured but " + f"rails.output.streaming.enabled is False. Either set " + f"rails.output.streaming.enabled to True in your configuration, or use " + f"generate_async() instead of stream_async()." ) def stream_async( diff --git a/tests/llm_providers/test_langchain_initializer.py b/tests/llm_providers/test_langchain_initializer.py index 2252570a6..0cdb87f43 100644 --- a/tests/llm_providers/test_langchain_initializer.py +++ b/tests/llm_providers/test_langchain_initializer.py @@ -70,25 +70,27 @@ def test_special_case_called_first(mock_initializers): def test_chat_completion_called(mock_initializers): mock_initializers["special"].return_value = None + mock_initializers["text"].return_value = None mock_initializers["chat"].return_value = "chat_model" result = init_langchain_model("chat-model", "provider", "chat", {}) assert result == "chat_model" mock_initializers["special"].assert_called_once() + mock_initializers["text"].assert_called_once() mock_initializers["chat"].assert_called_once() mock_initializers["community"].assert_not_called() - mock_initializers["text"].assert_not_called() def test_community_chat_called(mock_initializers): mock_initializers["special"].return_value = None + mock_initializers["text"].return_value = None mock_initializers["chat"].return_value = None mock_initializers["community"].return_value = "community_model" result = init_langchain_model("community-chat", "provider", "chat", {}) assert result == "community_model" mock_initializers["special"].assert_called_once() + mock_initializers["text"].assert_called_once() mock_initializers["chat"].assert_called_once() mock_initializers["community"].assert_called_once() - mock_initializers["text"].assert_not_called() def test_text_completion_called(mock_initializers): @@ -154,36 +156,39 @@ def test_all_initializers_raise_exceptions(mock_initializers): def test_duplicate_modes_in_initializer(mock_initializers): mock_initializers["special"].return_value = None + mock_initializers["text"].return_value = None mock_initializers["chat"].return_value = "chat_model" result = init_langchain_model("chat-model", "provider", "chat", {}) assert result == "chat_model" mock_initializers["special"].assert_called_once() + mock_initializers["text"].assert_called_once() mock_initializers["chat"].assert_called_once() mock_initializers["community"].assert_not_called() - mock_initializers["text"].assert_not_called() def test_chat_completion_called_when_special_returns_none(mock_initializers): mock_initializers["special"].return_value = None + mock_initializers["text"].return_value = None mock_initializers["chat"].return_value = "chat_model" result = init_langchain_model("chat-model", "provider", "chat", {}) assert result == "chat_model" mock_initializers["special"].assert_called_once() + mock_initializers["text"].assert_called_once() mock_initializers["chat"].assert_called_once() mock_initializers["community"].assert_not_called() - mock_initializers["text"].assert_not_called() def test_community_chat_called_when_previous_fail(mock_initializers): mock_initializers["special"].return_value = None + mock_initializers["text"].return_value = None mock_initializers["chat"].return_value = None mock_initializers["community"].return_value = "community_model" result = init_langchain_model("community-chat", "provider", "chat", {}) assert result == "community_model" mock_initializers["special"].assert_called_once() + mock_initializers["text"].assert_called_once() mock_initializers["chat"].assert_called_once() mock_initializers["community"].assert_called_once() - mock_initializers["text"].assert_not_called() def test_text_completion_called_when_previous_fail(mock_initializers): @@ -201,12 +206,11 @@ def test_text_completion_called_when_previous_fail(mock_initializers): def test_text_completion_supports_chat_mode(mock_initializers): mock_initializers["special"].return_value = None - mock_initializers["chat"].return_value = None - mock_initializers["community"].return_value = None mock_initializers["text"].return_value = "text_model" result = init_langchain_model("text-model", "provider", "chat", {}) assert result == "text_model" mock_initializers["special"].assert_called_once() - mock_initializers["chat"].assert_called_once() - mock_initializers["community"].assert_called_once() mock_initializers["text"].assert_called_once() + # Since text returns a value, chat and community are not called + mock_initializers["chat"].assert_not_called() + mock_initializers["community"].assert_not_called() diff --git a/tests/test_actions_llm_utils.py b/tests/test_actions_llm_utils.py index 8f0accbd2..7b786eac0 100644 --- a/tests/test_actions_llm_utils.py +++ b/tests/test_actions_llm_utils.py @@ -13,12 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import cast +from unittest.mock import AsyncMock + +import pytest +from langchain_core.language_models import BaseLanguageModel + from nemoguardrails.actions.llm.utils import ( _extract_and_remove_think_tags, _infer_provider_from_module, _store_reasoning_traces, + llm_call, ) from nemoguardrails.context import reasoning_trace_var +from nemoguardrails.exceptions import LLMCallException class MockOpenAILLM: @@ -49,6 +57,24 @@ class MockPatchedNVIDIA(MockNVIDIAOriginal): __module__ = "nemoguardrails.llm.providers._langchain_nvidia_ai_endpoints_patch" +class MockTRTLLM: + __module__ = "nemoguardrails.llm.providers.trtllm.llm" + + +class MockAzureLLM: + __module__ = "langchain_openai.chat_models" + + +class MockLLMWithClient: + __module__ = "langchain_openai.chat_models" + + class _MockClient: + base_url = "https://custom.endpoint.com/v1" + + def __init__(self): + self.client = self._MockClient() + + def test_infer_provider_openai(): llm = MockOpenAILLM() provider = _infer_provider_from_module(llm) @@ -304,3 +330,88 @@ def test_extract_and_remove_think_tags_wrong_order(): assert result is None assert response.content == " text here " + + +@pytest.mark.asyncio +async def test_llm_call_exception_enrichment_with_model_and_endpoint(): + """Test that LLM invocation errors include model and endpoint context.""" + mock_llm = MockOpenAILLM() + mock_llm.model_name = "gpt-4" + mock_llm.base_url = "https://api.openai.com/v1" + mock_llm.ainvoke = AsyncMock(side_effect=ConnectionError("Connection refused")) + + with pytest.raises(LLMCallException) as exc_info: + await llm_call(cast(BaseLanguageModel, mock_llm), "test prompt") + + exc_str = str(exc_info.value) + assert "gpt-4" in exc_str + assert "https://api.openai.com/v1" in exc_str + assert "Connection refused" in exc_str + assert isinstance(exc_info.value.inner_exception, ConnectionError) + + +@pytest.mark.asyncio +async def test_llm_call_exception_without_endpoint(): + """Test exception enrichment when endpoint URL is not available.""" + mock_llm = AsyncMock() + mock_llm.__module__ = "langchain_openai.chat_models" + mock_llm.model_name = "custom-model" + # No base_url attribute + mock_llm.ainvoke = AsyncMock(side_effect=ValueError("Invalid request")) + + with pytest.raises(LLMCallException) as exc_info: + await llm_call(mock_llm, "test prompt") + + # Should still have model name but no endpoint + assert "custom-model" in str(exc_info.value) + assert "Invalid request" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_llm_call_exception_extracts_azure_endpoint(): + """Test that Azure-style endpoint URLs are extracted.""" + mock_llm = MockAzureLLM() + mock_llm.model_name = "gpt-4" + mock_llm.azure_endpoint = "https://example.openai.azure.com" + mock_llm.ainvoke = AsyncMock(side_effect=Exception("Azure error")) + + with pytest.raises(LLMCallException) as exc_info: + await llm_call(cast(BaseLanguageModel, mock_llm), "test prompt") + + exc_str = str(exc_info.value) + assert "https://example.openai.azure.com" in exc_str + assert "gpt-4" in exc_str + assert "Azure error" in exc_str + + +@pytest.mark.asyncio +async def test_llm_call_exception_extracts_server_url(): + """Test that TRT-style server_url is extracted.""" + mock_llm = MockTRTLLM() + mock_llm.model_name = "llama-2-70b" + mock_llm.server_url = "https://triton.example.com:8000" + mock_llm.ainvoke = AsyncMock(side_effect=Exception("Triton server error")) + + with pytest.raises(LLMCallException) as exc_info: + await llm_call(cast(BaseLanguageModel, mock_llm), "test prompt") + + exc_str = str(exc_info.value) + assert "https://triton.example.com:8000" in exc_str + assert "llama-2-70b" in exc_str + assert "Triton server error" in exc_str + + +@pytest.mark.asyncio +async def test_llm_call_exception_extracts_nested_client_base_url(): + """Test that nested client.base_url is extracted.""" + mock_llm = MockLLMWithClient() + mock_llm.model_name = "gpt-4-turbo" + mock_llm.ainvoke = AsyncMock(side_effect=Exception("Client error")) + + with pytest.raises(LLMCallException) as exc_info: + await llm_call(cast(BaseLanguageModel, mock_llm), "test prompt") + + exc_str = str(exc_info.value) + assert "https://custom.endpoint.com/v1" in exc_str + assert "gpt-4-turbo" in exc_str + assert "Client error" in exc_str diff --git a/tests/test_config_validation.py b/tests/test_config_validation.py index 3e0bf62d7..a73216695 100644 --- a/tests/test_config_validation.py +++ b/tests/test_config_validation.py @@ -75,7 +75,7 @@ def test_self_check_input_prompt_exception(): ) LLMRails(config=config) - assert "You must provide a `self_check_input` prompt" in str(exc_info.value) + assert "Missing a `self_check_input` prompt template" in str(exc_info.value) def test_self_check_output_prompt_exception(): @@ -90,7 +90,7 @@ def test_self_check_output_prompt_exception(): ) LLMRails(config=config) - assert "You must provide a `self_check_output` prompt" in str(exc_info.value) + assert "Missing a `self_check_output` prompt template" in str(exc_info.value) def test_passthrough_and_single_call_incompatibility(): diff --git a/tests/test_embeddings_only_user_messages.py b/tests/test_embeddings_only_user_messages.py index 8f6e6109d..b02730ccf 100644 --- a/tests/test_embeddings_only_user_messages.py +++ b/tests/test_embeddings_only_user_messages.py @@ -18,7 +18,7 @@ import pytest from nemoguardrails import LLMRails, RailsConfig -from nemoguardrails.actions.llm.utils import LLMCallException +from nemoguardrails.exceptions import LLMCallException from nemoguardrails.llm.filters import colang from tests.utils import TestChat diff --git a/tests/test_rails_config.py b/tests/test_rails_config.py index 9f5f3e7c7..159c39221 100644 --- a/tests/test_rails_config.py +++ b/tests/test_rails_config.py @@ -98,7 +98,7 @@ def test_check_prompt_exist_for_self_check_rails(): ], } with pytest.raises( - ValueError, match="You must provide a `self_check_output` prompt template" + ValueError, match="Missing a `self_check_output` prompt template" ): RailsConfig.check_prompt_exist_for_self_check_rails(values) @@ -353,7 +353,7 @@ def test_validate_rail_prompts_wrong_flow_id_raises(self): with pytest.raises( ValueError, - match="You must provide a `content_safety_check_input \$model=content_safety` prompt template.", + match="Missing a `content_safety_check_input \$model=content_safety` prompt template", ): _validate_rail_prompts( ["content safety check input $model=content_safety"], @@ -366,7 +366,7 @@ def test_validate_rail_prompts_wrong_model_raises(self): with pytest.raises( ValueError, - match="You must provide a `content_safety_check_input \$model=content_safety` prompt template.", + match="Missing a `content_safety_check_input \$model=content_safety` prompt template", ): _validate_rail_prompts( ["content safety check input $model=content_safety"], @@ -379,7 +379,7 @@ def test_validate_rail_prompts_no_prompt_raises(self): with pytest.raises( ValueError, - match="You must provide a `content_safety_check_input \$model=content_safety` prompt template.", + match="Missing a `content_safety_check_input \$model=content_safety` prompt template", ): _validate_rail_prompts( ["content safety check input $model=content_safety"], @@ -395,7 +395,7 @@ def test_content_safety_input_missing_prompt_raises(self): """Check Content Safety output rail raises ValueError if we don't have a prompt""" with pytest.raises( ValueError, - match="You must provide a `content_safety_check_input \$model=content_safety` prompt template.", + match="Missing a `content_safety_check_input \$model=content_safety` prompt template", ): _ = RailsConfig.from_content( yaml_content=""" @@ -415,7 +415,7 @@ def test_content_safety_output_missing_prompt_raises(self): """Check Content Safety output rail raises ValueError if we don't have a prompt""" with pytest.raises( ValueError, - match="You must provide a `content_safety_check_output \$model=content_safety` prompt template.", + match="Missing a `content_safety_check_output \$model=content_safety` prompt template", ): _ = RailsConfig.from_content( yaml_content=""" @@ -531,7 +531,7 @@ def test_input_content_safety_no_model_raises(self): with pytest.raises( ValueError, - match="No `content_safety` model provided for input flow `content safety check input`", + match="Input flow 'content safety check input' references model type 'content_safety' that is not defined", ): _ = RailsConfig.from_content( yaml_content=""" @@ -556,7 +556,7 @@ def test_input_content_safety_wrong_model_raises(self): with pytest.raises( ValueError, - match="No `content_safety` model provided for input flow `content safety check input", + match="Input flow 'content safety check input' references model type 'content_safety' that is not defined", ): _ = RailsConfig.from_content( yaml_content=""" @@ -581,7 +581,7 @@ def test_output_content_safety_no_model_raises(self): with pytest.raises( ValueError, - match="No `content_safety` model provided for output flow `content safety check output`", + match="Output flow 'content safety check output' references model type 'content_safety' that is not defined", ): _ = RailsConfig.from_content( yaml_content=""" @@ -606,7 +606,7 @@ def test_output_content_safety_wrong_model_raises(self): with pytest.raises( ValueError, - match="You must provide a `content_safety_check_output \$model=content_safety` prompt template", + match="Missing a `content_safety_check_output \$model=content_safety` prompt template", ): _ = RailsConfig.from_content( yaml_content=""" @@ -664,7 +664,7 @@ def test_topic_safety_no_prompt_raises(self): with pytest.raises( ValueError, - match="You must provide a `topic_safety_check_input \$model=topic_control` prompt template", + match="Missing a `topic_safety_check_input \$model=topic_control` prompt template", ): _ = RailsConfig.from_content( yaml_content=""" @@ -688,7 +688,7 @@ def test_topic_safety_no_model_raises(self): """Check if we don't provide a topic-safety model we raise a ValueError""" with pytest.raises( ValueError, - match="No `topic_control` model provided for input flow `topic safety check input`", + match="Input flow 'topic safety check input' references model type 'topic_control' that is not defined", ): _ = RailsConfig.from_content( yaml_content=""" @@ -712,7 +712,7 @@ def test_topic_safety_no_model_no_prompt_raises(self): """Check a missing model and prompt raises ValueError""" with pytest.raises( ValueError, - match="You must provide a `topic_safety_check_input \$model=topic_control` prompt template", + match="Missing a `topic_safety_check_input \$model=topic_control` prompt template", ): _ = RailsConfig.from_content( yaml_content=""" @@ -741,7 +741,7 @@ def test_hero_separate_models_no_prompts_raises(self): with pytest.raises( ValueError, - match="You must provide a `content_safety_check_input \$model=my_content_safety` prompt template", + match="Missing a `content_safety_check_input \$model=my_content_safety` prompt template", ): _ = RailsConfig.from_content( yaml_content=""" @@ -883,7 +883,7 @@ def test_hero_no_prompts_raises(self): """Create hero workflow with no prompts. Expect Content Safety input prompt check to fail""" with pytest.raises( ValueError, - match="You must provide a `content_safety_check_input \$model=content_safety` prompt template", + match="Missing a `content_safety_check_input \$model=content_safety` prompt template", ): _ = RailsConfig.from_content( yaml_content=""" @@ -923,7 +923,7 @@ def test_hero_no_output_content_safety_prompt_raises(self): """Create hero workflow with no prompts. Expect Content Safety input prompt check to fail""" with pytest.raises( ValueError, - match="You must provide a `topic_safety_check_input \$model=your_topic_control` prompt template", + match="Missing a `topic_safety_check_input \$model=your_topic_control` prompt template", ): _ = RailsConfig.from_content( yaml_content=""" @@ -967,7 +967,7 @@ def test_hero_no_topic_safety_prompt_raises(self): """Create hero workflow with no prompts. Expect Content Safety input prompt check to fail""" with pytest.raises( ValueError, - match="You must provide a `topic_safety_check_input \$model=your_topic_control` prompt template", + match="Missing a `topic_safety_check_input \$model=your_topic_control` prompt template", ): _ = RailsConfig.from_content( yaml_content=""" @@ -1013,7 +1013,7 @@ def test_hero_topic_safety_prompt_raises(self): """Create hero workflow with no prompts. Expect Content Safety input prompt check to fail""" with pytest.raises( ValueError, - match="You must provide a `content_safety_check_input \$model=content_safety` prompt template", + match="Missing a `content_safety_check_input \$model=content_safety` prompt template", ): _ = RailsConfig.from_content( yaml_content=""" diff --git a/tests/test_tool_calling_utils.py b/tests/test_tool_calling_utils.py index 3a34eab82..5fa794257 100644 --- a/tests/test_tool_calling_utils.py +++ b/tests/test_tool_calling_utils.py @@ -19,7 +19,6 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from nemoguardrails.actions.llm.utils import ( - LLMCallException, _convert_messages_to_langchain_format, _extract_content, _store_tool_calls, @@ -27,6 +26,7 @@ llm_call, ) from nemoguardrails.context import tool_calls_var +from nemoguardrails.exceptions import LLMCallException from nemoguardrails.rails.llm.llmrails import GenerationResponse