Skip to content

Commit b671719

Browse files
committed
Merge remote-tracking branch 'origin' into yifu/vllm0112_bump
Signed-off-by: Yi-Fu Wu <[email protected]>
2 parents eab6019 + 25ff3f6 commit b671719

File tree

22 files changed

+1348
-965
lines changed

22 files changed

+1348
-965
lines changed

docs/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ ensure-docs-env:
4646
@if [ ! -x "$(PYTHON)" ]; then \
4747
echo "📦 Creating isolated docs environment..."; \
4848
uv venv .venv; \
49-
uv sync --no-config; \
49+
uv sync --project ../pyproject.toml --group docs; \
5050
echo "✅ Docs environment ready."; \
5151
echo "📝 To activate it: $(ACTIVATE_CMD)"; \
5252
fi

docs/fp8.md

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ FP8 generations are recommended to be configured with the following settings:
6666

6767
## Compatibility Note for Deepseek-Style FP8 Training
6868

69-
When using FP8 training with Deepseek-style FP8 (sub-channel scaling), be aware of the following compatibility issue:
69+
The TransformerEngine implementation for this recipe requires **cuda version ≥ 12.9**. The latest nemo-rl depends on torch 2.8.0 + cuda 12.9 (since this [commit](https://github.com/NVIDIA-NeMo/RL/commit/3f36d14b53e906b27c01c06e36dbbd2b8eb300cd)). Users should check-out code to latest and build container from `docker/Dockerfile` ([instructions](docker.md)).
7070

71-
The TransformerEngine implementation for this recipe requires **cuBLAS version ≥ 12.9**. However, `nemo-rl` currently depends on **Torch 2.7.1**, which in turn requires **CUDA 12.8**. As a result, attempting to use the default setup will trigger the following error:
71+
If you are using nemo-rl before this [commit](https://github.com/NVIDIA-NeMo/RL/commit/3f36d14b53e906b27c01c06e36dbbd2b8eb300cd), you will see the following error when trying to use fp8 training
7272

7373
```
7474
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/transformer_engine/pytorch/fp8.py", line 646, in fp8_autocast
@@ -78,11 +78,6 @@ assert fp8_block_available, reason_for_no_fp8_block
7878
^^^^^^^^^^^^^^^^^^^
7979
AssertionError: FP8 block scaled GEMM requires Hopper and CUDA >= 12.9.
8080
```
81-
This issue will be resolved once the Torch version is upgraded to **≥ 2.8.0** (Please follow [#1122](https://github.com/NVIDIA-NeMo/RL/issues/1122) for more progress on the upgrade). In the meantime, you can enable Deepseek-style FP8 training using the following workaround:
82-
83-
- **Build the NGC PyTorch container** from `docker/Dockerfile.ngc_pytorch`.
84-
This setup uses the system Python environment, which includes **CUDA version 12.9 or higher**, meeting the requirements for TransformerEngine’s FP8 implementation.
85-
8681

8782

8883
## Accuracy

docs/pyproject.toml

Lines changed: 0 additions & 22 deletions
This file was deleted.

docs/uv.lock

Lines changed: 0 additions & 846 deletions
This file was deleted.

examples/configs/grpo_math_1B.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ policy:
229229
use_deep_gemm: False
230230
num_last_layers_in_bf16: 0
231231
num_first_layers_in_bf16: 0
232+
enable_vllm_metrics_logger: true # Set to true to enable vLLM internal metrics logger, turn off for better performance
233+
vllm_metrics_logger_interval: 0.5 # Interval in seconds to collect vLLM logger metrics
232234
vllm_kwargs: {}
233235
colocated:
234236
# true: generation shares training GPUs
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
defaults: ../../grpo_math_1B.yaml
2+
grpo:
3+
val_period: -1
4+
loss_fn:
5+
reference_policy_kl_penalty: 0.04
6+
use_importance_sampling_correction: true
7+
checkpointing:
8+
enabled: false
9+
checkpoint_dir: results/grpo_megatron
10+
save_period: 10000
11+
policy:
12+
model_name: moonshotai/Moonlight-16B-A3B-Instruct
13+
train_micro_batch_size: 1
14+
generation_batch_size: 64
15+
logprob_batch_size: 1
16+
max_total_sequence_length: 8192
17+
dtensor_cfg:
18+
enabled: false
19+
sequence_packing:
20+
algorithm: modified_ffd
21+
make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size}
22+
optimizer: null
23+
megatron_cfg:
24+
enabled: true
25+
moe_router_dtype: fp32
26+
expert_model_parallel_size: 4
27+
pipeline_model_parallel_size: 4
28+
num_layers_in_first_pipeline_stage: 7
29+
num_layers_in_last_pipeline_stage: 6
30+
apply_rope_fusion: false
31+
fp8_cfg:
32+
enabled: true
33+
fp8: e4m3
34+
fp8_recipe: blockwise
35+
fp8_param: false
36+
optimizer:
37+
lr: 1.0e-06
38+
use_precision_aware_optimizer: false
39+
scheduler:
40+
lr_warmup_iters: 50
41+
env_vars:
42+
NVTE_FP8_BLOCK_SCALING_FP32_SCALES: '1'
43+
generation:
44+
vllm_cfg:
45+
precision: fp8
46+
use_deep_gemm: true
47+
gpu_memory_utilization: 0.5
48+
expert_parallel_size: 4
49+
quantization_ignored_layer_kws: [
50+
a_proj,
51+
b_proj
52+
]
53+
logger:
54+
monitor_gpus: false
55+
wandb:
56+
name: grpo-moonlight-16B-A3B-Instruct
57+
cluster:
58+
gpus_per_node: 8
59+
num_nodes: 4

examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ policy:
3131
PYTORCH_CUDA_ALLOC_CONF: expandable_segments:False
3232
generation:
3333
vllm_cfg:
34-
tensor_parallel_size: 4
34+
tensor_parallel_size: 2
3535
logger:
3636
log_dir: logs/grpo-qwen3-30ba3b-4n8g
3737
wandb_enabled: true

nemo_rl/algorithms/grpo.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,8 @@ def grpo_train(
10731073

10741074
dynamic_sampling_num_gen_batches += 1
10751075
with timer.time("generation"):
1076+
# Clear vLLM logger metrics for each generation step
1077+
policy_generation.clear_vllm_logger_metrics()
10761078
# Use penguin rollouts if enabled. We cascade penguin first since penguin requires async rollouts.
10771079
if _should_use_penguin(master_config):
10781080
generation_config = master_config["policy"]["generation"]
@@ -1122,6 +1124,9 @@ def grpo_train(
11221124
greedy=False,
11231125
)
11241126
policy_generation.finish_generation()
1127+
# Collect vLLM logger metrics for performance reporting after each generation step
1128+
# inflight batch sizes and num pending samples are collected from each vLLM worker
1129+
vllm_logger_metrics = policy_generation.get_vllm_logger_metrics()
11251130

11261131
repeated_batch = scale_rewards(
11271132
repeated_batch, master_config["grpo"]["reward_scaling"]
@@ -1340,6 +1345,7 @@ def grpo_train(
13401345
metrics[k] = np.sum(v).item()
13411346

13421347
metrics.update(rollout_metrics)
1348+
metrics["vllm_logger_metrics"] = vllm_logger_metrics
13431349
total_valid_tokens += metrics["global_valid_toks"]
13441350

13451351
## Checkpointing
@@ -1907,6 +1913,9 @@ def async_grpo_train(
19071913

19081914
print("✅ All setup complete, starting buffer wait...")
19091915

1916+
# Clear vLLM logger metrics after at start of training
1917+
policy_generation.clear_vllm_logger_metrics()
1918+
19101919
# Wait for initial buffer fill
19111920
print(
19121921
f"⏳ Waiting for replay buffer to have sufficient trajectories ({min_trajectories_needed} trajectories)..."
@@ -2145,12 +2154,17 @@ def async_grpo_train(
21452154
train_results = policy.train(train_data, loss_fn)
21462155

21472156
print("🔄 Synchronizing policy weights to trajectory collector…")
2157+
vllm_logger_metrics = None
21482158
if NEED_REFIT:
21492159
# Measure pending-generation wait as exposed_generation time
21502160
print("🔄 Coordinating with trajectory collector before refit...")
21512161
with timer.time("exposed_generation"):
21522162
ray.get(trajectory_collector.prepare_for_refit.remote())
21532163

2164+
# Collect vLLM logger metrics for performance reporting
2165+
# inflight batch sizes and num pending samples are collected from each vLLM worker
2166+
vllm_logger_metrics = policy_generation.get_vllm_logger_metrics()
2167+
21542168
# Only the actual refit/weight transfer should be counted as weight_sync
21552169
print("🔄 Performing policy generation refit...")
21562170
with timer.time("weight_sync"):
@@ -2164,6 +2178,9 @@ def async_grpo_train(
21642178
trajectory_collector.set_weight_version.remote(weight_version)
21652179
trajectory_collector.resume_after_refit.remote()
21662180

2181+
# Clear vLLM logger metrics after each refit (weight sync), starting a new logging cycle
2182+
policy_generation.clear_vllm_logger_metrics()
2183+
21672184
# Validation
21682185
val_metrics, validation_timings = None, None
21692186
is_last_step = step + 1 == master_config["grpo"]["max_num_steps"]
@@ -2241,6 +2258,8 @@ def async_grpo_train(
22412258
else:
22422259
metrics[k] = np.sum(v).item()
22432260
metrics.update(rollout_metrics)
2261+
if vllm_logger_metrics is not None:
2262+
metrics["vllm_logger_metrics"] = vllm_logger_metrics
22442263
total_valid_tokens += metrics["global_valid_toks"]
22452264

22462265
# Checkpointing (same as sync version)

nemo_rl/algorithms/utils.py

Lines changed: 125 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import random
1717
import warnings
1818
from functools import partial, wraps
19-
from typing import Optional
19+
from typing import Any, Optional
2020

2121
import numpy as np
2222
import torch
@@ -384,7 +384,7 @@ def maybe_pad_last_batch(batch: dict, dp_size: int, mbs: int) -> dict:
384384

385385
def print_performance_metrics(
386386
train_results: dict[str, float],
387-
metrics: dict[str, float],
387+
metrics: dict[str, Any],
388388
timing_metrics: dict[str, float],
389389
master_config: dict,
390390
) -> dict[str, float]:
@@ -400,13 +400,14 @@ def visualize_per_worker_load(per_worker_token_counts: dict[int, int]) -> float:
400400
per_worker_load_ratio = [
401401
v / max(per_worker_token_counts_list) for v in per_worker_token_counts_list
402402
]
403-
max_rows_to_print = 100
403+
max_rows_to_print = 1000
404+
bar_length = 20
404405
print(" • Visualizing Token Imbalance per Generation Worker:")
405406
for i in range(min(len(per_worker_token_counts_list), max_rows_to_print)):
406407
print(
407408
f" - Generated Tokens from Worker {i:3.0f}:"
408-
f"{'■' * int(per_worker_load_ratio[i] * 10)}"
409-
f"{'□' * (10 - int(per_worker_load_ratio[i] * 10))}"
409+
f"{'■' * int(per_worker_load_ratio[i] * bar_length)}"
410+
f"{'□' * (bar_length - int(per_worker_load_ratio[i] * bar_length))}"
410411
f" Count: {per_worker_token_counts_list[i] / 1000:.1f}K"
411412
)
412413
estimated_idle_ratio = 1 - sum(per_worker_load_ratio) / len(
@@ -441,6 +442,125 @@ def visualize_per_worker_load(per_worker_token_counts: dict[int, int]) -> float:
441442
f" • Mean Total Tokens per Sample: {metrics['mean_total_tokens_per_sample']:.2f}"
442443
)
443444

445+
# =====================================================
446+
# vLLM Logger Metrics (inflight batch sizes, num pending samples, etc.)
447+
# =====================================================
448+
def resize_timeline(data, new_size):
449+
old_size = len(data)
450+
x_old = np.linspace(0, 1, old_size)
451+
x_new = np.linspace(0, 1, new_size)
452+
return np.interp(x_new, x_old, data)
453+
454+
def get_min_idle_time(
455+
metric_dict: dict[int, list[int]], timeline_interval: float
456+
) -> float:
457+
min_idle_time = float("inf")
458+
for _, metric_values in metric_dict.items():
459+
count_zeros = lambda x: sum(v == 0 for v in x)
460+
idle_time = count_zeros(metric_values) * timeline_interval
461+
min_idle_time = min(min_idle_time, idle_time)
462+
return min_idle_time
463+
464+
def visualize_per_worker_timeline(
465+
metric_dict: dict[int, list[int]],
466+
metric_name: str,
467+
timeline_interval: float | None,
468+
) -> None:
469+
dp_ranks = list(metric_dict.keys())
470+
max_rows_to_print = 1000
471+
max_timeline_length = 50
472+
marker = {0: "▃", 1: "▅", 2: "▆", 3: "▉"}
473+
zero_marker = "▁"
474+
475+
max_value = max((max(v) if v else 0) for v in metric_dict.values())
476+
bin_width = (max_value + 1) / len(marker)
477+
478+
print(f" - {metric_name}:")
479+
print(f" - Max value: {max_value}")
480+
if timeline_interval is not None:
481+
print(
482+
f" - Min idle time: {get_min_idle_time(metric_dict, timeline_interval)} s"
483+
)
484+
print(
485+
f" - Timeline (0: {zero_marker}, {', '.join(f'{1.0 if k == 0 else k * (max_value / len(marker))}-{(k + 1) * (max_value / len(marker))}: {marker[k]}' for k in marker.keys())}):"
486+
)
487+
for dp_idx, metric_values in metric_dict.items():
488+
if dp_idx > max_rows_to_print:
489+
break
490+
timeline = []
491+
length = len(metric_values)
492+
if timeline_interval is not None:
493+
count_zeros = lambda x: sum(v == 0 for v in x)
494+
idle = count_zeros(metric_values) * timeline_interval
495+
active = length * timeline_interval - idle
496+
if length > max_timeline_length:
497+
resized_metric_values = resize_timeline(
498+
metric_values, max_timeline_length
499+
)
500+
else:
501+
resized_metric_values = metric_values
502+
503+
for i, value in enumerate(resized_metric_values):
504+
m = (
505+
zero_marker
506+
if value == 0
507+
else marker[min(int(value // bin_width), len(marker) - 1)]
508+
)
509+
timeline.append(m)
510+
if timeline_interval is not None:
511+
print(
512+
f" - Generation Worker {dp_idx:3.0f}: {''.join(timeline)} (Active: {active:.2f} s, Idle: {idle:.2f} s)"
513+
)
514+
else:
515+
print(f" - Generation Worker {dp_idx:3.0f}: {''.join(timeline)}")
516+
517+
is_vllm_metrics_logger_enabled = master_config["policy"]["generation"].get(
518+
"vllm_cfg", {}
519+
).get("enable_vllm_metrics_logger", False) and master_config["policy"][
520+
"generation"
521+
].get("vllm_cfg", {}).get("async_engine", False)
522+
if is_vllm_metrics_logger_enabled:
523+
vllm_logger_metrics = metrics["vllm_logger_metrics"]
524+
# vllm_logger_me trics: dict[str (metric_name), dict[int (dp_idx), list[int] (metric_values)]]
525+
# metric_name: "inflight_batch_sizes" or "num_pending_samples"
526+
527+
assert "inflight_batch_sizes" in vllm_logger_metrics, (
528+
"inflight_batch_sizes not found in vllm_logger_metrics"
529+
)
530+
assert "num_pending_samples" in vllm_logger_metrics, (
531+
"num_pending_samples not found in vllm_logger_metrics"
532+
)
533+
assert isinstance(vllm_logger_metrics["inflight_batch_sizes"], dict), (
534+
"inflight_batch_sizes must be a dictionary"
535+
)
536+
assert isinstance(vllm_logger_metrics["num_pending_samples"], dict), (
537+
"num_pending_samples must be a dictionary"
538+
)
539+
540+
vllm_metrics_logger_interval = master_config["policy"]["generation"][
541+
"vllm_cfg"
542+
]["vllm_metrics_logger_interval"]
543+
print(" • vLLM Logger Metrics:")
544+
# Visualize the inflight batch sizes timeline
545+
if len(vllm_logger_metrics["inflight_batch_sizes"].values()) > 0:
546+
visualize_per_worker_timeline(
547+
vllm_logger_metrics["inflight_batch_sizes"],
548+
"Inflight Batch Sizes",
549+
vllm_metrics_logger_interval,
550+
)
551+
if len(vllm_logger_metrics["num_pending_samples"].values()) > 0:
552+
max_num_pending_samples = max(
553+
(max(v) if v else 0)
554+
for v in vllm_logger_metrics["num_pending_samples"].values()
555+
)
556+
# If there is at least one pending sample, visualize the timeline
557+
if max_num_pending_samples > 0:
558+
visualize_per_worker_timeline(
559+
vllm_logger_metrics["num_pending_samples"],
560+
"Num Pending Samples",
561+
None,
562+
)
563+
444564
# =====================================================
445565
# Throughputs
446566
# =====================================================

0 commit comments

Comments
 (0)