Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
1a5ec3f
Remove sync apis
smahdavi4 Oct 2, 2025
b83f930
Remove sync apis
smahdavi4 Oct 2, 2025
c5bb015
Remove sync apis
smahdavi4 Oct 2, 2025
98fb075
add responess api type to generate
smahdavi4 Oct 2, 2025
836f580
responses fix
smahdavi4 Oct 2, 2025
aa22312
responses fix
smahdavi4 Oct 2, 2025
2946fac
fix
smahdavi4 Oct 2, 2025
50044a5
complete tool calls
smahdavi4 Oct 2, 2025
dbb1817
debug
smahdavi4 Oct 3, 2025
c7974cf
debug
smahdavi4 Oct 3, 2025
eb6f8a3
debug
smahdavi4 Oct 3, 2025
3b31189
debug
smahdavi4 Oct 3, 2025
b486f23
debug
smahdavi4 Oct 3, 2025
988bc0b
debug
smahdavi4 Oct 3, 2025
3b46f64
debug
smahdavi4 Oct 3, 2025
42bb1f3
debug
smahdavi4 Oct 3, 2025
7c12f93
debug
smahdavi4 Oct 3, 2025
0a715ef
debug
smahdavi4 Oct 3, 2025
7a2e5d7
debug
smahdavi4 Oct 3, 2025
10ac8f9
fix
smahdavi4 Oct 3, 2025
d0d324f
bump up litellm version
smahdavi4 Oct 3, 2025
4886d73
debug
smahdavi4 Oct 3, 2025
f4263fa
cancel async tasks
smahdavi4 Oct 3, 2025
f30b8cb
cancel async tasks
smahdavi4 Oct 3, 2025
6909820
cancel async tasks
smahdavi4 Oct 3, 2025
57f33ec
debug
smahdavi4 Oct 3, 2025
1d412cc
debug
smahdavi4 Oct 3, 2025
f6b8a5f
debug
smahdavi4 Oct 3, 2025
833d94a
debug
smahdavi4 Oct 3, 2025
7ba7529
debug
smahdavi4 Oct 3, 2025
230fcbc
debug
smahdavi4 Oct 3, 2025
b570f02
debug
smahdavi4 Oct 3, 2025
513fc5e
debug
smahdavi4 Oct 3, 2025
1ee53ee
debug
smahdavi4 Oct 3, 2025
18f94f9
debug
smahdavi4 Oct 3, 2025
84eec57
debug
smahdavi4 Oct 3, 2025
323d05b
debug
smahdavi4 Oct 3, 2025
1aa1557
fix bfcl
smahdavi4 Oct 3, 2025
e807a26
fix
smahdavi4 Oct 3, 2025
a729523
cleanup
smahdavi4 Oct 3, 2025
1d505f3
print
smahdavi4 Oct 3, 2025
e4144db
fix
smahdavi4 Oct 4, 2025
baf4610
fix
smahdavi4 Oct 4, 2025
4122c41
debug bfcl
smahdavi4 Oct 4, 2025
86a703b
disable litellm log
smahdavi4 Oct 4, 2025
ec935b4
disable litellm log
smahdavi4 Oct 4, 2025
ea84af3
better fix for litellm
smahdavi4 Oct 4, 2025
b75ed7f
debug bfcl
smahdavi4 Oct 4, 2025
aac23ea
debug bfcl
smahdavi4 Oct 4, 2025
03fd088
supress log warning
smahdavi4 Oct 4, 2025
3c5ebcb
litellm comment
smahdavi4 Oct 5, 2025
39877e3
litellm comment
smahdavi4 Oct 5, 2025
225acee
litellm comment
smahdavi4 Oct 5, 2025
31d6e90
litellm comment
smahdavi4 Oct 5, 2025
daa3742
litellm comment
smahdavi4 Oct 5, 2025
51f7d80
extract function from tool call
smahdavi4 Oct 6, 2025
77fc5c9
logging instead of print
smahdavi4 Oct 6, 2025
3c15220
hard deprecate use_completions_api
smahdavi4 Oct 6, 2025
0f3d7f2
log available tools
smahdavi4 Oct 6, 2025
9f4b79b
make tool strict
smahdavi4 Oct 6, 2025
5046adb
more comments + strict function calling
smahdavi4 Oct 7, 2025
c225fb6
rename completion type to endpoint type
smahdavi4 Oct 7, 2025
4b6a9d3
implement responses for openai
smahdavi4 Oct 7, 2025
3a5b96b
merge with main
smahdavi4 Oct 7, 2025
3383408
more comments on litellm failure
smahdavi4 Oct 7, 2025
fc8aca5
Merge branch 'main' of github.com:NVIDIA/NeMo-Skills into smahdavi/re…
smahdavi4 Oct 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/basics/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ Click on :material-plus-circle: symbols in the snippet below to learn more detai
```python
from nemo_skills.inference.model import get_model
from nemo_skills.prompt.utils import get_prompt
import asyncio

llm = get_model(model="meta-llama/Llama-3.1-8B-Instruct", server_type="vllm") # localhost by default
prompt_obj = get_prompt('generic/default') # (1)!
prompt = prompt_obj.fill({'question': "What's 2 + 2?"})
print(prompt) # (2)!
output = llm.generate_sync(prompt=prompt)
output = asyncio.run(llm.generate_async(prompt=prompt))
print(output["generation"]) # (3)!
```

Expand Down Expand Up @@ -69,6 +70,7 @@ Click on :material-plus-circle: symbols in the snippet below to learn more detai
```python
from nemo_skills.inference.model import get_model
from nemo_skills.prompt.utils import get_prompt
import asyncio

llm = get_model( # (1)!
server_type="openai", # NIM models are using OpenAI API
Expand All @@ -80,7 +82,7 @@ Click on :material-plus-circle: symbols in the snippet below to learn more detai
prompt = prompt_obj.fill({'question': "What's 2 + 2?"})

print(prompt) # (3)!
output = llm.generate_sync(prompt=prompt)
output = asyncio.run(llm.generate_async(prompt=prompt))
print(output["generation"]) # (4)!
```

Expand Down
2 changes: 1 addition & 1 deletion docs/basics/prompt-format.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ which outputs

#### Example 2 - Prompt formatted as a string

If you want to use completions API, you can set `++use_completions_api=True`. This will use model's tokenizer to format
If you want to use completions API, you can set `++inference.endpoint_type=text`. This will use model's tokenizer to format
messages as a string (you can specify a custom tokenizer with `++tokenizer=...` argument).

Here is an example of the input to completions api
Expand Down
11 changes: 7 additions & 4 deletions docs/pipelines/generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ ns generate \
--input_file=/nemo_run/code/nemo_skills/dataset/math/train.jsonl \
++prompt_config=generic/math-base \
++examples_type=math_text_detailed \
++use_completions_api=True \
++inference.endpoint_type=text \
++tokenizer=meta-llama/Llama-3.1-405B \
++stop_phrase='\\n\\n\\n\\n\\n\\n'
```
Expand Down Expand Up @@ -366,6 +366,7 @@ We also support automatic trimming of generation budget or context when using vl

from nemo_skills.prompt.utils import get_prompt
from nemo_skills.inference.model import get_model
import asyncio

prompt = get_prompt(
"generic/math",
Expand All @@ -382,7 +383,7 @@ We also support automatic trimming of generation budget or context when using vl

# The 1M generation budget is well beyond the 40960 context window size of Qwen/Qwen3-0.6B
# We will automatically reduce the generation budget to fit in the context window
output_dict = llm.generate_sync(input_prompt, tokens_to_generate=1_000_000)
output_dict = asyncio.run(llm.generate_async(input_prompt, tokens_to_generate=1_000_000))
```
To specify this setting for the generation or eval pipeline use
```bash
Expand All @@ -395,6 +396,7 @@ We also support automatic trimming of generation budget or context when using vl
```python hl_lines="15-16"
from nemo_skills.prompt.utils import get_prompt
from nemo_skills.inference.model import get_model
import asyncio

prompt = get_prompt(
"generic/math",
Expand All @@ -413,7 +415,7 @@ We also support automatic trimming of generation budget or context when using vl

# We will automatically reduce the prompt from the start to fit in the context window
# Note that this requires the `tokens_to_generate` budget to be specified
output_dict = llm.generate_sync(prompt=input_prompt, tokens_to_generate=1024)
output_dict = asyncio.run(llm.generate_async(prompt=input_prompt, tokens_to_generate=1024))
```
To specify this setting for the generation or eval pipeline use
```bash
Expand All @@ -427,6 +429,7 @@ We also support automatic trimming of generation budget or context when using vl

from nemo_skills.prompt.utils import get_prompt
from nemo_skills.inference.model import get_model
import asyncio

prompt = get_prompt(
"generic/math",
Expand All @@ -445,7 +448,7 @@ We also support automatic trimming of generation budget or context when using vl

# We will automatically reduce the prompt from the end to fit in the context window
# Note that this requires the `tokens_to_generate` budget to be specified
output_dict = llm.generate_sync(prompt=input_prompt, tokens_to_generate=1024)
output_dict = asyncio.run(llm.generate_async(prompt=input_prompt, tokens_to_generate=1024))
```
To specify this setting for the generation or eval pipeline use
```bash
Expand Down
12 changes: 6 additions & 6 deletions docs/releases/openmathinstruct2/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ns generate \
--input_file=/nemo_run/code/nemo_skills/dataset/math/train.jsonl \
++prompt_config=generic/math-base \
++examples_type=math_text_detailed \
++use_completions_api=True \
++inference.endpoint_type=text \
++tokenizer=meta-llama/Llama-3.1-405B \
++stop_phrase='\\n\\n\\n\\n\\n\\n'
```
Expand All @@ -53,7 +53,7 @@ ns generate \
--input_file=/nemo_run/code/nemo_skills/dataset/gsm8k/train.jsonl \
++prompt_config=generic/math-base \
++examples_type=gsm8k_text_detailed \
++use_completions_api=True \
++inference.endpoint_type=text \
++tokenizer=meta-llama/Llama-3.1-405B \
++stop_phrase='\\n\\n\\n\\n\\n\\n'
```
Expand All @@ -76,7 +76,7 @@ ns generate \
++prompt_config=generic/problem-augmentation \
++examples_type=math_problem_augmentation \
++generation_key=problem \
++use_completions_api=True \
++inference.endpoint_type=text \
++tokenizer=meta-llama/Llama-3.1-405B \
++stop_phrase='\\n\\n\\n\\n\\n\\n'
```
Expand All @@ -96,7 +96,7 @@ ns generate \
++prompt_config=generic/problem-augmentation-similar \
++examples_type=gsm8k_problem_augmentation \
++generation_key=problem \
++use_completions_api=True \
++inference.endpoint_type=text \
++tokenizer=meta-llama/Llama-3.1-405B \
++stop_phrase='\\n\\n\\n\\n\\n\\n'
```
Expand Down Expand Up @@ -128,7 +128,7 @@ for i in range(80):
ctx=wrap_arguments(
f"++prompt_config=generic/math-base "
f"++examples_type=math_text_detailed "
f"++use_completions_api=True "
f"++inference.endpoint_type=text "
f"++tokenizer=meta-llama/Llama-3.1-405B "
f"++stop_phrase='\n\n\n\n\n\n' "
),
Expand All @@ -155,7 +155,7 @@ for i in range(10):
ctx=wrap_arguments(
f"++prompt_config=generic/math-base "
f"++examples_type=gsm8k_text_detailed "
f"++use_completions_api=True "
f"++inference.endpoint_type=text "
f"++tokenizer=meta-llama/Llama-3.1-405B "
f"++stop_phrase='\n\n\n\n\n\n' "
),
Expand Down
4 changes: 2 additions & 2 deletions docs/releases/openmathreasoning/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ ns eval \
--with_sandbox \
++code_tags=openmath \
++prompt_config=openmath/tir \
++use_completions_api=True \
++inference.endpoint_type=text \
++inference.tokens_to_generate=32768 \
++inference.temperature=0.6 \
++code_execution=true \
Expand All @@ -127,7 +127,7 @@ ns eval \
--with_sandbox \
++code_tags=openmath \
++prompt_config=generic/math \
++use_completions_api=True \
++inference.endpoint_type=text \
++inference.tokens_to_generate=32768 \
++inference.temperature=0.6 \
++code_execution=true
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/posts/gpt-oss-python.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ eval(
# we currently implement native Python code tool through text completions API
# as we found alternative implementations to have issues.
# We will switch to the official responses API when the support is added
"++use_completions_api=true "
"++inference.endpoint_type=text "
"++code_tags=gpt-oss "
# gpt-oss generates a lot of code, so need to set max_code_executions high!
# you can also add ++server.code_execution.code_execution_timeout=120 to match
Expand Down Expand Up @@ -219,7 +219,7 @@ generate(
# we currently implement native Python code tool through text completions API
# as we found alternative implementations to have issues.
# We will switch to the official responses API when the support is added
"++use_completions_api=true "
"++inference.endpoint_type=text "
"++code_tags=gpt-oss "
# gpt-oss generates a lot of code, so need to set max_code_executions high!
# you can also add ++server.code_execution.code_execution_timeout=120 to match
Expand Down
2 changes: 1 addition & 1 deletion nemo_skills/dataset/ruler/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"++inference.tokens_to_generate={tokens_to_generate} "
# ruler is adding prefix for assistant response, so it has to go through completions api
"++start_assistant_response_key=generation "
"++use_completions_api=True "
"++inference.endpoint_type=text "
)
"""
TOKENS_TO_GENERATE = {"niah": 128, "vt": 30, "cwe": 120, "fwe": 50, "qa": 32}
Expand Down
17 changes: 10 additions & 7 deletions nemo_skills/inference/chat_interface/chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import asyncio
import logging
from typing import Iterator

Expand Down Expand Up @@ -56,13 +57,15 @@ def stream_chat(
raise RuntimeError(f"Error preparing prompt: {e}") from e

extra_params = prompt_obj.get_code_execution_args() if use_code else {}
stream_iter_list = llm.generate_sync(
prompt=prompt_filled,
tokens_to_generate=int(tokens_to_generate),
temperature=float(temperature),
stream=True,
stop_phrases=prompt_obj.stop_phrases or [],
**extra_params,
stream_iter_list = asyncio.run(
llm.generate_async(
prompt=prompt_filled,
tokens_to_generate=int(tokens_to_generate),
temperature=float(temperature),
stream=True,
stop_phrases=prompt_obj.stop_phrases or [],
**extra_params,
)
)
if not stream_iter_list:
raise RuntimeError("LLM did not return a stream iterator.")
Expand Down
15 changes: 13 additions & 2 deletions nemo_skills/inference/eval/bfcl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
InferenceConfig,
)
from nemo_skills.inference.model import server_params
from nemo_skills.inference.model.base import EndpointType
from nemo_skills.inference.model.utils import is_context_window_exceeded_error
from nemo_skills.prompt.utils import get_token_count
from nemo_skills.utils import (
Expand Down Expand Up @@ -129,11 +130,21 @@ def _validate_and_setup_client_parsing(self):
self.message_formatter = partial(tokenizer.apply_chat_template, tokenize=False, add_generation_prompt=True)

def construct_input_dict(self, messages: list[dict], tools: list[dict]):
fmted_prompt = self.message_formatter(messages, tools=tools)
try:
fmted_prompt = self.message_formatter(messages, tools=tools)
except Exception as e:
# Sometimes the parsed tool-call is a string, which is not JSON serializable
# Putting a debugging here in case it happens in the future and we need to address it.
LOG.info(f"Messages: {messages}, Tools: {tools}")
LOG.error(f"Error formatting prompt: {e}")
raise e
kwargs = asdict(self.cfg.inference)
# Replace the completion type with text
kwargs["endpoint_type"] = EndpointType.text
return {
"prompt": fmted_prompt,
"include_response": True,
**asdict(self.cfg.inference),
**kwargs,
}

def parse_output_dict(self, output_dict: dict):
Expand Down
29 changes: 21 additions & 8 deletions nemo_skills/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
get_tool_calling_model,
server_params,
)
from nemo_skills.inference.model.base import EndpointType
from nemo_skills.prompt.utils import get_prompt, get_token_count
from nemo_skills.utils import (
chunk_data,
Expand All @@ -56,6 +57,13 @@

@nested_dataclass(kw_only=True)
class InferenceConfig:
# Type of completion to generate when using OpenAI
# "chat": used by default
# "text": for text completions, in this case we will
# take the tokenizer from the model and apply it to the prompt before sending it.
# You can override tokenizer with tokenizer parameter.
# "responses": for responses api format.
endpoint_type: EndpointType = EndpointType.chat
temperature: float = 0.0 # Temperature of 0 means greedy decoding
top_k: int = -1
top_p: float = 0.95
Expand All @@ -76,10 +84,10 @@ class GenerateSolutionsConfig:
input_file: str # Path to the input file with data
output_file: str # Where to save the generations
prompt_config: str | None = None # How to format the data into prompts
# by default we use chat completions, set this to True to use completions API. In that case we will take the
# tokenizer from the model and apply it to the prompt before sending it. You can override tokenizer with
# tokenizer parameter

# Deprecated, please use endpoint_type in the InferenceConfig instead
use_completions_api: bool = False

# path or name of the tokenizer to use for completions API. By default uses server.model
tokenizer: str | None = None
# extra parameters to pass to the tokenizer's apply_chat_template method
Expand Down Expand Up @@ -179,6 +187,7 @@ def __post_init__(self):
self._post_init_validate_data()
self._post_init_validate_server()
self._post_init_validate_params()
self._post_init_deprecated_params()

def _post_init_validate_data(self):
if isinstance(self.total_code_executions_in_prompt, ListConfig):
Expand All @@ -199,7 +208,7 @@ def _post_init_validate_server(self):
"Megatron server doesn't support chat completions and we can't infer tokenizer from model name. "
"Please provide it with an explicit `tokenizer` parameter."
)
self.use_completions_api = True
self.inference.endpoint_type = EndpointType.text
LOG.warning("Megatron inference is extremely slow. It's highly recommended to use other server types!")

def _post_init_validate_params(self):
Expand All @@ -215,6 +224,10 @@ def _post_init_validate_params(self):
if getattr(self, param) != default_value:
raise ValueError(f"{param} must be {default_value}")

def _post_init_deprecated_params(self):
if self.use_completions_api:
raise ValueError("use_completions_api is deprecated, please use ++inference.endpoint_type=text instead.")

def _get_disallowed_params(self):
"""Returns a list of parameters with their default values to check that they are not changed from the defaults"""
return []
Expand Down Expand Up @@ -261,7 +274,7 @@ def __init__(self, cfg: GenerateSolutionsConfig):

# chat template kwargs goes either into extra body of inference or as a prompt parameter
if self.cfg.chat_template_kwargs:
if not self.cfg.use_completions_api:
if self.cfg.inference.endpoint_type != EndpointType.text:
if "chat_template_kwargs" in self.cfg.inference.extra_body:
raise ValueError(
"chat_template_kwargs is provided in both inference.extra_body and as a separate argument. "
Expand All @@ -273,7 +286,7 @@ def __init__(self, cfg: GenerateSolutionsConfig):

# Setup tokenizer
if (
self.cfg.use_completions_api
self.cfg.inference.endpoint_type == EndpointType.text
or self.cfg.server.get("enable_soft_fail", False)
or self.cfg.count_prompt_tokens
):
Expand All @@ -285,7 +298,7 @@ def __init__(self, cfg: GenerateSolutionsConfig):
# Setup litellm cache
self.setup_litellm_cache()

if self.cfg.use_completions_api and self.cfg.inference.tokens_to_generate is None:
if self.cfg.inference.endpoint_type == EndpointType.text and self.cfg.inference.tokens_to_generate is None:
raise ValueError("When using completions API, tokens_to_generate must be specified!")

# Setup prompt formatter and LLM
Expand Down Expand Up @@ -345,7 +358,7 @@ def setup_prompt(self):

prompt = get_prompt(
prompt_config=self.cfg.prompt_config,
tokenizer=self.tokenizer if self.cfg.use_completions_api else None,
tokenizer=self.tokenizer if self.cfg.inference.endpoint_type == EndpointType.text else None,
code_tags=self.cfg.code_tags,
examples_type=self.cfg.examples_type,
system_message=self.cfg.system_message,
Expand Down
Loading