diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index b0971aa198..b959e2828f 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1289,6 +1289,10 @@ def grpo_train( policy_generation.prepare_for_generation() dynamic_sampling_num_gen_batches += 1 + if dynamic_sampling_num_gen_batches == 1 and hasattr( + policy_generation, "snapshot_step_metrics" + ): + policy_generation.snapshot_step_metrics() with timer.time("generation"): # Clear logger metrics for each generation step if policy_generation is not None: @@ -1431,6 +1435,9 @@ def grpo_train( # If the current batch is not enough to fill the buffer during dynamic sampling, we update the cache and process the next batch. if not is_batch_complete: continue + gen_step_metrics = {} + if hasattr(policy_generation, "get_step_metrics"): + gen_step_metrics = policy_generation.get_step_metrics() advantages = (rewards - baseline).unsqueeze(-1) if master_config["grpo"]["normalize_rewards"]: @@ -1649,6 +1656,7 @@ def grpo_train( metrics["reward"] = repeated_batch["total_reward"].numpy() metrics.update(train_results["all_mb_metrics"]) + metrics.update(gen_step_metrics) for k, v in metrics.items(): if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: valid_values = [x for x in v if not np.isinf(x)] diff --git a/nemo_rl/models/generation/__init__.py b/nemo_rl/models/generation/__init__.py index c50598cb86..25bb0596df 100644 --- a/nemo_rl/models/generation/__init__.py +++ b/nemo_rl/models/generation/__init__.py @@ -42,6 +42,17 @@ def configure_generation_config( config = cast(VllmConfig, config) # set load_format config["vllm_cfg"]["load_format"] = "auto" if is_eval else "dummy" + is_spec = "speculative_config" in config.get("vllm_kwargs", {}) + if is_spec: + # When speculative decoding is enabled but the draft model is not co-trained + # with the policy (i.e., no weight sync for the draft model), we must use + # load_format='auto' to load actual weights. Using 'dummy' would leave the + # draft model with random weights that never get updated. + warnings.warn( + "Speculative decoding is enabled. Setting vllm_cfg['load_format'] to 'auto'. " + "This may result in slower startup times as full model weights are loaded." + ) + config["vllm_cfg"]["load_format"] = "auto" # Respect the skip_tokenizer_init setting from the config. VLMs for example, require this to be False. if "skip_tokenizer_init" not in config["vllm_cfg"]: diff --git a/nemo_rl/models/generation/vllm/utils.py b/nemo_rl/models/generation/vllm/utils.py index d4a8cd88ef..4be7d95117 100644 --- a/nemo_rl/models/generation/vllm/utils.py +++ b/nemo_rl/models/generation/vllm/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict from typing import Any, Optional from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -82,3 +83,103 @@ def _get_regular_prompt(index: int): prompts.append(_get_regular_prompt(i)) return prompts if return_all else prompts[0] + + +def aggregate_spec_decode_counters( + worker_metrics: list[dict[str, float | list[float]]], +) -> dict[str | tuple[str, int], float]: + """Aggregate speculative decoding counters from multiple workers. + + Combines spec decode metrics collected from DP leader workers into + a single aggregated counter dictionary. + + Args: + worker_metrics: List of metric dictionaries from each worker. + Each dict maps metric names to float values or lists of floats + (for per-position metrics). + + Returns: + Dictionary mapping metric names to their aggregated float values. + Per-position metrics use (name, position) tuples as keys. + + Example: + >>> metrics_from_workers = policy_generation.get_metrics() + >>> counters = aggregate_spec_decode_counters(metrics_from_workers) + >>> print(counters.get("vllm:spec_decode_num_drafts", 0)) + 1234.0 + """ + counters: dict[str | tuple[str, int], float] = defaultdict(float) + + for report in worker_metrics: + for metric_name, value in report.items(): + if "spec_decode" in metric_name: + if isinstance(value, list): + # Per-position metrics (e.g., acceptance counts at each draft position) + for position, pos_value in enumerate(value, 1): + counters[metric_name, position] += pos_value + else: + counters[metric_name] += value + + return dict(counters) + + +def compute_spec_decode_metrics( + start_counters: dict[str | tuple[str, int], float], + end_counters: dict[str | tuple[str, int], float], +) -> dict[str, float]: + """Compute delta and derived metrics for speculative decoding. + + Calculates the difference between two counter snapshots and derives + acceptance rate and acceptance length metrics for logging. + + Args: + start_counters: Counter snapshot taken before generation. + end_counters: Counter snapshot taken after generation. + + Returns: + Dictionary of metrics suitable for logging to wandb/tensorboard. + Keys are prefixed with "vllm/" for namespace consistency. + Includes: + - vllm/spec_num_drafts: Total number of draft batches + - vllm/spec_num_draft_tokens: Total draft tokens generated + - vllm/spec_num_accepted_tokens: Total tokens accepted + - vllm/spec_acceptance_length: Average accepted tokens per draft + 1 + - vllm/spec_acceptance_rate: Ratio of accepted to draft tokens + - vllm/{metric}-{position}: Per-position acceptance counts + - vllm/spec_acceptance_rate-pos-{position}: Per-position acceptance rates + """ + keys = set(start_counters) | set(end_counters) + delta = {k: end_counters.get(k, 0.0) - start_counters.get(k, 0.0) for k in keys} + + num_drafts = delta.get("vllm:spec_decode_num_drafts", 0.0) + num_draft_tokens = delta.get("vllm:spec_decode_num_draft_tokens", 0.0) + num_accepted_tokens = delta.get("vllm:spec_decode_num_accepted_tokens", 0.0) + + # acceptance_length = 1 + (accepted / drafts) represents average tokens + # generated per draft batch (1 target model token + accepted draft tokens) + acceptance_length = ( + 1.0 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1.0 + ) + acceptance_rate = ( + num_accepted_tokens / num_draft_tokens if num_draft_tokens > 0 else 0.0 + ) + + spec_metrics: dict[str, float] = { + "vllm/spec_num_drafts": num_drafts, + "vllm/spec_num_draft_tokens": num_draft_tokens, + "vllm/spec_num_accepted_tokens": num_accepted_tokens, + "vllm/spec_acceptance_length": acceptance_length, + "vllm/spec_acceptance_rate": acceptance_rate, + } + + # Add per-position metrics for detailed analysis + for key, value in delta.items(): + if isinstance(key, tuple): + metric_name, position = key + spec_metrics[f"vllm/{metric_name}-{position}"] = value + if num_drafts > 0: + spec_metrics[f"vllm/spec_acceptance_rate-pos-{position}"] = ( + value / num_drafts + ) + + return spec_metrics diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 1366ce28c5..6138dfdb43 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -14,6 +14,7 @@ import asyncio import os +import warnings from collections import defaultdict from typing import ( Any, @@ -36,6 +37,10 @@ GenerationOutputSpec, ) from nemo_rl.models.generation.vllm.config import VllmConfig +from nemo_rl.models.generation.vllm.utils import ( + aggregate_spec_decode_counters, + compute_spec_decode_metrics, +) # Global thresholds for top_k and top_p validation. # While top-k/p are not supported, these values allow for token filtering while the logprobs should be compatible. @@ -223,6 +228,8 @@ def __init__( # Save the device uuids for the workers self.device_uuids = self._report_device_id() + self._step_metrics_snapshot: dict[str | tuple[str, int], float] | None = None + def _get_tied_worker_bundle_indices( self, cluster: RayVirtualCluster ) -> list[tuple[int, list[int]]]: @@ -381,6 +388,61 @@ def _post_init(self): results = ray.get(futures) return results + def _get_raw_spec_counters(self) -> dict[str | tuple[str, int], float]: + """Collect raw spec decode counters from workers.""" + futures = self.worker_group.run_all_workers_single_data( + "_get_raw_spec_counters", + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + ) + worker_metrics = ray.get(futures) + + # Aggregate across workers + return aggregate_spec_decode_counters(worker_metrics) + + def snapshot_step_metrics(self) -> None: + """Snapshot current spec decode counters to begin tracking a training step. + + Call this before generation to establish a baseline for metrics delta. + + Raises: + RuntimeWarning: If called twice without get_step_metrics() in between. + """ + if self._step_metrics_snapshot is not None: + warnings.warn( + "snapshot_step_metrics() called again without get_step_metrics(). " + "Previous snapshot will be overwritten.", + RuntimeWarning, + ) + self._step_metrics_snapshot = self._get_raw_spec_counters() + + def get_step_metrics(self) -> dict[str, float]: + """Get speculative decoding metrics delta since snapshot_step_metrics(). + + Returns: + Dictionary of delta metrics with 'vllm/' prefix. + Returns empty dict if snapshot_step_metrics() was not called. + + Raises: + RuntimeWarning: If called without snapshot_step_metrics() first. + """ + if self._step_metrics_snapshot is None: + warnings.warn( + "get_step_metrics() called without snapshot_step_metrics(). " + "Call snapshot_step_metrics() before generation to track metrics.", + RuntimeWarning, + ) + return {} + + counters_end = self._get_raw_spec_counters() + step_metrics = compute_spec_decode_metrics( + self._step_metrics_snapshot, counters_end + ) + + # Reset snapshot for next step + self._step_metrics_snapshot = None + + return step_metrics + def init_collective( self, ip: str, port: int, world_size: int, *, train_world_size: int ) -> list[ray.ObjectRef]: diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 9238533cd2..7e30a33aed 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -281,12 +281,50 @@ def _patch_vllm_vit_flash_attn_backend(): with open(file_to_patch, "w") as f: f.write(content) + def _patch_vllm_speculative_decoding_post_step(): + """Patch vLLM speculative decoding post_step call. + + Related PR: + - https://github.com/vllm-project/vllm/pull/30319 + + This patch fixes the InprocessClient.get_output method to properly + call post_step with the model_executed flag from step_fn. + """ + file_to_patch = _get_vllm_file("v1/engine/core_client.py") + + with open(file_to_patch, "r") as f: + content = f.read() + + old_snippet = ( + " def get_output(self) -> EngineCoreOutputs:\n" + " outputs, _ = self.engine_core.step_fn()\n" + " return outputs and outputs.get(0) or EngineCoreOutputs()" + ) + + new_snippet = ( + " def get_output(self) -> EngineCoreOutputs:\n" + " outputs, model_executed = self.engine_core.step_fn()\n" + " self.engine_core.post_step(model_executed=model_executed)\n" + " return outputs and outputs.get(0) or EngineCoreOutputs()" + ) + + if new_snippet in content or old_snippet not in content: + return + + content = content.replace(old_snippet, new_snippet) + + with open(file_to_patch, "w") as f: + f.write(content) + logger.info("Successfully patched vllm speculative decoding post_step.") + _patch_vllm_init_workers_ray() logger.info("Successfully patched vllm _init_workers_ray.") _patch_vllm_vit_flash_attn_backend() logger.info("Successfully patched vllm vit flash attention backend.") + _patch_vllm_speculative_decoding_post_step() + try: import vllm @@ -415,7 +453,8 @@ def _patch_vllm_vit_flash_attn_backend(): trust_remote_code=True, worker_extension_cls="nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension", enable_sleep_mode=True, - disable_log_stats=True, + # Set disable_log_stats=False so that self.llm.get_metrics() works. + disable_log_stats=False, logprobs_mode="processed_logprobs", **vllm_kwargs, ) @@ -485,6 +524,28 @@ def stop_gpu_profiling(self) -> None: if self.llm is not None: self.llm.collective_rpc("stop_gpu_profiling", args=tuple()) + def _get_raw_spec_counters(self) -> dict[str, float | list[float]]: + """Get speculative decoding metrics from the vLLM engine. + + Collects spec decode counters including number of drafts, + draft tokens, and accepted tokens for monitoring acceptance rates. + + Returns: + Dictionary mapping metric names to their values. + Values may be floats or lists of floats (for per-position metrics). + + Raises: + AssertionError: If called before vLLM engine is initialized. + """ + metrics: dict[str, float | list[float]] = {} + if self.llm is not None: + for metric in self.llm.get_metrics(): + if hasattr(metric, "values"): + metrics[metric.name] = metric.values + elif hasattr(metric, "value"): + metrics[metric.name] = metric.value + return metrics + @ray.remote( runtime_env={**get_nsight_config_if_pattern_matches("vllm_generation_worker")} diff --git a/tests/unit/models/generation/test_vllm_utils.py b/tests/unit/models/generation/test_vllm_utils.py index 4093b4c5ae..a3d366f952 100644 --- a/tests/unit/models/generation/test_vllm_utils.py +++ b/tests/unit/models/generation/test_vllm_utils.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math + +import pytest import torch from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.generation.vllm.utils import ( + aggregate_spec_decode_counters, + compute_spec_decode_metrics, format_prompt_for_vllm_generation, ) @@ -111,3 +116,72 @@ def test_vllm_utils_vlm_with_none_content_fallback_to_tokens_and_sample_idx(): p1 = format_prompt_for_vllm_generation(data, sample_idx=1) assert isinstance(p0, dict) and isinstance(p1, dict) assert "prompt_token_ids" in p0 and "prompt_token_ids" in p1 + + +@pytest.mark.vllm +def test_vllm_speculative_decoding_patch_still_needed(): + # This test reminds to remove the vLLM patch when no longer needed. + # The patch was fixed upstream: https://github.com/vllm-project/vllm/pull/30319 + # When this test fails, remove _patch_vllm_speculative_decoding_post_step() + # from nemo_rl/models/generation/vllm/vllm_worker.py + from importlib.metadata import version + + from packaging.version import Version + + assert Version(version("vllm")) < Version("0.14.0"), ( + "vLLM >= 0.14.0 includes the speculative decoding fix from " + "https://github.com/vllm-project/vllm/pull/30319. " + "Please remove the _patch_vllm_speculative_decoding_post_step() function " + "from nemo_rl/models/generation/vllm/vllm_worker.py" + ) + + +def test_aggregate_spec_decode_counters(): + """Test aggregation of speculative decoding counters from multiple workers.""" + worker_metrics = [ + { + "vllm:spec_decode_num_drafts": 100.0, + "vllm:spec_decode_num_draft_tokens": 300.0, + "vllm:spec_decode_num_accepted_tokens": 240.0, + "other_metric": 999.0, # Should be ignored + }, + { + "vllm:spec_decode_num_drafts": 150.0, + "vllm:spec_decode_num_draft_tokens": 450.0, + "vllm:spec_decode_num_accepted_tokens": 360.0, + }, + ] + + counters = aggregate_spec_decode_counters(worker_metrics) + + assert counters["vllm:spec_decode_num_drafts"] == 250.0 + assert counters["vllm:spec_decode_num_draft_tokens"] == 750.0 + assert counters["vllm:spec_decode_num_accepted_tokens"] == 600.0 + assert "other_metric" not in counters + + +def test_compute_spec_decode_metrics(): + """Test computation of speculative decoding metrics from counter snapshots.""" + start_counters = { + "vllm:spec_decode_num_drafts": 100.0, + "vllm:spec_decode_num_draft_tokens": 300.0, + "vllm:spec_decode_num_accepted_tokens": 200.0, + } + end_counters = { + "vllm:spec_decode_num_drafts": 200.0, + "vllm:spec_decode_num_draft_tokens": 600.0, + "vllm:spec_decode_num_accepted_tokens": 440.0, + } + + metrics = compute_spec_decode_metrics(start_counters, end_counters) + + # Delta values + assert metrics["vllm/spec_num_drafts"] == 100.0 + assert metrics["vllm/spec_num_draft_tokens"] == 300.0 + assert metrics["vllm/spec_num_accepted_tokens"] == 240.0 + + # Derived metrics + # acceptance_length = 1 + (accepted / drafts) = 1 + (240 / 100) = 3.4 + assert math.isclose(metrics["vllm/spec_acceptance_length"], 3.4, rel_tol=1e-6) + # acceptance_rate = accepted / draft_tokens = 240 / 300 = 0.8 + assert math.isclose(metrics["vllm/spec_acceptance_rate"], 0.8, rel_tol=1e-6)