@@ -394,15 +394,15 @@ async def test_session_memory_rejects_both_session_and_list_input(runner_method)
394394 await run_agent_async (runner_method , agent , list_input , session = session )
395395
396396 # Verify the error message explains the issue
397- assert "You must specify the `session_input_handling ` in" in str (exc_info .value )
397+ assert "You must specify the `session_input_callback ` in" in str (exc_info .value )
398398 assert "manually manage conversation history" in str (exc_info .value )
399399
400400 session .close ()
401401
402402
403403@pytest .mark .parametrize ("runner_method" , ["run" , "run_sync" , "run_streamed" ])
404404@pytest .mark .asyncio
405- async def test_session_memory_append_list (runner_method ):
405+ async def test_session_callback_prepared_input (runner_method ):
406406 """Test if the user passes a list of items and want to append them."""
407407 with tempfile .TemporaryDirectory () as temp_dir :
408408 db_path = Path (temp_dir ) / "test_memory.db"
@@ -414,106 +414,35 @@ async def test_session_memory_append_list(runner_method):
414414 session_id = "session_1"
415415 session = SQLiteSession (session_id , db_path )
416416
417- model .set_next_output ([get_text_message ("I like cats" )])
418- _ = await run_agent_async (runner_method , agent , "I like cats" , session = session )
419-
420- append_input = [
421- {"role" : "user" , "content" : "Some random user text" },
422- {"role" : "assistant" , "content" : "You're right" },
423- {"role" : "user" , "content" : "What did I say I like?" },
417+ # Add first messages manually
418+ initial_history : list [TResponseInputItem ] = [
419+ {"role" : "user" , "content" : "Hello there." },
420+ {"role" : "assistant" , "content" : "Hi, I'm here to assist you." },
424421 ]
425- second_model_response = {"role" : "assistant" , "content" : "Yes, you mentioned cats" }
426- model .set_next_output ([get_text_message (second_model_response .get ("content" , "" ))])
427-
428- _ = await run_agent_async (
429- runner_method ,
430- agent ,
431- append_input ,
432- session = session ,
433- run_config = RunConfig (session_input_handling = "append" ),
434- )
435-
436- session_items = await session .get_items ()
437-
438- # Check the items has been appended
439- assert len (session_items ) == 6
440-
441- # Check the items are the last 4 elements
442- append_input .append (second_model_response )
443- for sess_item , orig_item in zip (session_items [- 4 :], append_input ):
444- assert sess_item .get ("role" ) == orig_item .get ("role" )
445-
446- sess_content = sess_item .get ("content" )
447- # Narrow to list or str for mypy
448- assert isinstance (sess_content , (list , str ))
449-
450- if isinstance (sess_content , list ):
451- # now mypy knows `content: list[Any]`
452- assert isinstance (sess_content [0 ], dict ) and "text" in sess_content [0 ]
453- val_sess = sess_content [0 ]["text" ]
454- else :
455- # here content is str
456- val_sess = sess_content
457-
458- assert val_sess == orig_item ["content" ]
459-
460- session .close ()
461-
462-
463- @pytest .mark .parametrize ("runner_method" , ["run" , "run_sync" , "run_streamed" ])
464- @pytest .mark .asyncio
465- async def test_session_memory_replace_list (runner_method ):
466- """Test if the user passes a list of items and want to replace the history."""
467- with tempfile .TemporaryDirectory () as temp_dir :
468- db_path = Path (temp_dir ) / "test_memory.db"
422+ await session .add_items (initial_history )
469423
470- model = FakeModel ()
471- agent = Agent (name = "test" , model = model )
424+ def filter_assistant_messages (history , new_input ):
425+ # Only include user messages from history
426+ return [item for item in history if item ["role" ] == "user" ] + new_input
472427
473- # Session
474- session_id = "session_1"
475- session = SQLiteSession (session_id , db_path )
428+ new_turn_input = [{"role" : "user" , "content" : "What your name?" }]
429+ model .set_next_output ([get_text_message ("I'm gpt-4o" )])
476430
477- model .set_next_output ([get_text_message ("I like cats" )])
478- _ = await run_agent_async (runner_method , agent , "I like cats" , session = session )
479-
480- replace_input = [
481- {"role" : "user" , "content" : "Some random user text" },
482- {"role" : "assistant" , "content" : "You're right" },
483- {"role" : "user" , "content" : "What did I say I like?" },
484- ]
485- second_model_response = {"role" : "assistant" , "content" : "Yes, you mentioned cats" }
486- model .set_next_output ([get_text_message (second_model_response .get ("content" , "" ))])
487-
488- _ = await run_agent_async (
431+ # Run the agent with the callable
432+ await run_agent_async (
489433 runner_method ,
490434 agent ,
491- replace_input ,
435+ new_turn_input ,
492436 session = session ,
493- run_config = RunConfig (session_input_handling = "replace" ),
437+ run_config = RunConfig (session_input_callback = filter_assistant_messages ),
494438 )
495439
496- session_items = await session .get_items ()
497-
498- # Check the new items replaced the history
499- assert len (session_items ) == 4
500-
501- # Check the items are the last 4 elements
502- replace_input .append (second_model_response )
503- for sess_item , orig_item in zip (session_items , replace_input ):
504- assert sess_item .get ("role" ) == orig_item .get ("role" )
505- sess_content = sess_item .get ("content" )
506- # Narrow to list or str for mypy
507- assert isinstance (sess_content , (list , str ))
508-
509- if isinstance (sess_content , list ):
510- # now mypy knows `content: list[Any]`
511- assert isinstance (sess_content [0 ], dict ) and "text" in sess_content [0 ]
512- val_sess = sess_content [0 ]["text" ]
513- else :
514- # here content is str
515- val_sess = sess_content
440+ expected_model_input = [
441+ initial_history [0 ], # From history
442+ new_turn_input [0 ], # New input
443+ ]
516444
517- assert val_sess == orig_item ["content" ]
445+ assert len (model .last_turn_args ["input" ]) == 2
446+ assert model .last_turn_args ["input" ] == expected_model_input
518447
519448 session .close ()
0 commit comments