diff --git a/nemo_skills/mcp/servers/tavily_search_tool.py b/nemo_skills/mcp/servers/tavily_search_tool.py index ef4ca54ff5..f72f6761dd 100644 --- a/nemo_skills/mcp/servers/tavily_search_tool.py +++ b/nemo_skills/mcp/servers/tavily_search_tool.py @@ -40,18 +40,28 @@ class ExecutionResult: TAVILY_API_KEY: str | None = None EXCLUDE_DOMAINS: list[str] | None = None +MAX_NUM_RESULTS: int = 20 ## See docs https://docs.tavily.com/documentation/api-reference/endpoint/search ## There is also a hosted MCP that can be used instead of this tool: https://github.com/tavily-ai/tavily-mcp?tab=readme-ov-file#remote-mcp-server -@mcp.tool(name="tavily-search") +@mcp.tool(name="web-search") async def answer( query: Annotated[str, Field(description="Search query.")], exclude_domains: Annotated[list[str], Field(description="Domains to exclude from the search.")] = [], + num_results: Annotated[int, Field(description="Number of results to return.")] = 10, + answer_type: Annotated[ + str, + Field( + description='Type of results to return. Choose "answer" for a concise answer or "results" for a list of results.' + ), + ] = "answer", ): - """Get a summary of search results from the web using Tavily.""" + """Search the web for a query""" api_url = "https://api.tavily.com/search" + assert answer_type in ["answer", "results"], "Invalid answer type. Choose 'answer' or 'results'." + assert num_results <= MAX_NUM_RESULTS, f"Number of results must be less than or equal to {MAX_NUM_RESULTS}." headers = { "Authorization": f"Bearer {TAVILY_API_KEY}", @@ -63,6 +73,7 @@ async def answer( # "auto_parameters": False, "search_depth": "basic", "include_answer": "basic", ## or advanced. + "num_results": num_results, # this should be statically set to the domains we want to exclude "exclude_domains": exclude_domains, } @@ -72,7 +83,7 @@ async def answer( if response.status_code != 200: return {"error": response.json()["error"]} - result = response.json()["answer"] + result = response.json()[answer_type] return result @@ -99,7 +110,7 @@ def __init__(self) -> None: "args": ["-m", "nemo_skills.mcp.servers.tavily_search_tool"], }, "hide_args": { - "tavily-search": ["exclude_domains"], + "web-search": ["exclude_domains", "num_results", "answer_type"], }, "exclude_domains_config": None, } @@ -120,6 +131,9 @@ async def execute(self, tool_name: str, arguments: dict[str, Any], extra_args: d if not hasattr(self, "exclude_domains"): raise ValueError("exclude_domains_config is not set") merged_extra["exclude_domains"] = self.exclude_domains + for key in ["num_results", "answer_type"]: + if key in self._config: + merged_extra[key] = self._config[key] result = await self._client.call_tool(tool=tool_name, args=arguments, extra_args=merged_extra) return result