Skip to content

Commit 6c3ca71

Browse files
Use a callback function to mix history and input
1 parent 7a4e740 commit 6c3ca71

File tree

4 files changed

+82
-124
lines changed

4 files changed

+82
-124
lines changed

src/agents/memory/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .session import Session, SQLiteSession
2+
from .util import SessionInputHandler, SessionMixerCallable
23

3-
__all__ = ["Session", "SQLiteSession"]
4+
__all__ = ["Session", "SessionInputHandler", "SessionMixerCallable", "SQLiteSession"]

src/agents/memory/util.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from __future__ import annotations
2+
3+
from typing import Callable, Union
4+
5+
from ..items import TResponseInputItem
6+
from ..util._types import MaybeAwaitable
7+
8+
SessionMixerCallable = Callable[
9+
[list[TResponseInputItem], list[TResponseInputItem]],
10+
MaybeAwaitable[list[TResponseInputItem]],
11+
]
12+
"""A function that combines session history with new input items.
13+
14+
Args:
15+
history_items: The list of items from the session history.
16+
new_items: The list of new input items for the current turn.
17+
18+
Returns:
19+
A list of combined items to be used as input for the agent. Can be sync or async.
20+
"""
21+
22+
23+
SessionInputHandler = Union[SessionMixerCallable, None]
24+
"""Defines how to handle session history when new input is provided.
25+
26+
- `None` (default): The new input is appended to the session history.
27+
- `SessionMixerCallable`: A custom function that receives the history and new input, and
28+
returns the desired combined list of items.
29+
"""

src/agents/run.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import copy
55
import inspect
66
from dataclasses import dataclass, field
7-
from typing import Any, Generic, Literal, cast
7+
from typing import Any, Generic, cast
88

99
from openai.types.responses import ResponseCompletedEvent
1010
from openai.types.responses.response_prompt_param import (
@@ -44,7 +44,7 @@
4444
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
4545
from .lifecycle import RunHooks
4646
from .logger import logger
47-
from .memory import Session
47+
from .memory import Session, SessionInputHandler
4848
from .model_settings import ModelSettings
4949
from .models.interface import Model, ModelProvider
5050
from .models.multi_provider import MultiProvider
@@ -139,9 +139,12 @@ class RunConfig:
139139
An optional dictionary of additional metadata to include with the trace.
140140
"""
141141

142-
session_input_handling: Literal["replace", "append"] | None = None
143-
"""If a custom input list is given together with the Session, it will
144-
be appended to the session messages or it will replace them.
142+
session_input_callback: SessionInputHandler = None
143+
"""Defines how to handle session history when new input is provided.
144+
145+
- `None` (default): The new input is appended to the session history.
146+
- `SessionMixerCallable`: A custom function that receives the history and new input, and
147+
returns the desired combined list of items.
145148
"""
146149

147150

@@ -349,7 +352,7 @@ async def run(
349352

350353
# Prepare input with session if enabled
351354
prepared_input = await self._prepare_input_with_session(
352-
input, session, run_config.session_input_handling
355+
input, session, run_config.session_input_callback
353356
)
354357

355358
tool_use_tracker = AgentToolUseTracker()
@@ -475,9 +478,7 @@ async def run(
475478
)
476479

477480
# Save the conversation to session if enabled
478-
await self._save_result_to_session(
479-
session, input, result, run_config.session_input_handling
480-
)
481+
await self._save_result_to_session(session, input, result)
481482

482483
return result
483484
elif isinstance(turn_result.next_step, NextStepHandoff):
@@ -672,7 +673,7 @@ async def _start_streaming(
672673
try:
673674
# Prepare input with session if enabled
674675
prepared_input = await AgentRunner._prepare_input_with_session(
675-
starting_input, session, run_config.session_input_handling
676+
starting_input, session, run_config.session_input_callback
676677
)
677678

678679
# Update the streamed result with the prepared input
@@ -792,7 +793,7 @@ async def _start_streaming(
792793
context_wrapper=context_wrapper,
793794
)
794795
await AgentRunner._save_result_to_session(
795-
session, starting_input, temp_result, run_config.session_input_handling
796+
session, starting_input, temp_result
796797
)
797798

798799
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
@@ -1202,16 +1203,16 @@ async def _prepare_input_with_session(
12021203
cls,
12031204
input: str | list[TResponseInputItem],
12041205
session: Session | None,
1205-
session_input_handling: Literal["replace", "append"] | None,
1206+
session_input_callback: SessionInputHandler,
12061207
) -> str | list[TResponseInputItem]:
12071208
"""Prepare input by combining it with session history if enabled."""
12081209
if session is None:
12091210
return input
12101211

12111212
# If the user doesn't explicitly specify a mode, raise an error
1212-
if isinstance(input, list) and not session_input_handling:
1213+
if isinstance(input, list) and not session_input_callback:
12131214
raise UserError(
1214-
"You must specify the `session_input_handling` in the `RunConfig`. "
1215+
"You must specify the `session_input_callback` in the `RunConfig`. "
12151216
"Otherwise, when using session memory, provide only a string input to append to "
12161217
"the conversation, or use session=None and provide a list to manually manage "
12171218
"conversation history."
@@ -1223,36 +1224,34 @@ async def _prepare_input_with_session(
12231224
# Convert input to list format
12241225
new_input_list = ItemHelpers.input_to_new_input_list(input)
12251226

1226-
if session_input_handling == "append" or session_input_handling is None:
1227-
# Append new input to history
1228-
combined_input = history + new_input_list
1229-
elif session_input_handling == "replace":
1230-
# Replace history with new input
1231-
combined_input = new_input_list
1227+
if session_input_callback is None:
1228+
return history + new_input_list
1229+
elif callable(session_input_callback):
1230+
res = session_input_callback(history, new_input_list)
1231+
if inspect.isawaitable(res):
1232+
return await res
1233+
return res
12321234
else:
12331235
raise UserError(
1234-
"The specified `session_input_handling` is not available. "
1235-
"Choose between `append`, `replace` or `None`."
1236+
f"Invalid `session_input_callback` value: {session_input_callback}. "
1237+
"Choose between `None` or a custom callable function."
12361238
)
12371239

1238-
return combined_input
1239-
12401240
@classmethod
12411241
async def _save_result_to_session(
12421242
cls,
12431243
session: Session | None,
12441244
original_input: str | list[TResponseInputItem],
12451245
result: RunResult,
1246-
saving_mode: Literal["replace", "append"] | None = None,
12471246
) -> None:
1248-
"""Save the conversation turn to session."""
1247+
"""
1248+
Save the conversation turn to session.
1249+
It does not account for any filtering or modification performed by
1250+
`RunConfig.session_input_handling`.
1251+
"""
12491252
if session is None:
12501253
return
12511254

1252-
# Remove old history
1253-
if saving_mode == "replace":
1254-
await session.clear_session()
1255-
12561255
# Convert original input to list format if needed
12571256
input_list = ItemHelpers.input_to_new_input_list(original_input)
12581257

tests/test_session.py

Lines changed: 22 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -394,15 +394,15 @@ async def test_session_memory_rejects_both_session_and_list_input(runner_method)
394394
await run_agent_async(runner_method, agent, list_input, session=session)
395395

396396
# Verify the error message explains the issue
397-
assert "You must specify the `session_input_handling` in" in str(exc_info.value)
397+
assert "You must specify the `session_input_callback` in" in str(exc_info.value)
398398
assert "manually manage conversation history" in str(exc_info.value)
399399

400400
session.close()
401401

402402

403403
@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"])
404404
@pytest.mark.asyncio
405-
async def test_session_memory_append_list(runner_method):
405+
async def test_session_callback_prepared_input(runner_method):
406406
"""Test if the user passes a list of items and want to append them."""
407407
with tempfile.TemporaryDirectory() as temp_dir:
408408
db_path = Path(temp_dir) / "test_memory.db"
@@ -414,106 +414,35 @@ async def test_session_memory_append_list(runner_method):
414414
session_id = "session_1"
415415
session = SQLiteSession(session_id, db_path)
416416

417-
model.set_next_output([get_text_message("I like cats")])
418-
_ = await run_agent_async(runner_method, agent, "I like cats", session=session)
419-
420-
append_input = [
421-
{"role": "user", "content": "Some random user text"},
422-
{"role": "assistant", "content": "You're right"},
423-
{"role": "user", "content": "What did I say I like?"},
417+
# Add first messages manually
418+
initial_history: list[TResponseInputItem] = [
419+
{"role": "user", "content": "Hello there."},
420+
{"role": "assistant", "content": "Hi, I'm here to assist you."},
424421
]
425-
second_model_response = {"role": "assistant", "content": "Yes, you mentioned cats"}
426-
model.set_next_output([get_text_message(second_model_response.get("content", ""))])
427-
428-
_ = await run_agent_async(
429-
runner_method,
430-
agent,
431-
append_input,
432-
session=session,
433-
run_config=RunConfig(session_input_handling="append"),
434-
)
435-
436-
session_items = await session.get_items()
437-
438-
# Check the items has been appended
439-
assert len(session_items) == 6
440-
441-
# Check the items are the last 4 elements
442-
append_input.append(second_model_response)
443-
for sess_item, orig_item in zip(session_items[-4:], append_input):
444-
assert sess_item.get("role") == orig_item.get("role")
445-
446-
sess_content = sess_item.get("content")
447-
# Narrow to list or str for mypy
448-
assert isinstance(sess_content, (list, str))
449-
450-
if isinstance(sess_content, list):
451-
# now mypy knows `content: list[Any]`
452-
assert isinstance(sess_content[0], dict) and "text" in sess_content[0]
453-
val_sess = sess_content[0]["text"]
454-
else:
455-
# here content is str
456-
val_sess = sess_content
457-
458-
assert val_sess == orig_item["content"]
459-
460-
session.close()
461-
462-
463-
@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"])
464-
@pytest.mark.asyncio
465-
async def test_session_memory_replace_list(runner_method):
466-
"""Test if the user passes a list of items and want to replace the history."""
467-
with tempfile.TemporaryDirectory() as temp_dir:
468-
db_path = Path(temp_dir) / "test_memory.db"
422+
await session.add_items(initial_history)
469423

470-
model = FakeModel()
471-
agent = Agent(name="test", model=model)
424+
def filter_assistant_messages(history, new_input):
425+
# Only include user messages from history
426+
return [item for item in history if item["role"] == "user"] + new_input
472427

473-
# Session
474-
session_id = "session_1"
475-
session = SQLiteSession(session_id, db_path)
428+
new_turn_input = [{"role": "user", "content": "What your name?"}]
429+
model.set_next_output([get_text_message("I'm gpt-4o")])
476430

477-
model.set_next_output([get_text_message("I like cats")])
478-
_ = await run_agent_async(runner_method, agent, "I like cats", session=session)
479-
480-
replace_input = [
481-
{"role": "user", "content": "Some random user text"},
482-
{"role": "assistant", "content": "You're right"},
483-
{"role": "user", "content": "What did I say I like?"},
484-
]
485-
second_model_response = {"role": "assistant", "content": "Yes, you mentioned cats"}
486-
model.set_next_output([get_text_message(second_model_response.get("content", ""))])
487-
488-
_ = await run_agent_async(
431+
# Run the agent with the callable
432+
await run_agent_async(
489433
runner_method,
490434
agent,
491-
replace_input,
435+
new_turn_input,
492436
session=session,
493-
run_config=RunConfig(session_input_handling="replace"),
437+
run_config=RunConfig(session_input_callback=filter_assistant_messages),
494438
)
495439

496-
session_items = await session.get_items()
497-
498-
# Check the new items replaced the history
499-
assert len(session_items) == 4
500-
501-
# Check the items are the last 4 elements
502-
replace_input.append(second_model_response)
503-
for sess_item, orig_item in zip(session_items, replace_input):
504-
assert sess_item.get("role") == orig_item.get("role")
505-
sess_content = sess_item.get("content")
506-
# Narrow to list or str for mypy
507-
assert isinstance(sess_content, (list, str))
508-
509-
if isinstance(sess_content, list):
510-
# now mypy knows `content: list[Any]`
511-
assert isinstance(sess_content[0], dict) and "text" in sess_content[0]
512-
val_sess = sess_content[0]["text"]
513-
else:
514-
# here content is str
515-
val_sess = sess_content
440+
expected_model_input = [
441+
initial_history[0], # From history
442+
new_turn_input[0], # New input
443+
]
516444

517-
assert val_sess == orig_item["content"]
445+
assert len(model.last_turn_args["input"]) == 2
446+
assert model.last_turn_args["input"] == expected_model_input
518447

519448
session.close()

0 commit comments

Comments
 (0)