diff --git a/docs/basics/inference.md b/docs/basics/inference.md index 1acfebfb4b..08a8f51697 100644 --- a/docs/basics/inference.md +++ b/docs/basics/inference.md @@ -34,12 +34,13 @@ Click on :material-plus-circle: symbols in the snippet below to learn more detai ```python from nemo_skills.inference.model import get_model from nemo_skills.prompt.utils import get_prompt + import asyncio llm = get_model(model="meta-llama/Llama-3.1-8B-Instruct", server_type="vllm") # localhost by default prompt_obj = get_prompt('generic/default') # (1)! prompt = prompt_obj.fill({'question': "What's 2 + 2?"}) print(prompt) # (2)! - output = llm.generate_sync(prompt=prompt) + output = asyncio.run(llm.generate_async(prompt=prompt)) print(output["generation"]) # (3)! ``` @@ -69,6 +70,7 @@ Click on :material-plus-circle: symbols in the snippet below to learn more detai ```python from nemo_skills.inference.model import get_model from nemo_skills.prompt.utils import get_prompt + import asyncio llm = get_model( # (1)! server_type="openai", # NIM models are using OpenAI API @@ -80,7 +82,7 @@ Click on :material-plus-circle: symbols in the snippet below to learn more detai prompt = prompt_obj.fill({'question': "What's 2 + 2?"}) print(prompt) # (3)! - output = llm.generate_sync(prompt=prompt) + output = asyncio.run(llm.generate_async(prompt=prompt)) print(output["generation"]) # (4)! ``` diff --git a/docs/basics/prompt-format.md b/docs/basics/prompt-format.md index 95b74aed62..37e3eb5f93 100644 --- a/docs/basics/prompt-format.md +++ b/docs/basics/prompt-format.md @@ -108,7 +108,7 @@ which outputs #### Example 2 - Prompt formatted as a string -If you want to use completions API, you can set `++use_completions_api=True`. This will use model's tokenizer to format +If you want to use completions API, you can set `++inference.endpoint_type=text`. This will use model's tokenizer to format messages as a string (you can specify a custom tokenizer with `++tokenizer=...` argument). Here is an example of the input to completions api diff --git a/docs/pipelines/generation.md b/docs/pipelines/generation.md index 8a062acc21..81a37a0dd8 100644 --- a/docs/pipelines/generation.md +++ b/docs/pipelines/generation.md @@ -126,7 +126,7 @@ ns generate \ --input_file=/nemo_run/code/nemo_skills/dataset/math/train.jsonl \ ++prompt_config=generic/math-base \ ++examples_type=math_text_detailed \ - ++use_completions_api=True \ + ++inference.endpoint_type=text \ ++tokenizer=meta-llama/Llama-3.1-405B \ ++stop_phrase='\\n\\n\\n\\n\\n\\n' ``` @@ -366,6 +366,7 @@ We also support automatic trimming of generation budget or context when using vl from nemo_skills.prompt.utils import get_prompt from nemo_skills.inference.model import get_model + import asyncio prompt = get_prompt( "generic/math", @@ -382,7 +383,7 @@ We also support automatic trimming of generation budget or context when using vl # The 1M generation budget is well beyond the 40960 context window size of Qwen/Qwen3-0.6B # We will automatically reduce the generation budget to fit in the context window - output_dict = llm.generate_sync(input_prompt, tokens_to_generate=1_000_000) + output_dict = asyncio.run(llm.generate_async(input_prompt, tokens_to_generate=1_000_000)) ``` To specify this setting for the generation or eval pipeline use ```bash @@ -395,6 +396,7 @@ We also support automatic trimming of generation budget or context when using vl ```python hl_lines="15-16" from nemo_skills.prompt.utils import get_prompt from nemo_skills.inference.model import get_model + import asyncio prompt = get_prompt( "generic/math", @@ -413,7 +415,7 @@ We also support automatic trimming of generation budget or context when using vl # We will automatically reduce the prompt from the start to fit in the context window # Note that this requires the `tokens_to_generate` budget to be specified - output_dict = llm.generate_sync(prompt=input_prompt, tokens_to_generate=1024) + output_dict = asyncio.run(llm.generate_async(prompt=input_prompt, tokens_to_generate=1024)) ``` To specify this setting for the generation or eval pipeline use ```bash @@ -427,6 +429,7 @@ We also support automatic trimming of generation budget or context when using vl from nemo_skills.prompt.utils import get_prompt from nemo_skills.inference.model import get_model + import asyncio prompt = get_prompt( "generic/math", @@ -445,7 +448,7 @@ We also support automatic trimming of generation budget or context when using vl # We will automatically reduce the prompt from the end to fit in the context window # Note that this requires the `tokens_to_generate` budget to be specified - output_dict = llm.generate_sync(prompt=input_prompt, tokens_to_generate=1024) + output_dict = asyncio.run(llm.generate_async(prompt=input_prompt, tokens_to_generate=1024)) ``` To specify this setting for the generation or eval pipeline use ```bash diff --git a/docs/releases/openmathinstruct2/dataset.md b/docs/releases/openmathinstruct2/dataset.md index ec0ea26034..e3ee0b74ee 100644 --- a/docs/releases/openmathinstruct2/dataset.md +++ b/docs/releases/openmathinstruct2/dataset.md @@ -33,7 +33,7 @@ ns generate \ --input_file=/nemo_run/code/nemo_skills/dataset/math/train.jsonl \ ++prompt_config=generic/math-base \ ++examples_type=math_text_detailed \ - ++use_completions_api=True \ + ++inference.endpoint_type=text \ ++tokenizer=meta-llama/Llama-3.1-405B \ ++stop_phrase='\\n\\n\\n\\n\\n\\n' ``` @@ -53,7 +53,7 @@ ns generate \ --input_file=/nemo_run/code/nemo_skills/dataset/gsm8k/train.jsonl \ ++prompt_config=generic/math-base \ ++examples_type=gsm8k_text_detailed \ - ++use_completions_api=True \ + ++inference.endpoint_type=text \ ++tokenizer=meta-llama/Llama-3.1-405B \ ++stop_phrase='\\n\\n\\n\\n\\n\\n' ``` @@ -76,7 +76,7 @@ ns generate \ ++prompt_config=generic/problem-augmentation \ ++examples_type=math_problem_augmentation \ ++generation_key=problem \ - ++use_completions_api=True \ + ++inference.endpoint_type=text \ ++tokenizer=meta-llama/Llama-3.1-405B \ ++stop_phrase='\\n\\n\\n\\n\\n\\n' ``` @@ -96,7 +96,7 @@ ns generate \ ++prompt_config=generic/problem-augmentation-similar \ ++examples_type=gsm8k_problem_augmentation \ ++generation_key=problem \ - ++use_completions_api=True \ + ++inference.endpoint_type=text \ ++tokenizer=meta-llama/Llama-3.1-405B \ ++stop_phrase='\\n\\n\\n\\n\\n\\n' ``` @@ -128,7 +128,7 @@ for i in range(80): ctx=wrap_arguments( f"++prompt_config=generic/math-base " f"++examples_type=math_text_detailed " - f"++use_completions_api=True " + f"++inference.endpoint_type=text " f"++tokenizer=meta-llama/Llama-3.1-405B " f"++stop_phrase='\n\n\n\n\n\n' " ), @@ -155,7 +155,7 @@ for i in range(10): ctx=wrap_arguments( f"++prompt_config=generic/math-base " f"++examples_type=gsm8k_text_detailed " - f"++use_completions_api=True " + f"++inference.endpoint_type=text " f"++tokenizer=meta-llama/Llama-3.1-405B " f"++stop_phrase='\n\n\n\n\n\n' " ), diff --git a/docs/releases/openmathreasoning/evaluation.md b/docs/releases/openmathreasoning/evaluation.md index 3e534a3512..9e6ffcbe25 100644 --- a/docs/releases/openmathreasoning/evaluation.md +++ b/docs/releases/openmathreasoning/evaluation.md @@ -104,7 +104,7 @@ ns eval \ --with_sandbox \ ++code_tags=openmath \ ++prompt_config=openmath/tir \ - ++use_completions_api=True \ + ++inference.endpoint_type=text \ ++inference.tokens_to_generate=32768 \ ++inference.temperature=0.6 \ ++code_execution=true \ @@ -127,7 +127,7 @@ ns eval \ --with_sandbox \ ++code_tags=openmath \ ++prompt_config=generic/math \ - ++use_completions_api=True \ + ++inference.endpoint_type=text \ ++inference.tokens_to_generate=32768 \ ++inference.temperature=0.6 \ ++code_execution=true diff --git a/docs/tutorials/posts/gpt-oss-python.md b/docs/tutorials/posts/gpt-oss-python.md index eed6024872..43f1685778 100644 --- a/docs/tutorials/posts/gpt-oss-python.md +++ b/docs/tutorials/posts/gpt-oss-python.md @@ -68,7 +68,7 @@ eval( # we currently implement native Python code tool through text completions API # as we found alternative implementations to have issues. # We will switch to the official responses API when the support is added - "++use_completions_api=true " + "++inference.endpoint_type=text " "++code_tags=gpt-oss " # gpt-oss generates a lot of code, so need to set max_code_executions high! # you can also add ++server.code_execution.code_execution_timeout=120 to match @@ -219,7 +219,7 @@ generate( # we currently implement native Python code tool through text completions API # as we found alternative implementations to have issues. # We will switch to the official responses API when the support is added - "++use_completions_api=true " + "++inference.endpoint_type=text " "++code_tags=gpt-oss " # gpt-oss generates a lot of code, so need to set max_code_executions high! # you can also add ++server.code_execution.code_execution_timeout=120 to match diff --git a/nemo_skills/dataset/ruler/prepare.py b/nemo_skills/dataset/ruler/prepare.py index 0bc70c991a..9f15cd7f5f 100644 --- a/nemo_skills/dataset/ruler/prepare.py +++ b/nemo_skills/dataset/ruler/prepare.py @@ -31,7 +31,7 @@ "++inference.tokens_to_generate={tokens_to_generate} " # ruler is adding prefix for assistant response, so it has to go through completions api "++start_assistant_response_key=generation " - "++use_completions_api=True " + "++inference.endpoint_type=text " ) """ TOKENS_TO_GENERATE = {"niah": 128, "vt": 30, "cwe": 120, "fwe": 50, "qa": 32} diff --git a/nemo_skills/inference/chat_interface/chat_service.py b/nemo_skills/inference/chat_interface/chat_service.py index 72b3452ed2..b17213a486 100644 --- a/nemo_skills/inference/chat_interface/chat_service.py +++ b/nemo_skills/inference/chat_interface/chat_service.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio import logging from typing import Iterator @@ -56,13 +57,15 @@ def stream_chat( raise RuntimeError(f"Error preparing prompt: {e}") from e extra_params = prompt_obj.get_code_execution_args() if use_code else {} - stream_iter_list = llm.generate_sync( - prompt=prompt_filled, - tokens_to_generate=int(tokens_to_generate), - temperature=float(temperature), - stream=True, - stop_phrases=prompt_obj.stop_phrases or [], - **extra_params, + stream_iter_list = asyncio.run( + llm.generate_async( + prompt=prompt_filled, + tokens_to_generate=int(tokens_to_generate), + temperature=float(temperature), + stream=True, + stop_phrases=prompt_obj.stop_phrases or [], + **extra_params, + ) ) if not stream_iter_list: raise RuntimeError("LLM did not return a stream iterator.") diff --git a/nemo_skills/inference/eval/bfcl.py b/nemo_skills/inference/eval/bfcl.py index 37aeb1334b..fc9f397793 100644 --- a/nemo_skills/inference/eval/bfcl.py +++ b/nemo_skills/inference/eval/bfcl.py @@ -38,6 +38,7 @@ InferenceConfig, ) from nemo_skills.inference.model import server_params +from nemo_skills.inference.model.base import EndpointType from nemo_skills.inference.model.utils import is_context_window_exceeded_error from nemo_skills.prompt.utils import get_token_count from nemo_skills.utils import ( @@ -129,11 +130,21 @@ def _validate_and_setup_client_parsing(self): self.message_formatter = partial(tokenizer.apply_chat_template, tokenize=False, add_generation_prompt=True) def construct_input_dict(self, messages: list[dict], tools: list[dict]): - fmted_prompt = self.message_formatter(messages, tools=tools) + try: + fmted_prompt = self.message_formatter(messages, tools=tools) + except Exception as e: + # Sometimes the parsed tool-call is a string, which is not JSON serializable + # Putting a debugging here in case it happens in the future and we need to address it. + LOG.info(f"Messages: {messages}, Tools: {tools}") + LOG.error(f"Error formatting prompt: {e}") + raise e + kwargs = asdict(self.cfg.inference) + # Replace the completion type with text + kwargs["endpoint_type"] = EndpointType.text return { "prompt": fmted_prompt, "include_response": True, - **asdict(self.cfg.inference), + **kwargs, } def parse_output_dict(self, output_dict: dict): diff --git a/nemo_skills/inference/generate.py b/nemo_skills/inference/generate.py index e6061a4060..b6793d8acb 100644 --- a/nemo_skills/inference/generate.py +++ b/nemo_skills/inference/generate.py @@ -40,6 +40,7 @@ get_tool_calling_model, server_params, ) +from nemo_skills.inference.model.base import EndpointType from nemo_skills.prompt.utils import get_prompt, get_token_count from nemo_skills.utils import ( chunk_data, @@ -56,6 +57,13 @@ @nested_dataclass(kw_only=True) class InferenceConfig: + # Type of completion to generate when using OpenAI + # "chat": used by default + # "text": for text completions, in this case we will + # take the tokenizer from the model and apply it to the prompt before sending it. + # You can override tokenizer with tokenizer parameter. + # "responses": for responses api format. + endpoint_type: EndpointType = EndpointType.chat temperature: float = 0.0 # Temperature of 0 means greedy decoding top_k: int = -1 top_p: float = 0.95 @@ -76,10 +84,10 @@ class GenerateSolutionsConfig: input_file: str # Path to the input file with data output_file: str # Where to save the generations prompt_config: str | None = None # How to format the data into prompts - # by default we use chat completions, set this to True to use completions API. In that case we will take the - # tokenizer from the model and apply it to the prompt before sending it. You can override tokenizer with - # tokenizer parameter + + # Deprecated, please use endpoint_type in the InferenceConfig instead use_completions_api: bool = False + # path or name of the tokenizer to use for completions API. By default uses server.model tokenizer: str | None = None # extra parameters to pass to the tokenizer's apply_chat_template method @@ -179,6 +187,7 @@ def __post_init__(self): self._post_init_validate_data() self._post_init_validate_server() self._post_init_validate_params() + self._post_init_deprecated_params() def _post_init_validate_data(self): if isinstance(self.total_code_executions_in_prompt, ListConfig): @@ -199,7 +208,7 @@ def _post_init_validate_server(self): "Megatron server doesn't support chat completions and we can't infer tokenizer from model name. " "Please provide it with an explicit `tokenizer` parameter." ) - self.use_completions_api = True + self.inference.endpoint_type = EndpointType.text LOG.warning("Megatron inference is extremely slow. It's highly recommended to use other server types!") def _post_init_validate_params(self): @@ -215,6 +224,10 @@ def _post_init_validate_params(self): if getattr(self, param) != default_value: raise ValueError(f"{param} must be {default_value}") + def _post_init_deprecated_params(self): + if self.use_completions_api: + raise ValueError("use_completions_api is deprecated, please use ++inference.endpoint_type=text instead.") + def _get_disallowed_params(self): """Returns a list of parameters with their default values to check that they are not changed from the defaults""" return [] @@ -261,7 +274,7 @@ def __init__(self, cfg: GenerateSolutionsConfig): # chat template kwargs goes either into extra body of inference or as a prompt parameter if self.cfg.chat_template_kwargs: - if not self.cfg.use_completions_api: + if self.cfg.inference.endpoint_type != EndpointType.text: if "chat_template_kwargs" in self.cfg.inference.extra_body: raise ValueError( "chat_template_kwargs is provided in both inference.extra_body and as a separate argument. " @@ -273,7 +286,7 @@ def __init__(self, cfg: GenerateSolutionsConfig): # Setup tokenizer if ( - self.cfg.use_completions_api + self.cfg.inference.endpoint_type == EndpointType.text or self.cfg.server.get("enable_soft_fail", False) or self.cfg.count_prompt_tokens ): @@ -285,7 +298,7 @@ def __init__(self, cfg: GenerateSolutionsConfig): # Setup litellm cache self.setup_litellm_cache() - if self.cfg.use_completions_api and self.cfg.inference.tokens_to_generate is None: + if self.cfg.inference.endpoint_type == EndpointType.text and self.cfg.inference.tokens_to_generate is None: raise ValueError("When using completions API, tokens_to_generate must be specified!") # Setup prompt formatter and LLM @@ -345,7 +358,7 @@ def setup_prompt(self): prompt = get_prompt( prompt_config=self.cfg.prompt_config, - tokenizer=self.tokenizer if self.cfg.use_completions_api else None, + tokenizer=self.tokenizer if self.cfg.inference.endpoint_type == EndpointType.text else None, code_tags=self.cfg.code_tags, examples_type=self.cfg.examples_type, system_message=self.cfg.system_message, diff --git a/nemo_skills/inference/model/base.py b/nemo_skills/inference/model/base.py index 94aa096450..435d96ce6b 100644 --- a/nemo_skills/inference/model/base.py +++ b/nemo_skills/inference/model/base.py @@ -11,16 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import abc import logging import os +from enum import Enum from typing import Union import httpx import litellm import openai +from nemo_skills.inference.patch_litellm_logging import patch_litellm_logging_worker from nemo_skills.utils import get_logger_name from .context_retry import ContextLimitRetryConfig, with_context_retry @@ -28,6 +29,16 @@ LOG = logging.getLogger(get_logger_name(__file__)) +# The logging worker sometimes does not stop. We patch it to disable its functionality. +# TODO: Remove this once LiteLLM fixes it. +patch_litellm_logging_worker() + + +class EndpointType(str, Enum): + text = "text" + chat = "chat" + responses = "responses" + class BaseModel: """Base model class for handling requests to the inference server. @@ -124,6 +135,7 @@ def __init__( max_retries=max_retries, api_key=api_key, base_url=self.base_url, + api_base=self.base_url, # Used in later versions with responses API ) httpx_limits = httpx.Limits(max_keepalive_connections=2048, max_connections=2048) litellm.client_session = httpx.Client(limits=httpx_limits) @@ -190,19 +202,14 @@ def _build_chat_request_params(self, **kwargs) -> dict: def _build_completion_request_params(self, **kwargs) -> dict: pass - def _build_request_params(self, prompt: str | list[dict], stream: bool, **kwargs) -> dict: - if isinstance(prompt, str): - return self._build_completion_request_params(prompt=prompt, stream=stream, **kwargs) - elif isinstance(prompt, list): - request_params = self._build_chat_request_params(messages=prompt, stream=stream, **kwargs) - return request_params - else: - raise ValueError("Either prompt or messages must be provided") + def _build_responses_request_params(self, **kwargs) -> dict: + raise NotImplementedError("Responses completion is not not supported or implemented for this model.") @with_context_retry async def generate_async( self, prompt: str | list[dict], + endpoint_type: EndpointType = None, tokens_to_generate: int | None = None, temperature: float = 0.0, top_p: float = 0.95, @@ -220,8 +227,9 @@ async def generate_async( include_response: bool = False, extra_body: dict = None, ) -> dict: - """Native async version of generate for single prompt.""" - + if endpoint_type is None: + # Infering completion type from prompt + endpoint_type = EndpointType.chat if isinstance(prompt, list) else EndpointType.text # Check tool calls are a list of dict if tools is not None: for tool in tools: @@ -251,7 +259,8 @@ async def generate_async( while retry_count <= max_retries: try: - if isinstance(prompt, list): + if endpoint_type == EndpointType.chat: + assert isinstance(prompt, list), "Chat completion requests must be a list of messages." request_params = self._build_chat_request_params(messages=prompt, stream=stream, **kwargs) response = await litellm.acompletion(**request_params, **self.litellm_kwargs) if stream: @@ -260,16 +269,26 @@ async def generate_async( result = self._parse_chat_completion_response( response, include_response=include_response, **kwargs ) - - elif isinstance(prompt, str): + elif endpoint_type == EndpointType.text: + assert isinstance(prompt, str), "Text completion requests must be a string." request_params = self._build_completion_request_params(prompt=prompt, stream=stream, **kwargs) response = await litellm.atext_completion(**request_params, **self.litellm_kwargs) if stream: result = self._stream_completion_chunks_async(response) else: result = self._parse_completion_response(response, include_response=include_response, **kwargs) + elif endpoint_type == EndpointType.responses: + assert isinstance(prompt, list), "Responses completion requests must be a list." + request_params = self._build_responses_request_params(input=prompt, stream=stream, **kwargs) + response = await litellm.aresponses(**request_params, **self.litellm_kwargs) + if stream: + raise NotImplementedError("Streaming responses is not supported yet.") + else: + result = self._parse_responses_completion_response( + response, include_response=include_response, **kwargs + ) else: - raise TypeError(f"Unsupported prompt type: {type(prompt)}") + raise TypeError(f"Unsupported completion type: {endpoint_type}") if not stream: self._maybe_apply_stop_phrase_removal(result, remove_stop_phrases, stop_phrases) return result @@ -288,73 +307,6 @@ async def generate_async( return result - @with_context_retry - def generate_sync( - self, - prompt: str | list[dict], - tokens_to_generate: int | None = None, - temperature: float = 0.0, - top_p: float = 0.95, - top_k: int = -1, - min_p: float = 0.0, - repetition_penalty: float = 1.0, - random_seed: int = None, - stop_phrases: list[str] | None = None, - top_logprobs: int | None = None, - timeout: float | int | None = 14400, # None is 10min - remove_stop_phrases: bool = True, - stream: bool = False, - reasoning_effort: str | None = None, - tools: list[dict] | None = None, - include_response: bool = False, - extra_body: dict = None, - ) -> dict: - """ - Synchronous version of generate for single prompt. - See generate_async for full list of parameters. - """ - # Check tool calls are a list of dict - if tools is not None: - for tool in tools: - # TODO: We may want to add additional checks for tools in the future - if not isinstance(tool, dict): - raise ValueError(f"Tool must be a dictionary, got {type(tool)}") - - kwargs = { - "tokens_to_generate": tokens_to_generate, - "temperature": temperature, - "top_p": top_p, - "top_k": top_k, - "min_p": min_p, - "repetition_penalty": repetition_penalty, - "random_seed": random_seed, - "stop_phrases": stop_phrases, - "top_logprobs": top_logprobs, - "timeout": timeout, - "reasoning_effort": reasoning_effort, - "tools": tools, - "extra_body": extra_body, - } - request_params = self._build_request_params(prompt=prompt, stream=stream, **kwargs) - if isinstance(prompt, list): - response = litellm.completion(**request_params, **self.litellm_kwargs) - if stream: - result = self._stream_chat_chunks_sync(response) - else: - result = self._parse_chat_completion_response(response, include_response=include_response, **kwargs) - - elif isinstance(prompt, str): - response = litellm.text_completion(**request_params, **self.litellm_kwargs) - if stream: - result = self._stream_completion_chunks_sync(response) - else: - result = self._parse_completion_response(response, include_response=include_response, **kwargs) - else: - raise TypeError(f"Unsupported prompt type: {type(prompt)}") - - self._maybe_apply_stop_phrase_removal(result, remove_stop_phrases, stop_phrases) - return result - def _parse_completion_response( self, response: "openai.types.Completion", include_response: bool = False, **kwargs ) -> dict: @@ -415,6 +367,7 @@ def _parse_chat_completion_response(self, response, include_response: bool = Fal result["finish_reason"] = choice.finish_reason if hasattr(choice.message, "tool_calls") and choice.message.tool_calls: result["tool_calls"] = choice.message.tool_calls + result["serialized_output"] = self._serialize_output(response) if include_response: result["response"] = response @@ -482,31 +435,74 @@ def _process_chat_chunk(self, chunk): return [result] - def _stream_completion_chunks_sync(self, response): - """Synchronous version of stream completion chunks.""" - emitted_so_far = [] - for chunk in response: - results = self._process_completion_chunk(chunk, emitted_so_far) - for result in results: - yield result - - def _stream_chat_chunks_sync(self, response): - """Synchronous version of stream chat chunks.""" - for chunk in response: - results = self._process_chat_chunk(chunk) - for result in results: - yield result - async def _stream_completion_chunks_async(self, response): - """Async version of stream completion chunks.""" emitted_so_far = [] async for chunk in response: results = self._process_completion_chunk(chunk, emitted_so_far) for result in results: yield result + def _parse_responses_completion_response(self, response, include_response: bool = False, **kwargs) -> dict: + """Public method for parsing responses API responses""" + result = {"generation": "", "num_generated_tokens": 0} + + if hasattr(response, "usage"): + result["num_generated_tokens"] = response.usage.output_tokens + + tool_calls = [] + reasoning_content = "" + generation_text = "" + + if hasattr(response, "output") and response.output: + for output_item in response.output: + # Handle reasoning content + if output_item.type == "reasoning": + if output_item.content: + for content_item in output_item.content: + if content_item.text: + reasoning_content += content_item.text + "\n" + + # Handle function calls + elif output_item.type == "function_call": + tool_calls.append(output_item) + + # Handle message content + elif output_item.type == "message": + if output_item.content: + for content_item in output_item.content: + if content_item.text: + generation_text += content_item.text + + if tool_calls: + result["tool_calls"] = tool_calls + result["generation"] = "" # No text generation when there are tool calls + else: + result["generation"] = generation_text + if reasoning_content: + result["reasoning_content"] = reasoning_content.strip() + + result["finish_reason"] = response.status + result["serialized_output"] = self._serialize_output(response) + if include_response: + result["response"] = response + + return result + + def _serialize_output(self, response): + """Serialize response output objects using model_dump() for conversation history.""" + serialized_output = [] + + if hasattr(response, "output") and response.output: + for output_item in response.output: + serialized_output.append(output_item.model_dump()) + elif hasattr(response, "choices") and response.choices: + for choice in response.choices: + serialized_output.append(choice.model_dump()["message"]) + else: + raise ValueError(f"Unsupported response type: {type(response)}") + return serialized_output + async def _stream_chat_chunks_async(self, response): - """Async version of stream chat chunks.""" async for chunk in response: results = self._process_chat_chunk(chunk) for result in results: diff --git a/nemo_skills/inference/model/code_execution.py b/nemo_skills/inference/model/code_execution.py index 633090b2c5..e9026976c1 100644 --- a/nemo_skills/inference/model/code_execution.py +++ b/nemo_skills/inference/model/code_execution.py @@ -22,7 +22,7 @@ from nemo_skills.code_execution.sandbox import Sandbox from nemo_skills.utils import get_logger_name, nested_dataclass -from .base import BaseModel +from .base import BaseModel, EndpointType LOG = logging.getLogger(get_logger_name(__file__)) @@ -65,6 +65,7 @@ async def _generate_single( max_code_executions: int | None = None, # if not None, will override self.config.max_code_executions stream: bool = False, extra_body: dict = None, + endpoint_type: EndpointType = None, ): # Handle OpenAI-style dictionary prompts is_openai_format = not isinstance(prompt, str) @@ -75,6 +76,7 @@ async def _generate_single( if stream: return self._stream_single( prompt=prompt, + endpoint_type=endpoint_type, code_begin=code_begin, code_end=code_end, code_output_begin=code_output_begin, @@ -108,6 +110,7 @@ async def _generate_single( stop_phrases = stop_phrases or [] request = { + "endpoint_type": endpoint_type, "prompt": new_prompt, "tokens_to_generate": tokens_to_generate, "temperature": temperature, @@ -257,6 +260,7 @@ async def generate_async( max_code_executions: int | None = None, stream: bool = False, extra_body: dict = None, + endpoint_type: EndpointType = None, ) -> list[dict]: """For any generation parameter you can specify a list of values that needs to match the number of prompts. @@ -266,6 +270,7 @@ async def generate_async( raise NotImplementedError("top_logprobs is not supported yet.") kwargs = { + "endpoint_type": endpoint_type, "code_begin": code_begin, "code_end": code_end, "code_output_begin": code_output_begin, @@ -313,6 +318,7 @@ async def _stream_single( timeout: float | int | None = 14400, # None is 10min, max_code_executions: int | None = None, extra_body: dict = None, + endpoint_type: EndpointType = None, ): """ Helper method, that implements streaming generation. @@ -327,6 +333,7 @@ async def _stream_single( stop_phrases = stop_phrases or [] request = { + "endpoint_type": endpoint_type, "temperature": temperature, "top_p": top_p, "top_k": top_k, diff --git a/nemo_skills/inference/model/openai.py b/nemo_skills/inference/model/openai.py index a7deff928d..9f468c93c5 100644 --- a/nemo_skills/inference/model/openai.py +++ b/nemo_skills/inference/model/openai.py @@ -151,3 +151,10 @@ def _build_chat_request_params( params["top_p"] = top_p return params + + def _build_responses_request_params(self, input, **kwargs) -> dict: + # Remapping variables to match responses API + responses_params = self._build_chat_request_params(messages=input, **kwargs) + responses_params["input"] = responses_params.pop("messages") + responses_params["max_output_tokens"] = responses_params.pop("max_completion_tokens") + return responses_params diff --git a/nemo_skills/inference/model/parallel_thinking.py b/nemo_skills/inference/model/parallel_thinking.py index 6a63713444..959f22b05a 100644 --- a/nemo_skills/inference/model/parallel_thinking.py +++ b/nemo_skills/inference/model/parallel_thinking.py @@ -27,7 +27,7 @@ from nemo_skills.prompt.utils import get_prompt from nemo_skills.utils import get_logger_name, nested_dataclass, remove_thinking -from .base import BaseModel +from .base import BaseModel, EndpointType LOG = logging.getLogger(get_logger_name(__file__)) @@ -52,7 +52,7 @@ class ParallelThinkingConfig: remove_thinking: bool = True # Remove thinking tokens from the solution key thinking_begin: str = "" thinking_end: str = "" - use_completions_api: bool = False + endpoint_type: EndpointType = EndpointType.chat tokenizer: str | None = None chat_template_kwargs: dict | None = None # extra parameters to pass to the tokenizer's apply_chat_template method start_assistant_response_key: str | None = None # whether to start assistant response with this key diff --git a/nemo_skills/inference/model/tool_call.py b/nemo_skills/inference/model/tool_call.py index bba2f64f9c..ffbd1a9921 100644 --- a/nemo_skills/inference/model/tool_call.py +++ b/nemo_skills/inference/model/tool_call.py @@ -21,14 +21,14 @@ from typing import Dict, List from nemo_skills.mcp.adapters import ( - ChatCompletionCallInterpreter, - ChatCompletionResponseFormatter, - OpenAISchemaAdapter, + format_tool_list_by_endpoint_type, + format_tool_response_by_endpoint_type, + get_tool_details_by_endpoint_type, ) from nemo_skills.mcp.tool_manager import ToolManager from nemo_skills.utils import get_logger_name -from .base import BaseModel +from .base import BaseModel, EndpointType LOG = logging.getLogger(get_logger_name(__file__)) @@ -59,15 +59,10 @@ def __init__( overrides=tool_overrides or {}, context=additional_config, ) - # Use sensible defaults for adapters in module-based mode - self.schema_adapter = OpenAISchemaAdapter() - self.call_interpreter = ChatCompletionCallInterpreter() - self.response_formatter = ChatCompletionResponseFormatter() - async def _execute_tool_call(self, tool_call, request_id: str): + async def _execute_tool_call(self, tool_call, request_id: str, endpoint_type: EndpointType): ## TODO(sanyamk): The correct key format needs to be cohesive with other formatters. - tool_name = tool_call["function"]["name"] - tool_args = tool_call["function"]["arguments"] + tool_name, tool_args = get_tool_details_by_endpoint_type(tool_call, endpoint_type) ## # TODO(sanyamk): Not all tool arguments might necessarily be in JSON format. @@ -75,6 +70,7 @@ async def _execute_tool_call(self, tool_call, request_id: str): try: tool_args = json.loads(tool_args) except json.decoder.JSONDecodeError as e: + LOG.error(f"Tool arguments are not in JSON format: {tool_args}") LOG.exception(e) return {"error": "Tool argument parsing failed."} @@ -88,17 +84,21 @@ async def _execute_tool_call(self, tool_call, request_id: str): return result - async def _execute_tool_calls(self, tool_calls: List, request_id: str): - tasks = [self._execute_tool_call(tool_call, request_id=request_id) for tool_call in tool_calls] + async def _execute_tool_calls(self, tool_calls: List, request_id: str, endpoint_type: EndpointType): + tasks = [ + self._execute_tool_call(tool_call, request_id=request_id, endpoint_type=endpoint_type) + for tool_call in tool_calls + ] tool_results = await asyncio.gather(*tasks) return [ - self.response_formatter.format(tool_call, tool_result) + format_tool_response_by_endpoint_type(tool_call, tool_result, endpoint_type) for tool_call, tool_result in zip(tool_calls, tool_results) ] async def generate_async( self, prompt: List, + endpoint_type: EndpointType, tools: List[dict] = None, tokens_to_generate: int = None, **generation_kwargs, @@ -109,7 +109,8 @@ async def generate_async( # This assumes that the available tools do not change during the generation. raw_tools = await self.tool_manager.list_all_tools(use_cache=True) - tools = self.schema_adapter.convert(raw_tools) + tools = format_tool_list_by_endpoint_type(raw_tools, endpoint_type) + LOG.info("Available Tools: %s", tools) result_steps = defaultdict(list) conversation = copy.deepcopy(prompt) @@ -124,6 +125,7 @@ async def generate_async( prompt=conversation, tools=tools, tokens_to_generate=tokens_to_generate, + endpoint_type=endpoint_type, **generation_kwargs, ) if isinstance(tokens_to_generate, int): @@ -133,18 +135,15 @@ async def generate_async( if k in generation: result_steps[k].append(generation[k]) - conversation.append({"role": "assistant", "content": generation["generation"]}) - if "reasoning_content" in generation: - conversation[-1]["reasoning_content"] = generation["reasoning_content"] + conversation.extend(generation["serialized_output"]) tool_calls = generation.get("tool_calls", []) if tool_calls: - tool_calls_message = self.call_interpreter.parse(tool_calls) - conversation[-1].update(tool_calls_message) - + tool_calls = [tool_call.model_dump() for tool_call in tool_calls] tool_calls_output_messages = await self._execute_tool_calls( - tool_calls_message["tool_calls"], request_id=request_id + tool_calls, request_id=request_id, endpoint_type=endpoint_type ) + LOG.info("Sending tool calls: %s", tool_calls_output_messages) conversation.extend(tool_calls_output_messages) result_steps["num_tool_calls"].append(len(tool_calls)) diff --git a/nemo_skills/inference/model/vllm.py b/nemo_skills/inference/model/vllm.py index e70149c88b..e9a2146520 100644 --- a/nemo_skills/inference/model/vllm.py +++ b/nemo_skills/inference/model/vllm.py @@ -138,3 +138,11 @@ def _build_chat_request_params( request["allowed_openai_params"] = ["reasoning_effort"] request["reasoning_effort"] = reasoning_effort return request + + def _build_responses_request_params(self, input, **kwargs) -> dict: + # Parameters are the same as chat completion request params + # For now, we hack this by renaming messages to input + # Until we need more parameters for responses API + responses_params = self._build_chat_request_params(messages=input, **kwargs) + responses_params["input"] = responses_params.pop("messages") + return responses_params diff --git a/nemo_skills/inference/patch_litellm_logging.py b/nemo_skills/inference/patch_litellm_logging.py new file mode 100644 index 0000000000..bb20c70b67 --- /dev/null +++ b/nemo_skills/inference/patch_litellm_logging.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Patch for litellm.litellm_core_utils.logging_worker.LoggingWorker +to disable its functionality and make all methods no-op. + +Currently, the async_loop function in generate.py sometimes gets stuck forever because some of the coroutines in the litellm logging worker are not finished. +Debugging why the logger is not finished is non-trivial, so we just patch it to disable its functionality. +The behavior is that it keeps slurm jobs from existing and we waste gpus. +It always happens in docker containers, but does not happen locally. +""" + +from typing import Coroutine + + +class NoOpLoggingWorker: + """No-op implementation of LoggingWorker that disables all functionality.""" + + def __init__(self, *args, **kwargs): + pass + + def _ensure_queue(self) -> None: + pass + + def start(self) -> None: + pass + + async def _worker_loop(self) -> None: + pass + + def enqueue(self, coroutine: Coroutine) -> None: + if coroutine is not None: + coroutine.close() + + def ensure_initialized_and_enqueue(self, async_coroutine: Coroutine): + if async_coroutine is not None: + async_coroutine.close() + + async def stop(self) -> None: + pass + + async def flush(self) -> None: + pass + + async def clear_queue(self): + pass + + +def patch_litellm_logging_worker(): + """ + Patches the litellm LoggingWorker to disable its functionality. + This prevents any logging worker from keeping the client alive. + """ + try: + import litellm.litellm_core_utils.logging_worker as logging_worker_module + + logging_worker_module.LoggingWorker = NoOpLoggingWorker + logging_worker_module.GLOBAL_LOGGING_WORKER = NoOpLoggingWorker() + except ModuleNotFoundError: + # Ensure compatibility with different litellm versions + pass diff --git a/nemo_skills/mcp/adapters.py b/nemo_skills/mcp/adapters.py index 6e385d9757..a927000403 100644 --- a/nemo_skills/mcp/adapters.py +++ b/nemo_skills/mcp/adapters.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. - import json from abc import ABC, abstractmethod -from typing import List from litellm.types.utils import ChatCompletionMessageToolCall +from nemo_skills.inference.model.base import EndpointType + # ============================== # ADAPTER INTERFACES @@ -48,9 +48,8 @@ def format(self, tool_call: ChatCompletionMessageToolCall, result: dict) -> dict # ============================== -class OpenAISchemaAdapter(ToolSchemaAdapter): - # https://platform.openai.com/docs/guides/function-calling#defining-functions - def convert(self, tools): +def format_tool_list_by_endpoint_type(tools, endpoint_type: EndpointType): + if endpoint_type == EndpointType.chat: return [ { "type": "function", @@ -62,6 +61,19 @@ def convert(self, tools): } for t in tools ] + elif endpoint_type == EndpointType.responses: + return [ + { + "type": "function", + "name": t["name"], + "description": t["description"], + "parameters": t["input_schema"], + "strict": True, # Less vllm errors through structured output + } + for t in tools + ] + else: + raise ValueError(f"Unsupported completion type for tool list: {endpoint_type}") class OpenAICallInterpreter(ToolCallInterpreter): @@ -81,38 +93,32 @@ def format(self, tool_call: ChatCompletionMessageToolCall, result): } -class ChatCompletionCallInterpreter(ToolCallInterpreter): - """Convert tool calls to a chat message item. - - Should be broadly compatible with any OpenAI-like APIs, - and HuggingFace Chat templates. - - NOTE(sanyamk): For error handling, delay JSON parsing of arguments to the model class. - """ - - def parse(self, tool_calls: List[ChatCompletionMessageToolCall]): - tool_calls = [ - { - "type": tool_call.type, - "id": tool_call.id, - "function": {"name": tool_call.function.name, "arguments": tool_call.function.arguments}, - } - for tool_call in tool_calls - ] - - return {"role": "assistant", "tool_calls": tool_calls} - - -class ChatCompletionResponseFormatter(ToolResponseFormatter): - """Convert tool call result to chat message item. - - Use in conjunction with `ChatCompletionCallInterpreter`. - """ - - def format(self, tool_call, result): +def format_tool_response_by_endpoint_type(tool_call, result, endpoint_type: EndpointType): + if endpoint_type == EndpointType.chat: return { "role": "tool", "name": tool_call["function"]["name"], "tool_call_id": tool_call["id"], "content": json.dumps(result) if not isinstance(result, str) else result, } + elif endpoint_type == EndpointType.responses: + return { + "type": "function_call_output", + "call_id": tool_call["call_id"], + "output": json.dumps(result) if not isinstance(result, str) else result, + } + else: + raise ValueError(f"Unsupported completion type for tool call: {endpoint_type}") + + +def get_tool_details_by_endpoint_type(tool_call, endpoint_type: EndpointType): + if endpoint_type == EndpointType.chat: + tool_name = tool_call["function"]["name"] + tool_args = tool_call["function"]["arguments"] + elif endpoint_type == EndpointType.responses: + assert tool_call["type"] == "function_call", "Tool call must be a function call" + tool_name = tool_call["name"] + tool_args = tool_call["arguments"] + else: + raise ValueError(f"Unsupported completion type for tool call: {endpoint_type}") + return tool_name, tool_args diff --git a/nemo_skills/training/openrlhf/math_reward.py b/nemo_skills/training/openrlhf/math_reward.py index ce259f1e31..b0dd87ddc9 100644 --- a/nemo_skills/training/openrlhf/math_reward.py +++ b/nemo_skills/training/openrlhf/math_reward.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import json import os @@ -57,7 +58,9 @@ def reward_func(queries: list[str], prompts: list[str], prompt_metadata: list[di judge_prompts = [prompt.fill(dp) for dp in data_points] if len(judge_prompts) > 0: # Too slow, but we are no longer supporting openrlhf anyways. - outputs = [llm.generate_sync(prompt=jp, stop_phrases=prompt.stop_phrases) for jp in judge_prompts] + outputs = [ + asyncio.run(llm.generate_async(prompt=jp, stop_phrases=prompt.stop_phrases)) for jp in judge_prompts + ] else: outputs = [] judgements = [] diff --git a/requirements/main.txt b/requirements/main.txt index 27fd526960..c796657121 100644 --- a/requirements/main.txt +++ b/requirements/main.txt @@ -27,7 +27,7 @@ hydra-core ipython iso639-lang langcodes -litellm[caching] < 1.75.0 # some bug with asyncio.run hanging forever +litellm[caching] == 1.77.5 # Requires patching the logging worker (See nemo_skills/inference/patch_litellm_logging.py) math-verify[antlr4_9_3] mcp nemo_run @ git+https://github.com/NVIDIA/NeMo-Run diff --git a/tests/slurm-tests/gpt_oss_python_aime25/run_test.py b/tests/slurm-tests/gpt_oss_python_aime25/run_test.py index d0d5e48431..866f604f17 100644 --- a/tests/slurm-tests/gpt_oss_python_aime25/run_test.py +++ b/tests/slurm-tests/gpt_oss_python_aime25/run_test.py @@ -24,7 +24,7 @@ def eval_gpt_oss_python(workspace, cluster, expname_prefix, wandb_project): "++inference.temperature=1.0 " "++inference.top_p=1.0 " "++prompt_config=gpt-oss/math " - "++use_completions_api=true " + "++inference.endpoint_type=text " "++code_tags=gpt-oss " "++code_execution=true " "++server.code_execution.max_code_executions=100 "