Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions nemo_skills/inference/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import dataclasses

from nemo_skills.mcp.utils import locate
from nemo_skills.utils import python_doc_to_cmd_help

# NIM models (speech)
Expand All @@ -32,6 +33,7 @@
from .parallel_thinking import ParallelThinkingConfig, ParallelThinkingTask

# Tool Calling
from .sglang import SGLangModel
from .tool_call import ToolCallingWrapper
from .tts_nim import TTSNIMModel

Expand All @@ -49,19 +51,33 @@
"azureopenai": AzureOpenAIModel,
"gemini": GeminiModel,
"vllm": VLLMModel,
"sglang": VLLMModel,
"sglang": SGLangModel,
"tts_nim": TTSNIMModel,
"asr_nim": ASRNIMModel,
}


def get_model(server_type, tokenizer=None, **kwargs):
"""A helper function to make it easier to set server through cmd."""
model_class = models[server_type.lower()]
def get_model(server_type, tokenizer=None, model_class: str | None = None, **kwargs):
"""A helper function to make it easier to set server through cmd.

Args:
server_type: The type of server (vllm, sglang, openai, etc.)
tokenizer: Optional tokenizer path
model_class: Optional custom model class path to override the default for server_type.
Supports dotted module paths (e.g., 'nemo_skills.inference.model.sglang.SGLangModel')
or double-colon syntax (e.g., 'nemo_skills.inference.model.sglang::SGLangModel').
Useful for models with specific requirements (e.g., Kimi-K2 requires tool_choice='auto').
**kwargs: Additional arguments passed to the model constructor
"""
if model_class is not None:
loaded_class = locate(model_class)
else:
loaded_class = models[server_type.lower()]

if server_type == "trtllm" and kwargs.get("enable_soft_fail", False):
if kwargs.get("context_limit_retry_strategy", None) is not None:
raise ValueError("context_limit_retry_strategy is not supported for trtllm")
return model_class(tokenizer=tokenizer, **kwargs)
return loaded_class(tokenizer=tokenizer, **kwargs)


def get_code_execution_model(server_type, tokenizer=None, code_execution=None, sandbox=None, **kwargs):
Expand Down
63 changes: 63 additions & 0 deletions nemo_skills/inference/model/sglang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .vllm import VLLMModel


class SGLangModel(VLLMModel):
"""SGLang model that extends VLLMModel with proper tool_choice handling.

SGLang requires "tool_choice": "auto" in the request body when tools are provided,
unlike VLLM which uses a server argument (--enable-auto-tool-choice).
"""

def _build_chat_request_params(
self,
messages: list[dict],
stream: bool,
tokens_to_generate: int = 512,
temperature: float = 0.0,
top_p: float = 0.95,
top_k: int = -1,
min_p: float = 0.0,
repetition_penalty: float = 1.0,
random_seed: int = 0,
stop_phrases: list[str] | None = None,
timeout: int | None = None,
top_logprobs: int | None = None,
reasoning_effort: str | None = None,
tools: list[dict] | None = None,
extra_body: dict = None,
) -> dict:
request = super()._build_chat_request_params(
messages=messages,
stream=stream,
tokens_to_generate=tokens_to_generate,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
repetition_penalty=repetition_penalty,
random_seed=random_seed,
stop_phrases=stop_phrases,
timeout=timeout,
top_logprobs=top_logprobs,
reasoning_effort=reasoning_effort,
tools=tools,
extra_body=extra_body,
)
# SGLang requires tool_choice in the request body when tools are provided
if tools is not None:
request["tool_choice"] = "auto"
return request
3 changes: 3 additions & 0 deletions tests/gpu-tests/run_qwen.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ pytest tests/gpu-tests/test_nemo_evaluator.py -s -x
export NEMO_SKILLS_TEST_HF_MODEL=Qwen/Qwen3-4B-Instruct-2507
pytest tests/gpu-tests/test_contamination.py -s -x

# Tool calling tests (uses same Qwen3-4B-Instruct model)
pytest tests/gpu-tests/test_tool_calling.py -s -x

# TODO: Add fast context retry tests
# pytest tests/gpu-tests/test_context_retry.py -s -x

Expand Down
133 changes: 133 additions & 0 deletions tests/gpu-tests/test_tool_calling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import subprocess
import tempfile
from pathlib import Path

import pytest
from utils import require_env_var

from tests.conftest import docker_rm

# NOTE: Tool calling behavior is model-specific. Some models (e.g., Qwen) work with standard
# tool call parsers without requiring `tool_choice: "auto"` in the request body, while others
# (e.g., Kimi-K2) have non-standard tool_call_id formats that may require custom handling.
# See: https://huggingface.co/moonshotai/Kimi-K2-Instruct/discussions/48
# For models that require `tool_choice: "auto"`, use a custom model class via the `model_class`
# parameter (e.g., ++server.model_class=nemo_skills.inference.model.sglang::SGLangModel).

# Test prompts designed to strongly encourage tool use
TEST_PROMPTS = [
{"problem": "Use the python tool to calculate 7 * 8 + 15 and verify your result is correct."},
{"problem": "Use python to compute the factorial of 10 and verify your answer."},
{"problem": "Write python code to check if 17 is a prime number and confirm the result."},
]


def _create_test_input_file():
"""Create a temporary input file with test prompts."""
fd, path = tempfile.mkstemp(suffix=".jsonl")
with os.fdopen(fd, "w") as f:
for prompt in TEST_PROMPTS:
f.write(json.dumps(prompt) + "\n")
return path


def _run_tool_calling_test(server_type: str, server_args: str, output_dir: str):
"""Common test logic for tool calling with different server types."""
model_path = require_env_var("NEMO_SKILLS_TEST_HF_MODEL")

docker_rm([output_dir])

# Create test input file
input_file = _create_test_input_file()

try:
cmd = (
f"ns generate "
f" --cluster test-local --config_dir {Path(__file__).absolute().parent} "
f" --model {model_path} "
f" --output_dir {output_dir} "
f" --server_type {server_type} "
f" --server_gpus 1 "
f" --server_nodes 1 "
f" --server_args '{server_args}' "
f" --with_sandbox "
f" --input_file {input_file} "
f" ++tool_modules=[nemo_skills.mcp.servers.python_tool::PythonTool] "
f" ++prompt_config=generic/math "
f" ++inference.tokens_to_generate=4096 "
f" ++inference.temperature=0.6 "
f" ++skip_filled=False "
)
subprocess.run(cmd, shell=True, check=True)

# Verify output exists and tool calls were made
output_file = f"{output_dir}/output.jsonl"
print(f"\n=== Output file location: {output_file} ===")
assert os.path.exists(output_file), f"Output file not found: {output_file}"
assert os.path.exists(f"{output_file}.done"), "Done marker not found"

with open(output_file) as fin:
lines = fin.readlines()

assert len(lines) == len(TEST_PROMPTS), f"Expected {len(TEST_PROMPTS)} lines, got {len(lines)}"

# Check that tool calls were made for each sample
samples_with_tool_calls = 0
for line in lines:
data = json.loads(line)
assert "generation" in data, "Missing 'generation' field in output"
num_tool_calls = data.get("num_tool_calls", 0)
if num_tool_calls > 0:
samples_with_tool_calls += 1

# At least some samples should have made tool calls
assert samples_with_tool_calls > 0, (
"No samples made tool calls. Expected tool usage for prompts that explicitly request it."
)

finally:
# Clean up temp file
if os.path.exists(input_file):
os.remove(input_file)


@pytest.mark.gpu
def test_vllm_tool_calling():
"""Test that VLLM properly makes tool calls with --enable-auto-tool-choice."""
model_type = require_env_var("NEMO_SKILLS_TEST_MODEL_TYPE")
output_dir = f"/tmp/nemo-skills-tests/{model_type}/vllm-tool-calling/generation"

_run_tool_calling_test(
server_type="vllm",
server_args="--enforce-eager --max-model-len 8192 --enable-auto-tool-choice --tool-call-parser hermes",
output_dir=output_dir,
)


@pytest.mark.gpu
def test_sglang_tool_calling():
"""Test that SGLang properly makes tool calls with tool_choice='auto' in request body."""
model_type = require_env_var("NEMO_SKILLS_TEST_MODEL_TYPE")
output_dir = f"/tmp/nemo-skills-tests/{model_type}/sglang-tool-calling/generation"

_run_tool_calling_test(
server_type="sglang",
server_args="--context-length 8192 --tool-call-parser qwen25",
output_dir=output_dir,
)