diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index fbbf2c00b8..0a16142757 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -476,6 +476,15 @@ async def _preprocess_async( # If it's a toolset, process it first if isinstance(tool_union, BaseToolset): + # Generate preprocessing events (e.g., authentication requests) + async with Aclosing( + tool_union.generate_preprocessing_events( + tool_context=tool_context, llm_request=llm_request + ) + ) as agen: + async for event in agen: + yield event + await tool_union.process_llm_request( tool_context=tool_context, llm_request=llm_request ) diff --git a/src/google/adk/tools/base_toolset.py b/src/google/adk/tools/base_toolset.py index 201eec9087..e4c9577a14 100644 --- a/src/google/adk/tools/base_toolset.py +++ b/src/google/adk/tools/base_toolset.py @@ -17,6 +17,7 @@ from abc import ABC from abc import abstractmethod import copy +from typing import AsyncGenerator from typing import final from typing import List from typing import Optional @@ -31,6 +32,7 @@ from .base_tool import BaseTool if TYPE_CHECKING: + from ..events.event import Event from ..models.llm_request import LlmRequest from .tool_configs import ToolArgsConfig from .tool_context import ToolContext @@ -204,3 +206,31 @@ async def process_llm_request( llm_request: The outgoing LLM request, mutable this method. """ pass + + async def generate_preprocessing_events( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> AsyncGenerator[Event, None]: + """Generates events during the preprocessing phase. + + This method allows toolsets to generate events (such as authentication + requests) before tool discovery occurs. It has access to the full + ToolContext with authentication capabilities. + + Use cases: + - OAuth2 authentication flows before tool discovery + - User confirmation requests for sensitive toolsets + - Dynamic configuration based on user context + - Pre-flight checks that require user interaction + + Args: + tool_context: The context of the tool with full authentication capabilities. + llm_request: The outgoing LLM request, mutable by this method. + + Yields: + Event: Events for user interaction (e.g., authentication requests). + """ + # Default implementation yields nothing (backward compatibility) + # Subclasses can override to yield authentication or other events + if False: # This ensures the method is an AsyncGenerator + yield # Required for AsyncGenerator type hint + return diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index 81ef925a39..9a5848d36d 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -14,6 +14,7 @@ """Unit tests for BaseLlmFlow toolset integration.""" +from typing import AsyncGenerator from unittest import mock from unittest.mock import AsyncMock @@ -26,6 +27,7 @@ from google.adk.plugins.base_plugin import BasePlugin from google.adk.tools.base_toolset import BaseToolset from google.adk.tools.google_search_tool import GoogleSearchTool +from google.adk.tools.tool_context import ToolContext from google.genai import types import pytest @@ -91,6 +93,156 @@ async def close(self): assert mock_toolset.process_llm_request_called +@pytest.mark.asyncio +async def test_preprocess_calls_toolset_generate_preprocessing_events(): + """Test that _preprocess_async calls generate_preprocessing_events on toolsets.""" + + # Create a mock toolset that tracks if generate_preprocessing_events was called + class _MockToolset(BaseToolset): + + def __init__(self): + super().__init__() + self.generate_preprocessing_events_called = False + self.generated_events = [] + + async def generate_preprocessing_events( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> AsyncGenerator[Event, None]: + self.generate_preprocessing_events_called = True + # Generate a mock authentication event + auth_event = Event( + author='system', + invocation_id='test_invocation', + content=types.Content( + role='model', + parts=[types.Part(text='Mock authentication request')], + ), + ) + self.generated_events.append(auth_event) + yield auth_event + + async def get_tools(self, readonly_context=None): + return [] + + async def close(self): + pass + + mock_toolset = _MockToolset() + + # Create a mock model that returns a simple response + mock_response = LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text='Test response')] + ), + partial=False, + ) + + mock_model = testing_utils.MockModel.create(responses=[mock_response]) + + # Create agent with the mock toolset + agent = Agent(name='test_agent', model=mock_model, tools=[mock_toolset]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + + flow = BaseLlmFlowForTesting() + + # Call _preprocess_async + llm_request = LlmRequest() + events = [] + async for event in flow._preprocess_async(invocation_context, llm_request): + events.append(event) + + # Verify that generate_preprocessing_events was called on the toolset + assert mock_toolset.generate_preprocessing_events_called + + # Verify that the generated event was yielded + assert len(events) == 1 + assert events[0].author == 'system' + assert events[0].content.parts[0].text == 'Mock authentication request' + + +@pytest.mark.asyncio +async def test_preprocess_calls_both_generate_events_and_process_request(): + """Test that _preprocess_async calls both generate_preprocessing_events and process_llm_request.""" + + # Create a mock toolset that tracks both method calls + class _MockToolset(BaseToolset): + + def __init__(self): + super().__init__() + self.generate_preprocessing_events_called = False + self.process_llm_request_called = False + self.call_order = [] + + async def generate_preprocessing_events( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> AsyncGenerator[Event, None]: + self.generate_preprocessing_events_called = True + self.call_order.append('generate_preprocessing_events') + # Generate a mock event + yield Event( + author='system', + invocation_id='test_invocation', + content=types.Content( + role='model', parts=[types.Part(text='Mock event')] + ), + ) + + async def process_llm_request( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> None: + self.process_llm_request_called = True + self.call_order.append('process_llm_request') + + async def get_tools(self, readonly_context=None): + return [] + + async def close(self): + pass + + mock_toolset = _MockToolset() + + # Create a mock model that returns a simple response + mock_response = LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text='Test response')] + ), + partial=False, + ) + + mock_model = testing_utils.MockModel.create(responses=[mock_response]) + + # Create agent with the mock toolset + agent = Agent(name='test_agent', model=mock_model, tools=[mock_toolset]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + + flow = BaseLlmFlowForTesting() + + # Call _preprocess_async + llm_request = LlmRequest() + events = [] + async for event in flow._preprocess_async(invocation_context, llm_request): + events.append(event) + + # Verify that both methods were called + assert mock_toolset.generate_preprocessing_events_called + assert mock_toolset.process_llm_request_called + + # Verify the correct call order (generate_preprocessing_events first) + assert mock_toolset.call_order == [ + 'generate_preprocessing_events', + 'process_llm_request', + ] + + # Verify that the generated event was yielded + assert len(events) == 1 + assert events[0].author == 'system' + assert events[0].content.parts[0].text == 'Mock event' + + @pytest.mark.asyncio async def test_preprocess_handles_mixed_tools_and_toolsets(): """Test that _preprocess_async properly handles both tools and toolsets."""