diff --git a/src/google/adk/tools/vertex_ai_search_tool.py b/src/google/adk/tools/vertex_ai_search_tool.py index aff5be1552..8d9e4b0875 100644 --- a/src/google/adk/tools/vertex_ai_search_tool.py +++ b/src/google/adk/tools/vertex_ai_search_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import logging from typing import Optional from typing import TYPE_CHECKING @@ -25,6 +26,8 @@ from .base_tool import BaseTool from .tool_context import ToolContext +_logger = logging.getLogger(__name__) + if TYPE_CHECKING: from ..models import LlmRequest @@ -37,6 +40,26 @@ class VertexAiSearchTool(BaseTool): search_engine_id: The Vertex AI search engine resource ID. """ + @staticmethod + def _extract_resource_id(resource_path: str, resource_type: str) -> str: + """Extracts the resource ID from a full resource path. + + Args: + resource_path: The full resource path (e.g., "projects/p/locations/l/collections/c/engines/e") + resource_type: The type of resource to extract (e.g., 'engines', 'dataStores') + + Returns: + The extracted resource ID + """ + parts = resource_path.split('/') + try: + resource_index = parts.index(resource_type) + if resource_index + 1 < len(parts): + return parts[resource_index + 1] + except ValueError: + pass + return resource_path # Return original if pattern not matched + def __init__( self, *, @@ -83,6 +106,11 @@ def __init__( self.data_store_id = data_store_id self.data_store_specs = data_store_specs self.search_engine_id = search_engine_id + self._search_engine_name = ( + self._extract_resource_id(search_engine_id, 'engines') + if search_engine_id + else None + ) self.filter = filter self.max_results = max_results self.bypass_multi_tools_limit = bypass_multi_tools_limit @@ -102,6 +130,15 @@ async def process_llm_request( ) llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] + _logger.debug( + 'Adding Vertex AI Search tool config to LLM request: datastore=%s,' + ' engine=%s, filter=%s, max_results=%s, data_store_specs=%s', + self.data_store_id, + self._search_engine_name or self.search_engine_id, + self.filter, + self.max_results, + self.data_store_specs, + ) llm_request.config.tools.append( types.Tool( retrieval=types.Retrieval( diff --git a/tests/unittests/tools/test_vertex_ai_search_tool.py b/tests/unittests/tools/test_vertex_ai_search_tool.py index 0df19288a3..c4a3f15101 100644 --- a/tests/unittests/tools/test_vertex_ai_search_tool.py +++ b/tests/unittests/tools/test_vertex_ai_search_tool.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.sequential_agent import SequentialAgent from google.adk.models.llm_request import LlmRequest @@ -24,6 +26,8 @@ from google.genai import types import pytest +VERTEX_SEARCH_TOOL_LOGGER_NAME = 'google.adk.tools.vertex_ai_search_tool' + async def _create_tool_context() -> ToolContext: session_service = InMemorySessionService() @@ -121,12 +125,32 @@ def test_init_with_data_store_id(self): tool = VertexAiSearchTool(data_store_id='test_data_store') assert tool.data_store_id == 'test_data_store' assert tool.search_engine_id is None + assert tool.data_store_specs is None def test_init_with_search_engine_id(self): """Test initialization with search engine ID.""" tool = VertexAiSearchTool(search_engine_id='test_search_engine') assert tool.search_engine_id == 'test_search_engine' assert tool.data_store_id is None + assert tool.data_store_specs is None + + def test_init_with_engine_and_specs(self): + """Test initialization with search engine ID and specs.""" + specs = [ + types.VertexAISearchDataStoreSpec( + dataStore='projects/p/locations/l/collections/default_collection/dataStores/spec_store_id' + ) + ] + tool = VertexAiSearchTool( + search_engine_id='projects/p/locations/l/collections/default_collection/engines/test_search_engine', + data_store_specs=specs, + ) + assert ( + tool.search_engine_id + == 'projects/p/locations/l/collections/default_collection/engines/test_search_engine' + ) + assert tool.data_store_id is None + assert tool.data_store_specs == specs def test_init_with_neither_raises_error(self): """Test that initialization without either ID raises ValueError.""" @@ -146,10 +170,27 @@ def test_init_with_both_raises_error(self): data_store_id='test_data_store', search_engine_id='test_search_engine' ) + def test_init_with_specs_but_no_engine_raises_error(self): + """Test that specs without engine ID raises ValueError.""" + specs = [ + types.VertexAISearchDataStoreSpec( + dataStore='projects/p/locations/l/collections/default_collection/dataStores/spec_store_id' + ) + ] + with pytest.raises( + ValueError, + match='Either data_store_id or search_engine_id must be specified', + ): + VertexAiSearchTool(data_store_specs=specs) + @pytest.mark.asyncio - async def test_process_llm_request_with_simple_gemini_model(self): + async def test_process_llm_request_with_simple_gemini_model(self, caplog): """Test processing LLM request with simple Gemini model name.""" - tool = VertexAiSearchTool(data_store_id='test_data_store') + caplog.set_level(logging.DEBUG, logger=VERTEX_SEARCH_TOOL_LOGGER_NAME) + + tool = VertexAiSearchTool( + data_store_id='test_data_store', filter='f', max_results=5 + ) tool_context = await _create_tool_context() llm_request = LlmRequest( @@ -162,17 +203,50 @@ async def test_process_llm_request_with_simple_gemini_model(self): assert llm_request.config.tools is not None assert len(llm_request.config.tools) == 1 - assert llm_request.config.tools[0].retrieval is not None - assert llm_request.config.tools[0].retrieval.vertex_ai_search is not None + retrieval_tool = llm_request.config.tools[0] + assert retrieval_tool.retrieval is not None + assert retrieval_tool.retrieval.vertex_ai_search is not None + assert ( + retrieval_tool.retrieval.vertex_ai_search.datastore == 'test_data_store' + ) + assert retrieval_tool.retrieval.vertex_ai_search.engine is None + assert retrieval_tool.retrieval.vertex_ai_search.filter == 'f' + assert retrieval_tool.retrieval.vertex_ai_search.max_results == 5 + assert retrieval_tool.retrieval.vertex_ai_search.data_store_specs is None + + # Check for debug log message and its components + debug_records = [ + r for r in caplog.records if 'Adding Vertex AI Search tool config' in r.message + ] + assert len(debug_records) == 1 + log_message = debug_records[0].getMessage() + assert 'Adding Vertex AI Search tool config to LLM request' in log_message + assert 'datastore=test_data_store' in log_message + assert 'engine=None' in log_message + assert 'filter=f' in log_message + assert 'max_results=5' in log_message + assert 'data_store_specs=None' in log_message @pytest.mark.asyncio - async def test_process_llm_request_with_path_based_gemini_model(self): + async def test_process_llm_request_with_path_based_gemini_model(self, caplog): """Test processing LLM request with path-based Gemini model name.""" - tool = VertexAiSearchTool(data_store_id='test_data_store') + caplog.set_level(logging.DEBUG, logger=VERTEX_SEARCH_TOOL_LOGGER_NAME) + + specs = [ + types.VertexAISearchDataStoreSpec( + dataStore='projects/p/locations/l/collections/default_collection/dataStores/spec_store_id' + ) + ] + tool = VertexAiSearchTool( + search_engine_id='projects/p/locations/l/collections/default_collection/engines/test_engine', + data_store_specs=specs, + filter='f2', + max_results=10, + ) tool_context = await _create_tool_context() llm_request = LlmRequest( - model='projects/265104255505/locations/us-central1/publishers/google/models/gemini-2.0-flash-001', + model='projects/p/locations/l/publishers/g/models/gemini-2.0-flash-001', config=types.GenerateContentConfig(), ) @@ -182,8 +256,30 @@ async def test_process_llm_request_with_path_based_gemini_model(self): assert llm_request.config.tools is not None assert len(llm_request.config.tools) == 1 - assert llm_request.config.tools[0].retrieval is not None - assert llm_request.config.tools[0].retrieval.vertex_ai_search is not None + retrieval_tool = llm_request.config.tools[0] + assert retrieval_tool.retrieval is not None + assert retrieval_tool.retrieval.vertex_ai_search is not None + assert retrieval_tool.retrieval.vertex_ai_search.datastore is None + assert ( + retrieval_tool.retrieval.vertex_ai_search.engine + == 'projects/p/locations/l/collections/default_collection/engines/test_engine' + ) + assert retrieval_tool.retrieval.vertex_ai_search.filter == 'f2' + assert retrieval_tool.retrieval.vertex_ai_search.max_results == 10 + assert retrieval_tool.retrieval.vertex_ai_search.data_store_specs == specs + + # Check for debug log message and its components + debug_messages = [ + r.message for r in caplog.records if r.levelno == logging.DEBUG + ] + debug_message = '\n'.join(debug_messages) + assert 'Adding Vertex AI Search tool config to LLM request' in debug_message + assert 'datastore=None' in debug_message + assert 'engine=test_engine' in debug_message + assert 'filter=f2' in debug_message + assert 'max_results=10' in debug_message + assert 'data_store_specs=' in debug_message + assert 'spec_store_id' in debug_message @pytest.mark.asyncio async def test_process_llm_request_with_gemini_1_and_other_tools_raises_error( @@ -230,7 +326,9 @@ async def test_process_llm_request_with_path_based_gemini_1_and_other_tools_rais ) llm_request = LlmRequest( - model='projects/265104255505/locations/us-central1/publishers/google/models/gemini-1.5-pro-preview', + model=( + 'projects/p/locations/l/publishers/g/models/gemini-1.5-pro-preview' + ), config=types.GenerateContentConfig(tools=[existing_tool]), ) @@ -273,7 +371,9 @@ async def test_process_llm_request_with_path_based_non_gemini_model_raises_error tool = VertexAiSearchTool(data_store_id='test_data_store') tool_context = await _create_tool_context() - non_gemini_path = 'projects/265104255505/locations/us-central1/publishers/google/models/claude-3-sonnet' + non_gemini_path = ( + 'projects/p/locations/l/publishers/g/models/claude-3-sonnet' + ) llm_request = LlmRequest( model=non_gemini_path, config=types.GenerateContentConfig() ) @@ -291,9 +391,11 @@ async def test_process_llm_request_with_path_based_non_gemini_model_raises_error @pytest.mark.asyncio async def test_process_llm_request_with_gemini_2_and_other_tools_succeeds( - self, + self, caplog ): """Test that Gemini 2.x with other tools succeeds.""" + caplog.set_level(logging.DEBUG, logger=VERTEX_SEARCH_TOOL_LOGGER_NAME) + tool = VertexAiSearchTool(data_store_id='test_data_store') tool_context = await _create_tool_context() @@ -316,5 +418,21 @@ async def test_process_llm_request_with_gemini_2_and_other_tools_succeeds( assert llm_request.config.tools is not None assert len(llm_request.config.tools) == 2 assert llm_request.config.tools[0] == existing_tool - assert llm_request.config.tools[1].retrieval is not None - assert llm_request.config.tools[1].retrieval.vertex_ai_search is not None + retrieval_tool = llm_request.config.tools[1] + assert retrieval_tool.retrieval is not None + assert retrieval_tool.retrieval.vertex_ai_search is not None + assert ( + retrieval_tool.retrieval.vertex_ai_search.datastore == 'test_data_store' + ) + + debug_records = [ + r for r in caplog.records if 'Adding Vertex AI Search tool config' in r.message + ] + assert len(debug_records) == 1 + log_message = debug_records[0].getMessage() + assert 'Adding Vertex AI Search tool config to LLM request' in log_message + assert 'datastore=test_data_store' in log_message + assert 'engine=None' in log_message + assert 'filter=None' in log_message + assert 'max_results=None' in log_message + assert 'data_store_specs=None' in log_message