diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 5e21fe8b..30d1ee89 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -44,6 +44,7 @@ UnsupportedOperationError, ) from a2a.utils.errors import ServerError +from a2a.utils.task import apply_history_length from a2a.utils.telemetry import SpanKind, trace_class @@ -118,25 +119,7 @@ async def on_get_task( raise ServerError(error=TaskNotFoundError()) # Apply historyLength parameter if specified - if params.history_length is not None and task.history: - # Limit history to the most recent N messages - limited_history = ( - task.history[-params.history_length :] - if params.history_length > 0 - else [] - ) - # Create a new task instance with limited history - task = Task( - id=task.id, - context_id=task.context_id, - status=task.status, - artifacts=task.artifacts, - history=limited_history, - metadata=task.metadata, - kind=task.kind, - ) - - return task + return apply_history_length(task, params.history_length) async def on_cancel_task( self, params: TaskIdParams, context: ServerCallContext | None = None @@ -363,6 +346,10 @@ async def push_notification_callback() -> None: if isinstance(result, Task): self._validate_task_id_match(task_id, result.id) + if params.configuration: + result = apply_history_length( + result, params.configuration.history_length + ) await self._send_push_notification_if_needed(task_id, result_aggregator) diff --git a/src/a2a/utils/task.py b/src/a2a/utils/task.py index 22556cde..5c5f3f07 100644 --- a/src/a2a/utils/task.py +++ b/src/a2a/utils/task.py @@ -70,3 +70,25 @@ def completed_task( artifacts=artifacts, history=history, ) + + +def apply_history_length(task: Task, history_length: int | None) -> Task: + """Applies history_length parameter on task and returns a new task object. + + Args: + task: The original task object with complete history + history_length: History length configuration value + + Returns: + A new task object with limited history + """ + # Apply historyLength parameter if specified + if history_length is not None and task.history: + # Limit history to the most recent N messages + limited_history = ( + task.history[-history_length:] if history_length > 0 else [] + ) + # Create a new task instance with limited history + return task.model_copy(update={'history': limited_history}) + + return task diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 6765000c..5268af11 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -836,6 +836,85 @@ async def test_on_message_send_non_blocking(): assert task.status.state == TaskState.completed +@pytest.mark.asyncio +async def test_on_message_send_limit_history(): + task_store = InMemoryTaskStore() + push_store = InMemoryPushNotificationConfigStore() + + request_handler = DefaultRequestHandler( + agent_executor=HelloAgentExecutor(), + task_store=task_store, + push_config_store=push_store, + ) + params = MessageSendParams( + message=Message( + role=Role.user, + message_id='msg_push', + parts=[Part(root=TextPart(text='Hi'))], + ), + configuration=MessageSendConfiguration( + blocking=True, + accepted_output_modes=['text/plain'], + history_length=0, + ), + ) + + result = await request_handler.on_message_send( + params, create_server_call_context() + ) + + # verify that history_length is honored + assert result is not None + assert isinstance(result, Task) + assert result.history is not None and len(result.history) == 0 + assert result.status.state == TaskState.completed + + # verify that history is still persisted to the store + task = await task_store.get(result.id) + assert task is not None + assert task.history is not None and len(task.history) > 0 + + +@pytest.mark.asyncio +async def test_on_task_get_limit_history(): + task_store = InMemoryTaskStore() + push_store = InMemoryPushNotificationConfigStore() + + request_handler = DefaultRequestHandler( + agent_executor=HelloAgentExecutor(), + task_store=task_store, + push_config_store=push_store, + ) + params = MessageSendParams( + message=Message( + role=Role.user, + message_id='msg_push', + parts=[Part(root=TextPart(text='Hi'))], + ), + configuration=MessageSendConfiguration( + blocking=True, accepted_output_modes=['text/plain'] + ), + ) + + result = await request_handler.on_message_send( + params, create_server_call_context() + ) + + assert result is not None + assert isinstance(result, Task) + + get_task_result = await request_handler.on_get_task( + TaskQueryParams(id=result.id, history_length=0), + create_server_call_context(), + ) + assert get_task_result is not None + assert isinstance(get_task_result, Task) + assert ( + get_task_result.history is not None + and len(get_task_result.history) == 0 + ) + + @pytest.mark.asyncio async def test_on_message_send_interrupted_flow(): """Test on_message_send when flow is interrupted (e.g., auth_required)."""