From d9cf489219cb0f3532f1a806a091aa482f17efef Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Sun, 23 Nov 2025 17:25:51 +0000 Subject: [PATCH 01/29] sglang support:initial commit Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- nemo_rl/models/generation/sglang/__init__.py | 0 nemo_rl/models/generation/sglang/config.py | 91 ++++++ .../generation/sglang/sglang_generation.py | 297 ++++++++++++++++++ .../models/generation/sglang/sglang_worker.py | 260 +++++++++++++++ 4 files changed, 648 insertions(+) create mode 100644 nemo_rl/models/generation/sglang/__init__.py create mode 100644 nemo_rl/models/generation/sglang/config.py create mode 100644 nemo_rl/models/generation/sglang/sglang_generation.py create mode 100644 nemo_rl/models/generation/sglang/sglang_worker.py diff --git a/nemo_rl/models/generation/sglang/__init__.py b/nemo_rl/models/generation/sglang/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_rl/models/generation/sglang/config.py b/nemo_rl/models/generation/sglang/config.py new file mode 100644 index 0000000000..12e99ad82b --- /dev/null +++ b/nemo_rl/models/generation/sglang/config.py @@ -0,0 +1,91 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, NotRequired, TypedDict + +from nemo_rl.models.generation.interfaces import GenerationConfig + + +class SGLangConfig(): + """Configuration for SGLang runtime. Refer to: + https://github.com/sgl-project/sglang for detailed documentation. + """ + + model_path: str = "" + random_seed: int = 1 + skip_tokenizer_init: bool = False + disable_cuda_graph: bool = False + disable_radix_cache: bool = True + disable_cuda_graph_padding: bool = False + enable_nccl_nvls: bool = False + disable_outlines_disk_cache: bool = False + disable_custom_all_reduce: bool = False + disable_overlap_schedule: bool = False + enable_mixed_chunk: bool = False + enable_dp_attention: bool = False + enable_ep_moe: bool = False + enable_torch_compile: bool = False + torch_compile_max_bs: int = 32 + cuda_graph_max_bs: int | None = None + cuda_graph_bs: list[int] | None = None + torchao_config: str = "" + enable_nan_detection: bool = False + enable_p2p_check: bool = False + triton_attention_reduce_in_fp32: bool = False + triton_attention_num_kv_splits: int = 8 + num_continuous_decode_steps: int = 1 + enable_memory_saver: bool = False + allow_auto_truncate: bool = False + attention_backend: str | None = "fa3" + enable_multimodal: bool = False + sampling_backend: str | None = None + context_length: int | None = 32768 + mem_fraction_static: float | None = 0.9 + max_running_requests: int | None = None + # NOTE: chunked_prefill_size is by default 8192 on GPUs with 80GB mem in SGLang, + # but we disable it to avoid precision issues + chunked_prefill_size: int | None = -1 + max_prefill_tokens: int = 32768 + schedule_policy: str = "lpm" + schedule_conservativeness: float = 1.0 + cpu_offload_gb: int = 0 + dtype: str = "bfloat16" + kv_cache_dtype: str = "auto" + dp_size: int = 1 # only used for dp attention + ep_size: int = 1 + # lora + enable_lora: bool | None = None + max_lora_rank: int | None = None + lora_target_modules: list[str] | None = None + lora_paths: list[str] | None = None + max_loaded_loras: int = 1 + max_loras_per_batch: int = 1 + lora_backend: str = "triton" + # logging + log_level: str = "warning" + log_level_http: str | None = "warning" + log_requests: bool = False + log_requests_level: int = 0 + show_time_cost: bool = False + enable_metrics: bool = True # Exports Prometheus-like metrics + # The interval (in decoding iterations) to log throughput + # and update prometheus metrics + decode_log_interval: int = 1 + # Extra loader arguments + # NOTE: These arguments will be parsed into a dict json-string + # and passed as `model_loader_extra_config` to SGLang. + enable_multithread_load: bool = False + enable_fast_load: bool = False + + \ No newline at end of file diff --git a/nemo_rl/models/generation/sglang/sglang_generation.py b/nemo_rl/models/generation/sglang/sglang_generation.py new file mode 100644 index 0000000000..f4bc4433f7 --- /dev/null +++ b/nemo_rl/models/generation/sglang/sglang_generation.py @@ -0,0 +1,297 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +from collections import defaultdict +from typing import ( + Any, + AsyncGenerator, + Optional, + Union, +) + +import numpy as np +import ray +from ray.util.placement_group import PlacementGroup + +from nemo_rl.distributed.batched_data_dict import BatchedDataDict, SlicedDataDict +from nemo_rl.distributed.named_sharding import NamedSharding +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup +from nemo_rl.models.generation.interfaces import ( + GenerationDatumSpec, + GenerationInterface, + GenerationOutputSpec, +) +from nemo_rl.models.generation.sglang.config import SGLangConfig + +# Global thresholds for top_k and top_p validation. +# While top-k/p are not supported, these values allow for token filtering while the logprobs should be compatible. +# See https://github.com/NVIDIA-NeMo/RL/issues/69 and https://github.com/NVIDIA-NeMo/RL/issues/237 for more details. +TOP_K_THRESHOLD = 8000 # Allow top_k >= 8000 (effectively no filtering) +TOP_P_THRESHOLD = 0.99 # Allow top_p >= 0.99 (close to 1.0) + + +class SGLangGeneration(GenerationInterface): + def __init__( + self, + cluster: RayVirtualCluster, + config: SGLangConfig, + name_prefix: str = "sglang_policy", + workers_per_node: Optional[Union[int, list[int]]] = None, + ): + """Initialize a SGLang policy with distributed workers. + + SGLang server manages TP/PP internally, but we still need to: + 1. Manage data parallel distribution across multiple servers + 2. Assign GPU bundles to each server + + Each server will see logical GPUs 0-N (via CUDA_VISIBLE_DEVICES set by Ray), + so we just need to tell SGLang how many GPUs to use (tp_size). + """ + # Store config + self.cfg = config + + # Get number of GPUs per server from config + # For SGLang, this is typically the tensor parallel size + # TODO: Add proper config field, hardcoded to 4 for now + gpus_per_server = self.cfg.get("gpus_per_server", None) + if gpus_per_server is None: + gpus_per_server = 4 + + # Calculate number of servers based on available resources + total_gpus = cluster.world_size() + num_servers = total_gpus // gpus_per_server + + if num_servers == 0: + raise ValueError( + f"Not enough GPUs. Need at least {gpus_per_server} GPUs per server, " + f"but only have {total_gpus} GPUs total." + ) + + if total_gpus % gpus_per_server != 0: + print( + f"[WARNING] Total GPUs ({total_gpus}) is not divisible by GPUs per server ({gpus_per_server}). " + f"Will use {num_servers} servers, leaving {total_gpus % gpus_per_server} GPUs unused." + ) + + self.dp_size = num_servers + self.gpus_per_server = gpus_per_server + + # Create sharding annotations with only data_parallel dimension + # Each server is independent, so we only need DP sharding + self.sharding_annotations = NamedSharding( + layout=np.arange(num_servers).reshape(num_servers), + names=["data_parallel"], + ) + + # Initialize placement groups + # For SGLang, we use PACK strategy to keep bundles together + strategy = None if self.cfg.get("colocated", {}).get("enabled", False) else "PACK" + cluster._init_placement_groups( + strategy=strategy, + use_unified_pg=False, # SGLang servers don't need cross-node model parallelism + ) + + # Create worker builder for SGLangGenerationWorker + worker_cls = "nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker" + worker_builder = RayWorkerBuilder(worker_cls, config) + + env_vars = {} + + # Allocate bundles for each server + # Each server gets consecutive bundles + bundle_indices_list = self._allocate_bundles_for_servers( + cluster, num_servers, gpus_per_server + ) + + # Create worker group with explicit bundle allocation + self.worker_group = RayWorkerGroup( + cluster, + worker_builder, + name_prefix=name_prefix, + bundle_indices_list=bundle_indices_list, + sharding_annotations=self.sharding_annotations, + env_vars=env_vars, + ) + + # Verify data parallel size matches + assert self.dp_size == self.worker_group.dp_size, ( + f"Data parallel size mismatch. Expected {self.dp_size}, got {self.worker_group.dp_size}" + ) + + # Used to track the round-robin selection of worker groups for generate_async + self.current_generate_dp_shard_idx = 0 + + def _allocate_bundles_for_servers( + self, + cluster: RayVirtualCluster, + num_servers: int, + gpus_per_server: int, + ) -> list[tuple[int, list[int]]]: + """Allocate GPU bundles to each SGLang server. + + Each server gets consecutive bundles within the same placement group (node). + Ray will automatically set CUDA_VISIBLE_DEVICES so each server sees logical GPUs 0, 1, 2, ..., gpus_per_server-1. + + Args: + cluster: The Ray virtual cluster + num_servers: Total number of SGLang servers to create + gpus_per_server: Number of GPUs each server needs + + Returns: + List of (node_idx, [bundle_indices]) tuples for each server + """ + placement_groups = cluster.get_placement_groups() + + if not placement_groups: + raise ValueError("No placement groups available in the cluster") + + bundle_indices_list = [] + + # Each server's bundles must be within the same placement group (node) + server_idx = 0 + for pg_idx, pg in enumerate(placement_groups): + if pg.bundle_count == 0: + continue + + # Calculate how many servers can fit in this placement group + num_servers_in_pg = pg.bundle_count // gpus_per_server + + # Allocate servers within this placement group + for local_server_idx in range(num_servers_in_pg): + if server_idx >= num_servers: + break + + # Calculate which bundles this server gets (consecutive within the PG) + start_bundle = local_server_idx * gpus_per_server + server_bundles = list(range(start_bundle, start_bundle + gpus_per_server)) + + # Each server gets a tuple of (node_idx, [local_bundle_indices]) + bundle_indices_list.append((pg_idx, server_bundles)) + server_idx += 1 + + if server_idx >= num_servers: + break + + if len(bundle_indices_list) < num_servers: + total_available = sum( + pg.bundle_count // gpus_per_server + for pg in placement_groups + if pg.bundle_count > 0 + ) + raise ValueError( + f"Not enough bundles to allocate all {num_servers} servers. " + f"Only {total_available} servers can be allocated " + f"(each server needs {gpus_per_server} GPUs)." + ) + + return bundle_indices_list + + + def generate( + self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + ) -> BatchedDataDict[GenerationOutputSpec]: + """Generate a batch of data using SGLang.""" + assert isinstance(data, BatchedDataDict), ( + f"data must be a BatchedDataDict, got type: {type(data)}" + ) + assert "input_ids" in data and "input_lengths" in data, ( + "input_ids and input_lengths are required in data for SGLang generation" + ) + + # Shard the data across the data parallel servers + dp_size = self.sharding_annotations.get_axis_size("data_parallel") + sharded_data: list[SlicedDataDict] = data.shard_by_batch_size( + dp_size, allow_uneven_shards=True + ) + future_bundle = self.worker_group.run_all_workers_sharded_data( + "generate", + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=None, + output_is_replicated=None, + common_kwargs={"greedy": greedy}, + ) + + # Get results from the workers + results = self.worker_group.get_all_worker_results(future_bundle) + + # Combine results from all servers + combined: BatchedDataDict[GenerationOutputSpec] = BatchedDataDict.from_batches( + results, pad_value_dict={"output_ids": self.cfg["_pad_token_id"]} + ) + + # Verify the output has all required fields + required_keys = [ + "output_ids", + "generation_lengths", + "unpadded_sequence_lengths", + "logprobs", + ] + missing_keys = [key for key in required_keys if key not in combined] + if missing_keys: + raise ValueError( + f"Missing required keys for GenerationOutputSpec: {missing_keys}" + ) + + return combined + + def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: + """Wake workers up for colocated inference.""" + pass + + def finish_generation(self, *args: Any, **kwargs: Any) -> bool: + """Sleep workers and reset prefix cache.""" + pass + + def shutdown(self) -> bool: + """Shut down all SGLang workers and clean up resources.""" + try: + # Use the worker group's shutdown method with the worker's cleanup method + return self.worker_group.shutdown(cleanup_method="shutdown") + except Exception as e: + print(f"Error during SGLang policy shutdown: {e}") + return False + + def __del__(self) -> None: + """Shuts down the worker groups when the object is deleted or is garbage collected. + + This is an extra safety net in case the user forgets to call shutdown() and the pointer to + the object is lost due to leaving a function scope. It's always recommended that the + user calls shutdown(). + """ + self.shutdown() + + def invalidate_kv_cache(self) -> bool: + """Invalidate KV cache after weight updates. + + For SGLang, this might need to call a different method or might not be needed + if the server handles it automatically. + """ + try: + # For SGLang, we can call a method on each worker if it exists + futures = [] + for worker in self.worker_group.workers: + if hasattr(worker, "invalidate_kv_cache"): + futures.append(worker.invalidate_kv_cache.remote()) + + if futures: + results = ray.get(futures) + return all(result for result in results if result is not None) + return True + except Exception as e: + print(f"Error invalidating SGLang caches: {e}") + return False diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py new file mode 100644 index 0000000000..2ea03f5e63 --- /dev/null +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -0,0 +1,260 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import gc +import os +import sys +from typing import Any, Optional, cast +import requests + +import time +import ray +import torch +import multiprocessing + +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches +from nemo_rl.models.generation.interfaces import ( + GenerationDatumSpec, + GenerationOutputSpec, + verify_right_padding, +) +from nemo_rl.models.generation.sglang.config import SGLangConfig +from nemo_rl.models.huggingface.common import ModelFlag +from nemo_rl.utils.nsys import wrap_with_nvtx_name + +try: + from sglang.srt.entrypoints.http_server import launch_server + from sglang.srt.server_args import ServerArgs + from sglang.srt.utils import kill_process_tree +except ImportError: + # SGLang may not be installed, but we still want the code to be importable + launch_server = None + ServerArgs = None + kill_process_tree = None + + + + +@ray.remote( + runtime_env={**get_nsight_config_if_pattern_matches("sglang_generation_worker")} +) # pragma: no cover +class SGLangGenerationWorker: + def __repr__(self) -> str: + """Customizes the actor's prefix in the Ray logs. + + This makes it easier to identify which worker is producing specific log messages. + """ + return f"{self.__class__.__name__}" + + @staticmethod + def configure_worker( + num_gpus: int | float, bundle_indices: Optional[tuple[int, list[int]]] = None + ) -> tuple[dict[str, Any], dict[str, str], dict[str, Any]]: + """Provides complete worker configuration for SGLang server. + + This method configures the worker based on bundle_indices which tells us + how many GPUs this server should use. + + Args: + num_gpus: Original GPU allocation for this worker based on the placement group + bundle_indices: Tuple of (node_idx, local_bundle_indices) for this server + + Returns: + tuple with complete worker configuration: + - 'resources': Resource allocation (e.g., num_gpus) + - 'env_vars': Environment variables for this worker + - 'init_kwargs': Parameters to pass to __init__ of the worker + """ + # Initialize configuration + resources: dict[str, Any] = {"num_gpus": num_gpus} + init_kwargs: dict[str, Any] = {} + env_vars: dict[str, str] = {} + + local_bundle_indices = None + if bundle_indices is not None: + node_idx = bundle_indices[0] + local_bundle_indices = bundle_indices[1] + init_kwargs["bundle_indices"] = local_bundle_indices + + # Calculate a unique seed from node_idx and bundle_indices + if len(local_bundle_indices) == 1: + seed = node_idx * 1024 + local_bundle_indices[0] + else: + bundle_id = local_bundle_indices[0] // len(local_bundle_indices) + seed = node_idx * 1024 + bundle_id + + init_kwargs["seed"] = seed + + # For SGLang, Ray manages GPU assignment via CUDA_VISIBLE_DEVICES + # We set num_gpus to 0 and let Ray handle it + if local_bundle_indices is not None and len(local_bundle_indices) > 1: + resources["num_gpus"] = 0 + env_vars["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" + init_kwargs["fraction_of_gpus"] = num_gpus + + return resources, env_vars, init_kwargs + + def __init__( + self, + config: SGLangConfig, + bundle_indices: Optional[list[int]] = None, + fraction_of_gpus: float = 1.0, + seed: Optional[int] = None, + ): + """Initialize a SGLang worker for distributed inference. + + Args: + config: Configuration dictionary for the policy + bundle_indices: List of local bundle indices for this server. + The length of this list determines tp_size (number of GPUs per server). + Only needed for the first worker in each server group (model owner). + fraction_of_gpus: Fraction of GPUs to use for this worker + seed: Random seed for initialization + """ + self.cfg = config + self.is_model_owner = bundle_indices is not None + + if not self.is_model_owner: + return + + # Determine tp_size from bundle_indices length + # Ray sets CUDA_VISIBLE_DEVICES so each server sees logical GPUs 0, 1, 2, ..., tp_size-1 + tp_size = len(bundle_indices) if bundle_indices else 1 + + # Build SGLang server arguments + # Ray automatically sets CUDA_VISIBLE_DEVICES, so base_gpu_id should be 0 + # and gpu_id_step should be 1 + kwargs = { + "model_path": self.cfg.get("model_path", ""), + "trust_remote_code": True, + "random_seed": seed if seed is not None else self.cfg.get("random_seed", 1), + # Memory settings + "enable_memory_saver": self.cfg.get("enable_memory_saver", False), + # GPU settings - Ray handles CUDA_VISIBLE_DEVICES, so we use logical GPU 0 + "gpu_id_step": 1, + "base_gpu_id": 0, # Always 0 because Ray sets CUDA_VISIBLE_DEVICES + # Parallel settings + "tp_size": tp_size, + "dp_size": self.cfg.get("dp_size", 1), + "pp_size": self.cfg.get("pp_size", 1), + "ep_size": self.cfg.get("ep_size", 1), + # Always skip warmup to prevent warmup timeout + "skip_server_warmup": True, + } + + # Add other config fields if they exist + for key in [ + "dtype", "kv_cache_dtype", "context_length", "max_running_requests", + "chunked_prefill_size", "max_prefill_tokens", "schedule_policy", + "schedule_conservativeness", "cpu_offload_gb", "log_level", + ]: + if key in self.cfg: + kwargs[key] = self.cfg[key] + + server_args = ServerArgs(**kwargs) + self.server_process = self._launch_server_process(server_args) + + + def _merge_stop_strings(self, batch_stop_strings): + pass + + def _build_sampling_params( + self, + *, + greedy: bool, + stop_strings, + max_new_tokens: Optional[int] = None, + ): + pass + + def _launch_server_process(self, server_args: ServerArgs) -> multiprocessing.Process: + """Launch the SGLang server process and wait for it to be ready.""" + p = multiprocessing.Process(target=launch_server, args=(server_args,)) + p.start() + + if server_args.node_rank != 0: + return + + base_url = server_args.url() + + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {server_args.api_key}", + } + + with requests.Session() as session: + while True: + try: + response = session.get(f"{base_url}/health_generate", headers=headers) + if response.status_code == 200: + break + except requests.RequestException: + pass + + if not p.is_alive(): + raise Exception("Server process terminated unexpectedly.") + + time.sleep(2) + return p + + + + + @wrap_with_nvtx_name("sglang_genertion_worker/generate") + def generate( + self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + ) -> BatchedDataDict[GenerationOutputSpec]: + """Generate a batch of data using SGLang generation. + + Args: + data: BatchedDataDict containing input_ids and input_lengths tensors + greedy: Whether to use greedy decoding instead of sampling + + Returns: + BatchedDataDict conforming to GenerationOutputSpec: + - output_ids: input + generated token IDs with proper padding + - logprobs: Log probabilities for tokens + - generation_lengths: Lengths of each response + - unpadded_sequence_lengths: Lengths of each input + generated sequence + """ + pass + + def sleep(self): + pass + + def wake_up(self, **kwargs): + pass + + def shutdown(self) -> bool: + pass + + def _make_request(self, endpoint: str, payload: Optional[dict] = None): + """Make a POST request to the specified endpoint with the given payload. + + Args: + endpoint: The API endpoint to call + payload: The JSON payload to send (default: empty dict) + + Returns: + The JSON response from the server + """ + if self.node_rank != 0: + return + + url = f"http://{self.server_args.host}:{self.server_args.port}/{endpoint}" + response = requests.post(url, json=payload or {}) + response.raise_for_status() + return response.json() \ No newline at end of file From 3eace5f4d64c3f9651c54f02bbb8d0492abf66e7 Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Mon, 24 Nov 2025 01:43:07 +0000 Subject: [PATCH 02/29] sglang:manually set cuda visible to let localran=0 to manage gpus of a server Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- .../generation/sglang/sglang_generation.py | 14 +++++--- .../models/generation/sglang/sglang_worker.py | 34 ++++++++++++++----- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/nemo_rl/models/generation/sglang/sglang_generation.py b/nemo_rl/models/generation/sglang/sglang_generation.py index f4bc4433f7..d920935870 100644 --- a/nemo_rl/models/generation/sglang/sglang_generation.py +++ b/nemo_rl/models/generation/sglang/sglang_generation.py @@ -89,12 +89,16 @@ def __init__( self.dp_size = num_servers self.gpus_per_server = gpus_per_server - - # Create sharding annotations with only data_parallel dimension - # Each server is independent, so we only need DP sharding + + # Create sharding annotations + # Even though SGLang manages TP internally, we include it in the layout to support + # RayWorkerGroup's worker management (which creates one worker per GPU bundle). + # The TP dimension becomes a "free axis" in run_all_workers_sharded_data, ensuring + # only the primary workers (TP rank 0) are called. + total_workers = num_servers * gpus_per_server self.sharding_annotations = NamedSharding( - layout=np.arange(num_servers).reshape(num_servers), - names=["data_parallel"], + layout=np.arange(total_workers).reshape(num_servers, gpus_per_server), + names=["data_parallel", "tensor_parallel"], ) # Initialize placement groups diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index 2ea03f5e63..1c6caa1ab6 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -98,9 +98,15 @@ def configure_worker( init_kwargs["seed"] = seed - # For SGLang, Ray manages GPU assignment via CUDA_VISIBLE_DEVICES - # We set num_gpus to 0 and let Ray handle it - if local_bundle_indices is not None and len(local_bundle_indices) > 1: + # Check if this worker is part of a parallel group (multiple GPUs per server). + # A worker with local rank =0 owns the server(local_bundle_indices is not None ) + # otherwise it is a placeholder for Ray's resource management (local_bundle_indices is None). + is_part_of_parallel_workers = ( + local_bundle_indices is not None and len(local_bundle_indices) > 1 + ) or local_bundle_indices is None + + if is_part_of_parallel_workers: + # For parallel workers, we manage GPU assignment manually via CUDA_VISIBLE_DEVICES resources["num_gpus"] = 0 env_vars["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" init_kwargs["fraction_of_gpus"] = num_gpus @@ -126,17 +132,27 @@ def __init__( """ self.cfg = config self.is_model_owner = bundle_indices is not None - + + # Only the primary worker (local_rank=0) in each server group starts the SGLang server + # Secondary workers (local_rank!=0) just returns if not self.is_model_owner: return + # Set CUDA_VISIBLE_DEVICES to allow SGLang server to see the correct GPUs + # bundle_indices contains the node-local GPU indices (e.g., [0,1,2,3] or [4,5,6,7]) + # Since we set RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1, Ray won't override this + gpu_ids = ",".join(str(idx) for idx in bundle_indices) + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids + # Determine tp_size from bundle_indices length - # Ray sets CUDA_VISIBLE_DEVICES so each server sees logical GPUs 0, 1, 2, ..., tp_size-1 - tp_size = len(bundle_indices) if bundle_indices else 1 - + tp_size = len(bundle_indices) + + print( + f"[SGLang Server] Node {os.environ.get('NODE_RANK', '?')}: " + f"Setting CUDA_VISIBLE_DEVICES={gpu_ids} (tp_size={tp_size})" + ) + # Build SGLang server arguments - # Ray automatically sets CUDA_VISIBLE_DEVICES, so base_gpu_id should be 0 - # and gpu_id_step should be 1 kwargs = { "model_path": self.cfg.get("model_path", ""), "trust_remote_code": True, From 6fbbbb741e680ee2d020d9f73063aa831a8f7e9d Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Tue, 25 Nov 2025 21:14:33 +0000 Subject: [PATCH 03/29] sglang: add sglang setup in grpo.py, add find available port to set up servers Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- nemo_rl/algorithms/grpo.py | 59 ++++++++++++++++++ .../ray_actor_environment_registry.py | 4 ++ nemo_rl/distributed/virtual_cluster.py | 2 + nemo_rl/models/generation/sglang/__init__.py | 23 +++++++ .../generation/sglang/sglang_generation.py | 18 ++++++ .../models/generation/sglang/sglang_worker.py | 61 +++++++++++-------- 6 files changed, 143 insertions(+), 24 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index d79b6d2fac..ab0033575b 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -62,6 +62,7 @@ ) from nemo_rl.models.generation.interfaces import GenerationInterface from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration +from nemo_rl.models.generation.sglang import SGLangConfig, SGLangGeneration from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import ColocatablePolicyInterface from nemo_rl.models.policy.lm_policy import Policy @@ -482,6 +483,13 @@ def init_vllm(): pg.finish_generation() return pg, time.perf_counter() - t0 + def init_sglang(): + """Initialize SGLang generation workers.""" + t0 = time.perf_counter() + pg = SGLangGeneration(cluster=inference_cluster, config=generation_config) + pg.finish_generation() + return pg, time.perf_counter() - t0 + # Handle backend-specific setup if backend == "megatron": # Megatron backend: policy_generation is None, only initialize policy @@ -568,6 +576,57 @@ def init_vllm(): flush=True, ) + elif backend == "sglang": + # Set model_name and model_path + generation_config["model_name"] = policy_config["model_name"] + if "model_path" not in generation_config or not generation_config.get("model_path"): + generation_config["model_path"] = policy_config["model_name"] + + # Determine if parallel initialization is possible (non-colocated mode) + use_parallel_init = not colocated_inference + + if use_parallel_init: + # Parallel initialization: SGLang and Policy can initialize simultaneously + print( + " ⚡ Using parallel worker initialization (non-colocated mode)", + flush=True, + ) + + # Execute both initializations in parallel + parallel_start_time = time.perf_counter() + with ThreadPoolExecutor(max_workers=2) as executor: + sglang_future = executor.submit(init_sglang) + policy_future = executor.submit(init_policy) + policy_generation, sglang_time = sglang_future.result() + policy, policy_time = policy_future.result() + parallel_wall_time = time.perf_counter() - parallel_start_time + + # Store timing metrics + worker_init_timing_metrics["sglang_init_time_s"] = sglang_time + worker_init_timing_metrics["policy_init_time_s"] = policy_time + worker_init_timing_metrics["parallel_wall_time_s"] = parallel_wall_time + worker_init_timing_metrics["parallel_init_enabled"] = True + + else: + # Sequential initialization: colocated mode (GPU memory requires SGLang first) + print( + " ⚙️ Using sequential worker initialization (colocated mode)", + flush=True, + ) + + # Initialize SGLang first (clean GPU memory), then policy + policy_generation, sglang_time = init_sglang() + worker_init_timing_metrics["sglang_init_time_s"] = sglang_time + + policy, policy_time = init_policy() + worker_init_timing_metrics["policy_init_time_s"] = policy_time + worker_init_timing_metrics["parallel_init_enabled"] = 0.0 + + print( + f" ✓ Using SGLang backend for generation with {policy_config['model_name']}", + flush=True, + ) + # Record when worker initialization completes (for calculating other setup time) worker_init_complete_time = time.perf_counter() - setup_start_time diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py index 8d233185a4..fb95d73e95 100644 --- a/nemo_rl/distributed/ray_actor_environment_registry.py +++ b/nemo_rl/distributed/ray_actor_environment_registry.py @@ -20,6 +20,9 @@ VLLM_EXECUTABLE = ( PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.VLLM ) +SGLANG_EXECUTABLE = ( + PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.SGLANG +) MCORE_EXECUTABLE = ( PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.MCORE ) @@ -27,6 +30,7 @@ ACTOR_ENVIRONMENT_REGISTRY: dict[str, str] = { "nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker": VLLM_EXECUTABLE, "nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker": VLLM_EXECUTABLE, + "nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker": SGLANG_EXECUTABLE, # Temporary workaround for the coupled implementation of DTensorPolicyWorker and vLLM. # This will be reverted to PY_EXECUTABLES.BASE once https://github.com/NVIDIA-NeMo/RL/issues/501 is resolved. "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker": VLLM_EXECUTABLE, diff --git a/nemo_rl/distributed/virtual_cluster.py b/nemo_rl/distributed/virtual_cluster.py index 3021b760e4..4c42054455 100644 --- a/nemo_rl/distributed/virtual_cluster.py +++ b/nemo_rl/distributed/virtual_cluster.py @@ -57,6 +57,8 @@ class PY_EXECUTABLES: # Use NeMo-Gym dependencies NEMO_GYM = f"uv run --locked --extra nemo_gym --directory {git_root}" + # Use NeMo-RL direct dependencies and SGLang. + SGLANG = "uv run --locked --extra sglang --directory {git_root}" @ray.remote # pragma: no cover diff --git a/nemo_rl/models/generation/sglang/__init__.py b/nemo_rl/models/generation/sglang/__init__.py index e69de29bb2..55ce57084d 100644 --- a/nemo_rl/models/generation/sglang/__init__.py +++ b/nemo_rl/models/generation/sglang/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OR WARRANTIES OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nemo_rl.models.generation.sglang.config import SGLangConfig +from nemo_rl.models.generation.sglang.sglang_generation import SGLangGeneration +from nemo_rl.models.generation.sglang.sglang_worker import SGLangGenerationWorker + +__all__ = [ + "SGLangConfig", + "SGLangGeneration", + "SGLangGenerationWorker", +] + diff --git a/nemo_rl/models/generation/sglang/sglang_generation.py b/nemo_rl/models/generation/sglang/sglang_generation.py index d920935870..2a42ac9409 100644 --- a/nemo_rl/models/generation/sglang/sglang_generation.py +++ b/nemo_rl/models/generation/sglang/sglang_generation.py @@ -204,6 +204,15 @@ def _allocate_bundles_for_servers( return bundle_indices_list + def init_collective( + self, ip: str, port: int, world_size: int, *, train_world_size: int + ) -> list[ray.ObjectRef]: + """Initialize the collective communication. + + + TODO: if weight updates via NCCL are needed in the future. + """ + return [] def generate( self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False @@ -252,6 +261,15 @@ def generate( ) return combined + + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: + pass + + def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: + return [] + + def update_weights_from_collective(self) -> list[ray.ObjectRef]: + return [] def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: """Wake workers up for colocated inference.""" diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index 1c6caa1ab6..5774c7a4bf 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -25,6 +25,7 @@ import multiprocessing from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import _get_node_ip_local, _get_free_port_local from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches from nemo_rl.models.generation.interfaces import ( GenerationDatumSpec, @@ -35,17 +36,9 @@ from nemo_rl.models.huggingface.common import ModelFlag from nemo_rl.utils.nsys import wrap_with_nvtx_name -try: - from sglang.srt.entrypoints.http_server import launch_server - from sglang.srt.server_args import ServerArgs - from sglang.srt.utils import kill_process_tree -except ImportError: - # SGLang may not be installed, but we still want the code to be importable - launch_server = None - ServerArgs = None - kill_process_tree = None - - +from sglang.srt.entrypoints.http_server import launch_server +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import kill_process_tree @ray.remote( @@ -132,7 +125,10 @@ def __init__( """ self.cfg = config self.is_model_owner = bundle_indices is not None - + + # This is the global worker rank across all workers + self.global_rank = int(os.environ.get("RANK", "0")) + # Only the primary worker (local_rank=0) in each server group starts the SGLang server # Secondary workers (local_rank!=0) just returns if not self.is_model_owner: @@ -152,6 +148,10 @@ def __init__( f"Setting CUDA_VISIBLE_DEVICES={gpu_ids} (tp_size={tp_size})" ) + # Get current node IP and a free port for the server + node_ip = _get_node_ip_local() + free_port = _get_free_port_local() + # Build SGLang server arguments kwargs = { "model_path": self.cfg.get("model_path", ""), @@ -169,6 +169,10 @@ def __init__( "ep_size": self.cfg.get("ep_size", 1), # Always skip warmup to prevent warmup timeout "skip_server_warmup": True, + # Server network settings - listen on all interfaces, use the free port we found + "host": "0.0.0.0", + "port": free_port, + "torchao_config": "", } # Add other config fields if they exist @@ -181,6 +185,12 @@ def __init__( kwargs[key] = self.cfg[key] server_args = ServerArgs(**kwargs) + # Save server_args and base_url for use in generate() and _make_request() + self.server_args = server_args + self.base_url = f"http://{node_ip}:{free_port}" + + print(f"[SGLang Server] Rank {self.global_rank} Starting on {self.base_url}") + self.server_process = self._launch_server_process(server_args) @@ -201,11 +211,8 @@ def _launch_server_process(self, server_args: ServerArgs) -> multiprocessing.Pro p = multiprocessing.Process(target=launch_server, args=(server_args,)) p.start() - if server_args.node_rank != 0: - return - - base_url = server_args.url() - + # Wait for server to be ready by checking health endpoint + # Use the base_url we stored earlier headers = { "Content-Type": "application/json; charset=utf-8", "Authorization": f"Bearer {server_args.api_key}", @@ -214,14 +221,15 @@ def _launch_server_process(self, server_args: ServerArgs) -> multiprocessing.Pro with requests.Session() as session: while True: try: - response = session.get(f"{base_url}/health_generate", headers=headers) + response = session.get(f"{self.base_url}/health_generate", headers=headers) if response.status_code == 200: + print(f"[SGLang Server] Rank {self.global_rank} Server is ready at {self.base_url}") break except requests.RequestException: pass if not p.is_alive(): - raise Exception("Server process terminated unexpectedly.") + raise Exception(f"[SGLang Server] Rank {self.global_rank} Server process terminated unexpectedly.") time.sleep(2) return p @@ -246,6 +254,9 @@ def generate( - generation_lengths: Lengths of each response - unpadded_sequence_lengths: Lengths of each input + generated sequence """ + input_lengths = data["input_lengths"] + print(f"[SGLang Generation Worker] Rank {self.global_rank} Input lengths: {input_lengths}") + pass def sleep(self): @@ -267,10 +278,12 @@ def _make_request(self, endpoint: str, payload: Optional[dict] = None): Returns: The JSON response from the server """ - if self.node_rank != 0: - return - - url = f"http://{self.server_args.host}:{self.server_args.port}/{endpoint}" - response = requests.post(url, json=payload or {}) + # Use the stored base_url instead of constructing from server_args + url = f"{self.base_url}/{endpoint}" + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {self.server_args.api_key}", + } + response = requests.post(url, json=payload or {}, headers=headers) response.raise_for_status() return response.json() \ No newline at end of file From 242612c552574589a3a2e447164007a77fb2e6da Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Tue, 25 Nov 2025 22:38:52 +0000 Subject: [PATCH 04/29] sglang: add shutdown Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- .../models/generation/sglang/sglang_worker.py | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index 5774c7a4bf..3442a42603 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -260,13 +260,45 @@ def generate( pass def sleep(self): + # TODO pass def wake_up(self, **kwargs): + # TODO pass def shutdown(self) -> bool: - pass + """Shutdown the SGLang server process. + + Returns: + bool: True if shutdown was successful, False otherwise + """ + if not self.is_model_owner: + return True + + if not hasattr(self, "server_process") or self.server_process is None: + return True + + try: + print( + f"[SGLang Worker] Rank {self.global_rank} Shutting down server at {self.base_url}..." + ) + + if self.server_process.is_alive(): + kill_process_tree(self.server_process.pid) + + # Wait for the process to terminate + self.server_process.join(timeout=5.0) + + if self.server_process.is_alive(): + return False + return True + + except Exception as e: + print( + f"[SGLang Worker] Rank {self.global_rank} Error during shutdown: {e}" + ) + return False def _make_request(self, endpoint: str, payload: Optional[dict] = None): """Make a POST request to the specified endpoint with the given payload. From a3d8ad6bb0d99fed592e03859f647526d4e7c7af Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Fri, 28 Nov 2025 18:17:03 +0000 Subject: [PATCH 05/29] sglang server: fix gpu allocation when tp =1 Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- nemo_rl/models/generation/sglang/sglang_worker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index 3442a42603..4ccba0f957 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -103,6 +103,8 @@ def configure_worker( resources["num_gpus"] = 0 env_vars["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" init_kwargs["fraction_of_gpus"] = num_gpus + else: + env_vars["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" return resources, env_vars, init_kwargs From 88971e3e4ed1c5b86203bd8f52d16e29d2a485ac Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Tue, 25 Nov 2025 23:53:05 +0000 Subject: [PATCH 06/29] generate only first request Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- .../models/generation/sglang/sglang_worker.py | 169 +++++++++++++++++- 1 file changed, 164 insertions(+), 5 deletions(-) diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index 4ccba0f957..bec8c273cf 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -208,6 +208,55 @@ def _build_sampling_params( ): pass + def _generate_single_sample( + self, + input_ids: list[int], + sampling_params: dict[str, Any], + stop_string: Optional[str] = None, + ) -> tuple[list[int], list[float]]: + """Generate a single sample using SGLang API. + + Args: + input_ids: List of input token IDs (without padding) + sampling_params: Dictionary of sampling parameters (temperature, top_p, max_new_tokens, etc.) + stop_string: Optional stop string for this sample + + Returns: + Tuple of (generated_tokens, logprobs): + - generated_tokens: List of generated token IDs + - logprobs: List of log probabilities for generated tokens + """ + # Prepare payload for SGLang API + # Note: stop should be in sampling_params, not in payload top level + if stop_string is not None: + # stop can be a string or list of strings + sampling_params = sampling_params.copy() # Don't modify the original + sampling_params["stop"] = stop_string + + payload = { + "sampling_params": sampling_params, + "return_logprob": True, + "input_ids": input_ids, + } + + print(f"[SGLang Worker] Rank {self.global_rank} payload: {payload}") + # Call SGLang generate endpoint + response = self._make_request("generate", payload) + + # Extract generated tokens and logprobs + meta_info = response.get("meta_info", {}) + output_token_logprobs = meta_info.get("output_token_logprobs", []) + + if output_token_logprobs: + new_tokens = [item[1] for item in output_token_logprobs] + new_logprobs = [item[0] for item in output_token_logprobs] + else: + # Fallback: empty if token logprobs not available + new_tokens = [] + new_logprobs = [] + + return new_tokens, new_logprobs + def _launch_server_process(self, server_args: ServerArgs) -> multiprocessing.Process: """Launch the SGLang server process and wait for it to be ready.""" p = multiprocessing.Process(target=launch_server, args=(server_args,)) @@ -217,7 +266,6 @@ def _launch_server_process(self, server_args: ServerArgs) -> multiprocessing.Pro # Use the base_url we stored earlier headers = { "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {server_args.api_key}", } with requests.Session() as session: @@ -234,6 +282,8 @@ def _launch_server_process(self, server_args: ServerArgs) -> multiprocessing.Pro raise Exception(f"[SGLang Server] Rank {self.global_rank} Server process terminated unexpectedly.") time.sleep(2) + # response = session.get(f"{self.base_url}/get_model_info", headers=headers) + # print(f"[SGLang Worker] Rank {self.global_rank} model_info: {response.json()}") return p @@ -256,10 +306,120 @@ def generate( - generation_lengths: Lengths of each response - unpadded_sequence_lengths: Lengths of each input + generated sequence """ + # Handle empty input case + if len(data["input_ids"]) == 0: + return BatchedDataDict[GenerationOutputSpec]( + { + "output_ids": torch.zeros((0, 0), dtype=torch.long), + "logprobs": torch.zeros((0, 0), dtype=torch.float), + "generation_lengths": torch.zeros(0, dtype=torch.long), + "unpadded_sequence_lengths": torch.zeros(0, dtype=torch.long), + } + ) + + input_ids = data["input_ids"] input_lengths = data["input_lengths"] - print(f"[SGLang Generation Worker] Rank {self.global_rank} Input lengths: {input_lengths}") - - pass + stop_strings = data.get("stop_strings", [None] * len(input_lengths)) + batch_size = len(input_lengths) + pad_token_id = self.cfg.get("_pad_token_id", 0) + + # Verify inputs have correct padding + verify_right_padding(data, pad_value=pad_token_id) + + # Original input length with padding + padded_input_length = input_ids.size(1) + + print(f"[SGLang Worker] Rank {self.global_rank} batch_size: {batch_size}, padded_input_length: {padded_input_length}") + + # Get generation parameters from config + max_new_tokens = self.cfg.get("max_new_tokens", 512) + temperature = 0.0 if greedy else self.cfg.get("temperature", 1.0) + top_p = self.cfg.get("top_p", 1.0) + top_k = self.cfg.get("top_k", None) + + sampling_params = { + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + if top_k is not None: + sampling_params["top_k"] = top_k + + # TEST: Only process the first sample TODO + if batch_size == 0: + raise ValueError("Empty batch received") + + i = 0 + input_len = input_lengths[i].item() + valid_input_ids = input_ids[i, :input_len].tolist() + + print(f"[SGLang Worker] Rank {self.global_rank} Processing sample {i}, input_len: {input_len}") + + new_tokens, new_logprobs = self._generate_single_sample( + input_ids=valid_input_ids, + sampling_params=sampling_params, + stop_string=stop_strings[i], + ) + + print(f"[SGLang Worker] Rank {self.global_rank} Generated {len(new_tokens)} tokens") + + generation_length = len(new_tokens) + + # Calculate total length: padded_input_length + max_generated_length + # For now, since we only process one sample, max_length = generation_length + max_length = generation_length + total_length = padded_input_length + max_length + + # Create output tensor + full_output = torch.full( + (total_length,), pad_token_id, dtype=input_ids.dtype + ) + + # Copy original input (with padding) into the beginning + full_output[:input_len] = input_ids[i][:input_len] + + # Add generated tokens after the original input + if new_tokens: + full_output[input_len : input_len + len(new_tokens)] = ( + torch.tensor(new_tokens, dtype=input_ids.dtype) + ) + + full_logprobs = torch.zeros(total_length, dtype=torch.float32) + if new_logprobs: + for idx, logprob in enumerate(new_logprobs): + position = input_len + idx + full_logprobs[position] = logprob + + unpadded_length = input_len + generation_length + + # For other samples, create dummy outputs (same shape as first sample) + output_ids_list = [full_output] + logprobs_list = [full_logprobs] + generation_lengths_list = [generation_length] + unpadded_sequence_lengths_list = [unpadded_length] + + for j in range(1, batch_size): + dummy_output = torch.full((total_length,), pad_token_id, dtype=input_ids.dtype) + dummy_logprobs = torch.zeros(total_length, dtype=torch.float32) + output_ids_list.append(dummy_output) + logprobs_list.append(dummy_logprobs) + generation_lengths_list.append(0) + unpadded_sequence_lengths_list.append(input_lengths[j].item()) + + # Stack into tensors + output_ids = torch.stack(output_ids_list) + logprobs = torch.stack(logprobs_list) + generation_lengths = torch.tensor(generation_lengths_list, dtype=torch.long) + unpadded_sequence_lengths = torch.tensor(unpadded_sequence_lengths_list, dtype=torch.long) + + return BatchedDataDict[GenerationOutputSpec]( + { + "output_ids": output_ids, + "generation_lengths": generation_lengths, + "unpadded_sequence_lengths": unpadded_sequence_lengths, + "logprobs": logprobs, + } + ) def sleep(self): # TODO @@ -316,7 +476,6 @@ def _make_request(self, endpoint: str, payload: Optional[dict] = None): url = f"{self.base_url}/{endpoint}" headers = { "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {self.server_args.api_key}", } response = requests.post(url, json=payload or {}, headers=headers) response.raise_for_status() From db8b07b86e2d43694c09029b3cf7246f47f4747c Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Wed, 26 Nov 2025 01:21:41 +0000 Subject: [PATCH 07/29] fix : choose the correct gpu using base gpu id Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- .../generation/sglang/sglang_generation.py | 5 ++++ .../models/generation/sglang/sglang_worker.py | 28 ++++++++++--------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/nemo_rl/models/generation/sglang/sglang_generation.py b/nemo_rl/models/generation/sglang/sglang_generation.py index 2a42ac9409..19f208304a 100644 --- a/nemo_rl/models/generation/sglang/sglang_generation.py +++ b/nemo_rl/models/generation/sglang/sglang_generation.py @@ -114,6 +114,11 @@ def __init__( worker_builder = RayWorkerBuilder(worker_cls, config) env_vars = {} + global_cvd = os.environ.get("CUDA_VISIBLE_DEVICES", None) + if global_cvd: + # Explicitly pass CUDA_VISIBLE_DEVICES to workers via env_vars + # This ensures all workers see the same global value, even though + env_vars["CUDA_VISIBLE_DEVICES"] = global_cvd # Allocate bundles for each server # Each server gets consecutive bundles diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index bec8c273cf..24cbf6932b 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -99,7 +99,9 @@ def configure_worker( ) or local_bundle_indices is None if is_part_of_parallel_workers: - # For parallel workers, we manage GPU assignment manually via CUDA_VISIBLE_DEVICES + # For parallel workers, we manage GPU assignment via base_gpu_id + # All workers see the same global CUDA_VISIBLE_DEVICES, but use different + # logical GPU ranges via base_gpu_id resources["num_gpus"] = 0 env_vars["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" init_kwargs["fraction_of_gpus"] = num_gpus @@ -136,18 +138,19 @@ def __init__( if not self.is_model_owner: return - # Set CUDA_VISIBLE_DEVICES to allow SGLang server to see the correct GPUs - # bundle_indices contains the node-local GPU indices (e.g., [0,1,2,3] or [4,5,6,7]) - # Since we set RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1, Ray won't override this - gpu_ids = ",".join(str(idx) for idx in bundle_indices) - os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids - # Determine tp_size from bundle_indices length tp_size = len(bundle_indices) - + + base_gpu_id = bundle_indices[0] if bundle_indices else 0 + + # Get the global CUDA_VISIBLE_DEVICES (all engines see the same global value) + global_cvd = os.environ.get("CUDA_VISIBLE_DEVICES", None) + + print( - f"[SGLang Server] Node {os.environ.get('NODE_RANK', '?')}: " - f"Setting CUDA_VISIBLE_DEVICES={gpu_ids} (tp_size={tp_size})" + f"[SGLang Server] Rank {self.global_rank}: " + f"base_gpu_id={base_gpu_id}, tp_size={tp_size}, " + f"bundle_indices={bundle_indices}, global_cvd={global_cvd}" ) # Get current node IP and a free port for the server @@ -161,9 +164,8 @@ def __init__( "random_seed": seed if seed is not None else self.cfg.get("random_seed", 1), # Memory settings "enable_memory_saver": self.cfg.get("enable_memory_saver", False), - # GPU settings - Ray handles CUDA_VISIBLE_DEVICES, so we use logical GPU 0 "gpu_id_step": 1, - "base_gpu_id": 0, # Always 0 because Ray sets CUDA_VISIBLE_DEVICES + "base_gpu_id": base_gpu_id, # Parallel settings "tp_size": tp_size, "dp_size": self.cfg.get("dp_size", 1), @@ -191,7 +193,7 @@ def __init__( self.server_args = server_args self.base_url = f"http://{node_ip}:{free_port}" - print(f"[SGLang Server] Rank {self.global_rank} Starting on {self.base_url}") + print(f"[SGLang Worker] Rank {self.global_rank} Starting on {self.base_url}, CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}, base_gpu_id: {base_gpu_id}") self.server_process = self._launch_server_process(server_args) From dd0e54f8ef0b38a6f9e71a9809c3b0a1a9ad528b Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Wed, 26 Nov 2025 02:57:16 +0000 Subject: [PATCH 08/29] asyncio to roolout all saples Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- .../models/generation/sglang/sglang_worker.py | 147 +++++++++++------- 1 file changed, 89 insertions(+), 58 deletions(-) diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index 24cbf6932b..0ddeded3a0 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -18,6 +18,8 @@ import sys from typing import Any, Optional, cast import requests +import asyncio +import aiohttp import time import ray @@ -210,13 +212,13 @@ def _build_sampling_params( ): pass - def _generate_single_sample( + async def _generate_single_sample( self, input_ids: list[int], sampling_params: dict[str, Any], stop_string: Optional[str] = None, ) -> tuple[list[int], list[float]]: - """Generate a single sample using SGLang API. + """Generate a single sample using SGLang API (async function). Args: input_ids: List of input token IDs (without padding) @@ -241,12 +243,19 @@ def _generate_single_sample( "input_ids": input_ids, } - print(f"[SGLang Worker] Rank {self.global_rank} payload: {payload}") - # Call SGLang generate endpoint - response = self._make_request("generate", payload) + # Use aiohttp for async request + url = f"{self.base_url}/generate" + headers = { + "Content-Type": "application/json; charset=utf-8", + } + + async with aiohttp.ClientSession() as session: + async with session.post(url, json=payload, headers=headers) as response: + response.raise_for_status() + result = await response.json() # Extract generated tokens and logprobs - meta_info = response.get("meta_info", {}) + meta_info = result.get("meta_info", {}) output_token_logprobs = meta_info.get("output_token_logprobs", []) if output_token_logprobs: @@ -259,6 +268,17 @@ def _generate_single_sample( return new_tokens, new_logprobs + async def _generate_async(self, tasks: list) -> list: + """Execute all async generation tasks concurrently. + + Args: + tasks: List of async coroutines for generating samples + + Returns: + List of (tokens, logprobs) tuples for all samples + """ + return await asyncio.gather(*tasks) + def _launch_server_process(self, server_args: ServerArgs) -> multiprocessing.Process: """Launch the SGLang server process and wait for it to be ready.""" p = multiprocessing.Process(target=launch_server, args=(server_args,)) @@ -347,66 +367,77 @@ def generate( if top_k is not None: sampling_params["top_k"] = top_k - # TEST: Only process the first sample TODO if batch_size == 0: raise ValueError("Empty batch received") - i = 0 - input_len = input_lengths[i].item() - valid_input_ids = input_ids[i, :input_len].tolist() - - print(f"[SGLang Worker] Rank {self.global_rank} Processing sample {i}, input_len: {input_len}") - - new_tokens, new_logprobs = self._generate_single_sample( - input_ids=valid_input_ids, - sampling_params=sampling_params, - stop_string=stop_strings[i], - ) - - print(f"[SGLang Worker] Rank {self.global_rank} Generated {len(new_tokens)} tokens") + # Create async tasks for all samples + tasks = [] + for i in range(batch_size): + input_len = input_lengths[i].item() + valid_input_ids = input_ids[i, :input_len].tolist() + + tasks.append( + self._generate_single_sample( + input_ids=valid_input_ids, + sampling_params=sampling_params, + stop_string=stop_strings[i], + ) + ) - generation_length = len(new_tokens) + # Execute all requests concurrently + try: + loop = asyncio.get_running_loop() + future = asyncio.run_coroutine_threadsafe( + self._generate_async(tasks), + loop + ) + all_results = future.result() + except RuntimeError: + all_results = asyncio.run(self._generate_async(tasks)) + + # Process results + output_ids_list = [] + logprobs_list = [] + generation_lengths_list = [] + unpadded_sequence_lengths_list = [] + max_length = 0 + + # First pass: calculate max_length + for i, (new_tokens, new_logprobs) in enumerate(all_results): + input_len = input_lengths[i].item() + generation_length = len(new_tokens) + unpadded_length = input_len + generation_length + max_length = max(max_length, unpadded_length) - # Calculate total length: padded_input_length + max_generated_length - # For now, since we only process one sample, max_length = generation_length - max_length = generation_length total_length = padded_input_length + max_length - # Create output tensor - full_output = torch.full( - (total_length,), pad_token_id, dtype=input_ids.dtype - ) - - # Copy original input (with padding) into the beginning - full_output[:input_len] = input_ids[i][:input_len] - - # Add generated tokens after the original input - if new_tokens: - full_output[input_len : input_len + len(new_tokens)] = ( - torch.tensor(new_tokens, dtype=input_ids.dtype) + for i, (new_tokens, new_logprobs) in enumerate(all_results): + input_len = input_lengths[i].item() + generation_length = len(new_tokens) + unpadded_length = input_len + generation_length + + full_output = torch.full( + (total_length,), pad_token_id, dtype=input_ids.dtype ) - - full_logprobs = torch.zeros(total_length, dtype=torch.float32) - if new_logprobs: - for idx, logprob in enumerate(new_logprobs): - position = input_len + idx - full_logprobs[position] = logprob - - unpadded_length = input_len + generation_length - - # For other samples, create dummy outputs (same shape as first sample) - output_ids_list = [full_output] - logprobs_list = [full_logprobs] - generation_lengths_list = [generation_length] - unpadded_sequence_lengths_list = [unpadded_length] - - for j in range(1, batch_size): - dummy_output = torch.full((total_length,), pad_token_id, dtype=input_ids.dtype) - dummy_logprobs = torch.zeros(total_length, dtype=torch.float32) - output_ids_list.append(dummy_output) - logprobs_list.append(dummy_logprobs) - generation_lengths_list.append(0) - unpadded_sequence_lengths_list.append(input_lengths[j].item()) + full_output[:input_len] = input_ids[i][:input_len] + + # Add generated tokens after the original input + if new_tokens: + full_output[input_len : input_len + len(new_tokens)] = ( + torch.tensor(new_tokens, dtype=input_ids.dtype) + ) + + # Construct logprobs: zeros for input tokens, actual logprobs for generated tokens + full_logprobs = torch.zeros(total_length, dtype=torch.float32) + if new_logprobs: + for idx, logprob in enumerate(new_logprobs): + position = input_len + idx + full_logprobs[position] = logprob + + output_ids_list.append(full_output) + logprobs_list.append(full_logprobs) + generation_lengths_list.append(generation_length) + unpadded_sequence_lengths_list.append(unpadded_length) # Stack into tensors output_ids = torch.stack(output_ids_list) From 21c54e39a0845dc3c2874b67ef2ba97da5cbab54 Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Wed, 26 Nov 2025 03:41:16 +0000 Subject: [PATCH 09/29] fix new event loop for rollout Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- .../models/generation/sglang/sglang_worker.py | 145 +++++++++++++++--- 1 file changed, 120 insertions(+), 25 deletions(-) diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index 0ddeded3a0..40a730bc41 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -20,6 +20,7 @@ import requests import asyncio import aiohttp +import threading import time import ray @@ -43,6 +44,52 @@ from sglang.srt.utils import kill_process_tree +class AsyncLoopThread: + """A background event loop thread for running async operations in Ray actors. + + This class creates a dedicated thread with its own event loop, allowing + synchronous Ray actor methods to execute async coroutines without blocking + the main actor thread. This is necessary because run_coroutine_threadsafe + requires the event loop to be in a different thread. + """ + def __init__(self): + self.loop = asyncio.new_event_loop() + self._ready = threading.Event() + self._thread = threading.Thread(target=self._start_loop, daemon=True) + self._thread.start() + if not self._ready.wait(timeout=5.0): + raise RuntimeError("Event loop thread failed to start within 5 seconds") + + def _start_loop(self): + """Run the event loop in the background thread.""" + asyncio.set_event_loop(self.loop) + self._ready.set() + self.loop.run_forever() + + def run(self, coro): + """Schedule a coroutine onto the loop and block until it's done. + + Args: + coro: The coroutine to execute + + Returns: + The result of the coroutine + """ + if not self.loop.is_running(): + raise RuntimeError("Event loop is not running") + future = asyncio.run_coroutine_threadsafe(coro, self.loop) + result = future.result() + return result + + def shutdown(self): + """Shutdown the event loop and wait for the thread to finish.""" + if self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + self._thread.join(timeout=2.0) + if self.loop.is_running(): + self.loop.close() + + @ray.remote( runtime_env={**get_nsight_config_if_pattern_matches("sglang_generation_worker")} ) # pragma: no cover @@ -135,6 +182,10 @@ def __init__( # This is the global worker rank across all workers self.global_rank = int(os.environ.get("RANK", "0")) + # Create a dedicated event loop thread for async operations + # there will be issues if we use the event loop in the main thread + self.async_loop_thread = AsyncLoopThread() + # Only the primary worker (local_rank=0) in each server group starts the SGLang server # Secondary workers (local_rank!=0) just returns if not self.is_model_owner: @@ -197,6 +248,9 @@ def __init__( print(f"[SGLang Worker] Rank {self.global_rank} Starting on {self.base_url}, CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}, base_gpu_id: {base_gpu_id}") + self.session = None + self.connector = None + self.server_process = self._launch_server_process(server_args) @@ -212,6 +266,15 @@ def _build_sampling_params( ): pass + async def _ensure_session(self): + if self.session is None: + # Create connector with connection pool limit + self.connector = aiohttp.TCPConnector(limit=512, limit_per_host=512) + # Create session with timeout + timeout = aiohttp.ClientTimeout(total=300) # 5 minutes timeout + self.session = aiohttp.ClientSession(connector=self.connector, timeout=timeout) + return self.session + async def _generate_single_sample( self, input_ids: list[int], @@ -243,16 +306,20 @@ async def _generate_single_sample( "input_ids": input_ids, } - # Use aiohttp for async request url = f"{self.base_url}/generate" headers = { "Content-Type": "application/json; charset=utf-8", } - async with aiohttp.ClientSession() as session: + session = await self._ensure_session() + + try: async with session.post(url, json=payload, headers=headers) as response: response.raise_for_status() result = await response.json() + except Exception as e: + print(f"[SGLang Worker] Rank {self.global_rank} Request failed for input_len={len(input_ids)}: {e}") + raise # Extract generated tokens and logprobs meta_info = result.get("meta_info", {}) @@ -268,16 +335,27 @@ async def _generate_single_sample( return new_tokens, new_logprobs - async def _generate_async(self, tasks: list) -> list: - """Execute all async generation tasks concurrently. + async def _generate_async(self, tasks): - Args: - tasks: List of async coroutines for generating samples - - Returns: - List of (tokens, logprobs) tuples for all samples - """ - return await asyncio.gather(*tasks) + async def wrap(idx, coro): + try: + result = await coro + return idx, result + except Exception as e: + raise + + wrapped = [wrap(i, t) for i, t in enumerate(tasks)] + results = [None] * len(tasks) + count = 0 + + for fut in asyncio.as_completed(wrapped): + idx, value = await fut + results[idx] = value + count += 1 + if count % 50 == 0 or count == len(tasks): + print(f"[SGLang Worker] Rank {self.global_rank} Completed {count}/{len(tasks)} tasks") + + return results def _launch_server_process(self, server_args: ServerArgs) -> multiprocessing.Process: """Launch the SGLang server process and wait for it to be ready.""" @@ -384,16 +462,14 @@ def generate( ) ) - # Execute all requests concurrently + # Execute all requests concurrently using the dedicated event loop thread try: - loop = asyncio.get_running_loop() - future = asyncio.run_coroutine_threadsafe( - self._generate_async(tasks), - loop - ) - all_results = future.result() - except RuntimeError: - all_results = asyncio.run(self._generate_async(tasks)) + all_results = self.async_loop_thread.run(self._generate_async(tasks)) + except Exception as e: + raise + + total_generated_tokens = sum(len(tokens) for tokens, _ in all_results) + avg_generation_length = total_generated_tokens / batch_size if batch_size > 0 else 0 # Process results output_ids_list = [] @@ -444,7 +520,7 @@ def generate( logprobs = torch.stack(logprobs_list) generation_lengths = torch.tensor(generation_lengths_list, dtype=torch.long) unpadded_sequence_lengths = torch.tensor(unpadded_sequence_lengths_list, dtype=torch.long) - + print(f"[SGLang Worker] Rank {self.global_rank} Generated {total_generated_tokens} tokens across {batch_size} samples (avg: {avg_generation_length:.1f} tokens/sample)") return BatchedDataDict[GenerationOutputSpec]( { "output_ids": output_ids, @@ -463,18 +539,37 @@ def wake_up(self, **kwargs): pass def shutdown(self) -> bool: - """Shutdown the SGLang server process. + """Shutdown the SGLang server process and cleanup async resources. Returns: bool: True if shutdown was successful, False otherwise """ - if not self.is_model_owner: - return True + if hasattr(self, "async_loop_thread"): + try: + self.async_loop_thread.shutdown() + print(f"[SGLang Worker] Rank {self.global_rank} Async loop thread shut down.") + except Exception as e: + print(f"[SGLang Worker] Rank {self.global_rank} Error shutting down async loop thread: {e}") - if not hasattr(self, "server_process") or self.server_process is None: + if not self.is_model_owner: return True try: + if hasattr(self, "session") and self.session is not None: + try: + async def close_session(): + await self.session.close() + if self.connector is not None: + await self.connector.close() + + self.async_loop_thread.run(close_session()) + print(f"[SGLang Worker] Rank {self.global_rank} aiohttp session closed.") + except Exception as e: + print(f"[SGLang Worker] Rank {self.global_rank} Error closing aiohttp session: {e}") + + if not hasattr(self, "server_process") or self.server_process is None: + return True + print( f"[SGLang Worker] Rank {self.global_rank} Shutting down server at {self.base_url}..." ) From 5e24fab0d285092d9fc8ab9d067d164826dd17b2 Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Wed, 26 Nov 2025 03:41:16 +0000 Subject: [PATCH 10/29] added mem_fraction Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- nemo_rl/models/generation/sglang/sglang_worker.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index 40a730bc41..dd0118aea8 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -237,6 +237,7 @@ def __init__( "dtype", "kv_cache_dtype", "context_length", "max_running_requests", "chunked_prefill_size", "max_prefill_tokens", "schedule_policy", "schedule_conservativeness", "cpu_offload_gb", "log_level", + "mem_fraction_static", ]: if key in self.cfg: kwargs[key] = self.cfg[key] @@ -382,8 +383,6 @@ def _launch_server_process(self, server_args: ServerArgs) -> multiprocessing.Pro raise Exception(f"[SGLang Server] Rank {self.global_rank} Server process terminated unexpectedly.") time.sleep(2) - # response = session.get(f"{self.base_url}/get_model_info", headers=headers) - # print(f"[SGLang Worker] Rank {self.global_rank} model_info: {response.json()}") return p From 50189a9c3fc0911037e5449ad2aa9be07c596201 Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Fri, 28 Nov 2025 22:27:27 +0000 Subject: [PATCH 11/29] modified build_sampling_paras and stop token handling Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- .../models/generation/sglang/sglang_worker.py | 123 +++++++++++++++--- 1 file changed, 104 insertions(+), 19 deletions(-) diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index dd0118aea8..1eb1453a01 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -186,6 +186,9 @@ def __init__( # there will be issues if we use the event loop in the main thread self.async_loop_thread = AsyncLoopThread() + # Maximum concurrent requests per server to avoid overloading + # Default to 8 concurrent requests per server + self.max_concurrent_requests = config.get("max_concurrent_requests", 16) # Only the primary worker (local_rank=0) in each server group starts the SGLang server # Secondary workers (local_rank!=0) just returns if not self.is_model_owner: @@ -256,7 +259,33 @@ def __init__( def _merge_stop_strings(self, batch_stop_strings): - pass + """Merge stop strings from config and batch. + + Args: + batch_stop_strings: List of stop strings from batch (one per sample) + + Returns: + List of merged stop strings (one per sample) + """ + stop_set: set[str] = set() + + # Add stop strings from config + if self.cfg.get("stop_strings"): + stop_set.update(self.cfg["stop_strings"]) + + # Merge stop strings from batch + merged_stop_strings = [] + for sample_ss in batch_stop_strings: + sample_stop_set = stop_set.copy() + if sample_ss: + if isinstance(sample_ss, str): + sample_stop_set.add(sample_ss) + elif isinstance(sample_ss, list): + sample_stop_set.update(sample_ss) + + merged_stop_strings.append(list(sample_stop_set) if sample_stop_set else None) + + return merged_stop_strings def _build_sampling_params( self, @@ -264,8 +293,60 @@ def _build_sampling_params( greedy: bool, stop_strings, max_new_tokens: Optional[int] = None, - ): - pass + input_len: Optional[int] = None, + context_length: Optional[int] = None, + sample_index: Optional[int] = None, + ) -> dict[str, Any]: + """Build sampling parameters dictionary for SGLang API. + + Args: + greedy: Whether to use greedy decoding (temperature=0.0) + stop_strings: Merged stop strings (not used here, handled per sample) + max_new_tokens: Override max_new_tokens from config if provided + input_len: Input length for this sample (used for context_length adjustment) + context_length: Maximum context length (if provided, adjusts max_new_tokens) + sample_index: Sample index (used for warning messages, 0-indexed) + + Returns: + Dictionary of sampling parameters compatible with SGLang API + """ + top_k_cfg = self.cfg.get("top_k") + top_k_val = 1 if greedy else (top_k_cfg if top_k_cfg is not None else -1) + temperature = 0.0 if greedy else self.cfg.get("temperature", 1.0) + + base_max_tokens = ( + max_new_tokens if max_new_tokens is not None else self.cfg.get("max_new_tokens", 512) + ) + + # TODO: check if this is needed + final_max_tokens = base_max_tokens + if context_length is not None and input_len is not None: + max_allowed_new_tokens = max(0, context_length - input_len - 1) + if base_max_tokens > max_allowed_new_tokens: + final_max_tokens = max_allowed_new_tokens + if sample_index == 0: + print( + f"[SGLang Worker] Rank {self.global_rank} Warning: " + f"Sample {sample_index} input length ({input_len}) + max_new_tokens ({base_max_tokens}) " + f"would exceed context_length ({context_length}). " + f"Reducing max_new_tokens to {final_max_tokens} for this sample." + ) + + # Build sampling params dict + sampling_params = { + "temperature": temperature, + "top_p": self.cfg.get("top_p", 1.0), + "max_new_tokens": final_max_tokens, + } + + if top_k_val != -1: + sampling_params["top_k"] = top_k_val + + stop_token_ids = self.cfg.get("stop_token_ids") + if stop_token_ids is not None: + sampling_params["stop_token_ids"] = stop_token_ids + + return sampling_params async def _ensure_session(self): if self.session is None: @@ -418,7 +499,8 @@ def generate( input_ids = data["input_ids"] input_lengths = data["input_lengths"] - stop_strings = data.get("stop_strings", [None] * len(input_lengths)) + batch_stop_strings = data.get("stop_strings", [None] * len(input_lengths)) + stop_strings = self._merge_stop_strings(batch_stop_strings) batch_size = len(input_lengths) pad_token_id = self.cfg.get("_pad_token_id", 0) @@ -430,33 +512,36 @@ def generate( print(f"[SGLang Worker] Rank {self.global_rank} batch_size: {batch_size}, padded_input_length: {padded_input_length}") - # Get generation parameters from config - max_new_tokens = self.cfg.get("max_new_tokens", 512) - temperature = 0.0 if greedy else self.cfg.get("temperature", 1.0) - top_p = self.cfg.get("top_p", 1.0) - top_k = self.cfg.get("top_k", None) - - sampling_params = { - "temperature": temperature, - "top_p": top_p, - "max_new_tokens": max_new_tokens, - } - if top_k is not None: - sampling_params["top_k"] = top_k - if batch_size == 0: raise ValueError("Empty batch received") + context_length = self.cfg.get("context_length", None) + # Create async tasks for all samples tasks = [] for i in range(batch_size): input_len = input_lengths[i].item() + + # Truncate input if it exceeds context_length + if context_length is not None and input_len >= context_length: + input_len = context_length - 1 + valid_input_ids = input_ids[i, :input_len].tolist() + # Build sampling params for this sample (with context_length adjustment) + sample_sampling_params = self._build_sampling_params( + greedy=greedy, + stop_strings=stop_strings, + max_new_tokens=None, + input_len=input_len, + context_length=context_length, + sample_index=i, + ) + tasks.append( self._generate_single_sample( input_ids=valid_input_ids, - sampling_params=sampling_params, + sampling_params=sample_sampling_params, stop_string=stop_strings[i], ) ) From ec35b6baec8dee8ba658414bec144bc1706948b9 Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Fri, 28 Nov 2025 21:07:14 +0000 Subject: [PATCH 12/29] temp: prevent server overlaod with semaphore Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- .../models/generation/sglang/sglang_worker.py | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index 1eb1453a01..00dafef3ce 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -186,9 +186,11 @@ def __init__( # there will be issues if we use the event loop in the main thread self.async_loop_thread = AsyncLoopThread() - # Maximum concurrent requests per server to avoid overloading - # Default to 8 concurrent requests per server - self.max_concurrent_requests = config.get("max_concurrent_requests", 16) + # + # temp: Maximum concurrent requests per server + # we may remove this limit in the future + self.max_concurrent_requests = config.get("max_concurrent_requests", 999999) + # Only the primary worker (local_rank=0) in each server group starts the SGLang server # Secondary workers (local_rank!=0) just returns if not self.is_model_owner: @@ -240,7 +242,7 @@ def __init__( "dtype", "kv_cache_dtype", "context_length", "max_running_requests", "chunked_prefill_size", "max_prefill_tokens", "schedule_policy", "schedule_conservativeness", "cpu_offload_gb", "log_level", - "mem_fraction_static", + "mem_fraction_static", "allow_auto_truncate", ]: if key in self.cfg: kwargs[key] = self.cfg[key] @@ -418,13 +420,20 @@ async def _generate_single_sample( return new_tokens, new_logprobs async def _generate_async(self, tasks): + """Execute generation tasks with concurrency control. + + TEMP: Uses a semaphore to limit the number of concurrent requests per server, preventing server overload. + A router based solution is preffered in the future. + """ + semaphore = asyncio.Semaphore(self.max_concurrent_requests) async def wrap(idx, coro): - try: - result = await coro - return idx, result - except Exception as e: - raise + async with semaphore: + try: + result = await coro + return idx, result + except Exception as e: + raise wrapped = [wrap(i, t) for i, t in enumerate(tasks)] results = [None] * len(tasks) From f099caa1a59e7d4cf8bbad2f7f8ffabe9807dee0 Mon Sep 17 00:00:00 2001 From: Ryan Date: Sun, 30 Nov 2025 13:58:49 -0500 Subject: [PATCH 13/29] sglang: refactor, move async loop position Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- .../models/generation/sglang/sglang_worker.py | 133 +++++++++++------- nemo_rl/models/generation/sglang/utils.py | 63 +++++++++ 2 files changed, 145 insertions(+), 51 deletions(-) create mode 100644 nemo_rl/models/generation/sglang/utils.py diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index 00dafef3ce..d47e32635f 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -20,7 +20,6 @@ import requests import asyncio import aiohttp -import threading import time import ray @@ -36,6 +35,7 @@ verify_right_padding, ) from nemo_rl.models.generation.sglang.config import SGLangConfig +from nemo_rl.models.generation.sglang.utils import AsyncLoopThread from nemo_rl.models.huggingface.common import ModelFlag from nemo_rl.utils.nsys import wrap_with_nvtx_name @@ -44,52 +44,6 @@ from sglang.srt.utils import kill_process_tree -class AsyncLoopThread: - """A background event loop thread for running async operations in Ray actors. - - This class creates a dedicated thread with its own event loop, allowing - synchronous Ray actor methods to execute async coroutines without blocking - the main actor thread. This is necessary because run_coroutine_threadsafe - requires the event loop to be in a different thread. - """ - def __init__(self): - self.loop = asyncio.new_event_loop() - self._ready = threading.Event() - self._thread = threading.Thread(target=self._start_loop, daemon=True) - self._thread.start() - if not self._ready.wait(timeout=5.0): - raise RuntimeError("Event loop thread failed to start within 5 seconds") - - def _start_loop(self): - """Run the event loop in the background thread.""" - asyncio.set_event_loop(self.loop) - self._ready.set() - self.loop.run_forever() - - def run(self, coro): - """Schedule a coroutine onto the loop and block until it's done. - - Args: - coro: The coroutine to execute - - Returns: - The result of the coroutine - """ - if not self.loop.is_running(): - raise RuntimeError("Event loop is not running") - future = asyncio.run_coroutine_threadsafe(coro, self.loop) - result = future.result() - return result - - def shutdown(self): - """Shutdown the event loop and wait for the thread to finish.""" - if self.loop.is_running(): - self.loop.call_soon_threadsafe(self.loop.stop) - self._thread.join(timeout=2.0) - if self.loop.is_running(): - self.loop.close() - - @ray.remote( runtime_env={**get_nsight_config_if_pattern_matches("sglang_generation_worker")} ) # pragma: no cover @@ -178,19 +132,16 @@ def __init__( """ self.cfg = config self.is_model_owner = bundle_indices is not None - - # This is the global worker rank across all workers self.global_rank = int(os.environ.get("RANK", "0")) # Create a dedicated event loop thread for async operations # there will be issues if we use the event loop in the main thread self.async_loop_thread = AsyncLoopThread() - # # temp: Maximum concurrent requests per server # we may remove this limit in the future self.max_concurrent_requests = config.get("max_concurrent_requests", 999999) - + # Only the primary worker (local_rank=0) in each server group starts the SGLang server # Secondary workers (local_rank!=0) just returns if not self.is_model_owner: @@ -259,6 +210,85 @@ def __init__( self.server_process = self._launch_server_process(server_args) + def get_base_url(self) -> str: + """Get the base URL of this SGLang server.""" + return self.base_url + + def invalidate_kv_cache(self) -> bool: + """Invalidate KV cache before weight updates (Megatron-style). + + This flushes the cache before weight updates to clear stale cache. + Uses retry logic to handle cases where there are pending requests. + + Returns: + bool: True if flush was successful, False otherwise + """ + if not self.is_model_owner: + return True + + url = f"{self.base_url}/flush_cache" + max_attempts = 60 + connection_retry_limit = 5 + + # flush_cache will not return status_code 200 when there are pending requests + for attempt in range(max_attempts): + try: + response = requests.get(url, timeout=10) + if response.status_code == 200: + if attempt > 0: + print( + f"[SGLang Worker] Rank {self.global_rank} Cache flushed successfully " + f"(attempt {attempt + 1})", + flush=True + ) + return True + except requests.exceptions.ConnectionError: + # Server might not be ready yet - only retry for first few attempts + if attempt >= connection_retry_limit: + print( + f"[SGLang Worker] Rank {self.global_rank} Connection failed after " + f"{connection_retry_limit} attempts", + flush=True + ) + return False + except Exception as e: + # For other errors, log and retry (except on last attempt) + if attempt == max_attempts - 1: + print( + f"[SGLang Worker] Rank {self.global_rank} Failed to flush cache after " + f"{max_attempts} attempts: {e}", + flush=True + ) + return False + + time.sleep(1) + + # All attempts exhausted without success + print( + f"[SGLang Worker] Rank {self.global_rank} Timeout: Cache flush failed after " + f"{max_attempts} attempts. Server may have pending requests.", + flush=True + ) + return False + + def get_gpu_uuids(self) -> list[str]: + """Get list of GPU UUIDs used by this SGLang server. + + Returns: + List of GPU UUIDs (e.g., ["GPU-xxxxx", "GPU-yyyyy"]) + """ + from nemo_rl.utils.nvml import get_device_uuid + + # Get all GPU UUIDs used by this server + # SGLang server uses GPUs starting from base_gpu_id with tp_size GPUs + gpu_uuids = [] + for i in range(self.server_args.tp_size): + gpu_id = self.server_args.base_gpu_id + i + uuid = get_device_uuid(gpu_id) + gpu_uuids.append(uuid) + + return gpu_uuids + def _merge_stop_strings(self, batch_stop_strings): """Merge stop strings from config and batch. @@ -379,6 +409,7 @@ async def _generate_single_sample( """ # Prepare payload for SGLang API # Note: stop should be in sampling_params, not in payload top level + # TODO: double check this if stop_string is not None: # stop can be a string or list of strings sampling_params = sampling_params.copy() # Don't modify the original diff --git a/nemo_rl/models/generation/sglang/utils.py b/nemo_rl/models/generation/sglang/utils.py new file mode 100644 index 0000000000..3b56037891 --- /dev/null +++ b/nemo_rl/models/generation/sglang/utils.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import threading + + +class AsyncLoopThread: + """A background event loop thread for running async operations in Ray actors. + + This class creates a dedicated thread with its own event loop, allowing + synchronous Ray actor methods to execute async coroutines without blocking + the main actor thread. This is necessary because run_coroutine_threadsafe + requires the event loop to be in a different thread. + """ + def __init__(self): + self.loop = asyncio.new_event_loop() + self._ready = threading.Event() + self._thread = threading.Thread(target=self._start_loop, daemon=True) + self._thread.start() + if not self._ready.wait(timeout=5.0): + raise RuntimeError("Event loop thread failed to start within 5 seconds") + + def _start_loop(self): + """Run the event loop in the background thread.""" + asyncio.set_event_loop(self.loop) + self._ready.set() + self.loop.run_forever() + + def run(self, coro): + """Schedule a coroutine onto the loop and block until it's done. + + Args: + coro: The coroutine to execute + + Returns: + The result of the coroutine + """ + if not self.loop.is_running(): + raise RuntimeError("Event loop is not running") + future = asyncio.run_coroutine_threadsafe(coro, self.loop) + result = future.result() + return result + + def shutdown(self): + """Shutdown the event loop and wait for the thread to finish.""" + if self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + self._thread.join(timeout=2.0) + if self.loop.is_running(): + self.loop.close() + From a03eba861203ea25518edba7c97d23a17b3f379a Mon Sep 17 00:00:00 2001 From: Ryan Date: Sun, 30 Nov 2025 13:59:54 -0500 Subject: [PATCH 14/29] sglang: fix total length in generate Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- nemo_rl/models/generation/sglang/sglang_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index d47e32635f..1aba513047 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -609,7 +609,7 @@ def generate( unpadded_length = input_len + generation_length max_length = max(max_length, unpadded_length) - total_length = padded_input_length + max_length + total_length = max(max_length, padded_input_length) for i, (new_tokens, new_logprobs) in enumerate(all_results): input_len = input_lengths[i].item() From e08cfd69d1a80555d0bee21bb4bce2fc327b0c9c Mon Sep 17 00:00:00 2001 From: Ryan Date: Sat, 29 Nov 2025 23:36:57 -0500 Subject: [PATCH 15/29] sglang: env setup sglang: add 1B example Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- examples/configs/grpo_math_1B_sglang.yaml | 285 ++++++++++++++++++++++ pyproject.toml | 20 ++ run.sh | 20 ++ 3 files changed, 325 insertions(+) create mode 100644 examples/configs/grpo_math_1B_sglang.yaml create mode 100755 run.sh diff --git a/examples/configs/grpo_math_1B_sglang.yaml b/examples/configs/grpo_math_1B_sglang.yaml new file mode 100644 index 0000000000..c9e28f9cff --- /dev/null +++ b/examples/configs/grpo_math_1B_sglang.yaml @@ -0,0 +1,285 @@ +# GRPO Algorithm Configuration +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 16 + max_rollout_turns: 1 + max_num_epochs: 1 + max_num_steps: 1000000 + normalize_rewards: true + use_leave_one_out_baseline: true + val_period: 2 + val_at_start: false + overlong_filtering: false + max_val_samples: 256 + val_batch_size: 128 + seed: 42 + use_dynamic_sampling: false + dynamic_sampling_max_gen_batches: 10 + batch_multiplier: 1 + reward_shaping: + enabled: false + overlong_buffer_length: 128 + overlong_buffer_penalty: 1 + max_response_length: ${policy.max_total_sequence_length} + reward_scaling: + enabled: false + source_min: 0.0 + source_max: 1.0 + target_min: 0.0 + target_max: 1.0 + + async_grpo: + enabled: false # Set to true to enable async training mode + # Max age (in training steps) for trajectories used in training + max_trajectory_age_steps: 1 + in_flight_weight_updates: false # Set to true to enable in-flight weight updates + recompute_kv_cache_after_weight_updates: false # Set to true to recompute kv cache after in-flight-weight-updates + +loss_fn: + reference_policy_kl_penalty: 0.01 + # Can be set to k1, k2, k3 + # For more details, see http://joschu.net/blog/kl-approx.html + reference_policy_kl_type: "k3" + kl_input_clamp_value: 20.0 + kl_output_clamp_value: 10.0 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + ratio_clip_c: null + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) + use_on_policy_kl_approximation: false + # Async GRPO requires importance sampling correction enabled + # Set to true when async_grpo.enabled is true + use_importance_sampling_correction: false + truncated_importance_sampling_ratio: null + sequence_level_importance_ratios: false + token_level_loss: true + +checkpointing: + enabled: true + checkpoint_dir: "results/grpo" + metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name + higher_is_better: true + keep_top_k: 3 + save_period: 10 + checkpoint_must_save_by: null + model_save_format: "safetensors" + save_consolidated: false + +policy: + model_name: "Qwen/Qwen2.5-1.5B" + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + chat_template_kwargs: null # can be used to pass kwargs to the chat template, e.g., enable_thinking=true + hf_config_overrides: {} + train_global_batch_size: 512 + train_micro_batch_size: 4 + generation_batch_size: 32 # Only used when generating using HF backend + logprob_batch_size: 4 + max_total_sequence_length: 512 + precision: "bfloat16" + logprob_chunk_size: null + offload_optimizer_for_logprob: false # Only useful for non-colocated generation since colocated generation will always offload optimizer to cuda before refit + + dtensor_cfg: + _v2: true + enabled: true + cpu_offload: False + sequence_parallel: false + activation_checkpointing: false + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + + megatron_cfg: + enabled: false + empty_unused_memory_level: 1 # 1 is the minimum recommendation for RL since we almost always need to offload before beginning generation. Setting to 0 is faster, but you are more likely to run out of GPU memory. + activation_checkpointing: false + converter_type: "Qwen2ForCausalLM" + tensor_model_parallel_size: 1 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + context_parallel_size: 1 + pipeline_dtype: ${policy.precision} + sequence_parallel: false + freeze_moe_router: true + moe_router_dtype: "fp64" + moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo + moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo + moe_permute_fusion: false + #gives ~20% training perf speedup with sequence packing + apply_rope_fusion: True + # gives ~25% training perf speedup with sequence packing and apply_rope_fusion + bias_activation_fusion: True + defer_fp32_logits: False + + optimizer: + optimizer: "adam" + lr: 5.0e-6 + min_lr: 5.0e-7 + weight_decay: 0.01 + bf16: true + fp16: false + params_dtype: "float32" + + #adam + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + + #sgd + sgd_momentum: 0.9 + + #distributed optimizer + use_distributed_optimizer: true + use_precision_aware_optimizer: true + + clip_grad: ${policy.max_grad_norm} + + # optimizer cpu offload + optimizer_cpu_offload: false + optimizer_offload_fraction: 0.0 + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: 1000 + lr_warmup_iters: 13 + lr_warmup_init: 5.0e-7 + + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: true + overlap_param_gather: true + use_custom_fsdp: false + data_parallel_sharding_strategy: "optim_grads_params" + + fp8_cfg: null + + env_vars: null + + # See docs/design-docs/sequence-packing-and-dynamic-batching.md + # for more details on dynamic batching and sequence packing. + dynamic_batching: + enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + sequence_length_round: 64 + + sequence_packing: + enabled: True + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + # when using Dtensor, we need to set foreach + # and fused to False + foreach: False + fused: False + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 50 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [50] + + generation: + backend: "sglang" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + # SGLang specific configuration + model_path: ${policy.model_name} # Model path for SGLang server + gpus_per_server: 1 # Number of GPUs per SGLang server (tensor parallel size) + dtype: ${policy.precision} # Model precision (bfloat16, float16, etc.) + context_length: 512 # Maximum context length + allow_auto_truncate: true + enable_memory_saver: false + max_running_requests: null + mem_fraction_static: 0.5 + skip_server_warmup: true # Skip server warmup to prevent timeout + colocated: + # true: generation shares training GPUs + # false: uses dedicated generation resources + enabled: true + # only relevant when enabled is false + resources: + gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 + num_nodes: null # Decides number of nodes to be dedicated to generation + +data: + max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len + prompt_file: "examples/prompts/cot.txt" + system_prompt_file: null + shuffle: true + num_workers: 1 + + dataset_name: "OpenMathInstruct-2" + # You can use custom response datasets for training and validation. For example: + # data: + # dataset_name: ResponseDataset + # train_data_path: # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace) + # val_data_path: + # input_key: , default is "input" + # output_key: , default is "output" + # train_split: , default is None # used for HuggingFace datasets + # val_split: , default is None # used for HuggingFace datasets + # See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#datasets for more details. + +env: + math: + num_workers: 8 + math_verify_impl: "hf_math_verify" + ## unused in this config but needed for DAPO recipe + dapo: + num_workers: 8 + math_verify_impl: "dapo_math_verify" + +logger: + log_dir: "logs" # Base directory for all logs + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb_enabled: false + tensorboard_enabled: false + mlflow_enabled: false # Disable MLflow logging + swanlab_enabled: false # Disable SwanLab logging + monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "grpo-dev" + name: "grpo-dev-logger" + tensorboard: {} + mlflow: + experiment_name: "grpo-dev" + run_name: "grpo-dev-logger" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 1 + num_nodes: 1 diff --git a/pyproject.toml b/pyproject.toml index a5a9881ea4..f668a896b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,26 @@ vllm = [ # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved "causal-conv1d", ] +sglang = [ + "sglang>=0.4.1", + "pybase64", + "orjson", + "uvloop", + "requests", + "openai", + "partial-json-parser", + "sentencepiece", + "sgl-kernel==0.3.17.post1", + "compressed-tensors", + "msgspec", + "python-multipart", + "torchao", + "xgrammar", + "interegular", + "openai-harmony", + "torch-memory-saver", + "einops", +] mcore = [ # also need cudnn (https://developer.nvidia.com/cudnn-downloads?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=deb_network) # wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.1-1_all.deb diff --git a/run.sh b/run.sh new file mode 100755 index 0000000000..fcea74f835 --- /dev/null +++ b/run.sh @@ -0,0 +1,20 @@ +#!/bin/bash +set -e + +VENV_NAME=".venv_test" +CONFIG_FILE="examples/configs/grpo_math_1B_sglang.yaml" + +if [ -d "$VENV_NAME" ]; then + echo "Removing existing virtual environment..." + rm -rf "$VENV_NAME" +fi + +uv venv "$VENV_NAME" +source "$VENV_NAME/bin/activate" +uv pip install -e ".[sglang]" + +echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES" + + +python examples/run_grpo_math.py --config "$CONFIG_FILE" + From ccc66f6b8664a6e072bc18e67891e0e8c3e0e11b Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Thu, 27 Nov 2025 21:52:19 +0000 Subject: [PATCH 16/29] from tensor: Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- nemo_rl/algorithms/grpo.py | 36 ++- .../generation/sglang/sglang_generation.py | 50 ++++ nemo_rl/models/policy/interfaces.py | 12 + nemo_rl/models/policy/lm_policy.py | 14 + nemo_rl/models/policy/utils.py | 245 ++++++++++++++++++ .../workers/dtensor_policy_worker_v2.py | 44 ++++ 6 files changed, 393 insertions(+), 8 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index ab0033575b..4a830a269b 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -979,8 +979,11 @@ def refit_policy_generation( timer: Optional Timer used to time the prepare/transfer/update phase kv_scales: Optional dictionary of KV cache scales for FP8 quantization. """ + print("[sglang refit] Starting refit process...", flush=True) if colocated_inference: + print("[sglang refit] Offloading optimizer before refit...", flush=True) policy.offload_before_refit() + print("[sglang refit] Preparing generation interface for weights...", flush=True) policy_generation.prepare_for_generation(tags=["weights"]) # Create a context manager that does nothing when timer is None @@ -1004,14 +1007,27 @@ def refit_policy_generation( policy.get_free_memory_bytes() * float(memory_ratio) ) - futures_train = policy.stream_weights_via_ipc_zmq( - buffer_size_bytes=buffer_size_bytes, kv_scales=kv_scales - ) - futures_inference = policy_generation.update_weights_via_ipc_zmq() - # wait for all futures to complete - ray.get(futures_train) - results = ray.get(futures_inference) - update_success = all(result for result in results if result is not None) + if isinstance(policy_generation, SGLangGeneration): + # Get SGLang server URL to GPU UUIDs mapping + sglang_url_to_gpu_uuids = policy_generation.get_sglang_url_to_gpu_uuids() + + futures_train = policy.stream_weights_via_http( + sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, + ) + # Wait for all workers to complete + ray.get(futures_train) + update_success = True + else: + # Original ZMQ IPC path for vLLM + print("[sglang refit] Using ZMQ IPC path for vLLM", flush=True) + futures_train = policy.stream_weights_via_ipc_zmq( + buffer_size_bytes=buffer_size_bytes + ) + futures_inference = policy_generation.update_weights_via_ipc_zmq() + # wait for all futures to complete + ray.get(futures_train) + results = ray.get(futures_inference) + update_success = all(result for result in results if result is not None) else: # update weights through nccl futures_train = policy.broadcast_weights_for_collective(kv_scales=kv_scales) @@ -1029,11 +1045,14 @@ def refit_policy_generation( f"This often indicates an issue with {error_tag} or " "a problem within the generation backend (e.g., vLLM worker).\n" ) + print(f"[sglang refit] {error_message}", flush=True) raise RuntimeError(error_message) if colocated_inference: + print("[sglang refit] Offloading after refit and preparing for generation...", flush=True) policy.offload_after_refit() policy_generation.prepare_for_generation(tags=["kv_cache"]) + print("[sglang refit] Refit process completed successfully", flush=True) # =============================================================================== @@ -1200,6 +1219,7 @@ def grpo_train( kv_scales=kv_scales_cache if sync_kv_scales else None, ) POLICY_GENERATION_STALE = False + print("[sglang refit] Policy generation refit completed, stale flag cleared", flush=True) else: if colocated_inference: policy.offload_after_refit() # unload optimizer to make space for generation diff --git a/nemo_rl/models/generation/sglang/sglang_generation.py b/nemo_rl/models/generation/sglang/sglang_generation.py index 19f208304a..6f538831d6 100644 --- a/nemo_rl/models/generation/sglang/sglang_generation.py +++ b/nemo_rl/models/generation/sglang/sglang_generation.py @@ -275,6 +275,56 @@ def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: def update_weights_from_collective(self) -> list[ray.ObjectRef]: return [] + + def get_sglang_server_urls(self) -> list[str]: + """Get base URLs of all SGLang servers. + + Returns: + List of base URLs (e.g., ["http://localhost:30000", "http://localhost:30001"]) + """ + if not self.worker_group or not self.worker_group.workers: + raise RuntimeError("Worker group is not initialized") + + # Get base URLs from all workers (only primary workers, TP rank 0) + # Use run_rank_0_only_axes to only get URLs from primary workers + futures = self.worker_group.run_all_workers_single_data( + "get_base_url", + run_rank_0_only_axes=["tensor_parallel"], + ) + urls = ray.get(futures) + # Filter out None values and return unique URLs + return list(set(url for url in urls if url is not None)) + + def get_sglang_url_to_gpu_uuids(self) -> dict[str, list[str]]: + """Get mapping from SGLang server URL to list of GPU UUIDs it uses. + + Returns: + Dict mapping server URL to list of GPU UUIDs + e.g., {"http://localhost:30000": ["GPU-aaa", "GPU-bbb"], ...} + """ + if not self.worker_group or not self.worker_group.workers: + raise RuntimeError("Worker group is not initialized") + + # Get base URLs and GPU UUIDs from all primary workers (TP rank 0) + futures_url = self.worker_group.run_all_workers_single_data( + "get_base_url", + run_rank_0_only_axes=["tensor_parallel"], + ) + futures_uuids = self.worker_group.run_all_workers_single_data( + "get_gpu_uuids", + run_rank_0_only_axes=["tensor_parallel"], + ) + + urls = ray.get(futures_url) + uuids_list = ray.get(futures_uuids) + + # Create mapping + url_to_uuids = {} + for url, uuids in zip(urls, uuids_list): + if url is not None and uuids is not None: + url_to_uuids[url] = uuids + + return url_to_uuids def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: """Wake workers up for colocated inference.""" diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index 144b0c517d..10b34e5ae0 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -182,6 +182,18 @@ def stream_weights_via_ipc_zmq( ) -> list[ray.ObjectRef]: pass + def stream_weights_via_http( + self, sglang_url_to_gpu_uuids: dict[str, list[str]] + ) -> None: + """Stream model weights to SGLang servers via HTTP API. + + Args: + sglang_url_to_gpu_uuids: Dict mapping SGLang server URL to list of GPU UUIDs it uses + """ + raise NotImplementedError( + "stream_weights_via_http is not implemented for this policy worker" + ) + @abstractmethod def broadcast_weights_for_collective( self, kv_scales: Optional[dict[str, float]] = None diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 434f850423..b2a0dd60b8 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -758,6 +758,20 @@ def stream_weights_via_ipc_zmq( ) return futures + def stream_weights_via_http( + self, sglang_url_to_gpu_uuids: dict[str, list[str]] + ) -> list[ray.ObjectRef]: + """Send the weights to SGLang servers via HTTP API. + + Args: + sglang_url_to_gpu_uuids: Dict mapping SGLang server URL to list of GPU UUIDs it uses + """ + futures = self.worker_group.run_all_workers_single_data( + "stream_weights_via_http", + sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, + ) + return futures + def broadcast_weights_for_collective( self, kv_scales: Optional[dict[str, float]] = None ) -> list[ray.ObjectRef]: diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 7ad6d99849..f8b1f9f38a 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -12,16 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import gc import importlib import os +import pickle import traceback from enum import Enum from typing import Any, Dict, Optional +import requests import torch +import torch.distributed as dist import zmq from torch.multiprocessing.reductions import rebuild_cuda_tensor + +from sglang.srt.utils import MultiprocessingSerializer +from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions from transformers import ( AutoModelForCausalLM, AutoModelForImageTextToText, @@ -488,3 +495,241 @@ def rebuild_cuda_tensor_from_ipc( list_args = list(args) list_args[6] = device_id return func(*list_args) + + +def stream_weights_via_http_impl( + params_generator, + sglang_url_to_gpu_uuids: dict[str, list[str]], + rank: int, + worker_name: str, + current_device_uuid: str, +) -> None: + """Stream weights to SGLang servers via HTTP API (update_weights_from_tensor). + + Flow: Each rank creates IPC handler → gather handlers in rank order → send list → SGLang matches by tp_rank index + + Key points: + - Each rank creates handler on its own GPU + - Handlers are gathered in rank order: [rank0_handler, rank1_handler, ...] + - List index = rank = GPU ID + - SGLang automatically matches: handler = serialized_handlers[tp_rank] + + Args: + params_generator: Generator yielding (name, tensor) pairs + sglang_url_to_gpu_uuids: Dict mapping SGLang server URL to list of GPU UUIDs it uses + rank: Worker rank for logging + worker_name: Name of the worker for logging + current_device_uuid: UUID of the current training worker's GPU + """ + monkey_patch_torch_reductions() + + target_urls = [ + url for url, uuids in sglang_url_to_gpu_uuids.items() + if current_device_uuid in uuids + ] + + if not target_urls: + raise RuntimeError( + f"{worker_name} (rank {rank}): No matching SGLang server found for GPU UUID {current_device_uuid}. " + f"Available servers: {list(sglang_url_to_gpu_uuids.keys())}" + ) + + if len(target_urls) > 1: + print( + f"[WARNING] {worker_name} (rank {rank}): GPU UUID {current_device_uuid} matches multiple SGLang servers: {target_urls}. " + f"Using the first one: {target_urls[0]}" + ) + target_urls = [target_urls[0]] + + base_url = target_urls[0] + url = f"{base_url}/update_weights_from_tensor" + sglang_gpu_uuids = sglang_url_to_gpu_uuids[base_url] + + ipc_gather_group, ipc_gather_src = _setup_ipc_gather_group( + rank, current_device_uuid, sglang_gpu_uuids, sglang_url_to_gpu_uuids + ) + + tensor_count = 0 + + try: + for name, tensor in params_generator: + torch.cuda.current_stream().synchronize() + tensor = tensor.contiguous().cuda() + + serialized_handler = MultiprocessingSerializer.serialize( + tensor, + output_str=True + ) + + gathered_handlers = _gather_ipc_handlers( + serialized_handler, ipc_gather_group, ipc_gather_src, rank + ) + + if rank == ipc_gather_src: + _send_tensor_to_sglang( + url, name, gathered_handlers, tensor.shape, str(tensor.dtype) + ) + tensor_count += 1 + + del tensor, serialized_handler + if rank == ipc_gather_src: + del gathered_handlers + torch.cuda.empty_cache() + + if rank == ipc_gather_src: + completion_payload = {"complete": True} + try: + response = requests.post(url, json=completion_payload, timeout=120) + response.raise_for_status() + except Exception as e: + raise RuntimeError( + f"{worker_name} (rank {rank}): Failed to send completion to {url}: {e}" + ) from e + + if rank == 0: + print( + f"[sglang refit] {worker_name}: Sent {tensor_count} tensors to SGLang server: {base_url}", + flush=True + ) + + except Exception as e: + print( + f"{worker_name} (rank {rank}): Error during HTTP weight streaming: {e}.\n" + f"{traceback.format_exc()}" + ) + raise + + finally: + gc.collect() + torch.cuda.empty_cache() + + +def _setup_ipc_gather_group( + rank: int, + current_device_uuid: str, + sglang_gpu_uuids: list[str], + sglang_url_to_gpu_uuids: dict[str, list[str]], +) -> tuple[Optional[dist.ProcessGroup], Optional[int]]: + """Setup Gloo group for gathering IPC handlers from ranks in the same SGLang server. + + Returns: + Tuple of (gather_group, gather_src_rank) or (None, None) if not needed + """ + if not dist.is_initialized(): + return None, None + + world_size = dist.get_world_size() + my_rank = dist.get_rank() + + all_ranks_uuids = [None] * world_size + dist.all_gather_object(all_ranks_uuids, current_device_uuid) + + matching_ranks = [ + r for r, uuid in enumerate(all_ranks_uuids) + if uuid in sglang_gpu_uuids + ] + + if len(matching_ranks) == 0: + return None, None + + matching_ranks = sorted(matching_ranks) + gather_src = matching_ranks[0] + + if my_rank in matching_ranks: + gather_group = dist.new_group(ranks=matching_ranks, backend="gloo") + return gather_group, gather_src + else: + return None, None + + +def _gather_ipc_handlers( + serialized_handler: str, + gather_group: Optional[dist.ProcessGroup], + gather_src: Optional[int], + rank: int, +) -> Optional[list[str]]: + """Gather IPC handlers from all ranks in the group to gather_src rank. + + Key: dist.gather_object automatically arranges by rank order + Result: gathered_handlers[0] = rank0_handler, gathered_handlers[1] = rank1_handler + Index = rank = GPU ID, automatically matched by SGLang tp_rank + + Returns: + List of serialized handlers in rank order (only on gather_src rank), None otherwise + """ + if gather_group is None or gather_src is None: + return None + + if not dist.is_initialized(): + return None + + world_size = dist.get_world_size(gather_group) + + if rank == gather_src: + gathered_handlers = [None] * world_size + else: + gathered_handlers = None + + dist.gather_object( + obj=serialized_handler, + object_gather_list=gathered_handlers, + dst=gather_src, + group=gather_group, + ) + + return gathered_handlers + + +def _send_tensor_to_sglang( + url: str, + tensor_name: str, + gathered_handlers: list[str], + shape: torch.Size, + dtype: str, +) -> None: + """Send gathered IPC handlers to SGLang server via HTTP. + + Key: gathered_handlers are in rank order [rank0, rank1, ...] + SGLang will automatically match: handler = serialized_handlers[tp_rank] + + Args: + url: SGLang server URL + tensor_name: Name of the tensor + gathered_handlers: List of serialized IPC handlers in rank order + shape: Tensor shape + dtype: Tensor dtype + """ + encoded_handlers = [ + base64.b64encode(handler.encode('utf-8')).decode('utf-8') + for handler in gathered_handlers + ] + + payload = { + "tensor_name": tensor_name, + "shape": list(shape), + "dtype": dtype, + "serialized_handlers": encoded_handlers, + } + + try: + response = requests.post( + url, + json=payload, + headers={"Content-Type": "application/json"}, + timeout=120, + ) + response.raise_for_status() + except requests.exceptions.HTTPError as e: + error_msg = f"Failed to send tensor '{tensor_name}' to {url}: {e}" + try: + error_detail = response.text + error_msg += f"\nResponse status: {response.status_code}" + error_msg += f"\nResponse body: {error_detail[:500]}" + except: + pass + print(f"[sglang refit] {error_msg}", flush=True) + raise RuntimeError(error_msg) from e + except Exception as e: + raise RuntimeError( + f"Failed to send tensor '{tensor_name}' to {url}: {e}" + ) from e diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 4b8bf56d42..fcf3ba4b6f 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -1692,6 +1692,50 @@ def dtensor_params_generator(): worker_name=str(self), ) + @torch.no_grad() + @wrap_with_nvtx_name("dtensor_policy_worker_v2/stream_weights_via_http") + def stream_weights_via_http( + self, + sglang_url_to_gpu_uuids: dict[str, list[str]], + ) -> None: + """Stream model weights to SGLang servers via HTTP API. + + Args: + sglang_url_to_gpu_uuids: Dict mapping SGLang server URL to list of GPU UUIDs it uses + """ + # Manually move model to cuda for cpu offload case + if self.cpu_offload: + self.model = self.move_to_cuda(self.model) + + from nemo_rl.models.policy.utils import stream_weights_via_http_impl + + # Get current GPU UUID + current_device_uuid = self.report_device_id() + + def dtensor_params_generator(): + """Generator that yields (name, tensor) pairs, converting DTensors to local tensors.""" + for name, tensor in self.model.state_dict().items(): + if isinstance(tensor, DTensor): + # Convert DTensor to full tensor for streaming + full_tensor = tensor.full_tensor() + # Convert to target dtype + yield ( + name, + full_tensor.to(self.dtype, non_blocking=True).contiguous(), + ) + else: + # Convert to target dtype + yield name, tensor.to(self.dtype, non_blocking=True).contiguous() + + # Use the HTTP implementation + stream_weights_via_http_impl( + params_generator=dtensor_params_generator(), + sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, + rank=self.rank, + worker_name=str(self), + current_device_uuid=current_device_uuid, + ) + @torch.no_grad() def broadcast_weights_for_collective( self, kv_scales: Optional[dict[str, float]] = None From 2ce928bbbcddfd75784dc69b289029247fdc7b82 Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Thu, 27 Nov 2025 23:07:14 +0000 Subject: [PATCH 17/29] sglang refit: fix sglang import Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- nemo_rl/algorithms/grpo.py | 1 - .../ray_actor_environment_registry.py | 1 + nemo_rl/distributed/virtual_cluster.py | 4 ++ nemo_rl/models/policy/utils.py | 43 ++++++++----------- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 4a830a269b..477dbc13a9 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1010,7 +1010,6 @@ def refit_policy_generation( if isinstance(policy_generation, SGLangGeneration): # Get SGLang server URL to GPU UUIDs mapping sglang_url_to_gpu_uuids = policy_generation.get_sglang_url_to_gpu_uuids() - futures_train = policy.stream_weights_via_http( sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, ) diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py index fb95d73e95..3d6e38abc0 100644 --- a/nemo_rl/distributed/ray_actor_environment_registry.py +++ b/nemo_rl/distributed/ray_actor_environment_registry.py @@ -66,3 +66,4 @@ def get_actor_python_env(actor_class_fqn: str) -> str: "adding a new generation framework or training backend), you'll need to specify the " "appropriate environment. See uv.md for more details." ) + diff --git a/nemo_rl/distributed/virtual_cluster.py b/nemo_rl/distributed/virtual_cluster.py index 4c42054455..979f1e3e77 100644 --- a/nemo_rl/distributed/virtual_cluster.py +++ b/nemo_rl/distributed/virtual_cluster.py @@ -52,6 +52,9 @@ class PY_EXECUTABLES: # Use NeMo-RL direct dependencies and nemo-automodel. AUTOMODEL = f"uv run --locked --extra automodel --directory {git_root}" + # Use NeMo-RL direct dependencies, nemo-automodel, and SGLang. + AUTOMODEL_SGLANG = "uv run --locked --extra automodel --extra sglang" + # Use NeMo-RL direct dependencies and Megatron. MCORE = f"uv run --locked --extra mcore --directory {git_root}" @@ -505,3 +508,4 @@ def __del__(self) -> None: user calls shutdown(). """ self.shutdown() + \ No newline at end of file diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index f8b1f9f38a..c3f1ad47cb 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -27,8 +27,6 @@ import zmq from torch.multiprocessing.reductions import rebuild_cuda_tensor -from sglang.srt.utils import MultiprocessingSerializer -from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions from transformers import ( AutoModelForCausalLM, AutoModelForImageTextToText, @@ -521,6 +519,12 @@ def stream_weights_via_http_impl( worker_name: Name of the worker for logging current_device_uuid: UUID of the current training worker's GPU """ + from sglang.srt.utils import MultiprocessingSerializer + try: + from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions + except ImportError: + from sglang.srt.patch_torch import monkey_patch_torch_reductions + monkey_patch_torch_reductions() target_urls = [ @@ -552,12 +556,16 @@ def stream_weights_via_http_impl( tensor_count = 0 try: - for name, tensor in params_generator: + tensor_list = list(params_generator) + total_tensors = len(tensor_list) + + for idx, (name, tensor) in enumerate(tensor_list): torch.cuda.current_stream().synchronize() tensor = tensor.contiguous().cuda() + named_tensors = [(name, tensor)] serialized_handler = MultiprocessingSerializer.serialize( - tensor, + named_tensors, output_str=True ) @@ -566,8 +574,10 @@ def stream_weights_via_http_impl( ) if rank == ipc_gather_src: + is_last = (idx == total_tensors - 1) _send_tensor_to_sglang( - url, name, gathered_handlers, tensor.shape, str(tensor.dtype) + url, name, gathered_handlers, tensor.shape, str(tensor.dtype), + flush_cache=is_last ) tensor_count += 1 @@ -576,16 +586,6 @@ def stream_weights_via_http_impl( del gathered_handlers torch.cuda.empty_cache() - if rank == ipc_gather_src: - completion_payload = {"complete": True} - try: - response = requests.post(url, json=completion_payload, timeout=120) - response.raise_for_status() - except Exception as e: - raise RuntimeError( - f"{worker_name} (rank {rank}): Failed to send completion to {url}: {e}" - ) from e - if rank == 0: print( f"[sglang refit] {worker_name}: Sent {tensor_count} tensors to SGLang server: {base_url}", @@ -686,6 +686,7 @@ def _send_tensor_to_sglang( gathered_handlers: list[str], shape: torch.Size, dtype: str, + flush_cache: bool = False, ) -> None: """Send gathered IPC handlers to SGLang server via HTTP. @@ -698,17 +699,11 @@ def _send_tensor_to_sglang( gathered_handlers: List of serialized IPC handlers in rank order shape: Tensor shape dtype: Tensor dtype + flush_cache: Whether to flush cache after this tensor (for last tensor) """ - encoded_handlers = [ - base64.b64encode(handler.encode('utf-8')).decode('utf-8') - for handler in gathered_handlers - ] - payload = { - "tensor_name": tensor_name, - "shape": list(shape), - "dtype": dtype, - "serialized_handlers": encoded_handlers, + "serialized_named_tensors": gathered_handlers, + "flush_cache": flush_cache, } try: From 4aa1e74eff496b3ef0511c5955d2d91098cb69c5 Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Fri, 28 Nov 2025 18:24:17 +0000 Subject: [PATCH 18/29] fix: match fsdp ranks correctly with sglang Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- nemo_rl/models/policy/utils.py | 61 +++++++++++++++++----------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index c3f1ad47cb..214974c87e 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -549,10 +549,10 @@ def stream_weights_via_http_impl( url = f"{base_url}/update_weights_from_tensor" sglang_gpu_uuids = sglang_url_to_gpu_uuids[base_url] - ipc_gather_group, ipc_gather_src = _setup_ipc_gather_group( + ipc_gather_group, ipc_gather_src, matching_ranks = _setup_ipc_gather_group( rank, current_device_uuid, sglang_gpu_uuids, sglang_url_to_gpu_uuids ) - + print(f"[sglang refit] {worker_name} (rank {rank}): ipc_gather_group={ipc_gather_group}, ipc_gather_src={ipc_gather_src}, matching_ranks={matching_ranks}") tensor_count = 0 try: @@ -570,7 +570,7 @@ def stream_weights_via_http_impl( ) gathered_handlers = _gather_ipc_handlers( - serialized_handler, ipc_gather_group, ipc_gather_src, rank + serialized_handler, ipc_gather_group, ipc_gather_src, rank, matching_ranks ) if rank == ipc_gather_src: @@ -609,14 +609,17 @@ def _setup_ipc_gather_group( current_device_uuid: str, sglang_gpu_uuids: list[str], sglang_url_to_gpu_uuids: dict[str, list[str]], -) -> tuple[Optional[dist.ProcessGroup], Optional[int]]: - """Setup Gloo group for gathering IPC handlers from ranks in the same SGLang server. +) -> tuple[Optional[dist.ProcessGroup], Optional[int], Optional[list[int]]]: + """Setup gather configuration for IPC handlers. Returns: - Tuple of (gather_group, gather_src_rank) or (None, None) if not needed + Tuple of (gather_group, gather_src_rank, matching_ranks) + - gather_group: None (use default FSDP group) + - gather_src_rank: The rank that will collect and send to SGLang server + - matching_ranks: List of ranks that belong to the same SGLang server """ if not dist.is_initialized(): - return None, None + return None, None, None world_size = dist.get_world_size() my_rank = dist.get_rank() @@ -630,16 +633,12 @@ def _setup_ipc_gather_group( ] if len(matching_ranks) == 0: - return None, None + return None, None, None matching_ranks = sorted(matching_ranks) gather_src = matching_ranks[0] - if my_rank in matching_ranks: - gather_group = dist.new_group(ranks=matching_ranks, backend="gloo") - return gather_group, gather_src - else: - return None, None + return None, gather_src, matching_ranks def _gather_ipc_handlers( @@ -647,37 +646,37 @@ def _gather_ipc_handlers( gather_group: Optional[dist.ProcessGroup], gather_src: Optional[int], rank: int, + matching_ranks: Optional[list[int]] = None, ) -> Optional[list[str]]: - """Gather IPC handlers from all ranks in the group to gather_src rank. + """Gather IPC handlers from all ranks in the default FSDP group, then filter by server. - Key: dist.gather_object automatically arranges by rank order - Result: gathered_handlers[0] = rank0_handler, gathered_handlers[1] = rank1_handler - Index = rank = GPU ID, automatically matched by SGLang tp_rank + Args: + serialized_handler: Serialized IPC handler from this rank + gather_group: Process group (None means use default FSDP group) + gather_src: Rank that will collect and filter handlers + rank: Current rank + matching_ranks: List of ranks that belong to the same SGLang server Returns: List of serialized handlers in rank order (only on gather_src rank), None otherwise + The list contains handlers from matching_ranks only, in rank order """ - if gather_group is None or gather_src is None: + if gather_src is None: return None if not dist.is_initialized(): return None - world_size = dist.get_world_size(gather_group) - - if rank == gather_src: - gathered_handlers = [None] * world_size - else: - gathered_handlers = None + world_size = dist.get_world_size() - dist.gather_object( - obj=serialized_handler, - object_gather_list=gathered_handlers, - dst=gather_src, - group=gather_group, - ) + all_handlers = [None] * world_size + dist.all_gather_object(all_handlers, serialized_handler) - return gathered_handlers + if rank == gather_src and matching_ranks is not None: + filtered_handlers = [all_handlers[r] for r in matching_ranks] + return filtered_handlers + else: + return None def _send_tensor_to_sglang( From 9098077e7cb2c2c7d77d5812aac35946505553c6 Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Fri, 28 Nov 2025 21:52:16 +0000 Subject: [PATCH 19/29] flush cache before update begins Signed-off-by: Ryan Signed-off-by: Zhuoran Yin --- nemo_rl/algorithms/grpo.py | 11 ++----- .../generation/sglang/sglang_generation.py | 33 +++++++++++-------- nemo_rl/models/policy/utils.py | 22 ++++++++++--- .../workers/dtensor_policy_worker_v2.py | 7 ++-- 4 files changed, 43 insertions(+), 30 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 477dbc13a9..e96f335dd2 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -979,11 +979,8 @@ def refit_policy_generation( timer: Optional Timer used to time the prepare/transfer/update phase kv_scales: Optional dictionary of KV cache scales for FP8 quantization. """ - print("[sglang refit] Starting refit process...", flush=True) if colocated_inference: - print("[sglang refit] Offloading optimizer before refit...", flush=True) policy.offload_before_refit() - print("[sglang refit] Preparing generation interface for weights...", flush=True) policy_generation.prepare_for_generation(tags=["weights"]) # Create a context manager that does nothing when timer is None @@ -1008,8 +1005,9 @@ def refit_policy_generation( ) if isinstance(policy_generation, SGLangGeneration): - # Get SGLang server URL to GPU UUIDs mapping sglang_url_to_gpu_uuids = policy_generation.get_sglang_url_to_gpu_uuids() + # Stream weights via HTTP + flush_success = policy_generation.invalidate_kv_cache() futures_train = policy.stream_weights_via_http( sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, ) @@ -1018,7 +1016,6 @@ def refit_policy_generation( update_success = True else: # Original ZMQ IPC path for vLLM - print("[sglang refit] Using ZMQ IPC path for vLLM", flush=True) futures_train = policy.stream_weights_via_ipc_zmq( buffer_size_bytes=buffer_size_bytes ) @@ -1044,14 +1041,11 @@ def refit_policy_generation( f"This often indicates an issue with {error_tag} or " "a problem within the generation backend (e.g., vLLM worker).\n" ) - print(f"[sglang refit] {error_message}", flush=True) raise RuntimeError(error_message) if colocated_inference: - print("[sglang refit] Offloading after refit and preparing for generation...", flush=True) policy.offload_after_refit() policy_generation.prepare_for_generation(tags=["kv_cache"]) - print("[sglang refit] Refit process completed successfully", flush=True) # =============================================================================== @@ -1218,7 +1212,6 @@ def grpo_train( kv_scales=kv_scales_cache if sync_kv_scales else None, ) POLICY_GENERATION_STALE = False - print("[sglang refit] Policy generation refit completed, stale flag cleared", flush=True) else: if colocated_inference: policy.offload_after_refit() # unload optimizer to make space for generation diff --git a/nemo_rl/models/generation/sglang/sglang_generation.py b/nemo_rl/models/generation/sglang/sglang_generation.py index 6f538831d6..ff062a79ff 100644 --- a/nemo_rl/models/generation/sglang/sglang_generation.py +++ b/nemo_rl/models/generation/sglang/sglang_generation.py @@ -353,22 +353,27 @@ def __del__(self) -> None: self.shutdown() def invalidate_kv_cache(self) -> bool: - """Invalidate KV cache after weight updates. + """Invalidate KV cache before weight updates (Megatron-style). - For SGLang, this might need to call a different method or might not be needed - if the server handles it automatically. + This flushes the cache before weight updates to clear stale cache. + Only primary workers (TP rank 0, model owners) will flush their cache. + + Returns: + bool: True if all caches were flushed successfully, False otherwise """ try: - # For SGLang, we can call a method on each worker if it exists - futures = [] - for worker in self.worker_group.workers: - if hasattr(worker, "invalidate_kv_cache"): - futures.append(worker.invalidate_kv_cache.remote()) - - if futures: - results = ray.get(futures) - return all(result for result in results if result is not None) - return True + futures = self.worker_group.run_all_workers_single_data( + "invalidate_kv_cache", + run_rank_0_only_axes=["tensor_parallel"], + ) + results = ray.get(futures) + results = [r for r in results if r is not None] + success = all(result for result in results) if results else True + if success: + print("[sglang refit] All SGLang server caches flushed successfully", flush=True) + else: + print("[sglang refit] WARNING - Some SGLang server caches failed to flush", flush=True) + return success except Exception as e: - print(f"Error invalidating SGLang caches: {e}") + print(f"[sglang refit] Error flushing SGLang caches: {e}", flush=True) return False diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 214974c87e..019f67bb8f 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -524,6 +524,7 @@ def stream_weights_via_http_impl( from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions except ImportError: from sglang.srt.patch_torch import monkey_patch_torch_reductions + print(f"[sglang refit details] entering stream_weights_via_http_impl") monkey_patch_torch_reductions() @@ -559,6 +560,13 @@ def stream_weights_via_http_impl( tensor_list = list(params_generator) total_tensors = len(tensor_list) + if rank == ipc_gather_src: + print( + f"[sglang refit details] {worker_name}: Starting weight update - " + f"Total parameters to update: {total_tensors}", + flush=True + ) + for idx, (name, tensor) in enumerate(tensor_list): torch.cuda.current_stream().synchronize() tensor = tensor.contiguous().cuda() @@ -574,10 +582,9 @@ def stream_weights_via_http_impl( ) if rank == ipc_gather_src: - is_last = (idx == total_tensors - 1) _send_tensor_to_sglang( url, name, gathered_handlers, tensor.shape, str(tensor.dtype), - flush_cache=is_last + flush_cache=False ) tensor_count += 1 @@ -586,11 +593,18 @@ def stream_weights_via_http_impl( del gathered_handlers torch.cuda.empty_cache() - if rank == 0: + if rank == ipc_gather_src: print( - f"[sglang refit] {worker_name}: Sent {tensor_count} tensors to SGLang server: {base_url}", + f"[sglang refit details] {worker_name}: Weight update completed - " + f"Successfully updated {tensor_count}/{total_tensors} parameters to SGLang server: {base_url}", flush=True ) + if tensor_count != total_tensors: + print( + f"[sglang refit details] {worker_name}: WARNING - Expected {total_tensors} tensors, " + f"but only sent {tensor_count}", + flush=True + ) except Exception as e: print( diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index fcf3ba4b6f..c6ba034c11 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -1713,8 +1713,10 @@ def stream_weights_via_http( current_device_uuid = self.report_device_id() def dtensor_params_generator(): - """Generator that yields (name, tensor) pairs, converting DTensors to local tensors.""" - for name, tensor in self.model.state_dict().items(): + """Generator that yields (name, tensor) pairs, converting DTensors to local tensors. + """ + state_dict_items = sorted(self.model.state_dict().items(), key=lambda x: x[0]) + for name, tensor in state_dict_items: if isinstance(tensor, DTensor): # Convert DTensor to full tensor for streaming full_tensor = tensor.full_tensor() @@ -1726,7 +1728,6 @@ def dtensor_params_generator(): else: # Convert to target dtype yield name, tensor.to(self.dtype, non_blocking=True).contiguous() - # Use the HTTP implementation stream_weights_via_http_impl( params_generator=dtensor_params_generator(), From 9900a3363328079832bae7dad593052205b4cc25 Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Mon, 1 Dec 2025 20:03:10 +0000 Subject: [PATCH 20/29] Fix SGLang compatibility: add hasattr checks for vLLM-specific methods Signed-off-by: Zhuoran Yin --- nemo_rl/algorithms/grpo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index e96f335dd2..dc5b0ecf3e 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -2124,7 +2124,6 @@ def async_grpo_train( trajectory_collector.resume.remote() print("✅ All setup complete, starting buffer wait...") - # Clear vLLM logger metrics after at start of training if policy_generation is not None and hasattr( policy_generation, "clear_vllm_logger_metrics" From 5cb78e34f91db9410ed3b7672ec43f4f3af4205b Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Mon, 1 Dec 2025 20:56:38 +0000 Subject: [PATCH 21/29] sglang: modified config (increase mem_fration, enable wandb) Signed-off-by: Zhuoran Yin --- examples/configs/grpo_math_1B_sglang.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/configs/grpo_math_1B_sglang.yaml b/examples/configs/grpo_math_1B_sglang.yaml index c9e28f9cff..97d6f38a56 100644 --- a/examples/configs/grpo_math_1B_sglang.yaml +++ b/examples/configs/grpo_math_1B_sglang.yaml @@ -7,7 +7,7 @@ grpo: max_num_steps: 1000000 normalize_rewards: true use_leave_one_out_baseline: true - val_period: 2 + val_period: 10 val_at_start: false overlong_filtering: false max_val_samples: 256 @@ -222,7 +222,7 @@ policy: allow_auto_truncate: true enable_memory_saver: false max_running_requests: null - mem_fraction_static: 0.5 + mem_fraction_static: 0.7 skip_server_warmup: true # Skip server warmup to prevent timeout colocated: # true: generation shares training GPUs @@ -264,7 +264,7 @@ env: logger: log_dir: "logs" # Base directory for all logs num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal - wandb_enabled: false + wandb_enabled: true tensorboard_enabled: false mlflow_enabled: false # Disable MLflow logging swanlab_enabled: false # Disable SwanLab logging From 03d9d0c3ab1a5db08fccc3496328dbf81e2721e8 Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Tue, 2 Dec 2025 18:25:46 +0000 Subject: [PATCH 22/29] refactor(grpo): extract init logic for generation backends Signed-off-by: Zhuoran Yin --- nemo_rl/algorithms/grpo.py | 160 ++++++++++++++++++------------------- 1 file changed, 79 insertions(+), 81 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index dc5b0ecf3e..6744dd499b 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -490,9 +490,71 @@ def init_sglang(): pg.finish_generation() return pg, time.perf_counter() - t0 - # Handle backend-specific setup + def initialize_generation_with_policy( + init_generation_fn, + generation_name: str, + init_time_key: str, + colocated_inference: bool, + worker_init_timing_metrics: dict, + ): + """ + Generic function to initialize a generation engine (vLLM or SGLang) along with policy. + + Args: + init_generation_fn: Function that initializes the generation engine (init_vllm or init_sglang) + generation_name: Name of the generation engine ("vLLM" or "SGLang") + init_time_key: Key name for storing initialization time in metrics ("vllm_init_time_s" or "sglang_init_time_s") + colocated_inference: Whether inference is colocated with training + worker_init_timing_metrics: Dictionary to store timing metrics + + Returns: + Tuple of (policy_generation, policy) + """ + # Determine if parallel initialization is possible (non-colocated mode) + use_parallel_init = not colocated_inference + + if use_parallel_init: + # Parallel initialization: Generation engine and Policy can initialize simultaneously + print( + " ⚡ Using parallel worker initialization (non-colocated mode)", + flush=True, + ) + + # Execute both initializations in parallel + parallel_start_time = time.perf_counter() + with ThreadPoolExecutor(max_workers=2) as executor: + generation_future = executor.submit(init_generation_fn) + policy_future = executor.submit(init_policy) + policy_generation, generation_time = generation_future.result() + policy, policy_time = policy_future.result() + parallel_wall_time = time.perf_counter() - parallel_start_time + + # Store timing metrics + worker_init_timing_metrics[init_time_key] = generation_time + worker_init_timing_metrics["policy_init_time_s"] = policy_time + worker_init_timing_metrics["parallel_wall_time_s"] = parallel_wall_time + worker_init_timing_metrics["parallel_init_enabled"] = True + + else: + # Sequential initialization: colocated mode (GPU memory requires generation engine first) + print( + " ⚙️ Using sequential worker initialization (colocated mode)", + flush=True, + ) + + # Initialize generation engine first (clean GPU memory), then policy + policy_generation, generation_time = init_generation_fn() + worker_init_timing_metrics[init_time_key] = generation_time + + policy, policy_time = init_policy() + worker_init_timing_metrics["policy_init_time_s"] = policy_time + worker_init_timing_metrics["parallel_init_enabled"] = 0.0 + + return policy_generation, policy + + # Handle generation-specific setup if backend == "megatron": - # Megatron backend: policy_generation is None, only initialize policy + # Megatron generation: policy_generation is None, only initialize policy policy_generation = None print( f" ✓ Using {backend} backend for generation with {policy_config['model_name']}", @@ -503,7 +565,7 @@ def init_sglang(): worker_init_timing_metrics["policy_init_time_s"] = policy_time elif backend == "vllm": - # vLLM backend: setup config, then decide parallel vs sequential init + # vLLM generation: setup config, then initialize with policy generation_config = cast(VllmConfig, generation_config) if generation_config["vllm_cfg"]["precision"] == "fp8": assert loss_config["use_importance_sampling_correction"] is True, ( @@ -531,45 +593,13 @@ def init_sglang(): "hf_config_overrides", {} ) - # Determine if parallel initialization is possible (non-colocated mode) - use_parallel_init = not colocated_inference - - if use_parallel_init: - # Parallel initialization: vLLM and Policy can initialize simultaneously - print( - " ⚡ Using parallel worker initialization (non-colocated mode)", - flush=True, - ) - - # Execute both initializations in parallel - parallel_start_time = time.perf_counter() - with ThreadPoolExecutor(max_workers=2) as executor: - vllm_future = executor.submit(init_vllm) - policy_future = executor.submit(init_policy) - policy_generation, vllm_time = vllm_future.result() - policy, policy_time = policy_future.result() - parallel_wall_time = time.perf_counter() - parallel_start_time - - # Store timing metrics - worker_init_timing_metrics["vllm_init_time_s"] = vllm_time - worker_init_timing_metrics["policy_init_time_s"] = policy_time - worker_init_timing_metrics["parallel_wall_time_s"] = parallel_wall_time - worker_init_timing_metrics["parallel_init_enabled"] = True - - else: - # Sequential initialization: colocated mode (GPU memory requires vLLM first) - print( - " ⚙️ Using sequential worker initialization (colocated mode)", - flush=True, - ) - - # Initialize vLLM first (clean GPU memory), then policy - policy_generation, vllm_time = init_vllm() - worker_init_timing_metrics["vllm_init_time_s"] = vllm_time - - policy, policy_time = init_policy() - worker_init_timing_metrics["policy_init_time_s"] = policy_time - worker_init_timing_metrics["parallel_init_enabled"] = 0.0 + policy_generation, policy = initialize_generation_with_policy( + init_generation_fn=init_vllm, + generation_name="vLLM", + init_time_key="vllm_init_time_s", + colocated_inference=colocated_inference, + worker_init_timing_metrics=worker_init_timing_metrics, + ) print( f" ✓ Using vLLM backend for generation with {policy_config['model_name']}", @@ -582,45 +612,13 @@ def init_sglang(): if "model_path" not in generation_config or not generation_config.get("model_path"): generation_config["model_path"] = policy_config["model_name"] - # Determine if parallel initialization is possible (non-colocated mode) - use_parallel_init = not colocated_inference - - if use_parallel_init: - # Parallel initialization: SGLang and Policy can initialize simultaneously - print( - " ⚡ Using parallel worker initialization (non-colocated mode)", - flush=True, - ) - - # Execute both initializations in parallel - parallel_start_time = time.perf_counter() - with ThreadPoolExecutor(max_workers=2) as executor: - sglang_future = executor.submit(init_sglang) - policy_future = executor.submit(init_policy) - policy_generation, sglang_time = sglang_future.result() - policy, policy_time = policy_future.result() - parallel_wall_time = time.perf_counter() - parallel_start_time - - # Store timing metrics - worker_init_timing_metrics["sglang_init_time_s"] = sglang_time - worker_init_timing_metrics["policy_init_time_s"] = policy_time - worker_init_timing_metrics["parallel_wall_time_s"] = parallel_wall_time - worker_init_timing_metrics["parallel_init_enabled"] = True - - else: - # Sequential initialization: colocated mode (GPU memory requires SGLang first) - print( - " ⚙️ Using sequential worker initialization (colocated mode)", - flush=True, - ) - - # Initialize SGLang first (clean GPU memory), then policy - policy_generation, sglang_time = init_sglang() - worker_init_timing_metrics["sglang_init_time_s"] = sglang_time - - policy, policy_time = init_policy() - worker_init_timing_metrics["policy_init_time_s"] = policy_time - worker_init_timing_metrics["parallel_init_enabled"] = 0.0 + policy_generation, policy = initialize_generation_with_policy( + init_generation_fn=init_sglang, + generation_name="SGLang", + init_time_key="sglang_init_time_s", + colocated_inference=colocated_inference, + worker_init_timing_metrics=worker_init_timing_metrics, + ) print( f" ✓ Using SGLang backend for generation with {policy_config['model_name']}", From 7ca9776b31882b1325ca3250c79e2b515e8c0a6a Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Tue, 2 Dec 2025 18:53:54 +0000 Subject: [PATCH 23/29] refactor SGLangConfig - Convert SGLangConfig from regular class to TypedDict inheriting GenerationConfig - Align structure with VllmConfig pattern for consistency - Mark all fields as NotRequired for backward compatibility - Add sglang_kwargs field for additional ServerArgs parameters - Add type casting in grpo.py for type safety This maintains backward compatibility while aligning with the existing generation config structure pattern. Signed-off-by: Zhuoran Yin --- nemo_rl/algorithms/grpo.py | 1 + nemo_rl/models/generation/sglang/config.py | 127 +++++++++--------- .../generation/sglang/sglang_generation.py | 7 +- 3 files changed, 68 insertions(+), 67 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 6744dd499b..4d3c7ee8bc 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -607,6 +607,7 @@ def initialize_generation_with_policy( ) elif backend == "sglang": + generation_config = cast(SGLangConfig, generation_config) # Set model_name and model_path generation_config["model_name"] = policy_config["model_name"] if "model_path" not in generation_config or not generation_config.get("model_path"): diff --git a/nemo_rl/models/generation/sglang/config.py b/nemo_rl/models/generation/sglang/config.py index 12e99ad82b..9c82c7583b 100644 --- a/nemo_rl/models/generation/sglang/config.py +++ b/nemo_rl/models/generation/sglang/config.py @@ -17,75 +17,76 @@ from nemo_rl.models.generation.interfaces import GenerationConfig -class SGLangConfig(): - """Configuration for SGLang runtime. Refer to: - https://github.com/sgl-project/sglang for detailed documentation. +class SGLangConfig(GenerationConfig): + """Configuration for SGLang runtime. + + Most fields below map directly to SGLang's ServerArgs (see: + https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py). """ - model_path: str = "" - random_seed: int = 1 - skip_tokenizer_init: bool = False - disable_cuda_graph: bool = False - disable_radix_cache: bool = True - disable_cuda_graph_padding: bool = False - enable_nccl_nvls: bool = False - disable_outlines_disk_cache: bool = False - disable_custom_all_reduce: bool = False - disable_overlap_schedule: bool = False - enable_mixed_chunk: bool = False - enable_dp_attention: bool = False - enable_ep_moe: bool = False - enable_torch_compile: bool = False - torch_compile_max_bs: int = 32 - cuda_graph_max_bs: int | None = None - cuda_graph_bs: list[int] | None = None - torchao_config: str = "" - enable_nan_detection: bool = False - enable_p2p_check: bool = False - triton_attention_reduce_in_fp32: bool = False - triton_attention_num_kv_splits: int = 8 - num_continuous_decode_steps: int = 1 - enable_memory_saver: bool = False - allow_auto_truncate: bool = False - attention_backend: str | None = "fa3" - enable_multimodal: bool = False - sampling_backend: str | None = None - context_length: int | None = 32768 - mem_fraction_static: float | None = 0.9 - max_running_requests: int | None = None - # NOTE: chunked_prefill_size is by default 8192 on GPUs with 80GB mem in SGLang, - # but we disable it to avoid precision issues - chunked_prefill_size: int | None = -1 - max_prefill_tokens: int = 32768 - schedule_policy: str = "lpm" - schedule_conservativeness: float = 1.0 - cpu_offload_gb: int = 0 - dtype: str = "bfloat16" - kv_cache_dtype: str = "auto" - dp_size: int = 1 # only used for dp attention - ep_size: int = 1 + model_path: NotRequired[str] + gpus_per_server: NotRequired[int] + random_seed: NotRequired[int] + skip_tokenizer_init: NotRequired[bool] + disable_cuda_graph: NotRequired[bool] + disable_radix_cache: NotRequired[bool] + disable_cuda_graph_padding: NotRequired[bool] + enable_nccl_nvls: NotRequired[bool] + disable_outlines_disk_cache: NotRequired[bool] + disable_custom_all_reduce: NotRequired[bool] + disable_overlap_schedule: NotRequired[bool] + enable_mixed_chunk: NotRequired[bool] + enable_dp_attention: NotRequired[bool] + enable_ep_moe: NotRequired[bool] + enable_torch_compile: NotRequired[bool] + torch_compile_max_bs: NotRequired[int] + cuda_graph_max_bs: NotRequired[int | None] + cuda_graph_bs: NotRequired[list[int] | None] + torchao_config: NotRequired[str] + enable_nan_detection: NotRequired[bool] + enable_p2p_check: NotRequired[bool] + triton_attention_reduce_in_fp32: NotRequired[bool] + triton_attention_num_kv_splits: NotRequired[int] + num_continuous_decode_steps: NotRequired[int] + enable_memory_saver: NotRequired[bool] + allow_auto_truncate: NotRequired[bool] + attention_backend: NotRequired[str | None] + enable_multimodal: NotRequired[bool] + sampling_backend: NotRequired[str | None] + context_length: NotRequired[int | None] + mem_fraction_static: NotRequired[float | None] + max_running_requests: NotRequired[int | None] + chunked_prefill_size: NotRequired[int | None] + max_prefill_tokens: NotRequired[int] + schedule_policy: NotRequired[str] + schedule_conservativeness: NotRequired[float] + cpu_offload_gb: NotRequired[int] + dtype: NotRequired[str] + kv_cache_dtype: NotRequired[str] + dp_size: NotRequired[int] # only used for dp attention + ep_size: NotRequired[int] # lora - enable_lora: bool | None = None - max_lora_rank: int | None = None - lora_target_modules: list[str] | None = None - lora_paths: list[str] | None = None - max_loaded_loras: int = 1 - max_loras_per_batch: int = 1 - lora_backend: str = "triton" + enable_lora: NotRequired[bool | None] + max_lora_rank: NotRequired[int | None] + lora_target_modules: NotRequired[list[str] | None] + lora_paths: NotRequired[list[str] | None] + max_loaded_loras: NotRequired[int] + max_loras_per_batch: NotRequired[int] + lora_backend: NotRequired[str] # logging - log_level: str = "warning" - log_level_http: str | None = "warning" - log_requests: bool = False - log_requests_level: int = 0 - show_time_cost: bool = False - enable_metrics: bool = True # Exports Prometheus-like metrics + log_level: NotRequired[str] + log_level_http: NotRequired[str | None] + log_requests: NotRequired[bool] + log_requests_level: NotRequired[int] + show_time_cost: NotRequired[bool] + enable_metrics: NotRequired[bool] # Exports Prometheus-like metrics # The interval (in decoding iterations) to log throughput # and update prometheus metrics - decode_log_interval: int = 1 + decode_log_interval: NotRequired[int] # Extra loader arguments - # NOTE: These arguments will be parsed into a dict json-string - # and passed as `model_loader_extra_config` to SGLang. - enable_multithread_load: bool = False - enable_fast_load: bool = False + enable_multithread_load: NotRequired[bool] + enable_fast_load: NotRequired[bool] + # Additional ServerArgs fields can be passed via this generic kwargs dict + sglang_kwargs: NotRequired[dict[str, Any]] \ No newline at end of file diff --git a/nemo_rl/models/generation/sglang/sglang_generation.py b/nemo_rl/models/generation/sglang/sglang_generation.py index ff062a79ff..47065aa557 100644 --- a/nemo_rl/models/generation/sglang/sglang_generation.py +++ b/nemo_rl/models/generation/sglang/sglang_generation.py @@ -64,12 +64,11 @@ def __init__( # Store config self.cfg = config - # Get number of GPUs per server from config - # For SGLang, this is typically the tensor parallel size - # TODO: Add proper config field, hardcoded to 4 for now gpus_per_server = self.cfg.get("gpus_per_server", None) if gpus_per_server is None: - gpus_per_server = 4 + raise ValueError( + "gpus_per_server must be set in SGLangConfig. " + ) # Calculate number of servers based on available resources total_gpus = cluster.world_size() From f1c26dd182adf93be3a82e9bafedf427569f995e Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Thu, 4 Dec 2025 19:24:02 +0000 Subject: [PATCH 24/29] refactor: generalize logger metrics for all generation backends Signed-off-by: Zhuoran Yin --- nemo_rl/algorithms/grpo.py | 65 +++++++---------- nemo_rl/algorithms/utils.py | 69 ++++++++++--------- nemo_rl/models/generation/interfaces.py | 19 +++++ .../models/generation/vllm/vllm_generation.py | 8 +++ 4 files changed, 86 insertions(+), 75 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 4d3c7ee8bc..5b6518589b 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1218,11 +1218,10 @@ def grpo_train( dynamic_sampling_num_gen_batches += 1 with timer.time("generation"): - # Clear vLLM logger metrics for each generation step - if policy_generation is not None and hasattr( - policy_generation, "clear_vllm_logger_metrics" - ): - policy_generation.clear_vllm_logger_metrics() + + # Clear logger metrics for each generation step + if policy_generation is not None: + policy_generation.clear_logger_metrics() # Use NeMo-Gym rollouts if enabled. We cascade NeMo-Gym first since NeMo-Gym requires async rollouts. if _should_use_nemo_gym(master_config): generation_config = master_config["policy"]["generation"] @@ -1272,16 +1271,10 @@ def grpo_train( greedy=False, ) policy_generation.finish_generation() - # Collect vLLM logger metrics for performance reporting after each generation step - # inflight batch sizes and num pending samples are collected from each vLLM worker - if policy_generation is not None and hasattr( - policy_generation, "get_vllm_logger_metrics" - ): - vllm_logger_metrics = ( - policy_generation.get_vllm_logger_metrics() - ) - else: - vllm_logger_metrics = {} + # Collect generation logger metrics for performance reporting after each generation step + # inflight batch sizes and num pending samples are collected from each worker + if policy_generation is not None: + generation_logger_metrics = policy_generation.get_logger_metrics() repeated_batch = scale_rewards( repeated_batch, master_config["grpo"]["reward_scaling"] @@ -1530,7 +1523,7 @@ def grpo_train( metrics[k] = np.sum(v).item() metrics.update(rollout_metrics) - metrics["vllm_logger_metrics"] = vllm_logger_metrics + metrics["generation_logger_metrics"] = generation_logger_metrics total_valid_tokens += metrics["global_valid_toks"] ## Checkpointing @@ -1653,7 +1646,7 @@ def grpo_train( "enable_vllm_metrics_logger", False ) and master_config.get("logger", {}).get("wandb_enabled", False): log_generation_metrics_to_wandb( - vllm_logger_metrics, + generation_logger_metrics, total_steps + 1, master_config["policy"]["generation"]["vllm_cfg"][ "vllm_metrics_logger_interval" @@ -2123,11 +2116,9 @@ def async_grpo_train( trajectory_collector.resume.remote() print("✅ All setup complete, starting buffer wait...") - # Clear vLLM logger metrics after at start of training - if policy_generation is not None and hasattr( - policy_generation, "clear_vllm_logger_metrics" - ): - policy_generation.clear_vllm_logger_metrics() + # Clear logger metrics at start of training + if policy_generation is not None: + policy_generation.clear_logger_metrics() # Wait for initial buffer fill print( @@ -2367,23 +2358,17 @@ def async_grpo_train( train_results = policy.train(train_data, loss_fn) print("🔄 Synchronizing policy weights to trajectory collector…") - vllm_logger_metrics = None + generation_logger_metrics = None if NEED_REFIT: # Measure pending-generation wait as exposed_generation time print("🔄 Coordinating with trajectory collector before refit...") with timer.time("exposed_generation"): ray.get(trajectory_collector.prepare_for_refit.remote()) - # Collect vLLM logger metrics for performance reporting - # inflight batch sizes and num pending samples are collected from each vLLM worker - if policy_generation is not None and hasattr( - policy_generation, "get_vllm_logger_metrics" - ): - vllm_logger_metrics = ( - policy_generation.get_vllm_logger_metrics() - ) - else: - vllm_logger_metrics = {} + # Collect generation logger metrics for performance reporting + # inflight batch sizes and num pending samples are collected from each worker + if policy_generation is not None: + generation_logger_metrics = policy_generation.get_logger_metrics() # Only the actual refit/weight transfer should be counted as weight_sync print("🔄 Performing policy generation refit...") @@ -2398,11 +2383,9 @@ def async_grpo_train( trajectory_collector.set_weight_version.remote(weight_version) trajectory_collector.resume_after_refit.remote() - # Clear vLLM logger metrics after each refit (weight sync), starting a new logging cycle - if policy_generation is not None and hasattr( - policy_generation, "clear_vllm_logger_metrics" - ): - policy_generation.clear_vllm_logger_metrics() + # Clear logger metrics after each refit (weight sync), starting a new logging cycle + if policy_generation is not None: + policy_generation.clear_logger_metrics() # Validation val_metrics, validation_timings = None, None @@ -2495,8 +2478,8 @@ def async_grpo_train( else: metrics[k] = np.sum(v).item() metrics.update(rollout_metrics) - if vllm_logger_metrics is not None: - metrics["vllm_logger_metrics"] = vllm_logger_metrics + if generation_logger_metrics is not None: + metrics["generation_logger_metrics"] = generation_logger_metrics total_valid_tokens += metrics["global_valid_toks"] # Checkpointing (same as sync version) @@ -2603,7 +2586,7 @@ def async_grpo_train( "enable_vllm_metrics_logger", False ) and master_config.get("logger", {}).get("wandb_enabled", False): log_generation_metrics_to_wandb( - vllm_logger_metrics, + generation_logger_metrics, step + 1, master_config["policy"]["generation"]["vllm_cfg"][ "vllm_metrics_logger_interval" diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index 17c69e479a..428252e1f2 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -521,46 +521,47 @@ def visualize_per_worker_timeline( "generation" ].get("vllm_cfg", {}).get("async_engine", False) if is_vllm_metrics_logger_enabled: - vllm_logger_metrics = metrics["vllm_logger_metrics"] - # vllm_logger_me trics: dict[str (metric_name), dict[int (dp_idx), list[int] (metric_values)]] + vllm_logger_metrics = metrics.get("generation_logger_metrics", {}) + # vllm_logger_metrics: dict[str (metric_name), dict[int (dp_idx), list[int] (metric_values)]] # metric_name: "inflight_batch_sizes" or "num_pending_samples" - assert "inflight_batch_sizes" in vllm_logger_metrics, ( - "inflight_batch_sizes not found in vllm_logger_metrics" - ) - assert "num_pending_samples" in vllm_logger_metrics, ( - "num_pending_samples not found in vllm_logger_metrics" - ) - assert isinstance(vllm_logger_metrics["inflight_batch_sizes"], dict), ( - "inflight_batch_sizes must be a dictionary" - ) - assert isinstance(vllm_logger_metrics["num_pending_samples"], dict), ( - "num_pending_samples must be a dictionary" - ) - - vllm_metrics_logger_interval = master_config["policy"]["generation"][ - "vllm_cfg" - ]["vllm_metrics_logger_interval"] - print(" • vLLM Logger Metrics:") - # Visualize the inflight batch sizes timeline - if len(vllm_logger_metrics["inflight_batch_sizes"].values()) > 0: - visualize_per_worker_timeline( - vllm_logger_metrics["inflight_batch_sizes"], - "Inflight Batch Sizes", - vllm_metrics_logger_interval, + if vllm_logger_metrics: + assert "inflight_batch_sizes" in vllm_logger_metrics, ( + "inflight_batch_sizes not found in vllm_logger_metrics" ) - if len(vllm_logger_metrics["num_pending_samples"].values()) > 0: - max_num_pending_samples = max( - (max(v) if v else 0) - for v in vllm_logger_metrics["num_pending_samples"].values() + assert "num_pending_samples" in vllm_logger_metrics, ( + "num_pending_samples not found in vllm_logger_metrics" ) - # If there is at least one pending sample, visualize the timeline - if max_num_pending_samples > 0: + assert isinstance(vllm_logger_metrics["inflight_batch_sizes"], dict), ( + "inflight_batch_sizes must be a dictionary" + ) + assert isinstance(vllm_logger_metrics["num_pending_samples"], dict), ( + "num_pending_samples must be a dictionary" + ) + + vllm_metrics_logger_interval = master_config["policy"]["generation"][ + "vllm_cfg" + ]["vllm_metrics_logger_interval"] + print(" • vLLM Logger Metrics:") + # Visualize the inflight batch sizes timeline + if len(vllm_logger_metrics["inflight_batch_sizes"].values()) > 0: visualize_per_worker_timeline( - vllm_logger_metrics["num_pending_samples"], - "Num Pending Samples", - None, + vllm_logger_metrics["inflight_batch_sizes"], + "Inflight Batch Sizes", + vllm_metrics_logger_interval, ) + if len(vllm_logger_metrics["num_pending_samples"].values()) > 0: + max_num_pending_samples = max( + (max(v) if v else 0) + for v in vllm_logger_metrics["num_pending_samples"].values() + ) + # If there is at least one pending sample, visualize the timeline + if max_num_pending_samples > 0: + visualize_per_worker_timeline( + vllm_logger_metrics["num_pending_samples"], + "Num Pending Samples", + None, + ) # ===================================================== # Throughputs diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index d134027bdf..7ec3c14576 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -257,3 +257,22 @@ def update_weights_from_collective(self) -> list[ray.ObjectRef]: # (e.g., vLLM prefix/KV caches) after weight updates. def invalidate_kv_cache(self) -> bool: return False + + def clear_logger_metrics(self) -> None: + """Clear logger metrics for performance reporting. + + This is an optional method that backends can implement to clear + telemetry metrics. Default implementation does nothing. + """ + pass + + def get_logger_metrics(self) -> dict[str, Any]: + """Get logger metrics for performance reporting. + + This is an optional method that backends can implement to collect + telemetry metrics. Default implementation returns empty dict. + + Returns: + Dictionary of metrics. Format may vary by backend. + """ + return {} diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 93540ebe82..1366ce28c5 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -876,6 +876,14 @@ def clear_vllm_logger_metrics(self) -> None: ) ray.get(futures) + def clear_logger_metrics(self) -> None: + """Clear logger metrics for performance reporting.""" + self.clear_vllm_logger_metrics() + + def get_logger_metrics(self) -> dict[str, Any]: + """Get logger metrics for performance reporting.""" + return self.get_vllm_logger_metrics() + def __del__(self) -> None: """Shuts down the worker groups when the object is deleted or is garbage collected. From 255dcc675e58ab82cbf2fb3edbfb100a263132d7 Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Thu, 4 Dec 2025 20:40:38 +0000 Subject: [PATCH 25/29] refactor sglang config loading to make it consistent with other backendw Signed-off-by: Zhuoran Yin --- examples/configs/grpo_math_1B_sglang.yaml | 21 +++++++++--------- nemo_rl/algorithms/grpo.py | 8 +++---- nemo_rl/models/generation/sglang/config.py | 14 ++++++++---- .../generation/sglang/sglang_generation.py | 6 +++-- .../models/generation/sglang/sglang_worker.py | 22 +++++++++---------- 5 files changed, 40 insertions(+), 31 deletions(-) diff --git a/examples/configs/grpo_math_1B_sglang.yaml b/examples/configs/grpo_math_1B_sglang.yaml index 97d6f38a56..e31310e202 100644 --- a/examples/configs/grpo_math_1B_sglang.yaml +++ b/examples/configs/grpo_math_1B_sglang.yaml @@ -214,16 +214,17 @@ policy: top_k: null stop_token_ids: null stop_strings: null - # SGLang specific configuration - model_path: ${policy.model_name} # Model path for SGLang server - gpus_per_server: 1 # Number of GPUs per SGLang server (tensor parallel size) - dtype: ${policy.precision} # Model precision (bfloat16, float16, etc.) - context_length: 512 # Maximum context length - allow_auto_truncate: true - enable_memory_saver: false - max_running_requests: null - mem_fraction_static: 0.7 - skip_server_warmup: true # Skip server warmup to prevent timeout + sglang_cfg: + # SGLang specific configuration + model_path: ${policy.model_name} + gpus_per_server: 1 + dtype: ${policy.precision} + context_length: 512 # Maximum context length + allow_auto_truncate: true + enable_memory_saver: false + max_running_requests: null + mem_fraction_static: 0.7 + skip_server_warmup: true colocated: # true: generation shares training GPUs # false: uses dedicated generation resources diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 5b6518589b..7b54936c33 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -608,10 +608,10 @@ def initialize_generation_with_policy( elif backend == "sglang": generation_config = cast(SGLangConfig, generation_config) - # Set model_name and model_path - generation_config["model_name"] = policy_config["model_name"] - if "model_path" not in generation_config or not generation_config.get("model_path"): - generation_config["model_path"] = policy_config["model_name"] + + # Set model_path if not already set + if "model_path" not in generation_config["sglang_cfg"]: + generation_config["sglang_cfg"]["model_path"] = policy_config["model_name"] policy_generation, policy = initialize_generation_with_policy( init_generation_fn=init_sglang, diff --git a/nemo_rl/models/generation/sglang/config.py b/nemo_rl/models/generation/sglang/config.py index 9c82c7583b..a401243a6d 100644 --- a/nemo_rl/models/generation/sglang/config.py +++ b/nemo_rl/models/generation/sglang/config.py @@ -17,13 +17,12 @@ from nemo_rl.models.generation.interfaces import GenerationConfig -class SGLangConfig(GenerationConfig): - """Configuration for SGLang runtime. +class SglangSpecificArgs(TypedDict): + """SGLang-specific configuration arguments. Most fields below map directly to SGLang's ServerArgs (see: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py). """ - model_path: NotRequired[str] gpus_per_server: NotRequired[int] random_seed: NotRequired[int] @@ -64,6 +63,7 @@ class SGLangConfig(GenerationConfig): dtype: NotRequired[str] kv_cache_dtype: NotRequired[str] dp_size: NotRequired[int] # only used for dp attention + pp_size: NotRequired[int] # pipeline parallel size ep_size: NotRequired[int] # lora enable_lora: NotRequired[bool | None] @@ -86,7 +86,13 @@ class SGLangConfig(GenerationConfig): # Extra loader arguments enable_multithread_load: NotRequired[bool] enable_fast_load: NotRequired[bool] - # Additional ServerArgs fields can be passed via this generic kwargs dict + # Server warmup + skip_server_warmup: NotRequired[bool] + + +class SGLangConfig(GenerationConfig): + """Configuration for SGLang runtime.""" + sglang_cfg: SglangSpecificArgs sglang_kwargs: NotRequired[dict[str, Any]] \ No newline at end of file diff --git a/nemo_rl/models/generation/sglang/sglang_generation.py b/nemo_rl/models/generation/sglang/sglang_generation.py index 47065aa557..b63acedfdf 100644 --- a/nemo_rl/models/generation/sglang/sglang_generation.py +++ b/nemo_rl/models/generation/sglang/sglang_generation.py @@ -63,11 +63,12 @@ def __init__( """ # Store config self.cfg = config + self.sglang_cfg = config["sglang_cfg"] - gpus_per_server = self.cfg.get("gpus_per_server", None) + gpus_per_server = self.sglang_cfg.get("gpus_per_server", None) if gpus_per_server is None: raise ValueError( - "gpus_per_server must be set in SGLangConfig. " + "gpus_per_server must be set in SGLangConfig.sglang_cfg." ) # Calculate number of servers based on available resources @@ -102,6 +103,7 @@ def __init__( # Initialize placement groups # For SGLang, we use PACK strategy to keep bundles together + # colocated is always at top level, not in sglang_cfg strategy = None if self.cfg.get("colocated", {}).get("enabled", False) else "PACK" cluster._init_placement_groups( strategy=strategy, diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index 1aba513047..2be5399880 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -133,6 +133,7 @@ def __init__( self.cfg = config self.is_model_owner = bundle_indices is not None self.global_rank = int(os.environ.get("RANK", "0")) + self.sglang_cfg = config["sglang_cfg"] # Create a dedicated event loop thread for async operations # there will be issues if we use the event loop in the main thread @@ -168,35 +169,34 @@ def __init__( # Build SGLang server arguments kwargs = { - "model_path": self.cfg.get("model_path", ""), + "model_path": self.sglang_cfg.get("model_path", ""), "trust_remote_code": True, - "random_seed": seed if seed is not None else self.cfg.get("random_seed", 1), + "random_seed": seed if seed is not None else self.sglang_cfg.get("random_seed", 1), # Memory settings - "enable_memory_saver": self.cfg.get("enable_memory_saver", False), + "enable_memory_saver": self.sglang_cfg.get("enable_memory_saver", False), "gpu_id_step": 1, "base_gpu_id": base_gpu_id, # Parallel settings "tp_size": tp_size, - "dp_size": self.cfg.get("dp_size", 1), - "pp_size": self.cfg.get("pp_size", 1), - "ep_size": self.cfg.get("ep_size", 1), + "dp_size": self.sglang_cfg.get("dp_size", 1), + "pp_size": self.sglang_cfg.get("pp_size", 1), + "ep_size": self.sglang_cfg.get("ep_size", 1), # Always skip warmup to prevent warmup timeout - "skip_server_warmup": True, + "skip_server_warmup": self.sglang_cfg.get("skip_server_warmup", True), # Server network settings - listen on all interfaces, use the free port we found "host": "0.0.0.0", "port": free_port, "torchao_config": "", } - # Add other config fields if they exist for key in [ "dtype", "kv_cache_dtype", "context_length", "max_running_requests", "chunked_prefill_size", "max_prefill_tokens", "schedule_policy", "schedule_conservativeness", "cpu_offload_gb", "log_level", "mem_fraction_static", "allow_auto_truncate", ]: - if key in self.cfg: - kwargs[key] = self.cfg[key] + if key in self.sglang_cfg: + kwargs[key] = self.sglang_cfg[key] server_args = ServerArgs(**kwargs) # Save server_args and base_url for use in generate() and _make_request() @@ -555,7 +555,7 @@ def generate( if batch_size == 0: raise ValueError("Empty batch received") - context_length = self.cfg.get("context_length", None) + context_length = self.sglang_cfg.get("context_length", None) # Create async tasks for all samples tasks = [] From ee01f913ea7313e735d488fef13456e1bd47baef Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Sat, 6 Dec 2025 21:31:36 +0000 Subject: [PATCH 26/29] resolved ai comments Signed-off-by: Zhuoran Yin --- nemo_rl/algorithms/grpo.py | 6 +++- .../models/generation/sglang/sglang_worker.py | 34 +++++++++++++------ nemo_rl/models/generation/sglang/utils.py | 2 +- 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 7b54936c33..73b49c45a0 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1006,7 +1006,11 @@ def refit_policy_generation( if isinstance(policy_generation, SGLangGeneration): sglang_url_to_gpu_uuids = policy_generation.get_sglang_url_to_gpu_uuids() # Stream weights via HTTP - flush_success = policy_generation.invalidate_kv_cache() + flush_success = policy_generation.invalidate_kv_cache() + if not flush_success: + print( + "SGLang KV cache invalidation failed before weight update. " + ) futures_train = policy.stream_weights_via_http( sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, ) diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index 2be5399880..56bdc704b7 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -490,10 +490,17 @@ def _launch_server_process(self, server_args: ServerArgs) -> multiprocessing.Pro "Content-Type": "application/json; charset=utf-8", } + max_wait_time = 300 # 5 minutes timeout + start_time = time.time() with requests.Session() as session: while True: + if time.time() - start_time > max_wait_time: + kill_process_tree(p.pid) + raise TimeoutError( + f"[SGLang Server] Rank {self.global_rank} Server failed to start within {max_wait_time}s" + ) try: - response = session.get(f"{self.base_url}/health_generate", headers=headers) + response = session.get(f"{self.base_url}/health_generate", headers=headers, timeout=10) if response.status_code == 200: print(f"[SGLang Server] Rank {self.global_rank} Server is ready at {self.base_url}") break @@ -501,7 +508,7 @@ def _launch_server_process(self, server_args: ServerArgs) -> multiprocessing.Pro pass if not p.is_alive(): - raise Exception(f"[SGLang Server] Rank {self.global_rank} Server process terminated unexpectedly.") + raise RuntimeError(f"[SGLang Server] Rank {self.global_rank} Server process terminated unexpectedly.") time.sleep(2) return p @@ -668,14 +675,13 @@ def shutdown(self) -> bool: Returns: bool: True if shutdown was successful, False otherwise """ - if hasattr(self, "async_loop_thread"): - try: - self.async_loop_thread.shutdown() - print(f"[SGLang Worker] Rank {self.global_rank} Async loop thread shut down.") - except Exception as e: - print(f"[SGLang Worker] Rank {self.global_rank} Error shutting down async loop thread: {e}") - if not self.is_model_owner: + if hasattr(self, "async_loop_thread"): + try: + self.async_loop_thread.shutdown() + print(f"[SGLang Worker] Rank {self.global_rank} Async loop thread shut down.") + except Exception as e: + print(f"[SGLang Worker] Rank {self.global_rank} Error shutting down async loop thread: {e}") return True try: @@ -691,6 +697,14 @@ async def close_session(): except Exception as e: print(f"[SGLang Worker] Rank {self.global_rank} Error closing aiohttp session: {e}") + # Shutdown async loop thread after session cleanup + if hasattr(self, "async_loop_thread"): + try: + self.async_loop_thread.shutdown() + print(f"[SGLang Worker] Rank {self.global_rank} Async loop thread shut down.") + except Exception as e: + print(f"[SGLang Worker] Rank {self.global_rank} Error shutting down async loop thread: {e}") + if not hasattr(self, "server_process") or self.server_process is None: return True @@ -729,6 +743,6 @@ def _make_request(self, endpoint: str, payload: Optional[dict] = None): headers = { "Content-Type": "application/json; charset=utf-8", } - response = requests.post(url, json=payload or {}, headers=headers) + response = requests.post(url, json=payload or {}, headers=headers, timeout=60) response.raise_for_status() return response.json() \ No newline at end of file diff --git a/nemo_rl/models/generation/sglang/utils.py b/nemo_rl/models/generation/sglang/utils.py index 3b56037891..469d3bb79e 100644 --- a/nemo_rl/models/generation/sglang/utils.py +++ b/nemo_rl/models/generation/sglang/utils.py @@ -58,6 +58,6 @@ def shutdown(self): if self.loop.is_running(): self.loop.call_soon_threadsafe(self.loop.stop) self._thread.join(timeout=2.0) - if self.loop.is_running(): + if not self.loop.is_closed(): self.loop.close() From e25e57300d530cc0acee0236ee7b254ae15de66e Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Sat, 6 Dec 2025 21:42:57 +0000 Subject: [PATCH 27/29] changed print to using loging Signed-off-by: Zhuoran Yin --- .../generation/sglang/sglang_generation.py | 13 +++-- .../models/generation/sglang/sglang_worker.py | 55 +++++++++---------- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/nemo_rl/models/generation/sglang/sglang_generation.py b/nemo_rl/models/generation/sglang/sglang_generation.py index b63acedfdf..dbd1f3afb0 100644 --- a/nemo_rl/models/generation/sglang/sglang_generation.py +++ b/nemo_rl/models/generation/sglang/sglang_generation.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import logging import os from collections import defaultdict from typing import ( @@ -43,6 +44,8 @@ TOP_K_THRESHOLD = 8000 # Allow top_k >= 8000 (effectively no filtering) TOP_P_THRESHOLD = 0.99 # Allow top_p >= 0.99 (close to 1.0) +logger = logging.getLogger(__name__) + class SGLangGeneration(GenerationInterface): def __init__( @@ -82,7 +85,7 @@ def __init__( ) if total_gpus % gpus_per_server != 0: - print( + logger.warning( f"[WARNING] Total GPUs ({total_gpus}) is not divisible by GPUs per server ({gpus_per_server}). " f"Will use {num_servers} servers, leaving {total_gpus % gpus_per_server} GPUs unused." ) @@ -341,7 +344,7 @@ def shutdown(self) -> bool: # Use the worker group's shutdown method with the worker's cleanup method return self.worker_group.shutdown(cleanup_method="shutdown") except Exception as e: - print(f"Error during SGLang policy shutdown: {e}") + logger.error(f"Error during SGLang policy shutdown: {e}") return False def __del__(self) -> None: @@ -371,10 +374,10 @@ def invalidate_kv_cache(self) -> bool: results = [r for r in results if r is not None] success = all(result for result in results) if results else True if success: - print("[sglang refit] All SGLang server caches flushed successfully", flush=True) + logger.info("[sglang refit] All SGLang server caches flushed successfully") else: - print("[sglang refit] WARNING - Some SGLang server caches failed to flush", flush=True) + logger.warning("[sglang refit] WARNING - Some SGLang server caches failed to flush") return success except Exception as e: - print(f"[sglang refit] Error flushing SGLang caches: {e}", flush=True) + logger.error(f"[sglang refit] Error flushing SGLang caches: {e}") return False diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index 56bdc704b7..4cf15fc0e7 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -14,6 +14,7 @@ import copy import gc +import logging import os import sys from typing import Any, Optional, cast @@ -43,6 +44,8 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import kill_process_tree +logger = logging.getLogger(__name__) + @ray.remote( runtime_env={**get_nsight_config_if_pattern_matches("sglang_generation_worker")} @@ -157,7 +160,7 @@ def __init__( global_cvd = os.environ.get("CUDA_VISIBLE_DEVICES", None) - print( + logger.info( f"[SGLang Server] Rank {self.global_rank}: " f"base_gpu_id={base_gpu_id}, tp_size={tp_size}, " f"bundle_indices={bundle_indices}, global_cvd={global_cvd}" @@ -203,7 +206,7 @@ def __init__( self.server_args = server_args self.base_url = f"http://{node_ip}:{free_port}" - print(f"[SGLang Worker] Rank {self.global_rank} Starting on {self.base_url}, CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}, base_gpu_id: {base_gpu_id}") + logger.info(f"[SGLang Worker] Rank {self.global_rank} Starting on {self.base_url}, CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}, base_gpu_id: {base_gpu_id}") self.session = None self.connector = None @@ -236,38 +239,34 @@ def invalidate_kv_cache(self) -> bool: response = requests.get(url, timeout=10) if response.status_code == 200: if attempt > 0: - print( + logger.info( f"[SGLang Worker] Rank {self.global_rank} Cache flushed successfully " - f"(attempt {attempt + 1})", - flush=True + f"(attempt {attempt + 1})" ) return True except requests.exceptions.ConnectionError: # Server might not be ready yet - only retry for first few attempts if attempt >= connection_retry_limit: - print( + logger.warning( f"[SGLang Worker] Rank {self.global_rank} Connection failed after " - f"{connection_retry_limit} attempts", - flush=True + f"{connection_retry_limit} attempts" ) return False except Exception as e: # For other errors, log and retry (except on last attempt) if attempt == max_attempts - 1: - print( + logger.error( f"[SGLang Worker] Rank {self.global_rank} Failed to flush cache after " - f"{max_attempts} attempts: {e}", - flush=True + f"{max_attempts} attempts: {e}" ) return False time.sleep(1) # All attempts exhausted without success - print( + logger.error( f"[SGLang Worker] Rank {self.global_rank} Timeout: Cache flush failed after " - f"{max_attempts} attempts. Server may have pending requests.", - flush=True + f"{max_attempts} attempts. Server may have pending requests." ) return False @@ -357,7 +356,7 @@ def _build_sampling_params( if base_max_tokens > max_allowed_new_tokens: final_max_tokens = max_allowed_new_tokens if sample_index == 0: - print( + logger.warning( f"[SGLang Worker] Rank {self.global_rank} Warning: " f"Sample {sample_index} input length ({input_len}) + max_new_tokens ({base_max_tokens}) " f"would exceed context_length ({context_length}). " @@ -433,7 +432,7 @@ async def _generate_single_sample( response.raise_for_status() result = await response.json() except Exception as e: - print(f"[SGLang Worker] Rank {self.global_rank} Request failed for input_len={len(input_ids)}: {e}") + logger.error(f"[SGLang Worker] Rank {self.global_rank} Request failed for input_len={len(input_ids)}: {e}") raise # Extract generated tokens and logprobs @@ -475,7 +474,7 @@ async def wrap(idx, coro): results[idx] = value count += 1 if count % 50 == 0 or count == len(tasks): - print(f"[SGLang Worker] Rank {self.global_rank} Completed {count}/{len(tasks)} tasks") + logger.debug(f"[SGLang Worker] Rank {self.global_rank} Completed {count}/{len(tasks)} tasks") return results @@ -502,7 +501,7 @@ def _launch_server_process(self, server_args: ServerArgs) -> multiprocessing.Pro try: response = session.get(f"{self.base_url}/health_generate", headers=headers, timeout=10) if response.status_code == 200: - print(f"[SGLang Server] Rank {self.global_rank} Server is ready at {self.base_url}") + logger.info(f"[SGLang Server] Rank {self.global_rank} Server is ready at {self.base_url}") break except requests.RequestException: pass @@ -557,7 +556,7 @@ def generate( # Original input length with padding padded_input_length = input_ids.size(1) - print(f"[SGLang Worker] Rank {self.global_rank} batch_size: {batch_size}, padded_input_length: {padded_input_length}") + logger.debug(f"[SGLang Worker] Rank {self.global_rank} batch_size: {batch_size}, padded_input_length: {padded_input_length}") if batch_size == 0: raise ValueError("Empty batch received") @@ -651,7 +650,7 @@ def generate( logprobs = torch.stack(logprobs_list) generation_lengths = torch.tensor(generation_lengths_list, dtype=torch.long) unpadded_sequence_lengths = torch.tensor(unpadded_sequence_lengths_list, dtype=torch.long) - print(f"[SGLang Worker] Rank {self.global_rank} Generated {total_generated_tokens} tokens across {batch_size} samples (avg: {avg_generation_length:.1f} tokens/sample)") + logger.debug(f"[SGLang Worker] Rank {self.global_rank} Generated {total_generated_tokens} tokens across {batch_size} samples (avg: {avg_generation_length:.1f} tokens/sample)") return BatchedDataDict[GenerationOutputSpec]( { "output_ids": output_ids, @@ -679,9 +678,9 @@ def shutdown(self) -> bool: if hasattr(self, "async_loop_thread"): try: self.async_loop_thread.shutdown() - print(f"[SGLang Worker] Rank {self.global_rank} Async loop thread shut down.") + logger.info(f"[SGLang Worker] Rank {self.global_rank} Async loop thread shut down.") except Exception as e: - print(f"[SGLang Worker] Rank {self.global_rank} Error shutting down async loop thread: {e}") + logger.error(f"[SGLang Worker] Rank {self.global_rank} Error shutting down async loop thread: {e}") return True try: @@ -693,22 +692,22 @@ async def close_session(): await self.connector.close() self.async_loop_thread.run(close_session()) - print(f"[SGLang Worker] Rank {self.global_rank} aiohttp session closed.") + logger.info(f"[SGLang Worker] Rank {self.global_rank} aiohttp session closed.") except Exception as e: - print(f"[SGLang Worker] Rank {self.global_rank} Error closing aiohttp session: {e}") + logger.error(f"[SGLang Worker] Rank {self.global_rank} Error closing aiohttp session: {e}") # Shutdown async loop thread after session cleanup if hasattr(self, "async_loop_thread"): try: self.async_loop_thread.shutdown() - print(f"[SGLang Worker] Rank {self.global_rank} Async loop thread shut down.") + logger.info(f"[SGLang Worker] Rank {self.global_rank} Async loop thread shut down.") except Exception as e: - print(f"[SGLang Worker] Rank {self.global_rank} Error shutting down async loop thread: {e}") + logger.error(f"[SGLang Worker] Rank {self.global_rank} Error shutting down async loop thread: {e}") if not hasattr(self, "server_process") or self.server_process is None: return True - print( + logger.info( f"[SGLang Worker] Rank {self.global_rank} Shutting down server at {self.base_url}..." ) @@ -723,7 +722,7 @@ async def close_session(): return True except Exception as e: - print( + logger.error( f"[SGLang Worker] Rank {self.global_rank} Error during shutdown: {e}" ) return False From 85d6a92b0ac9bd4272d47a5d3c93ad515dca1ee9 Mon Sep 17 00:00:00 2001 From: Night <32424487+PrinsYin@users.noreply.github.com> Date: Wed, 17 Dec 2025 12:28:28 -0500 Subject: [PATCH 28/29] Update nemo_rl/models/generation/sglang/sglang_worker.py Co-authored-by: Terry Kong Signed-off-by: Night <32424487+PrinsYin@users.noreply.github.com> --- nemo_rl/models/generation/sglang/sglang_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index 4cf15fc0e7..2ecff6a6e2 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -131,7 +131,7 @@ def __init__( The length of this list determines tp_size (number of GPUs per server). Only needed for the first worker in each server group (model owner). fraction_of_gpus: Fraction of GPUs to use for this worker - seed: Random seed for initialization + seed: Random seed for initialization, if None, then defaults to the config's seed """ self.cfg = config self.is_model_owner = bundle_indices is not None From ede624f7e29a2c46040b25e4bce4632e5d3db371 Mon Sep 17 00:00:00 2001 From: PrinsYin Date: Wed, 17 Dec 2025 17:38:14 +0000 Subject: [PATCH 29/29] fix comments about config defaults --- examples/configs/grpo_math_1B_sglang.yaml | 271 +----------------- nemo_rl/distributed/virtual_cluster.py | 2 +- .../generation/sglang/sglang_generation.py | 2 +- .../models/generation/sglang/sglang_worker.py | 16 +- 4 files changed, 15 insertions(+), 276 deletions(-) diff --git a/examples/configs/grpo_math_1B_sglang.yaml b/examples/configs/grpo_math_1B_sglang.yaml index e31310e202..17b30f3ef5 100644 --- a/examples/configs/grpo_math_1B_sglang.yaml +++ b/examples/configs/grpo_math_1B_sglang.yaml @@ -1,219 +1,11 @@ -# GRPO Algorithm Configuration +defaults: grpo_math_1B.yaml + grpo: - num_prompts_per_step: 32 - num_generations_per_prompt: 16 - max_rollout_turns: 1 - max_num_epochs: 1 - max_num_steps: 1000000 - normalize_rewards: true - use_leave_one_out_baseline: true - val_period: 10 - val_at_start: false - overlong_filtering: false - max_val_samples: 256 val_batch_size: 128 - seed: 42 - use_dynamic_sampling: false - dynamic_sampling_max_gen_batches: 10 - batch_multiplier: 1 - reward_shaping: - enabled: false - overlong_buffer_length: 128 - overlong_buffer_penalty: 1 - max_response_length: ${policy.max_total_sequence_length} - reward_scaling: - enabled: false - source_min: 0.0 - source_max: 1.0 - target_min: 0.0 - target_max: 1.0 - - async_grpo: - enabled: false # Set to true to enable async training mode - # Max age (in training steps) for trajectories used in training - max_trajectory_age_steps: 1 - in_flight_weight_updates: false # Set to true to enable in-flight weight updates - recompute_kv_cache_after_weight_updates: false # Set to true to recompute kv cache after in-flight-weight-updates - -loss_fn: - reference_policy_kl_penalty: 0.01 - # Can be set to k1, k2, k3 - # For more details, see http://joschu.net/blog/kl-approx.html - reference_policy_kl_type: "k3" - kl_input_clamp_value: 20.0 - kl_output_clamp_value: 10.0 - ratio_clip_min: 0.2 - ratio_clip_max: 0.2 - ratio_clip_c: null - # (default off) loss formulation improvements (docs/guides/grpo.md#loss) - use_on_policy_kl_approximation: false - # Async GRPO requires importance sampling correction enabled - # Set to true when async_grpo.enabled is true - use_importance_sampling_correction: false - truncated_importance_sampling_ratio: null - sequence_level_importance_ratios: false - token_level_loss: true - -checkpointing: - enabled: true - checkpoint_dir: "results/grpo" - metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name - higher_is_better: true - keep_top_k: 3 - save_period: 10 - checkpoint_must_save_by: null - model_save_format: "safetensors" - save_consolidated: false policy: - model_name: "Qwen/Qwen2.5-1.5B" - tokenizer: - name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default - chat_template_kwargs: null # can be used to pass kwargs to the chat template, e.g., enable_thinking=true - hf_config_overrides: {} - train_global_batch_size: 512 - train_micro_batch_size: 4 - generation_batch_size: 32 # Only used when generating using HF backend - logprob_batch_size: 4 - max_total_sequence_length: 512 - precision: "bfloat16" - logprob_chunk_size: null - offload_optimizer_for_logprob: false # Only useful for non-colocated generation since colocated generation will always offload optimizer to cuda before refit - - dtensor_cfg: - _v2: true - enabled: true - cpu_offload: False - sequence_parallel: false - activation_checkpointing: false - tensor_parallel_size: 1 - context_parallel_size: 1 - custom_parallel_plan: null - - megatron_cfg: - enabled: false - empty_unused_memory_level: 1 # 1 is the minimum recommendation for RL since we almost always need to offload before beginning generation. Setting to 0 is faster, but you are more likely to run out of GPU memory. - activation_checkpointing: false - converter_type: "Qwen2ForCausalLM" - tensor_model_parallel_size: 1 - expert_tensor_parallel_size: 1 - expert_model_parallel_size: 1 - pipeline_model_parallel_size: 1 - num_layers_in_first_pipeline_stage: null - num_layers_in_last_pipeline_stage: null - context_parallel_size: 1 - pipeline_dtype: ${policy.precision} - sequence_parallel: false - freeze_moe_router: true - moe_router_dtype: "fp64" - moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo - moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo - moe_permute_fusion: false - #gives ~20% training perf speedup with sequence packing - apply_rope_fusion: True - # gives ~25% training perf speedup with sequence packing and apply_rope_fusion - bias_activation_fusion: True - defer_fp32_logits: False - - optimizer: - optimizer: "adam" - lr: 5.0e-6 - min_lr: 5.0e-7 - weight_decay: 0.01 - bf16: true - fp16: false - params_dtype: "float32" - - #adam - adam_beta1: 0.9 - adam_beta2: 0.999 - adam_eps: 1e-8 - - #sgd - sgd_momentum: 0.9 - - #distributed optimizer - use_distributed_optimizer: true - use_precision_aware_optimizer: true - - clip_grad: ${policy.max_grad_norm} - - # optimizer cpu offload - optimizer_cpu_offload: false - optimizer_offload_fraction: 0.0 - - scheduler: - start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} - end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} - weight_decay_incr_style: "constant" - lr_decay_style: "constant" - lr_decay_iters: 1000 - lr_warmup_iters: 13 - lr_warmup_init: 5.0e-7 - - distributed_data_parallel_config: - grad_reduce_in_fp32: false - overlap_grad_reduce: true - overlap_param_gather: true - use_custom_fsdp: false - data_parallel_sharding_strategy: "optim_grads_params" - - fp8_cfg: null - - env_vars: null - - # See docs/design-docs/sequence-packing-and-dynamic-batching.md - # for more details on dynamic batching and sequence packing. - dynamic_batching: - enabled: False - train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} - logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} - sequence_length_round: 64 - - sequence_packing: - enabled: True - train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} - logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} - algorithm: "modified_first_fit_decreasing" - sequence_length_round: 64 - - # makes the training sequence length divisible by the tensor parallel size - # this is useful for sequence parallel training - make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} - max_grad_norm: 1.0 - - optimizer: - name: "torch.optim.AdamW" - kwargs: - lr: 5.0e-6 - weight_decay: 0.01 - betas: [0.9, 0.999] - eps: 1e-8 - # when using Dtensor, we need to set foreach - # and fused to False - foreach: False - fused: False - - scheduler: - - name: "torch.optim.lr_scheduler.LinearLR" - kwargs: - start_factor: 0.1 - end_factor: 1.0 - total_iters: 50 - - name: "torch.optim.lr_scheduler.ConstantLR" - kwargs: - factor: 1.0 - total_iters: 10000000000 - - milestones: [50] - generation: backend: "sglang" - max_new_tokens: ${policy.max_total_sequence_length} - temperature: 1.0 - top_p: 1.0 - top_k: null - stop_token_ids: null - stop_strings: null sglang_cfg: # SGLang specific configuration model_path: ${policy.model_name} @@ -222,65 +14,12 @@ policy: context_length: 512 # Maximum context length allow_auto_truncate: true enable_memory_saver: false + dp_size: 1 + pp_size: 1 + ep_size: 1 max_running_requests: null mem_fraction_static: 0.7 skip_server_warmup: true - colocated: - # true: generation shares training GPUs - # false: uses dedicated generation resources - enabled: true - # only relevant when enabled is false - resources: - gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 - num_nodes: null # Decides number of nodes to be dedicated to generation - -data: - max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len - prompt_file: "examples/prompts/cot.txt" - system_prompt_file: null - shuffle: true - num_workers: 1 - - dataset_name: "OpenMathInstruct-2" - # You can use custom response datasets for training and validation. For example: - # data: - # dataset_name: ResponseDataset - # train_data_path: # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace) - # val_data_path: - # input_key: , default is "input" - # output_key: , default is "output" - # train_split: , default is None # used for HuggingFace datasets - # val_split: , default is None # used for HuggingFace datasets - # See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#datasets for more details. - -env: - math: - num_workers: 8 - math_verify_impl: "hf_math_verify" - ## unused in this config but needed for DAPO recipe - dapo: - num_workers: 8 - math_verify_impl: "dapo_math_verify" logger: - log_dir: "logs" # Base directory for all logs - num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal wandb_enabled: true - tensorboard_enabled: false - mlflow_enabled: false # Disable MLflow logging - swanlab_enabled: false # Disable SwanLab logging - monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard - wandb: - project: "grpo-dev" - name: "grpo-dev-logger" - tensorboard: {} - mlflow: - experiment_name: "grpo-dev" - run_name: "grpo-dev-logger" - gpu_monitoring: - collection_interval: 10 # How often to collect GPU usage metrics (in seconds) - flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) - -cluster: - gpus_per_node: 1 - num_nodes: 1 diff --git a/nemo_rl/distributed/virtual_cluster.py b/nemo_rl/distributed/virtual_cluster.py index 979f1e3e77..53662a37a6 100644 --- a/nemo_rl/distributed/virtual_cluster.py +++ b/nemo_rl/distributed/virtual_cluster.py @@ -53,7 +53,7 @@ class PY_EXECUTABLES: AUTOMODEL = f"uv run --locked --extra automodel --directory {git_root}" # Use NeMo-RL direct dependencies, nemo-automodel, and SGLang. - AUTOMODEL_SGLANG = "uv run --locked --extra automodel --extra sglang" + AUTOMODEL_SGLANG = f"uv run --locked --extra automodel --extra sglang --directory {git_root}" # Use NeMo-RL direct dependencies and Megatron. MCORE = f"uv run --locked --extra mcore --directory {git_root}" diff --git a/nemo_rl/models/generation/sglang/sglang_generation.py b/nemo_rl/models/generation/sglang/sglang_generation.py index dbd1f3afb0..99d2bd8bb7 100644 --- a/nemo_rl/models/generation/sglang/sglang_generation.py +++ b/nemo_rl/models/generation/sglang/sglang_generation.py @@ -107,7 +107,7 @@ def __init__( # Initialize placement groups # For SGLang, we use PACK strategy to keep bundles together # colocated is always at top level, not in sglang_cfg - strategy = None if self.cfg.get("colocated", {}).get("enabled", False) else "PACK" + strategy = None if self.cfg["colocated"]["enabled"] else "PACK" cluster._init_placement_groups( strategy=strategy, use_unified_pg=False, # SGLang servers don't need cross-node model parallelism diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py index 2ecff6a6e2..64b188e55d 100644 --- a/nemo_rl/models/generation/sglang/sglang_worker.py +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -172,18 +172,18 @@ def __init__( # Build SGLang server arguments kwargs = { - "model_path": self.sglang_cfg.get("model_path", ""), + "model_path": self.sglang_cfg["model_path"], "trust_remote_code": True, "random_seed": seed if seed is not None else self.sglang_cfg.get("random_seed", 1), # Memory settings - "enable_memory_saver": self.sglang_cfg.get("enable_memory_saver", False), + "enable_memory_saver": self.sglang_cfg["enable_memory_saver"], "gpu_id_step": 1, "base_gpu_id": base_gpu_id, # Parallel settings "tp_size": tp_size, - "dp_size": self.sglang_cfg.get("dp_size", 1), - "pp_size": self.sglang_cfg.get("pp_size", 1), - "ep_size": self.sglang_cfg.get("ep_size", 1), + "dp_size": self.sglang_cfg["dp_size"], + "pp_size": self.sglang_cfg["pp_size"], + "ep_size": self.sglang_cfg["ep_size"], # Always skip warmup to prevent warmup timeout "skip_server_warmup": self.sglang_cfg.get("skip_server_warmup", True), # Server network settings - listen on all interfaces, use the free port we found @@ -343,10 +343,10 @@ def _build_sampling_params( """ top_k_cfg = self.cfg.get("top_k") top_k_val = 1 if greedy else (top_k_cfg if top_k_cfg is not None else -1) - temperature = 0.0 if greedy else self.cfg.get("temperature", 1.0) + temperature = 0.0 if greedy else self.cfg["temperature"] base_max_tokens = ( - max_new_tokens if max_new_tokens is not None else self.cfg.get("max_new_tokens", 512) + max_new_tokens if max_new_tokens is not None else self.cfg["max_new_tokens"] ) # TODO: check if this is needed @@ -548,7 +548,7 @@ def generate( batch_stop_strings = data.get("stop_strings", [None] * len(input_lengths)) stop_strings = self._merge_stop_strings(batch_stop_strings) batch_size = len(input_lengths) - pad_token_id = self.cfg.get("_pad_token_id", 0) + pad_token_id = self.cfg["_pad_token_id"] # Verify inputs have correct padding verify_right_padding(data, pad_value=pad_token_id)