Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions nemo_skills/mcp/servers/tavily_search_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import logging
import os
from dataclasses import dataclass
from typing import Annotated, Any

import httpx
from mcp.server.fastmcp import FastMCP
from pydantic import Field

from nemo_skills.mcp.tool_providers import MCPClientTool

logger = logging.getLogger(__name__)


@dataclass
class ExecutionResult:
error: str | None = None
result: str | None = None


mcp = FastMCP(name="tavily")

# Populated from CLI args in main()
TAVILY_API_KEY: str | None = None

EXCLUDE_DOMAINS: list[str] | None = None


## See docs https://docs.tavily.com/documentation/api-reference/endpoint/search
## There is also a hosted MCP that can be used instead of this tool: https://github.com/tavily-ai/tavily-mcp?tab=readme-ov-file#remote-mcp-server
@mcp.tool(name="tavily-search")
async def answer(
query: Annotated[str, Field(description="Search query.")],
exclude_domains: Annotated[list[str], Field(description="Domains to exclude from the search.")] = [],
):
"""Get a summary of search results from the web using Tavily."""

api_url = "https://api.tavily.com/search"

headers = {
"Authorization": f"Bearer {TAVILY_API_KEY}",
"Content-Type": "application/json",
}

payload = {
"query": query,
# "auto_parameters": False,
"search_depth": "basic",
"include_answer": "basic", ## or advanced.
# this should be statically set to the domains we want to exclude
"exclude_domains": exclude_domains,
}

async with httpx.AsyncClient() as client:
response = await client.post(api_url, headers=headers, json=payload)
if response.status_code != 200:
return {"error": response.json()["error"]}

result = response.json()["answer"]
Comment on lines +70 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add error handling for JSON parsing and missing keys.

If the Tavily API returns a non-JSON error response or an unexpected JSON structure, the current code will raise unhandled exceptions (JSONDecodeError or KeyError).

Consider wrapping the response parsing in try-except:

     async with httpx.AsyncClient(timeout=30.0) as client:
         response = await client.post(api_url, headers=headers, json=payload)
+        try:
+            data = response.json()
+        except Exception as e:
+            return ExecutionResult(error=f"Failed to parse response: {e}")
+        
         if response.status_code != 200:
-            error_detail = response.json().get("error", response.text)
+            error_detail = data.get("error", response.text)
             return ExecutionResult(error=str(error_detail))
 
-        result = response.json()["answer"]
+        result = data.get("answer")
+        if result is None:
+            return ExecutionResult(error="No answer in response")
 
     return ExecutionResult(result=result)

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In nemo_skills/mcp/servers/tavily_search_tool.py around lines 64 to 69, the code
assumes response.json() succeeds and the "error" or "answer" keys exist; wrap
the JSON parsing and key access in a try/except block that catches
JSONDecodeError and KeyError (or use response.is_json/content-type check),
attempt to parse JSON safely for both non-200 and 200 cases, and return a clear
error dict when parsing fails or keys are missing (e.g., include response.text
and status_code), otherwise extract and return result = parsed.get("answer")
after validating it's present.


return result


def _parse_exclude_domains(exclude_config: dict) -> list[str]:
exclude_domains = []
# this is pretty hard-coded so we ensure the file structure is correct
notices = exclude_config["notices"]
for notice in notices:
for prop in notice["properties"]:
if prop.get("type") == "domain":
exclude_domains.append(prop["value"])
return exclude_domains


class TavilySearchTool(MCPClientTool):
def __init__(self) -> None:
super().__init__()
self.apply_config_updates(
{
"client": "nemo_skills.mcp.clients.MCPStdioClient",
"client_params": {
"command": "python",
"args": ["-m", "nemo_skills.mcp.servers.tavily_search_tool"],
},
"hide_args": {
"tavily-search": ["exclude_domains"],
},
"exclude_domains_config": None,
}
)

def post_configure(self) -> None:
# Required the exclude domains to be set--we do not want to accidentally include all domains
if (conf := self._config.get("exclude_domains_config")) is not None:
with open(conf, "r") as f:
exlude_config = json.load(f)
self.exclude_domains = _parse_exclude_domains(exlude_config)
else:
raise ValueError("exclude_domains_config is not set")

async def execute(self, tool_name: str, arguments: dict[str, Any], extra_args: dict[str, Any] | None = None):
arguments = dict(arguments)
merged_extra = dict(extra_args or {})
if not hasattr(self, "exclude_domains"):
raise ValueError("exclude_domains_config is not set")
merged_extra["exclude_domains"] = self.exclude_domains
result = await self._client.call_tool(tool=tool_name, args=arguments, extra_args=merged_extra)
return result


def main():
parser = argparse.ArgumentParser(description="MCP server for Tavily web search tool")
parser.add_argument("--api-key", type=str, default=os.getenv("TAVILY_API_KEY"), help="Tavily API Key")
args = parser.parse_args()

if not args.api_key:
raise ValueError("Missing Tavily API key.")

global TAVILY_API_KEY
TAVILY_API_KEY = args.api_key

mcp.run(transport="stdio")


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions nemo_skills/mcp/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ async def execute(
async def shutdown(self) -> None: # Optional hook
return None

def post_configure(self) -> None:
return None


class ToolManager:
"""Registry/Router for module-based tools.
Expand Down Expand Up @@ -98,6 +101,7 @@ def __init__(
raise ValueError(f"Duplicate tool class registered: '{provider_key}'")

tool.configure((overrides.get(provider_key) if overrides else None), context)
tool.post_configure()
self._tools[provider_key] = tool

async def shutdown(self) -> None:
Expand Down
3 changes: 3 additions & 0 deletions nemo_skills/mcp/tool_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def _resolve_maybe_callable(self, value: Any):
return value
return value

def post_configure(self) -> None:
pass

def configure(self, overrides: Dict[str, Any] | None = None, context: Dict[str, Any] | None = None) -> None:
cfg = dict(self._config)
if overrides:
Expand Down