feat: KV cache quantization support in fp8 rollout in GRPO#1212
feat: KV cache quantization support in fp8 rollout in GRPO#1212terrykong merged 44 commits intoNVIDIA-NeMo:mainfrom
Conversation
📝 WalkthroughWalkthroughAdds FP8 KV-cache scale handling across GRPO refit and vLLM update paths: computes/calibrates Q/K/V FP8 scales, caches and threads them through IPC/NCCL weight updates, and applies them post-load in vLLM. Introduces config for kv_cache_dtype, backend compatibility checks, and new calibration methods in policy interfaces and workers. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Trainer as GRPO Trainer
participant Policy as Policy
participant Worker as MegatronPolicyWorker
participant Gen as vLLMGeneration
participant Wkr as vLLMWorker
participant BE as vLLMBackend
participant Model as Model
rect rgba(230,240,255,0.5)
note over Trainer,Worker: Optional FP8 Q/K/V scale calibration
Trainer->>Policy: calibrate_qkv_fp8_scales(data, opts)
Policy->>Worker: calibrate_qkv_fp8_scales(shard, opts)
Worker-->>Policy: kv_scales (per-layer)
Policy-->>Trainer: kv_scales
end
rect rgba(240,255,240,0.5)
note over Trainer,Model: Refit with KV scales
Trainer->>Gen: update_weights_from_ipc_handles/collective(kv_scales)
alt IPC
Gen->>Wkr: update_weights_from_ipc_handles(..., kv_scales)
Wkr->>BE: update_weights_from_local_ipc_handles(..., kv_scales)
else Collective
Gen->>Wkr: update_weights_from_collective(kv_scales)
Wkr->>BE: update_weights_from_collective(kv_scales)
end
BE->>Model: load weights (+kv scale tensors)
BE->>Model: process_weights_after_loading(apply kv scales)
Model-->>BE: ready
BE-->>Wkr: ok
Wkr-->>Gen: ok
Gen-->>Trainer: ok
end
note over Trainer: Continue rollout/generation
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
nemo_rl/models/policy/megatron_policy_worker.py (1)
1851-1937: Optimizer state offloaded to CPU after saving but never restored — can break continued trainingThe block moves optimizer state tensors to CPU unconditionally after save. If training continues in the same process, subsequent steps will pay a large PCIe penalty or fail if code assumes CUDA tensors.
Apply a guard to only offload when not training (or behind an env flag), so regular checkpoints during training don’t degrade runtime. Example diff:
- torch.randn(1).cuda() # wake up torch allocator - if hasattr(self, "optimizer") and self.optimizer is not None: + torch.randn(1).cuda() # wake up torch allocator + # Only offload optimizer state after save when not training, or if explicitly requested. + should_offload_opt = (os.getenv("NRL_OFFLOAD_OPT_STATE_AFTER_SAVE", "0") == "1") or (not is_training) + if should_offload_opt and hasattr(self, "optimizer") and self.optimizer is not None: # Iterate through the state dictionaries for each parameter group if isinstance(self.optimizer, ChainedOptimizer): optimizer_state = self.optimizer.state else: optimizer_state = self.optimizer._get_state() for _, state in optimizer_state.items(): # Iterate through the state items (e.g., momentum, variance) for a parameter for k, v in state.items(): # Check if the item is a tensor and on the GPU if torch.is_tensor(v) and v.is_cuda: # Move the tensor to CPU and update the state dictionary state[k] = v.to("cpu")Optionally, if you do offload during training (env flag), add a follow-up reload-to-CUDA path after memory cleanup.
Consider replacing print() memory logs with logging or warnings to align with codebase practices. As per coding guidelines
nemo_rl/models/generation/vllm/vllm_generation.py (1)
798-825: Addkv_scalesparameter to async collective update methods
update_weights_from_collective_asyncinvllm_worker_async.pyand theupdate_weights_from_collectivesignature ininterfaces.pyboth lack thekv_scales: Optional[dict[str, float]] = Noneparameter required by the orchestrator’s call. Add this optional argument to both to match the sync version.nemo_rl/models/generation/vllm/vllm_worker.py (2)
768-775: Check all worker results, not just the firstOnly inspecting
result_or_coro[0]can silently ignore failures on other workers. Aggregate across all results.Apply this diff:
- worker_result = result_or_coro[0] - - if not worker_result: + results = result_or_coro if isinstance(result_or_coro, (list, tuple)) else [result_or_coro] + if not all(bool(r) for r in results): print( - f"Error: Worker failed to update weights. Result: {worker_result}" + f"Error: Worker failed to update weights. Results: {results}" ) return False
796-811: Collective update should verify all worker resultsSame issue: only the first result is checked. Ensure all workers succeeded.
Apply this diff:
- worker_result = result_or_coro[0] - - if not worker_result: + results = result_or_coro if isinstance(result_or_coro, (list, tuple)) else [result_or_coro] + if not all(bool(r) for r in results): print( - f"Error: Worker failed to update weights. Result: {worker_result}" + f"Error: Worker failed to update weights. Results: {results}" ) return False
🧹 Nitpick comments (12)
nemo_rl/models/policy/lm_policy.py (1)
579-643: DP sharding + worker dispatch for calibration is correct; add a small guard before indexingFlow mirrors get_logprobs and properly replicates CP/TP/PP. Add a quick safety check before returning results[0].
results = self.worker_group.get_all_worker_results(futures) - return results[0] + assert len(results) > 0, "No calibration results returned from workers" + return results[0]Confirm the DTensor policy worker won’t be invoked for this method (or raises NotImplementedError) since the current implementation targets Megatron only. Based on learnings
nemo_rl/models/policy/megatron_policy_worker.py (1)
1970-2173: Calibrate Q/K/V FP8 scales: good structure; tighten error handling and avoid zero scalesOverall approach (hooks + percentile + DP/TP/PP reduction) is sound. Improve robustness:
- Don’t swallow generic Exception; log traceback or narrow to expected exceptions.
- Avoid producing zero scales; clamp with epsilon.
- _percentile() needn’t allocate on CUDA; keep on CPU to reduce GPU noise.
- def _percentile(values: list[float], p: float) -> float: + def _percentile(values: list[float], p: float) -> float: if not values: return 0.0 - t = torch.tensor(sorted(values), device="cuda", dtype=torch.float32) - rank = max(0, min(len(values) - 1, int(round((p / 100.0) * (len(values) - 1))))) - return float(t[rank].item()) + t = torch.tensor(sorted(values), dtype=torch.float32) # CPU is fine + # Index by nearest rank; avoid int() on an int + rank = max(0, min(len(values) - 1, round((p / 100.0) * (len(values) - 1)))) + return float(t[rank].item())- if include_q: - q_scale = (vals.get("q_amax_p", 0.0) * margin) / FP8_MAX_Q - out_entry["q_scale"] = float(q_scale) - k_scale = (vals.get("k_amax_p", 0.0) * margin) / FP8_MAX_K - v_scale = (vals.get("v_amax_p", 0.0) * margin) / FP8_MAX_V + eps = 1e-8 + if include_q: + q_scale = max(eps, (vals.get("q_amax_p", 0.0) * margin) / FP8_MAX_Q) + out_entry["q_scale"] = float(q_scale) + k_scale = max(eps, (vals.get("k_amax_p", 0.0) * margin) / FP8_MAX_K) + v_scale = max(eps, (vals.get("v_amax_p", 0.0) * margin) / FP8_MAX_V)Also rename the unused pre-hook module parameter to silence linters:
- def _pre_hook(module, inputs): + def _pre_hook(_module, inputs):The vLLM-side loaders expect kv_scales as a flat dict[str, float] of parameter names to scales (see nemo_rl/models/generation/vllm/vllm_backend.py). This method returns a nested structure {format, percentile, margin, layers: {...}} with layer keys like "layer_N". Please confirm the translation from layer keys to vLLM parameter names (e.g., model.layers.N.<...>._k_scale/_v_scale) is performed upstream (e.g., in GRPO refit flow) before passing to vLLM. If missing, I can propose a mapper.
nemo_rl/models/generation/fp8.py (2)
278-279: Remove unused local variable kv_cache_dtypeIt’s set then unused; drop to silence lint.
- kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto")
303-309: Gate noisy debug print or switch to loggerThe global config print will spam logs. Gate under an env flag or use logger at debug level.
- # TODO: Remove this after debugging. - print(f"[KV_SCALES] Global FP8 config: {global_fp8_config}") + if os.getenv("NRL_FP8_DEBUG", "0") == "1": + print(f"[KV_SCALES] Global FP8 config: {global_fp8_config}")nemo_rl/models/generation/vllm/vllm_backend.py (3)
183-185: Remove extraneous f-string prefixNo placeholders; triggers Ruff F541.
Apply this diff:
- print(f"[KV_SCALES] Loading weights into FP8 model") + print("[KV_SCALES] Loading weights into FP8 model")The Ruff hint (F541) flagged this; please run lint to confirm no other instances remain.
169-177: Reduce noisy debug prints or gate behind a flagRepeated printing (including
.item()on GPU tensors) induces syncs and slows the hot path. Gate behind an env var or remove before merge.Apply this diff to gate under an env toggle:
- # Debug print to check if the KV scales are in the weights - # TODO: Remove this after debugging. - kv_scale_params = [name for name, _ in weights if any(scale_name in name for scale_name in ['q_scale', 'k_scale', 'v_scale'])] - print(f"[KV_SCALES] KV scale parameters found in weights: {kv_scale_params}") - if kv_scale_params: - for name, tensor in weights: - if any(scale_name in name for scale_name in ['q_scale', 'k_scale', 'v_scale']): - print(f"[KV_SCALES] Parameter {name}: shape={tensor.shape}, dtype={tensor.dtype}, value={tensor.item() if tensor.numel() == 1 else 'multi-element'}") - else: - print("[KV_SCALES] No KV scale parameters found in weights") + if os.environ.get("NRL_DEBUG_KV_SCALES") == "1": + kv_scale_params = [ + name + for name, _ in weights + if any(scale_name in name for scale_name in ["q_scale", "k_scale", "v_scale"]) + ] + print(f"[KV_SCALES] KV scale parameters found in weights: {kv_scale_params}")
231-237: Align device placement for KV scales in collective pathFor consistency with other weights allocated on
device="cuda", consider using the worker’s device as well.Apply this diff:
- for param_name, scale_value in kv_scales.items(): - # Convert scale to tensor - scale_tensor = torch.tensor(scale_value, dtype=torch.float32, device="cuda") + for param_name, scale_value in kv_scales.items(): + # Convert scale to tensor + scale_tensor = torch.tensor( + scale_value, dtype=torch.float32, device=self.device + ) weights.append((param_name, scale_tensor))nemo_rl/algorithms/grpo.py (5)
394-405: Avoid using assert for runtime checksAsserts can be stripped with -O, hiding critical compatibility guards. Raise explicit exceptions instead.
Apply this diff:
- if kv_cache_dtype == "fp8": - policy_backend = "megatron" if policy_config.get("megatron_cfg", {}).get("enabled", False) else "dtensor" - - # Validate KV cache FP8 compatibility - assert policy_backend != "dtensor", "DTensor backend is not supported with kv cache fp8 enabled." - assert not _should_use_async_rollouts(master_config), "Async rollouts is not supported with kv cache fp8 enabled." - assert policy_config.get("megatron_cfg", {}).get("pipeline_model_parallel_size", 1) == 1, "Pipeline model parallel size must be 1 for megatron backend with kv cache fp8 enabled." + if kv_cache_dtype == "fp8": + policy_backend = "megatron" if policy_config.get("megatron_cfg", {}).get("enabled", False) else "dtensor" + if policy_backend == "dtensor": + raise RuntimeError("DTensor backend is not supported with kv cache fp8 enabled.") + if _should_use_async_rollouts(master_config): + raise RuntimeError("Async rollouts is not supported with kv cache fp8 enabled.") + if policy_config.get("megatron_cfg", {}).get("pipeline_model_parallel_size", 1) != 1: + raise RuntimeError("Pipeline model parallel size must be 1 for megatron backend with kv cache fp8 enabled.")
496-504: Do not introduce implicit config defaults in code (nemo_rl/ rule)**Accessing
vllm_cfgkeys with.get(..., "auto")sets hidden defaults. Per guidelines, YAML is the single source of truth; assume presence or model optionality via TypedDict.Apply this diff:
- vllm_cfg = generation_config.get("vllm_cfg", {}) - kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - vllm_precision = vllm_cfg.get("precision", "auto") + vllm_cfg = generation_config["vllm_cfg"] + kv_cache_dtype = vllm_cfg["kv_cache_dtype"] + vllm_precision = vllm_cfg["precision"]As per coding guidelines
647-649: Remove extraneous f-string prefixThis print has no placeholders; triggers Ruff F541.
Apply this diff:
- print(f"[KV_SCALES] FP8 KV cache detected, will sync q_scale, _k_scale and _v_scale during refit") + print("[KV_SCALES] FP8 KV cache detected, will sync q_scale, _k_scale and _v_scale during refit")Rerun lint to ensure no other F541 remain.
909-924: Ensure kv_scales_cache is initialized before updatesIf KV scales are recomputed before the initial cache creation (edge flows), assignments to
kv_scales_cache[...]would fail. Initialize defensively.Apply this diff:
- if sync_kv_scales: + if sync_kv_scales: + if kv_scales_cache is None: + kv_scales_cache = {} with timer.time("recompute_kv_scales"):
642-650: KV-scale sync logging is fine; consider gating debug verbosityGiven frequency, consider using the existing Logger or an env flag to avoid noisy stdout in production runs.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
nemo_rl/algorithms/grpo.py(9 hunks)nemo_rl/models/generation/fp8.py(7 hunks)nemo_rl/models/generation/vllm/vllm_backend.py(3 hunks)nemo_rl/models/generation/vllm/vllm_generation.py(3 hunks)nemo_rl/models/generation/vllm/vllm_worker.py(5 hunks)nemo_rl/models/policy/interfaces.py(1 hunks)nemo_rl/models/policy/lm_policy.py(1 hunks)nemo_rl/models/policy/megatron_policy_worker.py(4 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/models/policy/lm_policy.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/algorithms/grpo.pynemo_rl/models/policy/interfaces.pynemo_rl/models/policy/megatron_policy_worker.pynemo_rl/models/generation/vllm/vllm_backend.pynemo_rl/models/generation/fp8.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/models/policy/lm_policy.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/algorithms/grpo.pynemo_rl/models/policy/interfaces.pynemo_rl/models/policy/megatron_policy_worker.pynemo_rl/models/generation/vllm/vllm_backend.pynemo_rl/models/generation/fp8.py
🧬 Code graph analysis (7)
nemo_rl/models/policy/lm_policy.py (6)
nemo_rl/models/policy/megatron_policy_worker.py (1)
calibrate_qkv_fp8_scales(1971-2173)nemo_rl/models/policy/interfaces.py (1)
calibrate_qkv_fp8_scales(104-126)nemo_rl/distributed/batched_data_dict.py (2)
BatchedDataDict(75-839)shard_by_batch_size(246-644)nemo_rl/models/generation/interfaces.py (1)
GenerationDatumSpec(127-158)nemo_rl/distributed/worker_groups.py (2)
run_all_workers_sharded_data(774-909)get_all_worker_results(911-928)nemo_rl/distributed/named_sharding.py (1)
get_axis_size(209-211)
nemo_rl/models/generation/vllm/vllm_worker.py (2)
nemo_rl/models/generation/vllm/vllm_generation.py (2)
update_weights_from_ipc_handles(745-796)update_weights_from_collective(798-824)nemo_rl/models/generation/vllm/vllm_backend.py (1)
update_weights_from_collective(216-269)
nemo_rl/models/generation/vllm/vllm_generation.py (3)
nemo_rl/models/generation/vllm/vllm_worker.py (2)
update_weights_from_ipc_handles(702-781)update_weights_from_collective(784-817)nemo_rl/distributed/worker_groups.py (2)
run_all_workers_multiple_data(631-726)run_all_workers_single_data(728-772)nemo_rl/models/generation/vllm/vllm_backend.py (1)
update_weights_from_collective(216-269)
nemo_rl/algorithms/grpo.py (9)
nemo_rl/algorithms/sft.py (1)
MasterConfig(75-81)nemo_rl/models/policy/megatron_policy_worker.py (4)
get_weights_ipc_handles(1594-1684)broadcast_weights_for_collective(1687-1696)prepare_for_lp_inference(1698-1701)calibrate_qkv_fp8_scales(1971-2173)nemo_rl/models/policy/lm_policy.py (4)
get_weights_ipc_handles(687-706)broadcast_weights_for_collective(708-714)prepare_for_lp_inference(554-558)calibrate_qkv_fp8_scales(579-642)nemo_rl/models/policy/dtensor_policy_worker_v2.py (3)
get_weights_ipc_handles(1409-1445)broadcast_weights_for_collective(1448-1469)prepare_for_lp_inference(1472-1479)nemo_rl/models/generation/vllm/vllm_generation.py (2)
update_weights_from_ipc_handles(745-796)update_weights_from_collective(798-824)nemo_rl/models/generation/vllm/vllm_worker.py (2)
update_weights_from_ipc_handles(702-781)update_weights_from_collective(784-817)nemo_rl/models/generation/vllm/vllm_backend.py (1)
update_weights_from_collective(216-269)nemo_rl/distributed/batched_data_dict.py (3)
BatchedDataDict(75-839)get_multimodal_dict(88-99)to(804-811)nemo_rl/algorithms/loss_functions.py (1)
ClippedPGLossDataDict(47-57)
nemo_rl/models/policy/interfaces.py (4)
nemo_rl/models/policy/megatron_policy_worker.py (1)
calibrate_qkv_fp8_scales(1971-2173)nemo_rl/models/policy/lm_policy.py (1)
calibrate_qkv_fp8_scales(579-642)nemo_rl/distributed/batched_data_dict.py (1)
BatchedDataDict(75-839)nemo_rl/models/generation/interfaces.py (1)
GenerationDatumSpec(127-158)
nemo_rl/models/policy/megatron_policy_worker.py (3)
nemo_rl/distributed/batched_data_dict.py (2)
to(804-811)BatchedDataDict(75-839)nemo_rl/models/policy/lm_policy.py (1)
calibrate_qkv_fp8_scales(579-642)nemo_rl/models/policy/interfaces.py (1)
calibrate_qkv_fp8_scales(104-126)
nemo_rl/models/generation/vllm/vllm_backend.py (3)
nemo_rl/models/generation/fp8.py (3)
is_fp8_model(314-325)load_weights(401-428)process_weights_after_loading(504-550)nemo_rl/models/generation/vllm/vllm_generation.py (1)
update_weights_from_collective(798-824)nemo_rl/models/generation/vllm/vllm_worker.py (1)
update_weights_from_collective(784-817)
🪛 Ruff (0.13.1)
nemo_rl/models/generation/vllm/vllm_generation.py
794-794: Do not catch blind exception: Exception
(BLE001)
nemo_rl/algorithms/grpo.py
489-489: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
647-647: f-string without any placeholders
Remove extraneous f prefix
(F541)
nemo_rl/models/policy/megatron_policy_worker.py
2007-2007: Do not catch blind exception: Exception
(BLE001)
2033-2033: Unused function argument: module
(ARG001)
2046-2046: Do not catch blind exception: Exception
(BLE001)
2060-2060: Do not catch blind exception: Exception
(BLE001)
2080-2081: try-except-pass detected, consider logging the exception
(S110)
2080-2080: Do not catch blind exception: Exception
(BLE001)
2095-2095: Do not catch blind exception: Exception
(BLE001)
2104-2104: Value being cast to int is already an integer
Remove unnecessary int call
(RUF046)
2170-2171: try-except-pass detected, consider logging the exception
(S110)
2170-2170: Do not catch blind exception: Exception
(BLE001)
nemo_rl/models/generation/vllm/vllm_backend.py
184-184: f-string without any placeholders
Remove extraneous f prefix
(F541)
nemo_rl/models/generation/fp8.py
108-108: Unused function argument: self
(ARG001)
150-151: Avoid specifying long messages outside the exception class
(TRY003)
185-186: Do not assign a lambda expression, use a def
Rewrite is_singleton_float as a def
(E731)
185-186: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
189-190: Avoid specifying long messages outside the exception class
(TRY003)
205-205: f-string without any placeholders
Remove extraneous f prefix
(F541)
278-278: Local variable kv_cache_dtype is assigned to but never used
Remove assignment to unused variable kv_cache_dtype
(F841)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Lint check
- GitHub Check: Post automodel integration comment / Comment on PR
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (9)
nemo_rl/models/policy/interfaces.py (1)
103-127: New abstract KV-scale calibration API looks goodSignature, typing, and docstring are consistent with the rest of the interface.
Please confirm all concrete Policy implementations (Megatron, DTensor, etc.) implement this method or are gated to avoid calls when unsupported. If DTensor is not yet supported, consider raising NotImplementedError in that backend to fail fast.
nemo_rl/models/policy/megatron_policy_worker.py (1)
19-21: Minor import additionsjson/re are fine and used later.
nemo_rl/models/generation/vllm/vllm_generation.py (2)
745-756: IPC update path now accepts kv_scales — interface and propagation look correctThe optional kv_scales parameter is threaded to workers without breaking the default path.
Confirm update_weights_from_ipc_handles_async in the async worker also accepts kv_scales to avoid mismatched signatures across engine modes.
776-791: Correct replication of kv_scales across tied groupsReplicating kv_scales per worker’s device UUID set is appropriate.
nemo_rl/models/generation/fp8.py (3)
43-44: Config extension for kv_cache_dtype is fineField addition and propagation via init_fp8 are consistent.
233-237: Patching BaseKVCacheMethod hook is appropriateKeeps KV-scale params for dynamic updates; aligns with the goal of feeding calibrated scales later.
251-261: kv_cache_dtype propagated into global FP8ConfigGood to thread through vLLM kwargs downstream.
nemo_rl/models/generation/vllm/vllm_backend.py (1)
89-96: API extension looks goodAccepting optional
kv_scaleshere aligns with upstream worker calls and the generation path.nemo_rl/algorithms/grpo.py (1)
575-584: Passing kv_scales only with the first IPC batch — confirm loader semanticsvLLM’s loader must be able to receive “extra” kv_scale tensors separate from the rest of the weights, and only once. If subsequent weight groups re-touch affected layers, scales must remain intact. Please confirm this is guaranteed for the targeted models and vLLM version.
Would you like me to add a guard to call
process_weights_after_loadingonly once per refit and to assert presence of KV-scale params post-load?
|
@zpqiu sorry for the long delay, I have put some comments; could you first merge in main and then address them? I think this solution can be further optimized in terms of performance, but i'm ok with merging this version first. |
Sure. I will rebase code and resolve these comments first. |
8053461 to
eecb3d4
Compare
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
… max. Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Signed-off-by: alexchiu <alexq@nvidia.com>
ℹ️ File Consistency CheckCheck based on commit: 47ea0c0 (PR #1212 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
ℹ️ File Consistency CheckCheck based on commit: 94d16ec (PR #1212 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
ℹ️ File Consistency CheckCheck based on commit: 48b20aa (PR #1212 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
ℹ️ File Consistency CheckCheck based on commit: 7ca82f3 (PR #1212 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
I'm ok with adding test just to the nightly test suite as you do now. |
|
@terrykong please review |
|
@terrykong this is the last FP8 functionality we want to merge before v0.5, after this I want to perform a refactor of code to make it cleaner and more structured. Please take a review when you have time. |
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com> Signed-off-by: alexchiu <qiuzhaopeng@foxmail.com>
ℹ️ File Consistency CheckCheck based on commit: b34ad76 (PR #1212 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Signed-off-by: alexchiu <alexq@nvidia.com>
ℹ️ File Consistency CheckCheck based on commit: 6d65466 (PR #1212 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
ℹ️ File Consistency CheckCheck based on commit: db3ea88 (PR #1212 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
|
synced offline w/ @guyueh1 who is following up with another PR to do some cleanup, so i'll defer my comments to her PR
|


What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Hi @guyueh1 , as discussed, I'm creating this draft PR for you to review the initial design so that we can discuss the refactor and next steps etc.
Current experiment result of Qwen3-8B: (Orange line: bf16. Green: default FP8 rollout. Blue: default FP8 rollout + KV cache FP8.)

Note that there is room for calibration optimization to reduce the total step time.
@zpqiu fyi.
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit