diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index 89a708077f3..6eb4ee74490 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -4468,9 +4468,10 @@ def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]: defs = parameters.pop("$defs", {}) defs_copy = copy.deepcopy(defs) - # flatten the defs - for _, value in defs_copy.items(): - unpack_defs(value, defs_copy) + # Expand $ref references in parameters using the definitions + # Note: We don't pre-flatten defs as that causes exponential memory growth + # with circular references (see issue #19098). unpack_defs handles nested + # refs recursively and correctly detects/skips circular references. unpack_defs(parameters, defs_copy) tool_input_schema = BedrockToolInputSchemaBlock( json=BedrockToolJsonSchemaBlock( diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index 2aa6a00c72b..8550de504b8 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -453,9 +453,10 @@ def _build_vertex_schema(parameters: dict, add_property_ordering: bool = False): valid_schema_fields = set(get_type_hints(Schema).keys()) defs = parameters.pop("$defs", {}) - # flatten the defs - for name, value in defs.items(): - unpack_defs(value, defs) + # Expand $ref references in parameters using the definitions + # Note: We don't pre-flatten defs as that causes exponential memory growth + # with circular references (see issue #19098). unpack_defs handles nested + # refs recursively and correctly detects/skips circular references. unpack_defs(parameters, defs) # 5. Nullable fields: diff --git a/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py b/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py index 42a2b5d0971..a22fe13798f 100644 --- a/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py +++ b/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py @@ -1392,3 +1392,134 @@ def test_anthropic_messages_pt_server_tool_use_passthrough(): b for b in assistant_msg["content"] if b.get("type") == "text" ) assert text_block["text"] == "I found the time tool. How can I help you?" + + +def test_bedrock_tools_unpack_defs_no_oom_with_nested_refs(): + """ + Regression test for issue #19098: unpack_defs() causes OOM with nested tool schemas. + + The old implementation had a "flatten defs" loop that would pre-expand each def + using unpack_defs(), but since defs often reference each other, each subsequent + call would copy already-expanded content, causing exponential memory growth. + + This test creates a schema with multiple nested $defs that reference each other + to verify the fix prevents memory explosion while still correctly resolving refs. + """ + import sys + import copy + + from litellm.litellm_core_utils.prompt_templates.factory import _bedrock_tools_pt + + # Schema with multiple nested $defs that reference each other + # This pattern would cause OOM with the old "flatten defs" loop + complex_nested_schema = { + "type": "object", + "properties": { + "query": {"$ref": "#/$defs/Expression"}, + }, + "$defs": { + "Expression": { + "type": "object", + "properties": { + "type": {"type": "string", "enum": ["and", "or", "not", "comparison"]}, + "left": {"$ref": "#/$defs/Operand"}, + "right": {"$ref": "#/$defs/Operand"}, + "operator": {"$ref": "#/$defs/Operator"}, + }, + }, + "Operand": { + "type": "object", + "anyOf": [ + {"$ref": "#/$defs/Literal"}, + {"$ref": "#/$defs/FieldRef"}, + {"$ref": "#/$defs/Expression"}, # Circular: Operand -> Expression -> Operand + ], + }, + "Literal": { + "type": "object", + "properties": { + "type": {"type": "string", "const": "literal"}, + "value": {"$ref": "#/$defs/LiteralValue"}, + }, + }, + "LiteralValue": { + "oneOf": [ + {"type": "string"}, + {"type": "number"}, + {"type": "boolean"}, + {"type": "null"}, + ], + }, + "FieldRef": { + "type": "object", + "properties": { + "type": {"type": "string", "const": "field"}, + "name": {"type": "string"}, + "table": {"$ref": "#/$defs/TableRef"}, + }, + }, + "TableRef": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "alias": {"type": "string"}, + }, + }, + "Operator": { + "type": "string", + "enum": ["=", "!=", "<", ">", "<=", ">=", "LIKE", "IN"], + }, + }, + } + + tools = [ + { + "type": "function", + "function": { + "name": "execute_query", + "description": "Execute a query with complex expressions", + "parameters": complex_nested_schema, + }, + } + ] + + # Measure initial size + def get_size(obj, seen=None): + size = sys.getsizeof(obj) + if seen is None: + seen = set() + obj_id = id(obj) + if obj_id in seen: + return 0 + seen.add(obj_id) + if isinstance(obj, dict): + size += sum([get_size(v, seen) for v in obj.values()]) + size += sum([get_size(k, seen) for k in obj.keys()]) + elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, bytearray)): + size += sum([get_size(i, seen) for i in obj]) + return size + + initial_size = get_size(tools) + + # Process through _bedrock_tools_pt - this should complete without OOM + tools_copy = copy.deepcopy(tools) + result = _bedrock_tools_pt(tools=tools_copy) + + final_size = get_size(result) + + # The expansion factor should be reasonable (< 100x), not exponential (35000x as in #19098) + expansion_factor = final_size / initial_size + assert expansion_factor < 100, ( + f"Memory expansion factor {expansion_factor:.1f}x is too high. " + f"Initial: {initial_size} bytes, Final: {final_size} bytes" + ) + + # Verify the result is valid Bedrock tools format + assert isinstance(result, list) + assert len(result) == 1 + assert "toolSpec" in result[0] + assert result[0]["toolSpec"]["name"] == "execute_query" + + # Verify $defs have been removed (Bedrock doesn't support them) + tool_schema = result[0]["toolSpec"].get("inputSchema", {}).get("json", {}) + assert "$defs" not in tool_schema, "$defs should be removed after expansion"