Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions nemo_skills/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,25 @@ class GenerateSolutionsConfig:
# - Set an ExampleTool server-only arg:
# ++tool_overrides.ExampleTool.foo_argument='[TEST] '
tool_overrides: dict | None = field(default_factory=dict)
#
# Schema overrides allow customizing tool schemas shown to the model.
# Dict keyed by provider class name (like tool_overrides), then tool name.
# Format: ProviderClassName -> tool_name -> (name, description, parameters)
#
# Example YAML configuration (config.yaml):
# schema_overrides:
# PythonTool:
# stateful_python_code_exec:
# name: "python_executor"
# description: "Evaluate Python code interactively"
# parameters:
# code:
# name: "script"
# description: "Python code to execute"
#
# To use this config with Hydra, launch your script with:
# --config-path /path/to/configs --config-name config
schema_overrides: dict | None = field(default_factory=dict)

# if True, will move full generation to _full_generation key and keep cfg.generation_key without thinking tokens
# IMPORTANT: do not set this for non-reasoning models as it will make the generations empty!
Expand Down Expand Up @@ -387,6 +406,7 @@ def setup_llm(self):
**self.cfg.server,
tool_modules=self.cfg.tool_modules,
tool_overrides=self.cfg.tool_overrides,
schema_overrides=self.cfg.schema_overrides,
tokenizer=self.tokenizer,
additional_config={"sandbox": self.cfg.sandbox},
)
Expand Down
2 changes: 2 additions & 0 deletions nemo_skills/inference/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def get_tool_calling_model(
additional_config=None,
tool_modules: list[str] | None = None,
tool_overrides: dict | None = None,
schema_overrides: dict | None = None,
**kwargs,
):
if isinstance(model, str):
Expand All @@ -131,6 +132,7 @@ def get_tool_calling_model(
tool_modules=tool_modules,
tool_overrides=tool_overrides,
additional_config=additional_config,
schema_overrides=schema_overrides,
)


Expand Down
22 changes: 17 additions & 5 deletions nemo_skills/inference/model/tool_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
format_tool_list_by_endpoint_type,
format_tool_response_by_endpoint_type,
get_tool_details_by_endpoint_type,
load_schema_overrides,
remap_tool_call,
)
from nemo_skills.mcp.tool_manager import ToolManager
from nemo_skills.utils import get_logger_name
Expand All @@ -46,27 +48,29 @@ def __init__(
tool_modules: list[str] | None = None,
tool_overrides: dict | None = None,
additional_config: dict | None = None,
schema_overrides: dict | None = None,
):
self.model = model
additional_config = additional_config or {}

self.tool_manager = None

# Module-based tool loading only
assert tool_modules, "tool_modules must be provided for tool calling"
self.tool_manager = ToolManager(
module_specs=tool_modules,
overrides=tool_overrides or {},
context=additional_config,
)

self.schema_overrides = load_schema_overrides(schema_overrides)
self.schema_mappings = {} # Built when tools are listed

async def _execute_tool_call(self, tool_call, request_id: str, endpoint_type: EndpointType):
## TODO(sanyamk): The correct key format needs to be cohesive with other formatters.
tool_name, tool_args = get_tool_details_by_endpoint_type(tool_call, endpoint_type)

##
# TODO(sanyamk): Not all tool arguments might necessarily be in JSON format.
# Kept here to handle errors for now.

try:
tool_args = json.loads(tool_args)
except json.decoder.JSONDecodeError as e:
Expand All @@ -75,9 +79,14 @@ async def _execute_tool_call(self, tool_call, request_id: str, endpoint_type: En
return {"error": "Tool argument parsing failed."}

## TODO(sanyamk): Only exceptions related to tool execution here, all others must fail.
# Remap model's tool name/args back to original schema
original_tool_name, tool_args = remap_tool_call(tool_name, tool_args, self.schema_mappings)

try:
# Allow providers to specify extra_args behavior internally if needed in the future
result = await self.tool_manager.execute_tool(tool_name, tool_args, extra_args={"request_id": request_id})
result = await self.tool_manager.execute_tool(
original_tool_name, tool_args, extra_args={"request_id": request_id}
)
except Exception as e:
LOG.exception(e)
return {"error": "Tool execution failed."}
Expand Down Expand Up @@ -109,7 +118,9 @@ async def generate_async(

# This assumes that the available tools do not change during the generation.
raw_tools = await self.tool_manager.list_all_tools(use_cache=True)
tools = format_tool_list_by_endpoint_type(raw_tools, endpoint_type)
tools, self.schema_mappings = format_tool_list_by_endpoint_type(
raw_tools, endpoint_type, schema_overrides=self.schema_overrides
)
LOG.info("Available Tools: %s", tools)

result_steps = defaultdict(list)
Expand Down Expand Up @@ -156,5 +167,6 @@ async def generate_async(
result_steps["num_generated_tokens"] = sum(result_steps["num_generated_tokens"])
result_steps["num_tool_calls"] = sum(result_steps["num_tool_calls"])
result_steps["conversation"] = conversation
result_steps["tools"] = tools # Schema sent to model (with overrides applied)

return result_steps
129 changes: 124 additions & 5 deletions nemo_skills/mcp/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import json
from abc import ABC, abstractmethod
from typing import Any, Dict

from litellm.types.utils import ChatCompletionMessageToolCall
from omegaconf import DictConfig, OmegaConf

from nemo_skills.inference.model.base import EndpointType

Expand Down Expand Up @@ -48,9 +51,123 @@ def format(self, tool_call: ChatCompletionMessageToolCall, result: dict) -> dict
# ==============================


def format_tool_list_by_endpoint_type(tools, endpoint_type: EndpointType):
def load_schema_overrides(schema_overrides: dict | None) -> Dict[str, Dict[str, Dict[str, Any]]]:
"""
Normalize schema overrides dict from Hydra/OmegaConf.

Args:
schema_overrides: Dict keyed by provider class name, then tool name, or None.
Format: ProviderClassName -> tool_name -> (name, description, parameters)

Returns:
Normalized dict ready for use with format_tool_list_by_endpoint_type
"""
if schema_overrides is None:
return {}

if isinstance(schema_overrides, DictConfig):
schema_overrides = OmegaConf.to_container(schema_overrides, resolve=True)

if not isinstance(schema_overrides, dict):
raise ValueError(f"schema_overrides must be dict or None, got {type(schema_overrides)}")

normalized = {}
for provider_class, provider_overrides in schema_overrides.items():
if not isinstance(provider_overrides, dict):
raise ValueError(f"Override for provider '{provider_class}' must be a dict")

normalized[provider_class] = {}
for tool_name, cfg in provider_overrides.items():
if not isinstance(cfg, dict):
raise ValueError(f"Override for tool '{tool_name}' in '{provider_class}' must be a dict")
normalized[provider_class][tool_name] = {
"name": cfg.get("name"),
"description": cfg.get("description"),
"parameters": cfg.get("parameters"),
}

return normalized


def apply_schema_overrides(
tool: Dict[str, Any], override_config: Dict[str, Any] | None
) -> tuple[Dict[str, Any], Dict[str, str]]:
"""Apply schema overrides to a tool. Returns (transformed_tool, {new_param: orig_param})."""
if not override_config:
return tool, {}

transformed = copy.deepcopy(tool)
for key in ("name", "description"):
if override_config.get(key) is not None:
transformed[key] = override_config[key]

param_overrides = override_config.get("parameters", {})
if not param_overrides:
return transformed, {}

schema = transformed.get("input_schema", {})
props, required = schema.get("properties", {}), set(schema.get("required", []))

for name, cfg in param_overrides.items():
if name not in props:
raise ValueError(f"Parameter '{name}' not in schema")
if not isinstance(cfg, dict):
raise ValueError(f"Override for '{name}' must be a dict")

new_props, new_required, mapping = {}, [], {}
for orig, param in props.items():
ovr = param_overrides.get(orig, {})
new = ovr.get("name", orig)
new_props[new] = {**param, **{k: v for k, v in ovr.items() if k != "name"}}
if new != orig:
mapping[new] = orig
if orig in required:
new_required.append(new)

transformed["input_schema"] = {**schema, "properties": new_props, "required": new_required}
return transformed, mapping


def remap_tool_call(tool_name: str, args: dict, mappings: dict) -> tuple[str, dict]:
"""Remap a tool call from model names back to original tool schema names."""
original_tool = mappings.get("tool_names", {}).get(tool_name, tool_name)
param_mapping = mappings.get("parameters", {}).get(tool_name, {})
original_args = {param_mapping.get(k, k): v for k, v in args.items()}
return original_tool, original_args


def format_tool_list_by_endpoint_type(
tools, endpoint_type: EndpointType, schema_overrides: Dict[str, Dict[str, Dict[str, Any]]] | None = None
) -> tuple[list[Dict[str, Any]], Dict[str, Any]]:
"""
Format tool list for the given endpoint type, applying schema overrides.

Returns:
Tuple of (formatted_tools, mappings_dict) where mappings_dict has:
- "tool_names": {model_name: original_name}
- "parameters": {tool_name: {model_param: original_param}}
"""
schema_overrides = schema_overrides or {}
mappings = {"tool_names": {}, "parameters": {}}
transformed_tools = []

for tool in tools:
original_name = tool["name"]
provider = schema_overrides.get(tool.get("server")) or {}
override = provider.get(original_name)

transformed, param_mapping = apply_schema_overrides(tool, override)
transformed_tools.append(transformed)

new_name = transformed["name"]
if new_name != original_name:
mappings["tool_names"][new_name] = original_name
if param_mapping:
mappings["parameters"][new_name] = param_mapping

# Format for endpoint type
if endpoint_type == EndpointType.chat:
return [
formatted = [
{
"type": "function",
"function": {
Expand All @@ -59,22 +176,24 @@ def format_tool_list_by_endpoint_type(tools, endpoint_type: EndpointType):
"parameters": t["input_schema"],
},
}
for t in tools
for t in transformed_tools
]
elif endpoint_type == EndpointType.responses:
return [
formatted = [
{
"type": "function",
"name": t["name"],
"description": t["description"],
"parameters": t["input_schema"],
"strict": True, # Less vllm errors through structured output
}
for t in tools
for t in transformed_tools
]
else:
raise ValueError(f"Unsupported completion type for tool list: {endpoint_type}")

return formatted, mappings


class OpenAICallInterpreter(ToolCallInterpreter):
def parse(self, tool_call):
Expand Down
49 changes: 49 additions & 0 deletions tests/test_mcp_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,52 @@ async def __aexit__(self, exc_type, exc, tb):
client = MCPStreamableHttpClient(base_url="https://example.com/mcp", enabled_tools=["only_t2"]) # not including t1
with pytest.raises(PermissionError):
await client.call_tool("t1", {})


@pytest.mark.asyncio
async def test_tool_manager_with_schema_overrides():
"""Test ToolManager integration with schema overrides."""
from nemo_skills.inference.model.base import EndpointType
from nemo_skills.mcp.adapters import format_tool_list_by_endpoint_type, load_schema_overrides

tm = ToolManager(module_specs=[f"{__name__}::DummyTool"], overrides={}, context={})
tools = await tm.list_all_tools(use_cache=False)

schema_overrides = {
"DummyTool": {
"execute": {
"name": "renamed_execute",
"parameters": {"code": {"name": "script"}}, # rename 'code' -> 'script' for model
}
}
}
loaded_overrides = load_schema_overrides(schema_overrides)
formatted_tools, mappings = format_tool_list_by_endpoint_type(
tools, EndpointType.chat, schema_overrides=loaded_overrides
)

renamed_tool = next((t for t in formatted_tools if t["function"]["name"] == "renamed_execute"), None)
assert renamed_tool is not None
assert "script" in renamed_tool["function"]["parameters"]["properties"]
assert "code" not in renamed_tool["function"]["parameters"]["properties"]
assert mappings["parameters"]["renamed_execute"] == {"script": "code"}
assert mappings["tool_names"]["renamed_execute"] == "execute"


def test_schema_override_nonexistent_param_fails():
"""Overriding a parameter that doesn't exist in the schema must fail early.

This also covers the hidden-arg case: when hide_args removes a param from the
schema before overrides are applied, attempting to override that (now-missing)
param will trigger the same error.
"""
from nemo_skills.mcp.adapters import apply_schema_overrides

tool = {
"name": "test",
"description": "Test",
"input_schema": {"type": "object", "properties": {"code": {"type": "string"}}, "required": []},
}
# Try to override 'script' which doesn't exist (tool only has 'code')
with pytest.raises(ValueError, match="Parameter 'script' not in schema"):
apply_schema_overrides(tool, {"parameters": {"script": {"name": "renamed"}}})