32
32
from .. import EVENT_LOGGER_NAME
33
33
from ..base import Handoff as HandoffBase
34
34
from ..base import Response
35
+ from ..memory ._base_memory import Memory
35
36
from ..messages import (
36
37
AgentEvent ,
37
38
ChatMessage ,
44
45
)
45
46
from ..state import AssistantAgentState
46
47
from ._base_chat_agent import BaseChatAgent
47
- from ..memory ._base_memory import Memory
48
48
49
49
event_logger = logging .getLogger (EVENT_LOGGER_NAME )
50
50
@@ -245,8 +245,7 @@ def __init__(
245
245
name : str ,
246
246
model_client : ChatCompletionClient ,
247
247
* ,
248
- tools : List [Tool | Callable [..., Any ] |
249
- Callable [..., Awaitable [Any ]]] | None = None ,
248
+ tools : List [Tool | Callable [..., Any ] | Callable [..., Awaitable [Any ]]] | None = None ,
250
249
handoffs : List [HandoffBase | str ] | None = None ,
251
250
model_context : ChatCompletionContext | None = None ,
252
251
description : str = "An agent that provides assistance with ability to use tools." ,
@@ -266,20 +265,19 @@ def __init__(
266
265
elif isinstance (memory , list ):
267
266
self ._memory = memory
268
267
else :
269
- raise TypeError (
270
- f"Expected Memory, List[Memory], or None, got { type (memory )} " )
268
+ raise TypeError (f"Expected Memory, List[Memory], or None, got { type (memory )} " )
271
269
272
- self ._system_messages : List [SystemMessage | UserMessage |
273
- AssistantMessage | FunctionExecutionResultMessage ] = []
270
+ self ._system_messages : List [
271
+ SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage
272
+ ] = []
274
273
if system_message is None :
275
274
self ._system_messages = []
276
275
else :
277
276
self ._system_messages = [SystemMessage (content = system_message )]
278
277
self ._tools : List [Tool ] = []
279
278
if tools is not None :
280
279
if model_client .model_info ["function_calling" ] is False :
281
- raise ValueError (
282
- "The model does not support function calling." )
280
+ raise ValueError ("The model does not support function calling." )
283
281
for tool in tools :
284
282
if isinstance (tool , Tool ):
285
283
self ._tools .append (tool )
@@ -288,8 +286,7 @@ def __init__(
288
286
description = tool .__doc__
289
287
else :
290
288
description = ""
291
- self ._tools .append (FunctionTool (
292
- tool , description = description ))
289
+ self ._tools .append (FunctionTool (tool , description = description ))
293
290
else :
294
291
raise ValueError (f"Unsupported tool type: { type (tool )} " )
295
292
# Check if tool names are unique.
@@ -301,22 +298,19 @@ def __init__(
301
298
self ._handoffs : Dict [str , HandoffBase ] = {}
302
299
if handoffs is not None :
303
300
if model_client .model_info ["function_calling" ] is False :
304
- raise ValueError (
305
- "The model does not support function calling, which is needed for handoffs." )
301
+ raise ValueError ("The model does not support function calling, which is needed for handoffs." )
306
302
for handoff in handoffs :
307
303
if isinstance (handoff , str ):
308
304
handoff = HandoffBase (target = handoff )
309
305
if isinstance (handoff , HandoffBase ):
310
306
self ._handoff_tools .append (handoff .handoff_tool )
311
307
self ._handoffs [handoff .name ] = handoff
312
308
else :
313
- raise ValueError (
314
- f"Unsupported handoff type: { type (handoff )} " )
309
+ raise ValueError (f"Unsupported handoff type: { type (handoff )} " )
315
310
# Check if handoff tool names are unique.
316
311
handoff_tool_names = [tool .name for tool in self ._handoff_tools ]
317
312
if len (handoff_tool_names ) != len (set (handoff_tool_names )):
318
- raise ValueError (
319
- f"Handoff names must be unique: { handoff_tool_names } " )
313
+ raise ValueError (f"Handoff names must be unique: { handoff_tool_names } " )
320
314
# Check if handoff tool names not in tool names.
321
315
if any (name in tool_names for name in handoff_tool_names ):
322
316
raise ValueError (
@@ -344,8 +338,7 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
344
338
async for message in self .on_messages_stream (messages , cancellation_token ):
345
339
if isinstance (message , Response ):
346
340
return message
347
- raise AssertionError (
348
- "The stream should have returned the final result." )
341
+ raise AssertionError ("The stream should have returned the final result." )
349
342
350
343
async def on_messages_stream (
351
344
self , messages : Sequence [ChatMessage ], cancellation_token : CancellationToken
@@ -377,26 +370,22 @@ async def on_messages_stream(
377
370
# Check if the response is a string and return it.
378
371
if isinstance (result .content , str ):
379
372
yield Response (
380
- chat_message = TextMessage (
381
- content = result .content , source = self .name , models_usage = result .usage ),
373
+ chat_message = TextMessage (content = result .content , source = self .name , models_usage = result .usage ),
382
374
inner_messages = inner_messages ,
383
375
)
384
376
return
385
377
386
378
# Process tool calls.
387
- assert isinstance (result .content , list ) and all (
388
- isinstance (item , FunctionCall ) for item in result .content )
389
- tool_call_msg = ToolCallRequestEvent (
390
- content = result .content , source = self .name , models_usage = result .usage )
379
+ assert isinstance (result .content , list ) and all (isinstance (item , FunctionCall ) for item in result .content )
380
+ tool_call_msg = ToolCallRequestEvent (content = result .content , source = self .name , models_usage = result .usage )
391
381
event_logger .debug (tool_call_msg )
392
382
# Add the tool call message to the output.
393
383
inner_messages .append (tool_call_msg )
394
384
yield tool_call_msg
395
385
396
386
# Execute the tool calls.
397
387
results = await asyncio .gather (* [self ._execute_tool_call (call , cancellation_token ) for call in result .content ])
398
- tool_call_result_msg = ToolCallExecutionEvent (
399
- content = results , source = self .name )
388
+ tool_call_result_msg = ToolCallExecutionEvent (content = results , source = self .name )
400
389
event_logger .debug (tool_call_result_msg )
401
390
await self ._model_context .add_message (FunctionExecutionResultMessage (content = results ))
402
391
inner_messages .append (tool_call_result_msg )
@@ -416,8 +405,7 @@ async def on_messages_stream(
416
405
)
417
406
# Return the output messages to signal the handoff.
418
407
yield Response (
419
- chat_message = HandoffMessage (
420
- content = handoffs [0 ].message , target = handoffs [0 ].target , source = self .name ),
408
+ chat_message = HandoffMessage (content = handoffs [0 ].message , target = handoffs [0 ].target , source = self .name ),
421
409
inner_messages = inner_messages ,
422
410
)
423
411
return
@@ -431,8 +419,7 @@ async def on_messages_stream(
431
419
await self ._model_context .add_message (AssistantMessage (content = result .content , source = self .name ))
432
420
# Yield the response.
433
421
yield Response (
434
- chat_message = TextMessage (
435
- content = result .content , source = self .name , models_usage = result .usage ),
422
+ chat_message = TextMessage (content = result .content , source = self .name , models_usage = result .usage ),
436
423
inner_messages = inner_messages ,
437
424
)
438
425
else :
@@ -448,8 +435,7 @@ async def on_messages_stream(
448
435
)
449
436
tool_call_summary = "\n " .join (tool_call_summaries )
450
437
yield Response (
451
- chat_message = ToolCallSummaryMessage (
452
- content = tool_call_summary , source = self .name ),
438
+ chat_message = ToolCallSummaryMessage (content = tool_call_summary , source = self .name ),
453
439
inner_messages = inner_messages ,
454
440
)
455
441
@@ -460,11 +446,9 @@ async def _execute_tool_call(
460
446
try :
461
447
if not self ._tools + self ._handoff_tools :
462
448
raise ValueError ("No tools are available." )
463
- tool = next ((t for t in self ._tools +
464
- self ._handoff_tools if t .name == tool_call .name ), None )
449
+ tool = next ((t for t in self ._tools + self ._handoff_tools if t .name == tool_call .name ), None )
465
450
if tool is None :
466
- raise ValueError (
467
- f"The tool '{ tool_call .name } ' is not available." )
451
+ raise ValueError (f"The tool '{ tool_call .name } ' is not available." )
468
452
arguments = json .loads (tool_call .arguments )
469
453
result = await tool .run_json (arguments , cancellation_token )
470
454
result_as_str = tool .return_value_as_string (result )
0 commit comments