diff --git a/nemo_skills/inference/generate.py b/nemo_skills/inference/generate.py index aae36c7351..136375db46 100644 --- a/nemo_skills/inference/generate.py +++ b/nemo_skills/inference/generate.py @@ -177,6 +177,25 @@ class GenerateSolutionsConfig: # - Set an ExampleTool server-only arg: # ++tool_overrides.ExampleTool.foo_argument='[TEST] ' tool_overrides: dict | None = field(default_factory=dict) + # + # Schema overrides allow customizing tool schemas shown to the model. + # Dict keyed by provider class name (like tool_overrides), then tool name. + # Format: ProviderClassName -> tool_name -> (name, description, parameters) + # + # Example YAML configuration (config.yaml): + # schema_overrides: + # PythonTool: + # stateful_python_code_exec: + # name: "python_executor" + # description: "Evaluate Python code interactively" + # parameters: + # code: + # name: "script" + # description: "Python code to execute" + # + # To use this config with Hydra, launch your script with: + # --config-path /path/to/configs --config-name config + schema_overrides: dict | None = field(default_factory=dict) # if True, will move full generation to _full_generation key and keep cfg.generation_key without thinking tokens # IMPORTANT: do not set this for non-reasoning models as it will make the generations empty! @@ -387,6 +406,7 @@ def setup_llm(self): **self.cfg.server, tool_modules=self.cfg.tool_modules, tool_overrides=self.cfg.tool_overrides, + schema_overrides=self.cfg.schema_overrides, tokenizer=self.tokenizer, additional_config={"sandbox": self.cfg.sandbox}, ) diff --git a/nemo_skills/inference/model/__init__.py b/nemo_skills/inference/model/__init__.py index bd2e246499..164d92fcc8 100644 --- a/nemo_skills/inference/model/__init__.py +++ b/nemo_skills/inference/model/__init__.py @@ -122,6 +122,7 @@ def get_tool_calling_model( additional_config=None, tool_modules: list[str] | None = None, tool_overrides: dict | None = None, + schema_overrides: dict | None = None, **kwargs, ): if isinstance(model, str): @@ -131,6 +132,7 @@ def get_tool_calling_model( tool_modules=tool_modules, tool_overrides=tool_overrides, additional_config=additional_config, + schema_overrides=schema_overrides, ) diff --git a/nemo_skills/inference/model/tool_call.py b/nemo_skills/inference/model/tool_call.py index ffbd1a9921..2891389bd8 100644 --- a/nemo_skills/inference/model/tool_call.py +++ b/nemo_skills/inference/model/tool_call.py @@ -24,6 +24,8 @@ format_tool_list_by_endpoint_type, format_tool_response_by_endpoint_type, get_tool_details_by_endpoint_type, + load_schema_overrides, + remap_tool_call, ) from nemo_skills.mcp.tool_manager import ToolManager from nemo_skills.utils import get_logger_name @@ -46,13 +48,11 @@ def __init__( tool_modules: list[str] | None = None, tool_overrides: dict | None = None, additional_config: dict | None = None, + schema_overrides: dict | None = None, ): self.model = model additional_config = additional_config or {} - self.tool_manager = None - - # Module-based tool loading only assert tool_modules, "tool_modules must be provided for tool calling" self.tool_manager = ToolManager( module_specs=tool_modules, @@ -60,6 +60,9 @@ def __init__( context=additional_config, ) + self.schema_overrides = load_schema_overrides(schema_overrides) + self.schema_mappings = {} # Built when tools are listed + async def _execute_tool_call(self, tool_call, request_id: str, endpoint_type: EndpointType): ## TODO(sanyamk): The correct key format needs to be cohesive with other formatters. tool_name, tool_args = get_tool_details_by_endpoint_type(tool_call, endpoint_type) @@ -67,6 +70,7 @@ async def _execute_tool_call(self, tool_call, request_id: str, endpoint_type: En ## # TODO(sanyamk): Not all tool arguments might necessarily be in JSON format. # Kept here to handle errors for now. + try: tool_args = json.loads(tool_args) except json.decoder.JSONDecodeError as e: @@ -75,9 +79,14 @@ async def _execute_tool_call(self, tool_call, request_id: str, endpoint_type: En return {"error": "Tool argument parsing failed."} ## TODO(sanyamk): Only exceptions related to tool execution here, all others must fail. + # Remap model's tool name/args back to original schema + original_tool_name, tool_args = remap_tool_call(tool_name, tool_args, self.schema_mappings) + try: # Allow providers to specify extra_args behavior internally if needed in the future - result = await self.tool_manager.execute_tool(tool_name, tool_args, extra_args={"request_id": request_id}) + result = await self.tool_manager.execute_tool( + original_tool_name, tool_args, extra_args={"request_id": request_id} + ) except Exception as e: LOG.exception(e) return {"error": "Tool execution failed."} @@ -109,7 +118,9 @@ async def generate_async( # This assumes that the available tools do not change during the generation. raw_tools = await self.tool_manager.list_all_tools(use_cache=True) - tools = format_tool_list_by_endpoint_type(raw_tools, endpoint_type) + tools, self.schema_mappings = format_tool_list_by_endpoint_type( + raw_tools, endpoint_type, schema_overrides=self.schema_overrides + ) LOG.info("Available Tools: %s", tools) result_steps = defaultdict(list) @@ -156,5 +167,6 @@ async def generate_async( result_steps["num_generated_tokens"] = sum(result_steps["num_generated_tokens"]) result_steps["num_tool_calls"] = sum(result_steps["num_tool_calls"]) result_steps["conversation"] = conversation + result_steps["tools"] = tools # Schema sent to model (with overrides applied) return result_steps diff --git a/nemo_skills/mcp/adapters.py b/nemo_skills/mcp/adapters.py index a927000403..2619c783f2 100644 --- a/nemo_skills/mcp/adapters.py +++ b/nemo_skills/mcp/adapters.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import json from abc import ABC, abstractmethod +from typing import Any, Dict from litellm.types.utils import ChatCompletionMessageToolCall +from omegaconf import DictConfig, OmegaConf from nemo_skills.inference.model.base import EndpointType @@ -48,9 +51,123 @@ def format(self, tool_call: ChatCompletionMessageToolCall, result: dict) -> dict # ============================== -def format_tool_list_by_endpoint_type(tools, endpoint_type: EndpointType): +def load_schema_overrides(schema_overrides: dict | None) -> Dict[str, Dict[str, Dict[str, Any]]]: + """ + Normalize schema overrides dict from Hydra/OmegaConf. + + Args: + schema_overrides: Dict keyed by provider class name, then tool name, or None. + Format: ProviderClassName -> tool_name -> (name, description, parameters) + + Returns: + Normalized dict ready for use with format_tool_list_by_endpoint_type + """ + if schema_overrides is None: + return {} + + if isinstance(schema_overrides, DictConfig): + schema_overrides = OmegaConf.to_container(schema_overrides, resolve=True) + + if not isinstance(schema_overrides, dict): + raise ValueError(f"schema_overrides must be dict or None, got {type(schema_overrides)}") + + normalized = {} + for provider_class, provider_overrides in schema_overrides.items(): + if not isinstance(provider_overrides, dict): + raise ValueError(f"Override for provider '{provider_class}' must be a dict") + + normalized[provider_class] = {} + for tool_name, cfg in provider_overrides.items(): + if not isinstance(cfg, dict): + raise ValueError(f"Override for tool '{tool_name}' in '{provider_class}' must be a dict") + normalized[provider_class][tool_name] = { + "name": cfg.get("name"), + "description": cfg.get("description"), + "parameters": cfg.get("parameters"), + } + + return normalized + + +def apply_schema_overrides( + tool: Dict[str, Any], override_config: Dict[str, Any] | None +) -> tuple[Dict[str, Any], Dict[str, str]]: + """Apply schema overrides to a tool. Returns (transformed_tool, {new_param: orig_param}).""" + if not override_config: + return tool, {} + + transformed = copy.deepcopy(tool) + for key in ("name", "description"): + if override_config.get(key) is not None: + transformed[key] = override_config[key] + + param_overrides = override_config.get("parameters", {}) + if not param_overrides: + return transformed, {} + + schema = transformed.get("input_schema", {}) + props, required = schema.get("properties", {}), set(schema.get("required", [])) + + for name, cfg in param_overrides.items(): + if name not in props: + raise ValueError(f"Parameter '{name}' not in schema") + if not isinstance(cfg, dict): + raise ValueError(f"Override for '{name}' must be a dict") + + new_props, new_required, mapping = {}, [], {} + for orig, param in props.items(): + ovr = param_overrides.get(orig, {}) + new = ovr.get("name", orig) + new_props[new] = {**param, **{k: v for k, v in ovr.items() if k != "name"}} + if new != orig: + mapping[new] = orig + if orig in required: + new_required.append(new) + + transformed["input_schema"] = {**schema, "properties": new_props, "required": new_required} + return transformed, mapping + + +def remap_tool_call(tool_name: str, args: dict, mappings: dict) -> tuple[str, dict]: + """Remap a tool call from model names back to original tool schema names.""" + original_tool = mappings.get("tool_names", {}).get(tool_name, tool_name) + param_mapping = mappings.get("parameters", {}).get(tool_name, {}) + original_args = {param_mapping.get(k, k): v for k, v in args.items()} + return original_tool, original_args + + +def format_tool_list_by_endpoint_type( + tools, endpoint_type: EndpointType, schema_overrides: Dict[str, Dict[str, Dict[str, Any]]] | None = None +) -> tuple[list[Dict[str, Any]], Dict[str, Any]]: + """ + Format tool list for the given endpoint type, applying schema overrides. + + Returns: + Tuple of (formatted_tools, mappings_dict) where mappings_dict has: + - "tool_names": {model_name: original_name} + - "parameters": {tool_name: {model_param: original_param}} + """ + schema_overrides = schema_overrides or {} + mappings = {"tool_names": {}, "parameters": {}} + transformed_tools = [] + + for tool in tools: + original_name = tool["name"] + provider = schema_overrides.get(tool.get("server")) or {} + override = provider.get(original_name) + + transformed, param_mapping = apply_schema_overrides(tool, override) + transformed_tools.append(transformed) + + new_name = transformed["name"] + if new_name != original_name: + mappings["tool_names"][new_name] = original_name + if param_mapping: + mappings["parameters"][new_name] = param_mapping + + # Format for endpoint type if endpoint_type == EndpointType.chat: - return [ + formatted = [ { "type": "function", "function": { @@ -59,10 +176,10 @@ def format_tool_list_by_endpoint_type(tools, endpoint_type: EndpointType): "parameters": t["input_schema"], }, } - for t in tools + for t in transformed_tools ] elif endpoint_type == EndpointType.responses: - return [ + formatted = [ { "type": "function", "name": t["name"], @@ -70,11 +187,13 @@ def format_tool_list_by_endpoint_type(tools, endpoint_type: EndpointType): "parameters": t["input_schema"], "strict": True, # Less vllm errors through structured output } - for t in tools + for t in transformed_tools ] else: raise ValueError(f"Unsupported completion type for tool list: {endpoint_type}") + return formatted, mappings + class OpenAICallInterpreter(ToolCallInterpreter): def parse(self, tool_call): diff --git a/tests/test_mcp_clients.py b/tests/test_mcp_clients.py index bf893e1e19..a26c6df277 100644 --- a/tests/test_mcp_clients.py +++ b/tests/test_mcp_clients.py @@ -556,3 +556,52 @@ async def __aexit__(self, exc_type, exc, tb): client = MCPStreamableHttpClient(base_url="https://example.com/mcp", enabled_tools=["only_t2"]) # not including t1 with pytest.raises(PermissionError): await client.call_tool("t1", {}) + + +@pytest.mark.asyncio +async def test_tool_manager_with_schema_overrides(): + """Test ToolManager integration with schema overrides.""" + from nemo_skills.inference.model.base import EndpointType + from nemo_skills.mcp.adapters import format_tool_list_by_endpoint_type, load_schema_overrides + + tm = ToolManager(module_specs=[f"{__name__}::DummyTool"], overrides={}, context={}) + tools = await tm.list_all_tools(use_cache=False) + + schema_overrides = { + "DummyTool": { + "execute": { + "name": "renamed_execute", + "parameters": {"code": {"name": "script"}}, # rename 'code' -> 'script' for model + } + } + } + loaded_overrides = load_schema_overrides(schema_overrides) + formatted_tools, mappings = format_tool_list_by_endpoint_type( + tools, EndpointType.chat, schema_overrides=loaded_overrides + ) + + renamed_tool = next((t for t in formatted_tools if t["function"]["name"] == "renamed_execute"), None) + assert renamed_tool is not None + assert "script" in renamed_tool["function"]["parameters"]["properties"] + assert "code" not in renamed_tool["function"]["parameters"]["properties"] + assert mappings["parameters"]["renamed_execute"] == {"script": "code"} + assert mappings["tool_names"]["renamed_execute"] == "execute" + + +def test_schema_override_nonexistent_param_fails(): + """Overriding a parameter that doesn't exist in the schema must fail early. + + This also covers the hidden-arg case: when hide_args removes a param from the + schema before overrides are applied, attempting to override that (now-missing) + param will trigger the same error. + """ + from nemo_skills.mcp.adapters import apply_schema_overrides + + tool = { + "name": "test", + "description": "Test", + "input_schema": {"type": "object", "properties": {"code": {"type": "string"}}, "required": []}, + } + # Try to override 'script' which doesn't exist (tool only has 'code') + with pytest.raises(ValueError, match="Parameter 'script' not in schema"): + apply_schema_overrides(tool, {"parameters": {"script": {"name": "renamed"}}})