Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d9cf489
sglang support:initial commit
PrinsYin Nov 23, 2025
3eace5f
sglang:manually set cuda visible to let localran=0 to manage gpus of …
PrinsYin Nov 24, 2025
6fbbbb7
sglang: add sglang setup in grpo.py, add find available port to set u…
PrinsYin Nov 25, 2025
242612c
sglang: add shutdown
PrinsYin Nov 25, 2025
a3d8ad6
sglang server: fix gpu allocation when tp =1
PrinsYin Nov 28, 2025
88971e3
generate only first request
PrinsYin Nov 25, 2025
db8b07b
fix : choose the correct gpu using base gpu id
PrinsYin Nov 26, 2025
dd0e54f
asyncio to roolout all saples
PrinsYin Nov 26, 2025
21c54e3
fix new event loop for rollout
PrinsYin Nov 26, 2025
5e24fab
added mem_fraction
PrinsYin Nov 26, 2025
50189a9
modified build_sampling_paras and stop token handling
PrinsYin Nov 28, 2025
ec35b6b
temp: prevent server overlaod with semaphore
PrinsYin Nov 28, 2025
f099caa
sglang: refactor, move async loop position
PrinsYin Nov 30, 2025
a03eba8
sglang: fix total length in generate
PrinsYin Nov 30, 2025
e08cfd6
sglang: env setup
PrinsYin Nov 30, 2025
ccc66f6
from tensor:
PrinsYin Nov 27, 2025
2ce928b
sglang refit: fix sglang import
PrinsYin Nov 27, 2025
4aa1e74
fix: match fsdp ranks correctly with sglang
PrinsYin Nov 28, 2025
9098077
flush cache before update begins
PrinsYin Nov 28, 2025
9900a33
Fix SGLang compatibility: add hasattr checks for vLLM-specific methods
PrinsYin Dec 1, 2025
5cb78e3
sglang: modified config (increase mem_fration, enable wandb)
PrinsYin Dec 1, 2025
03d9d0c
refactor(grpo): extract init logic for generation backends
PrinsYin Dec 2, 2025
7ca9776
refactor SGLangConfig
PrinsYin Dec 2, 2025
f1c26dd
refactor: generalize logger metrics for all generation backends
PrinsYin Dec 4, 2025
255dcc6
refactor sglang config loading to make it consistent with other backendw
PrinsYin Dec 4, 2025
ee01f91
resolved ai comments
PrinsYin Dec 6, 2025
e25e573
changed print to using loging
PrinsYin Dec 6, 2025
e93699f
Merge branch 'main' into sglang_server
PrinsYin Dec 9, 2025
85d6a92
Update nemo_rl/models/generation/sglang/sglang_worker.py
PrinsYin Dec 17, 2025
be1ae27
Merge branch 'main' into sglang_server
PrinsYin Dec 17, 2025
ede624f
fix comments about config defaults
PrinsYin Dec 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions examples/configs/grpo_math_1B_sglang.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
defaults: grpo_math_1B.yaml

grpo:
val_batch_size: 128

policy:
generation:
backend: "sglang"
sglang_cfg:
# SGLang specific configuration
model_path: ${policy.model_name}
gpus_per_server: 1
dtype: ${policy.precision}
context_length: 512 # Maximum context length
allow_auto_truncate: true
enable_memory_saver: false
dp_size: 1
pp_size: 1
ep_size: 1
max_running_requests: null
mem_fraction_static: 0.7
skip_server_warmup: true

logger:
wandb_enabled: true
238 changes: 147 additions & 91 deletions nemo_rl/algorithms/grpo.py

Large diffs are not rendered by default.

69 changes: 35 additions & 34 deletions nemo_rl/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,46 +521,47 @@ def visualize_per_worker_timeline(
"generation"
].get("vllm_cfg", {}).get("async_engine", False)
if is_vllm_metrics_logger_enabled:
vllm_logger_metrics = metrics["vllm_logger_metrics"]
# vllm_logger_me trics: dict[str (metric_name), dict[int (dp_idx), list[int] (metric_values)]]
vllm_logger_metrics = metrics.get("generation_logger_metrics", {})
# vllm_logger_metrics: dict[str (metric_name), dict[int (dp_idx), list[int] (metric_values)]]
# metric_name: "inflight_batch_sizes" or "num_pending_samples"

assert "inflight_batch_sizes" in vllm_logger_metrics, (
"inflight_batch_sizes not found in vllm_logger_metrics"
)
assert "num_pending_samples" in vllm_logger_metrics, (
"num_pending_samples not found in vllm_logger_metrics"
)
assert isinstance(vllm_logger_metrics["inflight_batch_sizes"], dict), (
"inflight_batch_sizes must be a dictionary"
)
assert isinstance(vllm_logger_metrics["num_pending_samples"], dict), (
"num_pending_samples must be a dictionary"
)

vllm_metrics_logger_interval = master_config["policy"]["generation"][
"vllm_cfg"
]["vllm_metrics_logger_interval"]
print(" • vLLM Logger Metrics:")
# Visualize the inflight batch sizes timeline
if len(vllm_logger_metrics["inflight_batch_sizes"].values()) > 0:
visualize_per_worker_timeline(
vllm_logger_metrics["inflight_batch_sizes"],
"Inflight Batch Sizes",
vllm_metrics_logger_interval,
if vllm_logger_metrics:
assert "inflight_batch_sizes" in vllm_logger_metrics, (
"inflight_batch_sizes not found in vllm_logger_metrics"
)
if len(vllm_logger_metrics["num_pending_samples"].values()) > 0:
max_num_pending_samples = max(
(max(v) if v else 0)
for v in vllm_logger_metrics["num_pending_samples"].values()
assert "num_pending_samples" in vllm_logger_metrics, (
"num_pending_samples not found in vllm_logger_metrics"
)
# If there is at least one pending sample, visualize the timeline
if max_num_pending_samples > 0:
assert isinstance(vllm_logger_metrics["inflight_batch_sizes"], dict), (
"inflight_batch_sizes must be a dictionary"
)
assert isinstance(vllm_logger_metrics["num_pending_samples"], dict), (
"num_pending_samples must be a dictionary"
)

vllm_metrics_logger_interval = master_config["policy"]["generation"][
"vllm_cfg"
]["vllm_metrics_logger_interval"]
print(" • vLLM Logger Metrics:")
# Visualize the inflight batch sizes timeline
if len(vllm_logger_metrics["inflight_batch_sizes"].values()) > 0:
visualize_per_worker_timeline(
vllm_logger_metrics["num_pending_samples"],
"Num Pending Samples",
None,
vllm_logger_metrics["inflight_batch_sizes"],
"Inflight Batch Sizes",
vllm_metrics_logger_interval,
)
if len(vllm_logger_metrics["num_pending_samples"].values()) > 0:
max_num_pending_samples = max(
(max(v) if v else 0)
for v in vllm_logger_metrics["num_pending_samples"].values()
)
# If there is at least one pending sample, visualize the timeline
if max_num_pending_samples > 0:
visualize_per_worker_timeline(
vllm_logger_metrics["num_pending_samples"],
"Num Pending Samples",
None,
)

# =====================================================
# Throughputs
Expand Down
5 changes: 5 additions & 0 deletions nemo_rl/distributed/ray_actor_environment_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@
VLLM_EXECUTABLE = (
PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.VLLM
)
SGLANG_EXECUTABLE = (
PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.SGLANG
)
MCORE_EXECUTABLE = (
PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.MCORE
)

ACTOR_ENVIRONMENT_REGISTRY: dict[str, str] = {
"nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker": VLLM_EXECUTABLE,
"nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker": VLLM_EXECUTABLE,
"nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker": SGLANG_EXECUTABLE,
# Temporary workaround for the coupled implementation of DTensorPolicyWorker and vLLM.
# This will be reverted to PY_EXECUTABLES.BASE once https://github.com/NVIDIA-NeMo/RL/issues/501 is resolved.
"nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker": VLLM_EXECUTABLE,
Expand Down Expand Up @@ -63,3 +67,4 @@ def get_actor_python_env(actor_class_fqn: str) -> str:
"adding a new generation framework or training backend), you'll need to specify the "
"appropriate environment. See uv.md for more details."
)

6 changes: 6 additions & 0 deletions nemo_rl/distributed/virtual_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,16 @@ class PY_EXECUTABLES:
# Use NeMo-RL direct dependencies and nemo-automodel.
AUTOMODEL = f"uv run --locked --extra automodel --directory {git_root}"

# Use NeMo-RL direct dependencies, nemo-automodel, and SGLang.
AUTOMODEL_SGLANG = f"uv run --locked --extra automodel --extra sglang --directory {git_root}"

# Use NeMo-RL direct dependencies and Megatron.
MCORE = f"uv run --locked --extra mcore --directory {git_root}"

# Use NeMo-Gym dependencies
NEMO_GYM = f"uv run --locked --extra nemo_gym --directory {git_root}"
# Use NeMo-RL direct dependencies and SGLang.
SGLANG = "uv run --locked --extra sglang --directory {git_root}"


@ray.remote # pragma: no cover
Expand Down Expand Up @@ -503,3 +508,4 @@ def __del__(self) -> None:
user calls shutdown().
"""
self.shutdown()

19 changes: 19 additions & 0 deletions nemo_rl/models/generation/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,22 @@ def update_weights_from_collective(self) -> list[ray.ObjectRef]:
# (e.g., vLLM prefix/KV caches) after weight updates.
def invalidate_kv_cache(self) -> bool:
return False

def clear_logger_metrics(self) -> None:
"""Clear logger metrics for performance reporting.

This is an optional method that backends can implement to clear
telemetry metrics. Default implementation does nothing.
"""
pass

def get_logger_metrics(self) -> dict[str, Any]:
"""Get logger metrics for performance reporting.

This is an optional method that backends can implement to collect
telemetry metrics. Default implementation returns empty dict.

Returns:
Dictionary of metrics. Format may vary by backend.
"""
return {}
23 changes: 23 additions & 0 deletions nemo_rl/models/generation/sglang/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OR WARRANTIES OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nemo_rl.models.generation.sglang.config import SGLangConfig
from nemo_rl.models.generation.sglang.sglang_generation import SGLangGeneration
from nemo_rl.models.generation.sglang.sglang_worker import SGLangGenerationWorker

__all__ = [
"SGLangConfig",
"SGLangGeneration",
"SGLangGenerationWorker",
]

98 changes: 98 additions & 0 deletions nemo_rl/models/generation/sglang/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, NotRequired, TypedDict

from nemo_rl.models.generation.interfaces import GenerationConfig


class SglangSpecificArgs(TypedDict):
"""SGLang-specific configuration arguments.

Most fields below map directly to SGLang's ServerArgs (see:
https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py).
"""
model_path: NotRequired[str]
gpus_per_server: NotRequired[int]
random_seed: NotRequired[int]
skip_tokenizer_init: NotRequired[bool]
disable_cuda_graph: NotRequired[bool]
disable_radix_cache: NotRequired[bool]
disable_cuda_graph_padding: NotRequired[bool]
enable_nccl_nvls: NotRequired[bool]
disable_outlines_disk_cache: NotRequired[bool]
disable_custom_all_reduce: NotRequired[bool]
disable_overlap_schedule: NotRequired[bool]
enable_mixed_chunk: NotRequired[bool]
enable_dp_attention: NotRequired[bool]
enable_ep_moe: NotRequired[bool]
enable_torch_compile: NotRequired[bool]
torch_compile_max_bs: NotRequired[int]
cuda_graph_max_bs: NotRequired[int | None]
cuda_graph_bs: NotRequired[list[int] | None]
torchao_config: NotRequired[str]
enable_nan_detection: NotRequired[bool]
enable_p2p_check: NotRequired[bool]
triton_attention_reduce_in_fp32: NotRequired[bool]
triton_attention_num_kv_splits: NotRequired[int]
num_continuous_decode_steps: NotRequired[int]
enable_memory_saver: NotRequired[bool]
allow_auto_truncate: NotRequired[bool]
attention_backend: NotRequired[str | None]
enable_multimodal: NotRequired[bool]
sampling_backend: NotRequired[str | None]
context_length: NotRequired[int | None]
mem_fraction_static: NotRequired[float | None]
max_running_requests: NotRequired[int | None]
chunked_prefill_size: NotRequired[int | None]
max_prefill_tokens: NotRequired[int]
schedule_policy: NotRequired[str]
schedule_conservativeness: NotRequired[float]
cpu_offload_gb: NotRequired[int]
dtype: NotRequired[str]
kv_cache_dtype: NotRequired[str]
dp_size: NotRequired[int] # only used for dp attention
pp_size: NotRequired[int] # pipeline parallel size
ep_size: NotRequired[int]
# lora
enable_lora: NotRequired[bool | None]
max_lora_rank: NotRequired[int | None]
lora_target_modules: NotRequired[list[str] | None]
lora_paths: NotRequired[list[str] | None]
max_loaded_loras: NotRequired[int]
max_loras_per_batch: NotRequired[int]
lora_backend: NotRequired[str]
# logging
log_level: NotRequired[str]
log_level_http: NotRequired[str | None]
log_requests: NotRequired[bool]
log_requests_level: NotRequired[int]
show_time_cost: NotRequired[bool]
enable_metrics: NotRequired[bool] # Exports Prometheus-like metrics
# The interval (in decoding iterations) to log throughput
# and update prometheus metrics
decode_log_interval: NotRequired[int]
# Extra loader arguments
enable_multithread_load: NotRequired[bool]
enable_fast_load: NotRequired[bool]
# Server warmup
skip_server_warmup: NotRequired[bool]


class SGLangConfig(GenerationConfig):
"""Configuration for SGLang runtime."""
sglang_cfg: SglangSpecificArgs
sglang_kwargs: NotRequired[dict[str, Any]]


Loading
Loading