|
6 | 6 | from litellm import ContextWindowExceededError |
7 | 7 |
|
8 | 8 | import dspy |
9 | | -from dspy.adapters.types.tool import Tool |
| 9 | +from dspy.adapters.types.tool import Tool, ToolCalls |
10 | 10 | from dspy.primitives.module import Module |
11 | 11 | from dspy.signatures.signature import ensure_signature |
12 | 12 |
|
@@ -81,7 +81,7 @@ def get_weather(city: str) -> str: |
81 | 81 | dspy.Signature({**signature.input_fields}, "\n".join(instr)) |
82 | 82 | .append("trajectory", dspy.InputField(), type_=str) |
83 | 83 | .append("next_thought", dspy.OutputField(), type_=str) |
84 | | - .append("next_tool_calls", dspy.OutputField(), type_=list[dict[str, Any]]) |
| 84 | + .append("next_tool_calls", dspy.OutputField(), type_=ToolCalls) |
85 | 85 | ) |
86 | 86 |
|
87 | 87 | fallback_signature = dspy.Signature( |
@@ -173,8 +173,12 @@ async def aforward(self, **input_args): |
173 | 173 | def _parse_tool_calls(self, tool_calls_data): |
174 | 174 | """Parse tool calls from the prediction output. |
175 | 175 |
|
176 | | - Handles both the new list format and provides backward compatibility. |
| 176 | + Handles both ToolCalls objects and list formats for backward compatibility. |
177 | 177 | """ |
| 178 | + # If it's a ToolCalls object, extract the list of tool calls |
| 179 | + if isinstance(tool_calls_data, ToolCalls): |
| 180 | + return [{"name": tc.name, "args": tc.args} for tc in tool_calls_data.tool_calls] |
| 181 | + |
178 | 182 | # If it's already a list of dicts with 'name' and 'args', use it directly |
179 | 183 | if isinstance(tool_calls_data, list): |
180 | 184 | return tool_calls_data |
|
0 commit comments