Skip to content

feat: KV cache quantization support in fp8 rollout in GRPO#1212

Merged
terrykong merged 44 commits intoNVIDIA-NeMo:mainfrom
sharonyu-115:kv-cache-fp8
Dec 2, 2025
Merged

feat: KV cache quantization support in fp8 rollout in GRPO#1212
terrykong merged 44 commits intoNVIDIA-NeMo:mainfrom
sharonyu-115:kv-cache-fp8

Conversation

@sharonyu-115
Copy link
Contributor

@sharonyu-115 sharonyu-115 commented Sep 26, 2025

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Hi @guyueh1 , as discussed, I'm creating this draft PR for you to review the initial design so that we can discuss the refactor and next steps etc.

Current experiment result of Qwen3-8B: (Orange line: bf16. Green: default FP8 rollout. Blue: default FP8 rollout + KV cache FP8.)
image
Note that there is room for calibration optimization to reduce the total step time.

@zpqiu fyi.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features
    • Added FP8 KV-cache support with configurable kv_cache dtype and optional KV-scale calibration/synchronization across generation and training.
    • Enabled KV-scale propagation during model refits and weight updates with vLLM.
  • Bug Fixes
    • Improved checkpoint saving stability by reducing GPU memory pressure and clearing caches to avoid OOM errors.
  • Documentation
    • Expanded docs for FP8 KV-cache settings and KV-scale usage.
  • Chores
    • Added runtime diagnostics and compatibility checks for FP8 KV-cache across backends.

@sharonyu-115 sharonyu-115 requested review from a team as code owners September 26, 2025 12:04
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 26, 2025

📝 Walkthrough

Walkthrough

Adds FP8 KV-cache scale handling across GRPO refit and vLLM update paths: computes/calibrates Q/K/V FP8 scales, caches and threads them through IPC/NCCL weight updates, and applies them post-load in vLLM. Introduces config for kv_cache_dtype, backend compatibility checks, and new calibration methods in policy interfaces and workers.

Changes

Cohort / File(s) Summary of Changes
GRPO algorithm refit and sync
nemo_rl/algorithms/grpo.py
Adds FP8 KV-scale threading through refit_policy_generation (new kv_scales arg). Computes/reuses kv_scales cache during generation/refit, propagates to IPC/NCCL update paths, and marks generation stale when recalibration occurs. Adds _should_sync_kv_scales and FP8 kv_cache compatibility checks.
vLLM update pipeline (IPC/collective)
nemo_rl/models/generation/vllm/vllm_backend.py, nemo_rl/models/generation/vllm/vllm_generation.py, nemo_rl/models/generation/vllm/vllm_worker.py
Extends update_weights_from_local_ipc_handles/update_weights_from_collective to accept optional kv_scales, append scale tensors to payload, and invoke process_weights_after_loading when provided. Threads kv_scales through generation and worker RPCs; adds diagnostics.
FP8 config and weight processing
nemo_rl/models/generation/fp8.py
Adds FP8Config.kv_cache_dtype. Introduces kv_cache_process_weights_after_loading to retain and remap Q/K/V/prob scales on load. Wires kv_cache_dtype into vLLM init kwargs and weight loading; adds debug output.
Policy interfaces and implementations
nemo_rl/models/policy/interfaces.py, nemo_rl/models/policy/lm_policy.py, nemo_rl/models/policy/megatron_policy_worker.py
Adds calibrate_qkv_fp8_scales to PolicyInterface and Policy (sharded dispatch/gather). Implements calibration in MegatronPolicyWorker using forward hooks, percentile-based amax, margin, optional JSON save, and distributed merging. Also augments checkpoint save with memory handling.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Trainer as GRPO Trainer
  participant Policy as Policy
  participant Worker as MegatronPolicyWorker
  participant Gen as vLLMGeneration
  participant Wkr as vLLMWorker
  participant BE as vLLMBackend
  participant Model as Model

  rect rgba(230,240,255,0.5)
  note over Trainer,Worker: Optional FP8 Q/K/V scale calibration
  Trainer->>Policy: calibrate_qkv_fp8_scales(data, opts)
  Policy->>Worker: calibrate_qkv_fp8_scales(shard, opts)
  Worker-->>Policy: kv_scales (per-layer)
  Policy-->>Trainer: kv_scales
  end

  rect rgba(240,255,240,0.5)
  note over Trainer,Model: Refit with KV scales
  Trainer->>Gen: update_weights_from_ipc_handles/collective(kv_scales)
  alt IPC
    Gen->>Wkr: update_weights_from_ipc_handles(..., kv_scales)
    Wkr->>BE: update_weights_from_local_ipc_handles(..., kv_scales)
  else Collective
    Gen->>Wkr: update_weights_from_collective(kv_scales)
    Wkr->>BE: update_weights_from_collective(kv_scales)
  end
  BE->>Model: load weights (+kv scale tensors)
  BE->>Model: process_weights_after_loading(apply kv scales)
  Model-->>BE: ready
  BE-->>Wkr: ok
  Wkr-->>Gen: ok
  Gen-->>Trainer: ok
  end

  note over Trainer: Continue rollout/generation
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested labels

CI:L1, r0.4.0

Suggested reviewers

  • yuki-97
  • parthchadha
  • joyang-nv
  • jgerh

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 68.75% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Test Results For Major Changes ⚠️ Warning The PR introduces a major feature by adding FP8 KV-cache quantization support, which directly impacts numerics and performance, but the description only mentions an illustrative image and leaves the testing checklist unchecked without documenting concrete test results, regression checks, or performance metrics, so the required evidence for major changes is absent. Please update the PR description with explicit test results or benchmarking details, including configurations and outcomes that demonstrate numerical stability and performance impact for the new FP8 KV-cache quantization flow.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: adding KV cache quantization support to FP8 rollout in GRPO, which aligns with the core modifications across fp8.py, vllm_backend.py, vllm_generation.py, vllm_worker.py, and grpo.py.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
nemo_rl/models/policy/megatron_policy_worker.py (1)

1851-1937: Optimizer state offloaded to CPU after saving but never restored — can break continued training

The block moves optimizer state tensors to CPU unconditionally after save. If training continues in the same process, subsequent steps will pay a large PCIe penalty or fail if code assumes CUDA tensors.

Apply a guard to only offload when not training (or behind an env flag), so regular checkpoints during training don’t degrade runtime. Example diff:

-            torch.randn(1).cuda()  # wake up torch allocator
-            if hasattr(self, "optimizer") and self.optimizer is not None:
+            torch.randn(1).cuda()  # wake up torch allocator
+            # Only offload optimizer state after save when not training, or if explicitly requested.
+            should_offload_opt = (os.getenv("NRL_OFFLOAD_OPT_STATE_AFTER_SAVE", "0") == "1") or (not is_training)
+            if should_offload_opt and hasattr(self, "optimizer") and self.optimizer is not None:
                 # Iterate through the state dictionaries for each parameter group
                 if isinstance(self.optimizer, ChainedOptimizer):
                     optimizer_state = self.optimizer.state
                 else:
                     optimizer_state = self.optimizer._get_state()
                 for _, state in optimizer_state.items():
                     # Iterate through the state items (e.g., momentum, variance) for a parameter
                     for k, v in state.items():
                         # Check if the item is a tensor and on the GPU
                         if torch.is_tensor(v) and v.is_cuda:
                             # Move the tensor to CPU and update the state dictionary
                             state[k] = v.to("cpu")

Optionally, if you do offload during training (env flag), add a follow-up reload-to-CUDA path after memory cleanup.

Consider replacing print() memory logs with logging or warnings to align with codebase practices. As per coding guidelines

nemo_rl/models/generation/vllm/vllm_generation.py (1)

798-825: Add kv_scales parameter to async collective update methods
update_weights_from_collective_async in vllm_worker_async.py and the update_weights_from_collective signature in interfaces.py both lack the kv_scales: Optional[dict[str, float]] = None parameter required by the orchestrator’s call. Add this optional argument to both to match the sync version.

nemo_rl/models/generation/vllm/vllm_worker.py (2)

768-775: Check all worker results, not just the first

Only inspecting result_or_coro[0] can silently ignore failures on other workers. Aggregate across all results.

Apply this diff:

-            worker_result = result_or_coro[0]
-
-            if not worker_result:
+            results = result_or_coro if isinstance(result_or_coro, (list, tuple)) else [result_or_coro]
+            if not all(bool(r) for r in results):
                 print(
-                    f"Error: Worker failed to update weights. Result: {worker_result}"
+                    f"Error: Worker failed to update weights. Results: {results}"
                 )
                 return False

796-811: Collective update should verify all worker results

Same issue: only the first result is checked. Ensure all workers succeeded.

Apply this diff:

-            worker_result = result_or_coro[0]
-
-            if not worker_result:
+            results = result_or_coro if isinstance(result_or_coro, (list, tuple)) else [result_or_coro]
+            if not all(bool(r) for r in results):
                 print(
-                    f"Error: Worker failed to update weights. Result: {worker_result}"
+                    f"Error: Worker failed to update weights. Results: {results}"
                 )
                 return False
🧹 Nitpick comments (12)
nemo_rl/models/policy/lm_policy.py (1)

579-643: DP sharding + worker dispatch for calibration is correct; add a small guard before indexing

Flow mirrors get_logprobs and properly replicates CP/TP/PP. Add a quick safety check before returning results[0].

         results = self.worker_group.get_all_worker_results(futures)
-        return results[0]
+        assert len(results) > 0, "No calibration results returned from workers"
+        return results[0]

Confirm the DTensor policy worker won’t be invoked for this method (or raises NotImplementedError) since the current implementation targets Megatron only. Based on learnings

nemo_rl/models/policy/megatron_policy_worker.py (1)

1970-2173: Calibrate Q/K/V FP8 scales: good structure; tighten error handling and avoid zero scales

Overall approach (hooks + percentile + DP/TP/PP reduction) is sound. Improve robustness:

  • Don’t swallow generic Exception; log traceback or narrow to expected exceptions.
  • Avoid producing zero scales; clamp with epsilon.
  • _percentile() needn’t allocate on CUDA; keep on CPU to reduce GPU noise.
-        def _percentile(values: list[float], p: float) -> float:
+        def _percentile(values: list[float], p: float) -> float:
             if not values:
                 return 0.0
-            t = torch.tensor(sorted(values), device="cuda", dtype=torch.float32)
-            rank = max(0, min(len(values) - 1, int(round((p / 100.0) * (len(values) - 1)))))
-            return float(t[rank].item())
+            t = torch.tensor(sorted(values), dtype=torch.float32)  # CPU is fine
+            # Index by nearest rank; avoid int() on an int
+            rank = max(0, min(len(values) - 1, round((p / 100.0) * (len(values) - 1))))
+            return float(t[rank].item())
-            if include_q:
-                q_scale = (vals.get("q_amax_p", 0.0) * margin) / FP8_MAX_Q
-                out_entry["q_scale"] = float(q_scale)
-            k_scale = (vals.get("k_amax_p", 0.0) * margin) / FP8_MAX_K
-            v_scale = (vals.get("v_amax_p", 0.0) * margin) / FP8_MAX_V
+            eps = 1e-8
+            if include_q:
+                q_scale = max(eps, (vals.get("q_amax_p", 0.0) * margin) / FP8_MAX_Q)
+                out_entry["q_scale"] = float(q_scale)
+            k_scale = max(eps, (vals.get("k_amax_p", 0.0) * margin) / FP8_MAX_K)
+            v_scale = max(eps, (vals.get("v_amax_p", 0.0) * margin) / FP8_MAX_V)

Also rename the unused pre-hook module parameter to silence linters:

-            def _pre_hook(module, inputs):
+            def _pre_hook(_module, inputs):

The vLLM-side loaders expect kv_scales as a flat dict[str, float] of parameter names to scales (see nemo_rl/models/generation/vllm/vllm_backend.py). This method returns a nested structure {format, percentile, margin, layers: {...}} with layer keys like "layer_N". Please confirm the translation from layer keys to vLLM parameter names (e.g., model.layers.N.<...>._k_scale/_v_scale) is performed upstream (e.g., in GRPO refit flow) before passing to vLLM. If missing, I can propose a mapper.

nemo_rl/models/generation/fp8.py (2)

278-279: Remove unused local variable kv_cache_dtype

It’s set then unused; drop to silence lint.

-    kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto")

303-309: Gate noisy debug print or switch to logger

The global config print will spam logs. Gate under an env flag or use logger at debug level.

-    # TODO: Remove this after debugging.
-    print(f"[KV_SCALES] Global FP8 config: {global_fp8_config}")
+    if os.getenv("NRL_FP8_DEBUG", "0") == "1":
+        print(f"[KV_SCALES] Global FP8 config: {global_fp8_config}")
nemo_rl/models/generation/vllm/vllm_backend.py (3)

183-185: Remove extraneous f-string prefix

No placeholders; triggers Ruff F541.

Apply this diff:

-                print(f"[KV_SCALES] Loading weights into FP8 model")
+                print("[KV_SCALES] Loading weights into FP8 model")

The Ruff hint (F541) flagged this; please run lint to confirm no other instances remain.


169-177: Reduce noisy debug prints or gate behind a flag

Repeated printing (including .item() on GPU tensors) induces syncs and slows the hot path. Gate behind an env var or remove before merge.

Apply this diff to gate under an env toggle:

-            # Debug print to check if the KV scales are in the weights
-            # TODO: Remove this after debugging.
-            kv_scale_params = [name for name, _ in weights if any(scale_name in name for scale_name in ['q_scale', 'k_scale', 'v_scale'])]
-            print(f"[KV_SCALES] KV scale parameters found in weights: {kv_scale_params}")
-            if kv_scale_params:
-                for name, tensor in weights:
-                    if any(scale_name in name for scale_name in ['q_scale', 'k_scale', 'v_scale']):
-                        print(f"[KV_SCALES] Parameter {name}: shape={tensor.shape}, dtype={tensor.dtype}, value={tensor.item() if tensor.numel() == 1 else 'multi-element'}")
-            else:
-                print("[KV_SCALES] No KV scale parameters found in weights")
+            if os.environ.get("NRL_DEBUG_KV_SCALES") == "1":
+                kv_scale_params = [
+                    name
+                    for name, _ in weights
+                    if any(scale_name in name for scale_name in ["q_scale", "k_scale", "v_scale"])
+                ]
+                print(f"[KV_SCALES] KV scale parameters found in weights: {kv_scale_params}")

231-237: Align device placement for KV scales in collective path

For consistency with other weights allocated on device="cuda", consider using the worker’s device as well.

Apply this diff:

-                for param_name, scale_value in kv_scales.items():
-                    # Convert scale to tensor
-                    scale_tensor = torch.tensor(scale_value, dtype=torch.float32, device="cuda")
+                for param_name, scale_value in kv_scales.items():
+                    # Convert scale to tensor
+                    scale_tensor = torch.tensor(
+                        scale_value, dtype=torch.float32, device=self.device
+                    )
                     weights.append((param_name, scale_tensor))
nemo_rl/algorithms/grpo.py (5)

394-405: Avoid using assert for runtime checks

Asserts can be stripped with -O, hiding critical compatibility guards. Raise explicit exceptions instead.

Apply this diff:

-            if kv_cache_dtype == "fp8":
-                policy_backend = "megatron" if policy_config.get("megatron_cfg", {}).get("enabled", False) else "dtensor"
-                
-                # Validate KV cache FP8 compatibility
-                assert policy_backend != "dtensor", "DTensor backend is not supported with kv cache fp8 enabled."
-                assert not _should_use_async_rollouts(master_config), "Async rollouts is not supported with kv cache fp8 enabled."
-                assert policy_config.get("megatron_cfg", {}).get("pipeline_model_parallel_size", 1) == 1, "Pipeline model parallel size must be 1 for megatron backend with kv cache fp8 enabled."
+            if kv_cache_dtype == "fp8":
+                policy_backend = "megatron" if policy_config.get("megatron_cfg", {}).get("enabled", False) else "dtensor"
+                if policy_backend == "dtensor":
+                    raise RuntimeError("DTensor backend is not supported with kv cache fp8 enabled.")
+                if _should_use_async_rollouts(master_config):
+                    raise RuntimeError("Async rollouts is not supported with kv cache fp8 enabled.")
+                if policy_config.get("megatron_cfg", {}).get("pipeline_model_parallel_size", 1) != 1:
+                    raise RuntimeError("Pipeline model parallel size must be 1 for megatron backend with kv cache fp8 enabled.")

496-504: Do not introduce implicit config defaults in code (nemo_rl/ rule)**

Accessing vllm_cfg keys with .get(..., "auto") sets hidden defaults. Per guidelines, YAML is the single source of truth; assume presence or model optionality via TypedDict.

Apply this diff:

-    vllm_cfg = generation_config.get("vllm_cfg", {})
-    kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto")
-    vllm_precision = vllm_cfg.get("precision", "auto")
+    vllm_cfg = generation_config["vllm_cfg"]
+    kv_cache_dtype = vllm_cfg["kv_cache_dtype"]
+    vllm_precision = vllm_cfg["precision"]

As per coding guidelines


647-649: Remove extraneous f-string prefix

This print has no placeholders; triggers Ruff F541.

Apply this diff:

-        print(f"[KV_SCALES] FP8 KV cache detected, will sync q_scale, _k_scale and _v_scale during refit")
+        print("[KV_SCALES] FP8 KV cache detected, will sync q_scale, _k_scale and _v_scale during refit")

Rerun lint to ensure no other F541 remain.


909-924: Ensure kv_scales_cache is initialized before updates

If KV scales are recomputed before the initial cache creation (edge flows), assignments to kv_scales_cache[...] would fail. Initialize defensively.

Apply this diff:

-                if sync_kv_scales:
+                if sync_kv_scales:
+                    if kv_scales_cache is None:
+                        kv_scales_cache = {}
                     with timer.time("recompute_kv_scales"):

642-650: KV-scale sync logging is fine; consider gating debug verbosity

Given frequency, consider using the existing Logger or an env flag to avoid noisy stdout in production runs.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f521459 and 2521fda.

📒 Files selected for processing (8)
  • nemo_rl/algorithms/grpo.py (9 hunks)
  • nemo_rl/models/generation/fp8.py (7 hunks)
  • nemo_rl/models/generation/vllm/vllm_backend.py (3 hunks)
  • nemo_rl/models/generation/vllm/vllm_generation.py (3 hunks)
  • nemo_rl/models/generation/vllm/vllm_worker.py (5 hunks)
  • nemo_rl/models/policy/interfaces.py (1 hunks)
  • nemo_rl/models/policy/lm_policy.py (1 hunks)
  • nemo_rl/models/policy/megatron_policy_worker.py (4 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • nemo_rl/models/policy/lm_policy.py
  • nemo_rl/models/generation/vllm/vllm_worker.py
  • nemo_rl/models/generation/vllm/vllm_generation.py
  • nemo_rl/algorithms/grpo.py
  • nemo_rl/models/policy/interfaces.py
  • nemo_rl/models/policy/megatron_policy_worker.py
  • nemo_rl/models/generation/vllm/vllm_backend.py
  • nemo_rl/models/generation/fp8.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)

Files:

  • nemo_rl/models/policy/lm_policy.py
  • nemo_rl/models/generation/vllm/vllm_worker.py
  • nemo_rl/models/generation/vllm/vllm_generation.py
  • nemo_rl/algorithms/grpo.py
  • nemo_rl/models/policy/interfaces.py
  • nemo_rl/models/policy/megatron_policy_worker.py
  • nemo_rl/models/generation/vllm/vllm_backend.py
  • nemo_rl/models/generation/fp8.py
🧬 Code graph analysis (7)
nemo_rl/models/policy/lm_policy.py (6)
nemo_rl/models/policy/megatron_policy_worker.py (1)
  • calibrate_qkv_fp8_scales (1971-2173)
nemo_rl/models/policy/interfaces.py (1)
  • calibrate_qkv_fp8_scales (104-126)
nemo_rl/distributed/batched_data_dict.py (2)
  • BatchedDataDict (75-839)
  • shard_by_batch_size (246-644)
nemo_rl/models/generation/interfaces.py (1)
  • GenerationDatumSpec (127-158)
nemo_rl/distributed/worker_groups.py (2)
  • run_all_workers_sharded_data (774-909)
  • get_all_worker_results (911-928)
nemo_rl/distributed/named_sharding.py (1)
  • get_axis_size (209-211)
nemo_rl/models/generation/vllm/vllm_worker.py (2)
nemo_rl/models/generation/vllm/vllm_generation.py (2)
  • update_weights_from_ipc_handles (745-796)
  • update_weights_from_collective (798-824)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
  • update_weights_from_collective (216-269)
nemo_rl/models/generation/vllm/vllm_generation.py (3)
nemo_rl/models/generation/vllm/vllm_worker.py (2)
  • update_weights_from_ipc_handles (702-781)
  • update_weights_from_collective (784-817)
nemo_rl/distributed/worker_groups.py (2)
  • run_all_workers_multiple_data (631-726)
  • run_all_workers_single_data (728-772)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
  • update_weights_from_collective (216-269)
nemo_rl/algorithms/grpo.py (9)
nemo_rl/algorithms/sft.py (1)
  • MasterConfig (75-81)
nemo_rl/models/policy/megatron_policy_worker.py (4)
  • get_weights_ipc_handles (1594-1684)
  • broadcast_weights_for_collective (1687-1696)
  • prepare_for_lp_inference (1698-1701)
  • calibrate_qkv_fp8_scales (1971-2173)
nemo_rl/models/policy/lm_policy.py (4)
  • get_weights_ipc_handles (687-706)
  • broadcast_weights_for_collective (708-714)
  • prepare_for_lp_inference (554-558)
  • calibrate_qkv_fp8_scales (579-642)
nemo_rl/models/policy/dtensor_policy_worker_v2.py (3)
  • get_weights_ipc_handles (1409-1445)
  • broadcast_weights_for_collective (1448-1469)
  • prepare_for_lp_inference (1472-1479)
nemo_rl/models/generation/vllm/vllm_generation.py (2)
  • update_weights_from_ipc_handles (745-796)
  • update_weights_from_collective (798-824)
nemo_rl/models/generation/vllm/vllm_worker.py (2)
  • update_weights_from_ipc_handles (702-781)
  • update_weights_from_collective (784-817)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
  • update_weights_from_collective (216-269)
nemo_rl/distributed/batched_data_dict.py (3)
  • BatchedDataDict (75-839)
  • get_multimodal_dict (88-99)
  • to (804-811)
nemo_rl/algorithms/loss_functions.py (1)
  • ClippedPGLossDataDict (47-57)
nemo_rl/models/policy/interfaces.py (4)
nemo_rl/models/policy/megatron_policy_worker.py (1)
  • calibrate_qkv_fp8_scales (1971-2173)
nemo_rl/models/policy/lm_policy.py (1)
  • calibrate_qkv_fp8_scales (579-642)
nemo_rl/distributed/batched_data_dict.py (1)
  • BatchedDataDict (75-839)
nemo_rl/models/generation/interfaces.py (1)
  • GenerationDatumSpec (127-158)
nemo_rl/models/policy/megatron_policy_worker.py (3)
nemo_rl/distributed/batched_data_dict.py (2)
  • to (804-811)
  • BatchedDataDict (75-839)
nemo_rl/models/policy/lm_policy.py (1)
  • calibrate_qkv_fp8_scales (579-642)
nemo_rl/models/policy/interfaces.py (1)
  • calibrate_qkv_fp8_scales (104-126)
nemo_rl/models/generation/vllm/vllm_backend.py (3)
nemo_rl/models/generation/fp8.py (3)
  • is_fp8_model (314-325)
  • load_weights (401-428)
  • process_weights_after_loading (504-550)
nemo_rl/models/generation/vllm/vllm_generation.py (1)
  • update_weights_from_collective (798-824)
nemo_rl/models/generation/vllm/vllm_worker.py (1)
  • update_weights_from_collective (784-817)
🪛 Ruff (0.13.1)
nemo_rl/models/generation/vllm/vllm_generation.py

794-794: Do not catch blind exception: Exception

(BLE001)

nemo_rl/algorithms/grpo.py

489-489: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


647-647: f-string without any placeholders

Remove extraneous f prefix

(F541)

nemo_rl/models/policy/megatron_policy_worker.py

2007-2007: Do not catch blind exception: Exception

(BLE001)


2033-2033: Unused function argument: module

(ARG001)


2046-2046: Do not catch blind exception: Exception

(BLE001)


2060-2060: Do not catch blind exception: Exception

(BLE001)


2080-2081: try-except-pass detected, consider logging the exception

(S110)


2080-2080: Do not catch blind exception: Exception

(BLE001)


2095-2095: Do not catch blind exception: Exception

(BLE001)


2104-2104: Value being cast to int is already an integer

Remove unnecessary int call

(RUF046)


2170-2171: try-except-pass detected, consider logging the exception

(S110)


2170-2170: Do not catch blind exception: Exception

(BLE001)

nemo_rl/models/generation/vllm/vllm_backend.py

184-184: f-string without any placeholders

Remove extraneous f prefix

(F541)

nemo_rl/models/generation/fp8.py

108-108: Unused function argument: self

(ARG001)


150-151: Avoid specifying long messages outside the exception class

(TRY003)


185-186: Do not assign a lambda expression, use a def

Rewrite is_singleton_float as a def

(E731)


185-186: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)


189-190: Avoid specifying long messages outside the exception class

(TRY003)


205-205: f-string without any placeholders

Remove extraneous f prefix

(F541)


278-278: Local variable kv_cache_dtype is assigned to but never used

Remove assignment to unused variable kv_cache_dtype

(F841)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Lint check
  • GitHub Check: Post automodel integration comment / Comment on PR
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (9)
nemo_rl/models/policy/interfaces.py (1)

103-127: New abstract KV-scale calibration API looks good

Signature, typing, and docstring are consistent with the rest of the interface.

Please confirm all concrete Policy implementations (Megatron, DTensor, etc.) implement this method or are gated to avoid calls when unsupported. If DTensor is not yet supported, consider raising NotImplementedError in that backend to fail fast.

nemo_rl/models/policy/megatron_policy_worker.py (1)

19-21: Minor import additions

json/re are fine and used later.

nemo_rl/models/generation/vllm/vllm_generation.py (2)

745-756: IPC update path now accepts kv_scales — interface and propagation look correct

The optional kv_scales parameter is threaded to workers without breaking the default path.

Confirm update_weights_from_ipc_handles_async in the async worker also accepts kv_scales to avoid mismatched signatures across engine modes.


776-791: Correct replication of kv_scales across tied groups

Replicating kv_scales per worker’s device UUID set is appropriate.

nemo_rl/models/generation/fp8.py (3)

43-44: Config extension for kv_cache_dtype is fine

Field addition and propagation via init_fp8 are consistent.


233-237: Patching BaseKVCacheMethod hook is appropriate

Keeps KV-scale params for dynamic updates; aligns with the goal of feeding calibrated scales later.


251-261: kv_cache_dtype propagated into global FP8Config

Good to thread through vLLM kwargs downstream.

nemo_rl/models/generation/vllm/vllm_backend.py (1)

89-96: API extension looks good

Accepting optional kv_scales here aligns with upstream worker calls and the generation path.

nemo_rl/algorithms/grpo.py (1)

575-584: Passing kv_scales only with the first IPC batch — confirm loader semantics

vLLM’s loader must be able to receive “extra” kv_scale tensors separate from the rest of the weights, and only once. If subsequent weight groups re-touch affected layers, scales must remain intact. Please confirm this is guaranteed for the targeted models and vLLM version.

Would you like me to add a guard to call process_weights_after_loading only once per refit and to assert presence of KV-scale params post-load?

@guyueh1 guyueh1 self-requested a review September 26, 2025 15:32
@zpqiu zpqiu linked an issue Sep 26, 2025 that may be closed by this pull request
@zpqiu zpqiu removed request for a team September 26, 2025 15:34
@guyueh1
Copy link
Contributor

guyueh1 commented Nov 10, 2025

@zpqiu sorry for the long delay, I have put some comments; could you first merge in main and then address them? I think this solution can be further optimized in terms of performance, but i'm ok with merging this version first.

@zpqiu
Copy link
Contributor

zpqiu commented Nov 11, 2025

@zpqiu sorry for the long delay, I have put some comments; could you first merge in main and then address them? I think this solution can be further optimized in terms of performance, but i'm ok with merging this version first.

Sure. I will rebase code and resolve these comments first.

@zpqiu zpqiu changed the title draft feat: KV cache quantization support in fp8 rollout in GRPO feat: KV cache quantization support in fp8 rollout in GRPO Nov 11, 2025
@guyueh1 guyueh1 added the CI:L0 Run doctests and unit tests label Nov 13, 2025
@zpqiu zpqiu marked this pull request as draft November 14, 2025 13:41
Shuang Yu added 4 commits November 17, 2025 06:03
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
… max.

Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
@zpqiu zpqiu added the CI:L2 Run doctests, unit tests, functional tests, and convergence tests label Nov 27, 2025
Signed-off-by: alexchiu <alexq@nvidia.com>
@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: 47ea0c0 (PR #1212 from kv-cache-fp8)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: 94d16ec (PR #1212 from kv-cache-fp8)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: 48b20aa (PR #1212 from kv-cache-fp8)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: 7ca82f3 (PR #1212 from kv-cache-fp8)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

guyueh1
guyueh1 previously approved these changes Nov 27, 2025
@guyueh1
Copy link
Contributor

guyueh1 commented Nov 27, 2025

@zpqiu can you fix the functional test failure?
Also I think the L1 functionality is ran on Ampere GPUs, maybe you need to conditionally skip for cuda arch before sm_90

If so, do we need to delete this test?

I'm ok with adding test just to the nightly test suite as you do now.

@guyueh1
Copy link
Contributor

guyueh1 commented Nov 27, 2025

@terrykong please review

@guyueh1
Copy link
Contributor

guyueh1 commented Nov 28, 2025

@terrykong this is the last FP8 functionality we want to merge before v0.5, after this I want to perform a refactor of code to make it cleaner and more structured. Please take a review when you have time.

Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
Signed-off-by: alexchiu <qiuzhaopeng@foxmail.com>
@github-actions
Copy link

github-actions bot commented Dec 1, 2025

ℹ️ File Consistency Check

Check based on commit: b34ad76 (PR #1212 from kv-cache-fp8)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@sharonyu-115
Copy link
Contributor Author

Update here the latest experimental results.

Configuration
Model: Qwen3-8B-Base.
Method: Dynamically calculate qkv scales at the end of each training step and synchronize them to vLLM.
Framework: NeMo-RL, vLLM + MCore, batch rollout mode.
Correction: Token-level TIS, C=2

image image

Observations:

  • Mismatch: Enabling FP8 for KV cache and attention increases mismatch compared to using FP8 only for Linear layers.

  • Accuracy: Applying token-level TIS realigns the accuracy curve with BF16.

  • KV Cache Capacity: FP8 KV cache provides an additional 2x token capacity and concurrency.
    BF16: GPU KV cache size: 249,952 tokens, Maximum concurrency: 11.10x
    FP8 Linear-only: GPU KV cache size: 299,344 tokens, Maximum concurrency: 13.29x
    FP8 Linear + KV Cache: GPU KV cache size: 598,672 tokens, Maximum concurrency: 26.57x

  • Speedup:
    Adding FP8 KV cache/Attention yields an additional ~30% rollout speedup over FP8 Linear only.
    Total speedup compared to BF16 is approximately 48%.

  • Observation: Longer response lengths benefit more due to the higher portion of computation spent in attention.

zpqiu added 2 commits December 1, 2025 23:47
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Signed-off-by: alexchiu <alexq@nvidia.com>
@github-actions
Copy link

github-actions bot commented Dec 2, 2025

ℹ️ File Consistency Check

Check based on commit: 6d65466 (PR #1212 from kv-cache-fp8)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
@github-actions
Copy link

github-actions bot commented Dec 2, 2025

ℹ️ File Consistency Check

Check based on commit: db3ea88 (PR #1212 from kv-cache-fp8)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L0 Run doctests and unit tests Low Precision

Projects

None yet

Development

Successfully merging this pull request may close these issues.

KV cache quantization support in fp8 rollout in GRPO

4 participants