Skip to content
8 changes: 8 additions & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,10 @@ def grpo_train(
policy_generation.prepare_for_generation()

dynamic_sampling_num_gen_batches += 1
if dynamic_sampling_num_gen_batches == 1 and hasattr(
policy_generation, "snapshot_step_metrics"
):
policy_generation.snapshot_step_metrics()
with timer.time("generation"):
# Clear logger metrics for each generation step
if policy_generation is not None:
Expand Down Expand Up @@ -1431,6 +1435,9 @@ def grpo_train(
# If the current batch is not enough to fill the buffer during dynamic sampling, we update the cache and process the next batch.
if not is_batch_complete:
continue
gen_step_metrics = {}
if hasattr(policy_generation, "get_step_metrics"):
gen_step_metrics = policy_generation.get_step_metrics()
advantages = (rewards - baseline).unsqueeze(-1)

if master_config["grpo"]["normalize_rewards"]:
Expand Down Expand Up @@ -1649,6 +1656,7 @@ def grpo_train(
metrics["reward"] = repeated_batch["total_reward"].numpy()

metrics.update(train_results["all_mb_metrics"])
metrics.update(gen_step_metrics)
for k, v in metrics.items():
if k in {"probs_ratio_min", "probs_ratio_clamped_min"}:
valid_values = [x for x in v if not np.isinf(x)]
Expand Down
11 changes: 11 additions & 0 deletions nemo_rl/models/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def configure_generation_config(
config = cast(VllmConfig, config)
# set load_format
config["vllm_cfg"]["load_format"] = "auto" if is_eval else "dummy"
is_spec = "speculative_config" in config.get("vllm_kwargs", {})
if is_spec:
# When speculative decoding is enabled but the draft model is not co-trained
# with the policy (i.e., no weight sync for the draft model), we must use
# load_format='auto' to load actual weights. Using 'dummy' would leave the
# draft model with random weights that never get updated.
warnings.warn(
"Speculative decoding is enabled. Setting vllm_cfg['load_format'] to 'auto'. "
"This may result in slower startup times as full model weights are loaded."
)
config["vllm_cfg"]["load_format"] = "auto"

# Respect the skip_tokenizer_init setting from the config. VLMs for example, require this to be False.
if "skip_tokenizer_init" not in config["vllm_cfg"]:
Expand Down
101 changes: 101 additions & 0 deletions nemo_rl/models/generation/vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from typing import Any, Optional

from nemo_rl.distributed.batched_data_dict import BatchedDataDict
Expand Down Expand Up @@ -82,3 +83,103 @@ def _get_regular_prompt(index: int):
prompts.append(_get_regular_prompt(i))

return prompts if return_all else prompts[0]


def aggregate_spec_decode_counters(
worker_metrics: list[dict[str, float | list[float]]],
) -> dict[str | tuple[str, int], float]:
"""Aggregate speculative decoding counters from multiple workers.

Combines spec decode metrics collected from DP leader workers into
a single aggregated counter dictionary.

Args:
worker_metrics: List of metric dictionaries from each worker.
Each dict maps metric names to float values or lists of floats
(for per-position metrics).

Returns:
Dictionary mapping metric names to their aggregated float values.
Per-position metrics use (name, position) tuples as keys.

Example:
>>> metrics_from_workers = policy_generation.get_metrics()
>>> counters = aggregate_spec_decode_counters(metrics_from_workers)
>>> print(counters.get("vllm:spec_decode_num_drafts", 0))
1234.0
"""
counters: dict[str | tuple[str, int], float] = defaultdict(float)

for report in worker_metrics:
for metric_name, value in report.items():
if "spec_decode" in metric_name:
if isinstance(value, list):
# Per-position metrics (e.g., acceptance counts at each draft position)
for position, pos_value in enumerate(value, 1):
counters[metric_name, position] += pos_value
else:
counters[metric_name] += value

return dict(counters)


def compute_spec_decode_metrics(
start_counters: dict[str | tuple[str, int], float],
end_counters: dict[str | tuple[str, int], float],
) -> dict[str, float]:
"""Compute delta and derived metrics for speculative decoding.

Calculates the difference between two counter snapshots and derives
acceptance rate and acceptance length metrics for logging.

Args:
start_counters: Counter snapshot taken before generation.
end_counters: Counter snapshot taken after generation.

Returns:
Dictionary of metrics suitable for logging to wandb/tensorboard.
Keys are prefixed with "vllm/" for namespace consistency.
Includes:
- vllm/spec_num_drafts: Total number of draft batches
- vllm/spec_num_draft_tokens: Total draft tokens generated
- vllm/spec_num_accepted_tokens: Total tokens accepted
- vllm/spec_acceptance_length: Average accepted tokens per draft + 1
- vllm/spec_acceptance_rate: Ratio of accepted to draft tokens
- vllm/{metric}-{position}: Per-position acceptance counts
- vllm/spec_acceptance_rate-pos-{position}: Per-position acceptance rates
"""
keys = set(start_counters) | set(end_counters)
delta = {k: end_counters.get(k, 0.0) - start_counters.get(k, 0.0) for k in keys}

num_drafts = delta.get("vllm:spec_decode_num_drafts", 0.0)
num_draft_tokens = delta.get("vllm:spec_decode_num_draft_tokens", 0.0)
num_accepted_tokens = delta.get("vllm:spec_decode_num_accepted_tokens", 0.0)

# acceptance_length = 1 + (accepted / drafts) represents average tokens
# generated per draft batch (1 target model token + accepted draft tokens)
acceptance_length = (
1.0 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1.0
)
acceptance_rate = (
num_accepted_tokens / num_draft_tokens if num_draft_tokens > 0 else 0.0
)

spec_metrics: dict[str, float] = {
"vllm/spec_num_drafts": num_drafts,
"vllm/spec_num_draft_tokens": num_draft_tokens,
"vllm/spec_num_accepted_tokens": num_accepted_tokens,
"vllm/spec_acceptance_length": acceptance_length,
"vllm/spec_acceptance_rate": acceptance_rate,
}

# Add per-position metrics for detailed analysis
for key, value in delta.items():
if isinstance(key, tuple):
metric_name, position = key
spec_metrics[f"vllm/{metric_name}-{position}"] = value
if num_drafts > 0:
spec_metrics[f"vllm/spec_acceptance_rate-pos-{position}"] = (
value / num_drafts
)

return spec_metrics
62 changes: 62 additions & 0 deletions nemo_rl/models/generation/vllm/vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import asyncio
import os
import warnings
from collections import defaultdict
from typing import (
Any,
Expand All @@ -36,6 +37,10 @@
GenerationOutputSpec,
)
from nemo_rl.models.generation.vllm.config import VllmConfig
from nemo_rl.models.generation.vllm.utils import (
aggregate_spec_decode_counters,
compute_spec_decode_metrics,
)

# Global thresholds for top_k and top_p validation.
# While top-k/p are not supported, these values allow for token filtering while the logprobs should be compatible.
Expand Down Expand Up @@ -223,6 +228,8 @@ def __init__(
# Save the device uuids for the workers
self.device_uuids = self._report_device_id()

self._step_metrics_snapshot: dict[str | tuple[str, int], float] | None = None

def _get_tied_worker_bundle_indices(
self, cluster: RayVirtualCluster
) -> list[tuple[int, list[int]]]:
Expand Down Expand Up @@ -381,6 +388,61 @@ def _post_init(self):
results = ray.get(futures)
return results

def _get_raw_spec_counters(self) -> dict[str | tuple[str, int], float]:
"""Collect raw spec decode counters from workers."""
futures = self.worker_group.run_all_workers_single_data(
"_get_raw_spec_counters",
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
)
worker_metrics = ray.get(futures)

# Aggregate across workers
return aggregate_spec_decode_counters(worker_metrics)

def snapshot_step_metrics(self) -> None:
"""Snapshot current spec decode counters to begin tracking a training step.

Call this before generation to establish a baseline for metrics delta.

Raises:
RuntimeWarning: If called twice without get_step_metrics() in between.
"""
if self._step_metrics_snapshot is not None:
warnings.warn(
"snapshot_step_metrics() called again without get_step_metrics(). "
"Previous snapshot will be overwritten.",
RuntimeWarning,
)
self._step_metrics_snapshot = self._get_raw_spec_counters()

def get_step_metrics(self) -> dict[str, float]:
"""Get speculative decoding metrics delta since snapshot_step_metrics().

Returns:
Dictionary of delta metrics with 'vllm/' prefix.
Returns empty dict if snapshot_step_metrics() was not called.

Raises:
RuntimeWarning: If called without snapshot_step_metrics() first.
"""
if self._step_metrics_snapshot is None:
warnings.warn(
"get_step_metrics() called without snapshot_step_metrics(). "
"Call snapshot_step_metrics() before generation to track metrics.",
RuntimeWarning,
)
return {}

counters_end = self._get_raw_spec_counters()
step_metrics = compute_spec_decode_metrics(
self._step_metrics_snapshot, counters_end
)

# Reset snapshot for next step
self._step_metrics_snapshot = None

return step_metrics

def init_collective(
self, ip: str, port: int, world_size: int, *, train_world_size: int
) -> list[ray.ObjectRef]:
Expand Down
63 changes: 62 additions & 1 deletion nemo_rl/models/generation/vllm/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,50 @@ def _patch_vllm_vit_flash_attn_backend():
with open(file_to_patch, "w") as f:
f.write(content)

def _patch_vllm_speculative_decoding_post_step():
"""Patch vLLM speculative decoding post_step call.

Related PR:
- https://github.com/vllm-project/vllm/pull/30319

This patch fixes the InprocessClient.get_output method to properly
call post_step with the model_executed flag from step_fn.
"""
file_to_patch = _get_vllm_file("v1/engine/core_client.py")

with open(file_to_patch, "r") as f:
content = f.read()

old_snippet = (
" def get_output(self) -> EngineCoreOutputs:\n"
" outputs, _ = self.engine_core.step_fn()\n"
" return outputs and outputs.get(0) or EngineCoreOutputs()"
)

new_snippet = (
" def get_output(self) -> EngineCoreOutputs:\n"
" outputs, model_executed = self.engine_core.step_fn()\n"
" self.engine_core.post_step(model_executed=model_executed)\n"
" return outputs and outputs.get(0) or EngineCoreOutputs()"
)

if new_snippet in content or old_snippet not in content:
return

content = content.replace(old_snippet, new_snippet)

with open(file_to_patch, "w") as f:
f.write(content)
logger.info("Successfully patched vllm speculative decoding post_step.")

_patch_vllm_init_workers_ray()
logger.info("Successfully patched vllm _init_workers_ray.")

_patch_vllm_vit_flash_attn_backend()
logger.info("Successfully patched vllm vit flash attention backend.")

_patch_vllm_speculative_decoding_post_step()

try:
import vllm

Expand Down Expand Up @@ -415,7 +453,8 @@ def _patch_vllm_vit_flash_attn_backend():
trust_remote_code=True,
worker_extension_cls="nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension",
enable_sleep_mode=True,
disable_log_stats=True,
# Set disable_log_stats=False so that self.llm.get_metrics() works.
disable_log_stats=False,
logprobs_mode="processed_logprobs",
**vllm_kwargs,
)
Expand Down Expand Up @@ -485,6 +524,28 @@ def stop_gpu_profiling(self) -> None:
if self.llm is not None:
self.llm.collective_rpc("stop_gpu_profiling", args=tuple())

def _get_raw_spec_counters(self) -> dict[str, float | list[float]]:
"""Get speculative decoding metrics from the vLLM engine.

Collects spec decode counters including number of drafts,
draft tokens, and accepted tokens for monitoring acceptance rates.

Returns:
Dictionary mapping metric names to their values.
Values may be floats or lists of floats (for per-position metrics).

Raises:
AssertionError: If called before vLLM engine is initialized.
"""
metrics: dict[str, float | list[float]] = {}
if self.llm is not None:
for metric in self.llm.get_metrics():
if hasattr(metric, "values"):
metrics[metric.name] = metric.values
elif hasattr(metric, "value"):
metrics[metric.name] = metric.value
return metrics


@ray.remote(
runtime_env={**get_nsight_config_if_pattern_matches("vllm_generation_worker")}
Expand Down
Loading
Loading