diff --git a/examples/configs/grpo_math_1B_sglang.yaml b/examples/configs/grpo_math_1B_sglang.yaml new file mode 100644 index 0000000000..17b30f3ef5 --- /dev/null +++ b/examples/configs/grpo_math_1B_sglang.yaml @@ -0,0 +1,25 @@ +defaults: grpo_math_1B.yaml + +grpo: + val_batch_size: 128 + +policy: + generation: + backend: "sglang" + sglang_cfg: + # SGLang specific configuration + model_path: ${policy.model_name} + gpus_per_server: 1 + dtype: ${policy.precision} + context_length: 512 # Maximum context length + allow_auto_truncate: true + enable_memory_saver: false + dp_size: 1 + pp_size: 1 + ep_size: 1 + max_running_requests: null + mem_fraction_static: 0.7 + skip_server_warmup: true + +logger: + wandb_enabled: true diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 8ab62d00fb..f0c4002071 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -62,6 +62,7 @@ ) from nemo_rl.models.generation.interfaces import GenerationInterface from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration +from nemo_rl.models.generation.sglang import SGLangConfig, SGLangGeneration from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import ColocatablePolicyInterface from nemo_rl.models.policy.lm_policy import Policy @@ -482,9 +483,78 @@ def init_vllm(): pg.finish_generation() return pg, time.perf_counter() - t0 - # Handle backend-specific setup + def init_sglang(): + """Initialize SGLang generation workers.""" + t0 = time.perf_counter() + pg = SGLangGeneration(cluster=inference_cluster, config=generation_config) + pg.finish_generation() + return pg, time.perf_counter() - t0 + + def initialize_generation_with_policy( + init_generation_fn, + generation_name: str, + init_time_key: str, + colocated_inference: bool, + worker_init_timing_metrics: dict, + ): + """ + Generic function to initialize a generation engine (vLLM or SGLang) along with policy. + + Args: + init_generation_fn: Function that initializes the generation engine (init_vllm or init_sglang) + generation_name: Name of the generation engine ("vLLM" or "SGLang") + init_time_key: Key name for storing initialization time in metrics ("vllm_init_time_s" or "sglang_init_time_s") + colocated_inference: Whether inference is colocated with training + worker_init_timing_metrics: Dictionary to store timing metrics + + Returns: + Tuple of (policy_generation, policy) + """ + # Determine if parallel initialization is possible (non-colocated mode) + use_parallel_init = not colocated_inference + + if use_parallel_init: + # Parallel initialization: Generation engine and Policy can initialize simultaneously + print( + " ⚡ Using parallel worker initialization (non-colocated mode)", + flush=True, + ) + + # Execute both initializations in parallel + parallel_start_time = time.perf_counter() + with ThreadPoolExecutor(max_workers=2) as executor: + generation_future = executor.submit(init_generation_fn) + policy_future = executor.submit(init_policy) + policy_generation, generation_time = generation_future.result() + policy, policy_time = policy_future.result() + parallel_wall_time = time.perf_counter() - parallel_start_time + + # Store timing metrics + worker_init_timing_metrics[init_time_key] = generation_time + worker_init_timing_metrics["policy_init_time_s"] = policy_time + worker_init_timing_metrics["parallel_wall_time_s"] = parallel_wall_time + worker_init_timing_metrics["parallel_init_enabled"] = True + + else: + # Sequential initialization: colocated mode (GPU memory requires generation engine first) + print( + " ⚙️ Using sequential worker initialization (colocated mode)", + flush=True, + ) + + # Initialize generation engine first (clean GPU memory), then policy + policy_generation, generation_time = init_generation_fn() + worker_init_timing_metrics[init_time_key] = generation_time + + policy, policy_time = init_policy() + worker_init_timing_metrics["policy_init_time_s"] = policy_time + worker_init_timing_metrics["parallel_init_enabled"] = 0.0 + + return policy_generation, policy + + # Handle generation-specific setup if backend == "megatron": - # Megatron backend: policy_generation is None, only initialize policy + # Megatron generation: policy_generation is None, only initialize policy policy_generation = None print( f" ✓ Using {backend} backend for generation with {policy_config['model_name']}", @@ -495,7 +565,7 @@ def init_vllm(): worker_init_timing_metrics["policy_init_time_s"] = policy_time elif backend == "vllm": - # vLLM backend: setup config, then decide parallel vs sequential init + # vLLM generation: setup config, then initialize with policy generation_config = cast(VllmConfig, generation_config) if generation_config["vllm_cfg"]["precision"] == "fp8": assert loss_config["use_importance_sampling_correction"] is True, ( @@ -523,48 +593,36 @@ def init_vllm(): "hf_config_overrides", {} ) - # Determine if parallel initialization is possible (non-colocated mode) - use_parallel_init = not colocated_inference - - if use_parallel_init: - # Parallel initialization: vLLM and Policy can initialize simultaneously - print( - " ⚡ Using parallel worker initialization (non-colocated mode)", - flush=True, - ) - - # Execute both initializations in parallel - parallel_start_time = time.perf_counter() - with ThreadPoolExecutor(max_workers=2) as executor: - vllm_future = executor.submit(init_vllm) - policy_future = executor.submit(init_policy) - policy_generation, vllm_time = vllm_future.result() - policy, policy_time = policy_future.result() - parallel_wall_time = time.perf_counter() - parallel_start_time - - # Store timing metrics - worker_init_timing_metrics["vllm_init_time_s"] = vllm_time - worker_init_timing_metrics["policy_init_time_s"] = policy_time - worker_init_timing_metrics["parallel_wall_time_s"] = parallel_wall_time - worker_init_timing_metrics["parallel_init_enabled"] = True - - else: - # Sequential initialization: colocated mode (GPU memory requires vLLM first) - print( - " ⚙️ Using sequential worker initialization (colocated mode)", - flush=True, - ) + policy_generation, policy = initialize_generation_with_policy( + init_generation_fn=init_vllm, + generation_name="vLLM", + init_time_key="vllm_init_time_s", + colocated_inference=colocated_inference, + worker_init_timing_metrics=worker_init_timing_metrics, + ) - # Initialize vLLM first (clean GPU memory), then policy - policy_generation, vllm_time = init_vllm() - worker_init_timing_metrics["vllm_init_time_s"] = vllm_time + print( + f" ✓ Using vLLM backend for generation with {policy_config['model_name']}", + flush=True, + ) - policy, policy_time = init_policy() - worker_init_timing_metrics["policy_init_time_s"] = policy_time - worker_init_timing_metrics["parallel_init_enabled"] = 0.0 + elif backend == "sglang": + generation_config = cast(SGLangConfig, generation_config) + + # Set model_path if not already set + if "model_path" not in generation_config["sglang_cfg"]: + generation_config["sglang_cfg"]["model_path"] = policy_config["model_name"] + + policy_generation, policy = initialize_generation_with_policy( + init_generation_fn=init_sglang, + generation_name="SGLang", + init_time_key="sglang_init_time_s", + colocated_inference=colocated_inference, + worker_init_timing_metrics=worker_init_timing_metrics, + ) print( - f" ✓ Using vLLM backend for generation with {policy_config['model_name']}", + f" ✓ Using SGLang backend for generation with {policy_config['model_name']}", flush=True, ) @@ -945,14 +1003,30 @@ def refit_policy_generation( policy.get_free_memory_bytes() * float(memory_ratio) ) - futures_train = policy.stream_weights_via_ipc_zmq( - buffer_size_bytes=buffer_size_bytes, kv_scales=kv_scales - ) - futures_inference = policy_generation.update_weights_via_ipc_zmq() - # wait for all futures to complete - ray.get(futures_train) - results = ray.get(futures_inference) - update_success = all(result for result in results if result is not None) + if isinstance(policy_generation, SGLangGeneration): + sglang_url_to_gpu_uuids = policy_generation.get_sglang_url_to_gpu_uuids() + # Stream weights via HTTP + flush_success = policy_generation.invalidate_kv_cache() + if not flush_success: + print( + "SGLang KV cache invalidation failed before weight update. " + ) + futures_train = policy.stream_weights_via_http( + sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, + ) + # Wait for all workers to complete + ray.get(futures_train) + update_success = True + else: + # Original ZMQ IPC path for vLLM + futures_train = policy.stream_weights_via_ipc_zmq( + buffer_size_bytes=buffer_size_bytes + ) + futures_inference = policy_generation.update_weights_via_ipc_zmq() + # wait for all futures to complete + ray.get(futures_train) + results = ray.get(futures_inference) + update_success = all(result for result in results if result is not None) else: # update weights through nccl futures_train = policy.broadcast_weights_for_collective(kv_scales=kv_scales) @@ -1148,11 +1222,10 @@ def grpo_train( dynamic_sampling_num_gen_batches += 1 with timer.time("generation"): - # Clear vLLM logger metrics for each generation step - if policy_generation is not None and hasattr( - policy_generation, "clear_vllm_logger_metrics" - ): - policy_generation.clear_vllm_logger_metrics() + + # Clear logger metrics for each generation step + if policy_generation is not None: + policy_generation.clear_logger_metrics() # Use NeMo-Gym rollouts if enabled. We cascade NeMo-Gym first since NeMo-Gym requires async rollouts. if _should_use_nemo_gym(master_config): generation_config = master_config["policy"]["generation"] @@ -1202,16 +1275,10 @@ def grpo_train( greedy=False, ) policy_generation.finish_generation() - # Collect vLLM logger metrics for performance reporting after each generation step - # inflight batch sizes and num pending samples are collected from each vLLM worker - if policy_generation is not None and hasattr( - policy_generation, "get_vllm_logger_metrics" - ): - vllm_logger_metrics = ( - policy_generation.get_vllm_logger_metrics() - ) - else: - vllm_logger_metrics = {} + # Collect generation logger metrics for performance reporting after each generation step + # inflight batch sizes and num pending samples are collected from each worker + if policy_generation is not None: + generation_logger_metrics = policy_generation.get_logger_metrics() repeated_batch = scale_rewards( repeated_batch, master_config["grpo"]["reward_scaling"] @@ -1460,7 +1527,7 @@ def grpo_train( metrics[k] = np.sum(v).item() metrics.update(rollout_metrics) - metrics["vllm_logger_metrics"] = vllm_logger_metrics + metrics["generation_logger_metrics"] = generation_logger_metrics total_valid_tokens += metrics["global_valid_toks"] ## Checkpointing @@ -1583,7 +1650,7 @@ def grpo_train( "enable_vllm_metrics_logger", False ) and master_config.get("logger", {}).get("wandb_enabled", False): log_generation_metrics_to_wandb( - vllm_logger_metrics, + generation_logger_metrics, total_steps + 1, master_config["policy"]["generation"]["vllm_cfg"][ "vllm_metrics_logger_interval" @@ -2051,12 +2118,9 @@ def async_grpo_train( trajectory_collector.resume.remote() print("✅ All setup complete, starting buffer wait...") - - # Clear vLLM logger metrics after at start of training - if policy_generation is not None and hasattr( - policy_generation, "clear_vllm_logger_metrics" - ): - policy_generation.clear_vllm_logger_metrics() + # Clear logger metrics at start of training + if policy_generation is not None: + policy_generation.clear_logger_metrics() # Wait for initial buffer fill print( @@ -2296,23 +2360,17 @@ def async_grpo_train( train_results = policy.train(train_data, loss_fn) print("🔄 Synchronizing policy weights to trajectory collector…") - vllm_logger_metrics = None + generation_logger_metrics = None if NEED_REFIT: # Measure pending-generation wait as exposed_generation time print("🔄 Coordinating with trajectory collector before refit...") with timer.time("exposed_generation"): ray.get(trajectory_collector.prepare_for_refit.remote()) - # Collect vLLM logger metrics for performance reporting - # inflight batch sizes and num pending samples are collected from each vLLM worker - if policy_generation is not None and hasattr( - policy_generation, "get_vllm_logger_metrics" - ): - vllm_logger_metrics = ( - policy_generation.get_vllm_logger_metrics() - ) - else: - vllm_logger_metrics = {} + # Collect generation logger metrics for performance reporting + # inflight batch sizes and num pending samples are collected from each worker + if policy_generation is not None: + generation_logger_metrics = policy_generation.get_logger_metrics() # Only the actual refit/weight transfer should be counted as weight_sync print("🔄 Performing policy generation refit...") @@ -2327,11 +2385,9 @@ def async_grpo_train( trajectory_collector.set_weight_version.remote(weight_version) trajectory_collector.resume_after_refit.remote() - # Clear vLLM logger metrics after each refit (weight sync), starting a new logging cycle - if policy_generation is not None and hasattr( - policy_generation, "clear_vllm_logger_metrics" - ): - policy_generation.clear_vllm_logger_metrics() + # Clear logger metrics after each refit (weight sync), starting a new logging cycle + if policy_generation is not None: + policy_generation.clear_logger_metrics() # Validation val_metrics, validation_timings = None, None @@ -2424,8 +2480,8 @@ def async_grpo_train( else: metrics[k] = np.sum(v).item() metrics.update(rollout_metrics) - if vllm_logger_metrics is not None: - metrics["vllm_logger_metrics"] = vllm_logger_metrics + if generation_logger_metrics is not None: + metrics["generation_logger_metrics"] = generation_logger_metrics total_valid_tokens += metrics["global_valid_toks"] # Checkpointing (same as sync version) @@ -2532,7 +2588,7 @@ def async_grpo_train( "enable_vllm_metrics_logger", False ) and master_config.get("logger", {}).get("wandb_enabled", False): log_generation_metrics_to_wandb( - vllm_logger_metrics, + generation_logger_metrics, step + 1, master_config["policy"]["generation"]["vllm_cfg"][ "vllm_metrics_logger_interval" diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index 17c69e479a..428252e1f2 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -521,46 +521,47 @@ def visualize_per_worker_timeline( "generation" ].get("vllm_cfg", {}).get("async_engine", False) if is_vllm_metrics_logger_enabled: - vllm_logger_metrics = metrics["vllm_logger_metrics"] - # vllm_logger_me trics: dict[str (metric_name), dict[int (dp_idx), list[int] (metric_values)]] + vllm_logger_metrics = metrics.get("generation_logger_metrics", {}) + # vllm_logger_metrics: dict[str (metric_name), dict[int (dp_idx), list[int] (metric_values)]] # metric_name: "inflight_batch_sizes" or "num_pending_samples" - assert "inflight_batch_sizes" in vllm_logger_metrics, ( - "inflight_batch_sizes not found in vllm_logger_metrics" - ) - assert "num_pending_samples" in vllm_logger_metrics, ( - "num_pending_samples not found in vllm_logger_metrics" - ) - assert isinstance(vllm_logger_metrics["inflight_batch_sizes"], dict), ( - "inflight_batch_sizes must be a dictionary" - ) - assert isinstance(vllm_logger_metrics["num_pending_samples"], dict), ( - "num_pending_samples must be a dictionary" - ) - - vllm_metrics_logger_interval = master_config["policy"]["generation"][ - "vllm_cfg" - ]["vllm_metrics_logger_interval"] - print(" • vLLM Logger Metrics:") - # Visualize the inflight batch sizes timeline - if len(vllm_logger_metrics["inflight_batch_sizes"].values()) > 0: - visualize_per_worker_timeline( - vllm_logger_metrics["inflight_batch_sizes"], - "Inflight Batch Sizes", - vllm_metrics_logger_interval, + if vllm_logger_metrics: + assert "inflight_batch_sizes" in vllm_logger_metrics, ( + "inflight_batch_sizes not found in vllm_logger_metrics" ) - if len(vllm_logger_metrics["num_pending_samples"].values()) > 0: - max_num_pending_samples = max( - (max(v) if v else 0) - for v in vllm_logger_metrics["num_pending_samples"].values() + assert "num_pending_samples" in vllm_logger_metrics, ( + "num_pending_samples not found in vllm_logger_metrics" ) - # If there is at least one pending sample, visualize the timeline - if max_num_pending_samples > 0: + assert isinstance(vllm_logger_metrics["inflight_batch_sizes"], dict), ( + "inflight_batch_sizes must be a dictionary" + ) + assert isinstance(vllm_logger_metrics["num_pending_samples"], dict), ( + "num_pending_samples must be a dictionary" + ) + + vllm_metrics_logger_interval = master_config["policy"]["generation"][ + "vllm_cfg" + ]["vllm_metrics_logger_interval"] + print(" • vLLM Logger Metrics:") + # Visualize the inflight batch sizes timeline + if len(vllm_logger_metrics["inflight_batch_sizes"].values()) > 0: visualize_per_worker_timeline( - vllm_logger_metrics["num_pending_samples"], - "Num Pending Samples", - None, + vllm_logger_metrics["inflight_batch_sizes"], + "Inflight Batch Sizes", + vllm_metrics_logger_interval, ) + if len(vllm_logger_metrics["num_pending_samples"].values()) > 0: + max_num_pending_samples = max( + (max(v) if v else 0) + for v in vllm_logger_metrics["num_pending_samples"].values() + ) + # If there is at least one pending sample, visualize the timeline + if max_num_pending_samples > 0: + visualize_per_worker_timeline( + vllm_logger_metrics["num_pending_samples"], + "Num Pending Samples", + None, + ) # ===================================================== # Throughputs diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py index 4190062ec6..636da32316 100644 --- a/nemo_rl/distributed/ray_actor_environment_registry.py +++ b/nemo_rl/distributed/ray_actor_environment_registry.py @@ -20,6 +20,9 @@ VLLM_EXECUTABLE = ( PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.VLLM ) +SGLANG_EXECUTABLE = ( + PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.SGLANG +) MCORE_EXECUTABLE = ( PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.MCORE ) @@ -27,6 +30,7 @@ ACTOR_ENVIRONMENT_REGISTRY: dict[str, str] = { "nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker": VLLM_EXECUTABLE, "nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker": VLLM_EXECUTABLE, + "nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker": SGLANG_EXECUTABLE, # Temporary workaround for the coupled implementation of DTensorPolicyWorker and vLLM. # This will be reverted to PY_EXECUTABLES.BASE once https://github.com/NVIDIA-NeMo/RL/issues/501 is resolved. "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker": VLLM_EXECUTABLE, @@ -63,3 +67,4 @@ def get_actor_python_env(actor_class_fqn: str) -> str: "adding a new generation framework or training backend), you'll need to specify the " "appropriate environment. See uv.md for more details." ) + diff --git a/nemo_rl/distributed/virtual_cluster.py b/nemo_rl/distributed/virtual_cluster.py index 3021b760e4..53662a37a6 100644 --- a/nemo_rl/distributed/virtual_cluster.py +++ b/nemo_rl/distributed/virtual_cluster.py @@ -52,11 +52,16 @@ class PY_EXECUTABLES: # Use NeMo-RL direct dependencies and nemo-automodel. AUTOMODEL = f"uv run --locked --extra automodel --directory {git_root}" + # Use NeMo-RL direct dependencies, nemo-automodel, and SGLang. + AUTOMODEL_SGLANG = f"uv run --locked --extra automodel --extra sglang --directory {git_root}" + # Use NeMo-RL direct dependencies and Megatron. MCORE = f"uv run --locked --extra mcore --directory {git_root}" # Use NeMo-Gym dependencies NEMO_GYM = f"uv run --locked --extra nemo_gym --directory {git_root}" + # Use NeMo-RL direct dependencies and SGLang. + SGLANG = "uv run --locked --extra sglang --directory {git_root}" @ray.remote # pragma: no cover @@ -503,3 +508,4 @@ def __del__(self) -> None: user calls shutdown(). """ self.shutdown() + \ No newline at end of file diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index d134027bdf..7ec3c14576 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -257,3 +257,22 @@ def update_weights_from_collective(self) -> list[ray.ObjectRef]: # (e.g., vLLM prefix/KV caches) after weight updates. def invalidate_kv_cache(self) -> bool: return False + + def clear_logger_metrics(self) -> None: + """Clear logger metrics for performance reporting. + + This is an optional method that backends can implement to clear + telemetry metrics. Default implementation does nothing. + """ + pass + + def get_logger_metrics(self) -> dict[str, Any]: + """Get logger metrics for performance reporting. + + This is an optional method that backends can implement to collect + telemetry metrics. Default implementation returns empty dict. + + Returns: + Dictionary of metrics. Format may vary by backend. + """ + return {} diff --git a/nemo_rl/models/generation/sglang/__init__.py b/nemo_rl/models/generation/sglang/__init__.py new file mode 100644 index 0000000000..55ce57084d --- /dev/null +++ b/nemo_rl/models/generation/sglang/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OR WARRANTIES OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nemo_rl.models.generation.sglang.config import SGLangConfig +from nemo_rl.models.generation.sglang.sglang_generation import SGLangGeneration +from nemo_rl.models.generation.sglang.sglang_worker import SGLangGenerationWorker + +__all__ = [ + "SGLangConfig", + "SGLangGeneration", + "SGLangGenerationWorker", +] + diff --git a/nemo_rl/models/generation/sglang/config.py b/nemo_rl/models/generation/sglang/config.py new file mode 100644 index 0000000000..a401243a6d --- /dev/null +++ b/nemo_rl/models/generation/sglang/config.py @@ -0,0 +1,98 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, NotRequired, TypedDict + +from nemo_rl.models.generation.interfaces import GenerationConfig + + +class SglangSpecificArgs(TypedDict): + """SGLang-specific configuration arguments. + + Most fields below map directly to SGLang's ServerArgs (see: + https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py). + """ + model_path: NotRequired[str] + gpus_per_server: NotRequired[int] + random_seed: NotRequired[int] + skip_tokenizer_init: NotRequired[bool] + disable_cuda_graph: NotRequired[bool] + disable_radix_cache: NotRequired[bool] + disable_cuda_graph_padding: NotRequired[bool] + enable_nccl_nvls: NotRequired[bool] + disable_outlines_disk_cache: NotRequired[bool] + disable_custom_all_reduce: NotRequired[bool] + disable_overlap_schedule: NotRequired[bool] + enable_mixed_chunk: NotRequired[bool] + enable_dp_attention: NotRequired[bool] + enable_ep_moe: NotRequired[bool] + enable_torch_compile: NotRequired[bool] + torch_compile_max_bs: NotRequired[int] + cuda_graph_max_bs: NotRequired[int | None] + cuda_graph_bs: NotRequired[list[int] | None] + torchao_config: NotRequired[str] + enable_nan_detection: NotRequired[bool] + enable_p2p_check: NotRequired[bool] + triton_attention_reduce_in_fp32: NotRequired[bool] + triton_attention_num_kv_splits: NotRequired[int] + num_continuous_decode_steps: NotRequired[int] + enable_memory_saver: NotRequired[bool] + allow_auto_truncate: NotRequired[bool] + attention_backend: NotRequired[str | None] + enable_multimodal: NotRequired[bool] + sampling_backend: NotRequired[str | None] + context_length: NotRequired[int | None] + mem_fraction_static: NotRequired[float | None] + max_running_requests: NotRequired[int | None] + chunked_prefill_size: NotRequired[int | None] + max_prefill_tokens: NotRequired[int] + schedule_policy: NotRequired[str] + schedule_conservativeness: NotRequired[float] + cpu_offload_gb: NotRequired[int] + dtype: NotRequired[str] + kv_cache_dtype: NotRequired[str] + dp_size: NotRequired[int] # only used for dp attention + pp_size: NotRequired[int] # pipeline parallel size + ep_size: NotRequired[int] + # lora + enable_lora: NotRequired[bool | None] + max_lora_rank: NotRequired[int | None] + lora_target_modules: NotRequired[list[str] | None] + lora_paths: NotRequired[list[str] | None] + max_loaded_loras: NotRequired[int] + max_loras_per_batch: NotRequired[int] + lora_backend: NotRequired[str] + # logging + log_level: NotRequired[str] + log_level_http: NotRequired[str | None] + log_requests: NotRequired[bool] + log_requests_level: NotRequired[int] + show_time_cost: NotRequired[bool] + enable_metrics: NotRequired[bool] # Exports Prometheus-like metrics + # The interval (in decoding iterations) to log throughput + # and update prometheus metrics + decode_log_interval: NotRequired[int] + # Extra loader arguments + enable_multithread_load: NotRequired[bool] + enable_fast_load: NotRequired[bool] + # Server warmup + skip_server_warmup: NotRequired[bool] + + +class SGLangConfig(GenerationConfig): + """Configuration for SGLang runtime.""" + sglang_cfg: SglangSpecificArgs + sglang_kwargs: NotRequired[dict[str, Any]] + + \ No newline at end of file diff --git a/nemo_rl/models/generation/sglang/sglang_generation.py b/nemo_rl/models/generation/sglang/sglang_generation.py new file mode 100644 index 0000000000..99d2bd8bb7 --- /dev/null +++ b/nemo_rl/models/generation/sglang/sglang_generation.py @@ -0,0 +1,383 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import os +from collections import defaultdict +from typing import ( + Any, + AsyncGenerator, + Optional, + Union, +) + +import numpy as np +import ray +from ray.util.placement_group import PlacementGroup + +from nemo_rl.distributed.batched_data_dict import BatchedDataDict, SlicedDataDict +from nemo_rl.distributed.named_sharding import NamedSharding +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup +from nemo_rl.models.generation.interfaces import ( + GenerationDatumSpec, + GenerationInterface, + GenerationOutputSpec, +) +from nemo_rl.models.generation.sglang.config import SGLangConfig + +# Global thresholds for top_k and top_p validation. +# While top-k/p are not supported, these values allow for token filtering while the logprobs should be compatible. +# See https://github.com/NVIDIA-NeMo/RL/issues/69 and https://github.com/NVIDIA-NeMo/RL/issues/237 for more details. +TOP_K_THRESHOLD = 8000 # Allow top_k >= 8000 (effectively no filtering) +TOP_P_THRESHOLD = 0.99 # Allow top_p >= 0.99 (close to 1.0) + +logger = logging.getLogger(__name__) + + +class SGLangGeneration(GenerationInterface): + def __init__( + self, + cluster: RayVirtualCluster, + config: SGLangConfig, + name_prefix: str = "sglang_policy", + workers_per_node: Optional[Union[int, list[int]]] = None, + ): + """Initialize a SGLang policy with distributed workers. + + SGLang server manages TP/PP internally, but we still need to: + 1. Manage data parallel distribution across multiple servers + 2. Assign GPU bundles to each server + + Each server will see logical GPUs 0-N (via CUDA_VISIBLE_DEVICES set by Ray), + so we just need to tell SGLang how many GPUs to use (tp_size). + """ + # Store config + self.cfg = config + self.sglang_cfg = config["sglang_cfg"] + + gpus_per_server = self.sglang_cfg.get("gpus_per_server", None) + if gpus_per_server is None: + raise ValueError( + "gpus_per_server must be set in SGLangConfig.sglang_cfg." + ) + + # Calculate number of servers based on available resources + total_gpus = cluster.world_size() + num_servers = total_gpus // gpus_per_server + + if num_servers == 0: + raise ValueError( + f"Not enough GPUs. Need at least {gpus_per_server} GPUs per server, " + f"but only have {total_gpus} GPUs total." + ) + + if total_gpus % gpus_per_server != 0: + logger.warning( + f"[WARNING] Total GPUs ({total_gpus}) is not divisible by GPUs per server ({gpus_per_server}). " + f"Will use {num_servers} servers, leaving {total_gpus % gpus_per_server} GPUs unused." + ) + + self.dp_size = num_servers + self.gpus_per_server = gpus_per_server + + # Create sharding annotations + # Even though SGLang manages TP internally, we include it in the layout to support + # RayWorkerGroup's worker management (which creates one worker per GPU bundle). + # The TP dimension becomes a "free axis" in run_all_workers_sharded_data, ensuring + # only the primary workers (TP rank 0) are called. + total_workers = num_servers * gpus_per_server + self.sharding_annotations = NamedSharding( + layout=np.arange(total_workers).reshape(num_servers, gpus_per_server), + names=["data_parallel", "tensor_parallel"], + ) + + # Initialize placement groups + # For SGLang, we use PACK strategy to keep bundles together + # colocated is always at top level, not in sglang_cfg + strategy = None if self.cfg["colocated"]["enabled"] else "PACK" + cluster._init_placement_groups( + strategy=strategy, + use_unified_pg=False, # SGLang servers don't need cross-node model parallelism + ) + + # Create worker builder for SGLangGenerationWorker + worker_cls = "nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker" + worker_builder = RayWorkerBuilder(worker_cls, config) + + env_vars = {} + global_cvd = os.environ.get("CUDA_VISIBLE_DEVICES", None) + if global_cvd: + # Explicitly pass CUDA_VISIBLE_DEVICES to workers via env_vars + # This ensures all workers see the same global value, even though + env_vars["CUDA_VISIBLE_DEVICES"] = global_cvd + + # Allocate bundles for each server + # Each server gets consecutive bundles + bundle_indices_list = self._allocate_bundles_for_servers( + cluster, num_servers, gpus_per_server + ) + + # Create worker group with explicit bundle allocation + self.worker_group = RayWorkerGroup( + cluster, + worker_builder, + name_prefix=name_prefix, + bundle_indices_list=bundle_indices_list, + sharding_annotations=self.sharding_annotations, + env_vars=env_vars, + ) + + # Verify data parallel size matches + assert self.dp_size == self.worker_group.dp_size, ( + f"Data parallel size mismatch. Expected {self.dp_size}, got {self.worker_group.dp_size}" + ) + + # Used to track the round-robin selection of worker groups for generate_async + self.current_generate_dp_shard_idx = 0 + + def _allocate_bundles_for_servers( + self, + cluster: RayVirtualCluster, + num_servers: int, + gpus_per_server: int, + ) -> list[tuple[int, list[int]]]: + """Allocate GPU bundles to each SGLang server. + + Each server gets consecutive bundles within the same placement group (node). + Ray will automatically set CUDA_VISIBLE_DEVICES so each server sees logical GPUs 0, 1, 2, ..., gpus_per_server-1. + + Args: + cluster: The Ray virtual cluster + num_servers: Total number of SGLang servers to create + gpus_per_server: Number of GPUs each server needs + + Returns: + List of (node_idx, [bundle_indices]) tuples for each server + """ + placement_groups = cluster.get_placement_groups() + + if not placement_groups: + raise ValueError("No placement groups available in the cluster") + + bundle_indices_list = [] + + # Each server's bundles must be within the same placement group (node) + server_idx = 0 + for pg_idx, pg in enumerate(placement_groups): + if pg.bundle_count == 0: + continue + + # Calculate how many servers can fit in this placement group + num_servers_in_pg = pg.bundle_count // gpus_per_server + + # Allocate servers within this placement group + for local_server_idx in range(num_servers_in_pg): + if server_idx >= num_servers: + break + + # Calculate which bundles this server gets (consecutive within the PG) + start_bundle = local_server_idx * gpus_per_server + server_bundles = list(range(start_bundle, start_bundle + gpus_per_server)) + + # Each server gets a tuple of (node_idx, [local_bundle_indices]) + bundle_indices_list.append((pg_idx, server_bundles)) + server_idx += 1 + + if server_idx >= num_servers: + break + + if len(bundle_indices_list) < num_servers: + total_available = sum( + pg.bundle_count // gpus_per_server + for pg in placement_groups + if pg.bundle_count > 0 + ) + raise ValueError( + f"Not enough bundles to allocate all {num_servers} servers. " + f"Only {total_available} servers can be allocated " + f"(each server needs {gpus_per_server} GPUs)." + ) + + return bundle_indices_list + + def init_collective( + self, ip: str, port: int, world_size: int, *, train_world_size: int + ) -> list[ray.ObjectRef]: + """Initialize the collective communication. + + + TODO: if weight updates via NCCL are needed in the future. + """ + return [] + + def generate( + self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + ) -> BatchedDataDict[GenerationOutputSpec]: + """Generate a batch of data using SGLang.""" + assert isinstance(data, BatchedDataDict), ( + f"data must be a BatchedDataDict, got type: {type(data)}" + ) + assert "input_ids" in data and "input_lengths" in data, ( + "input_ids and input_lengths are required in data for SGLang generation" + ) + + # Shard the data across the data parallel servers + dp_size = self.sharding_annotations.get_axis_size("data_parallel") + sharded_data: list[SlicedDataDict] = data.shard_by_batch_size( + dp_size, allow_uneven_shards=True + ) + future_bundle = self.worker_group.run_all_workers_sharded_data( + "generate", + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=None, + output_is_replicated=None, + common_kwargs={"greedy": greedy}, + ) + + # Get results from the workers + results = self.worker_group.get_all_worker_results(future_bundle) + + # Combine results from all servers + combined: BatchedDataDict[GenerationOutputSpec] = BatchedDataDict.from_batches( + results, pad_value_dict={"output_ids": self.cfg["_pad_token_id"]} + ) + + # Verify the output has all required fields + required_keys = [ + "output_ids", + "generation_lengths", + "unpadded_sequence_lengths", + "logprobs", + ] + missing_keys = [key for key in required_keys if key not in combined] + if missing_keys: + raise ValueError( + f"Missing required keys for GenerationOutputSpec: {missing_keys}" + ) + + return combined + + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: + pass + + def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: + return [] + + def update_weights_from_collective(self) -> list[ray.ObjectRef]: + return [] + + def get_sglang_server_urls(self) -> list[str]: + """Get base URLs of all SGLang servers. + + Returns: + List of base URLs (e.g., ["http://localhost:30000", "http://localhost:30001"]) + """ + if not self.worker_group or not self.worker_group.workers: + raise RuntimeError("Worker group is not initialized") + + # Get base URLs from all workers (only primary workers, TP rank 0) + # Use run_rank_0_only_axes to only get URLs from primary workers + futures = self.worker_group.run_all_workers_single_data( + "get_base_url", + run_rank_0_only_axes=["tensor_parallel"], + ) + urls = ray.get(futures) + # Filter out None values and return unique URLs + return list(set(url for url in urls if url is not None)) + + def get_sglang_url_to_gpu_uuids(self) -> dict[str, list[str]]: + """Get mapping from SGLang server URL to list of GPU UUIDs it uses. + + Returns: + Dict mapping server URL to list of GPU UUIDs + e.g., {"http://localhost:30000": ["GPU-aaa", "GPU-bbb"], ...} + """ + if not self.worker_group or not self.worker_group.workers: + raise RuntimeError("Worker group is not initialized") + + # Get base URLs and GPU UUIDs from all primary workers (TP rank 0) + futures_url = self.worker_group.run_all_workers_single_data( + "get_base_url", + run_rank_0_only_axes=["tensor_parallel"], + ) + futures_uuids = self.worker_group.run_all_workers_single_data( + "get_gpu_uuids", + run_rank_0_only_axes=["tensor_parallel"], + ) + + urls = ray.get(futures_url) + uuids_list = ray.get(futures_uuids) + + # Create mapping + url_to_uuids = {} + for url, uuids in zip(urls, uuids_list): + if url is not None and uuids is not None: + url_to_uuids[url] = uuids + + return url_to_uuids + + def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: + """Wake workers up for colocated inference.""" + pass + + def finish_generation(self, *args: Any, **kwargs: Any) -> bool: + """Sleep workers and reset prefix cache.""" + pass + + def shutdown(self) -> bool: + """Shut down all SGLang workers and clean up resources.""" + try: + # Use the worker group's shutdown method with the worker's cleanup method + return self.worker_group.shutdown(cleanup_method="shutdown") + except Exception as e: + logger.error(f"Error during SGLang policy shutdown: {e}") + return False + + def __del__(self) -> None: + """Shuts down the worker groups when the object is deleted or is garbage collected. + + This is an extra safety net in case the user forgets to call shutdown() and the pointer to + the object is lost due to leaving a function scope. It's always recommended that the + user calls shutdown(). + """ + self.shutdown() + + def invalidate_kv_cache(self) -> bool: + """Invalidate KV cache before weight updates (Megatron-style). + + This flushes the cache before weight updates to clear stale cache. + Only primary workers (TP rank 0, model owners) will flush their cache. + + Returns: + bool: True if all caches were flushed successfully, False otherwise + """ + try: + futures = self.worker_group.run_all_workers_single_data( + "invalidate_kv_cache", + run_rank_0_only_axes=["tensor_parallel"], + ) + results = ray.get(futures) + results = [r for r in results if r is not None] + success = all(result for result in results) if results else True + if success: + logger.info("[sglang refit] All SGLang server caches flushed successfully") + else: + logger.warning("[sglang refit] WARNING - Some SGLang server caches failed to flush") + return success + except Exception as e: + logger.error(f"[sglang refit] Error flushing SGLang caches: {e}") + return False diff --git a/nemo_rl/models/generation/sglang/sglang_worker.py b/nemo_rl/models/generation/sglang/sglang_worker.py new file mode 100644 index 0000000000..64b188e55d --- /dev/null +++ b/nemo_rl/models/generation/sglang/sglang_worker.py @@ -0,0 +1,747 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import gc +import logging +import os +import sys +from typing import Any, Optional, cast +import requests +import asyncio +import aiohttp + +import time +import ray +import torch +import multiprocessing + +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import _get_node_ip_local, _get_free_port_local +from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches +from nemo_rl.models.generation.interfaces import ( + GenerationDatumSpec, + GenerationOutputSpec, + verify_right_padding, +) +from nemo_rl.models.generation.sglang.config import SGLangConfig +from nemo_rl.models.generation.sglang.utils import AsyncLoopThread +from nemo_rl.models.huggingface.common import ModelFlag +from nemo_rl.utils.nsys import wrap_with_nvtx_name + +from sglang.srt.entrypoints.http_server import launch_server +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import kill_process_tree + +logger = logging.getLogger(__name__) + + +@ray.remote( + runtime_env={**get_nsight_config_if_pattern_matches("sglang_generation_worker")} +) # pragma: no cover +class SGLangGenerationWorker: + def __repr__(self) -> str: + """Customizes the actor's prefix in the Ray logs. + + This makes it easier to identify which worker is producing specific log messages. + """ + return f"{self.__class__.__name__}" + + @staticmethod + def configure_worker( + num_gpus: int | float, bundle_indices: Optional[tuple[int, list[int]]] = None + ) -> tuple[dict[str, Any], dict[str, str], dict[str, Any]]: + """Provides complete worker configuration for SGLang server. + + This method configures the worker based on bundle_indices which tells us + how many GPUs this server should use. + + Args: + num_gpus: Original GPU allocation for this worker based on the placement group + bundle_indices: Tuple of (node_idx, local_bundle_indices) for this server + + Returns: + tuple with complete worker configuration: + - 'resources': Resource allocation (e.g., num_gpus) + - 'env_vars': Environment variables for this worker + - 'init_kwargs': Parameters to pass to __init__ of the worker + """ + # Initialize configuration + resources: dict[str, Any] = {"num_gpus": num_gpus} + init_kwargs: dict[str, Any] = {} + env_vars: dict[str, str] = {} + + local_bundle_indices = None + if bundle_indices is not None: + node_idx = bundle_indices[0] + local_bundle_indices = bundle_indices[1] + init_kwargs["bundle_indices"] = local_bundle_indices + + # Calculate a unique seed from node_idx and bundle_indices + if len(local_bundle_indices) == 1: + seed = node_idx * 1024 + local_bundle_indices[0] + else: + bundle_id = local_bundle_indices[0] // len(local_bundle_indices) + seed = node_idx * 1024 + bundle_id + + init_kwargs["seed"] = seed + + # Check if this worker is part of a parallel group (multiple GPUs per server). + # A worker with local rank =0 owns the server(local_bundle_indices is not None ) + # otherwise it is a placeholder for Ray's resource management (local_bundle_indices is None). + is_part_of_parallel_workers = ( + local_bundle_indices is not None and len(local_bundle_indices) > 1 + ) or local_bundle_indices is None + + if is_part_of_parallel_workers: + # For parallel workers, we manage GPU assignment via base_gpu_id + # All workers see the same global CUDA_VISIBLE_DEVICES, but use different + # logical GPU ranges via base_gpu_id + resources["num_gpus"] = 0 + env_vars["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" + init_kwargs["fraction_of_gpus"] = num_gpus + else: + env_vars["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" + + return resources, env_vars, init_kwargs + + def __init__( + self, + config: SGLangConfig, + bundle_indices: Optional[list[int]] = None, + fraction_of_gpus: float = 1.0, + seed: Optional[int] = None, + ): + """Initialize a SGLang worker for distributed inference. + + Args: + config: Configuration dictionary for the policy + bundle_indices: List of local bundle indices for this server. + The length of this list determines tp_size (number of GPUs per server). + Only needed for the first worker in each server group (model owner). + fraction_of_gpus: Fraction of GPUs to use for this worker + seed: Random seed for initialization, if None, then defaults to the config's seed + """ + self.cfg = config + self.is_model_owner = bundle_indices is not None + self.global_rank = int(os.environ.get("RANK", "0")) + self.sglang_cfg = config["sglang_cfg"] + + # Create a dedicated event loop thread for async operations + # there will be issues if we use the event loop in the main thread + self.async_loop_thread = AsyncLoopThread() + + # temp: Maximum concurrent requests per server + # we may remove this limit in the future + self.max_concurrent_requests = config.get("max_concurrent_requests", 999999) + + # Only the primary worker (local_rank=0) in each server group starts the SGLang server + # Secondary workers (local_rank!=0) just returns + if not self.is_model_owner: + return + + # Determine tp_size from bundle_indices length + tp_size = len(bundle_indices) + + base_gpu_id = bundle_indices[0] if bundle_indices else 0 + + # Get the global CUDA_VISIBLE_DEVICES (all engines see the same global value) + global_cvd = os.environ.get("CUDA_VISIBLE_DEVICES", None) + + + logger.info( + f"[SGLang Server] Rank {self.global_rank}: " + f"base_gpu_id={base_gpu_id}, tp_size={tp_size}, " + f"bundle_indices={bundle_indices}, global_cvd={global_cvd}" + ) + + # Get current node IP and a free port for the server + node_ip = _get_node_ip_local() + free_port = _get_free_port_local() + + # Build SGLang server arguments + kwargs = { + "model_path": self.sglang_cfg["model_path"], + "trust_remote_code": True, + "random_seed": seed if seed is not None else self.sglang_cfg.get("random_seed", 1), + # Memory settings + "enable_memory_saver": self.sglang_cfg["enable_memory_saver"], + "gpu_id_step": 1, + "base_gpu_id": base_gpu_id, + # Parallel settings + "tp_size": tp_size, + "dp_size": self.sglang_cfg["dp_size"], + "pp_size": self.sglang_cfg["pp_size"], + "ep_size": self.sglang_cfg["ep_size"], + # Always skip warmup to prevent warmup timeout + "skip_server_warmup": self.sglang_cfg.get("skip_server_warmup", True), + # Server network settings - listen on all interfaces, use the free port we found + "host": "0.0.0.0", + "port": free_port, + "torchao_config": "", + } + + for key in [ + "dtype", "kv_cache_dtype", "context_length", "max_running_requests", + "chunked_prefill_size", "max_prefill_tokens", "schedule_policy", + "schedule_conservativeness", "cpu_offload_gb", "log_level", + "mem_fraction_static", "allow_auto_truncate", + ]: + if key in self.sglang_cfg: + kwargs[key] = self.sglang_cfg[key] + + server_args = ServerArgs(**kwargs) + # Save server_args and base_url for use in generate() and _make_request() + self.server_args = server_args + self.base_url = f"http://{node_ip}:{free_port}" + + logger.info(f"[SGLang Worker] Rank {self.global_rank} Starting on {self.base_url}, CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}, base_gpu_id: {base_gpu_id}") + + self.session = None + self.connector = None + + self.server_process = self._launch_server_process(server_args) + + def get_base_url(self) -> str: + """Get the base URL of this SGLang server.""" + return self.base_url + + def invalidate_kv_cache(self) -> bool: + """Invalidate KV cache before weight updates (Megatron-style). + + This flushes the cache before weight updates to clear stale cache. + Uses retry logic to handle cases where there are pending requests. + + Returns: + bool: True if flush was successful, False otherwise + """ + if not self.is_model_owner: + return True + + url = f"{self.base_url}/flush_cache" + max_attempts = 60 + connection_retry_limit = 5 + + # flush_cache will not return status_code 200 when there are pending requests + for attempt in range(max_attempts): + try: + response = requests.get(url, timeout=10) + if response.status_code == 200: + if attempt > 0: + logger.info( + f"[SGLang Worker] Rank {self.global_rank} Cache flushed successfully " + f"(attempt {attempt + 1})" + ) + return True + except requests.exceptions.ConnectionError: + # Server might not be ready yet - only retry for first few attempts + if attempt >= connection_retry_limit: + logger.warning( + f"[SGLang Worker] Rank {self.global_rank} Connection failed after " + f"{connection_retry_limit} attempts" + ) + return False + except Exception as e: + # For other errors, log and retry (except on last attempt) + if attempt == max_attempts - 1: + logger.error( + f"[SGLang Worker] Rank {self.global_rank} Failed to flush cache after " + f"{max_attempts} attempts: {e}" + ) + return False + + time.sleep(1) + + # All attempts exhausted without success + logger.error( + f"[SGLang Worker] Rank {self.global_rank} Timeout: Cache flush failed after " + f"{max_attempts} attempts. Server may have pending requests." + ) + return False + + def get_gpu_uuids(self) -> list[str]: + """Get list of GPU UUIDs used by this SGLang server. + + Returns: + List of GPU UUIDs (e.g., ["GPU-xxxxx", "GPU-yyyyy"]) + """ + from nemo_rl.utils.nvml import get_device_uuid + + # Get all GPU UUIDs used by this server + # SGLang server uses GPUs starting from base_gpu_id with tp_size GPUs + gpu_uuids = [] + for i in range(self.server_args.tp_size): + gpu_id = self.server_args.base_gpu_id + i + uuid = get_device_uuid(gpu_id) + gpu_uuids.append(uuid) + + return gpu_uuids + + + def _merge_stop_strings(self, batch_stop_strings): + """Merge stop strings from config and batch. + + Args: + batch_stop_strings: List of stop strings from batch (one per sample) + + Returns: + List of merged stop strings (one per sample) + """ + stop_set: set[str] = set() + + # Add stop strings from config + if self.cfg.get("stop_strings"): + stop_set.update(self.cfg["stop_strings"]) + + # Merge stop strings from batch + merged_stop_strings = [] + for sample_ss in batch_stop_strings: + sample_stop_set = stop_set.copy() + if sample_ss: + if isinstance(sample_ss, str): + sample_stop_set.add(sample_ss) + elif isinstance(sample_ss, list): + sample_stop_set.update(sample_ss) + + merged_stop_strings.append(list(sample_stop_set) if sample_stop_set else None) + + return merged_stop_strings + + def _build_sampling_params( + self, + *, + greedy: bool, + stop_strings, + max_new_tokens: Optional[int] = None, + input_len: Optional[int] = None, + context_length: Optional[int] = None, + sample_index: Optional[int] = None, + ) -> dict[str, Any]: + """Build sampling parameters dictionary for SGLang API. + + Args: + greedy: Whether to use greedy decoding (temperature=0.0) + stop_strings: Merged stop strings (not used here, handled per sample) + max_new_tokens: Override max_new_tokens from config if provided + input_len: Input length for this sample (used for context_length adjustment) + context_length: Maximum context length (if provided, adjusts max_new_tokens) + sample_index: Sample index (used for warning messages, 0-indexed) + + Returns: + Dictionary of sampling parameters compatible with SGLang API + """ + top_k_cfg = self.cfg.get("top_k") + top_k_val = 1 if greedy else (top_k_cfg if top_k_cfg is not None else -1) + temperature = 0.0 if greedy else self.cfg["temperature"] + + base_max_tokens = ( + max_new_tokens if max_new_tokens is not None else self.cfg["max_new_tokens"] + ) + + # TODO: check if this is needed + final_max_tokens = base_max_tokens + if context_length is not None and input_len is not None: + max_allowed_new_tokens = max(0, context_length - input_len - 1) + if base_max_tokens > max_allowed_new_tokens: + final_max_tokens = max_allowed_new_tokens + if sample_index == 0: + logger.warning( + f"[SGLang Worker] Rank {self.global_rank} Warning: " + f"Sample {sample_index} input length ({input_len}) + max_new_tokens ({base_max_tokens}) " + f"would exceed context_length ({context_length}). " + f"Reducing max_new_tokens to {final_max_tokens} for this sample." + ) + + # Build sampling params dict + sampling_params = { + "temperature": temperature, + "top_p": self.cfg.get("top_p", 1.0), + "max_new_tokens": final_max_tokens, + } + + if top_k_val != -1: + sampling_params["top_k"] = top_k_val + + stop_token_ids = self.cfg.get("stop_token_ids") + if stop_token_ids is not None: + sampling_params["stop_token_ids"] = stop_token_ids + + return sampling_params + + async def _ensure_session(self): + if self.session is None: + # Create connector with connection pool limit + self.connector = aiohttp.TCPConnector(limit=512, limit_per_host=512) + # Create session with timeout + timeout = aiohttp.ClientTimeout(total=300) # 5 minutes timeout + self.session = aiohttp.ClientSession(connector=self.connector, timeout=timeout) + return self.session + + async def _generate_single_sample( + self, + input_ids: list[int], + sampling_params: dict[str, Any], + stop_string: Optional[str] = None, + ) -> tuple[list[int], list[float]]: + """Generate a single sample using SGLang API (async function). + + Args: + input_ids: List of input token IDs (without padding) + sampling_params: Dictionary of sampling parameters (temperature, top_p, max_new_tokens, etc.) + stop_string: Optional stop string for this sample + + Returns: + Tuple of (generated_tokens, logprobs): + - generated_tokens: List of generated token IDs + - logprobs: List of log probabilities for generated tokens + """ + # Prepare payload for SGLang API + # Note: stop should be in sampling_params, not in payload top level + # TODO: double check this + if stop_string is not None: + # stop can be a string or list of strings + sampling_params = sampling_params.copy() # Don't modify the original + sampling_params["stop"] = stop_string + + payload = { + "sampling_params": sampling_params, + "return_logprob": True, + "input_ids": input_ids, + } + + url = f"{self.base_url}/generate" + headers = { + "Content-Type": "application/json; charset=utf-8", + } + + session = await self._ensure_session() + + try: + async with session.post(url, json=payload, headers=headers) as response: + response.raise_for_status() + result = await response.json() + except Exception as e: + logger.error(f"[SGLang Worker] Rank {self.global_rank} Request failed for input_len={len(input_ids)}: {e}") + raise + + # Extract generated tokens and logprobs + meta_info = result.get("meta_info", {}) + output_token_logprobs = meta_info.get("output_token_logprobs", []) + + if output_token_logprobs: + new_tokens = [item[1] for item in output_token_logprobs] + new_logprobs = [item[0] for item in output_token_logprobs] + else: + # Fallback: empty if token logprobs not available + new_tokens = [] + new_logprobs = [] + + return new_tokens, new_logprobs + + async def _generate_async(self, tasks): + """Execute generation tasks with concurrency control. + + TEMP: Uses a semaphore to limit the number of concurrent requests per server, preventing server overload. + A router based solution is preffered in the future. + """ + semaphore = asyncio.Semaphore(self.max_concurrent_requests) + + async def wrap(idx, coro): + async with semaphore: + try: + result = await coro + return idx, result + except Exception as e: + raise + + wrapped = [wrap(i, t) for i, t in enumerate(tasks)] + results = [None] * len(tasks) + count = 0 + + for fut in asyncio.as_completed(wrapped): + idx, value = await fut + results[idx] = value + count += 1 + if count % 50 == 0 or count == len(tasks): + logger.debug(f"[SGLang Worker] Rank {self.global_rank} Completed {count}/{len(tasks)} tasks") + + return results + + def _launch_server_process(self, server_args: ServerArgs) -> multiprocessing.Process: + """Launch the SGLang server process and wait for it to be ready.""" + p = multiprocessing.Process(target=launch_server, args=(server_args,)) + p.start() + + # Wait for server to be ready by checking health endpoint + # Use the base_url we stored earlier + headers = { + "Content-Type": "application/json; charset=utf-8", + } + + max_wait_time = 300 # 5 minutes timeout + start_time = time.time() + with requests.Session() as session: + while True: + if time.time() - start_time > max_wait_time: + kill_process_tree(p.pid) + raise TimeoutError( + f"[SGLang Server] Rank {self.global_rank} Server failed to start within {max_wait_time}s" + ) + try: + response = session.get(f"{self.base_url}/health_generate", headers=headers, timeout=10) + if response.status_code == 200: + logger.info(f"[SGLang Server] Rank {self.global_rank} Server is ready at {self.base_url}") + break + except requests.RequestException: + pass + + if not p.is_alive(): + raise RuntimeError(f"[SGLang Server] Rank {self.global_rank} Server process terminated unexpectedly.") + + time.sleep(2) + return p + + + + + @wrap_with_nvtx_name("sglang_genertion_worker/generate") + def generate( + self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + ) -> BatchedDataDict[GenerationOutputSpec]: + """Generate a batch of data using SGLang generation. + + Args: + data: BatchedDataDict containing input_ids and input_lengths tensors + greedy: Whether to use greedy decoding instead of sampling + + Returns: + BatchedDataDict conforming to GenerationOutputSpec: + - output_ids: input + generated token IDs with proper padding + - logprobs: Log probabilities for tokens + - generation_lengths: Lengths of each response + - unpadded_sequence_lengths: Lengths of each input + generated sequence + """ + # Handle empty input case + if len(data["input_ids"]) == 0: + return BatchedDataDict[GenerationOutputSpec]( + { + "output_ids": torch.zeros((0, 0), dtype=torch.long), + "logprobs": torch.zeros((0, 0), dtype=torch.float), + "generation_lengths": torch.zeros(0, dtype=torch.long), + "unpadded_sequence_lengths": torch.zeros(0, dtype=torch.long), + } + ) + + input_ids = data["input_ids"] + input_lengths = data["input_lengths"] + batch_stop_strings = data.get("stop_strings", [None] * len(input_lengths)) + stop_strings = self._merge_stop_strings(batch_stop_strings) + batch_size = len(input_lengths) + pad_token_id = self.cfg["_pad_token_id"] + + # Verify inputs have correct padding + verify_right_padding(data, pad_value=pad_token_id) + + # Original input length with padding + padded_input_length = input_ids.size(1) + + logger.debug(f"[SGLang Worker] Rank {self.global_rank} batch_size: {batch_size}, padded_input_length: {padded_input_length}") + + if batch_size == 0: + raise ValueError("Empty batch received") + + context_length = self.sglang_cfg.get("context_length", None) + + # Create async tasks for all samples + tasks = [] + for i in range(batch_size): + input_len = input_lengths[i].item() + + # Truncate input if it exceeds context_length + if context_length is not None and input_len >= context_length: + input_len = context_length - 1 + + valid_input_ids = input_ids[i, :input_len].tolist() + + # Build sampling params for this sample (with context_length adjustment) + sample_sampling_params = self._build_sampling_params( + greedy=greedy, + stop_strings=stop_strings, + max_new_tokens=None, + input_len=input_len, + context_length=context_length, + sample_index=i, + ) + + tasks.append( + self._generate_single_sample( + input_ids=valid_input_ids, + sampling_params=sample_sampling_params, + stop_string=stop_strings[i], + ) + ) + + # Execute all requests concurrently using the dedicated event loop thread + try: + all_results = self.async_loop_thread.run(self._generate_async(tasks)) + except Exception as e: + raise + + total_generated_tokens = sum(len(tokens) for tokens, _ in all_results) + avg_generation_length = total_generated_tokens / batch_size if batch_size > 0 else 0 + + # Process results + output_ids_list = [] + logprobs_list = [] + generation_lengths_list = [] + unpadded_sequence_lengths_list = [] + max_length = 0 + + # First pass: calculate max_length + for i, (new_tokens, new_logprobs) in enumerate(all_results): + input_len = input_lengths[i].item() + generation_length = len(new_tokens) + unpadded_length = input_len + generation_length + max_length = max(max_length, unpadded_length) + + total_length = max(max_length, padded_input_length) + + for i, (new_tokens, new_logprobs) in enumerate(all_results): + input_len = input_lengths[i].item() + generation_length = len(new_tokens) + unpadded_length = input_len + generation_length + + full_output = torch.full( + (total_length,), pad_token_id, dtype=input_ids.dtype + ) + full_output[:input_len] = input_ids[i][:input_len] + + # Add generated tokens after the original input + if new_tokens: + full_output[input_len : input_len + len(new_tokens)] = ( + torch.tensor(new_tokens, dtype=input_ids.dtype) + ) + + # Construct logprobs: zeros for input tokens, actual logprobs for generated tokens + full_logprobs = torch.zeros(total_length, dtype=torch.float32) + if new_logprobs: + for idx, logprob in enumerate(new_logprobs): + position = input_len + idx + full_logprobs[position] = logprob + + output_ids_list.append(full_output) + logprobs_list.append(full_logprobs) + generation_lengths_list.append(generation_length) + unpadded_sequence_lengths_list.append(unpadded_length) + + # Stack into tensors + output_ids = torch.stack(output_ids_list) + logprobs = torch.stack(logprobs_list) + generation_lengths = torch.tensor(generation_lengths_list, dtype=torch.long) + unpadded_sequence_lengths = torch.tensor(unpadded_sequence_lengths_list, dtype=torch.long) + logger.debug(f"[SGLang Worker] Rank {self.global_rank} Generated {total_generated_tokens} tokens across {batch_size} samples (avg: {avg_generation_length:.1f} tokens/sample)") + return BatchedDataDict[GenerationOutputSpec]( + { + "output_ids": output_ids, + "generation_lengths": generation_lengths, + "unpadded_sequence_lengths": unpadded_sequence_lengths, + "logprobs": logprobs, + } + ) + + def sleep(self): + # TODO + pass + + def wake_up(self, **kwargs): + # TODO + pass + + def shutdown(self) -> bool: + """Shutdown the SGLang server process and cleanup async resources. + + Returns: + bool: True if shutdown was successful, False otherwise + """ + if not self.is_model_owner: + if hasattr(self, "async_loop_thread"): + try: + self.async_loop_thread.shutdown() + logger.info(f"[SGLang Worker] Rank {self.global_rank} Async loop thread shut down.") + except Exception as e: + logger.error(f"[SGLang Worker] Rank {self.global_rank} Error shutting down async loop thread: {e}") + return True + + try: + if hasattr(self, "session") and self.session is not None: + try: + async def close_session(): + await self.session.close() + if self.connector is not None: + await self.connector.close() + + self.async_loop_thread.run(close_session()) + logger.info(f"[SGLang Worker] Rank {self.global_rank} aiohttp session closed.") + except Exception as e: + logger.error(f"[SGLang Worker] Rank {self.global_rank} Error closing aiohttp session: {e}") + + # Shutdown async loop thread after session cleanup + if hasattr(self, "async_loop_thread"): + try: + self.async_loop_thread.shutdown() + logger.info(f"[SGLang Worker] Rank {self.global_rank} Async loop thread shut down.") + except Exception as e: + logger.error(f"[SGLang Worker] Rank {self.global_rank} Error shutting down async loop thread: {e}") + + if not hasattr(self, "server_process") or self.server_process is None: + return True + + logger.info( + f"[SGLang Worker] Rank {self.global_rank} Shutting down server at {self.base_url}..." + ) + + if self.server_process.is_alive(): + kill_process_tree(self.server_process.pid) + + # Wait for the process to terminate + self.server_process.join(timeout=5.0) + + if self.server_process.is_alive(): + return False + return True + + except Exception as e: + logger.error( + f"[SGLang Worker] Rank {self.global_rank} Error during shutdown: {e}" + ) + return False + + def _make_request(self, endpoint: str, payload: Optional[dict] = None): + """Make a POST request to the specified endpoint with the given payload. + + Args: + endpoint: The API endpoint to call + payload: The JSON payload to send (default: empty dict) + + Returns: + The JSON response from the server + """ + # Use the stored base_url instead of constructing from server_args + url = f"{self.base_url}/{endpoint}" + headers = { + "Content-Type": "application/json; charset=utf-8", + } + response = requests.post(url, json=payload or {}, headers=headers, timeout=60) + response.raise_for_status() + return response.json() \ No newline at end of file diff --git a/nemo_rl/models/generation/sglang/utils.py b/nemo_rl/models/generation/sglang/utils.py new file mode 100644 index 0000000000..469d3bb79e --- /dev/null +++ b/nemo_rl/models/generation/sglang/utils.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import threading + + +class AsyncLoopThread: + """A background event loop thread for running async operations in Ray actors. + + This class creates a dedicated thread with its own event loop, allowing + synchronous Ray actor methods to execute async coroutines without blocking + the main actor thread. This is necessary because run_coroutine_threadsafe + requires the event loop to be in a different thread. + """ + def __init__(self): + self.loop = asyncio.new_event_loop() + self._ready = threading.Event() + self._thread = threading.Thread(target=self._start_loop, daemon=True) + self._thread.start() + if not self._ready.wait(timeout=5.0): + raise RuntimeError("Event loop thread failed to start within 5 seconds") + + def _start_loop(self): + """Run the event loop in the background thread.""" + asyncio.set_event_loop(self.loop) + self._ready.set() + self.loop.run_forever() + + def run(self, coro): + """Schedule a coroutine onto the loop and block until it's done. + + Args: + coro: The coroutine to execute + + Returns: + The result of the coroutine + """ + if not self.loop.is_running(): + raise RuntimeError("Event loop is not running") + future = asyncio.run_coroutine_threadsafe(coro, self.loop) + result = future.result() + return result + + def shutdown(self): + """Shutdown the event loop and wait for the thread to finish.""" + if self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + self._thread.join(timeout=2.0) + if not self.loop.is_closed(): + self.loop.close() + diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 93540ebe82..1366ce28c5 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -876,6 +876,14 @@ def clear_vllm_logger_metrics(self) -> None: ) ray.get(futures) + def clear_logger_metrics(self) -> None: + """Clear logger metrics for performance reporting.""" + self.clear_vllm_logger_metrics() + + def get_logger_metrics(self) -> dict[str, Any]: + """Get logger metrics for performance reporting.""" + return self.get_vllm_logger_metrics() + def __del__(self) -> None: """Shuts down the worker groups when the object is deleted or is garbage collected. diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index 144b0c517d..10b34e5ae0 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -182,6 +182,18 @@ def stream_weights_via_ipc_zmq( ) -> list[ray.ObjectRef]: pass + def stream_weights_via_http( + self, sglang_url_to_gpu_uuids: dict[str, list[str]] + ) -> None: + """Stream model weights to SGLang servers via HTTP API. + + Args: + sglang_url_to_gpu_uuids: Dict mapping SGLang server URL to list of GPU UUIDs it uses + """ + raise NotImplementedError( + "stream_weights_via_http is not implemented for this policy worker" + ) + @abstractmethod def broadcast_weights_for_collective( self, kv_scales: Optional[dict[str, float]] = None diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 21558768b4..eb711883c8 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -762,6 +762,20 @@ def stream_weights_via_ipc_zmq( ) return futures + def stream_weights_via_http( + self, sglang_url_to_gpu_uuids: dict[str, list[str]] + ) -> list[ray.ObjectRef]: + """Send the weights to SGLang servers via HTTP API. + + Args: + sglang_url_to_gpu_uuids: Dict mapping SGLang server URL to list of GPU UUIDs it uses + """ + futures = self.worker_group.run_all_workers_single_data( + "stream_weights_via_http", + sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, + ) + return futures + def broadcast_weights_for_collective( self, kv_scales: Optional[dict[str, float]] = None ) -> list[ray.ObjectRef]: diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 283b980e72..dacac5b25f 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -12,15 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import gc import os +import pickle import traceback from enum import Enum from typing import Any, Dict, Optional +import requests import torch +import torch.distributed as dist import zmq from torch.multiprocessing.reductions import rebuild_cuda_tensor + from transformers import ( AutoModelForCausalLM, AutoModelForImageTextToText, @@ -473,3 +478,251 @@ def rebuild_cuda_tensor_from_ipc( list_args = list(args) list_args[6] = device_id return func(*list_args) + + +def stream_weights_via_http_impl( + params_generator, + sglang_url_to_gpu_uuids: dict[str, list[str]], + rank: int, + worker_name: str, + current_device_uuid: str, +) -> None: + """Stream weights to SGLang servers via HTTP API (update_weights_from_tensor). + + Flow: Each rank creates IPC handler → gather handlers in rank order → send list → SGLang matches by tp_rank index + + Key points: + - Each rank creates handler on its own GPU + - Handlers are gathered in rank order: [rank0_handler, rank1_handler, ...] + - List index = rank = GPU ID + - SGLang automatically matches: handler = serialized_handlers[tp_rank] + + Args: + params_generator: Generator yielding (name, tensor) pairs + sglang_url_to_gpu_uuids: Dict mapping SGLang server URL to list of GPU UUIDs it uses + rank: Worker rank for logging + worker_name: Name of the worker for logging + current_device_uuid: UUID of the current training worker's GPU + """ + from sglang.srt.utils import MultiprocessingSerializer + try: + from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions + except ImportError: + from sglang.srt.patch_torch import monkey_patch_torch_reductions + print(f"[sglang refit details] entering stream_weights_via_http_impl") + + monkey_patch_torch_reductions() + + target_urls = [ + url for url, uuids in sglang_url_to_gpu_uuids.items() + if current_device_uuid in uuids + ] + + if not target_urls: + raise RuntimeError( + f"{worker_name} (rank {rank}): No matching SGLang server found for GPU UUID {current_device_uuid}. " + f"Available servers: {list(sglang_url_to_gpu_uuids.keys())}" + ) + + if len(target_urls) > 1: + print( + f"[WARNING] {worker_name} (rank {rank}): GPU UUID {current_device_uuid} matches multiple SGLang servers: {target_urls}. " + f"Using the first one: {target_urls[0]}" + ) + target_urls = [target_urls[0]] + + base_url = target_urls[0] + url = f"{base_url}/update_weights_from_tensor" + sglang_gpu_uuids = sglang_url_to_gpu_uuids[base_url] + + ipc_gather_group, ipc_gather_src, matching_ranks = _setup_ipc_gather_group( + rank, current_device_uuid, sglang_gpu_uuids, sglang_url_to_gpu_uuids + ) + print(f"[sglang refit] {worker_name} (rank {rank}): ipc_gather_group={ipc_gather_group}, ipc_gather_src={ipc_gather_src}, matching_ranks={matching_ranks}") + tensor_count = 0 + + try: + tensor_list = list(params_generator) + total_tensors = len(tensor_list) + + if rank == ipc_gather_src: + print( + f"[sglang refit details] {worker_name}: Starting weight update - " + f"Total parameters to update: {total_tensors}", + flush=True + ) + + for idx, (name, tensor) in enumerate(tensor_list): + torch.cuda.current_stream().synchronize() + tensor = tensor.contiguous().cuda() + + named_tensors = [(name, tensor)] + serialized_handler = MultiprocessingSerializer.serialize( + named_tensors, + output_str=True + ) + + gathered_handlers = _gather_ipc_handlers( + serialized_handler, ipc_gather_group, ipc_gather_src, rank, matching_ranks + ) + + if rank == ipc_gather_src: + _send_tensor_to_sglang( + url, name, gathered_handlers, tensor.shape, str(tensor.dtype), + flush_cache=False + ) + tensor_count += 1 + + del tensor, serialized_handler + if rank == ipc_gather_src: + del gathered_handlers + torch.cuda.empty_cache() + + if rank == ipc_gather_src: + print( + f"[sglang refit details] {worker_name}: Weight update completed - " + f"Successfully updated {tensor_count}/{total_tensors} parameters to SGLang server: {base_url}", + flush=True + ) + if tensor_count != total_tensors: + print( + f"[sglang refit details] {worker_name}: WARNING - Expected {total_tensors} tensors, " + f"but only sent {tensor_count}", + flush=True + ) + + except Exception as e: + print( + f"{worker_name} (rank {rank}): Error during HTTP weight streaming: {e}.\n" + f"{traceback.format_exc()}" + ) + raise + + finally: + gc.collect() + torch.cuda.empty_cache() + + +def _setup_ipc_gather_group( + rank: int, + current_device_uuid: str, + sglang_gpu_uuids: list[str], + sglang_url_to_gpu_uuids: dict[str, list[str]], +) -> tuple[Optional[dist.ProcessGroup], Optional[int], Optional[list[int]]]: + """Setup gather configuration for IPC handlers. + + Returns: + Tuple of (gather_group, gather_src_rank, matching_ranks) + - gather_group: None (use default FSDP group) + - gather_src_rank: The rank that will collect and send to SGLang server + - matching_ranks: List of ranks that belong to the same SGLang server + """ + if not dist.is_initialized(): + return None, None, None + + world_size = dist.get_world_size() + my_rank = dist.get_rank() + + all_ranks_uuids = [None] * world_size + dist.all_gather_object(all_ranks_uuids, current_device_uuid) + + matching_ranks = [ + r for r, uuid in enumerate(all_ranks_uuids) + if uuid in sglang_gpu_uuids + ] + + if len(matching_ranks) == 0: + return None, None, None + + matching_ranks = sorted(matching_ranks) + gather_src = matching_ranks[0] + + return None, gather_src, matching_ranks + + +def _gather_ipc_handlers( + serialized_handler: str, + gather_group: Optional[dist.ProcessGroup], + gather_src: Optional[int], + rank: int, + matching_ranks: Optional[list[int]] = None, +) -> Optional[list[str]]: + """Gather IPC handlers from all ranks in the default FSDP group, then filter by server. + + Args: + serialized_handler: Serialized IPC handler from this rank + gather_group: Process group (None means use default FSDP group) + gather_src: Rank that will collect and filter handlers + rank: Current rank + matching_ranks: List of ranks that belong to the same SGLang server + + Returns: + List of serialized handlers in rank order (only on gather_src rank), None otherwise + The list contains handlers from matching_ranks only, in rank order + """ + if gather_src is None: + return None + + if not dist.is_initialized(): + return None + + world_size = dist.get_world_size() + + all_handlers = [None] * world_size + dist.all_gather_object(all_handlers, serialized_handler) + + if rank == gather_src and matching_ranks is not None: + filtered_handlers = [all_handlers[r] for r in matching_ranks] + return filtered_handlers + else: + return None + + +def _send_tensor_to_sglang( + url: str, + tensor_name: str, + gathered_handlers: list[str], + shape: torch.Size, + dtype: str, + flush_cache: bool = False, +) -> None: + """Send gathered IPC handlers to SGLang server via HTTP. + + Key: gathered_handlers are in rank order [rank0, rank1, ...] + SGLang will automatically match: handler = serialized_handlers[tp_rank] + + Args: + url: SGLang server URL + tensor_name: Name of the tensor + gathered_handlers: List of serialized IPC handlers in rank order + shape: Tensor shape + dtype: Tensor dtype + flush_cache: Whether to flush cache after this tensor (for last tensor) + """ + payload = { + "serialized_named_tensors": gathered_handlers, + "flush_cache": flush_cache, + } + + try: + response = requests.post( + url, + json=payload, + headers={"Content-Type": "application/json"}, + timeout=120, + ) + response.raise_for_status() + except requests.exceptions.HTTPError as e: + error_msg = f"Failed to send tensor '{tensor_name}' to {url}: {e}" + try: + error_detail = response.text + error_msg += f"\nResponse status: {response.status_code}" + error_msg += f"\nResponse body: {error_detail[:500]}" + except: + pass + print(f"[sglang refit] {error_msg}", flush=True) + raise RuntimeError(error_msg) from e + except Exception as e: + raise RuntimeError( + f"Failed to send tensor '{tensor_name}' to {url}: {e}" + ) from e diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index c84617b09e..c1c1fb0beb 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -1727,6 +1727,51 @@ def dtensor_params_generator(): worker_name=str(self), ) + @torch.no_grad() + @wrap_with_nvtx_name("dtensor_policy_worker_v2/stream_weights_via_http") + def stream_weights_via_http( + self, + sglang_url_to_gpu_uuids: dict[str, list[str]], + ) -> None: + """Stream model weights to SGLang servers via HTTP API. + + Args: + sglang_url_to_gpu_uuids: Dict mapping SGLang server URL to list of GPU UUIDs it uses + """ + # Manually move model to cuda for cpu offload case + if self.cpu_offload: + self.model = self.move_to_cuda(self.model) + + from nemo_rl.models.policy.utils import stream_weights_via_http_impl + + # Get current GPU UUID + current_device_uuid = self.report_device_id() + + def dtensor_params_generator(): + """Generator that yields (name, tensor) pairs, converting DTensors to local tensors. + """ + state_dict_items = sorted(self.model.state_dict().items(), key=lambda x: x[0]) + for name, tensor in state_dict_items: + if isinstance(tensor, DTensor): + # Convert DTensor to full tensor for streaming + full_tensor = tensor.full_tensor() + # Convert to target dtype + yield ( + name, + full_tensor.to(self.dtype, non_blocking=True).contiguous(), + ) + else: + # Convert to target dtype + yield name, tensor.to(self.dtype, non_blocking=True).contiguous() + # Use the HTTP implementation + stream_weights_via_http_impl( + params_generator=dtensor_params_generator(), + sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, + rank=self.rank, + worker_name=str(self), + current_device_uuid=current_device_uuid, + ) + @torch.no_grad() def broadcast_weights_for_collective( self, kv_scales: Optional[dict[str, float]] = None diff --git a/pyproject.toml b/pyproject.toml index 6225fe43a7..ca0ed6dd2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,26 @@ vllm = [ # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved "causal-conv1d", ] +sglang = [ + "sglang>=0.4.1", + "pybase64", + "orjson", + "uvloop", + "requests", + "openai", + "partial-json-parser", + "sentencepiece", + "sgl-kernel==0.3.17.post1", + "compressed-tensors", + "msgspec", + "python-multipart", + "torchao", + "xgrammar", + "interegular", + "openai-harmony", + "torch-memory-saver", + "einops", +] mcore = [ # also need cudnn (https://developer.nvidia.com/cudnn-downloads?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=deb_network) # wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.1-1_all.deb diff --git a/run.sh b/run.sh new file mode 100755 index 0000000000..fcea74f835 --- /dev/null +++ b/run.sh @@ -0,0 +1,20 @@ +#!/bin/bash +set -e + +VENV_NAME=".venv_test" +CONFIG_FILE="examples/configs/grpo_math_1B_sglang.yaml" + +if [ -d "$VENV_NAME" ]; then + echo "Removing existing virtual environment..." + rm -rf "$VENV_NAME" +fi + +uv venv "$VENV_NAME" +source "$VENV_NAME/bin/activate" +uv pip install -e ".[sglang]" + +echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES" + + +python examples/run_grpo_math.py --config "$CONFIG_FILE" +