diff --git a/docs/my-website/docs/proxy/tag_routing.md b/docs/my-website/docs/proxy/tag_routing.md index 399c43d2c0f..a1ae52e5e45 100644 --- a/docs/my-website/docs/proxy/tag_routing.md +++ b/docs/my-website/docs/proxy/tag_routing.md @@ -209,6 +209,106 @@ Expect to see the following response header when this works x-litellm-model-id: default-model ``` +## Regex-based tag routing (`tag_regex`) + +Use `tag_regex` to route requests based on regex patterns matched against request headers, without requiring clients to pass a tag explicitly. This is useful when clients already send a recognisable header, such as `User-Agent`. + +**Use case: route all Claude Code traffic to dedicated AWS accounts** + +Claude Code always sends `User-Agent: claude-code/`. With `tag_regex` you can route that traffic to a dedicated deployment automatically — no per-developer configuration needed. + +### 1. Config + +```yaml +model_list: + # Claude Code traffic → dedicated deployment, matched by User-Agent + - model_name: claude-sonnet + litellm_params: + model: bedrock/converse/anthropic-claude-sonnet-4-6 + aws_region_name: us-east-1 + aws_role_name: arn:aws:iam::111122223333:role/LiteLLMClaudeCode + tag_regex: + - "^User-Agent: claude-code\\/" # matches claude-code/1.x, 2.x, etc. + model_info: + id: claude-code-deployment + + # All other traffic falls back to the default deployment + - model_name: claude-sonnet + litellm_params: + model: bedrock/converse/anthropic-claude-sonnet-4-6 + aws_region_name: us-east-1 + aws_role_name: arn:aws:iam::444455556666:role/LiteLLMDefault + tags: + - default + model_info: + id: regular-deployment + +router_settings: + enable_tag_filtering: true + tag_filtering_match_any: true + +general_settings: + master_key: sk-1234 +``` + +### 2. Verify routing + +Claude Code sets `User-Agent: claude-code/` automatically — no client config needed: + +```shell +# Claude Code request (User-Agent set automatically by Claude Code) +curl http://localhost:4000/v1/chat/completions \ + -H "Authorization: Bearer sk-1234" \ + -H "User-Agent: claude-code/1.2.3" \ + -d '{"model": "claude-sonnet", "messages": [{"role": "user", "content": "hi"}]}' +# → x-litellm-model-id: claude-code-deployment + +# Any other client (no matching User-Agent) → default deployment +curl http://localhost:4000/v1/chat/completions \ + -H "Authorization: Bearer sk-1234" \ + -d '{"model": "claude-sonnet", "messages": [{"role": "user", "content": "hi"}]}' +# → x-litellm-model-id: regular-deployment +``` + +### How matching works + +| Priority | Condition | Result | +|----------|-----------|--------| +| 1 | Request has `tags` AND deployment has `tags` | Exact tag match (respects `match_any` setting) | +| 2 | Deployment has `tag_regex` AND request has a `User-Agent` | Regex match (always OR logic — any pattern match suffices) | +| 3 | Deployment has `tags: [default]` | Default fallback | +| 4 | No default set | All healthy deployments returned | + +`tag_regex` always uses OR semantics — `tag_filtering_match_any=False` applies only to exact tag matching, not to regex patterns. + +### Observability + +When a regex matches, `tag_routing` is written into request metadata and flows to SpendLogs: + +```json +{ + "tag_routing": { + "matched_via": "tag_regex", + "matched_value": "^User-Agent: claude-code\\/", + "user_agent": "claude-code/1.2.3", + "request_tags": [] + } +} +``` + +### Security note + +:::caution + +**`User-Agent` is a client-supplied header and can be set to any value.** Any API consumer can send `User-Agent: claude-code/1.0` regardless of whether they are actually using Claude Code. + +Do not rely on `tag_regex` routing to enforce access controls or spend limits — use [team/key-based routing](./users) for that. `tag_regex` is a **traffic classification hint** (useful for billing visibility, capacity planning, and routing convenience), not a security boundary. + +::: + + +--- + ## ✨ Team based tag routing (Enterprise) LiteLLM Proxy supports team-based tag routing, allowing you to associate specific tags with teams and route requests accordingly. Example **Team A can access gpt-4 deployment A, Team B can access gpt-4 deployment B** (LLM Access Control For Teams) diff --git a/litellm/completion_extras/litellm_responses_transformation/transformation.py b/litellm/completion_extras/litellm_responses_transformation/transformation.py index 4b31bcfc285..53ffd3647bd 100644 --- a/litellm/completion_extras/litellm_responses_transformation/transformation.py +++ b/litellm/completion_extras/litellm_responses_transformation/transformation.py @@ -398,6 +398,7 @@ def _convert_response_output_to_choices( ResponseOutputMessage, ResponseReasoningItem, ) + try: from openai.types.responses.response_output_item import ( ResponseApplyPatchToolCall, @@ -460,7 +461,9 @@ def _convert_response_output_to_choices( accumulated_tool_calls.append(tool_call_dict) tool_call_index += 1 - elif ResponseApplyPatchToolCall is not None and isinstance(item, ResponseApplyPatchToolCall): + elif ResponseApplyPatchToolCall is not None and isinstance( + item, ResponseApplyPatchToolCall + ): from litellm.responses.litellm_completion_transformation.transformation import ( LiteLLMCompletionResponsesConfig, ) diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index ea1f81f9b36..47272b38ad6 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -2680,7 +2680,9 @@ def anthropic_messages_pt( # noqa: PLR0915 _content_is_list = "content" in assistant_content_block and isinstance( assistant_content_block["content"], list ) - _content_list = assistant_content_block.get("content") if _content_is_list else None + _content_list = ( + assistant_content_block.get("content") if _content_is_list else None + ) _list_has_thinking = False if _content_is_list and _content_list is not None: for _item in _content_list: diff --git a/litellm/llms/anthropic/files/transformation.py b/litellm/llms/anthropic/files/transformation.py index 0691742bb08..98a548a1369 100644 --- a/litellm/llms/anthropic/files/transformation.py +++ b/litellm/llms/anthropic/files/transformation.py @@ -79,7 +79,9 @@ def get_error_class( return AnthropicError( status_code=status_code, message=error_message, - headers=cast(httpx.Headers, headers) if isinstance(headers, dict) else headers, + headers=cast(httpx.Headers, headers) + if isinstance(headers, dict) + else headers, ) def validate_environment( diff --git a/litellm/llms/base_llm/base_model_iterator.py b/litellm/llms/base_llm/base_model_iterator.py index 18551d97142..cf1fd6f786e 100644 --- a/litellm/llms/base_llm/base_model_iterator.py +++ b/litellm/llms/base_llm/base_model_iterator.py @@ -144,9 +144,7 @@ def __next__(self): # Skip empty lines (common in SSE streams between events). # Only apply to str chunks — non-string objects (e.g. Pydantic # BaseModel events from the Responses API) must pass through. - if isinstance(str_line, str) and ( - not str_line or not str_line.strip() - ): + if isinstance(str_line, str) and (not str_line or not str_line.strip()): continue # chunk is a str at this point @@ -184,9 +182,7 @@ async def __anext__(self): # Skip empty lines (common in SSE streams between events). # Only apply to str chunks — non-string objects (e.g. Pydantic # BaseModel events from the Responses API) must pass through. - if isinstance(str_line, str) and ( - not str_line or not str_line.strip() - ): + if isinstance(str_line, str) and (not str_line or not str_line.strip()): continue # chunk is a str at this point diff --git a/litellm/llms/bedrock/base_aws_llm.py b/litellm/llms/bedrock/base_aws_llm.py index 697fccd268b..b159d62367d 100644 --- a/litellm/llms/bedrock/base_aws_llm.py +++ b/litellm/llms/bedrock/base_aws_llm.py @@ -1268,7 +1268,8 @@ def get_request_headers( # Add back all original headers (including forwarded ones) after signature calculation for header_name, header_value in headers.items(): - request.headers[header_name] = header_value + if header_value is not None: + request.headers[header_name] = header_value if ( extra_headers is not None and "Authorization" in extra_headers @@ -1298,6 +1299,8 @@ def _filter_headers_for_aws_signature(self, headers: dict) -> dict: } for header_name, header_value in headers.items(): + if header_value is None: + continue header_lower = header_name.lower() if ( header_lower in aws_headers @@ -1393,7 +1396,8 @@ def _sign_request( # Add back original headers after signing. Only headers in SignedHeaders # are integrity-protected; forwarded headers (x-forwarded-*) must remain unsigned. for header_name, header_value in headers.items(): - request_headers_dict[header_name] = header_value + if header_value is not None: + request_headers_dict[header_name] = header_value if ( headers is not None and "Authorization" in headers ): # prevent sigv4 from overwriting the auth header diff --git a/litellm/llms/bedrock/files/transformation.py b/litellm/llms/bedrock/files/transformation.py index 096371749b5..3007b54808c 100644 --- a/litellm/llms/bedrock/files/transformation.py +++ b/litellm/llms/bedrock/files/transformation.py @@ -173,7 +173,12 @@ def get_complete_file_url( "S3 bucket_name is required. Set 's3_bucket_name' in litellm_params or AWS_S3_BUCKET_NAME env var" ) - aws_region_name = self._get_aws_region_name(optional_params, model) + s3_region_name = litellm_params.get("s3_region_name") or optional_params.get( + "s3_region_name" + ) + aws_region_name = s3_region_name or self._get_aws_region_name( + optional_params, model + ) file_data = data.get("file") purpose = data.get("purpose") @@ -398,6 +403,15 @@ def transform_create_file_request( data=create_file_data, ) + # s3_region_name always wins for S3 operations (same priority as in + # get_complete_file_url above). Overwrite aws_region_name unconditionally + # so the SigV4 region matches the URL region, avoiding SignatureDoesNotMatch. + s3_region_name = litellm_params.get("s3_region_name") or optional_params.get( + "s3_region_name" + ) + if s3_region_name: + optional_params = {**optional_params, "aws_region_name": s3_region_name} + # Sign the request and return a pre-signed request object signed_headers, signed_body = self._sign_s3_request( content=file_content, diff --git a/litellm/llms/black_forest_labs/image_edit/transformation.py b/litellm/llms/black_forest_labs/image_edit/transformation.py index 63413787c0d..c6d8e8298e3 100644 --- a/litellm/llms/black_forest_labs/image_edit/transformation.py +++ b/litellm/llms/black_forest_labs/image_edit/transformation.py @@ -201,7 +201,9 @@ def _read_image_bytes( return image elif isinstance(image, list): # If it's a list, take the first image - return self._read_image_bytes(image[0], depth=depth + 1, max_depth=max_depth) + return self._read_image_bytes( + image[0], depth=depth + 1, max_depth=max_depth + ) elif isinstance(image, str): if image.startswith(("http://", "https://")): # Download image from URL diff --git a/litellm/llms/perplexity/responses/transformation.py b/litellm/llms/perplexity/responses/transformation.py index c7ec1313566..e09dc01f1c1 100644 --- a/litellm/llms/perplexity/responses/transformation.py +++ b/litellm/llms/perplexity/responses/transformation.py @@ -71,7 +71,9 @@ def _ensure_message_type( result: List[Any] = [] for item in input: if isinstance(item, dict) and "type" not in item: - new_item = dict(item) # convert to plain dict to avoid TypedDict checking + new_item = dict( + item + ) # convert to plain dict to avoid TypedDict checking new_item["type"] = "message" result.append(new_item) else: diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index 307caa2fbc8..1bec0d23c91 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -378,9 +378,7 @@ async def _list_tools_for_single_server( # Resolve a server name to its UUID if needed _name_resolved = None if server_id not in allowed_server_ids: - _name_resolved = global_mcp_server_manager.get_mcp_server_by_name( - server_id - ) + _name_resolved = global_mcp_server_manager.get_mcp_server_by_name(server_id) if _name_resolved is not None and _name_resolved.server_id in set( allowed_server_ids ): @@ -442,9 +440,7 @@ async def _list_tools_for_single_server( extra_headers=user_oauth_extra_headers, ) except Exception as e: - verbose_logger.exception( - f"Error getting tools from {server.name}: {e}" - ) + verbose_logger.exception(f"Error getting tools from {server.name}: {e}") return { "tools": [], "error": "server_error", @@ -473,7 +469,9 @@ async def _list_tools_for_single_server( _name_resolved = None if server_id not in allowed_server_ids: _name_resolved = global_mcp_server_manager.get_mcp_server_by_name(server_id) - if _name_resolved is not None and _name_resolved.server_id in set(allowed_server_ids): + if _name_resolved is not None and _name_resolved.server_id in set( + allowed_server_ids + ): server_id = _name_resolved.server_id if server_id not in allowed_server_ids: @@ -518,7 +516,9 @@ async def _list_tools_for_single_server( server_auth_header = _get_server_auth_header( server, mcp_server_auth_headers, mcp_auth_header ) - user_oauth_extra_headers = await _get_user_oauth_extra_headers(server, user_api_key_dict) + user_oauth_extra_headers = await _get_user_oauth_extra_headers( + server, user_api_key_dict + ) try: list_tools_result = await _get_tools_for_single_server( @@ -529,9 +529,7 @@ async def _list_tools_for_single_server( extra_headers=user_oauth_extra_headers, ) except Exception as e: - verbose_logger.exception( - f"Error getting tools from {server.name}: {e}" - ) + verbose_logger.exception(f"Error getting tools from {server.name}: {e}") return { "tools": [], "error": "server_error", @@ -905,7 +903,9 @@ async def _execute_with_mcp_client( try: client_id, client_secret, scopes = _extract_credentials(request) - _oauth2_flow: Optional[Literal["client_credentials", "authorization_code"]] = ( + _oauth2_flow: Optional[ + Literal["client_credentials", "authorization_code"] + ] = ( "client_credentials" if client_id and client_secret and request.token_url else None diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index c4adecdab44..376048e7a13 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -64,7 +64,11 @@ populate_request_with_path_params, ) from litellm.proxy.common_utils.realtime_utils import _realtime_request_body -from litellm.proxy.utils import PrismaClient, ProxyLogging, normalize_route_for_root_path +from litellm.proxy.utils import ( + PrismaClient, + ProxyLogging, + normalize_route_for_root_path, +) from litellm.secret_managers.main import get_secret_bool from litellm.types.services import ServiceTypes diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py index 48ffab39bd0..0f4ebbd4880 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py +++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -106,9 +106,13 @@ def __init__( if (self.output_parse_pii or self.apply_to_output) and not logging_only: current_hook = self.event_hook if isinstance(current_hook, str) and current_hook != "post_call": - self.event_hook = cast(List[GuardrailEventHooks], [current_hook, "post_call"]) + self.event_hook = cast( + List[GuardrailEventHooks], [current_hook, "post_call"] + ) elif isinstance(current_hook, list) and "post_call" not in current_hook: - self.event_hook = cast(List[GuardrailEventHooks], current_hook + ["post_call"]) + self.event_hook = cast( + List[GuardrailEventHooks], current_hook + ["post_call"] + ) self.pii_entities_config: Dict[Union[PiiEntityType, str], PiiAction] = ( pii_entities_config or {} ) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index db1a089ff7b..3049a9f43cf 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -1838,9 +1838,7 @@ async def _validate_update_key_data( # Check team limits if key has a team_id (from request or existing key) team_obj: Optional[LiteLLM_TeamTableCachedObj] = None - _team_id_to_check = data.team_id or getattr( - existing_key_row, "team_id", None - ) + _team_id_to_check = data.team_id or getattr(existing_key_row, "team_id", None) if _team_id_to_check is not None: team_obj = await get_team_object( team_id=_team_id_to_check, @@ -1910,9 +1908,7 @@ async def _validate_update_key_data( if team_obj is None: raise HTTPException( status_code=500, - detail={ - "error": "Team object not found for team change validation" - }, + detail={"error": "Team object not found for team change validation"}, ) await validate_key_team_change( key=existing_key_row, diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index d55aa85a9b8..daf1d6f1316 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -846,12 +846,16 @@ def response_convertor(response, client): verbose_proxy_logger.debug("calling generic_sso.verify_and_process") additional_generic_sso_headers_dict = _parse_generic_sso_headers() - code_verifier: Optional[str] = None # assigned inside try; initialized for type tracking + code_verifier: Optional[ + str + ] = None # assigned inside try; initialized for type tracking try: - token_exchange_params = await SSOAuthenticationHandler.prepare_token_exchange_parameters( - request=request, - generic_include_client_id=generic_include_client_id, + token_exchange_params = ( + await SSOAuthenticationHandler.prepare_token_exchange_parameters( + request=request, + generic_include_client_id=generic_include_client_id, + ) ) # Extract code_verifier (and the cache key for deferred deletion) before calling fastapi-sso @@ -915,7 +919,9 @@ def response_convertor(response, client): # Assign directly rather than relying on nonlocal mutation so that Pyright # can track that received_response is non-None from this point on. received_response = { - k: v for k, v in combined_response.items() if k not in _OAUTH_TOKEN_FIELDS + k: v + for k, v in combined_response.items() + if k not in _OAUTH_TOKEN_FIELDS } # In the PKCE path verify_and_process is skipped, so generic_sso.access_token # is never set. Read the token directly from the exchange response instead so @@ -2598,7 +2604,9 @@ async def prepare_token_exchange_parameters( state, ) else: - verbose_proxy_logger.debug("PKCE code_verifier retrieved from cache") + verbose_proxy_logger.debug( + "PKCE code_verifier retrieved from cache" + ) elif isinstance(cached_data, str): # Handle legacy format (plain string) for backward compatibility code_verifier = cached_data @@ -2647,7 +2655,9 @@ async def _handle_missing_pkce_verifier( In strict mode (PKCE_STRICT_CACHE_MISS=true) raises ProxyException. Otherwise logs a warning and returns (token exchange proceeds without verifier). """ - active_cache = redis_usage_cache if redis_usage_cache is not None else user_api_key_cache + active_cache = ( + redis_usage_cache if redis_usage_cache is not None else user_api_key_cache + ) strict_cache_miss = ( os.getenv("PKCE_STRICT_CACHE_MISS", "false").lower() == "true" ) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index b9a1bfb9062..01a0f55aac7 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -5233,7 +5233,7 @@ def normalize_route_for_root_path(route: str) -> Optional[str]: root_path = get_server_root_path() if root_path and root_path != "/": if route.startswith(root_path + "/"): - return route[len(root_path):] + return route[len(root_path) :] return None return route diff --git a/litellm/responses/litellm_completion_transformation/transformation.py b/litellm/responses/litellm_completion_transformation/transformation.py index e9333c7dfab..71fa88fb751 100644 --- a/litellm/responses/litellm_completion_transformation/transformation.py +++ b/litellm/responses/litellm_completion_transformation/transformation.py @@ -415,7 +415,9 @@ def _transform_response_input_param_to_chat_completion_message( if isinstance(new_msg, dict) else getattr(new_msg, "tool_calls", None) ) - new_tcs: list = _raw_tcs if isinstance(_raw_tcs, list) else [] + new_tcs: list = ( + _raw_tcs if isinstance(_raw_tcs, list) else [] + ) for tc in new_tcs: LiteLLMCompletionResponsesConfig._add_tool_call_to_assistant( last_msg, tc diff --git a/litellm/router.py b/litellm/router.py index a63d5d91d1e..295e937bfb6 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -14,6 +14,7 @@ import inspect import json import logging +import re import threading import time import traceback @@ -1551,7 +1552,9 @@ async def _run_silent_completion(): # Drain any fire-and-forget tasks (e.g. alerting hooks) # scheduled via asyncio.create_task during acompletion. pending = asyncio.all_tasks() - pending.discard(asyncio.current_task()) + current = asyncio.current_task() + if current is not None: + pending.discard(current) if pending: await asyncio.gather(*pending, return_exceptions=True) @@ -6542,6 +6545,18 @@ def _create_deployment( ) return None + # Validate tag_regex patterns BEFORE adding the deployment so we never + # have partially-initialised router state if a pattern is invalid. + _tag_regex = deployment.litellm_params.get("tag_regex") or [] + for pattern in _tag_regex: + try: + re.compile(pattern) + except re.error as exc: + raise ValueError( + f"Invalid regex in tag_regex for model '{deployment.model_name}': " + f"{pattern!r} — {exc}" + ) from exc + deployment = self._add_deployment(deployment=deployment) model = deployment.to_json(exclude_none=True) diff --git a/litellm/router_strategy/tag_based_routing.py b/litellm/router_strategy/tag_based_routing.py index e7156bf1282..1309846c102 100644 --- a/litellm/router_strategy/tag_based_routing.py +++ b/litellm/router_strategy/tag_based_routing.py @@ -6,6 +6,7 @@ - If no default_deployments are set, return all deployments """ +import re from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union from litellm._logging import verbose_logger @@ -19,6 +20,29 @@ LitellmRouter = Any +def _is_valid_deployment_tag_regex( + tag_regexes: List[str], + header_strings: List[str], +) -> Optional[str]: + """ + Test compiled regex patterns against "Header-Name: value" strings. + + Returns the first matching pattern string, or None if nothing matches. + Compiles each pattern once (re's LRU cache) and logs invalid patterns once + per pattern, not once per header string. + """ + for pattern in tag_regexes: + try: + compiled = re.compile(pattern) + except re.error: + verbose_logger.warning("tag_regex: invalid pattern %r — skipping", pattern) + continue + for header_str in header_strings: + if compiled.search(header_str): + return pattern + return None + + def is_valid_deployment_tag( deployment_tags: List[str], request_tags: List[str], match_any: bool = True ) -> bool: @@ -47,6 +71,54 @@ def is_valid_deployment_tag( return False +def _match_deployment( + deployment: Any, + request_tags: Optional[List[str]], + header_strings: List[str], + match_any: bool, +) -> Optional[Dict[str, str]]: + """ + Determine whether *deployment* matches the current request. + + Returns {"matched_via": ..., "matched_value": ...} if the deployment + should be included, or None if it should be excluded. + + Priority: + 1. Exact tag match (respects match_any semantics). + 2. Regex match — skipped when match_any=False and the tag check already + ran and failed, so the regex cannot override strict-tag policy. + """ + litellm_params = deployment.get("litellm_params", {}) + deployment_tags: Optional[List[str]] = litellm_params.get("tags") + deployment_tag_regex: Optional[List[str]] = litellm_params.get("tag_regex") + + # 1. Exact tag match (existing behaviour). + if deployment_tags and request_tags: + if is_valid_deployment_tag(deployment_tags, request_tags, match_any): + matched_value = next( + (t for t in deployment_tags if t in set(request_tags)), + deployment_tags[0], + ) + return {"matched_via": "tags", "matched_value": matched_value} + + # 2. Regex match against request headers. + # When match_any=False and the deployment has both plain tags and tag_regex, + # the strict tag check has already failed (step 1 returned None). Allow + # the regex to fire only when the deployment has NO plain tags, so we never + # use regex as a backdoor around the operator's strict-tag policy. + strict_tag_check_failed = ( + not match_any and bool(deployment_tags) and bool(request_tags) + ) + if deployment_tag_regex and header_strings and not strict_tag_check_failed: + regex_match = _is_valid_deployment_tag_regex( + deployment_tag_regex, header_strings + ) + if regex_match is not None: + return {"matched_via": "tag_regex", "matched_value": regex_match} + + return None + + async def get_deployments_for_tag( llm_router_instance: LitellmRouter, model: str, # used to raise the correct error @@ -83,30 +155,63 @@ async def get_deployments_for_tag( request_tags = metadata.get("tags") match_any = llm_router_instance.tag_filtering_match_any - new_healthy_deployments = [] - default_deployments = [] - if request_tags: + # Build header strings for regex matching from what the proxy already stores. + # Currently we match against User-Agent; format matches "^User-Agent: claude-code/..." + user_agent = metadata.get("user_agent", "") + header_strings: List[str] = [f"User-Agent: {user_agent}"] if user_agent else [] + + new_healthy_deployments: List[Any] = [] + default_deployments: List[Any] = [] + + # Only activate header-based regex filtering when at least one deployment in + # the candidate set has tag_regex configured. This preserves existing + # behaviour for operators who use plain tags: a request that carries a + # User-Agent (all proxy requests do) but targets deployments with no + # tag_regex will continue to use the original tag-only code path. + has_regex_deployments = any( + d.get("litellm_params", {}).get("tag_regex") for d in healthy_deployments + ) + has_tag_filter = bool(request_tags) or ( + bool(header_strings) and has_regex_deployments + ) + if has_tag_filter: verbose_logger.debug( - "get_deployments_for_tag routing: router_keys: %s", request_tags + "get_deployments_for_tag routing: request_tags=%s user_agent=%s", + request_tags, + user_agent, ) - # example this can be router_keys=["free", "custom"] for deployment in healthy_deployments: - deployment_litellm_params = deployment.get("litellm_params") - deployment_tags = deployment_litellm_params.get("tags") + deployment_tags = deployment.get("litellm_params", {}).get("tags") - verbose_logger.debug( - "deployment: %s, deployment_router_keys: %s", - deployment, - deployment_tags, + match_result = _match_deployment( + deployment=deployment, + request_tags=request_tags, + header_strings=header_strings, + match_any=match_any, ) - if deployment_tags is None: - continue - - if is_valid_deployment_tag(deployment_tags, request_tags, match_any): + if match_result is not None: + verbose_logger.debug( + "tag routing match: deployment=%s matched_via=%s matched_value=%s", + deployment.get("model_name"), + match_result["matched_via"], + match_result["matched_value"], + ) + # Record provenance in metadata so it flows to SpendLogs. + # Written only for the first match — load balancer selects one + # deployment from new_healthy_deployments, so overwriting on + # subsequent matches would produce misleading observability data. + if "tag_routing" not in metadata: + metadata["tag_routing"] = { + "matched_deployment": deployment.get("model_name"), + "matched_via": match_result["matched_via"], + "matched_value": match_result["matched_value"], + "request_tags": request_tags or [], + "user_agent": user_agent, + } new_healthy_deployments.append(deployment) - if "default" in deployment_tags: + if deployment_tags and "default" in deployment_tags: default_deployments.append(deployment) if len(new_healthy_deployments) == 0 and len(default_deployments) == 0: diff --git a/litellm/types/router.py b/litellm/types/router.py index f0c1ea5e32a..e8ff2115ff5 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -198,6 +198,11 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams): model_info: Optional[Dict] = None mock_response: Optional[Union[str, ModelResponse, Exception, Any]] = None + # tag-based routing + tags: Optional[List[str]] = None + # regex patterns matched against request headers for tag routing + tag_regex: Optional[List[str]] = None + # auto-router params auto_router_config_path: Optional[str] = None auto_router_config: Optional[str] = None @@ -334,6 +339,8 @@ class LiteLLMParamsTypedDict(TypedDict, total=False): # routing params # use this for tag-based routing tags: Optional[List[str]] + # regex patterns matched against request headers (e.g. "^User-Agent:\\s*claude-code\\/") + tag_regex: Optional[List[str]] # deployment budgets max_budget: Optional[float] diff --git a/tests/test_litellm/llms/bedrock/files/test_bedrock_files_transformation.py b/tests/test_litellm/llms/bedrock/files/test_bedrock_files_transformation.py index 88cac84e438..40a17c12118 100644 --- a/tests/test_litellm/llms/bedrock/files/test_bedrock_files_transformation.py +++ b/tests/test_litellm/llms/bedrock/files/test_bedrock_files_transformation.py @@ -272,6 +272,164 @@ def test_anthropic_still_works_after_nova_fix(self): assert "messages" in model_input assert "max_tokens" in model_input + def test_get_complete_file_url_respects_s3_region_name(self): + """ + s3_region_name in litellm_params must be used when building the S3 URL. + Previously the code fell back to us-west-2 even when s3_region_name was set, + breaking GovCloud (us-gov-west-1) deployments. + """ + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + + jsonl_content = json.dumps( + { + "custom_id": "req-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "bedrock/amazon.nova-pro-v1:0", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + }, + } + ).encode() + + create_file_data = { + "file": ("batch.jsonl", jsonl_content, "application/jsonl"), + "purpose": "batch", + } + + litellm_params = { + "s3_bucket_name": "litellm-batch-352026", + "s3_region_name": "us-gov-west-1", + } + + url = config.get_complete_file_url( + api_base=None, + api_key=None, + model="amazon.nova-pro-v1:0", + optional_params={}, + litellm_params=litellm_params, + data=create_file_data, + ) + + assert "us-gov-west-1" in url, ( + f"Expected us-gov-west-1 in URL but got: {url}" + ) + assert "us-west-2" not in url, ( + f"us-west-2 must not appear when s3_region_name is set, got: {url}" + ) + assert "litellm-batch-352026" in url + + def test_transform_create_file_request_injects_s3_region_for_signing(self): + """ + When s3_region_name is provided, transform_create_file_request must pass + that region to _sign_s3_request so SigV4 signatures use the correct region. + """ + from unittest.mock import patch + + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + + jsonl_content = json.dumps( + { + "custom_id": "req-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "bedrock/amazon.nova-pro-v1:0", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + }, + } + ).encode() + + create_file_data = { + "file": ("batch.jsonl", jsonl_content, "application/jsonl"), + "purpose": "batch", + } + + litellm_params = { + "s3_bucket_name": "litellm-batch-352026", + "s3_region_name": "us-gov-west-1", + } + + captured_optional_params: dict = {} + + def fake_sign(content, api_base, optional_params): + captured_optional_params.update(optional_params) + return {"Authorization": "fake"}, content + + with patch.object(config, "_sign_s3_request", side_effect=fake_sign): + config.transform_create_file_request( + model="amazon.nova-pro-v1:0", + create_file_data=create_file_data, + optional_params={}, + litellm_params=litellm_params, + ) + + assert captured_optional_params.get("aws_region_name") == "us-gov-west-1", ( + "s3_region_name must be forwarded as aws_region_name for SigV4 signing" + ) + + def test_s3_region_name_wins_over_aws_region_name_for_signing(self): + """ + When both s3_region_name and aws_region_name are set to different values, + s3_region_name must win for signing (same as for the URL). Otherwise the + SigV4 signature would be computed against a different region than the URL, + causing SignatureDoesNotMatch from AWS. + """ + from unittest.mock import patch + + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + + jsonl_content = json.dumps( + { + "custom_id": "req-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "bedrock/amazon.nova-pro-v1:0", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + }, + } + ).encode() + + create_file_data = { + "file": ("batch.jsonl", jsonl_content, "application/jsonl"), + "purpose": "batch", + } + + litellm_params = { + "s3_bucket_name": "litellm-batch-352026", + "s3_region_name": "us-gov-west-1", + } + # aws_region_name set to something different — s3_region_name must still win + optional_params = {"aws_region_name": "us-east-1"} + + captured_optional_params: dict = {} + + def fake_sign(content, api_base, optional_params): + captured_optional_params.update(optional_params) + return {"Authorization": "fake"}, content + + with patch.object(config, "_sign_s3_request", side_effect=fake_sign): + config.transform_create_file_request( + model="amazon.nova-pro-v1:0", + create_file_data=create_file_data, + optional_params=optional_params, + litellm_params=litellm_params, + ) + + assert captured_optional_params.get("aws_region_name") == "us-gov-west-1", ( + "s3_region_name must override aws_region_name for SigV4 signing" + ) + def test_openai_passthrough_still_works(self): """ Regression test: ensure OpenAI-compatible models (e.g. gpt-oss) diff --git a/tests/test_litellm/llms/bedrock/test_base_aws_llm.py b/tests/test_litellm/llms/bedrock/test_base_aws_llm.py index 18fc7c6173e..29ed345d2de 100644 --- a/tests/test_litellm/llms/bedrock/test_base_aws_llm.py +++ b/tests/test_litellm/llms/bedrock/test_base_aws_llm.py @@ -14,15 +14,16 @@ from typing import Any, Dict from unittest.mock import MagicMock, patch +from botocore.awsrequest import AWSPreparedRequest, AWSRequest from botocore.credentials import Credentials -from botocore.awsrequest import AWSRequest, AWSPreparedRequest + import litellm +from litellm.caching.caching import DualCache from litellm.llms.bedrock.base_aws_llm import ( AwsAuthError, BaseAWSLLM, Boto3CredentialsInfo, ) -from litellm.caching.caching import DualCache # Global variable for the base_aws_llm.py file path @@ -1519,6 +1520,83 @@ def test_is_already_running_as_role_invalid_target_arn(): assert base_aws_llm._is_already_running_as_role("not-a-valid-arn") is False +def test_filter_headers_skips_none_values(): + """ + Test that _filter_headers_for_aws_signature skips headers with None values. + + Reproduces the issue where botocore's SigV4Auth crashes with + 'NoneType' object has no attribute 'split' when a header value is None. + """ + llm = BaseAWSLLM() + + headers = { + "Content-Type": "application/json", + "x-amz-security-token": None, + "x-amzn-bedrock-kb-session-id": None, + "host": None, + "x-amz-date": "20240101T000000Z", + "x-custom-header": None, + } + + filtered = llm._filter_headers_for_aws_signature(headers) + + assert filtered["Content-Type"] == "application/json" + assert filtered["x-amz-date"] == "20240101T000000Z" + assert "x-amz-security-token" not in filtered + assert "x-amzn-bedrock-kb-session-id" not in filtered + assert "host" not in filtered + # Non-AWS headers are excluded regardless + assert "x-custom-header" not in filtered + + +def test_sign_request_with_none_header_values(): + """ + End-to-end test that _sign_request does not crash when headers contain + None values for x-amz-* keys. + + This reproduces the Bedrock KB GovCloud issue where SigV4 signing failed + with 'NoneType' object has no attribute 'split'. + + Also verifies that None-valued headers are NOT re-merged into the + returned headers dict (which would cause downstream HTTP client failures). + """ + llm = BaseAWSLLM() + + mock_credentials = Credentials("test_key", "test_secret") + + headers_with_nones = { + "Content-Type": "application/json", + "x-amzn-trace-id": None, + "x-forwarded-for": None, + } + + with patch.object( + llm, "get_credentials", return_value=mock_credentials + ), patch.object( + llm, "_get_aws_region_name", return_value="us-gov-west-1" + ): + result_headers, result_body = llm._sign_request( + service_name="bedrock", + headers=headers_with_nones, + optional_params={ + "aws_access_key_id": "test_key", + "aws_secret_access_key": "test_secret", + "aws_region_name": "us-gov-west-1", + }, + request_data={"retrievalQuery": {"text": "test query"}}, + api_base="https://bedrock-agent-runtime.us-gov-west-1.amazonaws.com/knowledgebases/KB123/retrieve", + ) + + assert "Authorization" in result_headers + assert result_body is not None + + # None-valued headers must NOT appear in the returned headers + for header_name, header_value in result_headers.items(): + assert header_value is not None, ( + f"Header '{header_name}' has None value in returned headers" + ) + + def test_is_already_running_as_role_ssl_verify_passed(): """ Test that ssl_verify parameter is correctly passed to the STS client. diff --git a/tests/test_litellm/router_strategy/test_router_tag_regex_routing.py b/tests/test_litellm/router_strategy/test_router_tag_regex_routing.py new file mode 100644 index 00000000000..6c7cfa61b58 --- /dev/null +++ b/tests/test_litellm/router_strategy/test_router_tag_regex_routing.py @@ -0,0 +1,375 @@ +""" +Unit tests for tag_regex routing. + +Tests _is_valid_deployment_tag_regex() and get_deployments_for_tag() with tag_regex +patterns, verifying that regex-based header matching works correctly alongside +existing tag-based routing. +""" + +import os +import sys + +import pytest + +sys.path.insert(0, os.path.abspath("../..")) + +from unittest.mock import MagicMock + +from litellm.router_strategy import tag_based_routing +from litellm.router_strategy.tag_based_routing import get_deployments_for_tag + +_is_valid_deployment_tag_regex = tag_based_routing._is_valid_deployment_tag_regex + + +# --------------------------------------------------------------------------- +# _is_valid_deployment_tag_regex unit tests +# --------------------------------------------------------------------------- + + +def test_regex_matches_claude_code_user_agent(): + """^User-Agent: claude-code/ matches a claude-code UA string.""" + result = _is_valid_deployment_tag_regex( + tag_regexes=[r"^User-Agent: claude-code\/"], + header_strings=["User-Agent: claude-code/1.2.3"], + ) + assert result == r"^User-Agent: claude-code\/" + + +def test_regex_no_match_for_other_ua(): + """Pattern does not match a non-claude-code User-Agent.""" + result = _is_valid_deployment_tag_regex( + tag_regexes=[r"^User-Agent: claude-code\/"], + header_strings=["User-Agent: Mozilla/5.0 (browser)"], + ) + assert result is None + + +def test_regex_returns_first_matching_pattern(): + """When multiple patterns are provided, returns the first match.""" + result = _is_valid_deployment_tag_regex( + tag_regexes=[r"^User-Agent: cursor\/", r"^User-Agent: claude-code\/"], + header_strings=["User-Agent: claude-code/2.0.0"], + ) + assert result == r"^User-Agent: claude-code\/" + + +def test_regex_empty_inputs_return_none(): + """Empty lists return None without errors.""" + assert _is_valid_deployment_tag_regex([], ["User-Agent: claude-code/1.0"]) is None + assert _is_valid_deployment_tag_regex([r"^User-Agent: claude-code\/"], []) is None + + +def test_invalid_regex_skipped_does_not_raise(): + """An invalid regex pattern is skipped (warning logged) — no exception raised.""" + result = _is_valid_deployment_tag_regex( + tag_regexes=["[invalid(regex"], + header_strings=["User-Agent: claude-code/1.0"], + ) + assert result is None + + +def test_regex_matches_version_range(): + """Semver-aware pattern matches multiple versions.""" + pattern = r"^User-Agent: claude-code\/\d" + for ua in ["claude-code/1.0", "claude-code/2.0.0-beta.1", "claude-code/99.0"]: + result = _is_valid_deployment_tag_regex( + tag_regexes=[pattern], + header_strings=[f"User-Agent: {ua}"], + ) + assert result == pattern, f"Expected match for UA: {ua}" + + +# --------------------------------------------------------------------------- +# get_deployments_for_tag integration tests +# --------------------------------------------------------------------------- + +CLAUDE_CODE_DEPLOYMENT = { + "model_name": "claude-sonnet", + "litellm_params": { + "model": "openai/claude-code-deployment", + "api_key": "fake", + "mock_response": "cc", + "tag_regex": [r"^User-Agent: claude-code\/"], + }, + "model_info": {"id": "claude-code-deployment"}, +} + +REGULAR_DEPLOYMENT = { + "model_name": "claude-sonnet", + "litellm_params": { + "model": "openai/regular-deployment", + "api_key": "fake", + "mock_response": "regular", + "tags": ["default"], + }, + "model_info": {"id": "regular-deployment"}, +} + +ALL_DEPLOYMENTS = [CLAUDE_CODE_DEPLOYMENT, REGULAR_DEPLOYMENT] + + +def _make_router_mock(enable_tag_filtering=True, match_any=True): + mock = MagicMock() + mock.enable_tag_filtering = enable_tag_filtering + mock.tag_filtering_match_any = match_any + return mock + + +@pytest.mark.asyncio +async def test_claude_code_ua_routes_to_cc_deployment(): + """claude-code/x.y.z UA → claude-code-deployment via tag_regex.""" + router = _make_router_mock() + result = await get_deployments_for_tag( + llm_router_instance=router, + model="claude-sonnet", + healthy_deployments=ALL_DEPLOYMENTS, + request_kwargs={"metadata": {"user_agent": "claude-code/1.2.3"}}, + ) + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "claude-code-deployment" + + +@pytest.mark.asyncio +async def test_regular_ua_routes_to_default_deployment(): + """Mozilla UA → regular-deployment via default tag fallback.""" + router = _make_router_mock() + result = await get_deployments_for_tag( + llm_router_instance=router, + model="claude-sonnet", + healthy_deployments=ALL_DEPLOYMENTS, + request_kwargs={"metadata": {"user_agent": "Mozilla/5.0 (browser)"}}, + ) + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "regular-deployment" + + +@pytest.mark.asyncio +async def test_no_ua_routes_to_default_deployment(): + """No User-Agent → default deployment.""" + router = _make_router_mock() + result = await get_deployments_for_tag( + llm_router_instance=router, + model="claude-sonnet", + healthy_deployments=ALL_DEPLOYMENTS, + request_kwargs={"metadata": {}}, + ) + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "regular-deployment" + + +@pytest.mark.asyncio +async def test_tag_routing_metadata_written_for_regex_match(): + """tag_routing metadata block is populated when regex matches.""" + router = _make_router_mock() + metadata: dict = {"user_agent": "claude-code/2.0.0-beta.1"} + await get_deployments_for_tag( + llm_router_instance=router, + model="claude-sonnet", + healthy_deployments=ALL_DEPLOYMENTS, + request_kwargs={"metadata": metadata}, + ) + assert "tag_routing" in metadata + tr = metadata["tag_routing"] + assert tr["matched_via"] == "tag_regex" + assert tr["matched_value"] == r"^User-Agent: claude-code\/" + assert tr["user_agent"] == "claude-code/2.0.0-beta.1" + + +@pytest.mark.asyncio +async def test_tag_filtering_disabled_returns_all_deployments(): + """When enable_tag_filtering is False, all deployments returned regardless of UA.""" + router = _make_router_mock(enable_tag_filtering=False) + result = await get_deployments_for_tag( + llm_router_instance=router, + model="claude-sonnet", + healthy_deployments=ALL_DEPLOYMENTS, + request_kwargs={"metadata": {"user_agent": "claude-code/1.0"}}, + ) + assert result == ALL_DEPLOYMENTS + + +@pytest.mark.asyncio +async def test_explicit_tag_match_takes_precedence_over_regex(): + """A deployment with both tags and tag_regex: exact tag match fires first.""" + deployment_with_both = { + "model_name": "claude-sonnet", + "litellm_params": { + "model": "openai/both-deployment", + "api_key": "fake", + "tags": ["premium"], + "tag_regex": [r"^User-Agent: claude-code\/"], + }, + "model_info": {"id": "both-deployment"}, + } + router = _make_router_mock() + metadata: dict = { + "tags": ["premium"], + "user_agent": "claude-code/1.0", + } + result = await get_deployments_for_tag( + llm_router_instance=router, + model="claude-sonnet", + healthy_deployments=[deployment_with_both], + request_kwargs={"metadata": metadata}, + ) + assert len(result) == 1 + tr = metadata.get("tag_routing", {}) + assert tr.get("matched_via") == "tags" + + +@pytest.mark.asyncio +async def test_user_agent_present_no_tag_regex_deployments_does_not_raise(): + """ + Backwards-compat: a request that carries a User-Agent but targets plain-tag + deployments (no tag_regex) must NOT raise ValueError — it should fall + through to the default/all-deployments path just like before. + """ + plain_tag_only_deployments = [ + { + "model_name": "gpt-4", + "litellm_params": { + "model": "openai/premium-deployment", + "api_key": "fake", + "tags": ["premium"], + }, + "model_info": {"id": "premium-deployment"}, + }, + { + "model_name": "gpt-4", + "litellm_params": { + "model": "openai/free-deployment", + "api_key": "fake", + "tags": ["free"], + }, + "model_info": {"id": "free-deployment"}, + }, + ] + router = _make_router_mock() + # The request has a User-Agent (as all proxy requests do) but NO tags and + # neither deployment has tag_regex — must not raise, must return all. + result = await get_deployments_for_tag( + llm_router_instance=router, + model="gpt-4", + healthy_deployments=plain_tag_only_deployments, + request_kwargs={"metadata": {"user_agent": "Mozilla/5.0 (any-client)"}}, + ) + # Falls through to "return healthy_deployments" path unchanged + assert result == plain_tag_only_deployments + + +@pytest.mark.asyncio +async def test_tag_routing_metadata_not_overwritten_for_multiple_matches(): + """ + When multiple deployments match, tag_routing records only the first match + so the provenance reflects what the load balancer likely selected. + """ + deployment_a = { + "model_name": "claude-sonnet", + "litellm_params": { + "model": "openai/cc-deployment-a", + "api_key": "fake", + "tag_regex": [r"^User-Agent: claude-code\/"], + }, + "model_info": {"id": "cc-deployment-a"}, + } + deployment_b = { + "model_name": "claude-sonnet", + "litellm_params": { + "model": "openai/cc-deployment-b", + "api_key": "fake", + "tag_regex": [r"^User-Agent: claude-code\/"], + }, + "model_info": {"id": "cc-deployment-b"}, + } + router = _make_router_mock() + metadata: dict = {"user_agent": "claude-code/1.0"} + result = await get_deployments_for_tag( + llm_router_instance=router, + model="claude-sonnet", + healthy_deployments=[deployment_a, deployment_b], + request_kwargs={"metadata": metadata}, + ) + assert len(result) == 2 + # tag_routing recorded once and reflects the first match + tr = metadata.get("tag_routing", {}) + assert tr.get("matched_deployment") == "claude-sonnet" + assert tr.get("matched_via") == "tag_regex" + + +@pytest.mark.asyncio +async def test_match_any_false_strict_tag_check_blocks_regex_fallback(): + """ + When match_any=False and a deployment has both tags and tag_regex: + if the strict tag check fails (request has a tag NOT present on the + deployment, so req_set is NOT a subset of dep_set), the regex fallback + must NOT fire — that would violate the operator's strict-filtering intent. + + Semantics of match_any=False: req_set.issubset(dep_set), i.e. every + request tag must appear on the deployment. A request with tags ["vip"] + against a deployment with tags ["premium"] fails because "vip" ∉ dep_set. + """ + deployment_strict = { + "model_name": "claude-sonnet", + "litellm_params": { + "model": "openai/strict-deployment", + "api_key": "fake", + "tags": ["premium"], + "tag_regex": [r"^User-Agent: claude-code\/"], + }, + "model_info": {"id": "strict-deployment"}, + } + default_deployment = { + "model_name": "claude-sonnet", + "litellm_params": { + "model": "openai/default-deployment", + "api_key": "fake", + "tags": ["default"], + }, + "model_info": {"id": "default-deployment"}, + } + # match_any=False: req_set must be a subset of dep_set. + # Request has "vip" which is NOT in ["premium"], so tag check fails. + # Even though UA matches tag_regex, the deployment must NOT be selected. + router = _make_router_mock(enable_tag_filtering=True, match_any=False) + metadata: dict = { + "tags": ["vip"], # "vip" not in deployment tags → strict check fails + "user_agent": "claude-code/1.0", + } + result = await get_deployments_for_tag( + llm_router_instance=router, + model="claude-sonnet", + healthy_deployments=[deployment_strict, default_deployment], + request_kwargs={"metadata": metadata}, + ) + ids = [d["model_info"]["id"] for d in result] + assert "strict-deployment" not in ids, ( + "strict-deployment should not be selected: strict tag check failed " + "and regex must not override the strict policy" + ) + + +@pytest.mark.asyncio +async def test_match_any_false_regex_only_deployment_still_matches(): + """ + When match_any=False and a deployment has ONLY tag_regex (no plain tags), + there is no strict tag policy to violate, so the regex check must still fire. + """ + regex_only_deployment = { + "model_name": "claude-sonnet", + "litellm_params": { + "model": "openai/regex-only-deployment", + "api_key": "fake", + "tag_regex": [r"^User-Agent: claude-code\/"], + # no "tags" key at all + }, + "model_info": {"id": "regex-only-deployment"}, + } + router = _make_router_mock(enable_tag_filtering=True, match_any=False) + result = await get_deployments_for_tag( + llm_router_instance=router, + model="claude-sonnet", + healthy_deployments=[regex_only_deployment], + request_kwargs={"metadata": {"user_agent": "claude-code/1.0"}}, + ) + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "regex-only-deployment" diff --git a/tests/test_litellm/test_model_cost_aliases.py b/tests/test_litellm/test_model_cost_aliases.py index 6e30cbfe157..f9f92a85cb2 100644 --- a/tests/test_litellm/test_model_cost_aliases.py +++ b/tests/test_litellm/test_model_cost_aliases.py @@ -5,8 +5,9 @@ entries, creating shared dict references for alias entries at load time. """ -import logging +from unittest.mock import patch +from litellm import verbose_logger from litellm.litellm_core_utils.get_model_cost_map import _expand_model_aliases @@ -118,7 +119,7 @@ def test_empty_aliases_list(self): class TestAliasConflicts: """Tests for alias conflict detection and handling.""" - def test_alias_conflicts_with_canonical_entry(self, caplog): + def test_alias_conflicts_with_canonical_entry(self): """Alias that matches an existing canonical entry is skipped with a warning.""" model_cost = { "model-latest": { @@ -133,14 +134,17 @@ def test_alias_conflicts_with_canonical_entry(self, caplog): "mode": "chat", }, } - with caplog.at_level(logging.WARNING, logger="LiteLLM"): + with patch.object(verbose_logger, "warning") as mock_warn: result = _expand_model_aliases(model_cost) # The canonical "model-dated" entry is preserved, not overwritten assert "model-dated" in result - assert "alias conflict" in caplog.text.lower() + # Verify a warning about the alias conflict was logged + mock_warn.assert_called() + warning_messages = " ".join(str(c) for c in mock_warn.call_args_list) + assert "alias conflict" in warning_messages.lower() - def test_duplicate_alias_across_entries(self, caplog): + def test_duplicate_alias_across_entries(self): """Same alias claimed by two different entries: second one is skipped.""" model_cost = { "model-a": { @@ -156,13 +160,16 @@ def test_duplicate_alias_across_entries(self, caplog): "mode": "chat", }, } - with caplog.at_level(logging.WARNING, logger="LiteLLM"): + with patch.object(verbose_logger, "warning") as mock_warn: result = _expand_model_aliases(model_cost) # "shared-alias" should point to model-a (first one wins) assert "shared-alias" in result assert result["shared-alias"]["input_cost_per_token"] == 1e-06 - assert "alias conflict" in caplog.text.lower() + # Verify a warning about the alias conflict was logged + mock_warn.assert_called() + warning_messages = " ".join(str(c) for c in mock_warn.call_args_list) + assert "alias conflict" in warning_messages.lower() def test_canonical_entry_not_overwritten_by_alias(self): """An alias must never overwrite an existing canonical entry's data."""