diff --git a/nemo_gym/cli.py b/nemo_gym/cli.py index 40e1ac447..c7cf269b6 100644 --- a/nemo_gym/cli.py +++ b/nemo_gym/cli.py @@ -49,6 +49,10 @@ GlobalConfigDictParserConfig, get_global_config_dict, ) +from nemo_gym.ray_utils import ( + _NeMoGymRayGPUSchedulingHelperActorProxy, + _start_global_ray_gpu_scheduling_helper, +) from nemo_gym.server_utils import ( HEAD_SERVER_KEY_NAME, HeadServer, @@ -162,6 +166,7 @@ class ServerInstanceDisplayConfig(BaseModel): class RunHelper: # pragma: no cover _head_server: uvicorn.Server _head_server_thread: Thread + _head_ray_gpu_helper: _NeMoGymRayGPUSchedulingHelperActorProxy _processes: Dict[str, Popen] _server_instance_display_configs: List[ServerInstanceDisplayConfig] @@ -174,6 +179,8 @@ def start(self, global_config_dict_parser_config: GlobalConfigDictParserConfig) # Note: This function will modify the global config dict - update `ray_head_node_address` initialize_ray() + self._head_ray_gpu_helper = _start_global_ray_gpu_scheduling_helper() + # Assume Nemo Gym Run is for a single agent. escaped_config_dict_yaml_str = shlex.quote(OmegaConf.to_yaml(global_config_dict)) diff --git a/nemo_gym/global_config.py b/nemo_gym/global_config.py index d69b3025d..99bc8b570 100644 --- a/nemo_gym/global_config.py +++ b/nemo_gym/global_config.py @@ -45,6 +45,10 @@ DISALLOWED_PORTS_KEY_NAME = "disallowed_ports" HEAD_SERVER_DEPS_KEY_NAME = "head_server_deps" PYTHON_VERSION_KEY_NAME = "python_version" +RAY_HEAD_NODE_ADDRESS_KEY_NAME = "ray_head_node_address" +RAY_NAMESPACE_KEY_NAME = "ray_namespace" +RAY_GPU_NODES_KEY_NAME = "ray_gpu_nodes" +RAY_NUM_GPUS_PER_NODE_KEY_NAME = "ray_num_gpus_per_node" USE_ABSOLUTE_IP = "use_absolute_ip" NEMO_GYM_RESERVED_TOP_LEVEL_KEYS = [ CONFIG_PATHS_KEY_NAME, @@ -54,6 +58,10 @@ DISALLOWED_PORTS_KEY_NAME, HEAD_SERVER_DEPS_KEY_NAME, PYTHON_VERSION_KEY_NAME, + RAY_HEAD_NODE_ADDRESS_KEY_NAME, + RAY_NAMESPACE_KEY_NAME, + RAY_GPU_NODES_KEY_NAME, + RAY_NUM_GPUS_PER_NODE_KEY_NAME, USE_ABSOLUTE_IP, ] diff --git a/nemo_gym/ray_utils.py b/nemo_gym/ray_utils.py new file mode 100644 index 000000000..f8fd14193 --- /dev/null +++ b/nemo_gym/ray_utils.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +from collections import defaultdict +from time import sleep +from typing import Dict, Optional + +import ray +import ray.util.state +from ray.actor import ActorClass, ActorProxy +from ray.util import get_node_ip_address +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + +from nemo_gym.global_config import ( + RAY_GPU_NODES_KEY_NAME, + RAY_NUM_GPUS_PER_NODE_KEY_NAME, + get_global_config_dict, +) + + +class _NeMoGymRayGPUSchedulingHelper: # pragma: no cover + def __init__(self, cfg): + self.cfg = cfg + self.avail_gpus_dict = defaultdict(int) + self.used_gpus_dict = defaultdict(int) + + @ray.method + def _post_init(self) -> None: + # If value of RAY_GPU_NODES_KEY_NAME is None, then Gym will use all Ray GPU nodes + # for scheduling GPU actors. + # Otherwise if value of RAY_GPU_NODES_KEY_NAME is a list, then Gym will only use + # the listed Ray GPU nodes for scheduling GPU actors. + allowed_gpu_nodes = self.cfg.get(RAY_GPU_NODES_KEY_NAME, None) + if allowed_gpu_nodes is not None: + allowed_gpu_nodes = set(allowed_gpu_nodes) + + head = self.cfg["ray_head_node_address"] + node_states = ray.util.state.list_nodes(head, detail=True) + for state in node_states: + assert state.node_id is not None + avail_num_gpus = state.resources_total.get("GPU", 0) + if allowed_gpu_nodes is not None and state.node_id not in allowed_gpu_nodes: + continue + self.avail_gpus_dict[state.node_id] += avail_num_gpus + + def alloc_gpu_node(self, num_gpus: int) -> Optional[str]: + for node_id, avail_num_gpus in self.avail_gpus_dict.items(): + used_num_gpus = self.used_gpus_dict[node_id] + if used_num_gpus + num_gpus <= avail_num_gpus: + self.used_gpus_dict[node_id] += num_gpus + return node_id + return None + + +_NeMoGymRayGPUSchedulingHelperActor: ActorClass[_NeMoGymRayGPUSchedulingHelper] = ray.remote( + _NeMoGymRayGPUSchedulingHelper +) # pragma: no cover +_NeMoGymRayGPUSchedulingHelperActorProxy = ActorProxy[_NeMoGymRayGPUSchedulingHelper] # pragma: no cover + + +def _start_global_ray_gpu_scheduling_helper( + node_id: Optional[str] = None, +) -> _NeMoGymRayGPUSchedulingHelperActorProxy: # pragma: no cover + cfg = get_global_config_dict() + helper_options = { + "name": "_NeMoGymRayGPUSchedulingHelper", + "num_cpus": 0, + } + if node_id is not None: + helper_options["scheduling_strategy"] = NodeAffinitySchedulingStrategy( + node_id=node_id, + soft=True, + ) + helper = _NeMoGymRayGPUSchedulingHelperActor.options(**helper_options).remote(cfg) + ray.get(helper._post_init.remote()) + return helper + + +def get_global_ray_gpu_scheduling_helper() -> _NeMoGymRayGPUSchedulingHelperActorProxy: # pragma: no cover + cfg = get_global_config_dict() + while True: + try: + get_actor_args = { + "name": "_NeMoGymRayGPUSchedulingHelper", + } + ray_namespace = cfg.get("ray_namespace", None) + if ray_namespace is None: + ray_namespace = "nemo_gym" + get_actor_args["namespace"] = ray_namespace + worker = ray.get_actor(**get_actor_args) + return worker + except ValueError: + sleep(3) + + +def lookup_ray_node_id_to_ip_dict() -> Dict[str, str]: # pragma: no cover + cfg = get_global_config_dict() + head = cfg["ray_head_node_address"] + id_to_ip = {} + node_states = ray.util.state.list_nodes(head) + for state in node_states: + id_to_ip[state.node_id] = state.node_ip + return id_to_ip + + +def lookup_current_ray_node_id() -> str: # pragma: no cover + return ray.get_runtime_context().get_node_id() + + +def lookup_current_ray_node_ip() -> str: # pragma: no cover + return get_node_ip_address() + + +def _prepare_ray_worker_env_vars() -> Dict[str, str]: # pragma: no cover + worker_env_vars = { + **os.environ, + } + pop_env_vars = [ + "CUDA_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", + "RAY_JOB_ID", + "RAY_RAYLET_PID", + ] + for k in pop_env_vars: + worker_env_vars.pop(k, None) + return worker_env_vars + + +def spinup_single_ray_gpu_node_worker( + worker_cls: ActorClass, + num_gpus: int, + *worker_args, + **worker_kwargs, +) -> ActorProxy: # pragma: no cover + cfg = get_global_config_dict() + + num_gpus_per_node = cfg.get(RAY_NUM_GPUS_PER_NODE_KEY_NAME, 8) + assert num_gpus >= 1, f"Must request at least 1 GPU node for spinning up {worker_cls}" + assert num_gpus <= num_gpus_per_node, ( + f"Requested {num_gpus} > {num_gpus_per_node} GPU nodes for spinning up {worker_cls}" + ) + + helper = get_global_ray_gpu_scheduling_helper() + node_id = ray.get(helper.alloc_gpu_node.remote(num_gpus)) + if node_id is None: + raise RuntimeError(f"Cannot find an available Ray node with {num_gpus} GPUs to spin up {worker_cls}") + + worker_options = {} + worker_options["num_gpus"] = num_gpus + worker_options["scheduling_strategy"] = NodeAffinitySchedulingStrategy( + node_id=node_id, + soft=False, + ) + worker_options["runtime_env"] = { + "py_executable": sys.executable, + "env_vars": _prepare_ray_worker_env_vars(), + } + worker = worker_cls.options(**worker_options).remote(*worker_args, **worker_kwargs) + return worker diff --git a/nemo_gym/server_utils.py b/nemo_gym/server_utils.py index 88466daef..cc67c8e82 100644 --- a/nemo_gym/server_utils.py +++ b/nemo_gym/server_utils.py @@ -350,6 +350,7 @@ def initialize_ray() -> None: global_config_dict = get_global_config_dict() ray_head_node_address = global_config_dict.get("ray_head_node_address") + ray_namespace = global_config_dict.get("ray_namespace", None) ray_init_kwargs = dict(ignore_reinit_error=True) if ray_head_node_address: @@ -358,6 +359,11 @@ def initialize_ray() -> None: else: print("Starting Ray cluster...") + if ray_namespace is None: + ray_namespace = "nemo_gym" + print(f"Ray namespace: {ray_namespace}") + ray_init_kwargs["namespace"] = ray_namespace + ray.init(**ray_init_kwargs) if not ray_head_node_address: diff --git a/responses_api_models/sglang_model/README.md b/responses_api_models/sglang_model/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/responses_api_models/sglang_model/app.py b/responses_api_models/sglang_model/app.py new file mode 100644 index 000000000..26efaedc6 --- /dev/null +++ b/responses_api_models/sglang_model/app.py @@ -0,0 +1,411 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from multiprocessing import Process +from time import sleep, time +from typing import Any, Dict, List, Optional, Union +from uuid import uuid4 + +import ray +from aiohttp.client_exceptions import ClientResponseError +from fastapi import Request + +from nemo_gym.base_responses_api_model import ( + BaseResponsesAPIModelConfig, + Body, + SimpleResponsesAPIModel, +) +from nemo_gym.global_config import find_open_port +from nemo_gym.openai_utils import ( + NeMoGymAsyncOpenAI, + NeMoGymChatCompletion, + NeMoGymChatCompletionCreateParamsNonStreaming, + NeMoGymChatCompletionMessage, + NeMoGymChoice, + NeMoGymResponse, + NeMoGymResponseCreateParamsNonStreaming, +) +from nemo_gym.ray_utils import ( + lookup_current_ray_node_ip, + spinup_single_ray_gpu_node_worker, +) +from nemo_gym.server_utils import SESSION_ID_KEY +from responses_api_models.vllm_model.app import ( + VLLMConverter, + _vllm_server_heartbeat, +) + + +class SGLangModelConfig(BaseResponsesAPIModelConfig): + base_url: Union[str, List[str]] + api_key: str + model: str + return_token_id_information: bool + + uses_reasoning_parser: bool + replace_developer_role_with_system: bool = False + + spinup_server: bool = False + server_args: Optional[Dict[str, Any]] = None + + router_dp_size: int = 1 + + def model_post_init(self, context): + if isinstance(self.base_url, str): + self.base_url = [self.base_url] + return super().model_post_init(context) + + +def _start_sglang_server(config: SGLangModelConfig, server_host: str, server_port: int, router_dp_rank: int) -> None: + import sglang.srt.entrypoints.http_server + import sglang.srt.server_args + + argv = [] + argv.append("--model-path") + argv.append(config.model) + argv.append("--host") + argv.append(server_host) + argv.append("--port") + argv.append(f"{server_port}") + for k, v in (config.server_args or {}).items(): + k2 = k.replace("_", "-") + if v is None: + pass + elif isinstance(v, bool): + if not v: + arg_key = f"--no-{k2}" + else: + arg_key = f"--{k2}" + argv.append(arg_key) + else: + arg_key = f"--{k2}" + argv.append(arg_key) + argv.append(f"{v}") + + server_args = sglang.srt.server_args.prepare_server_args(argv) + sglang.srt.entrypoints.http_server.launch_server(server_args) + + +@ray.remote +class SGLangServerSpinupWorker: + def __init__(self, config: SGLangModelConfig, working_dir: Optional[str], router_dp_rank: int): + self.config = config + self.working_dir = working_dir + self.router_dp_rank = router_dp_rank + self._server_host = lookup_current_ray_node_ip() + self._server_port = find_open_port() + + if self.working_dir is not None: + os.chdir(self.working_dir) + + server_proc = Process( + target=_start_sglang_server, + args=( + self.config, + self._server_host, + self._server_port, + self.router_dp_rank, + ), + daemon=False, + ) + server_proc.start() + self._server_proc = server_proc + + def _get_ip(self) -> int: + return self._server_host + + def _get_port(self) -> int: + return self._server_port + + +class SGLangModel(SimpleResponsesAPIModel): + config: SGLangModelConfig + + def model_post_init(self, context): + working_dir = os.getcwd() + + if self.config.spinup_server: + self._server_urls = [] + self._server_workers = [] + self._clients = [] + + # TODO: support for other parallel sizes. + server_tp_size = (self.config.server_args or {}).get("tensor_parallel_size", 1) + server_dp_size = (self.config.server_args or {}).get("data_parallel_size", 1) + + assert server_dp_size == 1 + + router_dp_size = max(1, self.config.router_dp_size) + + for router_dp_rank in range(router_dp_size): + server_worker = spinup_single_ray_gpu_node_worker( + SGLangServerSpinupWorker, + server_tp_size, + config=self.config, + working_dir=working_dir, + router_dp_rank=router_dp_rank, + ) + + server_ip = ray.get(server_worker._get_ip.remote()) + server_port = ray.get(server_worker._get_port.remote()) + server_url = f"http://{server_ip}:{server_port}/v1" + + self._server_urls.append(server_url) + self._server_workers.append(server_worker) + + self._clients.append( + NeMoGymAsyncOpenAI( + base_url=server_url, + api_key=self.config.api_key, + ) + ) + + for server_url in self._server_urls: + while True: + try: + _vllm_server_heartbeat(server_url) + break + except Exception: + sleep(3) + continue + + else: + self._server_urls = None + self._server_workers = None + self._clients = [ + NeMoGymAsyncOpenAI( + base_url=base_url, + api_key=self.config.api_key, + ) + for base_url in self.config.base_url + ] + + self._session_id_to_client: Dict[str, NeMoGymAsyncOpenAI] = dict() + + # TODO: sglang converter. + self._converter = VLLMConverter( + return_token_id_information=self.config.return_token_id_information, + ) + + return super().model_post_init(context) + + async def responses( + self, request: Request, body: NeMoGymResponseCreateParamsNonStreaming = Body() + ) -> NeMoGymResponse: + # Response Create Params -> Chat Completion Create Params + chat_completion_create_params = self._converter.responses_to_chat_completion_create_params(body) + body.model = self.config.model + + # Chat Completion Create Params -> Chat Completion + chat_completion_response = await self.chat_completions(request, chat_completion_create_params) + + choice = chat_completion_response.choices[0] + + response_output = self._converter.postprocess_chat_response(choice) + response_output_dicts = [item.model_dump() for item in response_output] + + # Chat Completion -> Response + return NeMoGymResponse( + id=f"resp_{uuid4().hex}", + created_at=int(time()), + model=body.model, + object="response", + output=response_output_dicts, + tool_choice=body.tool_choice if "tool_choice" in body else "auto", + parallel_tool_calls=body.parallel_tool_calls, + tools=body.tools, + temperature=body.temperature, + top_p=body.top_p, + background=body.background, + max_output_tokens=body.max_output_tokens, + max_tool_calls=body.max_tool_calls, + previous_response_id=body.previous_response_id, + prompt=body.prompt, + reasoning=body.reasoning, + service_tier=body.service_tier, + text=body.text, + top_logprobs=body.top_logprobs, + truncation=body.truncation, + metadata=body.metadata, + instructions=body.instructions, + user=body.user, + ) + + async def chat_completions( + self, request: Request, body: NeMoGymChatCompletionCreateParamsNonStreaming = Body() + ) -> NeMoGymChatCompletion: + if self.config.replace_developer_role_with_system: + for message in body.messages: + if message["role"] == "developer": + message["role"] = "system" + + body_dict = body.model_dump(exclude_unset=True) + body_dict["model"] = self.config.model + + session_id = request.session[SESSION_ID_KEY] + if session_id not in self._session_id_to_client: + # There is probably a better way to select the endpoint for this request. But this will do for now. + client_idx = len(self._session_id_to_client) % len(self._clients) + client = self._clients[client_idx] + self._session_id_to_client[session_id] = client + client = self._session_id_to_client[session_id] + + create_params = body_dict + + if self.config.return_token_id_information: + create_params |= dict( + logprobs=True, + # Typically passed via OpenAI client extra_body. + return_tokens_as_token_ids=True, + # TODO add this when NeMo RL upgrades to vLLM 0.10.2 support for prompt token ids + # For prompt and generation token IDs + # return_token_ids=True, + # For prompt token IDs + # prompt_logprobs=0, + ) + + if self.config.uses_reasoning_parser: + for message_dict in body_dict["messages"]: + if message_dict.get("role") != "assistant" or "content" not in message_dict: + continue + + content = message_dict["content"] + if isinstance(content, str): + reasoning_matches, remaining_content = self._converter._extract_reasoning_from_content(content) + message_dict["content"] = remaining_content + if reasoning_matches: + message_dict["reasoning_content"] = reasoning_matches[0] + elif isinstance(content, list): + reasoning_content = None + for content_item_dict in content: + reasoning_matches, remaining_content = self._converter._extract_reasoning_from_content( + content_item_dict["text"] + ) + assert reasoning_content is None or not reasoning_matches, ( + f"Found multiple reasoning matches in a single assistant message content item list!\nMessage: {message_dict}" + ) + + # Even though we set the reasoning content already here, we still loop through all the content item dicts for the assert above. + content_item_dict["text"] = remaining_content + if reasoning_matches: + message_dict["reasoning_content"] = reasoning_matches[0] + elif not content: + # No content or content None is a no-op + pass + else: + raise NotImplementedError + + try: + chat_completion_dict = await client.create_chat_completion(**create_params) + except ClientResponseError as e: + """ + Example messages for out of context length: + + 1. https://github.com/vllm-project/vllm/blob/685c99ee77b4818dcdd15b30fe0e0eff0d5d22ec/vllm/entrypoints/openai/serving_engine.py#L914 + ```json + {"object":"error","message":"This model\'s maximum context length is 32768 tokens. However, you requested 32818 tokens in the messages, Please reduce the length of the messages. None","type":"BadRequestError","param":null,"code":400} + ``` + 2. https://github.com/vllm-project/vllm/blob/685c99ee77b4818dcdd15b30fe0e0eff0d5d22ec/vllm/entrypoints/openai/serving_engine.py#L940 + 3. https://github.com/vllm-project/vllm/blob/685c99ee77b4818dcdd15b30fe0e0eff0d5d22ec/vllm/entrypoints/openai/serving_engine.py#L948 + 4. https://github.com/vllm-project/vllm/blob/685c99ee77b4818dcdd15b30fe0e0eff0d5d22ec/vllm/sampling_params.py#L463 + """ + result_content_str = e.response_content.decode() + + is_out_of_context_length = e.status == 400 and ( + "context length" in result_content_str or "max_tokens" in result_content_str + ) + if is_out_of_context_length: + return NeMoGymChatCompletion( + id="chtcmpl-123", + object="chat.completion", + created=int(time()), + model=self.config.model, + choices=[ + NeMoGymChoice( + index=0, + finish_reason="stop", + message=NeMoGymChatCompletionMessage( + role="assistant", + content=None, + tool_calls=None, + ), + ) + ], + ) + else: + raise e + + choice_dict = chat_completion_dict["choices"][0] + if self.config.uses_reasoning_parser: + reasoning_content = choice_dict["message"].get("reasoning_content") + if reasoning_content: + choice_dict["message"].pop("reasoning_content") + + # We wrap this here in think tags for Gym's sake and to return a valid OpenAI Chat Completions response. + choice_dict["message"]["content"] = self._converter._wrap_reasoning_in_think_tags( + [reasoning_content] + ) + (choice_dict["message"]["content"] or "") + else: + assert not choice_dict["message"].get("reasoning_content"), ( + "Please do not use a reasoning parser in vLLM! There is one source of truth for handling data (including reasoning), which is NeMo Gym!" + ) + + if self.config.return_token_id_information: + log_probs = choice_dict["logprobs"]["content"] + generation_log_probs = [log_prob["logprob"] for log_prob in log_probs] + + """ + START TODO remove this when NeMo RL upgrades to vLLM 0.10.2 support for prompt token ids + """ + # Looks like `"token_id:151667"` + generation_token_ids = [log_prob["token"].removeprefix("token_id:") for log_prob in log_probs] + + # The tokenize endpoint doesn't accept any sampling parameters + # The only relevant params are model, messages, and tools. + tokenize_body_dict = dict() + for key in ("model", "messages", "tools"): + if key in body_dict: + tokenize_body_dict[key] = body_dict[key] + + # The base url has /v1 at the end but vLLM's tokenize endpoint does not have v1, hence the .. + # I can't believe the path is resolved correctly LOL + tokenize_response = await client.create_tokenize(**tokenize_body_dict) + """ + END + """ + + message_dict = choice_dict["message"] + message_dict.update( + dict( + # TODO add this when NeMo RL upgrades to vLLM 0.10.2 support for prompt token ids + # prompt_token_ids=chat_completion_dict["prompt_token_ids"], + prompt_token_ids=tokenize_response["tokens"], + # generation_token_ids=choice_dict["token_ids"], + generation_token_ids=generation_token_ids, + generation_log_probs=generation_log_probs, + ) + ) + + # Clean the duplicated information + choice_dict.pop("logprobs") + # TODO add this when NeMo RL upgrades to vLLM 0.10.2 support for prompt token ids + # chat_completion_dict.pop("prompt_token_ids") + # choice_dict.pop("token_ids") + + return NeMoGymChatCompletion.model_validate(chat_completion_dict) + + +if __name__ == "__main__": + SGLangModel.run_webserver() diff --git a/responses_api_models/sglang_model/pyproject.toml b/responses_api_models/sglang_model/pyproject.toml new file mode 100644 index 000000000..6b82c62c2 --- /dev/null +++ b/responses_api_models/sglang_model/pyproject.toml @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[project] +name = "sglang-model" +version = "0.2.0rc0" +requires-python = ">=3.12" +dependencies = [ + "nemo-gym[dev]", + + # We specifically pin the sglang dependency because we have tested on this version. + # Updated Thu Dec 11 2025 with sglang[all]==0.5.2 + # License: Apache 2.0 https://github.com/sgl-project/sglang/blob/62a4a339ebc1b2a9ecf5deac10ebf1de9108bca3/LICENSE + "sglang[all]==0.5.2", +] + +[build-system] +build-backend = "setuptools.build_meta" +requires = ["setuptools>=61", "setuptools-scm"] + +[tool.setuptools.packages.find] +where = [".."] +include = ["sglang_model"] + +[tool.uv.sources] +nemo-gym = { path = "../..", editable = true } diff --git a/responses_api_models/vllm_model/app.py b/responses_api_models/vllm_model/app.py index b263dd965..429495c26 100644 --- a/responses_api_models/vllm_model/app.py +++ b/responses_api_models/vllm_model/app.py @@ -12,11 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json +import os import re -from time import time -from typing import ClassVar, Dict, List, Optional, Tuple, Union +import urllib +from multiprocessing import Process +from time import sleep, time +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union from uuid import uuid4 +import ray from aiohttp.client_exceptions import ClientResponseError from fastapi import Request from pydantic import BaseModel, Field @@ -26,6 +31,7 @@ Body, SimpleResponsesAPIModel, ) +from nemo_gym.global_config import find_open_port from nemo_gym.openai_utils import ( RESPONSES_TO_TRAIN, NeMoGymAsyncOpenAI, @@ -54,6 +60,10 @@ NeMoGymSummary, TokenIDLogProbMixin, ) +from nemo_gym.ray_utils import ( + lookup_current_ray_node_ip, + spinup_single_ray_gpu_node_worker, +) from nemo_gym.server_utils import SESSION_ID_KEY @@ -66,23 +76,189 @@ class VLLMModelConfig(BaseResponsesAPIModelConfig): uses_reasoning_parser: bool replace_developer_role_with_system: bool = False + spinup_server: bool = False + server_args: Optional[Dict[str, Any]] = None + + router_dp_size: int = 1 + def model_post_init(self, context): if isinstance(self.base_url, str): self.base_url = [self.base_url] return super().model_post_init(context) +def _start_vllm_server(config: VLLMModelConfig, server_host: str, server_port: int, router_dp_rank: int) -> None: + import uvloop + import vllm.engine.arg_utils + import vllm.entrypoints.openai.api_server + import vllm.entrypoints.openai.cli_args + import vllm.utils.argparse_utils + + argv = [] + argv.append("--model") + argv.append(config.model) + argv.append("--host") + argv.append(server_host) + argv.append("--port") + argv.append(f"{server_port}") + argv.append("--distributed-executor-backend") + argv.append("mp") + for k, v in (config.server_args or {}).items(): + k2 = k.replace("_", "-") + if v is None: + pass + elif isinstance(v, bool): + if not v: + arg_key = f"--no-{k2}" + else: + arg_key = f"--{k2}" + argv.append(arg_key) + elif isinstance(v, dict): + # Dict values must be passed as JSON strings to vLLM CLI + arg_key = f"--{k2}" + argv.append(arg_key) + argv.append(json.dumps(v)) + else: + arg_key = f"--{k2}" + argv.append(arg_key) + argv.append(f"{v}") + + server_args = vllm.utils.argparse_utils.FlexibleArgumentParser() + server_args = vllm.entrypoints.openai.cli_args.make_arg_parser(server_args) + server_args = server_args.parse_args(argv) + vllm.entrypoints.openai.cli_args.validate_parsed_serve_args(server_args) + + uvloop.run(vllm.entrypoints.openai.api_server.run_server(server_args)) + + +@ray.remote +class VLLMServerSpinupWorker: + def __init__(self, config: VLLMModelConfig, working_dir: Optional[str], router_dp_rank: int): + self.config = config + self.working_dir = working_dir + self.router_dp_rank = router_dp_rank + self._server_host = lookup_current_ray_node_ip() + self._server_port = find_open_port() + + if self.working_dir is not None: + os.chdir(self.working_dir) + + server_proc = Process( + target=_start_vllm_server, + args=( + self.config, + self._server_host, + self._server_port, + self.router_dp_rank, + ), + daemon=False, + ) + server_proc.start() + self._server_proc = server_proc + + def _get_ip(self) -> int: + return self._server_host + + def _get_port(self) -> int: + return self._server_port + + +# Use this to query the VLLM servers during spinup without having to start an +# asyncio event loop for the async client. +def _vllm_server_heartbeat(base_url: str): + req_headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + req_body = { + "messages": [ + { + "role": "user", + "content": "hi", + } + ], + "max_tokens": 8, + "temperature": 1.0, + } + req_data = json.dumps(req_body).encode("utf-8") + req_url = f"{base_url}/chat/completions" + req = urllib.request.Request( + req_url, + headers=req_headers, + data=req_data, + ) + with urllib.request.urlopen(req, timeout=5) as out: + out_status = out.status + out_data = out.read() + output = out_data.decode("utf-8") + return { + "_status": out_status, + "output": output, + "except": None, + } + + class VLLMModel(SimpleResponsesAPIModel): config: VLLMModelConfig def model_post_init(self, context): - self._clients = [ - NeMoGymAsyncOpenAI( - base_url=base_url, - api_key=self.config.api_key, - ) - for base_url in self.config.base_url - ] + working_dir = os.getcwd() + + if self.config.spinup_server: + self._server_urls = [] + self._server_workers = [] + self._clients = [] + + # TODO: support for other parallel sizes. + server_tp_size = (self.config.server_args or {}).get("tensor_parallel_size", 1) + server_dp_size = (self.config.server_args or {}).get("data_parallel_size", 1) + + assert server_dp_size == 1 + + router_dp_size = max(1, self.config.router_dp_size) + + for router_dp_rank in range(router_dp_size): + server_worker = spinup_single_ray_gpu_node_worker( + VLLMServerSpinupWorker, + server_tp_size, + config=self.config, + working_dir=working_dir, + router_dp_rank=router_dp_rank, + ) + + server_ip = ray.get(server_worker._get_ip.remote()) + server_port = ray.get(server_worker._get_port.remote()) + server_url = f"http://{server_ip}:{server_port}/v1" + + self._server_urls.append(server_url) + self._server_workers.append(server_worker) + + self._clients.append( + NeMoGymAsyncOpenAI( + base_url=server_url, + api_key=self.config.api_key, + ) + ) + + for server_url in self._server_urls: + while True: + try: + _vllm_server_heartbeat(server_url) + break + except Exception: + sleep(3) + continue + + else: + self._server_urls = None + self._server_workers = None + self._clients = [ + NeMoGymAsyncOpenAI( + base_url=base_url, + api_key=self.config.api_key, + ) + for base_url in self.config.base_url + ] self._session_id_to_client: Dict[str, NeMoGymAsyncOpenAI] = dict() diff --git a/tests/unit_tests/test_server_utils.py b/tests/unit_tests/test_server_utils.py index dfd39da65..7313db752 100644 --- a/tests/unit_tests/test_server_utils.py +++ b/tests/unit_tests/test_server_utils.py @@ -194,7 +194,9 @@ def test_initialize_ray_with_address(self, monkeypatch: MonkeyPatch) -> None: ray_is_initialized_mock.assert_called_once() get_global_config_dict_mock.assert_called_once() - ray_init_mock.assert_called_once_with(address="ray://test-address:10001", ignore_reinit_error=True) + ray_init_mock.assert_called_once_with( + address="ray://test-address:10001", ignore_reinit_error=True, namespace="nemo_gym" + ) def test_initialize_ray_without_address(self, monkeypatch: MonkeyPatch) -> None: ray_is_initialized_mock = self._mock_ray_return_value(monkeypatch, False) @@ -217,5 +219,5 @@ def test_initialize_ray_without_address(self, monkeypatch: MonkeyPatch) -> None: ray_is_initialized_mock.assert_called_once() get_global_config_dict_mock.assert_called_once() - ray_init_mock.assert_called_once_with(ignore_reinit_error=True) + ray_init_mock.assert_called_once_with(ignore_reinit_error=True, namespace="nemo_gym") ray_get_runtime_context_mock.assert_called_once()