1414import os
1515import random
1616from concurrent .futures import ThreadPoolExecutor
17- from typing import Any , AsyncIterator , Callable , Generator , Mapping , Optional , Type , TypeVar , Union , cast
17+ from typing import Any , AsyncIterator , Callable , Generator , List , Mapping , Optional , Type , TypeVar , Union , cast
1818
1919from opentelemetry import trace
2020from pydantic import BaseModel
3131from ..types .content import ContentBlock , Message , Messages
3232from ..types .exceptions import ContextWindowOverflowException
3333from ..types .models import Model
34- from ..types .tools import ToolConfig
34+ from ..types .tools import ToolConfig , ToolResult , ToolUse
3535from ..types .traces import AttributeValue
3636from .agent_result import AgentResult
3737from .conversation_manager import (
@@ -97,104 +97,56 @@ def __getattr__(self, name: str) -> Callable[..., Any]:
9797 AttributeError: If no tool with the given name exists or if multiple tools match the given name.
9898 """
9999
100- def find_normalized_tool_name () -> Optional [str ]:
101- """Lookup the tool represented by name, replacing characters with underscores as necessary."""
102- tool_registry = self ._agent .tool_registry .registry
103-
104- if tool_registry .get (name , None ):
105- return name
106-
107- # If the desired name contains underscores, it might be a placeholder for characters that can't be
108- # represented as python identifiers but are valid as tool names, such as dashes. In that case, find
109- # all tools that can be represented with the normalized name
110- if "_" in name :
111- filtered_tools = [
112- tool_name for (tool_name , tool ) in tool_registry .items () if tool_name .replace ("-" , "_" ) == name
113- ]
114-
115- # The registry itself defends against similar names, so we can just take the first match
116- if filtered_tools :
117- return filtered_tools [0 ]
118-
119- raise AttributeError (f"Tool '{ name } ' not found" )
120-
121- def caller (** kwargs : Any ) -> Any :
100+ def caller (
101+ user_message_override : Optional [str ] = None ,
102+ record_direct_tool_call : Optional [bool ] = None ,
103+ ** kwargs : Any ,
104+ ) -> Any :
122105 """Call a tool directly by name.
123106
124107 Args:
108+ user_message_override: Optional custom message to record instead of default
109+ record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class
110+ attribute if provided.
125111 **kwargs: Keyword arguments to pass to the tool.
126112
127- - user_message_override: Custom message to record instead of default
128- - tool_execution_handler: Custom handler for tool execution
129- - event_loop_metrics: Custom metrics collector
130- - messages: Custom message history to use
131- - tool_config: Custom tool configuration
132- - callback_handler: Custom callback handler
133- - record_direct_tool_call: Whether to record this call in history
134-
135113 Returns:
136114 The result returned by the tool.
137115
138116 Raises:
139117 AttributeError: If the tool doesn't exist.
140118 """
141- normalized_name = find_normalized_tool_name ( )
119+ normalized_name = self . _find_normalized_tool_name ( name )
142120
143121 # Create unique tool ID and set up the tool request
144122 tool_id = f"tooluse_{ name } _{ random .randint (100000000 , 999999999 )} "
145- tool_use = {
123+ tool_use : ToolUse = {
146124 "toolUseId" : tool_id ,
147125 "name" : normalized_name ,
148126 "input" : kwargs .copy (),
149127 }
150128
151- # Extract tool execution parameters
152- user_message_override = kwargs .get ("user_message_override" , None )
153- tool_execution_handler = kwargs .get ("tool_execution_handler" , self ._agent .thread_pool_wrapper )
154- event_loop_metrics = kwargs .get ("event_loop_metrics" , self ._agent .event_loop_metrics )
155- messages = kwargs .get ("messages" , self ._agent .messages )
156- tool_config = kwargs .get ("tool_config" , self ._agent .tool_config )
157- callback_handler = kwargs .get ("callback_handler" , self ._agent .callback_handler )
158- record_direct_tool_call = kwargs .get ("record_direct_tool_call" , self ._agent .record_direct_tool_call )
159-
160- # Process tool call
161- handler_kwargs = {
162- k : v
163- for k , v in kwargs .items ()
164- if k
165- not in [
166- "tool_execution_handler" ,
167- "event_loop_metrics" ,
168- "messages" ,
169- "tool_config" ,
170- "callback_handler" ,
171- "tool_handler" ,
172- "system_prompt" ,
173- "model" ,
174- "model_id" ,
175- "user_message_override" ,
176- "agent" ,
177- "record_direct_tool_call" ,
178- ]
179- }
180-
181129 # Execute the tool
182130 tool_result = self ._agent .tool_handler .process (
183131 tool = tool_use ,
184132 model = self ._agent .model ,
185133 system_prompt = self ._agent .system_prompt ,
186- messages = messages ,
187- tool_config = tool_config ,
188- callback_handler = callback_handler ,
189- tool_execution_handler = tool_execution_handler ,
190- event_loop_metrics = event_loop_metrics ,
191- agent = self ._agent ,
192- ** handler_kwargs ,
134+ messages = self ._agent .messages ,
135+ tool_config = self ._agent .tool_config ,
136+ callback_handler = self ._agent .callback_handler ,
137+ kwargs = kwargs ,
193138 )
194139
195- if record_direct_tool_call :
140+ if record_direct_tool_call is not None :
141+ should_record_direct_tool_call = record_direct_tool_call
142+ else :
143+ should_record_direct_tool_call = self ._agent .record_direct_tool_call
144+
145+ if should_record_direct_tool_call :
196146 # Create a record of this tool execution in the message history
197- self ._agent ._record_tool_execution (tool_use , tool_result , user_message_override , messages )
147+ self ._agent ._record_tool_execution (
148+ tool_use , tool_result , user_message_override , self ._agent .messages
149+ )
198150
199151 # Apply window management
200152 self ._agent .conversation_manager .apply_management (self ._agent )
@@ -203,6 +155,27 @@ def caller(**kwargs: Any) -> Any:
203155
204156 return caller
205157
158+ def _find_normalized_tool_name (self , name : str ) -> str :
159+ """Lookup the tool represented by name, replacing characters with underscores as necessary."""
160+ tool_registry = self ._agent .tool_registry .registry
161+
162+ if tool_registry .get (name , None ):
163+ return name
164+
165+ # If the desired name contains underscores, it might be a placeholder for characters that can't be
166+ # represented as python identifiers but are valid as tool names, such as dashes. In that case, find
167+ # all tools that can be represented with the normalized name
168+ if "_" in name :
169+ filtered_tools = [
170+ tool_name for (tool_name , tool ) in tool_registry .items () if tool_name .replace ("-" , "_" ) == name
171+ ]
172+
173+ # The registry itself defends against similar names, so we can just take the first match
174+ if filtered_tools :
175+ return filtered_tools [0 ]
176+
177+ raise AttributeError (f"Tool '{ name } ' not found" )
178+
206179 def __init__ (
207180 self ,
208181 model : Union [Model , str , None ] = None ,
@@ -371,7 +344,7 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
371344
372345 Args:
373346 prompt: The natural language prompt from the user.
374- **kwargs: Additional parameters to pass to the event loop.
347+ **kwargs: Additional parameters to pass through the event loop.
375348
376349 Returns:
377350 Result object containing:
@@ -514,44 +487,35 @@ def _execute_event_loop_cycle(
514487 Yields:
515488 Events of the loop cycle.
516489 """
517- # Extract parameters with fallbacks to instance values
518- system_prompt = kwargs .pop ("system_prompt" , self .system_prompt )
519- model = kwargs .pop ("model" , self .model )
520- tool_execution_handler = kwargs .pop ("tool_execution_handler" , self .thread_pool_wrapper )
521- event_loop_metrics = kwargs .pop ("event_loop_metrics" , self .event_loop_metrics )
522- callback_handler_override = kwargs .pop ("callback_handler" , callback_handler )
523- tool_handler = kwargs .pop ("tool_handler" , self .tool_handler )
524- messages = kwargs .pop ("messages" , self .messages )
525- tool_config = kwargs .pop ("tool_config" , self .tool_config )
526- kwargs .pop ("agent" , None ) # Remove agent to avoid conflicts
490+ # Add `Agent` to kwargs to keep backwards-compatibility
491+ kwargs ["agent" ] = self
527492
528493 try :
529494 # Execute the main event loop cycle
530495 yield from event_loop_cycle (
531- model = model ,
532- system_prompt = system_prompt ,
533- messages = messages , # will be modified by event_loop_cycle
534- tool_config = tool_config ,
535- callback_handler = callback_handler_override ,
536- tool_handler = tool_handler ,
537- tool_execution_handler = tool_execution_handler ,
538- event_loop_metrics = event_loop_metrics ,
539- agent = self ,
496+ model = self .model ,
497+ system_prompt = self .system_prompt ,
498+ messages = self .messages , # will be modified by event_loop_cycle
499+ tool_config = self .tool_config ,
500+ callback_handler = callback_handler ,
501+ tool_handler = self .tool_handler ,
502+ tool_execution_handler = self .thread_pool_wrapper ,
503+ event_loop_metrics = self .event_loop_metrics ,
540504 event_loop_parent_span = self .trace_span ,
541- ** kwargs ,
505+ kwargs = kwargs ,
542506 )
543507
544508 except ContextWindowOverflowException as e :
545509 # Try reducing the context size and retrying
546510 self .conversation_manager .reduce_context (self , e = e )
547- yield from self ._execute_event_loop_cycle (callback_handler_override , kwargs )
511+ yield from self ._execute_event_loop_cycle (callback_handler , kwargs )
548512
549513 def _record_tool_execution (
550514 self ,
551- tool : dict [ str , Any ] ,
552- tool_result : dict [ str , Any ] ,
515+ tool : ToolUse ,
516+ tool_result : ToolResult ,
553517 user_message_override : Optional [str ],
554- messages : list [ dict [ str , Any ]] ,
518+ messages : Messages ,
555519 ) -> None :
556520 """Record a tool execution in the message history.
557521
@@ -569,7 +533,7 @@ def _record_tool_execution(
569533 messages: The message history to append to.
570534 """
571535 # Create user message describing the tool call
572- user_msg_content = [
536+ user_msg_content : List [ ContentBlock ] = [
573537 {"text" : (f"agent.tool.{ tool ['name' ]} direct tool call.\n Input parameters: { json .dumps (tool ['input' ])} \n " )}
574538 ]
575539
@@ -578,19 +542,19 @@ def _record_tool_execution(
578542 user_msg_content .insert (0 , {"text" : f"{ user_message_override } \n " })
579543
580544 # Create the message sequence
581- user_msg = {
545+ user_msg : Message = {
582546 "role" : "user" ,
583547 "content" : user_msg_content ,
584548 }
585- tool_use_msg = {
549+ tool_use_msg : Message = {
586550 "role" : "assistant" ,
587551 "content" : [{"toolUse" : tool }],
588552 }
589- tool_result_msg = {
553+ tool_result_msg : Message = {
590554 "role" : "user" ,
591555 "content" : [{"toolResult" : tool_result }],
592556 }
593- assistant_msg = {
557+ assistant_msg : Message = {
594558 "role" : "assistant" ,
595559 "content" : [{"text" : f"agent.{ tool ['name' ]} was called" }],
596560 }
0 commit comments